{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Adding new tasks in MANTRA\n", "\n", "While we released MANTRA together with a set of predefined prediction tasks, MANTRA can actually be used to benchmark neural networks on other interesting topological tasks involving triangulations, such as Euler characteristic prediction or synthetic generation of triangulations. In this tutorial, we'll show you how to extend the base library of MANTRA to consider Euler characteristics of the input values as labels. The process is straightforward and opens up new research possibilities!\n", "\n", "The first step is to generate the Euler characteristic for each input manifold. To do this, we create a new transform that assigns to each input its respective Euler characteristic, formatting this number with a torch.Tensor type." ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "import torch \n", "\n", "class EulerCharacteristic:\n", " def __call__(self, data):\n", " data.euler_characteristic = torch.tensor([sum([((-1)**i)*betti_number for i, betti_number in enumerate(data.betti_numbers)])])\n", " return data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we simply add the transform to the dataset initialization. We made MANTRA to be compatible with the `torch_geometric` dataset API, so you can use transforms, filters, and force reaload." ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "from mantra.datasets import ManifoldTriangulations\n", "from mantra.transforms import NodeIndex, RandomNodeFeatures\n", "from torch_geometric.transforms import Compose, FaceToEdge\n", "\n", "dataset = ManifoldTriangulations(root=\"./data\", manifold=\"2\", version=\"latest\", \n", " transform=EulerCharacteristic)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's train a simple GNN to predict the Euler characteristics. For more information on training models in MANTRA, see the `train_gnn.ipynb` notebook." ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "# Load all required packages. \n", "import torch \n", "import torch.nn.functional as F\n", "from torch import nn\n", "from torch.utils.data import random_split\n", "\n", "from torchvision.transforms import Compose\n", "\n", "from torch_geometric.loader import DataLoader\n", "from torch_geometric.transforms import Compose, FaceToEdge\n", "\n", "from torch_geometric.nn import GCNConv, global_mean_pool\n", "\n", "# Load the mantra dataset\n", "from mantra.datasets import ManifoldTriangulations\n", "from mantra.transforms import NodeIndex, RandomNodeFeatures\n", "\n", "dataset = ManifoldTriangulations(root=\"./data\", manifold=\"2\", version=\"latest\",\n", " transform=Compose([\n", " NodeIndex(),\n", " RandomNodeFeatures(),\n", " FaceToEdge(remove_faces=True),\n", " EulerCharacteristic()\n", " ]\n", " )\n", " )\n", "\n", "train_dataset, test_dataset = random_split(\n", " dataset,\n", " [0.8,0.2],\n", " )\n", "\n", "train_dataloader = DataLoader(train_dataset,batch_size=32)\n", "test_dataloader = DataLoader(test_dataset,batch_size=32)" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "class GCN(nn.Module):\n", " ''' \n", " A standard Graph Convolutional Neural Network with three layers for \n", " predicting the Euler characteristic of the manifold. \n", " '''\n", " def __init__(self):\n", " super().__init__()\n", " self.conv_input = GCNConv(8, 16)\n", " self.final_linear = nn.Linear(16, 1)\n", "\n", " def forward(self, batch):\n", " x, edge_index, batch = batch.x, batch.edge_index, batch.batch\n", " return self.final_linear(F.dropout(global_mean_pool(self.conv_input(x, edge_index), batch), p=0.5, training=self.training))\n" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0, 0.6587832570075989\n", "Epoch 1, 0.6587832570075989\n", "Epoch 2, 0.6926756501197815\n", "Epoch 3, 0.6598934531211853\n", "Epoch 4, 0.6575506329536438\n", "Epoch 5, 0.6574848890304565\n", "Epoch 6, 0.6762682199478149\n", "Epoch 7, 0.6555275321006775\n", "Epoch 8, 0.6584398746490479\n", "Epoch 9, 0.6660548448562622\n" ] } ], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "model = GCN().to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n", "loss_fn = nn.MSELoss()\n", "\n", "model.train()\n", "for epoch in range(10):\n", " for batch in train_dataloader:\n", " batch.euler_characteristic = batch.euler_characteristic.to(torch.float)\n", " batch.to(device)\n", " optimizer.zero_grad()\n", " out = model(batch)\n", " loss = loss_fn(out.squeeze(), batch.euler_characteristic)\n", " loss.backward()\n", " optimizer.step()\n", " print(f\"Epoch {epoch}, {loss.item()}\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "mantra-benchmarks", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }