close
↑ 點擊藍字關注極市平台
作者丨Lart

編輯丨極市平台

極市導讀

one-hot 形式的編碼在深度學習任務中非常常見,但是卻並不是一種很自然的數據存儲方式。所以大多數情況下都需要我們自己手動轉換。本文儘可能將基於 pytorch 中常用方法來實現one-hot編碼的方式整理了下,希望對大家有用。>>加入極市CV技術交流群,走在計算機視覺的最前沿

前言

one-hot 形式的編碼在深度學習任務中非常常見,但是卻並不是一種很自然的數據存儲方式。所以大多數情況下都需要我們自己手動轉換。雖然思路很直接,就是將類別拆分成一一對應的 0-1 向量,但是具體實現起來確實還是需要思考下的。實際上 pytorch 自身在nn.functional中已經提供了one_hot方法來快速應用。但是這並不能影響我們的思考與實踐:>!所以本文儘可能將基於 pytorch 中常用方法來實現one-hot編碼的方式整理了下,希望有用。

主要的方式有這麼幾種:

for循環
scatter
index_select
原始文檔:

https://www.yuque.com/lart/ugkv9f/src5w8

代碼倉庫:

https://github.com/lartpang/CodeForArticle/tree/main/OneHotEncoding.PyTorch

for循環

這種方法非常直觀,說白了就是對一個空白(全零)張量中的指定位置進行賦值(賦 1)操作即可。關鍵在於如何設定索引。下面設計了兩種本質相同但由於指定維度不同而導致些許差異的方案。

defbhw_to_onehot_by_for(bhw_tensor:torch.Tensor,num_classes:int):"""Args:bhw_tensor:b,h,wnum_classes:Returns:b,h,w,num_classes"""assertbhw_tensor.ndim==3,bhw_tensor.shapeassertnum_classes>bhw_tensor.max(),torch.unique(bhw_tensor)one_hot=bhw_tensor.new_zeros(size=(num_classes,*bhw_tensor.shape))foriinrange(num_classes):one_hot[i,bhw_tensor==i]=1one_hot=one_hot.permute(1,2,3,0)returnone_hotdefbhw_to_onehot_by_for_V1(bhw_tensor:torch.Tensor,num_classes:int):"""Args:bhw_tensor:b,h,wnum_classes:Returns:b,h,w,num_classes"""assertbhw_tensor.ndim==3,bhw_tensor.shapeassertnum_classes>bhw_tensor.max(),torch.unique(bhw_tensor)one_hot=bhw_tensor.new_zeros(size=(*bhw_tensor.shape,num_classes))foriinrange(num_classes):one_hot[...,i][bhw_tensor==i]=1returnone_hotscatter

該方法應該是網上大多數簡潔的one_hot寫法的常用形式了。其實際上主要的作用是向 tensor 中指定的位置上賦值。

由於其可以使用專門構造的索引矩陣來作為索引,所以更加靈活。當然,靈活帶來的也就是理解上的困難。官方文檔中提供的解釋非常直觀:

'''https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html*(intdim,Tensorindex,Tensorsrc)*(intdim,Tensorindex,Tensorsrc,*,strreduce)*(intdim,Tensorindex,Numbervalue)*(intdim,Tensorindex,Numbervalue,*,strreduce)'''self[index[i][j][k]][j][k]=src[i][j][k]#ifdim==0self[i][index[i][j][k]][k]=src[i][j][k]#ifdim==1self[i][j][index[i][j][k]]=src[i][j][k]#ifdim==2

文檔中使用的是原地置換(in-place)版本,並且基於替換值為src,即 tensor 的情況下來解釋。實際上在我們的應用中主要基於原地置換版本並搭配替換值為標量浮點數value的形式。

上述的形式中,我們可以看到,通過指定參數 tensor index,我們就可以將src中(i,j,k)的值放置到方法調用者(這裡是self)的指定位置上。該指定位置由index的(i,j,k)處的值替換坐標(i,j,k)中的dim位置的值來構成(這裡也反映出來了index tensor 的一個要求,就是維度數量要和self、src(如果src為 tensor 的話。後文中使用的是具體的標量值 1,即src替換為value)一致)。這倒是和one-hot的概念非常吻合。因為one-hot本身形式上的含義就是對於第i類數據,第i個位置為 1,其餘位置為 0。所以對全零 tensor 使用scatter_是可以非常容易的構造出one-hottensor 的,即對對應於類別編號的位置放置 1 即可。

對於我們的問題而言,index非常適合使用輸入的包含類別編號的 tensor(形狀為B,H,W)來表示。基於這樣的思考,可以構思出兩種不同的策略:

defbhw_to_onehot_by_scatter(bhw_tensor:torch.Tensor,num_classes:int):"""Args:bhw_tensor:b,h,wnum_classes:Returns:b,h,w,num_classes"""assertbhw_tensor.ndim==3,bhw_tensor.shapeassertnum_classes>bhw_tensor.max(),torch.unique(bhw_tensor)one_hot=torch.zeros(size=(math.prod(bhw_tensor.shape),num_classes))one_hot.scatter_(dim=1,index=bhw_tensor.reshape(-1,1),value=1)one_hot=one_hot.reshape(*bhw_tensor.shape,num_classes)returnone_hotdefbhw_to_onehot_by_scatter_V1(bhw_tensor:torch.Tensor,num_classes:int):"""Args:bhw_tensor:b,h,wnum_classes:Returns:b,h,w,num_classes"""assertbhw_tensor.ndim==3,bhw_tensor.shapeassertnum_classes>bhw_tensor.max(),torch.unique(bhw_tensor)one_hot=torch.zeros(size=(*bhw_tensor.shape,num_classes))one_hot.scatter_(dim=-1,index=bhw_tensor[...,None],value=1)returnone_hot

