Introduction to Graph Neural Networks

Intro

A Graph Neural Networks (GNN) takes a data represented as a graph as input and maps it to a learned output. The input graph can be directed or undirected, and can have labeled or non-labeled edges or nodes.

Matrix Representation

A GNN has n nodes and each node has k features. Stack them together in shape n x k: X = (x1, x2, ..., xn)T. Hence, X[i, j] is the value of the jth feature of the ith node. This is known as the node label (or feature) matrix. Unfortunately this leads to an ordering of the nodes. A permutation matrix (shape nxn) will swap rows of X. The permutation matrix is a matrix of one hot vectors, where the column number of the one for a specific row is what the row index of that row will be in the result of the permutation.

A function f(x) is permutation invariant if, for all permutation matrices P, f(PX) = f(X). I.e. the result is the same regardless of if a permutation is performed. One example is the Deep Sets model f(x) = f1(sum(f2(xi) over all xi in x), where f1 and f2 are learnable functions.

Graph Convolution

Graph convolutions provide a parameter efficient way to learn features that are a function of how close nodes are to each other. The idea is this: For a basic graph convolution, let the updated feature node equal the previous feature node transformed by W1 plus the sum of all the adjacent features nodes transformed by W2: \[x_{i}' = x_{i}W_{1}+\sum_{j=1}^{n} x_jW_2+b\] where \(n\) is the number of neighbors node \(i\) has.

Here's an example. Consider the following 4 node graph:

Figure 1: Example graph: 4 nodes, undirected

Let's say nodes have 3 features. Thus, here's the feature matrix with each feature as a one-hot representation of the node type: \[X = \begin{bmatrix} 1 & 0 & 0 \\ 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \\ \end{bmatrix} \] Additionally, let's say W1 and W2 each transform from 3 to 2 features: \[W_1 = \begin{bmatrix} 0.4 & 0.5 \\ 0.3 & 0.8 \\ 0.1 & 0.6 \\ \end{bmatrix} \quad W_2 = \begin{bmatrix} 0.5 & 0.3 \\ 0.8 & 0.2 \\ 0.6 & 0.9 \\ \end{bmatrix} \quad b = \begin{bmatrix} 0.7 & 0.2 \end{bmatrix} \] We will now calculate the convolution of node 2: \[x_2' = x_2 W_1 + (x_1 + x_3) W_2 + b\] \[x_2'= \begin{bmatrix} 2.4 & 1.2 \end{bmatrix}\] In this example we determined which nodes were adjecent to \(x_2\) by looking at the graph. In practice we need a way to automatically determine this. We can use the adjacency matrix for this. An adjacency matrix is a matrix which indicates how nodes are connected. A value in the ith row and jth column indicates the ith node is connected to the jth node. These values can indicate edge weights, or in the case of our unweighted graph they are commonly 1s. Here's the adjacency matrix for our example graph: \[A = \begin{bmatrix} 0 & 1 & 1 & 0 \\ 1 & 0 & 1 & 0 \\ 1 & 1 & 0 & 0\\ 0 & 0 & 1 & 0\\ \end{bmatrix} \] The convolution, which uses the adjacency matrix A to automatically determine neighbors, is: \[X' = XW_1 + AXW_2 + b\]

Message Passing Neural Networks

The convolution described previously is a specific case of the more general message passing neural network (desribed here). In this message passing framework each node has a "hidden state", which is essentially it's embedding (or features) at time step t. Additionally, this message passing framework uses two functions:

  1. A message function \(M_t(h_i^t, h_j^t, e_{ij})\) where \(h_i^t\) is the hidden state of node i at time t, \(h_j^t\) is the hidden state of adjacent node j and time t, and \(e_{ij}\) is the edge connecting node i to node j.
  2. An update function \(U_t(h_i^t, m_i^{t+1})\) where \(m_i^{t+1}\) is the message of node i at time t+1.

The message update \(m_i^{t+1}\) is computed by summing the message update over all the neighbors: \[m_i^{t+1} = \sum_{w \in N(i)} M_t(h_i^t, h_j^t, e_{ij})\] Where N(i) is the set of neighbors of node i.

The hidden state update \(h_i^t\) can then be computed using the update function: \[h_i^{t+1} = U_t(h_i^t, m_i^{t+1})\]

For the convolution example above, the message function \(M_t = h_j^t W_2\) and update function \(U_t = h_i^t W_1 + m_i^{t+1} + b\)

Note that \(h_i^0 = X_i\)

Pairwise Margin-Based Hinge Loss

e is basically a real edge consisting of a triplet that includes a source embedding, destination embedding, and relation embedding. f is a similarity score between the source embedding and the transformed version of the destination embedding: f(e) = similarity(Θ_source, Θ_relation + Θ_destination). So we want this score to be low. e' is like a fake edge that is the same as e but with either the source or destination node replaced with a random node. Hence we want f(e') to be large.