close
作者丨科技猛獸 編輯丨極市平台

導讀

實數網絡在圖像領域取得極大成功,但在音頻中,信號特徵大多數是複數,如頻譜等。簡單分離實部虛部,或者考慮幅度和相位角都丟失了複數原本的關係。論文按照複數計算的定義,設計了深度複數網絡,能對複數的輸入數據進行卷積、激活、批規範化等操作。這裡對論文提出的幾種複數操作進行介紹,並給出簡單的 Pytorch 實現方法。

雖然叫深度複數網絡,但裡面的操作實際上還是在實數空間進行的。但通過實數的層實現類似於複數計算的操作。

目錄

1 PyTorch 中的複數張量形式

2 複數神經網絡背景

3 複數卷積操作3.1 複數卷積原理3.2 複數卷積 PyTorch 實現3.3 複數的反向傳播3.4 柯西-黎曼方程 (Cauchy–Riemann Equation)

4 複數 LSTM 操作4.1 複數 LSTM 原理4.2 複數 LSTM PyTorch 實現

5 複數激活函數5.1 複數激活函數原理5.2 複數激活函數 PyTorch 實現

6 複數 Dropout6.1 複數 Dropout原理6.2 複數 Dropout PyTorch 實現

7 複數權重初始化7.1 複數權重初始化原理

8 複數 Batch Normalization8.1 複數 BN 原理8.2 複數 BN PyTorch 實現

9 完整模型搭建

1 PyTorch 中的複數張量形式

PyTorch 1.8 及之後都支持2種複數形式的 Tensor,它們分別是:

意味着 torch 中有表示 complex 的張量形式,即:

torch.complex(real, imag, *, out=None) → Tensor

構造一個複數張量,其實部等於 real,虛部等於 imag。

Parameters

real (Tensor): 複數張量的實數部分。必須為 float 或 double。
imag (Tensor): 複數張量的虛部。dtype 必須與實部 real 相同。

關鍵字參數:

out (Tensor): 如果輸入為 torch.float32 ,則必須為 torch.complex64 。如果輸入為 torch.float64 ,則必須為 torch.complex128 。

torch.is_complex(input)

返回 input 是不是複數形式,也就是torch.complex64, 和torch.complex128中的一種。

2 複數神經網絡背景

眾所周知, 從計算、生物和信號處理的角度來看,使用複數有許多優點。所以,複數相對於實數具有更強的表達能力。若能夠藉助複數設計神經網絡,則非常具有吸引力。但是一個難題是如何設計配套的各種網絡的 building block,比如說 complex BN,complex weight initialization 等等。

複數神經網絡也有一些生物學上的優勢,即:若網絡中的數據都是實數,則只能代表某個中間輸出的具體的值的大小;反之,若網絡中的數據都是複數,則不僅能代表某個中間輸出的具體的值的大小 (複數的模長),還可以代表時間的概念 (複數的相位)。具有相似相位的輸入神經元是同步的 (synchronous),因為它們在複數運算中是相加的,而異步神經元相加則具有破壞性 (asynchronous),因此相互干擾。

複數神經網絡也有一些信號處理方面的優勢,即:複數蘊含着相位信息,而語音信號中的相位信息影響其可懂度。奧本海姆的研究表明,在圖像的相位中存在的信息量足以恢復以其幅值編碼的大部分信息。事實上,相位信息在對物體的形狀,邊緣和方向進行編碼時,提供了對物體的詳細描述。

本文開發了適當的工具和一個通用的框架來訓練具有複雜參數的深層神經網絡。

3 複數卷積操作3.1 複數卷積原理

任意的一個複數 ,其實部為 ,虛部為 。作者將複數的實部和虛部表示為邏輯上不同的實值實體,並在內部使用實值算術模擬複數運算。假設一個卷積核,權重是 ,則它可以表示成 個複數權重。

複數域上執行傳統的實值二維卷積:

複數卷積核:

複數輸入張量:

複數卷積過程:

在具體實現中,可以使用下圖1所示的簡單結構實現。

圖1:複數域上執行傳統的實值二維卷積的過程

如下圖1所示,把上式寫成矩陣的形式,就有:


3.2 複數卷積 PyTorch 實現

PyTorch 實現複數的操作基於 apply_complex 這個方法。

def apply_complex(fr, fi, input, dtype = torch.complex64): return (fr(input.real)-fi(input.imag)).type(dtype) \ + 1j*(fr(input.imag)+fi(input.real)).type(dtype)

