json2vec Hello World¶
This notebook trains the smallest useful json2vec model: two numeric Iris measurements predict the Iris species. The point is not accuracy. The point is to show the complete loop from records, to schema, to training, to prediction and embeddings. This example is intentionally flat; the next tutorials show why arrays matter.
Start with the normal training dependencies plus the bundled Iris JSONL buffer. The examples remove notebook logging noise so the rendered docs stay focused on model behavior.
import lightning.pytorch as lit
import polars as pl
import torch
from loguru import logger
from rich.pretty import pprint
import json2vec as j2v
logger.remove()
Load a tiny balanced slice of Iris rows. The schema field names match the DataFrame columns, so json2vec can infer the request queries.
records = pl.read_ndjson("docs/data/iris.jsonl").head(36)
records.head()
| sepal_length | sepal_width | petal_length | petal_width | species |
|---|---|---|---|---|
| f64 | f64 | f64 | f64 | str |
| 5.1 | 3.5 | 1.4 | 0.2 | "setosa" |
| 7.0 | 3.2 | 4.7 | 1.4 | "versicolor" |
| 6.3 | 3.3 | 6.0 | 2.5 | "virginica" |
| 4.9 | 3.0 | 1.4 | 0.2 | "setosa" |
| 6.4 | 3.2 | 4.5 | 1.5 | "versicolor" |
The schema declares exactly what the model should read. Number fields become numeric tensorfields, and the Category field is a supervised target because target=True hides it from the input and asks the model to decode it.
model = j2v.Model.from_schema(
j2v.Number("sepal_length"),
j2v.Number("petal_length"),
j2v.Category("species", target=True, max_vocab_size=4, topk=[2]),
d_model=16,
n_layers=1,
n_heads=4,
batch_size=8,
embed=True,
optimizer=lambda module: torch.optim.AdamW(module.parameters(), lr=1e-2),
)
datamodule = j2v.PolarsDataModule(
model=model,
train=records,
validate=records,
num_workers=0,
persistent_workers=False,
pin_memory=False,
observation_buffer_size=32,
sample_rate=1.0,
)
Train for one deliberately small epoch. The tutorials keep batch and epoch counts hardcoded so the example remains quick to run in documentation builds.
trainer = lit.Trainer(
accelerator="cpu",
max_epochs=1,
logger=False,
enable_progress_bar=False,
enable_model_summary=False,
enable_checkpointing=False,
limit_train_batches=1,
limit_val_batches=1,
)
trainer.fit(model=model, datamodule=datamodule)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
/home/runner/work/json2vec/json2vec/.venv/lib/python3.12/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead. /home/runner/work/json2vec/json2vec/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance. /home/runner/work/json2vec/json2vec/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance. `Trainer.fit` stopped: `max_epochs=1` reached.
Prediction uses the same nested batch shape as training: each outer item is one observation, and each observation contains one record.
batch = records.to_dicts()[:3]
pprint(model.predict(batch))
{ │ 'record': { │ │ 'embedding': [ │ │ │ [ │ │ │ │ -0.17086760699748993, │ │ │ │ 0.17301328480243683, │ │ │ │ -0.10288766026496887, │ │ │ │ -0.12322474271059036, │ │ │ │ -0.1786295771598816, │ │ │ │ 0.3295104503631592, │ │ │ │ 0.012937950901687145, │ │ │ │ 0.19766977429389954, │ │ │ │ -0.12257715314626694, │ │ │ │ 0.37761804461479187, │ │ │ │ -0.5092102289199829, │ │ │ │ 0.3248288333415985, │ │ │ │ -0.24244019389152527, │ │ │ │ 0.2886931300163269, │ │ │ │ -0.26312825083732605, │ │ │ │ 0.039841122925281525 │ │ │ ], │ │ │ [ │ │ │ │ -0.08700835704803467, │ │ │ │ 0.18561790883541107, │ │ │ │ -0.12381113320589066, │ │ │ │ -0.03233317658305168, │ │ │ │ -0.20933884382247925, │ │ │ │ 0.4027082026004791, │ │ │ │ -0.08733399957418442, │ │ │ │ 0.18443424999713898, │ │ │ │ -0.1095127984881401, │ │ │ │ 0.3441852331161499, │ │ │ │ -0.5101417303085327, │ │ │ │ 0.33367598056793213, │ │ │ │ -0.22456493973731995, │ │ │ │ 0.23830066621303558, │ │ │ │ -0.29038679599761963, │ │ │ │ 0.019471624866127968 │ │ │ ], │ │ │ [ │ │ │ │ -0.11521251499652863, │ │ │ │ 0.18316768109798431, │ │ │ │ -0.13563768565654755, │ │ │ │ -0.08014936000108719, │ │ │ │ -0.2040325105190277, │ │ │ │ 0.36332935094833374, │ │ │ │ -0.07828681170940399, │ │ │ │ 0.21006029844284058, │ │ │ │ -0.09831167012453079, │ │ │ │ 0.35792550444602966, │ │ │ │ -0.4815004765987396, │ │ │ │ 0.35289523005485535, │ │ │ │ -0.2326280027627945, │ │ │ │ 0.2601192593574524, │ │ │ │ -0.296287477016449, │ │ │ │ 0.027292318642139435 │ │ │ ] │ │ ] │ }, │ 'record/species': { │ │ 'state': { │ │ │ 'valued': [0.5176917314529419, 0.5305065512657166, 0.525334358215332], │ │ │ 'null': [0.11410070210695267, 0.1144195944070816, 0.11281796544790268], │ │ │ 'padded': [0.08085322380065918, 0.08053679764270782, 0.08090561628341675], │ │ │ 'masked': [0.22876456379890442, 0.2176964432001114, 0.22396528720855713], │ │ │ 'other': [0.05858968570828438, 0.05684061720967293, 0.05697666481137276] │ │ }, │ │ 'content': { │ │ │ 'value': ['virginica', 'virginica', 'virginica'], │ │ │ 'probability': [0.444672554731369, 0.445686012506485, 0.4495390057563782], │ │ │ 'topk': [ │ │ │ │ [ │ │ │ │ │ {'label': 'virginica', 'probability': 0.444672554731369}, │ │ │ │ │ {'label': 'versicolor', 'probability': 0.3779090940952301} │ │ │ │ ], │ │ │ │ [ │ │ │ │ │ {'label': 'virginica', 'probability': 0.445686012506485}, │ │ │ │ │ {'label': 'versicolor', 'probability': 0.39173898100852966} │ │ │ │ ], │ │ │ │ [ │ │ │ │ │ {'label': 'virginica', 'probability': 0.4495390057563782}, │ │ │ │ │ {'label': 'versicolor', 'probability': 0.3822685778141022} │ │ │ │ ] │ │ │ ] │ │ } │ } }
Embeddings are opt-in. Passing embed=True when constructing the model includes a root record vector in model.predict(...) for each input observation without changing the schema fields themselves.
The Rich display is the quickest way to verify what was built: array nodes, tensorfield nodes, targets, embeddings, and inferred queries all appear in the same tree.
model
Model [model] batch_size=8 d_model=16 parameters=18,120 arrays=1 fields=3 targets=1 embeds=1
`-- record [root] embed attention=mha n_layers=1 n_heads=4 n_linear=1
|-- sepal_length [number] active query=[*].sepal_length
| pooling=query weight=1 p_mask=0 p_prune=0 n_heads=4 n_linear=1
| jitter=0 n_bands=8 offset=4 objective=mae
|-- petal_length [number] active query=[*].petal_length
| pooling=query weight=1 p_mask=0 p_prune=0 n_heads=4 n_linear=1
| jitter=0 n_bands=8 offset=4 objective=mae
`-- species [category] active target query=[*].species
pooling=query weight=1 p_mask=0 p_prune=1 n_heads=4 n_linear=1
max_vocab_size=4 p_unavailable=0.01 topk=[2]