close

點擊下方卡片,關注「新機器視覺」公眾號

重磅乾貨,第一時間送達


知乎作者:小小將文僅分享,侵刪

原文鏈接:https://zhuanlan.zhihu.com/p/430563265

導讀
在計算機視覺中,模型的性能不僅取決於模型結構和數據質量,還與優化器、損失函數和數據增強等諸多方面有着密切聯繫。本文總結了在分類任務中常用的數據增強方法,希望能對各位有所幫助。

一個模型的性能除了和網絡結構本身有關,還非常依賴具體的訓練策略,比如優化器,數據增強以及正則化策略等(當然也很訓練數據強相關,訓練數據量往往決定模型性能的上線)。近年來,圖像分類模型在ImageNet數據集的top1 acc已經由原來的56.5(AlexNet,2012)提升至90.88(CoAtNet,2021,用了額外的數據集JFT-3B),這進步除了主要歸功於模型,算力和數據的提升,也與訓練策略的提升緊密相關。最近剛興起的vision transformer相比CNN模型往往也需要更heavy的數據增強和正則化策略。這裡簡單介紹圖像分類訓練技巧中的常用數據增強策略。

baseline

ImageNet數據集訓練常用的數據增強策略如下,訓練過程的數據增強包括隨機縮放裁剪(RandomResizedCrop,這種處理方式源自谷歌的Inception,所以稱為 Inception-style pre-processing)和水平翻轉(RandomHorizontalFlip),而測試階段是執行縮放和中心裁剪。這其實是一種輕量級的策略,這裡稱之為baseline。torchvision的實現的ResNet50訓練採用的策略就是這個,在ImageNet上的top1 acc可以達到76.1。

fromtorchvisionimporttransformsnormalize=transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])#訓練train_transform=transforms.Compose([#這裡的scale指的是面積,ratio是寬高比#具體實現每次先隨機確定scale和ratio,可以生成w和h,然後隨機確定裁剪位置進行crop#最後是resize到targetsizetransforms.RandomResizedCrop(224,scale=(0.08,1.0),ratio=(3./4.,4./3.)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),normalize])#測試test_transform=transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),normalize,])AutoAugment

谷歌在2018年提出通過AutoML來自動搜索數據增強策略,稱之為AutoAugment(算是自動數據增強開山之作)。搜索方法採用強化學習,和NAS類似,只不過搜索空間是數據增強策略,而不是網絡架構。在搜索空間裡,一個policy包含5個sub-policies,每個sub-policy包含兩個串行的圖像增強操作,每個增強操作有兩個超參數:進行該操作的概率和圖像增強的幅度(magnitude,這個表示數據增強的強度,比如對於旋轉,旋轉的角度就是增強幅度,旋轉角度越大,增強越大)。每個policy在執行時,首先隨機從5個策略中隨機選擇一個sub-policy,然後序列執行兩個圖像操作。

搜索空間一共有16種圖像增強類型,具體如下所示,大部分操作都定義了圖像增強的幅度範圍,在搜索時需要將幅度值離散化,具體地是將幅度值在定義範圍內均勻地取10個值。

論文在不同的數據集上( CIFAR-10 , SVHN, ImageNet)做了實驗,這裡給出在ImageNet數據集上搜索得到的最優policy(最後實際上是將搜索得到的前5個最好的policies合成了一個policy,所以這裡包含25個sub-policies):