這個函數需要傳入2個操作 (nn.Conv2d, nn.Linear 等等) 和 torch.complex64 類型的 input。fr(input.real): 卷積核的實部 * (輸入的實部)。fi(input.imag): 卷積核的虛部 * (輸入的虛部)fr(input.imag): 卷積核的實部 * (輸入的虛部)fi(input.real): 卷積核的虛部 * (輸入的實部)input 類型: torch.complex64返回值類型: torch.complex64

因此,利用 Pytorch 的 nn.Conv2D 實現,嚴格遵守上面複數卷積的定義式:

class ComplexConv2d(Module): def __init__(self,in_channels, out_channels, kernel_size=3, stride=1, padding = 0, dilation=1, groups=1, bias=True): super(ComplexConv2d, self).__init__() self.conv_r = Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.conv_i = Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) def forward(self,input): return apply_complex(self.conv_r, self.conv_i, input)

同理還可以實現 Pytorch 的 nn.Linear和 Pytorch 的 nn.ConvTranspose2d:

class ComplexLinear(Module): def __init__(self, in_features, out_features): super(ComplexLinear, self).__init__() self.fc_r = Linear(in_features, out_features) self.fc_i = Linear(in_features, out_features) def forward(self, input): return apply_complex(self.fc_r, self.fc_i, input)class ComplexConvTranspose2d(Module): def __init__(self,in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros'): super(ComplexConvTranspose2d, self).__init__() self.conv_tran_r = ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation, padding_mode) self.conv_tran_i = ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation, padding_mode)

具體實現的思路相似,都是藉助了 apply_complex 函數,傳入2個操作 (nn.Conv2d, nn.Linear 等等) 和 torch.complex64 類型的 input,然後在 ComplexLinear (或 ComplexConvTranspose2d) 中分別計算。

3.3 複數的反向傳播

為了在複數神經網絡中進行反向傳播,一個充分條件是網絡訓練的目標函數和激活函數對網絡中每個 complex parameter 的實部和虛部都是可微的。通常損失函數都是實數,則複數 chain rule 如下:

如果 是實數損失函數, 為復變量,滿足 ,則有:


如果現在有另一個複數 ,且 ,則根據偏導數的鏈式法則:


3.4 柯西-黎曼方程 (Cauchy–Riemann Equation)

一個解析函數的解析性 (Holomorphism, analyticity) 確保一個複變函數在其域中的每個點的鄰域中都是可微的。假設有複變函數 ,和一個複數 ,並有: (和 是 實函數)。這意味着函數 的導數,即: 在 的鄰域內都存在。

,也就是說 可以沿着不同的方向逼近0 (可以沿着實軸,虛軸或中間)。所以說當 沿着實軸方向逼近0時,有:


當 沿着虛軸方向逼近0時,有:


根據式4和式5我們可以推導出:


所以,為了使得上式恆成立,就必須滿足下式:


式7就被稱為柯西-黎曼方程 (Cauchy–Riemann Equation)。所以:

定理1:設函數 定義在區域 內,則 在 內一點 可導的充要條件是 與 在點 可微,並且在該點滿足柯西-黎曼方程 (Cauchy–Riemann Equation)。

4 複數 LSTM 操作

卷積 LSTM 類似於完全連接 LSTM。唯一的區別是,沒有使用矩陣乘法來執行計算,而是使用卷積運算。實值卷積 LSTM 的計算定義如下:


式中, 代表 sigmoid 激活函數, 代表 element-wise 的乘法, 代表實數卷積。 分別是 input gate,forget gate,和 output gate。 和 代表 cell 和 hidden states。

對於復卷積 LSTM,只需用復卷積運算代替實值卷積運算。elementwise multiplication 保持不變,Sigmoid 和 tanh 分別在實部和虛部進行。

5 複數激活函數5.1 複數激活函數原理

複數激活函數在之前的工作裡面已有研究:modReLU 和 zReLU。modReLU 可以表示為:


式中, , 是 的相位, 是可學習的參數。設置參數 的原因是 總是正的,所以參數 使得激活函數可以到達 dead zone 的位置。modReLU 的特點是激活函數前後複數的相位是不變的,但是modReLU 激活函數不滿足柯西-黎曼方程 (Cauchy–Riemann Equations),因此它不是解析的。而複數激活函數需要滿足 Cauchy-Riemann Equations 才是解析的,才能進行複數微分操作。

modReLU 不滿足 Cauchy-Riemann Equations。

接下來是 zReLU 激活函數。


zReLU 在 的時候不滿足,即在x和y的正半軸不滿足 Cauchy-Riemann Equations。

接下來是 CReLU 激活函數。它的設計初衷是要滿足柯西-黎曼方程 (Cauchy–Riemann Equation),所以 CReLU 激活函數分別在實部和虛部上進行激活操作,即:


當實部和虛部同時為嚴格正或嚴格負時,CReLU 激活函數滿足 Cauchy-Riemann Equations,這裡我給個簡單證明。

證明: 設 ,則 。


得證。

CReLU 只在實部虛部同時大於零或同時小於零的時候滿足 Cauchy-Riemann Equations,即在第2,4象限不滿足。

5.2 複數激活函數 PyTorch 實現from torch.nn.functional import reludef complex_relu(input): return relu(input.real).type(torch.complex64)+1j*relu(input.imag).type(torch.complex64)6 複數 Dropout6.1 複數 Dropout 原理

複數 Dropout 個人感覺實部虛部需要同時置0,作者源碼中沒用到 Dropout 層。

6.2 複數 Dropout PyTorch 實現from torch.nn.functional import dropoutdef complex_dropout(input, p=0.5, training=True): # need to have the same dropout mask for real and imaginary part, mask = torch.ones(*input.shape, dtype = torch.float32) mask = dropout(mask, p, training)*1/(1-p) mask.type(input.dtype) return mask*input7 複數權重初始化7.1 複數權重初始化原理

作者介紹了兩種初始化方法的複數形式:Glorot,Kaiming 初始化,複數形式的權重可以表示為:


式中, 和 分別是權重的幅值和相位。權重的方差可以這樣計算:


根據以上2式有:


當一個隨機二維向量的兩個分量呈獨立的、均值為0,有着相同的方差的正態分布時,這個向量的模呈瑞利分布 (Rayleigh Distribution)。

瑞利分布簡介

設 和 是相互獨立的隨機變量,並且均服從零均值的高斯分布:


有一個新的隨機變量 ,它與 和 是的關係是:,則變量 服從瑞利分布。

我們推導一下變量 的概率分布:

代表 和 的聯合分布在以 為半徑的圓內的概率。而 和 的分布是獨立的高斯分布,則有:


所以變量 的概率分布是 。


所以複數權重的模型 的分布服從期望為 ,方差為 的瑞利分布,即參數為 的瑞利分布。我們進一步有:

如果我們按照 Glorot 初始化的方式,則需要滿足 ,此時有: 。如果我們按照 Kaiming 初始化的方式,則需要滿足 ,此時有: 。

到這裡發現 的大小與 的幅值有關,而與具體的相位無關。所以利用 之間的均勻分布來初始化相位。通過執行式12中的相量乘以幅值,就完成了複數權重的初始化。

8 複數 Batch Normalization8.1 複數 BN 原理

首先肯定不能用常規的 BN 方法,否則實部和虛部的分布就不能保證了。但正如常規 BN 方法,首先要對輸入進行0均值1方差的操作,只是方法有所不同。複數 BN 的方法是:


協方差矩陣 是:


通過15式的操作,可以確保輸出的均值為0,協方差為1,相關為0。這裡有一點需要特別注意的是求協方差矩陣 的逆矩陣。


輸入要乘以 的逆平方根,即:


BN 中還有 和 兩個參數,因此最終結果如下:


因為歸一化的輸入 的模長是1,所以 和 被初始化為 。而 , , , 被初始化為0。

8.2 複數 BN PyTorch 實現

定義 self.register_buffer('running_mean', torch.zeros(num_features, dtype = torch.complex64)) 為 BN 的 momentum 的均值;定義 self.register_buffer('running_covar', torch.zeros(num_features,3)) 為 BN 的 momentum 的方差。

class _ComplexBatchNorm(Module): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(_ComplexBatchNorm, self).__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats if self.affine: self.weight = Parameter(torch.Tensor(num_features,3)) self.bias = Parameter(torch.Tensor(num_features,2)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) if self.track_running_stats: self.register_buffer('running_mean', torch.zeros(num_features, dtype = torch.complex64)) self.register_buffer('running_covar', torch.zeros(num_features,3)) self.running_covar[:,0] = 1.4142135623730951 self.running_covar[:,1] = 1.4142135623730951 self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) else: self.register_parameter('running_mean', None) self.register_parameter('running_covar', None) self.register_parameter('num_batches_tracked', None) self.reset_parameters() def reset_running_stats(self): if self.track_running_stats: self.running_mean.zero_() self.running_covar.zero_() self.running_covar[:,0] = 1.4142135623730951 self.running_covar[:,1] = 1.4142135623730951 self.num_batches_tracked.zero_() def reset_parameters(self): self.reset_running_stats() if self.affine: init.constant_(self.weight[:,:2],1.4142135623730951) init.zeros_(self.weight[:,2]) init.zeros_(self.bias)

