close

©作者 |機器之心編輯部

來源 |機器之心


來自 DeepMind 等機構的研究者提出了一個通用神經算法學習器,其能夠學習解決包括排序、搜索、貪心算法、動態規劃、圖形算法等經典算法任務,達到專家模型平均水平。

近年來,基於深度神經網絡的機器學習系統取得了巨大的進步,尤其是在以感知為主的任務方面。這些模型通常需要在分布內泛化,這意味着它們的訓練集和驗證集需要有輸入預期分布。相比之下,想要模型在推理任務上表現出色,這就要求即使在分布外(out-of-distribution, OOD)泛化時模型也能提供合理的輸出。

然而,多數神經網絡在 OOD 方面表現不佳。事實上,可以進行神經推理的架構需要算法對齊、自監督學習等其他算法的輔助。更進一步講,這些模型需要在基於觀察的基礎上,對生成的新知識有一定的穩健性,特別是當這些知識脫離訓練數據域時。

本文中, 來自 DeepMind 等機構的研究者提出一個通用神經算法學習器:具有單一參數集的 GNN,其能夠同時學習解決經典算法任務,包括排序、搜索、貪心算法、動態規劃、圖形算法、字符串算法和幾何算法,達到專家模型平均水平。

具體地,該研究利用 CLRS 基準從實證上表明,就像在感知領域取得的成功一樣,通用算法學習器可以通過整合知識來構建。也就是說,只要我們能學會在單任務模式下很好地執行算法,就有可能在多任務模式下有效地學習算法。

受此啟發,該研究對 CLRS 的輸入表示、訓練機制和處理器架構進行一系列改進,與現有技術相比,改進後的平均單任務性能提高了 20% 多。然後,本文利用這些改進對多任務學習器進行消融實驗。結果表明,通用學習器能夠有效地整合由專家模型捕獲的知識。

論文標題:
A Generalist Neural Algorithmic Learner

論文鏈接:

https://arxiv.org/pdf/2209.11142.pdf

可以說這項研究是一個重要的里程碑,表明即使在具有完全不同的控制流任務中,該研究也可以有意義地整合推理能力,並在多個任務中超過相應的單任務專家的 OOD 性能。

正如佐治亞理工學院機器學習博士生 Aran Komatsuzaki 所總結的:「本文構建了一個通用神經算法學習器,能夠學習執行各種算法的單個 GNN 處理器,例如排序、搜索、動態規劃、路徑查找和幾何。」


研究介紹

研究者提出的通用神經算法學習器如下圖 1 所示。


論文第 3 章是主旨部分,主要介紹了表示、訓練機制和架構的改進,使得單個模型的性能明顯優於之前在 CLRS-30 上發布的 SOTA 技術。


CLRS 基準定義了五種類型的特性:標量(scalar)、分類、掩碼、mask_one 和指針,它們都有自己的編碼和解碼策略以及損失函數。

本文中具體的改進包括但不僅限於:

數據集和訓練:移除 teacher forcing。在評估時,模型無法訪問數據集中的 hint,只能依靠已有的 hint 進行預測。在先前的模型中,訓練期間提供了概率為 0.5 的 ground-truth hint,在沒有 teacher forcing 的情況下,當存在 scalar hints 時,損失傾向於沿軌跡無界增長,從而破壞了訓練的穩定性。

這項工作整合了幾個重要的穩定變化,足以完全消除 teacher forcing 帶來的影響,使訓練與評估保持一致。由於 teacher forcing 的存在,排序算法和 Kruskal 算法的性能顯著下降。在移除了 teacher forcing 之後,本文還對訓練數據進行了擴充,以防止模型過擬合。

Soft hint 傳播。本文將 softmax 用於分類,mask_one 、指針類型、logistic sigmoid 用於掩碼類型。如果沒有這些 soft hints,排序算法的性能會下降(類似於有 teacher forcing 的情況)。

利用編碼器初始化和梯度裁剪提高訓練穩定性。該研究使用 Xavier 進行初始化,從而有效地減少了輸入維度僅為 scalar hint 的初始權重。此外,該研究還對編碼器、解碼器、網絡處理器進行了改進。

對模型改進之後得到一組超參數模型,經過訓練,該模型在 CLRS-30 上達到了 SOTA 性能。下表 1 和表 2 顯示了包括 Memnet、MPNN、PGN 等模型在內的 micro-F_1 得分。



下圖 2 顯示了改進模型與 SOTA 模型之間的比較。本文的模型比次優模型(見表 1)平均性能提高了 20% 以上,並且除了一個算法系列之外,所有算法的性能都比其他模型有了顯著提高。



從實驗可以看出,有兩個算法系列具有顯著的 OOD 性能改進:第一個是幾何算法,現在求解接準確率約 94% OOD ,而之前的最佳結果約為 73%;第二個是字符串算法,模型現在求解準確率超過 49%,而之前的最佳值約為 3%。與之前的 SOTA 相比,本文在 24 種算法中準確率超過 60%,17 種算法的準確率超過 80%,11 種算法的準確率超過 90%。

下圖 3 比較了單任務 Triplet-GMPNN 與多任務模型的性能。



為了獨立評估模型改進的效果,該研究還進行了消融實驗。下圖 4a 顯示了 vanilla 訓練和分塊訓練在性能上的顯著差異;圖 4b 顯示了累積消融的結果:逐漸刪除單個改進部分的結果。




更多閱讀



#投 稿通 道#

讓你的文字被更多人看到




如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。

總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋樑,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。

PaperWeekly 鼓勵高校實驗室或個人,在我們的平台上分享各類優質內容,可以是最新論文解讀,也可以是學術熱點剖析、科研心得或競賽經驗講解等。我們的目的只有一個,讓知識真正流動起來。

📝稿件基本要求:

• 文章確係個人原創作品,未曾在公開渠道發表,如為其他平台已發表或待發表的文章,請明確標註

• 稿件建議以markdown格式撰寫,文中配圖以附件形式發送,要求圖片清晰,無版權問題

• PaperWeekly 尊重原作者署名權,並將為每篇被採納的原創首發稿件,提供業內具有競爭力稿酬,具體依據文章閱讀量和文章質量階梯制結算

📬投稿通道:

• 投稿郵箱:hr@paperweekly.site

• 來稿請備註即時聯繫方式(微信),以便我們在稿件選用的第一時間聯繫作者

• 您也可以直接添加小編微信(pwbot02)快速投稿,備註:姓名-投稿

△長按添加PaperWeekly小編

🔍

現在,在「知乎」也能找到我們了

進入知乎首頁搜索「PaperWeekly」

點擊「關注」訂閱我們的專欄吧


·


arrow
arrow
    全站熱搜
    創作者介紹
    創作者 鑽石舞台 的頭像
    鑽石舞台

    鑽石舞台

    鑽石舞台 發表在 痞客邦 留言(0) 人氣()