Pretraining¶
This notebook uses the bundled Digits JSONL buffer as a self-supervised reconstruction task. There is no supervised label. Instead, the model masks pixel intensity values and learns to reconstruct them from the surrounding pixel context.
Only local dependencies are used. The digit records are already buffered as nested JSONL, so the notebook does not download data or reshape arrays at runtime.
import lightning.pytorch as lit
import polars as pl
import torch
from loguru import logger
import json2vec as j2v
logger.remove()
Each image is represented as a nested pixels array. Every pixel is a small JSON object with row, column, and intensity, which is closer to the structured data json2vec is meant to read.
records = pl.read_ndjson("docs/data/digits.jsonl").head(24)
records.head()
| pixels | digit |
|---|---|
| list[struct[3]] | str |
| [{0,0,0.0}, {0,1,0.0}, … {7,7,0.0}] | "0" |
| [{0,0,0.0}, {0,1,0.0}, … {7,7,0.0}] | "1" |
| [{0,0,0.0}, {0,1,0.0}, … {7,7,0.0}] | "2" |
| [{0,0,0.0}, {0,1,0.0}, … {7,7,0.0}] | "3" |
| [{0,0,0.0}, {0,1,0.0}, … {7,7,0.0}] | "4" |
The pixels array creates a local context encoder. Row and column are categorical position hints with light masking, and intensity is numeric content with p_mask=0.50, so random intensities are hidden and reconstructed during training. The root and pixel array also request embeddings so the notebook can show representation output after training.
model = j2v.Model.from_schema(
j2v.Array(
j2v.Category("row", max_vocab_size=8),
j2v.Category("column", max_vocab_size=8),
j2v.Number("intensity"),
name="pixels",
max_length=64,
embed=True,
),
d_model=16,
n_layers=1,
n_heads=4,
batch_size=8,
embed=True,
optimizer=lambda module: torch.optim.AdamW(module.parameters(), lr=1e-2),
)
model.update(j2v.where("name") == "intensity", p_mask=0.50)
model.update(j2v.where("type") == "category", p_mask=0.05)
The data module streams the nested records through the model schema. No target field is required because masking supplies the training signal.
datamodule = j2v.PolarsDataModule(
model=model,
train=records,
validate=records,
num_workers=0,
persistent_workers=False,
pin_memory=False,
observation_buffer_size=16,
chunk_batch_size=32,
sample_rate=1.0,
)
A single epoch is enough for the documentation example. The important behavior is that the same training loop works whether the task is supervised or mask-based pretraining.
trainer = lit.Trainer(
accelerator="cpu",
max_epochs=1,
logger=False,
enable_progress_bar=False,
enable_model_summary=False,
enable_checkpointing=False,
limit_train_batches=1,
limit_val_batches=1,
)
trainer.fit(model=model, datamodule=datamodule)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
/home/runner/work/json2vec/json2vec/.venv/lib/python3.12/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead. /home/runner/work/json2vec/json2vec/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance. /home/runner/work/json2vec/json2vec/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance. `Trainer.fit` stopped: `max_epochs=1` reached.
The Rich display shows a root record encoder, a nested pixel encoder, and the masked numeric intensity field that drives the pretraining loss. The final lines fetch root and pixel embeddings from the same address-keyed prediction output used by supervised examples.
model
Model [model] batch_size=8 d_model=16 parameters=27,699 arrays=2 fields=3 targets=0 embeds=2
`-- record [root] embed attention=mha n_layers=1 n_heads=4 n_linear=1
`-- pixels [array] embed max_length=64 overflow=head attention=mha n_layers=1 n_heads=4 n_linear=1
|-- row [category] active query=[*].pixels[*].row
| pooling=query weight=1 p_mask=0.05 p_prune=0 n_heads=4 n_linear=1
| max_vocab_size=8 p_unavailable=0.01 topk=[]
|-- column [category] active query=[*].pixels[*].column
| pooling=query weight=1 p_mask=0.05 p_prune=0 n_heads=4 n_linear=1
| max_vocab_size=8 p_unavailable=0.01 topk=[]
`-- intensity [number] active query=[*].pixels[*].intensity
pooling=query weight=1 p_mask=0.5 p_prune=0 n_heads=4 n_linear=1
jitter=0 n_bands=8 offset=4 objective=mae
predictions = model.predict(records.to_dicts()[:2])
record = predictions[j2v.Address("record")]["embedding"]
pixels = predictions[j2v.Address("record", "pixels")]["embedding"]
record[:1], pixels[:1]
([[0.35334512591362, -0.1747734397649765, -0.0918232649564743, -0.18089556694030762, -0.19774794578552246, 0.2640850841999054, -0.2617112100124359, 0.03843586891889572, 0.0944351851940155, 0.22415359318256378, 0.19069485366344452, 0.010278591886162758, 0.48065608739852905, -0.41741475462913513, -0.002361574210226536, -0.3517493009567261]], [[0.5565179586410522, 0.12459292262792587, 0.058676306158304214, -0.12925276160240173, -0.09895370155572891, 0.11117955297231674, -0.1297907829284668, -0.017674975097179413, 0.2913339138031006, 0.2819664478302002, -0.0029879009816795588, -0.1100824624300003, -0.26059725880622864, -0.5942274332046509, 0.0408138707280159, -0.12693989276885986]])