前向傳播時,首先計算當前輸入的均值:

mean_r = input.real.mean([0, 2, 3]).type(torch.complex64)mean_i = input.imag.mean([0, 2, 3]).type(torch.complex64)mean = mean_r + 1j*mean_i

再進行滑動平均:

if self.training and self.track_running_stats:# update running meanwith torch.no_grad():self.running_mean = exponential_average_factor * mean+ (1 - exponential_average_factor) * self.running_mean

輸入減去均值:

input = input - mean[None, :, None, None]

計算協方差矩陣的值並做滑動平均:

if self.training or (not self.training and not self.track_running_stats):# Elements of the covariance matrix (biased for train)n = input.numel() / input.size(1)Crr = 1./n*input.real.pow(2).sum(dim=[0,2,3])+self.epsCii = 1./n*input.imag.pow(2).sum(dim=[0,2,3])+self.epsCri = (input.real.mul(input.imag)).mean(dim=[0,2,3])else:Crr = self.running_covar[:,0]+self.epsCii = self.running_covar[:,1]+self.epsCri = self.running_covar[:,2]#+self.eps

if self.training and self.track_running_stats:with torch.no_grad():self.running_covar[:,0] = exponential_average_factor * Crr * n / (n - 1)\+ (1 - exponential_average_factor) * self.running_covar[:,0]

