消息传递网络

1 原理

Message Passing 是图网络中学习 node embedding 的重要方法。

公式

$$
\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}i^{(k-1)}, \square{j \in \mathcal{N} (i)} , \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}j^{(k-1)},\mathbf{ e}{j,i}\right) \right),
$$

x 表示表格节点的 embedding,e 表示边的特征,ϕ 表示 message 函数,□ 表示聚合 aggregation 函数,γ 表示 update 函数。上标表示层的 index,比如说,当 k = 1 时,x 则表示所有输入网络的图结构的数据。

函数

  • propagate(edge_index, size=None, **kwargs)
    这个函数最终会调用 message 和 update 函数。

  • message(**kwargs)
    这个函数定义了对于每个节点对 (xi,xj),怎样生成信息(message)。

  • update(aggr_out, **kwargs)
    这个函数利用聚合好的信息(message)更新每个节点的 embedding。

传播消息

1
torch_geometric.nn.MessagePassing.propagate(edge_index, size=None, **kwargs)
  • 初始化时调用这个函数开始传播消息。输入边的索引和所有其他的用于构造消息和更新节点嵌入的数据。值得注意的是,这个propagate()方法不仅可以在形状为 [N, N] 的对称邻接矩阵中交换消息,还可以在一般的稀疏分配矩阵中交换消息,例如二分图,形状为[N, M]的话只需要 size=(N, M) 作为一个额外的参数就可以。如果设置为None,那么分配的矩阵会被假设为对称的。对于具有两组独立节点和索引的二分图,并且每组保存自己的信息,这种拆分可以通过在传递过程中将信息作为一个元组来标记。例如 x=(x_row, x_col),以此表明不同集合中的节点关系。

创建消息

1
torch_geometric.nn.MessagePassing.message()
  • 以函数的方式构造消息,如果 flow=”source_to_target”,那么对每条边做此操作,即将消息从节点传播到节点;如果 flow=”target_to_source”,那么对每条边做此操作,即将消息从节点传播到节点。任何用于构造消息的输入都可以作为参数传递给 propagate() 函数。另外,通过将_i或者_j加在变量名后面可以将特征映射到相应的节点和,例如 x_i 和 x_j。

更新表示

1
torch_geometric.nn.MessagePassing.update()
  • 以函数 的方式对每一个节点 进行更新节点嵌入的操作。将聚合操作后的输出结果作为第一个参数,任何需要在初始化时传递给 propagate() 函数的参数作为输入。

MessagePassing 类是实现各种图神经网络模型的基础。从一个图神经层到另一个图神经进行的消息传递操作可以分为三个计算层次:

  • 第一层计算函数操作,这属于创建消息的过程。输入为上一层的节点特征和边关系,可以指定消息的流向。(边关系以索引的形式进行查找)利用了图结构数据的边关系。

  • 第二层计算函数操作,这属于邻域聚合的过程。经过第一层的操作,已经按边关系建立了节点间的消息,表示节点的(一阶)邻居节点,即限定了消息只在指定节点的邻域范围内传递。的 add,mean 和 max 都不会因为邻居节点的顺序排列问题而产生不同的结果。利用了图结构数据中的结构信息。

  • 第三层计算函数操作,这属于更新特征的过程。经过邻域上的消息传递之后,将邻域上的信息聚合到目标节点,然后更新节点的特征,作为这一层的输出。

SAGE示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class SAGEConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(SAGEConv, self).__init__(aggr='max')
self.lin = torch.nn.Linear(in_channels, out_channels)
self.act = torch.nn.ReLU()

def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]


edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))


return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

def message(self, x_j):
# x_j has shape [E, in_channels]

x_j = self.lin(x_j)
x_j = self.act(x_j)

return x_j

def update(self, aggr_out, x):
# aggr_out has shape [N, out_channels]


new_embedding = torch.cat([aggr_out, x], dim=1)

new_embedding = self.update_lin(new_embedding)
new_embedding = self.update_act(new_embedding)

return new_embedding