【GiantPandaCV導語】
CoAt=Convolution + Attention,paperwithcode榜單第一名,通過結合卷積與Transformer實現性能上的突破,方法部分設計非常規整,層層深入考慮模型的架構設計。
data:image/s3,"s3://crabby-images/60478/604786b6e91c004735e9943d0ee05150ca584f0d" alt=""
Transformer模型的容量大,由於缺乏正確的歸納偏置,泛化能力要比卷積網絡差。
提出了CoAtNets模型族:
這部分主要關注如何將conv與transformer以一種最優的方式結合:
卷積方面谷歌使用的是經典的MBConv, 使用深度可分離卷積來捕獲空間之間的交互。
卷積操作的表示:代表i周邊的位置,也即卷積處理的感受野。
自注意力表示:表示全局空間感受野。
融合方法一:先求和,再softmax
融合方法二:先softmax,再求和
出於參數量、計算兩方面的考慮,論文打算採用第二種融合方法。
垂直布局設計決定好合併卷積與注意力的方式後應該考慮如何構建網絡整體架構,主要有三個方面的考量:
第二種方法實現效率不夠高,第三種方法性能不夠好,因此採用第一種方法,如何設計降採樣的方式也有幾種方案:
data:image/s3,"s3://crabby-images/57e97/57e97b16d03bd0ff38eaccffafb4f1655fe75efd" alt=""
採用卷積以及MBConv,從的幾個模塊採用Transformer 結構。具體Transformer內部有以下幾個變體:C代表卷積,T代表Transformer
初步測試模型泛化能力
data:image/s3,"s3://crabby-images/2e3b4/2e3b4d9c835b9714696dd0c06c0d1484f510e0ce" alt=""
泛化能力排序為:(證明架構中還是需要存在想當比例的卷積操作)
data:image/s3,"s3://crabby-images/dde8d/dde8d52843b7bcef8b81637590cbd0267058088f" alt=""
初步測試模型容量
主要是從JFT以及ImageNet-1k上不同的表現來判定的,排序結果為:
data:image/s3,"s3://crabby-images/f3321/f33217d94fff3d74218fede87e1fd9a41f8aacf8" alt=""
測試模型遷移能力
data:image/s3,"s3://crabby-images/8d7e7/8d7e7cace9ab3e354e9c8b99e75d3a83271b0a9f" alt=""
為了進一步比較CCTT與CTTT,進行了遷移能力測試,發現CCTT能夠超越CTTT。
最終CCTT勝出!
實驗與SOTA模型比較結果:
data:image/s3,"s3://crabby-images/bb42b/bb42bb82fe29ffe125cca1b39fedcd9b2069b11b" alt=""
實驗結果:
data:image/s3,"s3://crabby-images/0abb8/0abb8bbd5adcf3498875b697477eebd321963044" alt=""
消融實驗:
data:image/s3,"s3://crabby-images/0f19c/0f19c5433198082b2e01f01114a7d9ceede0a40e" alt=""
data:image/s3,"s3://crabby-images/d16a4/d16a472bb2cbc74db84d9749d6141bac51f48972" alt=""
data:image/s3,"s3://crabby-images/a5a93/a5a93be6632f0dec0a5c9ca1115381d990d1f07d" alt=""
淺層使用的MBConv模塊如下:
classMBConv(nn.Module):def__init__(self,inp,oup,image_size,downsample=False,expansion=4):super().__init__()self.downsample=downsamplestride=1ifself.downsample==Falseelse2hidden_dim=int(inp*expansion)ifself.downsample:self.pool=nn.MaxPool2d(3,2,1)self.proj=nn.Conv2d(inp,oup,1,1,0,bias=False)ifexpansion==1:self.conv=nn.Sequential(#dwnn.Conv2d(hidden_dim,hidden_dim,3,stride,1,groups=hidden_dim,bias=False),nn.BatchNorm2d(hidden_dim),nn.GELU(),#pw-linearnn.Conv2d(hidden_dim,oup,1,1,0,bias=False),nn.BatchNorm2d(oup),)else:self.conv=nn.Sequential(#pw#down-sampleinthefirstconvnn.Conv2d(inp,hidden_dim,1,stride,0,bias=False),nn.BatchNorm2d(hidden_dim),nn.GELU(),#dwnn.Conv2d(hidden_dim,hidden_dim,3,1,1,groups=hidden_dim,bias=False),nn.BatchNorm2d(hidden_dim),nn.GELU(),SE(inp,hidden_dim),#pw-linearnn.Conv2d(hidden_dim,oup,1,1,0,bias=False),nn.BatchNorm2d(oup),)self.conv=PreNorm(inp,self.conv,nn.BatchNorm2d)defforward(self,x):ifself.downsample:returnself.proj(self.pool(x))+self.conv(x)else:returnx+self.conv(x)主要關注Attention Block設計,引入Relative Position:
classAttention(nn.Module):def__init__(self,inp,oup,image_size,heads=8,dim_head=32,dropout=0.):super().__init__()inner_dim=dim_head*headsproject_out=not(heads==1anddim_head==inp)self.ih,self.iw=image_sizeself.heads=headsself.scale=dim_head**-0.5#parametertableofrelativepositionbiasself.relative_bias_table=nn.Parameter(torch.zeros((2*self.ih-1)*(2*self.iw-1),heads))coords=torch.meshgrid((torch.arange(self.ih),torch.arange(self.iw)))coords=torch.flatten(torch.stack(coords),1)relative_coords=coords[:,:,None]-coords[:,None,:]relative_coords[0]+=self.ih-1relative_coords[1]+=self.iw-1relative_coords[0]*=2*self.iw-1relative_coords=rearrange(relative_coords,'chw->hwc')relative_index=relative_coords.sum(-1).flatten().unsqueeze(1)self.register_buffer("relative_index",relative_index)self.attend=nn.Softmax(dim=-1)self.to_qkv=nn.Linear(inp,inner_dim*3,bias=False)self.to_out=nn.Sequential(nn.Linear(inner_dim,oup),nn.Dropout(dropout))ifproject_outelsenn.Identity()defforward(self,x):qkv=self.to_qkv(x).chunk(3,dim=-1)q,k,v=map(lambdat:rearrange(t,'bn(hd)->bhnd',h=self.heads),qkv)dots=torch.matmul(q,k.transpose(-1,-2))*self.scale#Use"gather"formoreefficiencyonGPUsrelative_bias=self.relative_bias_table.gather(0,self.relative_index.repeat(1,self.heads))relative_bias=rearrange(relative_bias,'(hw)c->1chw',h=self.ih*self.iw,w=self.ih*self.iw)dots=dots+relative_biasattn=self.attend(dots)out=torch.matmul(attn,v)out=rearrange(out,'bhnd->bn(hd)')out=self.to_out(out)returnout參考https://arxiv.org/pdf/2106.04803.pdf
https://github.com/chinhsuanwu/coatnet-pytorch
data:image/s3,"s3://crabby-images/d9754/d9754422e0bda1c48c884f59ba01a423da483fb0" alt=""