close

點擊下方卡片,關注「CVer」公眾號

AI/CV重磅乾貨,第一時間送達


轉載自:集智書童
PyramidTNT:Improved Transformer-in-Transformer Baselines with Pyramid Architecture
論文:https://arxiv.org/abs/2201.00978
代碼(剛剛開源):https://github.com/huawei-noah/CV-Backbones/tree/master/tnt_pytorch

Transformer在計算機視覺任務方面取得了很大的進展。Transformer-in-Transformer (TNT)體系結構利用內部Transformer和外部Transformer來提取局部和全局表示。在這項工作中,通過引入2種先進的設計來提出新的TNT Baseline:

Pyramid Architecture

Convolutional Stem

新的「PyramidTNT」通過建立層次表示,顯著地改進了原來的TNT。PyramidTNT相較於之前最先進的Vision Transformer具有更好的性能,如Swin-Transformer。

1簡介

Vision Transformer為計算機視覺提供了一種新的解決思路。從ViT開始,提出了一系列改進Vision Transformer體系結構的工作。

PVT介紹了Vision Transformer的金字塔網絡體系結構

T2T-ViT-14 遞歸地將相鄰的Token聚合為一個Token,以提取局部結構,減少Token的數量

TNT 利用 inner Transformer和outer Transformer來建模 word-level 和 sentence-level 的視覺表示

Swin-Transformer提出了一種分層Transformer,其表示由Shifted windows來進行計算

隨着近年來的研究進展,Vision Transformer的性能已經可以優於卷積神經網絡(CNN)。而本文的這項工作是建立了基於TNT框架的改進的 Vision Transformer Baseline。這裡主要引入了兩個主要的架構修改:

Pyramid Architecture:逐漸降低分辨率,提取多尺度表示

Convolutional Stem:修補Stem和穩定訓練

這裡作者還使用了幾個其他技巧來進一步提高效率。新的Transformer被命名為PyramidTNT。

對圖像分類和目標檢測的實驗證明了金字塔檢測的優越性。具體來說,PyramidTNT-S在只有3.3B FLOPs的情況下獲得了82.0%的ImageNet分類準確率,明顯優於原來的TNT-S和Swin-T。

對於COCO檢測,PyramidTNT-S比現有的Transformer和MLP檢測模型以更少的計算成本實現42.0的mAP。

2本文方法2.1 Convolutional Stem

給定一個輸入圖像,TNT模型首先將圖像分割成多個patch,並進一步將每個patch視為一個sub-patch序列。然後應用線性層將sub-patch投射到visual word vector(又稱token)。這些視覺word被拼接在一起並轉換成一個visual sentence vector。

肖奧等人發現在ViT中使用多個卷積作為Stem可以提高優化穩定性,也能提高性能。在此基礎上,本文構造了一個金字塔的卷積Stem。利用3×3卷積的堆棧產生visual word vector ,其中C是visual word vector的維度。同樣也可以得到visual sentence vector ,其中D是visual sentence vector 的維度。word-level 和 sentence-level位置編碼分別添加到visual words和sentences上,和原始的TNT一樣。

