🔌 Integrations¶
RINGS can be easily integrated into your GNN training pipeline.
The rings.integrations module ships two small utilities:
SeparabilityStudy— a collector that iterates perturbation × seed, applies transforms PyG-idiomatically, records scalar scores from your evaluator, and returns a pairwise separability DataFrame. Use it with plain PyG, Lightning, or any other framework.SeparabilityCallback— a PyTorch Lightning specific callback that records a logged test metric into a study automatically at the end oftrainer.test().
Plain PyG¶
from rings import Original, EmptyGraph, RandomFeatures, CompleteFeatures
from rings.integrations import SeparabilityStudy
study = SeparabilityStudy(
perturbations={
"Original": Original(),
"EmptyGraph": EmptyGraph(),
"RandomFeatures": RandomFeatures(shuffle=True),
"CompleteFeatures": CompleteFeatures(max_nodes=max_nodes),
},
num_seeds=5,
comparator="ks", # or "wilcoxon"
alpha=0.05,
)
for name, transform, seed in study.runs():
perturbed = study.apply(base_dataset, transform)
score = train_and_eval(perturbed, seed=seed) # your code
study.record(name, score)
results = study.evaluate(n_permutations=1000)
Lightning¶
import pytorch_lightning as pl
from rings.integrations import SeparabilityStudy, SeparabilityCallback
for name, transform, seed in study.runs():
pl.seed_everything(seed, workers=True)
dm = make_datamodule(study.apply(base_dataset, transform), seed=seed)
trainer = pl.Trainer(
max_epochs=20,
callbacks=[SeparabilityCallback(study, perturbation_name=name)],
)
trainer.fit(model, datamodule=dm)
trainer.test(model, datamodule=dm)
results = study.evaluate()
Your LightningModule.test_step must call self.log("test_acc", acc) (or whatever metric_key you pass to SeparabilityCallback).
Custom evaluators¶
study.record(name, score) accepts any scalar — plug in GraphBench, OGB evaluators, or your own metric. See examples/integrations/graphbench.py.
Runnable recipes¶
uv run -m examples.integrations.pyg
uv run --with lightning -m examples.integrations.lightning
uv run --with graphbench-lib -m examples.integrations.graphbench
API reference¶
SeparabilityStudy¶
Lightweight collector for running RINGS separability studies inside an existing pipeline.
SeparabilityStudy is intentionally framework-agnostic: it does not own the training
loop, the dataset loader, or the evaluator. The user drives those — the study just
holds the perturbation set, hands out (name, transform, seed) triples to iterate
over, applies a transform to a PyG Data or Dataset, records scalar scores,
and runs SeparabilityFunctor over the collected distributions.
- class rings.integrations.study.SeparabilityStudy(perturbations: Dict[str, Callable], num_seeds: int = 5, comparator: str | Callable = 'ks', alpha: float = 0.01, n_jobs: int = 1)[source]¶
Bases:
objectCollect per-perturbation, per-seed scores from a user-driven training loop and compute pairwise separability across perturbations.
- Parameters:
perturbations (Dict[str, Callable]) – Mapping of perturbation name to a PyG
BaseTransform(e.g.Original(),EmptyGraph()).num_seeds (int, default=5) – Number of seeds to iterate per perturbation. The seed values yielded are
range(num_seeds); the user is responsible for using them to seed any framework RNGs inside their loop.comparator (str or Callable, default="ks") – Either
"ks"/"wilcoxon"or a comparator instance passed straight toSeparabilityFunctor.alpha (float, default=0.01) – Family-wise significance level for the separability test.
n_jobs (int, default=1) – Forwarded to
SeparabilityFunctorfor parallel pairwise comparison.
Examples
>>> from rings import Original, EmptyGraph >>> from rings.integrations import SeparabilityStudy >>> study = SeparabilityStudy( ... perturbations={"Original": Original(), "EmptyGraph": EmptyGraph()}, ... num_seeds=5, ... ) >>> for name, transform, seed in study.runs(): ... dataset = study.apply(base_dataset, transform) ... score = my_train_and_eval(dataset, seed=seed) ... study.record(name, score) >>> results = study.evaluate()
- __init__(perturbations: Dict[str, Callable], num_seeds: int = 5, comparator: str | Callable = 'ks', alpha: float = 0.01, n_jobs: int = 1)[source]¶
- runs() Iterator[Tuple[str, Callable, int]][source]¶
Yield
(perturbation_name, transform, seed)for every perturbation × seed.
- static apply(data: Any, transform: Callable) Any[source]¶
Apply
transformto a singleDataobject or to a PyGDataset.For a
Dataset, this setsdataset.transformso that the transform is applied lazily on each__getitem__call — the PyG-idiomatic pattern. For a singleDataobject, the transform is called directly and the result is returned. Other inputs are passed straight throughtransform(data)as a fallback.
- property scores: Dict[str, ndarray]¶
Recorded scores keyed by perturbation name.
SeparabilityCallback¶
PyTorch Lightning callback for recording test-time metrics into a SeparabilityStudy.
- class rings.integrations.lightning.SeparabilityCallback(study: SeparabilityStudy, perturbation_name: str, metric_key: str = 'test_acc')[source]¶
Bases:
CallbackRecord a Lightning test metric into a
SeparabilityStudyonce pertrainer.test()call.Attach one of these per perturbation run. On
on_test_endit readstrainer.callback_metrics[metric_key]and appends the scalar value to the study underperturbation_name. After looping over all perturbation x seed combinations, callstudy.evaluate()to get the separability DataFrame.- Parameters:
study (SeparabilityStudy) – The study to record into.
perturbation_name (str) – Which perturbation this run corresponds to. Must already be a key in
study.perturbations.metric_key (str, default="test_acc") – Key under which the test metric is logged in
trainer.callback_metrics. The user’sLightningModule.test_step(ortest_epoch_end) must callself.log(metric_key, value)for this to work.
Examples
>>> from rings.integrations import SeparabilityStudy, SeparabilityCallback >>> study = SeparabilityStudy(perturbations={"Original": Original(), ...}) >>> for name, transform, seed in study.runs(): ... dm = build_data_module(study.apply(base_dataset, transform), seed) ... trainer = pl.Trainer(callbacks=[SeparabilityCallback(study, name)]) ... trainer.fit(model, dm) ... trainer.test(model, dm) >>> results = study.evaluate()
- __init__(study: SeparabilityStudy, perturbation_name: str, metric_key: str = 'test_acc')[source]¶