#operation,probability,magnitude(("Posterize",0.4,8),("Rotate",0.6,9)),(("Solarize",0.6,5),("AutoContrast",0.6,None)),(("Equalize",0.8,None),("Equalize",0.6,None)),(("Posterize",0.6,7),("Posterize",0.6,6)),(("Equalize",0.4,None),("Solarize",0.2,4)),(("Equalize",0.4,None),("Rotate",0.8,8)),(("Solarize",0.6,3),("Equalize",0.6,None)),(("Posterize",0.8,5),("Equalize",1.0,None)),(("Rotate",0.2,3),("Solarize",0.6,8)),(("Equalize",0.6,None),("Posterize",0.4,6)),(("Rotate",0.8,8),("Color",0.4,0)),(("Rotate",0.4,9),("Equalize",0.6,None)),(("Equalize",0.0,None),("Equalize",0.8,None)),(("Invert",0.6,None),("Equalize",1.0,None)),(("Color",0.6,4),("Contrast",1.0,8)),(("Rotate",0.8,8),("Color",1.0,2)),(("Color",0.8,8),("Solarize",0.8,7)),(("Sharpness",0.4,7),("Invert",0.6,None)),(("ShearX",0.6,5),("Equalize",1.0,None)),(("Color",0.4,0),("Equalize",0.6,None)),(("Equalize",0.4,None),("Solarize",0.2,4)),(("Solarize",0.6,5),("AutoContrast",0.6,None)),(("Invert",0.6,None),("Equalize",1.0,None)),(("Color",0.6,4),("Contrast",1.0,8)),(("Equalize",0.8,None),("Equalize",0.6,None))

基於搜索得到的AutoAugment訓練可以將ResNet50在ImageNet數據集上的top1 acc從76.3提升至77.6。一個比較重要的問題,這些從某一個數據集搜索得到的策略是否只對固定的數據集有效,論文也通過具體實驗證明了AutoAugment的遷移能力,比如將ImageNet數據集上得到的策略用在5個 FGVC數據集(與ImageNet圖像輸入大小相似)也均有提升。

目前torchvision庫已經實現了AutoAugment,具體使用如下所示(注意AutoAug前也需要包括一個RandomResizedCrop):