classStem(nn.Module):"""ImagetoVisualWordEmbedding"""def__init__(self,img_size=224,in_chans=3,outer_dim=768,inner_dim=24):super().__init__()img_size=to_2tuple(img_size)self.img_size=img_sizeself.inner_dim=inner_dimself.num_patches=img_size[0]//8*img_size[1]//8self.num_words=16self.common_conv=nn.Sequential(nn.Conv2d(in_chans,inner_dim*2,3,stride=2,padding=1),nn.BatchNorm2d(inner_dim*2),nn.ReLU(inplace=True),)#利用innerTransformer來建模word-levelself.inner_convs=nn.Sequential(nn.Conv2d(inner_dim*2,inner_dim,3,stride=1,padding=1),nn.BatchNorm2d(inner_dim),nn.ReLU(inplace=False),)#利用outerTransformer來建模sentence-level的視覺表示self.outer_convs=nn.Sequential(nn.Conv2d(inner_dim*2,inner_dim*4,3,stride=2,padding=1),nn.BatchNorm2d(inner_dim*4),nn.ReLU(inplace=True),nn.Conv2d(inner_dim*4,inner_dim*8,3,stride=2,padding=1),nn.BatchNorm2d(inner_dim*8),nn.ReLU(inplace=True),nn.Conv2d(inner_dim*8,outer_dim,3,stride=1,padding=1),nn.BatchNorm2d(outer_dim),nn.ReLU(inplace=False),)self.unfold=nn.Unfold(kernel_size=4,padding=0,stride=4)defforward(self,x):B,C,H,W=x.shapeH_out,W_out=H//8,W//8H_in,W_in=4,4x=self.common_conv(x)#inner_tokens建模wordlevel表徵inner_tokens=self.inner_convs(x)#B,C,H,Winner_tokens=self.unfold(inner_tokens).transpose(1,2)#B,N,Ck2inner_tokens=inner_tokens.reshape(B*H_out*W_out,self.inner_dim,H_in*W_in).transpose(1,2)#B*N,C,4*4#outer_tokens建模sentencelevel表徵outer_tokens=self.outer_convs(x)#B,C,H_out,W_outouter_tokens=outer_tokens.permute(0,2,3,1).reshape(B,H_out*W_out,-1)returninner_tokens,outer_tokens,(H_out,W_out),(H_in,W_in)2.2 Pyramid Architecture

原始的TNT網絡在繼ViT之後的每個塊中保持相同數量的token。visual words和visual sentences的數量從下到上保持不變。

本文受PVT的啟發,為TNT構建了4個不同數量的Token階段,如圖1(b)。所示在這4個階段中,visual words的空間形狀分別設置為H/2×W/2、H/4×W/4、H/8×W/8、H/16×W/16;visual sentences的空間形狀分別設置為H/8×W/8、H/16×W/16、H/32×W/32、H/64×W/64。下採樣操作是通過stride=2的卷積來實現的。每個階段由幾個TNT塊組成,TNT塊在word-level 和 sentence-level特徵上操作。最後,利用全局平均池化操作,將輸出的visual sentences融合成一個向量作為圖像表示。