self.running_covar[:,1] = exponential_average_factor * Cii * n / (n - 1)\+ (1 - exponential_average_factor) * self.running_covar[:,1]

self.running_covar[:,2] = exponential_average_factor * Cri * n / (n - 1)\+ (1 - exponential_average_factor) * self.running_covar[:,2]

計算協方差矩陣的逆平方根 :

det = Crr*Cii-Cri.pow(2) s = torch.sqrt(det) t = torch.sqrt(Cii+Crr + 2 * s) inverse_st = 1.0 / (s * t) Rrr = (Cii + s) * inverse_st Rii = (Crr + s) * inverse_st Rri = -Cri * inverse_st

乘以 再加上 :

if self.affine: input = (self.weight[None,:,0,None,None]*input.real+self.weight[None,:,2,None,None]*input.imag+\ self.bias[None,:,0,None,None]).type(torch.complex64) \ +1j*(self.weight[None,:,2,None,None]*input.real+self.weight[None,:,1,None,None]*input.imag+\ self.bias[None,:,1,None,None]).type(torch.complex64)

完整的 Complex BN 代碼:

class ComplexBatchNorm2d(_ComplexBatchNorm): def forward(self, input): exponential_average_factor = 0.0 if self.training and self.track_running_stats: if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum if self.training or (not self.training and not self.track_running_stats): # calculate mean of real and imaginary part # mean does not support automatic differentiation for outputs with complex dtype. mean_r = input.real.mean([0, 2, 3]).type(torch.complex64) mean_i = input.imag.mean([0, 2, 3]).type(torch.complex64) mean = mean_r + 1j*mean_i else: mean = self.running_mean if self.training and self.track_running_stats: # update running mean with torch.no_grad(): self.running_mean = exponential_average_factor * mean\ + (1 - exponential_average_factor) * self.running_mean input = input - mean[None, :, None, None] if self.training or (not self.training and not self.track_running_stats): # Elements of the covariance matrix (biased for train) n = input.numel() / input.size(1) Crr = 1./n*input.real.pow(2).sum(dim=[0,2,3])+self.eps Cii = 1./n*input.imag.pow(2).sum(dim=[0,2,3])+self.eps Cri = (input.real.mul(input.imag)).mean(dim=[0,2,3]) else: Crr = self.running_covar[:,0]+self.eps Cii = self.running_covar[:,1]+self.eps Cri = self.running_covar[:,2]#+self.eps if self.training and self.track_running_stats: with torch.no_grad(): self.running_covar[:,0] = exponential_average_factor * Crr * n / (n - 1)\ + (1 - exponential_average_factor) * self.running_covar[:,0] self.running_covar[:,1] = exponential_average_factor * Cii * n / (n - 1)\ + (1 - exponential_average_factor) * self.running_covar[:,1] self.running_covar[:,2] = exponential_average_factor * Cri * n / (n - 1)\ + (1 - exponential_average_factor) * self.running_covar[:,2] # calculate the inverse square root the covariance matrix det = Crr*Cii-Cri.pow(2) s = torch.sqrt(det) t = torch.sqrt(Cii+Crr + 2 * s) inverse_st = 1.0 / (s * t) Rrr = (Cii + s) * inverse_st Rii = (Crr + s) * inverse_st Rri = -Cri * inverse_st input = (Rrr[None,:,None,None]*input.real+Rri[None,:,None,None]*input.imag).type(torch.complex64) \ + 1j*(Rii[None,:,None,None]*input.imag+Rri[None,:,None,None]*input.real).type(torch.complex64) if self.affine: input = (self.weight[None,:,0,None,None]*input.real+self.weight[None,:,2,None,None]*input.imag+\ self.bias[None,:,0,None,None]).type(torch.complex64) \ +1j*(self.weight[None,:,2,None,None]*input.real+self.weight[None,:,1,None,None]*input.imag+\ self.bias[None,:,1,None,None]).type(torch.complex64) return input9 完整模型搭建

