close

本文對LightGCN模型部分的代碼進行了解讀,對相應部分進行了簡單的注釋幫助大家理解。筆者第一次嘗試代碼閱讀分享,有什麼不足之處或者建議可以給我留言哦,感謝。

Dropout

在圖上實施dropout,以一定概率忽略一部分邊

def__dropout_x(self, x, keep_prob):# 獲取self.Graph中的大小,下標和值,Graph採用稀疏矩陣的表示方法SparseTensorsize = x.size()index = x.indices().t()values = x.values()# 通過rand得到len(values)數量的隨機數,加上keep_probrandom_index = torch.rand(len(values)) + keep_prob# 通過對這些數字取int使得小於1的為0,在通過bool()將0->false,大於等於1的取Truerandom_index = random_index.int().bool()# 利用上面得到的True,False數組選取下標,從而dropout了為False的下標index = index[random_index]# 由於dropout在訓練和測試過程中的不一致,所以需要除以pvalues = values[random_index]/keep_prob# 得到新的graphg = torch.sparse.FloatTensor(index.t(), values, size)returngdef__dropout(self, keep_prob):ifself.A_split:graph = []forg inself.Graph:graph.append(self.__dropout_x(g, keep_prob))else:graph = self.__dropout_x(self.Graph, keep_prob)returngraph

消息傳播

computer函數是LightGCN類中用於進行圖信息傳播的實現方法,整體上通過在整個圖上進行矩陣計算得到所有用戶和商品的embedding。

defcomputer(self):"""propagate methods for lightGCN"""# 得到所有用戶和所有商品的embeddingusers_emb = self.embedding_user.weightitems_emb = self.embedding_item.weightall_emb = torch.cat([users_emb, items_emb])# torch.split(all_emb , [self.num_users, self.num_items])embs = [all_emb]# 判斷是否需要dropoutifself.config['dropout']:ifself.training:print("droping")g_droped = self.__dropout(self.keep_prob)else:g_droped = self.Graph else:g_droped = self.Graph # 根據層數對圖進行信息傳播和聚合考慮n-hop# 通過稀疏矩陣乘法對Graph進行n_layers次的計算forlayer inrange(self.n_layers):ifself.A_split:temp_emb = []forf inrange(len(g_droped)):temp_emb.append(torch.sparse.mm(g_droped[f], all_emb))side_emb = torch.cat(temp_emb, dim=0)all_emb = side_embelse:all_emb = torch.sparse.mm(g_droped, all_emb)embs.append(all_emb)embs = torch.stack(embs, dim=1)#print(embs.size())# 對每一層得到的輸出求均值,以此將不同層的信息進行融合light_out = torch.mean(embs, dim=1)users, items = torch.split(light_out, [self.num_users, self.num_items])returnusers, items


損失構建

在computer函數計算得到所有用戶和商品經過消息傳播後的embedding之後,getEmbedding根據當前用戶和商品查詢出需要用到的embedding以及當前用戶和商品的原始embedding,即未經GCN的embedding。

傳播後的embedding用於計算bpr損失,原始embedding用於計算L2正則項。

defgetEmbedding(self, users, pos_items, neg_items):# 得到需要計算相似度的用戶和商品的embeddingall_users, all_items = self.computer()users_emb = all_users[users]pos_emb = all_items[pos_items]neg_emb = all_items[neg_items]# 沒經過傳播的embedding,用於後續正則項計算users_emb_ego = self.embedding_user(users)pos_emb_ego = self.embedding_item(pos_items)neg_emb_ego = self.embedding_item(neg_items)returnusers_emb, pos_emb, neg_emb, users_emb_ego, pos_emb_ego, neg_emb_egodefbpr_loss(self, users, pos, neg):(users_emb, pos_emb, neg_emb, userEmb0, posEmb0, negEmb0) = self.getEmbedding(users.long(), pos.long(), neg.long())# 這個損失計算的是LightGCN論文中損失函數中的正則項,即做了一個L2正則reg_loss = (1/2)*(userEmb0.norm(2).pow(2) + posEmb0.norm(2).pow(2) +negEmb0.norm(2).pow(2))/float(len(users))# 通過乘法計算用戶和商品的相似度pos_scores = torch.mul(users_emb, pos_emb)pos_scores = torch.sum(pos_scores, dim=1)neg_scores = torch.mul(users_emb, neg_emb)neg_scores = torch.sum(neg_scores, dim=1)# pair-wise的排序損失loss = torch.mean(torch.nn.functional.softplus(neg_scores - pos_scores))
歡迎乾貨投稿 \論文宣傳\合作交流推薦閱讀

強化學習推薦系統的模型結構與特點總結

一文理解PyTorch:附代碼實例
推薦系統之FM算法原理及實現(附代碼)

由於公眾號試行亂序推送,您可能不再準時收到機器學習與推薦算法的推送。為了第一時間收到本號的乾貨內容, 請將本號設為星標,以及常點文末右下角的「在看」。

喜歡的話點個在看吧👇
arrow
arrow
    全站熱搜
    創作者介紹
    創作者 鑽石舞台 的頭像
    鑽石舞台

    鑽石舞台

    鑽石舞台 發表在 痞客邦 留言(0) 人氣()