對抗生成網絡GAN應用場景
有一個原始的輸入,想按照我的一個要求或者按照我的一個目標儘可能完美的生成我想要的東西,比如根據夏天的圖片生成冬天的圖片,根據人臉圖片隨機生成一個人臉。
超分辨率重構
告訴網絡什麼樣的是低分辨率,什麼樣的是高分辨率,設計一個損失函數,學一學兩者之間的聯繫,就可以根據低分辨率生成高分辨率的圖片了或者根據側臉生成全臉圖片。
對抗生成網絡定義
生成 :想要什麼生成什麼
對抗:用最強的矛捅最硬的盾
罪犯印假錢,罪犯希望印出來的假幣越真越好,直到能夠騙過警察;而警察希望能夠分辨真假。
隨機的變量(Random Vector)比如100維的向量,這是噪音變量或垃圾變量,然後經過生成網絡(Generator Network)進行轉換得到假的數據,生成網絡希望它生成的假的數據能夠騙過判決器(Discriminator Network)的識別;將真實的數據和假的數據都輸入到判決器網絡中,讓判決器能夠識別真假。判決網絡把生成網絡生成的數據識別為假的,把真實數據識別成真的。
生成網絡可以用傳統的神經網絡去做,比如輸入28x28x1=748個像素點的圖片,最終生成的結果也是28x28x1=748個像素點的圖片
也可以用卷積神經網絡來做(生成圖像數據的話,最終生成一個特徵圖)特徵圖(H W C)C=3就會讓網絡往真實圖片上靠攏
即定義生成的結果是什麼樣的,讓網絡去學習怎樣生成這樣的。無論網絡怎麼定義,最終要做成一件什麼事情,是由損失函數決定的,損失函數決定了整個網絡的走勢。
導入類庫
定義損失函數
損失函數2個參數,一個是預測值,一個是標籤值。
定義隨機的輸入數據
這些輸入數據就當作是預測結果了。
把預測結果全部傳入Sigmoid函數中去,映射為0-1之間的值,滿足了損失函數對於預測參數的要求。
所有的預測結果必須映射到0-1的範圍當中。
判決網絡其實是一個二分類網絡,做的好不好,就是看真的是不是判斷為真的,假的判斷為假的即0/1問題
上述為定義真實的標籤值。有了預測值,有了標籤數據,看損失值是怎麼計算的
損失值計算公式
t[i]是概率值即預測的概率值
log(o[i]) 對數
概率值 x 對數 + (1-概率值)x 對數
代碼實現損失函數公式
第一個值0/1 就是實際的標籤值即真實的類別,1-0/1是錯誤判斷的類別
共有9個樣本(真實值),計算每個樣本的損失值
0-1之間的對數值都是負的,前面加個-號,使得變為正數
相當於9個樣本的損失值加在一起,除以9,求平均值
BCELoss
loss函數第一個參數是預測值,第二個參數是真實值
m是sigmoid函數,將預測值映射到0-1之間
計算結果是手動計算是一樣的,
BCEWithLogitsLoss
構建GAN網絡
損失函數定義好之後,構建GAN網絡
構建生成器網絡
比如輸入28x28x1=784個像素點的圖片,這裡沒有用卷積,而是用最簡單最基本的全連接做的。
in_feat=100 表示輸入100維的向量,out_feat=128 表示中間隱層是128個神經元,128個特徵;然後加上Relu激活函數
第一個block,輸入100維向量,輸出128維的向量;第二個block 輸入128轉換成256;第三個block輸入256轉換成512;第四個block輸入512轉換成1024;要得到的特徵數需跟輸入一致的,輸入是28x28x1=748個特徵,所以需要將1024轉換成748。
構建判決器網絡
第一層是全連接層。判斷一張圖片的真假,輸入784個像素點的圖片(生成網絡生成的圖片和原始圖片是一樣的,都是784個像素點),經過幾個全連接層(512、256)和激活函數,最終得到一個預測值,再把預測值傳入sigmoid函數,將預測概率映射到0-1之間。
使用gpu計算,速度更快,100個epoch,幾分鐘就訓練完了。
數據預處理
如果數據之前沒有下載,就先進行下載,然後進行數據預處理這些常規操作。
定義優化器
訓練的過程
針對一個batch的數據,定義真假標籤
真實數據標籤定義為1
假數據標籤為0
實際數據
這是一個4維的數據:batch size * channel * h * w
zero_grad 表示 梯度清零
z表示隨機構建的100維的向量
imgs.shape[0]表示batch的個數
opt.latent_dim=100 表示100維的向量
random.normal是隨機的高斯分布 均值為0 標準差為1
有多少個batch就生成多少個隨機向量
生成網絡
通過生成網絡把100維的向量轉換成784維的
生成器生成的數據讓判決器認為是真的
第一個參數是預測值,第二個參數是真實值(標籤值)。傳入的標籤值都是1,即告訴判決器,我生成器生成的數據是真的(生成器希望能夠騙到判決器,逃脫判決器的法眼)
判決器認為生成器生成的數據是假的,而生成器希望能夠騙過判決器的法眼,讓判決器認為它生成的數據是真的
判決器認為真實數據是真的
real_imgs
讀進來的實際的數據
判決器得有能力將實際數據認為是真數據
判決器得有能力識別生成的數據是假數據
loss平均值=(真數據損失值+假數據損失值)/2
將GAN生成的數據保存下來
這是100個epoch生成的結果
生成的效果和MNIST數據集手寫數字差不多了,這還是用的最簡單最基本的方法僅迭代100次、全連接網絡的對抗生成網絡GAN。