Custom Data Types¶
Tensorfields are the unit representations of data values. A tensorfield owns how data values are validated, tensorized, embedded, decoded, trained, and written.
json2vec includes many built-in tensorfield types:
CategoryDatePartsNumberEntitySetTextVector
There are some cases in which the built-in tensorfields do not cover necessary functionality. If your requirements are not met by the built-in tensorfields, you may build your own tensorfield as a plugin. All new generic tensorfields are welcome to be merged into the upstream repository.
Tensorfields are defined by several key components:
RequestTensorFieldEmbedderDecoderlosswrite
The following example creates a small Bucket extension that turns a numeric Iris measurement into learned bucket tokens.
Requirements¶
The imports include the extension base classes from json2vec, plus TensorDict and tensorclass for storing encoded field tensors. This is an advanced guide; read Query Paths first if the request/query split is still unfamiliar.
from pathlib import Path
from typing import Literal
import lightning.pytorch as lit
import numpy as np
import polars as pl
import torch
from rich.pretty import pprint
from tensordict import TensorDict, tensorclass
import json2vec as j2v
from json2vec.data.processing import pad
from json2vec.structs.packages import Parcel, Prediction
Register A New Field Type¶
A plugin name becomes the tensorfield type. Reusing a name overrides the previous plugin and emits a warning, which is convenient in notebooks where cells may be re-run. The request class defines user-facing schema options.
bucket = j2v.Plugin(name="bucket")
@bucket.register
class Request(j2v.RequestBase):
type: Literal["bucket"] = "bucket"
boundaries: list[float, ...]
@property
def n_bins(self) -> int:
return len(self.boundaries) + 1
# `Bucket` is an alias to the extensions' request. The request is materialized within the model to its corresponding components
Bucket = Request
Tensorize Values¶
TensorField.new(...) converts raw query results into tensors. This example pads nested values, converts numeric values into bucket ids, tracks missing/padded state separately, and records hidden values as training targets when masking or pruning is applied.
@bucket.register
@tensorclass
class TensorField(j2v.TensorFieldBase):
content: torch.Tensor
state: torch.Tensor
trainable: torch.Tensor
targets: TensorDict[j2v.TensorKey, torch.Tensor]
@classmethod
def new(cls, values: list, address: j2v.Address, hyperparameters: j2v.Hyperparameters, strata: j2v.Strata):
request: Request = hyperparameters.requests[address]
shape = (len(values), *hyperparameters.shapes[address])
data, state = pad(nested=values, shape=shape, dtype=np.float32, pad_value=np.nan)
binned = np.digitize(np.nan_to_num(data, nan=0.0), request.boundaries).astype(np.int64)
state_tensor = torch.tensor(state, dtype=torch.int64)
return cls(
content=torch.tensor(binned, dtype=torch.int64),
state=state_tensor,
trainable=torch.zeros_like(state_tensor, dtype=torch.bool),
targets=TensorDict({}),
batch_size=len(values),
)
@classmethod
def empty(cls, batch_size: int, address: j2v.Address, hyperparameters: j2v.Hyperparameters):
shape = (batch_size, *hyperparameters.shapes[address])
state = torch.full(shape, int(j2v.Tokens.masked), dtype=torch.int64)
return cls(
content=torch.zeros(shape, dtype=torch.int64),
state=state,
trainable=torch.zeros_like(state, dtype=torch.bool),
targets=TensorDict({}),
batch_size=batch_size,
)
def _hide(self, selected: torch.Tensor) -> None:
selected = selected & self.state.eq(j2v.Tokens.valued)
if j2v.TensorKey.state not in self.targets.keys():
self.targets[j2v.TensorKey.state] = self.state.clone()
if j2v.TensorKey.content not in self.targets.keys():
self.targets[j2v.TensorKey.content] = self.content.clone()
self.content = self.content.masked_fill(selected, 0)
self.state = self.state.masked_fill(selected, int(j2v.Tokens.masked))
self.trainable |= selected
def mask(self, p_mask: float):
selected = torch.rand_like(self.state, dtype=torch.float32).lt(p_mask)
self._hide(selected)
def target(self, p_prune: float = 1.0):
selected = torch.rand(self.state.size(0), *([1] * (self.state.ndim - 1)), device=self.state.device)
selected = selected.lt(p_prune).expand_as(self.state)
self._hide(selected)
Embed, Decode, And Train¶
The embedder maps field tensors into the shared model width. The decoder maps pooled context back to bucket logits. The loss trains only positions marked as trainable, and write converts prediction tensors into plain Python-friendly output.
@bucket.register
class Embedder(j2v.EmbedderBase):
def __init__(self, hyperparameters: j2v.Hyperparameters, address: j2v.Address):
super().__init__(hyperparameters=hyperparameters, address=address)
request: Request = hyperparameters.requests[address]
self.origin = address
self.destination = request.parent.address
self.state_embedding = torch.nn.Embedding(len(j2v.Tokens), hyperparameters.d_model)
self.bucket_embedding = torch.nn.Embedding(request.n_bins, hyperparameters.d_model)
def forward(self, inputs: j2v.TensorFieldBase) -> Parcel:
return Parcel(
payload=self.state_embedding(inputs.state) + self.bucket_embedding(inputs.content),
origin=self.origin,
destination=self.destination,
batch_size=inputs.batch_size[0],
)
@bucket.register
class Decoder(j2v.DecoderBase):
def __init__(self, hyperparameters: j2v.Hyperparameters, address: j2v.Address):
super().__init__(hyperparameters=hyperparameters, address=address)
request: Request = hyperparameters.requests[address]
self.linear = torch.nn.Linear(hyperparameters.d_model, request.n_bins)
def decode(self, pooled: torch.Tensor) -> TensorDict[j2v.TensorKey, torch.Tensor]:
return TensorDict({j2v.TensorKey.content: self.linear(pooled)})
@bucket.register
def loss(module: j2v.Model, prediction: Prediction, batch: j2v.TensorFieldBase, strata: j2v.Strata):
logits = prediction.payload[j2v.TensorKey.content].reshape(-1, prediction.payload[j2v.TensorKey.content].shape[-1])
targets = batch.targets[j2v.TensorKey.content].reshape(-1)
trainable = batch.trainable.reshape(-1)
if not bool(trainable.any()):
return logits.sum() * 0.0
return module.track(
(prediction.address, strata, j2v.Metric.loss, j2v.TensorKey.content),
value=torch.nn.functional.cross_entropy(logits[trainable], targets[trainable]),
)
@bucket.register
def write(module: j2v.Model, prediction: Prediction):
request: Request = module.hyperparameters.requests[prediction.address]
if request.boundaries:
labels = np.array(
[
f"(-inf, {request.boundaries[0]})",
*[f"[{lower}, {upper})" for lower, upper in zip(request.boundaries, request.boundaries[1:])],
f"[{request.boundaries[-1]}, inf)",
],
dtype=object,
)
else:
labels = np.array(["(-inf, inf)"], dtype=object)
logits = prediction.payload[j2v.TensorKey.content]
probabilities = logits.softmax(dim=-1)
indices = probabilities.argmax(dim=-1).detach().cpu().numpy()
return {
"buckets": labels[indices],
"indices": indices,
"probability": probabilities.detach().cpu().numpy(),
}
Use It In A Model¶
Once the plugin is registered, it behaves like any built-in field constructor. The example uses buffered Iris petal_width, but the model sees only ordinary JSON records.
data_path = Path("docs/data/iris.jsonl")
if not data_path.exists():
data_path = Path("../data/iris.jsonl")
records = pl.read_ndjson(data_path)
records.head()
| sepal_length | sepal_width | petal_length | petal_width | species |
|---|---|---|---|---|
| f64 | f64 | f64 | f64 | str |
| 5.1 | 3.5 | 1.4 | 0.2 | "setosa" |
| 7.0 | 3.2 | 4.7 | 1.4 | "versicolor" |
| 6.3 | 3.3 | 6.0 | 2.5 | "virginica" |
| 4.9 | 3.0 | 1.4 | 0.2 | "setosa" |
| 6.4 | 3.2 | 4.5 | 1.5 | "versicolor" |
The custom field appears in the Rich display as [bucket], with its own target marker.
model = j2v.Model.from_schema(
j2v.Number("sepal_length"),
j2v.Number("petal_length"),
Bucket("petal_width", boundaries=[0.8, 1.8], target=True),
d_model=16,
n_layers=1,
n_heads=4,
batch_size=8,
optimizer=lambda module: torch.optim.AdamW(module.parameters(), lr=1e-2),
)
2026-06-09 20:47:34.579 | INFO | json2vec.architecture.root:__init__:167 - initialized Model module
model
Model [model] batch_size=8 d_model=16 parameters=18,002 arrays=1 fields=3 targets=1 embeds=0
`-- record [root] attention=mha n_layers=1 n_heads=4 n_linear=1
|-- sepal_length [number] active query=[*].sepal_length
| pooling=query weight=1 p_mask=0 p_prune=0 n_heads=4 n_linear=1
| jitter=0 n_bands=8 offset=4 objective=mae
|-- petal_length [number] active query=[*].petal_length
| pooling=query weight=1 p_mask=0 p_prune=0 n_heads=4 n_linear=1
| jitter=0 n_bands=8 offset=4 objective=mae
`-- petal_width [bucket] active target query=[*].petal_width
pooling=query weight=1 p_mask=0 p_prune=1 n_heads=4 n_linear=1
boundaries=[0.8, 1.8]
Training uses the same data module and Lightning loop as the built-in tensorfields. Extensions integrate by implementing the hooks above, not by changing the model training loop.
datamodule = j2v.PolarsDataModule(
model,
train=records,
validate=records,
num_workers=0,
persistent_workers=False,
pin_memory=False,
observation_buffer_size=32,
sample_rate=1.0,
)
trainer = lit.Trainer(
max_epochs=10,
logger=False,
enable_progress_bar=False,
enable_model_summary=False,
enable_checkpointing=False,
limit_train_batches=10,
limit_val_batches=10,
)
trainer.fit(model=model, datamodule=datamodule)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
/home/runner/work/json2vec/json2vec/.venv/lib/python3.12/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead. /home/runner/work/json2vec/json2vec/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance. 2026-06-09 20:47:34.633 | INFO | json2vec.logging.throughput:end:45 - validate epoch throughput: 688.47 observations/s
/home/runner/work/json2vec/json2vec/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
2026-06-09 20:47:34.809 | INFO | json2vec.logging.throughput:end:45 - validate epoch throughput: 1459.77 observations/s
2026-06-09 20:47:34.810 | INFO | json2vec.logging.throughput:end:45 - train epoch throughput: 461.25 observations/s
2026-06-09 20:47:34.983 | INFO | json2vec.logging.throughput:end:45 - validate epoch throughput: 1535.10 observations/s
2026-06-09 20:47:34.984 | INFO | json2vec.logging.throughput:end:45 - train epoch throughput: 462.79 observations/s
2026-06-09 20:47:35.147 | INFO | json2vec.logging.throughput:end:45 - validate epoch throughput: 1560.73 observations/s
2026-06-09 20:47:35.148 | INFO | json2vec.logging.throughput:end:45 - train epoch throughput: 490.68 observations/s
2026-06-09 20:47:35.313 | INFO | json2vec.logging.throughput:end:45 - validate epoch throughput: 1547.45 observations/s
2026-06-09 20:47:35.314 | INFO | json2vec.logging.throughput:end:45 - train epoch throughput: 485.58 observations/s
2026-06-09 20:47:35.477 | INFO | json2vec.logging.throughput:end:45 - validate epoch throughput: 1559.94 observations/s
2026-06-09 20:47:35.478 | INFO | json2vec.logging.throughput:end:45 - train epoch throughput: 490.73 observations/s
2026-06-09 20:47:35.640 | INFO | json2vec.logging.throughput:end:45 - validate epoch throughput: 1571.86 observations/s
2026-06-09 20:47:35.642 | INFO | json2vec.logging.throughput:end:45 - train epoch throughput: 491.95 observations/s
2026-06-09 20:47:35.806 | INFO | json2vec.logging.throughput:end:45 - validate epoch throughput: 1544.19 observations/s
2026-06-09 20:47:35.807 | INFO | json2vec.logging.throughput:end:45 - train epoch throughput: 485.73 observations/s
2026-06-09 20:47:35.970 | INFO | json2vec.logging.throughput:end:45 - validate epoch throughput: 1567.12 observations/s
2026-06-09 20:47:35.971 | INFO | json2vec.logging.throughput:end:45 - train epoch throughput: 490.15 observations/s
2026-06-09 20:47:36.134 | INFO | json2vec.logging.throughput:end:45 - validate epoch throughput: 1559.64 observations/s
2026-06-09 20:47:36.135 | INFO | json2vec.logging.throughput:end:45 - train epoch throughput: 490.65 observations/s
2026-06-09 20:47:36.298 | INFO | json2vec.logging.throughput:end:45 - validate epoch throughput: 1569.52 observations/s
2026-06-09 20:47:36.299 | INFO | json2vec.logging.throughput:end:45 - train epoch throughput: 490.35 observations/s
`Trainer.fit` stopped: `max_epochs=10` reached.
Prediction now goes through the custom write hook, so the output shape and names are controlled by the extension.
batch = records.to_dicts()[:3]
pprint(model.predict(batch))
{ │ 'record/petal_width': { │ │ 'buckets': ['(-inf, 0.8)', '[0.8, 1.8)', '[0.8, 1.8)'], │ │ 'indices': [0, 1, 1], │ │ 'probability': [ │ │ │ [0.9897330403327942, 0.003750420641154051, 0.006516592111438513], │ │ │ [0.0036527756601572037, 0.6486832499504089, 0.34766390919685364], │ │ │ [0.003635462373495102, 0.6497473120689392, 0.3466172218322754] │ │ ] │ } }