Skip to content

NOTE: Is not up to date. TBD

Learning and optimizing over directions

This notebook provides some examples on generating and working with directions in our package. We provide a conceptual overview on how directions are generated and how one might go about learning a particular set of directions for a given problem.

%load_ext autoreload
%autoreload 2

import torch 
from torch import nn
from torch_geometric. data import Data, Batch
import matplotlib.pyplot as plt


from dect.directions import generate_2d_directions
from dect.nn import ECTLayer, ECTConfig

# Set up hardware configurations
DEVICE = "cuda:0"
/home/ernst/projects/dect/.venv/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

The direction vectors are always in the format [ndims,num_thetas] where ndims is the dimension of the ambient space and num_thetas is the number of directions used. Optionally if one is interested in multiple sets of directions we can provide the directions as a vector of size [num_channels,ndims,num_thetas] in which case we will returned a set of images with each channel corresponding to an ECT with those directions. An example of that use case is shown at the end of this notebook.

# Generate 128 uniformly spaced directions in two dimensions.
num_thetas = 128
v = torch.vstack(
        [
            torch.sin(torch.linspace(0, torch.pi, num_thetas)),
            torch.cos(torch.linspace(0, torch.pi, num_thetas)),
        ]
    )

# Output shape
print(v.shape)
torch.Size([2, 128])

Learning Directions

Now we wish to learn the direction (the uniform angles) that generated the ECT in the previous example.

In order to do so, we initialize all the directions in one spot and make them learnable. For a loss function we use the MSE between the groundtruth and approximation.

# Basic dataset with three points.
points_coordinates = torch.tensor([[0.3, 0.0], [0.7, 0.0]])
point_cloud = Data(x=points_coordinates)

batch = Batch.from_data_list([point_cloud])

# Initialize the ground truth layer
fixed_layer = ECTLayer(
    ECTConfig(
        ect_type="points", # The type of ect
        resolution=32, # Number of discretization steps
        scale=8,
        radius=1.0, # Radius of the ECT 
        normalized=False, # Whether to scale the ECT to the interval [0,1] 
        fixed=True # Learnable or fixed directions.
        ),
    v=generate_2d_directions(num_thetas=32)
)


# Generate all directions at the same spot at (0,1).
num_thetas = 32
v = torch.vstack(
        [
            torch.zeros_like(torch.linspace(0, 2*torch.pi, num_thetas)),
            torch.ones_like(torch.linspace(0, 2*torch.pi, num_thetas)),
        ]
    )


# Initialize the learnable layer
learnable_layer = ECTLayer(
    ECTConfig(
        ect_type="points", # The type of ect
        resolution=32, # Number of discretization steps
        scale=8,
        radius=1.0, # Radius of the ECT 
        normalized=False, # Whether to scale the ECT to the interval [0,1] 
        fixed=False # Learnable or fixed directions.
        ),
    v=v
)
fixedv = fixed_layer.v.numpy()
plt.scatter(fixedv[:,0],fixedv[:,1])
<matplotlib.collections.PathCollection at 0x7ff79b0c23c0>

png

learnablev = learnable_layer.v.detach().numpy()
plt.scatter(learnablev[:,0],learnablev[:,1])
<matplotlib.collections.PathCollection at 0x7ff798f1bed0>

png


# Compute the ECT and plot
ect_groundtruth = fixed_layer(batch)
ect = learnable_layer(batch)

fig, axes = plt.subplots(1,2)

ax = axes[0]
ax.imshow(ect_groundtruth.squeeze().detach().cpu().numpy())

ax = axes[1]
ax.imshow(ect.squeeze().detach().cpu().numpy())
<matplotlib.image.AxesImage at 0x7ff798fd0e10>

png

Now we can start optimizing using pytorch.

optimizer = torch.optim.Adam(learnable_layer.parameters(),lr=0.01)
loss_fn = nn.MSELoss()


for epoch in range(500):
    optimizer.zero_grad()
    ect_pred = learnable_layer(batch)    
    loss = loss_fn(ect_pred, ect_groundtruth)
    loss.backward()
    optimizer.step()

    if epoch % 50 == 0: 
        print(f"Epoch: {epoch}, Loss: {loss.item():.2f}")
Epoch: 0, Loss: 0.24
Epoch: 50, Loss: 0.06


Epoch: 100, Loss: 0.01
Epoch: 150, Loss: 0.00
Epoch: 200, Loss: 0.00
Epoch: 250, Loss: 0.00
Epoch: 300, Loss: 0.00
Epoch: 350, Loss: 0.00
Epoch: 400, Loss: 0.00
Epoch: 450, Loss: 0.00
# Compute the ECT and plot
ect_groundtruth = fixed_layer(batch)
ect = learnable_layer(batch)

fig, axes = plt.subplots(1,2)

ax = axes[0]
ax.imshow(ect_groundtruth.squeeze().detach().cpu().numpy())

ax = axes[1]
ax.imshow(ect.squeeze().detach().cpu().numpy())
<matplotlib.image.AxesImage at 0x7ff798e9f9d0>

png