Supervised Tabular Training¶
This notebook trains a compact flat classifier on the bundled Wine dataset. It follows the same API as Hello World, but uses a few more numeric inputs and exposes root embeddings during the same run. Use it as a tabular comparison point, not as the main nested-data story.
Import the runtime pieces used in the full training loop: Lightning for optimization and Polars for reading the bundled JSONL records.
import lightning.pytorch as lit
import polars as pl
import torch
from loguru import logger
from rich.pretty import pprint
import json2vec as j2v
logger.remove()
The Wine buffer is flat, but it gives enough numeric variation to make a clearer supervised example than handmade records. The model will use four chemistry fields to predict the cultivar.
records = pl.read_ndjson("docs/data/wine.jsonl").head(48)
records.head()
| alcohol | malic_acid | ash | alcalinity_of_ash | magnesium | total_phenols | flavanoids | nonflavanoid_phenols | proanthocyanins | color_intensity | hue | od280_od315_of_diluted_wines | proline | cultivar |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | str |
| 14.23 | 1.71 | 2.43 | 15.6 | 127.0 | 2.8 | 3.06 | 0.28 | 2.29 | 5.64 | 1.04 | 3.92 | 1065.0 | "class_0" |
| 12.37 | 0.94 | 1.36 | 10.6 | 88.0 | 1.98 | 0.57 | 0.28 | 0.42 | 1.95 | 1.05 | 1.82 | 520.0 | "class_1" |
| 12.86 | 1.35 | 2.32 | 18.0 | 122.0 | 1.51 | 1.25 | 0.21 | 0.94 | 4.1 | 0.76 | 1.29 | 630.0 | "class_2" |
| 13.2 | 1.78 | 2.14 | 11.2 | 100.0 | 2.65 | 2.76 | 0.26 | 1.28 | 4.38 | 1.05 | 3.4 | 1050.0 | "class_0" |
| 12.33 | 1.1 | 2.28 | 16.0 | 101.0 | 2.05 | 1.09 | 0.63 | 0.41 | 3.27 | 1.25 | 1.67 | 680.0 | "class_1" |
The schema is the architecture. Four Number requests feed the root encoder, cultivar is the categorical target, and embed=True asks the root node to return embeddings after training.
model = j2v.Model.from_schema(
j2v.Number("alcohol"),
j2v.Number("malic_acid"),
j2v.Number("color_intensity"),
j2v.Number("proline"),
j2v.Category("cultivar", target=True, max_vocab_size=4, topk=[2]),
d_model=16,
n_layers=1,
n_heads=4,
batch_size=8,
embed=True,
optimizer=lambda module: torch.optim.AdamW(module.parameters(), lr=1e-2),
)
PolarsDataModule(...) reads the schema configuration from the model, so batch size, queries, targets, and tensorfield behavior stay tied to one object.
datamodule = j2v.PolarsDataModule(
model=model,
train=records,
validate=records,
num_workers=0,
persistent_workers=False,
pin_memory=False,
observation_buffer_size=32,
chunk_batch_size=32,
sample_rate=1.0,
)
The tutorial trains for one tiny pass. In a real experiment this is where you would scale epochs, validation splits, callbacks, logging, and checkpointing.
trainer = lit.Trainer(
accelerator="cpu",
max_epochs=1,
logger=False,
enable_progress_bar=False,
enable_model_summary=False,
enable_checkpointing=False,
limit_train_batches=1,
limit_val_batches=1,
)
trainer.fit(model=model, datamodule=datamodule)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
/home/runner/work/json2vec/json2vec/.venv/lib/python3.12/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead. /home/runner/work/json2vec/json2vec/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance. /home/runner/work/json2vec/json2vec/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance. `Trainer.fit` stopped: `max_epochs=1` reached.
The Rich display confirms that the target and root embedding are both configured before inference.
model
Model [model] batch_size=8 d_model=16 parameters=26,006 arrays=1 fields=5 targets=1 embeds=1
`-- record [root] embed attention=mha n_layers=1 n_heads=4 n_linear=1
|-- alcohol [number] active query=[*].alcohol
| pooling=query weight=1 p_mask=0 p_prune=0 n_heads=4 n_linear=1
| jitter=0 n_bands=8 offset=4 objective=mae
|-- malic_acid [number] active query=[*].malic_acid
| pooling=query weight=1 p_mask=0 p_prune=0 n_heads=4 n_linear=1
| jitter=0 n_bands=8 offset=4 objective=mae
|-- color_intensity [number] active query=[*].color_intensity
| pooling=query weight=1 p_mask=0 p_prune=0 n_heads=4 n_linear=1
| jitter=0 n_bands=8 offset=4 objective=mae
|-- proline [number] active query=[*].proline
| pooling=query weight=1 p_mask=0 p_prune=0 n_heads=4 n_linear=1
| jitter=0 n_bands=8 offset=4 objective=mae
`-- cultivar [category] active target query=[*].cultivar
pooling=query weight=1 p_mask=0 p_prune=1 n_heads=4 n_linear=1
max_vocab_size=4 p_unavailable=0.01 topk=[2]
After training, predict returns typed outputs for supervised targets and embeddings from nodes configured with embed=True, keyed by schema address.
batch = records.to_dicts()[:3]
pprint(model.predict(batch))
{ │ 'record': { │ │ 'embedding': [ │ │ │ [ │ │ │ │ -0.5385760068893433, │ │ │ │ -0.32121503353118896, │ │ │ │ -0.23558999598026276, │ │ │ │ -0.11253997683525085, │ │ │ │ -0.04481128975749016, │ │ │ │ 0.26967138051986694, │ │ │ │ 0.40647092461586, │ │ │ │ 0.07385101914405823, │ │ │ │ -0.192938432097435, │ │ │ │ -0.04606473445892334, │ │ │ │ 0.392903208732605, │ │ │ │ 0.25607749819755554, │ │ │ │ -0.12488939613103867, │ │ │ │ 0.02011357806622982, │ │ │ │ 0.08051127195358276, │ │ │ │ 0.10679731518030167 │ │ │ ], │ │ │ [ │ │ │ │ -0.5507907271385193, │ │ │ │ -0.2999776303768158, │ │ │ │ -0.2469239979982376, │ │ │ │ -0.1592492014169693, │ │ │ │ -0.016163593158125877, │ │ │ │ 0.26637452840805054, │ │ │ │ 0.38987767696380615, │ │ │ │ 0.05987614765763283, │ │ │ │ -0.15685449540615082, │ │ │ │ -0.059313200414180756, │ │ │ │ 0.35128483176231384, │ │ │ │ 0.27974316477775574, │ │ │ │ -0.1630808711051941, │ │ │ │ 0.039333175867795944, │ │ │ │ 0.09404842555522919, │ │ │ │ 0.16351932287216187 │ │ │ ], │ │ │ [ │ │ │ │ -0.5559422373771667, │ │ │ │ -0.32212018966674805, │ │ │ │ -0.24519935250282288, │ │ │ │ -0.12109876424074173, │ │ │ │ -0.021834565326571465, │ │ │ │ 0.2611331641674042, │ │ │ │ 0.39810478687286377, │ │ │ │ 0.06025782972574234, │ │ │ │ -0.19020524621009827, │ │ │ │ -0.04631790146231651, │ │ │ │ 0.3558952510356903, │ │ │ │ 0.2762376666069031, │ │ │ │ -0.11466831713914871, │ │ │ │ 0.03663596138358116, │ │ │ │ 0.07989107817411423, │ │ │ │ 0.1393829584121704 │ │ │ ] │ │ ] │ }, │ 'record/cultivar': { │ │ 'state': { │ │ │ 'valued': [0.6327182650566101, 0.6467277407646179, 0.6338515877723694], │ │ │ 'null': [0.06148093566298485, 0.06085231527686119, 0.06135614216327667], │ │ │ 'padded': [0.08102165162563324, 0.07384219765663147, 0.07954676449298859], │ │ │ 'masked': [0.07567285746335983, 0.07698606699705124, 0.07719294726848602], │ │ │ 'other': [0.1491062343120575, 0.14159157872200012, 0.14805245399475098] │ │ }, │ │ 'content': { │ │ │ 'value': ['class_0', 'class_2', 'class_0'], │ │ │ 'probability': [0.36934828758239746, 0.3584912419319153, 0.3644393980503082], │ │ │ 'topk': [ │ │ │ │ [ │ │ │ │ │ {'label': 'class_0', 'probability': 0.36934828758239746}, │ │ │ │ │ {'label': 'class_2', 'probability': 0.3443358540534973} │ │ │ │ ], │ │ │ │ [ │ │ │ │ │ {'label': 'class_2', 'probability': 0.3584912419319153}, │ │ │ │ │ {'label': 'class_0', 'probability': 0.3538542687892914} │ │ │ │ ], │ │ │ │ [ │ │ │ │ │ {'label': 'class_0', 'probability': 0.3644393980503082}, │ │ │ │ │ {'label': 'class_2', 'probability': 0.34715282917022705} │ │ │ │ ] │ │ │ ] │ │ } │ } }