Artificial Intelligence

Apply machine learning and deep learning

Graph Neural Networks in MATLAB

Deep neural networks like convolutional neural networks (CNNs) and long-short term memory (LSTM) networks can be applied for image- and sequence-based deep learning tasks. Graph neural networks (GNNs) extend deep learning to graphs, that is structures that encode entities (nodes) and their relationships (edges).
This blog post provides a gentle introduction to GNNs and resources to get you started with GNNs in MATLAB.
 

What are Graphs?

Graphs model the connections in a network and are widely applicable to a variety of physical, biological, and information systems. The structure of a graph is comprised of “nodes” and “edges”. Each node represents an entity, and each edge represents a connection between two nodes. To learn more, see Directed and Undirected Graphs.
Nodes (vertices) represent the fundamental entities in a graph. Each node can correspond to an object, entity, or data point, depending on the application. In social networks, nodes represent users; in molecular graphs, they represent atoms.
Edges (connections) define the relationships or connections between pairs of nodes. In an undirected graph, edges have no direction (e.g., a chemical bond). In a directed graph, edges have a specific direction (e.g., a one-way street). Edges may also have attributes, such as weights to represent the strength or cost of a connection.
An adjacency matrix is a mathematical representation of a graph. It is a square matrix where  A(i, j)  indicates the presence (and optionally the weight) of an edge from node  i  to node  j . In an undirected graph, the matrix is symmetric. For directed graphs, the entries are asymmetric, reflecting the direction of edges. The adjacency matrix is a key tool in graph-based computations, including those performed in GNNs, as it encodes the graph’s structure for machine learning algorithms.
Types of graphs, including undirected, directed, and weighted graph
Figure: Key concepts in graphs
 
Using graph theory, you can model many real-world systems and relationships, as they can be naturally represented as graphs. You can use graphs to model the neurons in a brain, molecular structures, the flight patterns of an airline, road and social networks, and much more. By modeling entities as nodes and their relationships as edges, graph theory provides a structured way to analyze and optimize networks, interactions, and dependencies. In real-world applications, the graph complexity can significantly scale up.
Examples of an undirected and a directed graph
Figure: Examples of complex undirected and directed graphs
 
Graph theory came in handy in my own past research. I used graphs to model the propagation of neuronal activity in the cerebral cortex, when the subjects reacted to an auditory stimulus with the press of a button, and systematically measured and explained the underlying physiological mechanisms for the variations in reaction time. I wonder if GNNs have been discovered a few years back, how I could have used them for modeling cortical activity.
Neuronal activity and reaction times, as a human subject reacts to auditory stimulus with pushing a button
Figure: Directed graph modeling cortical activity when subjects react to auditory stimulus (Paraskevopoulou et al.)
 

What Are Graph Neural Networks?

Graph Neural Networks (GNNs) are deep learning models that use as inputs graph-based representations. GNNs learn from both node features and the structure of the graph itself, enabling rich representation learning for relational data. This makes GNNs a great option for problems involving irregular, non-Euclidean data, such as recommendation systems, fraud detection, and drug discovery.
 

How GNNs Work

A GNN follows a similar paradigm to a CNN in terms of layers, but instead of spatial convolutions, GNNs perform message passing over the graph. GNNs work by iteratively propagating and aggregating information across the graph’s structure, enabling nodes to learn contextualized feature representations based on their neighbors.
Here’s a breakdown of the key steps on how GNNs learn:
  1. Message Passing: Each node aggregates feature information from its neighbors based on the graph structure.
  2. Aggregation: Aggregated features are combined (e.g., summed or averaged) to form a new representation.
  3. Feature Update: The node’s feature vector is updated using a learnable transformation, often followed by a non-linear activation function.
  4. Stacking layers: Multiple GNN layers are stacked to propagate information from farther parts of the graph.
  5. Output Layer: The final layer of a GNN typically produces per-node or whole-graph predictions, supporting both classification and regression, as well as generating latent embeddings for downstream tasks.
  6. Graph Pooling: Some GNNs dynamically modify the graph structure layer-by-layer, similarly to strided convolutions or pooling layers in CNNs, to downsample the graph and improve computational efficiency.

Popular GNN Architectures

Graph Convolutional Network (GCN) - GCNs generalize the idea of convolutions to graphs by applying weighted aggregations of neighboring node features. GCNs work well for node classification and graph-level tasks, but when scaling to large graphs it’s important to consider the need for increased memory requirements and computational resources.
Graph Attention Network (GAT) - GATs introduce attention mechanisms, allowing the model to learn which neighbors are more important by assigning learnable attention weights to each connection. This increases expressiveness but also computation time.
GraphSAGE - GraphSAGE improves scalability by sampling a fixed number of neighbors for each node during training, rather than considering the entire graph.
 

