PyTorch Geometric Graph Embedding
Using SAGEConv in PyTorch Geometric module for embedding graphs
Graph representation learning/embedding is commonly the term used for the process where we transform a Graph data structure to a more structured vector form. This enables the downstream analysis by providing more manageable fixed-length vectors. Ideally, these vectors should incorporate both graph structure (topological) information apart from the node features. We use graph neural networks (GNN) to perform this transformation. To have a basic high-level idea about GNNs, you can take a peek at the following article.
In this article, I will talk about the GraphSAGE architecture which is a variant of message passing neural networks (MPNN). MPNN is a fancy term for how GNNs are efficiently implemented.
Generalized GNN representation
Any MPNN can be formally represented using the two functions aggregate and combine.
The aggregate function governs how neighbour information is gathered or aggregated for a given node.
The combine function governs how the information of a node itself is combined with the information from the neighbourhood.
GraphSAGE
GraphSAGE stands for Graph-SAmple-and-aggreGatE. Let’s first define the aggregate and combine functions for GraphSAGE.
Combine — Use element-wise mean of features of neighbours
Aggregate — Concatenate aggregated features with current node features
Graphical Explanation
GraphSAGE layers can be visually represented as follows. For a given node v, we aggregate all neighbours using mean aggregation. The result is concatenated with the node v’s features and fed through a multi-layer perception (MLP) followed by a non-linearity like RELU.
One can easily use a framework such as PyTorch geometric to use GraphSAGE. Before we go there let’s build up a use case to proceed. One major importance of embedding a graph is visualization. Therefore, let’s build a GNN with GraphSAGE to visualize Cora dataset. Note that here I am using the provided example in PyTorch Geometric repository with few tricks.
GraphSAGE Specifics
The key idea of GraphSAGE is sampling strategy. This enables the architecture to scale to very large scale applications. The sampling implies that, at each layer, only up to K number of neighbours are used. As usual, we must use an order invariant aggregator such as Mean, Max, Min, etc.
Loss Function
In graph embedding, we operate in an unsupervised manner. Therefore, we use the graph topological structure to define the loss.
Here Zu demonstrates the final layer output for node u. Zvn indicates the negative sampled node. In simple terms, the second term of the equation indicates that the negated dot product of negative (node u and any random node v) should be maximized. In other words, the cosine distance of random nodes should be further. First-term says otherwise for node v, which is a node that we need to be embedded closer to u. This v is called a positive node, which is typically obtained using a random walk starting from u. Evn~Pn(v) indicates that the negative nodes are taken from a negative sampling approach. In actual implementation, we take direct neighbours as positive samples and random nodes as negative samples.
Building a Graph Embedding Network
We can start by importing the following python modules.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_cluster import random_walk
from sklearn.linear_model import LogisticRegressionimport torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import Planetoid
from torch_geometric.data import NeighborSampler as
RawNeighborSamplerimport umap
import matplotlib.pyplot as plt
import seaborn as sns
Initializing the Cora dataset;
dataset = 'Cora'
path = './data'
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]
Note that the dataset object is a list of subgraphs. In the case of Cora, we have one so we pick the graph in index 0.
Sampler components (here we extend NeighborSampler class’s sample method to create batches with positive and negative samples);
# For each batch and the adjacency matrix
pos_batch = random_walk(row, col, batch,
walk_length=1,
coalesced=False)[:, 1]
# row are source nodes, col are target nodes from Adjacency matrix
# index 1 is taken as positive nodes# Random targets from whole adjacency matrix
neg_batch = torch.randint(0, self.adj_t.size(1), (batch.numel(), ),
dtype=torch.long)
GNN can be declared in PyTorch as follows;
class SAGE(nn.Module):
def __init__(self, in_channels, hidden_channels, num_layers):
super(SAGE, self).__init__()
self.num_layers = num_layers
self.convs = nn.ModuleList()
for i in range(num_layers):
in_channels = in_channels if i == 0 else hidden_channels
self.convs.append(SAGEConv(in_channels,
hidden_channels)) def forward(self, x, adjs):
for i, (edge_index, _, size) in enumerate(adjs):
x_target = x[:size[1]]
x = self.convs[i]((x, x_target), edge_index)
if i != self.num_layers - 1:
x = x.relu()
x = F.dropout(x, p=0.5, training=self.training)
return x def full_forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i != self.num_layers - 1:
x = x.relu()
x = F.dropout(x, p=0.5, training=self.training)
return x
Note that we use SAGEConv layers from PyTorch Geometric framework. In the forward pass, the NeighborSampler provides us with data to be passed over in each layer as data indices. This is a rather complex module so I suggest readers read the Minibatch Algorithm from paper(page 12) and the NeighborSampler module docs from PyTorch Geometric.
Visualization
Without using graph structure, the following is the UMAP plot.
When we embed using GraphSAGE we can have a better embedding as follows;
We can see that the embeddings are much better and well separated compared to naive UMAP embedding. However, this is not perfect and needs more work. But I hope this is a good enough demonstration to convey the idea. 😊
Hope you enjoyed this article!
The complete code and Jupyter Notebook is available here.