🔌 Integrations¶

RINGS can be easily integrated into your GNN training pipeline.

The rings.integrations module ships two small utilities:

  • SeparabilityStudy — a collector that iterates perturbation × seed, applies transforms PyG-idiomatically, records scalar scores from your evaluator, and returns a pairwise separability DataFrame. Use it with plain PyG, Lightning, or any other framework.

  • SeparabilityCallback — a PyTorch Lightning specific callback that records a logged test metric into a study automatically at the end of trainer.test().

Plain PyG¶

from rings import Original, EmptyGraph, RandomFeatures, CompleteFeatures
from rings.integrations import SeparabilityStudy

study = SeparabilityStudy(
    perturbations={
        "Original":         Original(),
        "EmptyGraph":       EmptyGraph(),
        "RandomFeatures":   RandomFeatures(shuffle=True),
        "CompleteFeatures": CompleteFeatures(max_nodes=max_nodes),
    },
    num_seeds=5,
    comparator="ks",   # or "wilcoxon"
    alpha=0.05,
)

for name, transform, seed in study.runs():
    perturbed = study.apply(base_dataset, transform)
    score = train_and_eval(perturbed, seed=seed)   # your code
    study.record(name, score)

results = study.evaluate(n_permutations=1000)

Lightning¶

import pytorch_lightning as pl
from rings.integrations import SeparabilityStudy, SeparabilityCallback

for name, transform, seed in study.runs():
    pl.seed_everything(seed, workers=True)
    dm = make_datamodule(study.apply(base_dataset, transform), seed=seed)
    trainer = pl.Trainer(
        max_epochs=20,
        callbacks=[SeparabilityCallback(study, perturbation_name=name)],
    )
    trainer.fit(model, datamodule=dm)
    trainer.test(model, datamodule=dm)

results = study.evaluate()

Your LightningModule.test_step must call self.log("test_acc", acc) (or whatever metric_key you pass to SeparabilityCallback).

Custom evaluators¶

study.record(name, score) accepts any scalar — plug in GraphBench, OGB evaluators, or your own metric. See examples/integrations/graphbench.py.

Runnable recipes¶

uv run -m examples.integrations.pyg
uv run --with lightning -m examples.integrations.lightning
uv run --with graphbench-lib -m examples.integrations.graphbench

API reference¶

SeparabilityStudy¶

Lightweight collector for running RINGS separability studies inside an existing pipeline.

SeparabilityStudy is intentionally framework-agnostic: it does not own the training loop, the dataset loader, or the evaluator. The user drives those — the study just holds the perturbation set, hands out (name, transform, seed) triples to iterate over, applies a transform to a PyG Data or Dataset, records scalar scores, and runs SeparabilityFunctor over the collected distributions.

class rings.integrations.study.SeparabilityStudy(perturbations: Dict[str, Callable], num_seeds: int = 5, comparator: str | Callable = 'ks', alpha: float = 0.01, n_jobs: int = 1)[source]¶

Bases: object

Collect per-perturbation, per-seed scores from a user-driven training loop and compute pairwise separability across perturbations.

Parameters:
  • perturbations (Dict[str, Callable]) – Mapping of perturbation name to a PyG BaseTransform (e.g. Original(), EmptyGraph()).

  • num_seeds (int, default=5) – Number of seeds to iterate per perturbation. The seed values yielded are range(num_seeds); the user is responsible for using them to seed any framework RNGs inside their loop.

  • comparator (str or Callable, default="ks") – Either "ks" / "wilcoxon" or a comparator instance passed straight to SeparabilityFunctor.

  • alpha (float, default=0.01) – Family-wise significance level for the separability test.

  • n_jobs (int, default=1) – Forwarded to SeparabilityFunctor for parallel pairwise comparison.

Examples

>>> from rings import Original, EmptyGraph
>>> from rings.integrations import SeparabilityStudy
>>> study = SeparabilityStudy(
...     perturbations={"Original": Original(), "EmptyGraph": EmptyGraph()},
...     num_seeds=5,
... )
>>> for name, transform, seed in study.runs():
...     dataset = study.apply(base_dataset, transform)
...     score = my_train_and_eval(dataset, seed=seed)
...     study.record(name, score)
>>> results = study.evaluate()
__init__(perturbations: Dict[str, Callable], num_seeds: int = 5, comparator: str | Callable = 'ks', alpha: float = 0.01, n_jobs: int = 1)[source]¶
runs() Iterator[Tuple[str, Callable, int]][source]¶

Yield (perturbation_name, transform, seed) for every perturbation × seed.

static apply(data: Any, transform: Callable) Any[source]¶

Apply transform to a single Data object or to a PyG Dataset.

For a Dataset, this sets dataset.transform so that the transform is applied lazily on each __getitem__ call — the PyG-idiomatic pattern. For a single Data object, the transform is called directly and the result is returned. Other inputs are passed straight through transform(data) as a fallback.

record(name: str, score: float) None[source]¶

Record a scalar score for name.

property scores: Dict[str, ndarray]¶

Recorded scores keyed by perturbation name.

evaluate(n_permutations: int = 10000, random_state: int | None = 42, as_dataframe: bool = True)[source]¶

Run pairwise separability on the recorded distributions.

SeparabilityCallback¶

PyTorch Lightning callback for recording test-time metrics into a SeparabilityStudy.

class rings.integrations.lightning.SeparabilityCallback(study: SeparabilityStudy, perturbation_name: str, metric_key: str = 'test_acc')[source]¶

Bases: Callback

Record a Lightning test metric into a SeparabilityStudy once per trainer.test() call.

Attach one of these per perturbation run. On on_test_end it reads trainer.callback_metrics[metric_key] and appends the scalar value to the study under perturbation_name. After looping over all perturbation x seed combinations, call study.evaluate() to get the separability DataFrame.

Parameters:
  • study (SeparabilityStudy) – The study to record into.

  • perturbation_name (str) – Which perturbation this run corresponds to. Must already be a key in study.perturbations.

  • metric_key (str, default="test_acc") – Key under which the test metric is logged in trainer.callback_metrics. The user’s LightningModule.test_step (or test_epoch_end) must call self.log(metric_key, value) for this to work.

Examples

>>> from rings.integrations import SeparabilityStudy, SeparabilityCallback
>>> study = SeparabilityStudy(perturbations={"Original": Original(), ...})
>>> for name, transform, seed in study.runs():
...     dm = build_data_module(study.apply(base_dataset, transform), seed)
...     trainer = pl.Trainer(callbacks=[SeparabilityCallback(study, name)])
...     trainer.fit(model, dm)
...     trainer.test(model, dm)
>>> results = study.evaluate()
__init__(study: SeparabilityStudy, perturbation_name: str, metric_key: str = 'test_acc')[source]¶
on_test_end(trainer, pl_module) None[source]¶

Called when the test ends.