Examples to Get Started

Check out these examples to learn how to implement GCN and GAT networks in MATLAB.
Architecture of a Graph Convolutional Network (GCN) Node Classification with GCN
Architecture of a Graph Attention Network (GAT) Multilabel Graph Classification Using GAT
 

Why Use Graph Neural Networks?

Graph Neural Networks (GNNs) are specifically designed to handle data that is naturally represented as graphs, making them invaluable in domains where relationships and interactions between data points are as important as the data points themselves. In GNNs, these relationships are explicitly modeled as part of the learning process. For example, you can:
  • Use molecular graphs to predict molecular properties for drug discovery and material design.
  • Use knowledge graphs to improve recommendation engines by learning from interconnected entities.
  • Use traffic networks to optimize routing and predict congestion based on road network graphs.
By understanding the advantages and limitations of GNNs, you can make informed decisions about when and how to apply them effectively in your work.
 

Advantages of GNNs

  • Flexible Input Structures - Unlike CNNs or LSTMs, GNNs can handle graph natively. Passing graph data to CNNs or LSTMs requires data manipulation, like interpolating a graph to a fixed grid, but this could cause an interpolation error.
  • Relational Reasoning - GNNs natively capture dependencies and interactions between entities.
  • Node and Graph-Level Learning - GNNs support tasks at different levels, namely predicting attributes of individual nodes (node classification), relationships (link prediction), or entire graphs (graph classification).

Disadvantages of GNNs

  • Scalability - GNNs require message passing across nodes, which can become computationally expensive for large graphs.
  • Over-Smoothing - As the number of layers increases, node representations may become indistinguishable from one another, leading to performance degradation.
  • Data Preparation - Graph data often requires complex preprocessing and feature engineering, as relationships between nodes may not always be straightforward to define.

When to Choose a GNN

Obviously, GNNs are not the answer to all machine learning tasks. The following table provides guidance on when to choose a GNN or another type of AI model.
Task Graph Neural Network Other AI Model
Image Classification Not typically used CNN
Time-Series Forecasting Only if nodes in the time series are related LSTM or transformer model
Social Network Analysis GNN (to capture user relationships) N/A
Drug Discovery GNN (to leverage molecule structure) Traditional machine learning model
   

GNNs for Engineering Applications

Engineers deal with models of the physical world, and therefore, engineering problems can involve complex interconnected systems that can be represented as graphs. GNNs allow engineers to model relational dependencies that might be difficult to capture with traditional neural networks, by encoding spatial, temporal, and relational information explicitly.
Some engineering applications of GNNs are presented below.
Engineering Application Specific Examples Why GNNs?
Power Grids and Electrical Networks Predict system failures, optimize energy. Power grids are inherently graphs, where nodes represent substations, transformers, or generators, and edges represent power lines. GNNs can model the dynamic dependencies and detect weak points in the network.
Fluid Dynamics and Simulation Model fluid flow over networks of pipes or surfaces. GNNs can model interactions between connected regions of fluid to predict flow rates, temperature changes, and pressure variations.
Robotics and Control System Learn interactions between different subsystems in a robotic assembly or a control network. In multi-robot coordination, GNNs can model dynamic relationships between robots to optimize collaboration and path planning.
Material Science Predict material properties (e.g., conductivity, elasticity) from atomic structures. Atoms and bonds form molecular graphs. GNNs excel at learning chemical and physical relationships from graph-based molecular representations.
 

Heat Transfer with GNN

With MATLAB, you can also use GNNs for solving physical problems described by partial differential equations (PDEs). PDEs are typically solved with numerical methods like Finite Element Analysis (FEA), which approximates the solution on a discrete mesh. This mesh naturally forms a graph-like structure which is suitable for GNNs. GNNs can make fast inferences once trained compared to FEA and other numerical methods.

 

To get the code for solving a heat transfer problem using a GNN, go to this repository.

Heat transfer maps
 

Final Thoughts

GNNs extend the capabilities of neural networks by enabling relational reasoning in domains where data is naturally expressed as graphs. With MATLAB, you can combine graph theory, deep learning, and high-fidelity modeling for creating solutions for real-world engineering applications and solving physical problems.
|
  • print

Comments

To leave a comment, please click here to sign in to your MathWorks Account or create a new one.