classSentenceAggregation(nn.Module):"""SentenceAggregation"""def__init__(self,dim_in,dim_out,stride=2,act_layer=nn.GELU):super().__init__()self.stride=strideself.norm=nn.LayerNorm(dim_in)self.conv=nn.Sequential(nn.Conv2d(dim_in,dim_out,kernel_size=2*stride-1,padding=stride-1,stride=stride),)defforward(self,x,H,W):B,N,C=x.shape#B,N,Cx=self.norm(x)x=x.transpose(1,2).reshape(B,C,H,W)x=self.conv(x)H,W=math.ceil(H/self.stride),math.ceil(W/self.stride)x=x.reshape(B,-1,H*W).transpose(1,2)returnx,H,WclassWordAggregation(nn.Module):"""WordAggregation"""def__init__(self,dim_in,dim_out,stride=2,act_layer=nn.GELU):super().__init__()self.stride=strideself.dim_out=dim_outself.norm=nn.LayerNorm(dim_in)self.conv=nn.Sequential(nn.Conv2d(dim_in,dim_out,kernel_size=2*stride-1,padding=stride-1,stride=stride),)defforward(self,x,H_out,W_out,H_in,W_in):B_N,M,C=x.shape#B*N,M,Cx=self.norm(x)x=x.reshape(-1,H_out,W_out,H_in,W_in,C)#paddingtofit(1333,800)indetection.pad_input=(H_out%2==1)or(W_out%2==1)ifpad_input:x=F.pad(x.permute(0,3,4,5,1,2),(0,W_out%2,0,H_out%2))x=x.permute(0,4,5,1,2,3)#patchmergex1=x[:,0::2,0::2,:,:,:]#B,H/2,W/2,H_in,W_in,Cx2=x[:,1::2,0::2,:,:,:]x3=x[:,0::2,1::2,:,:,:]x4=x[:,1::2,1::2,:,:,:]x=torch.cat([torch.cat([x1,x2],3),torch.cat([x3,x4],3)],4)#B,H/2,W/2,2*H_in,2*W_in,Cx=x.reshape(-1,2*H_in,2*W_in,C).permute(0,3,1,2)#B_N/4,C,2*H_in,2*W_inx=self.conv(x)#B_N/4,C,H_in,W_inx=x.reshape(-1,self.dim_out,M).transpose(1,2)returnxclassStage(nn.Module):"""PyramidTNTstage"""def__init__(self,num_blocks,outer_dim,inner_dim,outer_head,inner_head,num_patches,num_words,mlp_ratio=4.,qkv_bias=False,qk_scale=None,drop=0.,attn_drop=0.,drop_path=0.,act_layer=nn.GELU,norm_layer=nn.LayerNorm,se=0,sr_ratio=1):super().__init__()blocks=[]drop_path=drop_pathifisinstance(drop_path,list)else[drop_path]*num_blocksforjinrange(num_blocks):ifj==0:_inner_dim=inner_dimelifj==1andnum_blocks>6:_inner_dim=inner_dimelse:_inner_dim=-1blocks.append(Block(outer_dim,_inner_dim,outer_head=outer_head,inner_head=inner_head,num_words=num_words,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,qk_scale=qk_scale,drop=drop,attn_drop=attn_drop,drop_path=drop_path[j],act_layer=act_layer,norm_layer=norm_layer,se=se,sr_ratio=sr_ratio))self.blocks=nn.ModuleList(blocks)self.relative_pos=nn.Parameter(torch.randn(1,outer_head,num_patches,num_patches//sr_ratio//sr_ratio))defforward(self,inner_tokens,outer_tokens,H_out,W_out,H_in,W_in):forblkinself.blocks:inner_tokens,outer_tokens=blk(inner_tokens,outer_tokens,H_out,W_out,H_in,W_in,self.relative_pos)returninner_tokens,outer_tokensclassPyramidTNT(nn.Module):"""PyramidTNT"""def__init__(self,configs=None,img_size=224,in_chans=3,num_classes=1000,mlp_ratio=4.,qkv_bias=False,qk_scale=None,drop_rate=0.,attn_drop_rate=0.,drop_path_rate=0.,norm_layer=nn.LayerNorm,se=0):super().__init__()self.num_classes=num_classesdepths=configs['depths']outer_dims=configs['outer_dims']inner_dims=configs['inner_dims']outer_heads=configs['outer_heads']inner_heads=configs['inner_heads']sr_ratios=[4,2,1,1]dpr=[x.item()forxintorch.linspace(0,drop_path_rate,sum(depths))]#stochasticdepthdecayruleself.num_features=outer_dims[-1]#num_featuresforconsistencywithothermodelsself.patch_embed=Stem(img_size=img_size,in_chans=in_chans,outer_dim=outer_dims[0],inner_dim=inner_dims[0])num_patches=self.patch_embed.num_patchesnum_words=self.patch_embed.num_wordsself.outer_pos=nn.Parameter(torch.zeros(1,num_patches,outer_dims[0]))self.inner_pos=nn.Parameter(torch.zeros(1,num_words,inner_dims[0]))self.pos_drop=nn.Dropout(p=drop_rate)depth=0self.word_merges=nn.ModuleList([])self.sentence_merges=nn.ModuleList([])self.stages=nn.ModuleList([])#搭建PyramidTNT所需要的4個Stageforiinrange(4):ifi>0:self.word_merges.append(WordAggregation(inner_dims[i-1],inner_dims[i],stride=2))self.sentence_merges.append(SentenceAggregation(outer_dims[i-1],outer_dims[i],stride=2))self.stages.append(Stage(depths[i],outer_dim=outer_dims[i],inner_dim=inner_dims[i],outer_head=outer_heads[i],inner_head=inner_heads[i],num_patches=num_patches//(2**i)//(2**i),num_words=num_words,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,qk_scale=qk_scale,drop=drop_rate,attn_drop=attn_drop_rate,drop_path=dpr[depth:depth+depths[i]],norm_layer=norm_layer,se=se,sr_ratio=sr_ratios[i]))depth+=depths[i]self.norm=norm_layer(outer_dims[-1])#Classifierheadself.head=nn.Linear(outer_dims[-1],num_classes)ifnum_classes>0elsenn.Identity()defforward_features(self,x):inner_tokens,outer_tokens,(H_out,W_out),(H_in,W_in)=self.patch_embed(x)inner_tokens=inner_tokens+self.inner_pos#B*N,8*8,Couter_tokens=outer_tokens+self.pos_drop(self.outer_pos)#B,N,Dforiinrange(4):ifi>0:inner_tokens=self.word_merges[i-1](inner_tokens,H_out,W_out,H_in,W_in)outer_tokens,H_out,W_out=self.sentence_merges[i-1](outer_tokens,H_out,W_out)inner_tokens,outer_tokens=self.stages[i](inner_tokens,outer_tokens,H_out,W_out,H_in,W_in)outer_tokens=self.norm(outer_tokens)returnouter_tokens.mean(dim=1)defforward(self,x):#特徵提取層,可以作為Backbone用到下游任務x=self.forward_features(x)#分類層x=self.head(x)returnx2.3 其他的Tricks

