{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Training a GNN on the Mantra Dataset\n", "\n", "In this tutorial, we provide an example use-case for the mantra dataset. We show \n", "how to train a GNN to predict the orientability based on random node features. \n", "\n", "The `torch-geometric` interface for the MANTRA dataset can be installed with \n", "pip via the command \n", "```{python}\n", "pip install mantra\n", "```\n", "\n", "As a preprocessing step we apply three transforms to the base dataset.\n", "Since the dataset does not have intrinsic coordinates attached to the vertices, \n", "we first have to create a transform that generates random node features.\n", "Each manifold in MANTRA comes as a list of triples, where the integers in each \n", "triple are vertex id's. The starting id in each manifold is $1$ and has to be \n", "converted to a torch-geometric compliant $0$-based index.\n", "GNN's are typically trained on graphs and the FaceToEdge transform converts our\n", "manifold to a graph. \n", "\n", "For each of the transforms we use a single class and are succesively applied to\n", "form the final transformed dataset. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Downloading https://github.com/aidos-lab/MANTRADataset/releases/latest/download/2_manifolds.json.gz\n", "Extracting data/simplicial/raw/2_manifolds.json.gz\n", "Processing...\n", "Done!\n" ] } ], "source": [ "# Load all required packages. \n", "import torch \n", "import torch.nn.functional as F\n", "from torch import nn\n", "from torch.utils.data import random_split\n", "\n", "from torchvision.transforms import Compose\n", "\n", "from torch_geometric.loader import DataLoader\n", "from torch_geometric.transforms import Compose, FaceToEdge\n", "\n", "from torch_geometric.nn import GCNConv, global_mean_pool\n", "\n", "# Load the mantra dataset\n", "from mantra.datasets import ManifoldTriangulations\n", "from mantra.transforms import NodeIndex, RandomNodeFeatures\n", "\n", "# Instantiate the dataset. Following the `torch-geometric` API, we download the \n", "# dataset into the root directory. \n", "dataset = ManifoldTriangulations(root=\"./data\", manifold=\"2\", version=\"latest\",\n", " transform=Compose([\n", " NodeIndex(),\n", " RandomNodeFeatures(),\n", " FaceToEdge(remove_faces=True),\n", " ]\n", " )\n", " )\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "train_dataset, test_dataset = random_split(\n", " dataset,\n", " [0.8,0.2\n", " ],\n", " ) # type: ignore\n", "\n", "train_dataloader = DataLoader(train_dataset,batch_size=32)\n", "test_dataloader = DataLoader(test_dataset,batch_size=32)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class GCN(nn.Module):\n", " ''' \n", " A standard Graph Convolutional Neural Network with three layers for \n", " predicting the orientability of the manifold. \n", " Note that since our model only uses the edge information, it is not \n", " able to learn the orientability. It will therefore only serve as \n", " an example. \n", " '''\n", " def __init__(self):\n", " super().__init__()\n", "\n", " self.conv_input = GCNConv(\n", " 8, 16\n", " )\n", " self.final_linear = nn.Linear(\n", " 16, 1\n", " )\n", "\n", " def forward(self, batch):\n", " x, edge_index, batch = batch.x, batch.edge_index, batch.batch\n", " \n", " # 1. Obtain node embeddings\n", " x = self.conv_input(x, edge_index)\n", " # 2. Readout layer\n", " x = global_mean_pool(x, batch) # [batch_size, hidden_channels]\n", " # 3. Apply a final classifier\n", " x = F.dropout(x, p=0.5, training=self.training)\n", " x = self.final_linear(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0, 0.10134144872426987\n", "Epoch 1, 0.09631027281284332\n", "Epoch 2, 0.08791007101535797\n", "Epoch 3, 0.08849794417619705\n", "Epoch 4, 0.08793244510889053\n", "Epoch 5, 0.08987350761890411\n", "Epoch 6, 0.08849993348121643\n", "Epoch 7, 0.08913585543632507\n", "Epoch 8, 0.08902224153280258\n", "Epoch 9, 0.09069301933050156\n" ] } ], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "model = GCN().to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n", "loss_fn = nn.BCEWithLogitsLoss()\n", "\n", "model.train()\n", "for epoch in range(10):\n", " for batch in train_dataloader: \n", " batch.orientable = batch.orientable.to(torch.float)\n", " batch.to(device)\n", " optimizer.zero_grad()\n", " out = model(batch)\n", " loss = loss_fn(out.squeeze(), batch.orientable)\n", " loss.backward()\n", " optimizer.step()\n", " print(f\"Epoch {epoch}, {loss.item()}\")\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.0800\n" ] } ], "source": [ "correct = 0\n", "total = 0\n", "model.eval()\n", "for testbatch in test_dataloader: \n", " testbatch.to(device)\n", " pred = model(testbatch)\n", " correct += ((pred.squeeze() < 0) == testbatch.orientable).sum()\n", " total += len(testbatch)\n", "\n", "acc = int(correct) / int(total)\n", "print(f'Accuracy: {acc:.4f}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }