近年來,谷歌於 2018 年推出的 JAX 迎來了迅猛發展,很多研究者對其寄予厚望,希望它可以取代 TensorFlow 等眾多深度學習框架。但 JAX 是否真的適合所有人使用呢?這篇文章對 JAX 的方方面面展開了深入探討,希望可以給研究者選擇深度學習框架時提供有益的參考。自 2018 年底推出以來,JAX 的受歡迎程度一直在穩步提升。2020 年,DeepMind 宣布使用 JAX 來加速其研究。越來越多來自谷歌大腦(Google Brain)和其他機構的項目也都在使用 JAX。目前,在 JAX 的 GitHub 項目主頁,Star 量已經達到了 16.3k。
項目地址:https://github.com/google/jaxJAX 是一個非常有前途的項目,並且用戶一直在穩步增長。JAX 已經在深度學習、機器人 / 控制系統、貝葉斯方法和科學模擬等諸多領域得到了廣泛應用。
如此,是否意味着 JAX 也將成為下一個大型深度學習框架?近日,發表在 AssemblyAI 博客上的文章《Why You Should (or Shouldn't) Be Using JAX in 2022》中,作者 Ryan O'Connor 為我們深入解讀了 JAX 的概念、使用 JAX 的理由以及是否應該使用 JAX 等。JAX 不是一個深度學習框架或庫,其設計初衷也不是成為一個深度學習框架或庫。簡而言之,JAX 是一個包含可組合函數轉換的數值計算庫。正如我們所看到的,深度學習只是 JAX 功能的一小部分:
JAX 的定位科學計算(Scientific Computing)和函數轉換(Function Transformations)的交叉融合,具有除訓練深度學習模型以外的一系列能力,包括如下:即時編譯(Just-in-Time Compilation)
自動並行化(Automatic Parallelization)
自動向量化(Automatic Vectorization)
自動微分(Automatic Differentiation)
簡而言之,是速度。這是 JAX 與任何用例相關的一種通用能力。讓我們使用 NumPy 和 JAX 對矩陣的前三個冪求和(按元素)。首先是 NumPy 實現。我們發現,該計算大約需要 851 毫秒。JAX 僅在 5.54 毫秒內執行完成該計算,速度是 NumPy 的 150 倍以上。JAX 的速度比 NumPy 快了 N 個數量級。需要注意,JAX 使用的是 TPU,NumPy 使用了 CPU,以此強調 JAX 的速度上限遠高於 NumPy。NumPy 加速器。NumPy 是使用 Python 進行科學計算的基礎包之一,但它僅與 CPU 兼容。JAX 提供了 NumPy 的實現(具有幾乎相同的 API),可以非常輕鬆地在 GPU 和 TPU 上運行。對於許多用戶而言,僅此一項功能就足以證明使用 JAX 的合理性;
XLA。XLA(Accelerated Linear Algebra)是專為線性代數設計的全程序優化編譯器。JAX 建立在 XLA 之上,顯著提高了計算速度上限;
JIT。JAX 允許用戶使用 XLA 將自己的函數轉換為即時編譯(JIT)版本。這意味着可以通過在計算函數中添加一個簡單的函數裝飾器(decorator)來將計算速度提高几個數量級;
Auto-differentiation。JAX 將 Autograd(自動區分原生 Python 代碼和 NumPy 代碼)和 XLA 結合在一起,它的自動微分能力在科學計算的許多領域都至關重要。JAX 提供了幾個強大的自動微分工具;
深度學習。雖然 JAX 本身不是深度學習框架,但它的確為深度學習提供了一個很好的基礎。很多構建在 JAX 之上的庫旨在提供深度學習功能,包括 Flax、Haiku 和 Elegy。甚至在最近的一些 PyTorch 與 TensorFlow 文章中強調了 JAX 作為一個值得關注的「框架」,並推薦其用於基於 TPU 的深度學習研究。JAX 對 Hessians 的高效計算也與深度學習相關,因為它們使高階優化技術更加可行;
通用可微分編程範式(General Differentiable Programming Paradigm )。雖然我們可以使用 JAX 來構建和訓練深度學習模型,但它也為通用可微編程提供了一個框架。這意味着 JAX 可以通過使用基於模型的機器學習方法來解決問題,從而可以利用數十年研究建立起的給定領域的先驗知識。
到目前為止,我們已經討論了 XLA 以及它如何允許 JAX 在加速器上實現 NumPy;但請記住,這只是 JAX 定義的一半。JAX 不僅為強大的科學計算提供了工具,而且還為可組合的函數轉換提供了工具。舉例來說如果我們對標量值函數 f(x) 使用梯度函數轉換,那麼我們將得到一個向量值函數 f'(x),它給出了函數在 f(x) 域中任意點的梯度。在函數上使用 grad() 可以讓我們得到域中任意點的梯度JAX 包含了一個可擴展系統來實現這樣的函數轉換,有四種典型方式:Grad() 進行自動微分;
Vmap() 自動向量化;
Pmap() 並行化計算;
Jit() 將函數轉換為即時編譯版本。
訓練機器學習模型需要反向傳播。在 JAX 中,就像在 Autograd 中一樣,用戶可以使用 grad() 函數來計算梯度。舉例來說,如下是對函數 f(x) = abs(x^3) 求導。我們可以看到,當求 x=2 和 x=-3 處的函數及其導數時,我們得到了預期的結果。那麼 grad() 能微分到什麼程度?JAX 通過重複應用 grad() 使得微分變得很容易,如下程序我們可以看到,輸出函數的三階導數給出了 f'''(x)=6 的恆定預期輸出。可能有人會問,grad() 可以用在哪些方面?標量值函數:grad() 採用標量值函數的梯度,將標量 / 向量映射到標量函數。此外還有向量值函數:對於將向量映射到向量的向量值函數,梯度的類似物是雅可比矩陣。使用 jacfwd() 和 jacrev(),JAX 返回一個函數,該函數在域中的某個點求值時產生雅可比矩陣。從深度學習角度來看,JAX 使得計算 Hessians 變得非常簡單和高效。由於 XLA,JAX 可以比 PyTorch 更快地計算 Hessians,這使得實現諸如 AdaHessian 這樣的高階優化更加快速。下面代碼是在 PyTorch 中對一個簡單的輸入總和進行 Hessian:正如我們所看到的,上述計算大約需要 16.3 ms,在 JAX 中嘗試相同的計算:使用 JAX,計算僅需 1.55 毫秒,比 PyTorch 快 10 倍以上:JAX 可以非常快速地計算 Hessians,使得高階優化更加可行。JAX 在其 API 中還有另一種變換:vmap() 自動向量化。以下是矢量化向量加法展示:
分布式計算變得越來越重要,在深度學習中尤其如此,如下圖所示,SOTA 模型已經發展到超大規模。得益於 XLA,JAX 可以輕鬆地在加速器上進行計算,但 JAX 也可以輕鬆地使用多個加速器進行計算,即使用單個命令 - pmap() 執行 SPMD 程序的分布式訓練。我們以向量矩陣乘法為例,如下為非並行向量矩陣乘法:
使用 JAX,我們可以輕鬆地將這些計算分布在 4 個 TPU 上,只需將操作包裝在 pmap() 中即可。這允許用戶在每個 TPU 上同時執行一個點積,顯着提高了計算速度(對於大型計算而言)。
JIT 編譯是一種執行代碼的方法,介於解釋(interpretation)和 AoT(ahead-of-time)編譯之間。重要的是,JIT 編譯器在運行時將代碼編譯成快速的可執行文件,但代價是首次運行速度較慢。JIT 不是一次將一個操作分配給 GPU 內核,而是使用 XLA 將一系列操作編譯成一個內核,從而為函數提供端到端編譯的高效 XLA 實現。以下圖為例,代碼定義了一個函數:用三種方式計算 5000 x 5000 矩陣——一次使用 NumPy,一次使用 JAX,還有一次在 JIT 編譯的函數版本上使用 JAX。我們首先在 CPU 上進行實驗:JAX 對於逐元素計算明顯更快,尤其是在使用 jit 時。我們看到 JAX 比 NumPy 快 2.3 倍以上,當我們 JIT 函數時,JAX 比 NumPy 快 30 倍。這些結果已經令人印象深刻,但讓我們繼續看,讓 JAX 在 TPU 上進行計算:當 JAX 在 TPU 上執行相同的計算時,它的相對性能會進一步提升(NumPy 計算仍在 CPU 上執行,因為它不支持 TPU 計算)在這種情況下,我們可以看到 JAX 比 NumPy 快了驚人的 13 倍,如果我們同時在 TPU 上 JIT 函數和計算,我們會發現 JAX 比 NumPy 快 80 倍。當然,這種速度的大幅提升是有代價的。JAX 對 JIT 允許的函數進行了限制,儘管通常允許僅涉及上述 NumPy 操作的函數。此外,通過 Python 控制流進行 JIT 處理存在一些限制,因此在編寫函數時須牢記這一點。很遺憾,這個問題的答案還是「視情況而定」。是否遷移到 JAX 取決於你的情況和目標。為具體分析是否應該(或不應該)在 2022 年使用 JAX,這裡將建議匯總到下面的流程圖中,並針對不同的興趣領域提供不同的圖表。如果你對 JAX 在通用計算感興趣,首先要問的問題就是——是否只嘗試在加速器上運行 NumPy?如果答案是肯定的,那麼你顯然應該開始遷移到 JAX。如果你不只處理數字而是參與動態計算建模,那麼是否應該使用 JAX 將取決於具體用例。如果大部分工作是在 Python 中使用大量自定義代碼完成的,那麼開始學習 JAX 以增強工作流程是值得的。如果大部分工作不在 Python 中,但你想構建的是某種基於模型 / 神經網絡的混合系統,那麼使用 JAX 可能是值得的。如果大部分工作不使用 Python,或者你正在使用一些專門的軟件進行研究(熱力學、半導體等),那麼 JAX 可能是不合適的工具,除非你想從這些程序中導出數據,用來做自定義計算。如果你感興趣的領域更接近物理 / 數學並包含計算方法(動力系統、微分幾何、統計物理)並且大部分工作都在例如 Mathematica 上,那麼堅持使用目前的工具才是值得的,特別是在已有大型自定義代碼庫的情形下。雖然我們已經強調過,JAX 不是專為深度學習構建的通用框架,但 JAX 速度很快且具有自動微分功能,你肯定想知道使用 JAX 進行深度學習是什麼樣的。若想在 TPU 上進行訓練,那麼你應該開始使用 JAX,尤其是如果當前正在使用的是 PyTorch。雖然有 PyTorch-XLA 存在,但使用 JAX 進行 TPU 訓練絕對是更好的體驗。如果你正在研究的是「非標準」架構 / 建模,例如 SDE-Nets,那麼也絕對應該嘗試一下 JAX。此外,如果你想利用高階優化技術,JAX 也是要嘗試的東西。如果你不是在構建特殊的架構,只是在 GPU 上訓練常見的架構,那麼你現在可能應該堅持使用 PyTorch 或 TensorFlow。然而,這個建議可能會在未來一兩年內快速發生變化。雖然 PyTorch 仍然在研究領域占據主導地位,但使用 JAX 的論文數量一直在穩步增長。隨着 DeepMind 和谷歌重量級玩家不斷開發用於 JAX 的高級深度學習 API,在幾年內 JAX 可能會出現爆炸性的增長率。這意味着你至少應該稍微熟悉一下 JAX,如果你是研究人員的話更應如此。如果你有興趣了解深度學習並實現一些想法,你應該使用 JAX 或 PyTorch。如果你想自上而下學習深度學習,或有一些 Python 軟件的經驗,則應該從 PyTorch 入手。如果你想自下而上地學習深度學習,或具有數學背景,你可能會發現 JAX 很直觀。在這種情況下,在進行任何大型項目之前,請確保了解如何使用 JAX。如果你對深度學習感興趣,又想轉行相關的職位,那麼你需要使用 PyTorch 或 TensorFlow。儘管最好是同時熟悉兩個框架,但你必須知道 TensorFlow 被普遍認為是「行業」框架,不同框架的職位發布數量證明了這一點:如果你是一個沒有數學或軟件背景但想學習深度學習的初學者,那麼你不會想使用 JAX。相反,Keras 是更好的選擇。雖然上文已經討論了很多 JAX 的正面反饋,它有潛力極大地提升用戶程序的性能。但作者同時列舉了以下四條不該使用 JAX 的理由:JAX 仍然被官方認為是一個實驗性框架。JAX 是一個相對「年輕」的項目。目前,JAX 仍被視為一個研究項目,而不是成熟的谷歌產品,因此如果用戶正在考慮遷移到 JAX,請記住這一點;
使用 JAX 一定要勤勉。調試的時間成本,或者更嚴重的是,未跟蹤副作用(untracked side effects)的風險可能導致那些沒有紮實掌握函數式編程的用戶不適用 JAX。在開始將它用於正式項目之前,請確保自己了解使用 JAX 的常見缺陷;
JAX 沒有針對 CPU 計算進行優化。鑑於 JAX 是以「加速器優先」的方式開發的,因此每個操作的分派並未針對 JAX 進行完全優化。在某些情況下,NumPy 實際上可能比 JAX 更快,尤其是對於小型程序而言,這是因為 JAX 引入了開銷;
JAX 與 Windows 不兼容。目前在 Windows 上不支持 JAX。如果用戶使用 Windows 系統但仍想嘗試 JAX,可以使用 Colab 或將其安裝在虛擬機(VM)上。
原文鏈接:https://www.assemblyai.com/blog/why-you-should-or-shouldnt-be-using-jax-in-2022/
©THE END
轉載請聯繫本公眾號獲得授權
投稿或尋求報道:content@jiqizhixin.com