除了修改網絡體系結構外,還採用了幾種Vision Transformer的高級技巧。

在自注意力模塊上添加相對位置編碼,以更好地表示Token之間的相對位置。

前兩個階段利用Linear spatial reduction attention(LSRA)來降低長序列自注意力的計算複雜度。

3實驗3.1 分類

表3顯示了ImageNet-1K分類結果。與原來的TNT相比,PyramidTNT實現了更好的圖像分類精度。例如,與TNT-S相比,使用少1.9B的TNT-S的Top-1精度高0.5%。這裡還將PyramidTNT與其他具有代表性的CNN、MLP和基於Transformer的模型進行了比較。從結果中可以看到PyramidTNT是最先進的Vision Transformer。

3.2 目標檢測

表4報告了「1x」訓練計劃下的目標檢測和實例分割的結果。PyramidTNT-S在One-Stage和Two-Stage檢測器上都顯著優於其他Backbone,且計算成本相似。例如,基於PyramidTNT-S的RetinaNet達到了42.0 AP和57.7AP-L,分別高出使用Swin-Transformer的模型0.5AP和2.2APL。

這些結果表明,PyramidTNT體系結構可以更好地捕獲大型物體的全局信息。金字塔的簡單的上採樣策略和較小的空間形狀使AP-S從一個大規模的推廣。

3.3 實例分割

PyramidTNT-S在Mask R-CNN和Cascade Mask R-CNN上的AP-m可以獲得更好的AP-b和AP-m,顯示出更好的特徵表示能力。例如,在ParamidTNN約束上,MaskR-CNN-S超過Hire-MLPS 的0.9AP-b。

上面論文和代碼下載

後台回覆:PTNT,即可下載上述論文和代碼

後台回覆:CVPR2021,即可下載CVPR 2021論文和代碼開源的論文合集

後台回覆:ICCV2021,即可下載ICCV2021論文和代碼開源的論文合集

後台回覆:Transformer綜述,即可下載最新的3篇Transformer綜述PDF

重磅!Transformer交流群成立

掃碼添加CVer助手,可申請加入CVer-Transformer微信交流群,方向已涵蓋:目標檢測、圖像分割、目標跟蹤、人臉檢測&識別、OCR、姿態估計、超分辨率、SLAM、醫療影像、Re-ID、GAN、NAS、深度估計、自動駕駛、強化學習、車道線檢測、模型剪枝&壓縮、去噪、去霧、去雨、風格遷移、遙感圖像、行為識別、視頻理解、圖像融合、圖像檢索、論文投稿&交流、Transformer、PyTorch和TensorFlow等群。

一定要備註:研究方向+地點+學校/公司+暱稱(如Transformer+上海+上交+卡卡),根據格式備註,可更快被通過且邀請進群

▲長按加小助手微信,進交流群


▲點擊上方卡片,關注CVer公眾號

整理不易,請點讚和在看

arrow
arrow
    全站熱搜
    創作者介紹
    創作者 鑽石舞台 的頭像
    鑽石舞台

    鑽石舞台

    鑽石舞台 發表在 痞客邦 留言(0) 人氣()