fromtorchvision.transformsimportautoaugment,transformstrain_transform=transforms.Compose([transforms.RandomResizedCrop(crop_size,interpolation=interpolation),transforms.RandomHorizontalFlip(hflip_prob),#這裡policy屬於torchvision.transforms.autoaugment.AutoAugmentPolicy,#對於ImageNet就是AutoAugmentPolicy.IMAGENET#此時aa_policy=autoaugment.AutoAugmentPolicy('imagenet')autoaugment.AutoAugment(policy=aa_policy,interpolation=interpolation),transforms.PILToTensor(),transforms.ConvertImageDtype(torch.float),transforms.Normalize(mean=mean,std=std)])RandAugment

AutoAugment存在的一個問題是搜索空間巨大,這使得搜索只能在代理任務中進行:使用小的模型在ImageNet的一個小的子集( 120類和6000圖片)搜索。谷歌在2019年又提出了一個更簡單的數據增強策略:RandAugment。這篇論文首先發現AutoAugment這樣在小數據集上搜索出來的策略在大的數據集上應用會存在問題,這主要是因為數據增強策略和模型大小和數據量大小存在強相關,如下圖所示可以看到模型或者訓練數據量越大,其最優的數據增強的幅度越大,這說明AutoAugment得到的結果應該是次優的。另外,Population Based Augmentation這篇論文發現最優的數據增強幅度是隨訓練過程增加,而且不同的增強操作遵循類似的規律,這啟發作者採用固定的增強幅度而不是去搜索。RandAugment相比AutoAugment的策略空間很小(vs),所以它不需要採用代理任務,甚至直接採用簡單的網格搜索。

具體地,RandAugment共包含兩個超參數:圖像增強操作的數量N和一個全局的增強幅度M,其實現代碼如下所示,每次從候選操作集合(共14種策略)隨機選擇N個操作(等概率),然後串行執行(這裡沒有判斷概率,是一定執行)。這裡的M取值範圍為{0, . . . , 30}(每個圖像增強操作歸一化到同樣的幅度範圍),而N取值範圍一般為 {1, 2, 3}。

#Identity是恆等變換,不做任何增強transforms=['Identity','AutoContrast','Equalize','Rotate','Solarize','Color','Posterize','Contrast','Brightness','Sharpness','ShearX','ShearY','TranslateX','TranslateY']defrandaugment(N,M):"""Generateasetofdistortions.Args:N:Numberofaugmentationtransformationstoapplysequentially.M:Magnitudeforallthetransformations."""sampled_ops=np.random.choice(transforms,N)return[(op,M)foropinsampled_ops]

對於ResNet50,其搜索得到的N=2,M=9,RandAugment相比AutoAugment可以在ImageNet得到相似的效果(77.6),不過DeiT中發現使用RandAugment效果更好一些( DeiT-B:81.8 vs 81.2)。目前torchvision庫也已經實現了RandAugment,具體使用如下所示:

fromtorchvision.transformsimportautoaugment,transformstrain_transform=transforms.Compose([transforms.RandomResizedCrop(crop_size,interpolation=interpolation),transforms.RandomHorizontalFlip(hflip_prob),autoaugment.RandAugment(interpolation=interpolation),transforms.PILToTensor(),transforms.ConvertImageDtype(torch.float),transforms.Normalize(mean=mean,std=std)])TrivialAugment

雖然RandAugment的搜索空間極小,但是對於不同的數據集還是需要確定最優的N和M,這依然有較大的實驗成本。RandAugment後,華為提出了UniformAugment,這種策略不需要搜索也能取得較好的結果。不過這裡我們介紹一項更新的工作:TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation。TrivialAugment也不需要任何搜索,整個方法非常簡單:每次隨機選擇一個圖像增強操作,然後隨機確定它的增強幅度,並對圖像進行增強。由於沒有任何超參數,所以不需要任何搜索。從實驗結果上看,TA可以在多個數據集上取得更好的結果,如在ImageNet數據集上,ResNet50的top1 acc可以達到78.1,超過RandAugment。

TrivialAugment的圖像增強集合和RandAugment基本一樣,不過TA也定義了一套更寬的增強幅度,目前torchvision中已經實現了TrivialAugmentWide,具體使用代碼如下所示:

fromtorchvision.transformsimportautoaugment,transformsaugmentation_space={#op_name:(magnitudes,signed)"Identity":(torch.tensor(0.0),False),"ShearX":(torch.linspace(0.0,0.99,num_bins),True),"ShearY":(torch.linspace(0.0,0.99,num_bins),True),"TranslateX":(torch.linspace(0.0,32.0,num_bins),True),"TranslateY":(torch.linspace(0.0,32.0,num_bins),True),"Rotate":(torch.linspace(0.0,135.0,num_bins),True),"Brightness":(torch.linspace(0.0,0.99,num_bins),True),"Color":(torch.linspace(0.0,0.99,num_bins),True),"Contrast":(torch.linspace(0.0,0.99,num_bins),True),"Sharpness":(torch.linspace(0.0,0.99,num_bins),True),"Posterize":(8-(torch.arange(num_bins)/((num_bins-1)/6)).round().int(),False),"Solarize":(torch.linspace(255.0,0.0,num_bins),False),"AutoContrast":(torch.tensor(0.0),False),"Equalize":(torch.tensor(0.0),False),}train_transform=transforms.Compose([transforms.RandomResizedCrop(crop_size,interpolation=interpolation),transforms.RandomHorizontalFlip(hflip_prob),autoaugment.TrivialAugmentWide(interpolation=interpolation),transforms.PILToTensor(),transforms.ConvertImageDtype(torch.float),transforms.Normalize(mean=mean,std=std)])RandomErasing

RandomErasing是廈門大學在2017年提出的一種簡單的數據增強(這個策略和同期的CutOut基本一樣),基本原理是:隨機從圖像中擦除一個矩形區域而不改變圖像的原始標籤。DeiT的訓練策略中也包括了RandomErasing。

目前torchvision也實現了RandomErasing,其具體使用代碼如下(注意這個op不支持PIL圖像,需要在轉換為torch.tensor後使用):

train_transform=transforms.Compose([transforms.RandomResizedCrop(224,scale=(0.08,1.0),ratio=(3./4.,4./3.)),transforms.RandomHorizontalFlip(),transforms.PILToTensor()transforms.ConvertImageDtype(torch.float),normalize,#scale是指相對於原圖的擦除面積範圍#ratio是指擦除區域的寬高比#value是指擦除區域的值,如果是int,也可以是tuple(RGB3個通道值),或者是str,需為'random',表示隨機生成transforms.RandomErasing(p=0.5,scale=(0.02,0.33),ratio=(0.3,3.3),value=0,inplace=False),])MixUp

MixUp在FAIR在2017年提出的一種數據增強方法:兩張不同的圖像隨機線性組合,而同時生成線性組合的標籤。

這裡的和是兩張不同的圖像,和是它們對應的one-hot標籤,而是線性組合係數,每次執行時隨機生成。假定圖像分類任務是2分類(區分狗和貓),兩張輸入圖像分別是狗和貓(如下圖所示),它們對應的one-hot標籤分別是[1,0]和[0, 1]。在進行mixup之前,首先對它們進行必要的數據增強得到aug_img1和aug_img2,然後隨機生成線性組合係數,對於得到的圖像是mix_img1,標籤變為[0.7, 0.3],而得到的圖像是mix_img2,標籤變為[0.3, 0.7]。

目前timm和torchvision中已經實現了mixup,這裡以torchvision為例來講述具體的代碼實現。由於mixup需要兩個輸入,而不單單是對當前圖像進行操作,所以一般是在得到batch數據後再進行mixup,這也意味着圖像也已經完成了其它的數據增強如RandAugment,對於batch中的每個樣本可以隨機選擇另外一個樣本進行mixup。具體的實現代碼如下所示:

#fromhttps://github.com/pytorch/vision/blob/main/references/classification/transforms.pyclassRandomMixup(torch.nn.Module):"""RandomlyapplyMixuptotheprovidedbatchandtargets.Theclassimplementsthedataaugmentationsasdescribedinthepaper`"mixup:BeyondEmpiricalRiskMinimization"<https://arxiv.org/abs/1710.09412>`_.Args:num_classes(int):numberofclassesusedforone-hotencoding.p(float):probabilityofthebatchbeingtransformed.Defaultvalueis0.5.alpha(float):hyperparameteroftheBetadistributionusedformixup.Defaultvalueis1.0.#beta分布超參數inplace(bool):booleantomakethistransforminplace.DefaultsettoFalse."""def__init__(self,num_classes:int,p:float=0.5,alpha:float=1.0,inplace:bool=False)->None:super().__init__()assertnum_classes>0,"Pleaseprovideavalidpositivevalueforthenum_classes."assertalpha>0,"Alphaparamcan'tbezero."self.num_classes=num_classesself.p=pself.alpha=alphaself.inplace=inplacedefforward(self,batch:Tensor,target:Tensor)->Tuple[Tensor,Tensor]:"""Args:batch(Tensor):Floattensorofsize(B,C,H,W)target(Tensor):Integertensorofsize(B,)Returns:Tensor:Randomlytransformedbatch."""ifbatch.ndim!=4:raiseValueError(f"Batchndimshouldbe4.Got{batch.ndim}")iftarget.ndim!=1:raiseValueError(f"Targetndimshouldbe1.Got{target.ndim}")ifnotbatch.is_floating_point():raiseTypeError(f"Batchdtypeshouldbeafloattensor.Got{batch.dtype}.")iftarget.dtype!=torch.int64:raiseTypeError(f"Targetdtypeshouldbetorch.int64.Got{target.dtype}")ifnotself.inplace:batch=batch.clone()target=target.clone()#建立one-hot標籤iftarget.ndim==1:target=torch.nn.functional.one_hot(target,num_classes=self.num_classes).to(dtype=batch.dtype)#判斷是否進行mixupiftorch.rand(1).item()>=self.p:returnbatch,target#這裡將batch數據平移一個單位,產生mixup的圖像對,這意味着每個圖像與相鄰的下一個圖像進行mixup#timm實現是通過flip來做的,這意味着第一個圖像和最後一個圖像進行mixup#It'sfastertorollthebatchbyoneinsteadofshufflingittocreateimagepairsbatch_rolled=batch.roll(1,0)target_rolled=target.roll(1,0)#隨機生成組合係數#Implementedasonmixuppaper,page3.lambda_param=float(torch._sample_dirichlet(torch.tensor([self.alpha,self.alpha]))[0])batch_rolled.mul_(1.0-lambda_param)batch.mul_(lambda_param).add_(batch_rolled)#得到mixup後的圖像target_rolled.mul_(1.0-lambda_param)target.mul_(lambda_param).add_(target_rolled)#得到mixup後的標籤returnbatch,target

然後可以將MixUp操作放在DataLoader的collate_fn中,這個函數要實現的是將多個樣本合併成一個mini-batch,所以可以將MixUp插在得到mini-batch後,具體實現如下所示:

fromtorch.utils.data.dataloaderimportdefault_collatemixup_transform=RandomMixup(num_classes,p=1.0,alpha=mixup_alpha)collate_fn=lambdabatch:mixup_transform(*default_collate(batch))data_loader=torch.utils.data.DataLoader(dataset,batch_size=batch_size,sampler=train_sampler,collate_fn=collate_fn)

對於MixUp,還要注意兩個兩點。第一個是如果同時採用了label smoothing,那麼在創建one-hot標籤時要直接得到smooth後的標籤,具體實現如下(參考timm):

defone_hot(x,num_classes,on_value=1.,off_value=0.,device='cuda'):x=x.long().view(-1,1)returntorch.full((x.size()[0],num_classes),off_value,device=device).scatter_(1,x,on_value)off_value=smoothing/num_classeson_value=1.-smoothing+off_valuesmooth_one_hot=one_hot(target,num_classes,on_value=on_value,off_value=off_value)

第二個要注意的是MixUp後得到標籤時soft label,不能直接採用torch.nn.CrossEntropyLoss來計算loss,而是直接計算交叉熵(參考timm):

classSoftTargetCrossEntropy(nn.Module):def__init__(self):super(SoftTargetCrossEntropy,self).__init__()defforward(self,x:torch.Tensor,target:torch.Tensor)->torch.Tensor:loss=torch.sum(-target*F.log_softmax(x,dim=-1),dim=-1)returnloss.mean()

注意在PyTorch1.10版本之後,torch.nn.CrossEntropyLoss已經支持直接送入的target是probabilities for each class,原來只支持target是class indices;而且也支持label_smoothing參數,所以上述兩個注意點就不再需要了。

說到計算loss,timm作者近期在ResNet strikes back: An improved training procedure in timm指出採用MixUp後可以將多分類改成多標籤分類(multi-label classification),即從N分類變成N個2分類(直接採用BinaryCrossEntropy),這應該更符合MixUp後圖像的語義,從對比實驗來看效果有微弱的提升。MixUp除了可以用於圖像分類任務,還可以用於物體檢測任務中,比如YOLOX就採用了MixUp,這裡面的做法是對圖像mixup後,其box為兩個圖像的box的合併集合,而沒有對標籤軟化,這塊也可以見論文Bag of Freebies for Training Object Detection Neural Networks。

CutMix

CutMix是2019年提出的一項和MixUp和類似的數據增強策略,它也是同時對兩個圖像和標籤進行混合,與MixUp不同的是它的圖像混合方式。CutMix不是對兩個圖像線性組合,而是從另外一張圖像隨機剪切一個patch並粘貼到第一張圖像上,patch的起始坐標隨機生成,而寬高是由來控制:

這裡和是原始圖像的寬和高,所以其實決定的是patch和原圖的面積比:。下圖展示了分別取0.7和0.3的混合效果,越小,粘貼的patch越大。對於標籤,其處理方式和MixUp一樣,通過來得到兩張圖像的線性組合。

CutMix做了ImageNet上的對比實驗,相比MixUp,ResNet50的top1 acc大約能提升一個點(77.4 vs 78.6):

目前timm和torchvision中也已經實現了CutMix,這裡還是以torchvision為例來講述具體的代碼實現,如下所示(和MixUp基本類似,只不過內部處理存在差異):

classRandomCutmix(torch.nn.Module):"""RandomlyapplyCutmixtotheprovidedbatchandtargets.Theclassimplementsthedataaugmentationsasdescribedinthepaper`"CutMix:RegularizationStrategytoTrainStrongClassifierswithLocalizableFeatures"<https://arxiv.org/abs/1905.04899>`_.Args:num_classes(int):numberofclassesusedforone-hotencoding.p(float):probabilityofthebatchbeingtransformed.Defaultvalueis0.5.alpha(float):hyperparameteroftheBetadistributionusedforcutmix.Defaultvalueis1.0.inplace(bool):booleantomakethistransforminplace.DefaultsettoFalse."""def__init__(self,num_classes:int,p:float=0.5,alpha:float=1.0,inplace:bool=False)->None:super().__init__()assertnum_classes>0,"Pleaseprovideavalidpositivevalueforthenum_classes."assertalpha>0,"Alphaparamcan'tbezero."self.num_classes=num_classesself.p=pself.alpha=alphaself.inplace=inplacedefforward(self,batch:Tensor,target:Tensor)->Tuple[Tensor,Tensor]:"""Args:batch(Tensor):Floattensorofsize(B,C,H,W)target(Tensor):Integertensorofsize(B,)Returns:Tensor:Randomlytransformedbatch."""ifbatch.ndim!=4:raiseValueError(f"Batchndimshouldbe4.Got{batch.ndim}")iftarget.ndim!=1:raiseValueError(f"Targetndimshouldbe1.Got{target.ndim}")ifnotbatch.is_floating_point():raiseTypeError(f"Batchdtypeshouldbeafloattensor.Got{batch.dtype}.")iftarget.dtype!=torch.int64:raiseTypeError(f"Targetdtypeshouldbetorch.int64.Got{target.dtype}")ifnotself.inplace:batch=batch.clone()target=target.clone()iftarget.ndim==1:target=torch.nn.functional.one_hot(target,num_classes=self.num_classes).to(dtype=batch.dtype)iftorch.rand(1).item()>=self.p:returnbatch,target#It'sfastertorollthebatchbyoneinsteadofshufflingittocreateimagepairsbatch_rolled=batch.roll(1,0)target_rolled=target.roll(1,0)#Implementedasoncutmixpaper,page12(withminorcorrectionsontypos).lambda_param=float(torch._sample_dirichlet(torch.tensor([self.alpha,self.alpha]))[0])W,H=F.get_image_size(batch)#確定patch的起點r_x=torch.randint(W,(1,))r_y=torch.randint(H,(1,))#確定patch的w和h(其實是一半大小)r=0.5*math.sqrt(1.0-lambda_param)r_w_half=int(r*W)r_h_half=int(r*H)#越界處理x1=int(torch.clamp(r_x-r_w_half,min=0))y1=int(torch.clamp(r_y-r_h_half,min=0))x2=int(torch.clamp(r_x+r_w_half,max=W))y2=int(torch.clamp(r_y+r_h_half,max=H))batch[:,:,y1:y2,x1:x2]=batch_rolled[:,:,y1:y2,x1:x2]#由于越界處理,λ可能發生改變,所以要重新計算lambda_param=float(1.0-(x2-x1)*(y2-y1)/(W*H))target_rolled.mul_(1.0-lambda_param)target.mul_(lambda_param).add_(target_rolled)returnbatch,target

其它使用和MixUp一樣。

Repeated Augmentation

Repeated Augmentation (RA)是FAIR在MultiGrain提出的一種抽樣策略,一般情況下,訓練的mini-batch包含的增強過的sample都是來自不同的圖像,但是RA這種抽樣策略允許一個mini-batch中包含來自同一個圖像的不同增強版本,此時mini-batch的各個樣本並非是完全獨立的,這相當於對同一個樣本進行重複抽樣,所以稱為Repeated Augmentation。這篇論文認為在一個mini-batch學習來自同一個圖像的不同增強版本能讓模型更容易學習到增強不變的特徵。關於RA,其實另外一篇較早的論文Augment your batch: better training with larger batches也提出了類似的策略,另外DeepMind在最近的論文Drawing Multiple Augmentation Samples Per Image During Training Efficiently Decreases Test Error也進一步通過實驗來證明這種策略的效果。

DeiT的訓練也採用了RA,嚴格來說RA不屬於數據增強策略,而是一種mini-batch抽樣方法,這裡也簡單給出DeiT實現的RA(可以替換torch.utils.data.DistributedSampler):

classRASampler(torch.utils.data.Sampler):"""Samplerthatrestrictsdataloadingtoasubsetofthedatasetfordistributed,withrepeatedaugmentation.Itensuresthatdifferenteachaugmentedversionofasamplewillbevisibletoadifferentprocess(GPU)Heavilybasedontorch.utils.data.DistributedSampler"""def__init__(self,dataset,num_replicas=None,rank=None,shuffle=True):ifnum_replicasisNone:ifnotdist.is_available():raiseRuntimeError("Requiresdistributedpackagetobeavailable")num_replicas=dist.get_world_size()ifrankisNone:ifnotdist.is_available():raiseRuntimeError("Requiresdistributedpackagetobeavailable")rank=dist.get_rank()self.dataset=datasetself.num_replicas=num_replicasself.rank=rankself.epoch=0#重複採樣後每個replica的樣本量self.num_samples=int(math.ceil(len(self.dataset)*3.0/self.num_replicas))#重複採樣後的總樣本量self.total_size=self.num_samples*self.num_replicas#self.num_selected_samples=int(math.ceil(len(self.dataset)/self.num_replicas))#每個replica實際樣本量,即不重複採樣時的每個replica的樣本量self.num_selected_samples=int(math.floor(len(self.dataset)//256*256/self.num_replicas))self.shuffle=shuffledef__iter__(self):#deterministicallyshufflebasedonepochg=torch.Generator()g.manual_seed(self.epoch)ifself.shuffle:indices=torch.randperm(len(self.dataset),generator=g).tolist()else:indices=list(range(len(self.dataset)))#addextrasamplestomakeitevenlydivisibleindices=[eleforeleinindicesforiinrange(3)]#重複3次indices+=indices[:(self.total_size-len(indices))]assertlen(indices)==self.total_size#subsample:使得同一個樣本的重複版本進入不同的進程(GPU)indices=indices[self.rank:self.total_size:self.num_replicas]assertlen(indices)==self.num_samplesreturniter(indices[:self.num_selected_samples])#截取實際樣本量def__len__(self):returnself.num_selected_samplesdefset_epoch(self,epoch):self.epoch=epoch小結

這裡簡單介紹了幾種常用且有效的數據增強策略,這些策略在vision transformer模型被使用,而且timm訓練的ResNet新baseline也使用了這些策略。

參考

Training data-efficient image transformers & distillation through attention (https://arxiv.org/abs/2012.12877)
AutoAugment: Learning Augmentation Policies from Data (https://arxiv.org/abs/1805.09501)
RandAugment: Practical automated data augmentation with a reduced search space (https://arxiv.org/abs/1909.13719)
TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation (https://arxiv.org/abs/2103.10158)
Random Erasing Data Augmentation(https://arxiv.org/abs/1708.04896)
Augment your batch: better training with larger batches (https://arxiv.org/abs/1901.09335)
MultiGrain: a unified image embedding for classes and instances(https://arxiv.org/abs/1902.05509)
mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)

本文僅做學術分享,如有侵權,請聯繫刪文。

—THE END—
arrow
arrow
    全站熱搜
    創作者介紹
    創作者 鑽石舞台 的頭像
    鑽石舞台

    鑽石舞台

    鑽石舞台 發表在 痞客邦 留言(0) 人氣()