使用複數卷積,BN,激活函數搭建一個簡單的完整模型。使用 MNIST 數據集,用文中提到的方法生成虛部。

import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import Subsetfrom torchvision import datasets, transformsfrom complexPyTorch.complexLayers import ComplexBatchNorm2d, ComplexConv2d, ComplexLinearfrom complexPyTorch.complexLayers import ComplexDropout2d, NaiveComplexBatchNorm2dfrom complexPyTorch.complexLayers import ComplexBatchNorm1dfrom complexPyTorch.complexFunctions import complex_relu, complex_max_pool2dbatch_size = 64n_train = 1000n_test = 100trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])train_set = datasets.MNIST('../data', train=True, transform=trans, download=True)train_set = Subset(train_set, torch.arange(n_train))test_set = datasets.MNIST('../data', train=False, transform=trans, download=True)test_set = Subset(test_set, torch.arange(n_test))train_loader = torch.utils.data.DataLoader(train_set, batch_size= batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(test_set, batch_size= batch_size, shuffle=True)class ComplexNet(nn.Module): def __init__(self): super(ComplexNet, self).__init__() self.conv1 = ComplexConv2d(1, 10, 5, 1) self.bn2d = ComplexBatchNorm2d(10, track_running_stats = False) self.conv2 = ComplexConv2d(10, 20, 5, 1) self.fc1 = ComplexLinear(4*4*20, 500) self.dropout = ComplexDropout2d(p = 0.3) self.bn1d = ComplexBatchNorm1d(500, track_running_stats = False) self.fc2 = ComplexLinear(500, 10) def forward(self,x): x = self.conv1(x) x = complex_relu(x) x = complex_max_pool2d(x, 2, 2) x = self.bn2d(x) x = self.conv2(x) x = complex_relu(x) x = complex_max_pool2d(x, 2, 2) x = x.view(-1,4*4*20) x = self.fc1(x) x = self.dropout(x) x = complex_relu(x) x = self.bn1d(x) x = self.fc2(x) x = x.abs() x = F.log_softmax(x, dim=1) return x device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = ComplexNet().to(device)optimizer = torch.optim.SGD(model.parameters(), lr=5e-3, momentum=0.9)def train(model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target =data.to(device).type(torch.complex64), target.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % 100 == 0: print('Train\t Epoch: {:3} [{:6}/{:6} ({:3.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()) ) def test(model, device, test_loader, optimizer, epoch): model.eval() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device).type(torch.complex64), target.to(device) output = model(data) loss = F.nll_loss(output, target) if batch_idx % 100 == 0: print('Test\t Epoch: {:3} [{:6}/{:6} ({:3.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(test_loader.dataset), 100. * batch_idx / len(test_loader), loss.item()) )# Run training on 4 epochsfor epoch in range(4): train(model, device, train_loader, optimizer, epoch) test(model, device, test_loader, optimizer, epoch)

主要參考文獻

1 "Deep Complex Networks"

2 論文作者給出的源碼地址,使用Theano後端的Keras實現:' https://github.com/ChihebTrabelsi/deep_complex_networks'

3 'https://github.com/wavefrontshaping/complexPyTorch' 給出了部分操作的Pytorch實現版本。

4 深度學習:深度複數網絡(Deep Complex Networks)-從論文到pytorch實現:https://www.daimajiaoliu.com/daima/485c571e9100403

如果覺得有用,就請分享到朋友圈吧!


覺得有用麻煩給個在看啦~
arrow
arrow
    全站熱搜
    創作者介紹
    創作者 鑽石舞台 的頭像
    鑽石舞台

    鑽石舞台

    鑽石舞台 發表在 痞客邦 留言(0) 人氣()