close

戳我,查看GAN的系列專輯~!

等你着陸!【GAN生成對抗網絡】知識星球!

轉載zlhroughlove@知乎編輯極市平台僅分享,侵權

來源https://zhuanlan.zhihu.com/p/441317369

前言

在使用Pytorch建模時,常見的流程為先寫Model,再寫Dataset,最後寫Trainer。Dataset 是整個項目開發中投入時間第二多,也是中間關鍵的步驟。往往需要事先對於其設計有明確的思考,不然可能會因為Dataset的一些問題又要去調整Model,Trainer。本文將目前開發中的一些思考以及遇到的問題做一個總結,提供給各位讀者一個比較通用的模版,拋磚引玉~

一、Dataset的定義fromtorch.utils.dataimportDataset,DataLoader,RandomSampler

對於不同類型的建模任務,模型的輸入各不相同。自然語言,多模態,點擊率預估,往往這些場景輸入模型的數據並不是來自於單一文件,而且可能無法全部存入內存。Dataset需要整合項目的數據,對於單條樣本涉及到的數據做一個提取與歸納。不但如此,項目可能還涉及到多種模型,任務的訓練。Dataset需要為不同的模型以及訓練任務提供不同的單條樣本輸入,作為一個數據生成器,把後續模型訓練任務需要的所有基礎數據,標籤全返回了。所以往往我們可以定義一個BaseDataset類,繼承torch.utils.data.Dataset,這個類可以初始化一些文件路徑,配置等。後面不同的模型訓練任務定義相應的Dataset類繼承BaseDataset。

Dataset通用的結構為:

classBaseDataset(Dataset):def__init__(self,config):self.config=configifos.path.isfile(config.file_path)isFalse:raiseValueError(f"Inputfilepath{config.file_path}notfound")logger.info(f"Creatingfeaturesfromdatasetfileat{config.file_path}")#一次性全讀進內存self.data=joblib.load(config.file_path)self.nums=len(self.data)def__len__(self):returnself.numsdef__getitem__(self,i)->Dict[str,tensor]:sample_i=self.data[i]return{"f1":torch.tensor(sample_i["f1"]).long(),"f2":torch.tensor(sample_i["f2"]).long(),torch.LongTensor([sample_i["label"]])}

如果無法全部讀取進內存需要再__getitem__方法內構建數據,做自然語言則可以吧tokenizer初始化到該類中,在__getitem__方法內完成tokenizer。改方法的輸出推薦做成字典形式。

對於不同的訓練任務可以通過以下方法返迴響應的數據生成器

defbuild_dataset(task_type,features,**kwargs):asserttask_typein['task1','task2'],'taskmismatch'iftask_type=='task1':dataset=task1Dataset(features))else:dataset=task2Dataset(features)returndataset

有時模型的訓練任務需要做數據增強,對比學習,構造多種的預訓練任務輸入。Dataset的職能邊界是提供一套基礎的單樣本數據輸入生成器。如果是MLM任務,可以在Dataset內生成maskposition以及label。如果是在batch內的對比學習則應該在DataLoader生產batch數據後再進行。

二、DataLoader的定義

DataLoader的作用是對Dataset進行多進程高效地構建每個訓練批次的數據。傳入的數據可以認為是長度為batch大小的多個__getitem__ 方法返回的字典list。DataLoader的職能邊界是根據Dataset提供的單條樣本數據有選擇的構建一個batch的模型輸入數據。

其通常的結構為對Train,Valid,Test分別建立:

train_sampler=RandomSampler(train_dataset)train_loader=DataLoader(dataset=train_dataset,batch_size=args.train_batch_size,sampler=train_sampler,shuffle=(train_samplerisNone)collate_fn=None,#一般不用設置num_workers=4)

首先對於sampler 還有一種定義方式:

sampler=torch.utils.data.distributed.DistributedSampler(dataset)

至於batch內數據是否需要做shuffle也需要根據損失函數確定(對比學習慎用)

DataLoader會自動合併__getitem__ 方法返回的字典內每個key內每個tensor,在tensor的第0維度新增一個batch大小的維度。如果該方法返回的每條樣本長度不同無法拼接,batchsize>1就會報錯。但是又一些任務在還沒有確定後續的批樣本對應的任務時,Dataset可能返回的字典里每個key可能就是長度不同的tensor,甚至是list,這時候需要使用collate_fn參數告訴DataLoader如何取樣。我們可以定義自己的函數來準確地實現想要的功能。

如果__getitem__方法返回的是tuple((list, list)) 可以使用:

defmerge_sample(x):returnzip(*x)train_loader=DataLoader(dataset=train_dataset,batch_size=args.train_batch_size,sampler=train_sampler,shuffle=(train_samplerisNone)collate_fn=merge_sample,num_workers=4)

拼接數據,後續再做進一步處理。(此時list內數據還是不等長,無法轉為tensor)

如果__getitem_方法返回的是Dict[str,tensor],自定義的collate_fn方法內需要實現:List[Dict[str,tensor(xx)]]->Dict[str,tensor(bs,xx)]的操作,pad_sequence過程也可以在自定義方法內實現。(總之collate_fn中不但可以處理不等長數據,還可以對一個batch的數據做精修。當然也可以在DataLoader之後再做修改batch內的數據。)

值得注意的是在cpu環境下,如果要自定義collate_fn,num_workers必須設置為0,不然就會有問題..

通過以下方式可以檢查一下輸入後續模型的數據是否已經是想要的格式

forstep,batch_datainenumerate(train_loader):ifstep<1:print(batch_data)else:break

之後數據將數據放入gpu device, 一個batch的數據進入device端後就與內存上的數據不再互相干擾。之後數據就可以餵給模型了:

loss=model(**batch_data)

forkeyinbatch_data.keys():batch_data[key]=batch_data[key].to(device)

猜您喜歡:

超110篇!CVPR 2021最全GAN論文匯總梳理!

超100篇!CVPR 2020最全GAN論文梳理匯總!

拆解組新的GAN:解耦表徵MixNMatch

StarGAN第2版:多域多樣性圖像生成

附下載 |《可解釋的機器學習》中文版

附下載 |《TensorFlow 2.0 深度學習算法實戰》

附下載 |《計算機視覺中的數學方法》分享

《基於深度學習的表面缺陷檢測方法綜述》

《零樣本圖像分類綜述: 十年進展》

《基於深度神經網絡的少樣本學習綜述》

arrow
arrow
    全站熱搜
    創作者介紹
    創作者 鑽石舞台 的頭像
    鑽石舞台

    鑽石舞台

    鑽石舞台 發表在 痞客邦 留言(0) 人氣()