這兩種形式的差異的根源在於對形狀的處理上。由此帶來了scatter不同的應用形式。

對於第一種形式,將B,H,W三個維度合併,這樣的好處是對通道(類別)的索引的理解變得直觀起來。

one_hot=torch.zeros(size=(math.prod(bhw_tensor.shape),num_classes))one_hot.scatter_(dim=1,index=bhw_tensor.reshape(-1,1),value=1)

這裡將類別維度和其他維度直接分離,移到了末位。通過dim指定該維度,於是就有了這樣的對應關係:

zero_tensor[abc,index[abc][d]]=value#d=0

而在第二種情況下仍然保留了前面的三個維度,類別維度依然移動到最後一位。

one_hot=torch.zeros(size=(*bhw_tensor.shape,num_classes))one_hot.scatter_(dim=-1,index=bhw_tensor[...,None],value=1)

此時的對應關係是這樣的:

zero_tensor[a,b,c,index[a][b][c][d]]=value#d=0

另外在 pytorch 分類模型庫 timm 中,也使用了類似的方法:

#https://github.com/rwightman/pytorch-image-models/blob/2c33ca6d8ce5d9257edf8cab5ab7ece81780aaf7/timm/data/mixup.py#L17-L19defone_hot(x,num_classes,on_value=1.,off_value=0.,device='cuda'):x=x.long().view(-1,1)returntorch.full((x.size()[0],num_classes),off_value,device=device).scatter_(1,x,on_value)index_selecttorch.index_select(input,dim,index,*,out=None)→Tensor-input(Tensor)–theinputtensor.-dim(int)–thedimensioninwhichweindex-index(IntTensororLongTensor)–the1-Dtensorcontainingtheindicestoindex

該函數如其名,就是用索引來選擇 tensor 的指定維度的子 tensor 的。

想要理解這一方法的動機,實際上需要反過來,從類別標籤的角度看待one-hot編碼。

對於原始從小到大排布的類別序號對應的one-hot編碼成的矩陣就是一個單位矩陣。所以每個類別對應的就是該單位矩陣的特定的列(或者行)。這一需求恰好符合index_select的功能。所以我們可以使用其實現one_hot編碼,只需要使用類別序號索引特定的列或者行即可。下面就是一個例子:

defbhw_to_onehot_by_index_select(bhw_tensor:torch.Tensor,num_classes:int):"""Args:bhw_tensor:b,h,wnum_classes:Returns:b,h,w,num_classes"""assertbhw_tensor.ndim==3,bhw_tensor.shapeassertnum_classes>bhw_tensor.max(),torch.unique(bhw_tensor)one_hot=torch.eye(num_classes).index_select(dim=0,index=bhw_tensor.reshape(-1))one_hot=one_hot.reshape(*bhw_tensor.shape,num_classes)returnone_hot性能對比

整體代碼可見:https://github.com/lartpang/CodeForArticle/tree/main/OneHotEncoding.PyTorch

下面展示了不同方法的大致的相對性能(因為後台在跑程序,可能並不是十分準確,建議大家自行測試)。可以看到,pytorch 自帶的函數在 CPU 上效率並不是很高,但是在 GPU 上表現良好。其中有趣的是,基於index_select的形式表現非常亮眼。

1.10.0GeForceRTX2080Ticpu('bhw_to_onehot_by_for',0.5411529541015625)('bhw_to_onehot_by_for_V1',0.4515676498413086)('bhw_to_onehot_by_scatter',0.0686192512512207)('bhw_to_onehot_by_scatter_V1',0.08529376983642578)('bhw_to_onehot_by_index_select',0.05156970024108887)('F.one_hot',0.07366824150085449)gpu('bhw_to_onehot_by_for',0.005235433578491211)('bhw_to_onehot_by_for_V1',0.045584678649902344)('bhw_to_onehot_by_scatter',0.0025513172149658203)('bhw_to_onehot_by_scatter_V1',0.0024869441986083984)('bhw_to_onehot_by_index_select',0.002012014389038086)('F.one_hot',0.0024051666259765625)

如果覺得有用,就請分享到朋友圈吧!


△點擊卡片關注極市平台,獲取最新CV乾貨

公眾號後台回復「transformer」獲取最新Transformer綜述論文下載~

極市乾貨

課程/比賽:珠港澳人工智能算法大賽|保姆級零基礎人工智能教程
算法trick:目標檢測比賽中的tricks集錦|從39個kaggle競賽中總結出來的圖像分割的Tips和Tricks
技術綜述:一文弄懂各種loss function|工業圖像異常檢測最新研究總結(2019-2020)

#極市平台簽約作者#





Lart

知乎:人民藝術家

CSDN:有為少年


大連理工大學在讀博士

研究領域:主要方向為圖像分割,但多從事於二值圖像分割的研究。也會關注其他領域,例如分類和檢測等方向的發展。


作品精選

實踐教程 | PyTorch中相對位置編碼的理解
實操教程 | 使用Docker為無網絡環境搭建深度學習環境
實踐教程 | 一文讓你把Docker用起來!




投稿方式:
添加小編微信Fengcall(微信號:fengcall19),備註:姓名-投稿
△長按添加極市平台小編

覺得有用麻煩給個在看啦~
arrow
arrow
    全站熱搜
    創作者介紹
    創作者 鑽石舞台 的頭像
    鑽石舞台

    鑽石舞台

    鑽石舞台 發表在 痞客邦 留言(0) 人氣()