Field Importance¶
This notebook measures how much each input field contributes to a trained supervised model. The model is trained with input p_prune, then each input leaf is temporarily deactivated with active=False during test.
Start with a small Wine classifier. The numeric inputs use p_prune=0.20, so training includes examples where individual inputs are unavailable and the model has to predict from the remaining context.
import lightning.pytorch as lit
import polars as pl
import torch
from loguru import logger
import json2vec as j2v
logger.remove()
records = pl.read_ndjson("docs/data/wine.jsonl")
model = j2v.Model.from_schema(
j2v.Number("alcohol", p_prune=0.20),
j2v.Number("malic_acid", p_prune=0.20),
j2v.Number("color_intensity", p_prune=0.20),
j2v.Number("proline", p_prune=0.20),
j2v.Category("cultivar", target=True, max_vocab_size=4, topk=[2]),
d_model=32,
n_layers=2,
n_heads=4,
batch_size=64,
embed=True,
optimizer=lambda module: torch.optim.AdamW(module.parameters(), lr=1e-3),
)
PolarsDataModule(...) keeps the train, validation, and test encoders tied to the same mutable schema object. That is what makes temporary schema overrides visible to trainer.test(...).
datamodule = j2v.PolarsDataModule(
model,
train=records,
validate=records,
test=records,
num_workers=0,
persistent_workers=False,
pin_memory=False,
observation_buffer_size=256,
chunk_batch_size=32,
sample_rate=1.0,
)
Training is intentionally short. The point is the field importance workflow, not benchmark accuracy.
trainer = lit.Trainer(
max_epochs=32,
logger=False,
enable_progress_bar=False,
enable_model_summary=False,
enable_checkpointing=False,
)
trainer.fit(model=model, datamodule=datamodule)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
/home/runner/work/json2vec/json2vec/.venv/lib/python3.12/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead. /home/runner/work/json2vec/json2vec/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance. /home/runner/work/json2vec/json2vec/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=32` reached.
For scoring, the outer override clears input p_prune so the baseline uses all active inputs. Each inner override deactivates one leaf for a single test run, then restores it.
model.update(j2v.where("type") == "number", p_prune=0.0)
print(f"baseline: {trainer.test(model=model, datamodule=datamodule, verbose=False)[0]['loss/test']}")
for leaf in model.select(j2v.where("type") == "number"):
with model.override(j2v.where("address") == leaf.address, active=False):
print(f"{leaf.name}: {trainer.test(model=model, datamodule=datamodule, verbose=False)[0]['loss/test']}")
/home/runner/work/json2vec/json2vec/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
baseline: 0.5348756313323975 alcohol: 0.6325852870941162 malic_acid: 0.5995643734931946 color_intensity: 0.8460686802864075
proline: 0.6477322578430176
Higher loss after deactivation means the model depended more on that field under this training run.