
極市導讀
本文主要是收集了一些在使用pytorch自帶的amp下loss nan的情況及對應處理方案。>>加入極市CV技術交流群,走在計算機視覺的最前沿
如果要解決問題,首先就要明確原因:為什麼全精度訓練時不會nan,但是半精度就開始nan?這其實分了三種情況:
1&2我想放到後面討論,因為其實大部分報nan都是第三種情況。這裡來先看看3。什麼情況下會出現情況3?這個討論給出了不錯的解釋:
【Nan Loss with torch.cuda.amp and CrossEntropyLoss】https://link.zhihu.com/?target=https%3A//discuss.pytorch.org/t/nan-loss-with-torch-cuda-amp-and-crossentropyloss/108554/17
給大家翻譯翻譯:在使用ce loss 或者 bceloss的時候,會有log的操作,在半精度情況下,一些非常小的數值會被直接捨入到0,log(0)等於啥?——等於nan啊!
於是邏輯就理通了:回傳的梯度因為log而變為nan->網絡參數nan-> 每輪輸出都變成nan。(;´Д`)
How?問題定義清楚,那解決方案就非常簡單了,只需要在涉及到log計算時,把輸入從half精度轉回float32:
x = x.float()x_sigmoid = torch.sigmoid(x)一些思考&廢話這裡我接着討論下我第一次看到nan之後,企圖直接copy別人的解決方案,但解決不掉時踩過的坑。比如:
有些blog會建議你從默認的1e-8 改為 1e-3,比如這篇:【pytorch1.1 半精度訓練 Adam RMSprop 優化器 Nan 問題】https://link.zhihu.com/?target=https%3A//blog.csdn.net/gwb281386172/article/details/104705195
經過上面的分析,我們就能知道為什麼這種方法不行——這個方案是針對優化器的數值穩定性做的修改,而loss計算這一步在優化器之前,如果loss直接nan,優化器的eps是救不回來的(托腮)。
那麼這個方案在哪些場景下有效?——在loss輸出不是nan時(感覺說了一句廢話)。optimizer的eps是保證在進行除法backwards時,分母不出現0時需要加上的微小量。在半精度情況下,分母加上1e-8就仿佛聽君一席話,因此,需要把eps調大一點。
GradScaler是autocast的好夥伴,在官方教程上就和autocast配套使用:
from torch.cuda.amp import autocast, GradScaler...scaler = GradScaler()for epoch in epochs: for input, target in data: optimizer.zero_grad() with autocast(): output = model(input) loss = loss_fn(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()具體原理不是我這篇文章討論的範圍,網上很多教程都說得很清楚了,比如這個就不錯:
【Gemfield:PyTorch的自動混合精度(AMP)】https://zhuanlan.zhihu.com/p/165152789
但是我這裡想討論另一點:scaler.step(optimizer)的運行原理。
在初始化GradScaler的時候,有一個參數enabled,值默認為True。如果為True,那麼在調用scaler方法時會做梯度縮放來調整loss,以防半精度狀況下,梯度值過大或者過小從而被nan或者inf。而且,它還會判斷本輪loss是否是nan,如果是,那麼本輪計算的梯度不會回傳,同時,當前的scale係數乘上backoff_factor,縮減scale的大小_。_
那麼,為什麼這一步已經判斷了loss是不是nan,還是會出現網絡損失持續nan的情況呢?
這時我們就得再往前思考一步了:為什麼loss會變成nan?回到文章一開始說的:
(1)計算loss 時,出現了除以0的情況;
(2)loss過大,被半精度判斷為inf;
(3)網絡直接輸出了nan。
(1)&(2),其實是可以通過scaler.step(optimizer)解決的,分別由optimizer和scaler幫我們捕捉到了nan的異常。但(3)不行,(3)意味着部分甚至全部的網絡參數已經變成nan了。這可能是在更之前的梯度回傳過程中除以0導致的——首先【回傳的梯度不是nan】,所以scaler不會捕捉異常;其次,由於使用了半精度,optimizer接收到了【已經因為精度損失而變為nan的loss】,nan不管加上多大的eps,都還是nan,所以optimizer也無法處理異常,最終導致網絡參數nan。
所以3,只能通過本文一開始提出的方案來解決。其實,大部分分類問題在使用半精度時出現nan的情況都是第3種情況,也只能通過把精度轉回為float32,或者在計算log時加上微小量來避免(但這樣會損失精度)。
參考【Nan Loss with torch.cuda.amp and CrossEntropyLoss】https://discuss.pytorch.org/t/nan-loss-with-torch-cuda-amp-and-crossentropyloss/108554/17
如果覺得有用,就請分享到朋友圈吧!
公眾號後台回復「transformer」獲取最新Transformer綜述論文下載~

#CV技術社群邀請函#

備註:姓名-學校/公司-研究方向-城市(如:小極-北大-目標檢測-深圳)
即可申請加入極市目標檢測/圖像分割/工業檢測/人臉/醫學影像/3D/SLAM/自動駕駛/超分辨率/姿態估計/ReID/GAN/圖像增強/OCR/視頻理解等技術交流群
每月大咖直播分享、真實項目需求對接、求職內推、算法競賽、乾貨資訊匯總、與10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發者互動交流~
