Skip to content

API Reference

This page is generated from public docstrings and is meant as a lookup companion to the tutorials. Start with the notebooks when learning the workflow, then use this page to inspect constructor options, mutation methods, and extension base classes.

Common Entry Points

  • Model.from_schema(...) builds the model tree from field constructors and arrays.
  • Array(...) declares a repeated nested context.
  • Overflow enumerates Array overflow policies: head, tail, and error.
  • Number, Category, Set, DateParts, Entity, Vector, and Text declare typed fields.
  • CustomDataModule(...) wraps user-provided PyTorch iterable datasets.
  • PolarsDataModule(...) builds data loaders from a configured model.
  • StreamingDataModule(...) streams local or S3-backed files into Lightning loops.
  • Model.predict(...) returns configured target predictions and embeddings.
  • Writer(...) writes batch prediction output from Trainer.predict(...).
  • Postprocessor reshapes predictions after decoding; see Postprocessors.
  • Deployment wraps a checkpoint or model instance for serving; install the serving extra for FastAPI-backed deployment paths.

Learning-oriented entry points:

Package

json2vec

Public json2vec SDK surface.

The top-level package exports the constructors and helpers used by most applications: Model.from_schema(...) for model construction, tensorfield request constructors such as Category and Number, data modules, schema mutation predicates, and the @preprocess decorator.

OptimizerConfig module-attribute

OptimizerConfig = Optimizer | Callable[["Model"], Optimizer]

SchedulerConfig module-attribute

SchedulerConfig = Any | Callable[['Model', Optimizer], Any]

MASK_LITERAL module-attribute

MASK_LITERAL = '<MASK>'

MaskLiteral module-attribute

MaskLiteral: TypeAlias = Literal['<MASK>']

Postprocessor module-attribute

Postprocessor: TypeAlias = Callable[
    [dict[str, Any], dict[Address, dict[str, Any]]],
    dict[Address, dict[str, Any]] | None,
]

PREPROCESSORS module-attribute

PREPROCESSORS: dict[str, Preprocessor] = {}

SchemaField module-attribute

SchemaField: TypeAlias = Array | Leaf

TENSORFIELDS module-attribute

TENSORFIELDS: dict[str, 'Plugin'] = {}

RequestBase module-attribute

RequestBase: TypeAlias = Leaf

Input module-attribute

Input: TypeAlias = TensorDict[Address, TensorFieldBase]

ModelSource module-attribute

ModelSource: TypeAlias = str | Path | Model

UpdateOperation module-attribute

UpdateOperation: TypeAlias = tuple[
    tuple[
        NodePredicate
        | NodeAttribute
        | Callable[[Node], bool],
        ...,
    ],
    dict[str, Any],
]

RollbackCheckpoint

RollbackCheckpoint(*args: Any, **kwargs: Any)

Bases: ModelCheckpoint

Checkpoint the best model during fit and restore it into the module at fit end.

Source code in src/json2vec/architecture/checkpoint.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    super().__init__(*args, **kwargs)
    if self.save_weights_only:
        raise ValueError("RollbackCheckpoint requires full checkpoints; set save_weights_only=False")
    if self.save_top_k == 0:
        raise ValueError("RollbackCheckpoint requires at least one saved checkpoint; set save_top_k != 0")

on_fit_end

on_fit_end(
    trainer: Trainer, pl_module: LightningModule
) -> None
Source code in src/json2vec/architecture/checkpoint.py
def on_fit_end(self, trainer: lit.Trainer, pl_module: lit.LightningModule) -> None:
    from json2vec.architecture.root import Model

    super().on_fit_end(trainer=trainer, pl_module=pl_module)
    if not isinstance(pl_module, Model):
        raise TypeError("RollbackCheckpoint can only restore json2vec Model instances")

    best_model_path = self.best_model_path
    if not best_model_path:
        raise RuntimeError("RollbackCheckpoint did not find a best checkpoint to restore")

    strategy = getattr(trainer, "strategy", None)
    if strategy is not None:
        strategy.barrier("rollback_checkpoint_load")
        checkpoint = strategy.checkpoint_io.load_checkpoint(
            best_model_path,
            map_location=pl_module.device,
            weights_only=False,
        )
    else:
        checkpoint = torch.load(best_model_path, weights_only=False, map_location=pl_module.device)

    pl_module.restore_checkpoint_state(checkpoint)
    logger.bind(
        component="checkpoint",
        checkpoint=best_model_path,
        score=self.best_model_score,
    ).info("rolled back Model to best checkpoint")

MutationLockCallback

Bases: Callback

Prevent runtime schema mutations while Lightning owns an active loop.

locks class-attribute instance-attribute

locks: tuple[Strata, ...] = (train, validate, test, predict)

on_train_start class-attribute instance-attribute

on_train_start = partialmethod(_on_loop_start, strata=train)

on_train_end class-attribute instance-attribute

on_train_end = partialmethod(_on_loop_end, strata=train)

on_validation_start class-attribute instance-attribute

on_validation_start = partialmethod(
    _on_loop_start, strata=validate
)

on_validation_end class-attribute instance-attribute

on_validation_end = partialmethod(
    _on_loop_end, strata=validate
)

on_test_start class-attribute instance-attribute

on_test_start = partialmethod(_on_loop_start, strata=test)

on_test_end class-attribute instance-attribute

on_test_end = partialmethod(_on_loop_end, strata=test)

on_predict_start class-attribute instance-attribute

on_predict_start = partialmethod(
    _on_loop_start, strata=predict
)

on_predict_end class-attribute instance-attribute

on_predict_end = partialmethod(_on_loop_end, strata=predict)

on_exception

on_exception(
    trainer: Trainer,
    pl_module: "Model",
    exception: BaseException,
) -> None
Source code in src/json2vec/architecture/mutations.py
def on_exception(
    self,
    trainer: lit.Trainer,
    pl_module: "Model",
    exception: BaseException,
) -> None:  # ty:ignore[invalid-method-override]
    for lock in self.locks:
        pl_module.locks.pop(lock, None)

RuntimePlacementCallback

Bases: Callback

Move late-created modules onto the Lightning module's active device.

on_train_start class-attribute instance-attribute

on_train_start = partialmethod(_on_loop_start, strata=train)

on_validation_start class-attribute instance-attribute

on_validation_start = partialmethod(
    _on_loop_start, strata=validate
)

on_test_start class-attribute instance-attribute

on_test_start = partialmethod(_on_loop_start, strata=test)

on_predict_start class-attribute instance-attribute

on_predict_start = partialmethod(
    _on_loop_start, strata=predict
)

Model

Model(
    hyperparameters: Hyperparameters,
    *,
    batch_size: int = 1,
    optimizer: OptimizerConfig | None = None,
    scheduler: SchedulerConfig | None = None,
)

Bases: LightningModule, Renderable

Neural model generated from a json2vec schema tree.

Model owns the schema hyperparameters, tensorfield embedders, array encoders, decoders, and convenience methods for prediction, checkpointing, schema display and mutation.

Example
import json2vec as j2v

model = j2v.Model.from_schema(
    j2v.Category("segment", max_vocab_size=32),
    j2v.Category("label", target=True, max_vocab_size=4),
    d_model=16,
    n_layers=1,
    n_heads=4,
    batch_size=8,
    embed=True,
)
Source code in src/json2vec/architecture/root.py
@beartype
def __init__(
    self,
    hyperparameters: Hyperparameters,
    *,
    batch_size: int = 1,
    optimizer: OptimizerConfig | None = None,
    scheduler: SchedulerConfig | None = None,
):
    super().__init__()
    if batch_size <= 0:
        raise ValueError("batch_size must be > 0")

    self.hyperparameters: Hyperparameters = hyperparameters
    self.batch_size: int = batch_size
    self.optimizer: OptimizerConfig | None = optimizer
    self.scheduler: SchedulerConfig | None = scheduler
    self.locks: Counter[str | Strata] = Counter()
    self.nodes: torch.nn.ModuleDict = torch.nn.ModuleDict()
    self.schema: SchemaEditor = SchemaEditor(self)
    self._contract_generation: int = 0
    self._contract_scheduler: ContractScheduler = ContractScheduler()

    self._build()

    logger.bind(
        component="model",
        batch_size=self.batch_size,
        requests=len(self.hyperparameters.active_requests),
        arrays=len(self.hyperparameters.arrays),
        embeds=len(self.hyperparameters.embed),
    ).info("initialized Model module")

hyperparameters instance-attribute

hyperparameters: Hyperparameters = hyperparameters

batch_size instance-attribute

batch_size: int = batch_size

optimizer instance-attribute

optimizer: OptimizerConfig | None = optimizer

scheduler instance-attribute

scheduler: SchedulerConfig | None = scheduler

locks instance-attribute

locks: Counter[str | Strata] = Counter()

nodes instance-attribute

nodes: ModuleDict = ModuleDict()

schema instance-attribute

schema: SchemaEditor = SchemaEditor(self)

interprocess_encoding_context property

interprocess_encoding_context: dict[Address, Any]

from_checkpoint class-attribute instance-attribute

from_checkpoint = load

training_step class-attribute instance-attribute

training_step = partialmethod(step, strata=train)

validation_step class-attribute instance-attribute

validation_step = partialmethod(step, strata=validate)

test_step class-attribute instance-attribute

test_step = partialmethod(step, strata=test)

predict_step class-attribute instance-attribute

predict_step = partialmethod(step, strata=predict)

from_schema classmethod

from_schema(
    *field_args: SchemaField,
    d_model: int,
    n_layers: int,
    n_heads: int,
    batch_size: int = 1,
    fields: Sequence[SchemaField] | None = None,
    name: str = "record",
    description: str | None = None,
    embed: bool = False,
    attention: AttentionMode | str = AttentionMode.mha,
    n_linear: int = 1,
    dropout: Rate | None = None,
    optimizer: OptimizerConfig | None = None,
    scheduler: SchedulerConfig | None = None,
) -> Self

Build a model directly from schema fields.

Parameters:

Name Type Description Default
*field_args SchemaField

Field constructors such as Category, Number, or nested Array nodes.

()
d_model int

Shared model width.

required
n_layers int

Number of encoder layers on generated array nodes.

required
n_heads int

Attention heads used by generated nodes.

required
batch_size int

Batch size used by data modules, examples, and mocked Lightning input arrays.

1
fields Sequence[SchemaField] | None

Optional sequence form of field_args.

None
name str

Root array name. Defaults to record.

'record'
description str | None

Optional description on the generated root array.

None
embed bool

Configure the generated root array as an embedding output.

False
attention AttentionMode | str

Attention mode for the generated root array.

mha
n_linear int

Feed-forward block count on the generated root array.

1
dropout Rate | None

Optional dropout rate on the generated root array.

None
optimizer OptimizerConfig | None

Optimizer instance or factory used by Lightning training.

None
scheduler SchedulerConfig | None

Optional scheduler config or factory.

None

Returns:

Type Description
Self

A compiled Model with modules built for the schema.

Source code in src/json2vec/architecture/root.py
@classmethod
def from_schema(
    cls,
    *field_args: SchemaField,
    d_model: int,
    n_layers: int,
    n_heads: int,
    batch_size: int = 1,
    fields: Sequence[SchemaField] | None = None,
    name: str = "record",
    description: str | None = None,
    embed: bool = False,
    attention: AttentionMode | str = AttentionMode.mha,
    n_linear: int = 1,
    dropout: Rate | None = None,
    optimizer: OptimizerConfig | None = None,
    scheduler: SchedulerConfig | None = None,
) -> Self:
    """Build a model directly from schema fields.

    Args:
        *field_args: Field constructors such as `Category`, `Number`, or
            nested `Array` nodes.
        d_model: Shared model width.
        n_layers: Number of encoder layers on generated array nodes.
        n_heads: Attention heads used by generated nodes.
        batch_size: Batch size used by data modules, examples, and mocked
            Lightning input arrays.
        fields: Optional sequence form of `field_args`.
        name: Root array name. Defaults to `record`.
        description: Optional description on the generated root array.
        embed: Configure the generated root array as an embedding output.
        attention: Attention mode for the generated root array.
        n_linear: Feed-forward block count on the generated root array.
        dropout: Optional dropout rate on the generated root array.
        optimizer: Optimizer instance or factory used by Lightning training.
        scheduler: Optional scheduler config or factory.

    Returns:
        A compiled `Model` with modules built for the schema.
    """
    hyperparameters = Hyperparameters.from_schema(
        *field_args,
        d_model=d_model,
        n_layers=n_layers,
        n_heads=n_heads,
        fields=fields,
        name=name,
        description=description,
        embed=embed,
        attention=attention,
        n_linear=n_linear,
        dropout=dropout,
    )
    return cls(
        hyperparameters=hyperparameters,
        batch_size=batch_size,
        optimizer=optimizer,
        scheduler=scheduler,
    )

select

select(
    *predicates: NodePredicate
    | NodeAttribute
    | Callable[[Node], bool],
    include_root: bool = True,
    use_cache: bool = True,
) -> list[Node]

Return schema nodes that satisfy every predicate.

Source code in src/json2vec/architecture/root.py
def select(
    self,
    *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
    include_root: bool = True,
    use_cache: bool = True,
) -> list[Node]:
    """Return schema nodes that satisfy every predicate."""
    return self.schema.select(*predicates, include_root=include_root, use_cache=use_cache)

update

update(
    *predicates: NodePredicate
    | NodeAttribute
    | Callable[[Node], bool],
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    use_cache: bool = False,
    **values: Any,
) -> None

Mutate selected schema nodes and rebuild compatible modules.

target=True is shorthand for p_prune=1.0; target=False clears target behavior by setting p_prune=0.0.

Parameters:

Name Type Description Default
*predicates NodePredicate | NodeAttribute | Callable[[Node], bool]

Predicates used to select nodes.

()
strict bool

Raise when a selected node cannot accept one of values.

True
allow_extra bool

Permit updates to extra metadata fields on models that allow unknown fields.

False
include_root bool

Include the root node in predicate matching.

True
validate bool

Validate each node after applying candidate values.

True
use_cache bool

Permit cached selector results. Mutations default this to False so updates always evaluate against current schema state.

False
**values Any

Schema attributes to update.

{}
Source code in src/json2vec/architecture/root.py
def update(
    self,
    *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    use_cache: bool = False,
    **values: Any,
) -> None:
    """Mutate selected schema nodes and rebuild compatible modules.

    `target=True` is shorthand for `p_prune=1.0`; `target=False` clears
    target behavior by setting `p_prune=0.0`.

    Args:
        *predicates: Predicates used to select nodes.
        strict: Raise when a selected node cannot accept one of `values`.
        allow_extra: Permit updates to extra metadata fields on models that
            allow unknown fields.
        include_root: Include the root node in predicate matching.
        validate: Validate each node after applying candidate values.
        use_cache: Permit cached selector results. Mutations default this to
            `False` so updates always evaluate against current schema state.
        **values: Schema attributes to update.
    """
    self.schema.update(
        *predicates,
        strict=strict,
        allow_extra=allow_extra,
        include_root=include_root,
        validate=validate,
        use_cache=use_cache,
        **values,
    )

extend

extend(
    *args: NodePredicate
    | NodeAttribute
    | Callable[[Node], bool]
    | SchemaField,
    include_root: bool = True,
    use_cache: bool = True,
) -> None

Append new schema fields under one selected array node and rebuild modules.

Source code in src/json2vec/architecture/root.py
def extend(
    self,
    *args: NodePredicate | NodeAttribute | Callable[[Node], bool] | SchemaField,
    include_root: bool = True,
    use_cache: bool = True,
) -> None:
    """Append new schema fields under one selected array node and rebuild modules."""
    self.schema.extend(*args, include_root=include_root, use_cache=use_cache)

delete

delete(
    *predicates: NodePredicate
    | NodeAttribute
    | Callable[[Node], bool],
    include_root: bool = False,
    use_cache: bool = True,
) -> None

Permanently remove selected schema nodes and rebuild modules.

Source code in src/json2vec/architecture/root.py
def delete(
    self,
    *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
    include_root: bool = False,
    use_cache: bool = True,
) -> None:
    """Permanently remove selected schema nodes and rebuild modules."""
    self.schema.delete(*predicates, include_root=include_root, use_cache=use_cache)

reset

reset(
    *predicates: NodePredicate
    | NodeAttribute
    | Callable[[Node], bool],
    include_root: bool = True,
    use_cache: bool = True,
    descendants: bool = False,
) -> None

Reinitialize selected runtime node modules while preserving schema values.

Source code in src/json2vec/architecture/root.py
def reset(
    self,
    *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
    include_root: bool = True,
    use_cache: bool = True,
    descendants: bool = False,
) -> None:
    """Reinitialize selected runtime node modules while preserving schema values."""
    self.schema.reset(
        *predicates,
        include_root=include_root,
        use_cache=use_cache,
        descendants=descendants,
    )

override

override(
    *predicates: NodePredicate
    | NodeAttribute
    | Callable[[Node], bool],
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    use_cache: bool = False,
    **values: Any,
) -> Iterator[None]

Temporarily mutate selected schema nodes and keep runtime modules synchronized.

Source code in src/json2vec/architecture/root.py
@contextmanager
def override(
    self,
    *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    use_cache: bool = False,
    **values: Any,
) -> Iterator[None]:
    """Temporarily mutate selected schema nodes and keep runtime modules synchronized."""
    with self.schema.override(
        *predicates,
        strict=strict,
        allow_extra=allow_extra,
        include_root=include_root,
        validate=validate,
        use_cache=use_cache,
        **values,
    ):
        yield

configure_callbacks

configure_callbacks() -> list[Callback]
Source code in src/json2vec/architecture/root.py
def configure_callbacks(self) -> list[Callback]:
    callbacks: list[Callback] = []
    factories: set[Any] = set()
    trainer = getattr(self, "_trainer", None)
    attached_callback_types = {type(callback) for callback in getattr(trainer, "callbacks", ())}

    if RuntimePlacementCallback not in attached_callback_types:
        callbacks.append(RuntimePlacementCallback())
    if MutationLockCallback not in attached_callback_types:
        callbacks.append(MutationLockCallback())
    if ThroughputLogger not in attached_callback_types:
        callbacks.append(ThroughputLogger())

    for request in self.hyperparameters.active_requests.values():
        plugin: Plugin = TENSORFIELDS[request.type]
        for factory in plugin.callback_factories:
            if factory in factories:
                continue

            factories.add(factory)
            callback = factory()
            if type(callback) not in attached_callback_types:
                callbacks.append(callback)

    # Callbacks may perform distributed work, so register them in a
    # deterministic order on every rank. Use class paths instead of Python's
    # salted hash or schema traversal order.
    callbacks.sort(
        key=lambda callback: (
            type(callback).__module__,
            type(callback).__qualname__,
        )
    )

    return callbacks

track

track(
    names: tuple[str, ...], /, value: Tensor
) -> torch.Tensor
Source code in src/json2vec/architecture/root.py
def track(self, names: tuple[str, ...], /, value: torch.Tensor) -> torch.Tensor:
    def groupname(names: tuple[str, ...]) -> str:
        assert len(names) > 1

        group, *keys = tuple(map(lambda x: x.replace("/", ":").lower(), names))
        key = ":".join(list(keys))

        return f"{group}/{key}"

    # These metrics are emitted from data-dependent branches, so DDP ranks cannot
    # safely synchronize every log call as a collective. rank_zero_only keeps
    # Lightning from running a sync while still marking the metric as handled.
    self.log(
        name=groupname(names),
        value=value.detach(),
        on_step=False,
        on_epoch=True,
        sync_dist=True,
        rank_zero_only=True,
        batch_size=self.batch_size,
    )

    return value

save

save(pathname: str | Path) -> str | Path

Save model weights and schema hyperparameters to a checkpoint.

Source code in src/json2vec/architecture/root.py
@beartype
def save(self, pathname: str | Path) -> str | Path:
    """Save model weights and schema hyperparameters to a checkpoint."""
    CheckpointState.save(self, pathname)

    return pathname

forward

forward(
    inputs: TensorDict[Address, TensorFieldBase],
    *,
    strata: Strata | str,
    dataloader_idx: int = 0,
) -> list[Prediction]
Source code in src/json2vec/architecture/root.py
@immutable("forward")
@beartype
def forward(
    self,
    inputs: TensorDict[Address, TensorFieldBase],
    *,
    strata: Strata | str,
    dataloader_idx: int = 0,
) -> list[Prediction]:
    return ModelRuntime.forward(self, inputs, strata=strata, dataloader_idx=dataloader_idx)

configure_optimizers

configure_optimizers()
Source code in src/json2vec/architecture/root.py
@beartype
def configure_optimizers(self):
    if self.optimizer is None:
        raise ValueError("optimizer must be passed to Model before fitting")

    if isinstance(self.optimizer, torch.optim.Optimizer):
        optimizer = self.optimizer
    else:
        optimizer = self.optimizer(self)

    scheduler = self.scheduler(self, optimizer) if callable(self.scheduler) else self.scheduler

    if scheduler is None:
        return optimizer

    return dict(optimizer=optimizer, lr_scheduler=scheduler)

on_save_checkpoint

on_save_checkpoint(checkpoint)
Source code in src/json2vec/architecture/root.py
def on_save_checkpoint(self, checkpoint):
    CheckpointState.dump(self, checkpoint)

restore_checkpoint_state

restore_checkpoint_state(
    checkpoint: dict[str, Any],
) -> None

Restore this model in place from a json2vec checkpoint dictionary.

Source code in src/json2vec/architecture/root.py
def restore_checkpoint_state(self, checkpoint: dict[str, Any]) -> None:
    """Restore this model in place from a `json2vec` checkpoint dictionary."""
    CheckpointState.restore(self, checkpoint)

load classmethod

load(checkpoint: str | Path) -> Self

Load a Model checkpoint written by Model.save(...).

Source code in src/json2vec/architecture/root.py
@classmethod
def load(cls, checkpoint: str | Path) -> Self:
    """Load a `Model` checkpoint written by `Model.save(...)`."""
    return cast(Self, CheckpointState.load(cls, checkpoint))

write

write(
    predictions: list[Prediction],
) -> dict[Address, dict[str, Any]]
Source code in src/json2vec/architecture/root.py
def write(self, predictions: list[Prediction]) -> dict[Address, dict[str, Any]]:
    return ModelRuntime.write(self, predictions)

encode

encode(
    batch: EncodedBatch | list[dict[str, Any]],
    preprocess: Preprocessor | None = None,
    strata: Strata | str = Strata.predict,
    mask: bool = True,
) -> EncodedInput

Return encoded tensorfield inputs for raw or processed observations.

Source code in src/json2vec/architecture/root.py
@immutable("inference")
def encode(
    self,
    batch: EncodedBatch | list[dict[str, Any]],
    preprocess: Preprocessor | None = None,
    strata: Strata | str = Strata.predict,
    mask: bool = True,
) -> EncodedInput:
    """Return encoded tensorfield inputs for raw or processed observations."""
    return ModelRuntime.encode(
        self,
        batch=batch,
        preprocess=preprocess,
        strata=strata,
        mask=mask,
    )

predict

predict(
    batch: EncodedBatch | list[dict[str, Any]],
    preprocess: Preprocessor | None = None,
    postprocess: Postprocessor | None = None,
) -> dict[Address, dict[str, Any]]

Return typed predictions and configured embeddings for a raw or encoded batch.

Source code in src/json2vec/architecture/root.py
@immutable("inference")
def predict(
    self,
    batch: EncodedBatch | list[dict[str, Any]],
    preprocess: Preprocessor | None = None,
    postprocess: Postprocessor | None = None,
) -> dict[Address, dict[str, Any]]:
    """Return typed predictions and configured embeddings for a raw or encoded batch."""
    return ModelRuntime.predict(
        self,
        batch=batch,
        preprocess=preprocess,
        postprocess=postprocess,
    )

CustomDataModule

CustomDataModule(
    model: Model,
    train: IterableDataset | None = None,
    validate: IterableDataset | None = None,
    test: IterableDataset | None = None,
    predict: IterableDataset | None = None,
    preprocessor: str
    | Callable[..., Any]
    | Preprocessor
    | None = None,
    datasets: DatasetMap | None = None,
    num_workers: NonNegativeInt
    | None
    | StrataMap[NonNegativeInt | None] = None,
    persistent_workers: bool | StrataMap[bool] = True,
    pin_memory: bool | StrataMap[bool] = True,
    observation_buffer_size: PositiveInt
    | StrataMap[PositiveInt] = 1,
    sample_rate: SampleRate | StrataMap[SampleRate] = 1.0,
    **kwargs: Any,
)

Bases: LightningDataModule

Lightning data module for user-provided iterable datasets.

Source code in src/json2vec/data/datasets/custom.py
def __init__(
    self,
    model: Model,
    train: IterableDataset | None = None,
    validate: IterableDataset | None = None,
    test: IterableDataset | None = None,
    predict: IterableDataset | None = None,
    preprocessor: str | Callable[..., Any] | Preprocessor | None = None,
    datasets: DatasetMap | None = None,
    num_workers: NonNegativeInt | None | StrataMap[NonNegativeInt | None] = None,
    persistent_workers: bool | StrataMap[bool] = True,
    pin_memory: bool | StrataMap[bool] = True,
    observation_buffer_size: PositiveInt | StrataMap[PositiveInt] = 1,
    sample_rate: SampleRate | StrataMap[SampleRate] = 1.0,
    **kwargs: Any,
):
    super().__init__()

    _validate_loader_configuration(
        num_workers=num_workers,
        persistent_workers=persistent_workers,
        pin_memory=pin_memory,
        observation_buffer_size=observation_buffer_size,
        sample_rate=sample_rate,
    )

    if datasets is not None and any(dataset is not None for dataset in (train, validate, test, predict)):
        raise ValueError("pass either datasets or named splits, not both")

    if datasets is None:
        split_datasets = {}
        for strata, dataset in {
            Strata.train: train,
            Strata.validate: validate,
            Strata.test: test,
            Strata.predict: predict,
        }.items():
            if dataset is None:
                continue
            if not isinstance(dataset, IterableDataset):
                raise TypeError(f"dataset for strata '{strata}' must be an IterableDataset")
            split_datasets[strata] = dataset
        if not split_datasets:
            raise ValueError("at least one dataset split is required")
    else:
        split_datasets = _datasets_by_strata(datasets)

    self.datasets = split_datasets
    self.preprocessor = PreprocessorConfig.normalize(preprocessor)
    self.preprocessor_kwargs = dict(kwargs)
    try:
        self._model_ref = weakref.ref(model)
    except TypeError:
        self._model_ref = None
    self._hyperparameters = model.hyperparameters
    self._interprocess_encoding_context = model.interprocess_encoding_context
    self._batch_size = model.batch_size
    self.num_workers = Strata.expand(num_workers, default=None)
    self.persistent_workers = Strata.expand(persistent_workers, default=True)
    self.pin_memory = Strata.expand(pin_memory, default=True)
    self.observation_buffer_size = Strata.expand(observation_buffer_size, default=1)
    self.sample_rate = {strata: float(rate) for strata, rate in Strata.expand(sample_rate, default=1.0).items()}

datasets instance-attribute

datasets = split_datasets

preprocessor instance-attribute

preprocessor = normalize(preprocessor)

preprocessor_kwargs instance-attribute

preprocessor_kwargs = dict(kwargs)

num_workers instance-attribute

num_workers = expand(num_workers, default=None)

persistent_workers instance-attribute

persistent_workers = expand(
    persistent_workers, default=True
)

pin_memory instance-attribute

pin_memory = expand(pin_memory, default=True)

observation_buffer_size instance-attribute

observation_buffer_size = expand(
    observation_buffer_size, default=1
)

sample_rate instance-attribute

sample_rate = {
    strata: (float(rate)) for strata, rate in (items())
}

hyperparameters property writable

hyperparameters: Hyperparameters

batch_size property writable

batch_size: int

interprocess_encoding_context property writable

interprocess_encoding_context: InterprocessEncodingContext

train_dataloader class-attribute instance-attribute

train_dataloader = partialmethod(
    dataloader, strata=train, required=False
)

val_dataloader class-attribute instance-attribute

val_dataloader = partialmethod(
    dataloader, strata=validate, required=False
)

test_dataloader class-attribute instance-attribute

test_dataloader = partialmethod(
    dataloader, strata=test, required=False
)

predict_dataloader class-attribute instance-attribute

predict_dataloader = partialmethod(
    dataloader, strata=predict, required=False
)

dataloader

dataloader(
    strata: Strata, required: bool = True
) -> DataLoader | None
Source code in src/json2vec/data/datasets/custom.py
def dataloader(self, strata: Strata, required: bool = True) -> DataLoader | None:
    strata = Strata.normalize(strata)
    trainer = getattr(self, "trainer", None)
    global_rank = getattr(trainer, "global_rank", None)
    world_size = getattr(trainer, "world_size", None)
    if strata not in self.datasets:
        if not required:
            return None
        raise ValueError(f"no dataset configured for strata: {strata}")

    workers = self.num_workers[strata]
    if workers is None:
        workers = os.cpu_count() or 0

    interprocess_encoding_context = self.interprocess_encoding_context
    if strata == Strata.train and workers > 0:
        share_interprocess_encoding_context(interprocess_encoding_context)

    return custom_dataloader(
        hyperparameters=self.hyperparameters,
        dataset=self.datasets[strata],
        preprocessor=self.preprocessor,
        preprocessor_kwargs=self.preprocessor_kwargs,
        interprocess_encoding_context=interprocess_encoding_context,
        batch_size=self.batch_size,
        strata=strata,
        num_workers=workers,
        persistent_workers=self.persistent_workers[strata],
        pin_memory=self.pin_memory[strata],
        observation_buffer_size=self.observation_buffer_size[strata],
        sample_rate=self.sample_rate[strata],
        global_rank=global_rank,
        world_size=world_size,
    )

PolarsDataModule

PolarsDataModule(
    model: Model,
    train: DataFrame | None = None,
    validate: DataFrame | None = None,
    test: DataFrame | None = None,
    predict: DataFrame | None = None,
    preprocessor: str
    | Callable[..., Any]
    | Preprocessor
    | None = None,
    dataframe: DataFrame | DataFrameMap | None = None,
    num_workers: NonNegativeInt
    | None
    | StrataMap[NonNegativeInt | None] = None,
    persistent_workers: bool | StrataMap[bool] = True,
    pin_memory: bool | StrataMap[bool] = True,
    sharding: ShardingStrategy
    | str
    | StrataMap[
        ShardingStrategy | str
    ] = ShardingStrategy.chunk,
    chunk_batch_size: PositiveInt
    | StrataMap[PositiveInt] = 4096,
    observation_buffer_size: PositiveInt
    | StrataMap[PositiveInt] = 1,
    sample_rate: SampleRate | StrataMap[SampleRate] = 1.0,
    replacement: bool | StrataMap[bool] = False,
    **kwargs: Any,
)

Bases: LightningDataModule

Lightning data module for in-memory Polars DataFrames.

Source code in src/json2vec/data/datasets/polars.py
@beartype
def __init__(
    self,
    model: Model,
    train: pl.DataFrame | None = None,
    validate: pl.DataFrame | None = None,
    test: pl.DataFrame | None = None,
    predict: pl.DataFrame | None = None,
    preprocessor: str | Callable[..., Any] | Preprocessor | None = None,
    dataframe: pl.DataFrame | DataFrameMap | None = None,
    num_workers: NonNegativeInt | None | StrataMap[NonNegativeInt | None] = None,
    persistent_workers: bool | StrataMap[bool] = True,
    pin_memory: bool | StrataMap[bool] = True,
    sharding: ShardingStrategy | str | StrataMap[ShardingStrategy | str] = ShardingStrategy.chunk,
    chunk_batch_size: PositiveInt | StrataMap[PositiveInt] = 4096,
    observation_buffer_size: PositiveInt | StrataMap[PositiveInt] = 1,
    sample_rate: SampleRate | StrataMap[SampleRate] = 1.0,
    replacement: bool | StrataMap[bool] = False,
    **kwargs: Any,
):
    super().__init__()

    if dataframe is not None and any(frame is not None for frame in (train, validate, test, predict)):
        raise ValueError("pass either dataframe or named splits, not both")

    if dataframe is None:
        dataframes = {
            strata: frame
            for strata, frame in {
                Strata.train: train,
                Strata.validate: validate,
                Strata.test: test,
                Strata.predict: predict,
            }.items()
            if frame is not None
        }
        if not dataframes:
            raise ValueError("at least one dataframe split is required")
    else:
        dataframes = _dataframes_by_strata(dataframe)

    self.dataframes = dataframes
    self.preprocessor = PreprocessorConfig.normalize(preprocessor)
    self.preprocessor_kwargs = dict(kwargs)
    try:
        self._model_ref = weakref.ref(model)
    except TypeError:
        self._model_ref = None
    self._hyperparameters = model.hyperparameters
    self._interprocess_encoding_context = model.interprocess_encoding_context
    self._batch_size = model.batch_size
    self.num_workers = Strata.expand(num_workers, default=None)
    self.persistent_workers = Strata.expand(persistent_workers, default=True)
    self.pin_memory = Strata.expand(pin_memory, default=True)
    self.sharding = ShardingStrategy.expand(sharding, default=ShardingStrategy.chunk)
    self.chunk_batch_size = Strata.expand(chunk_batch_size, default=4096)
    self.observation_buffer_size = Strata.expand(observation_buffer_size, default=1)
    self.sample_rate = {strata: float(rate) for strata, rate in Strata.expand(sample_rate, default=1.0).items()}
    self.replacement = Strata.expand(replacement, default=False)

dataframes instance-attribute

dataframes = dataframes

preprocessor instance-attribute

preprocessor = normalize(preprocessor)

preprocessor_kwargs instance-attribute

preprocessor_kwargs = dict(kwargs)

num_workers instance-attribute

num_workers = expand(num_workers, default=None)

persistent_workers instance-attribute

persistent_workers = expand(
    persistent_workers, default=True
)

pin_memory instance-attribute

pin_memory = expand(pin_memory, default=True)

sharding instance-attribute

sharding = expand(sharding, default=chunk)

chunk_batch_size instance-attribute

chunk_batch_size = expand(chunk_batch_size, default=4096)

observation_buffer_size instance-attribute

observation_buffer_size = expand(
    observation_buffer_size, default=1
)

sample_rate instance-attribute

sample_rate = {
    strata: (float(rate)) for strata, rate in (items())
}

replacement instance-attribute

replacement = expand(replacement, default=False)

hyperparameters property writable

hyperparameters: Hyperparameters

batch_size property writable

batch_size: int

interprocess_encoding_context property writable

interprocess_encoding_context: InterprocessEncodingContext

train_dataloader class-attribute instance-attribute

train_dataloader = partialmethod(
    dataloader, strata=train, required=False
)

val_dataloader class-attribute instance-attribute

val_dataloader = partialmethod(
    dataloader, strata=validate, required=False
)

test_dataloader class-attribute instance-attribute

test_dataloader = partialmethod(
    dataloader, strata=test, required=False
)

predict_dataloader class-attribute instance-attribute

predict_dataloader = partialmethod(
    dataloader, strata=predict, required=False
)

dataloader

dataloader(
    strata: Strata, required: bool = True
) -> DataLoader | None
Source code in src/json2vec/data/datasets/polars.py
def dataloader(self, strata: Strata, required: bool = True) -> DataLoader | None:
    strata = Strata.normalize(strata)
    trainer = getattr(self, "trainer", None)
    global_rank = getattr(trainer, "global_rank", None)
    world_size = getattr(trainer, "world_size", None)
    if strata not in self.dataframes:
        if not required:
            return None
        raise ValueError(f"no dataframe configured for strata: {strata}")

    workers = self.num_workers[strata]
    if workers is None:
        workers = os.cpu_count() or 0

    interprocess_encoding_context = self.interprocess_encoding_context
    if strata == Strata.train and workers > 0:
        share_interprocess_encoding_context(interprocess_encoding_context)

    return polars_dataloader(
        hyperparameters=self.hyperparameters,
        dataframe=self.dataframes[strata],
        preprocessor=self.preprocessor,
        preprocessor_kwargs=self.preprocessor_kwargs,
        interprocess_encoding_context=interprocess_encoding_context,
        batch_size=self.batch_size,
        strata=strata,
        num_workers=workers,
        persistent_workers=self.persistent_workers[strata],
        pin_memory=self.pin_memory[strata],
        sharding=self.sharding[strata],
        chunk_batch_size=self.chunk_batch_size[strata],
        observation_buffer_size=self.observation_buffer_size[strata],
        sample_rate=self.sample_rate[strata],
        replacement=self.replacement[strata],
        global_rank=global_rank,
        world_size=world_size,
    )

StreamingDataModule

StreamingDataModule(
    model: Model,
    root: str | Path,
    suffix: Suffix | str,
    train: PatternInput | None = None,
    validate: PatternInput | None = None,
    test: PatternInput | None = None,
    predict: PatternInput | None = None,
    preprocessor: str
    | Callable[..., Any]
    | Preprocessor
    | None = None,
    num_workers: NonNegativeInt
    | None
    | StrataMap[NonNegativeInt | None] = None,
    persistent_workers: bool | StrataMap[bool] = True,
    pin_memory: bool | StrataMap[bool] = True,
    sharding: ShardingStrategy
    | str
    | StrataMap[
        ShardingStrategy | str
    ] = ShardingStrategy.file,
    chunk_batch_size: PositiveInt
    | StrataMap[PositiveInt] = 4096,
    file_buffer_size: PositiveInt
    | StrataMap[PositiveInt] = 1,
    observation_buffer_size: PositiveInt
    | StrataMap[PositiveInt] = 1,
    sample_rate: SampleRate | StrataMap[SampleRate] = 1.0,
    replacement: bool | StrataMap[bool] | None = None,
    **kwargs: Any,
)

Bases: LightningDataModule

Lightning data module for streaming records from files.

Reads file-backed records, applies an optional preprocessor, batches observations, and encodes them with model hyperparameters.

Source code in src/json2vec/data/datasets/streaming.py
@beartype
def __init__(
    self,
    model: Model,
    root: str | Path,
    suffix: Suffix | str,
    train: PatternInput | None = None,
    validate: PatternInput | None = None,
    test: PatternInput | None = None,
    predict: PatternInput | None = None,
    preprocessor: str | Callable[..., Any] | Preprocessor | None = None,
    num_workers: NonNegativeInt | None | StrataMap[NonNegativeInt | None] = None,
    persistent_workers: bool | StrataMap[bool] = True,
    pin_memory: bool | StrataMap[bool] = True,
    sharding: ShardingStrategy | str | StrataMap[ShardingStrategy | str] = ShardingStrategy.file,
    chunk_batch_size: PositiveInt | StrataMap[PositiveInt] = 4096,
    file_buffer_size: PositiveInt | StrataMap[PositiveInt] = 1,
    observation_buffer_size: PositiveInt | StrataMap[PositiveInt] = 1,
    sample_rate: SampleRate | StrataMap[SampleRate] = 1.0,
    replacement: bool | StrataMap[bool] | None = None,
    **kwargs: Any,
):
    super().__init__()

    self.root = root
    self.suffix = Suffix(suffix)
    self.train = _compile_pattern(train) if train is not None else None
    self.validate = _compile_pattern(validate) if validate is not None else None
    self.test = _compile_pattern(test) if test is not None else None
    self.predict = _compile_pattern(predict) if predict is not None else None
    self.preprocessor = PreprocessorConfig.normalize(preprocessor)
    self.preprocessor_kwargs = dict(kwargs)
    try:
        self._model_ref = weakref.ref(model)
    except TypeError:
        self._model_ref = None
    self._hyperparameters = model.hyperparameters
    self._interprocess_encoding_context = model.interprocess_encoding_context
    self._batch_size = model.batch_size
    self.num_workers = Strata.expand(num_workers, default=None)
    self.persistent_workers = Strata.expand(persistent_workers, default=True)
    self.pin_memory = Strata.expand(pin_memory, default=True)
    self.sharding = ShardingStrategy.expand(sharding, default=ShardingStrategy.file)
    self.chunk_batch_size = Strata.expand(chunk_batch_size, default=4096)
    self.file_buffer_size = Strata.expand(file_buffer_size, default=1)
    self.observation_buffer_size = Strata.expand(observation_buffer_size, default=1)
    self.sample_rate = {strata: float(rate) for strata, rate in Strata.expand(sample_rate, default=1.0).items()}
    self.replacement = (
        {strata: strata == Strata.train for strata in Strata}
        if replacement is None
        else Strata.expand(replacement, default=False)
    )

root instance-attribute

root = root

suffix instance-attribute

suffix = Suffix(suffix)

train instance-attribute

train = (
    _compile_pattern(train) if train is not None else None
)

validate instance-attribute

validate = (
    _compile_pattern(validate)
    if validate is not None
    else None
)

test instance-attribute

test = _compile_pattern(test) if test is not None else None

predict instance-attribute

predict = (
    _compile_pattern(predict)
    if predict is not None
    else None
)

preprocessor instance-attribute

preprocessor = normalize(preprocessor)

preprocessor_kwargs instance-attribute

preprocessor_kwargs = dict(kwargs)

num_workers instance-attribute

num_workers = expand(num_workers, default=None)

persistent_workers instance-attribute

persistent_workers = expand(
    persistent_workers, default=True
)

pin_memory instance-attribute

pin_memory = expand(pin_memory, default=True)

sharding instance-attribute

sharding = expand(sharding, default=file)

chunk_batch_size instance-attribute

chunk_batch_size = expand(chunk_batch_size, default=4096)

file_buffer_size instance-attribute

file_buffer_size = expand(file_buffer_size, default=1)

observation_buffer_size instance-attribute

observation_buffer_size = expand(
    observation_buffer_size, default=1
)

sample_rate instance-attribute

sample_rate = {
    strata: (float(rate)) for strata, rate in (items())
}

replacement instance-attribute

replacement = (
    {strata: (strata == train) for strata in Strata}
    if replacement is None
    else expand(replacement, default=False)
)

hyperparameters property writable

hyperparameters: Hyperparameters

batch_size property writable

batch_size: int

interprocess_encoding_context property writable

interprocess_encoding_context: InterprocessEncodingContext

train_dataloader class-attribute instance-attribute

train_dataloader = partialmethod(
    dataloader, strata=train, required=False
)

val_dataloader class-attribute instance-attribute

val_dataloader = partialmethod(
    dataloader, strata=validate, required=False
)

test_dataloader class-attribute instance-attribute

test_dataloader = partialmethod(
    dataloader, strata=test, required=False
)

predict_dataloader class-attribute instance-attribute

predict_dataloader = partialmethod(
    dataloader, strata=predict, required=False
)

dataloader

dataloader(
    strata: Strata, required: bool = True
) -> DataLoader | None
Source code in src/json2vec/data/datasets/streaming.py
def dataloader(self, strata: Strata, required: bool = True) -> DataLoader | None:
    strata = Strata.normalize(strata)
    pattern = getattr(self, strata.value)
    if pattern is None:
        if not required:
            return None
        raise ValueError(f"no file pattern configured for strata: {strata}")

    trainer = getattr(self, "trainer", None)
    global_rank = getattr(trainer, "global_rank", None)
    world_size = getattr(trainer, "world_size", None)

    workers = self.num_workers[strata]
    if workers is None:
        workers = os.cpu_count() or 0

    interprocess_encoding_context = self.interprocess_encoding_context
    if strata == Strata.train and workers > 0:
        share_interprocess_encoding_context(interprocess_encoding_context)

    return dataloader(
        hyperparameters=self.hyperparameters,
        root=self.root,
        suffix=self.suffix,
        pattern=pattern,
        preprocessor=self.preprocessor,
        preprocessor_kwargs=self.preprocessor_kwargs,
        interprocess_encoding_context=interprocess_encoding_context,
        batch_size=self.batch_size,
        strata=strata,
        num_workers=workers,
        persistent_workers=self.persistent_workers[strata],
        pin_memory=self.pin_memory[strata],
        sharding=self.sharding[strata],
        chunk_batch_size=self.chunk_batch_size[strata],
        file_buffer_size=self.file_buffer_size[strata],
        observation_buffer_size=self.observation_buffer_size[strata],
        sample_rate=self.sample_rate[strata],
        replacement=self.replacement[strata],
        global_rank=global_rank,
        world_size=world_size,
    )

Writer

Writer(
    path: PathLike | str,
    flush_every_n_batches: int | None = None,
    postprocessor: Postprocessor | None = None,
)

Bases: BasePredictionWriter

Source code in src/json2vec/inference/callback.py
def __init__(
    self,
    path: os.PathLike | str,
    flush_every_n_batches: int | None = None,
    postprocessor: Postprocessor | None = None,
):
    super().__init__(write_interval="batch")

    self.path = Path(path)
    self.flush_every_n_batches: int | None = flush_every_n_batches
    self.postprocessor: Postprocessor | None = postprocessor
    self.schema: pa.Schema | None = None
    self.writer: pq.ParquetWriter | None = None

path instance-attribute

path = Path(path)

flush_every_n_batches instance-attribute

flush_every_n_batches: int | None = flush_every_n_batches

postprocessor instance-attribute

postprocessor: Postprocessor | None = postprocessor

schema instance-attribute

schema: Schema | None = None

writer instance-attribute

writer: ParquetWriter | None = None

write_on_batch_end

write_on_batch_end(
    trainer: Trainer,
    pl_module: Model,
    output: dict[str, list[Prediction]],
    batch_indices: list[int] | None,
    batch: TensorDict[Address, TensorFieldBase],
    batch_idx: int,
    dataloader_idx: int,
) -> None
Source code in src/json2vec/inference/callback.py
def write_on_batch_end(
    self,
    trainer: lit.Trainer,
    pl_module: Model,
    output: dict[str, list[Prediction]],
    batch_indices: list[int] | None,
    batch: TensorDict[Address, TensorFieldBase],
    batch_idx: int,
    dataloader_idx: int,
) -> None:  # ty:ignore[invalid-method-override]
    num_rows = len(batch[TensorKey.metadata])

    predictions: dict[Address, dict[str, Any]] = pl_module.write(predictions=output["predictions"])
    postprocessor = self.postprocessor

    if postprocessor is not None:
        context = {
            "input": batch,
            "batch": batch,
            TensorKey.metadata: batch[TensorKey.metadata],
            "batch_indices": batch_indices,
            "batch_idx": batch_idx,
            "dataloader_idx": dataloader_idx,
        }
        processed = postprocessor(context, predictions)

        if processed is not None:
            predictions = processed

    if len(predictions) == 0:
        predictions_frame = pl.DataFrame({"predictions": [None] * num_rows})
    else:
        columns: list[pl.DataFrame] = []
        for address, values in predictions.items():
            field_frame = pl.DataFrame(data=values)
            columns.append(field_frame.select(pl.struct(pl.all()).alias(name=address)))

        nested: pl.DataFrame = pl.concat(items=columns, how="horizontal")
        predictions_frame = nested.select(pl.struct(pl.all()).alias(name="predictions"))

    items = [
        pl.from_records(data=batch[TensorKey.metadata], schema=["inputs"], orient="row"),
        predictions_frame,
    ]

    table: pa.Table = pl.concat(items=items, how="horizontal").to_arrow()

    if self.writer is None:
        self.path.mkdir(parents=True, exist_ok=True)
        self.schema = table.schema

        self.writer = pq.ParquetWriter(
            where=self.path / f"rank-{trainer.local_rank}.parquet",
            schema=self.schema,
        )

    if table.schema != self.schema:
        table = table.cast(self.schema)

    self.writer.write_table(table)

    flush = getattr(self.writer, "flush", None)
    if self.flush_every_n_batches and (batch_idx + 1) % self.flush_every_n_batches == 0 and callable(flush):
        flush()

on_predict_end

on_predict_end(
    trainer: Trainer, pl_module: LightningModule
) -> None
Source code in src/json2vec/inference/callback.py
def on_predict_end(self, trainer: lit.Trainer, pl_module: lit.LightningModule) -> None:
    if self.writer:
        self.writer.close()
        self.writer = None

Preprocessor

Bases: BaseModel

Registered observation preprocessor.

A transformation preprocessor returns one dict. A generator preprocessor yields or returns multiple dict objects, each of which becomes a processed observation.

model_config class-attribute instance-attribute

model_config = ConfigDict(
    arbitrary_types_allowed=True, frozen=True
)

name instance-attribute

name: str

func instance-attribute

func: Callable[..., Any]

mode instance-attribute

mode: PreprocessorMode

accepted_kwargs cached staticmethod

accepted_kwargs(
    func: Callable[..., Any],
) -> tuple[bool, frozenset[str]]
Source code in src/json2vec/preprocessors/base.py
@staticmethod
@cache
def accepted_kwargs(func: Callable[..., Any]) -> tuple[bool, frozenset[str]]:
    signature = inspect.signature(func)
    accepts_variadic_kwargs = any(
        parameter.kind == inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values()
    )
    accepted = frozenset(signature.parameters.keys())
    return accepts_variadic_kwargs, accepted

filter_supported_kwargs classmethod

filter_supported_kwargs(
    func: Callable[..., Any], kwargs: dict[str, Any]
) -> dict[str, Any]
Source code in src/json2vec/preprocessors/base.py
@classmethod
def filter_supported_kwargs(cls, func: Callable[..., Any], kwargs: dict[str, Any]) -> dict[str, Any]:
    accepts_variadic_kwargs, accepted = cls.accepted_kwargs(func)
    if accepts_variadic_kwargs:
        return kwargs

    return {key: value for key, value in kwargs.items() if key in accepted}

register classmethod

register(
    func: Callable[..., Any], *, mode: PreprocessorMode
) -> Callable[..., Any]
Source code in src/json2vec/preprocessors/base.py
@classmethod
def register(cls, func: Callable[..., Any], *, mode: PreprocessorMode) -> Callable[..., Any]:
    name = getattr(func, "__name__", type(func).__name__)
    PREPROCESSORS[name] = cls(name=name, func=func, mode=mode)
    return func

outputs

outputs(
    observation: dict, **kwargs
) -> Iterator[list[dict[str, Any]]]

Yield normalized processed observations for one raw observation.

Source code in src/json2vec/preprocessors/base.py
def outputs(self, observation: dict, **kwargs) -> Iterator[list[dict[str, Any]]]:
    """Yield normalized processed observations for one raw observation."""
    result = self(observation, **kwargs)

    if self.mode == PreprocessorMode.transformation:
        yield [self.require_object(result, mode=self.mode)]
        return

    if self.mode == PreprocessorMode.generator:
        if isinstance(result, list):
            iterable: list[Any] | Iterator[Any] = result
        elif isinstance(result, Iterator):
            iterable = result
        else:
            raise TypeError(
                f"generator preprocessor '{self.name}' must yield dict objects or return a list of dict objects, "
                f"got {type(result).__name__}"
            )

        for output in iterable:
            yield [self.require_object(output, mode=self.mode)]
        return

    raise ValueError(f"unsupported preprocessor mode: {self.mode}")

require_object

require_object(
    output: Any, *, mode: PreprocessorMode
) -> dict[str, Any]
Source code in src/json2vec/preprocessors/base.py
def require_object(self, output: Any, *, mode: PreprocessorMode) -> dict[str, Any]:
    if not isinstance(output, dict):
        raise TypeError(f"{mode} preprocessor '{self.name}' must produce dict objects, got {type(output).__name__}")

    return output

PreprocessorMode

Bases: StrEnum

Execution mode for a registered preprocessor.

generator class-attribute instance-attribute

generator = 'generator'

transformation class-attribute instance-attribute

transformation = 'transformation'

from_yields classmethod

from_yields(yields: bool) -> 'PreprocessorMode'
Source code in src/json2vec/preprocessors/base.py
@classmethod
def from_yields(cls, yields: bool) -> "PreprocessorMode":
    if not isinstance(yields, bool):
        raise TypeError("yields must be a boolean")

    return cls.generator if yields else cls.transformation

AttentionMode

Bases: StrEnum

mha class-attribute instance-attribute

mha = 'mha'

gqa class-attribute instance-attribute

gqa = 'gqa'

mqa class-attribute instance-attribute

mqa = 'mqa'

none class-attribute instance-attribute

none = 'none'

normalize classmethod

normalize(value: 'AttentionMode | str') -> 'AttentionMode'
Source code in src/json2vec/structs/enums.py
@classmethod
def normalize(cls, value: "AttentionMode | str") -> "AttentionMode":
    if isinstance(value, cls):
        return value

    return cls(value.strip().lower())

kv_heads

kv_heads(n_heads: int) -> int
Source code in src/json2vec/structs/enums.py
def kv_heads(self, n_heads: int) -> int:
    match self:
        case AttentionMode.mha:
            return n_heads
        case AttentionMode.gqa:
            return max(1, n_heads // 2)
        case AttentionMode.mqa:
            return 1
        case AttentionMode.none:
            raise ValueError("attention mode 'none' does not define key/value heads")

Component

Bases: StrEnum

Request class-attribute instance-attribute

Request = 'Request'

Embedder class-attribute instance-attribute

Embedder = 'Embedder'

Decoder class-attribute instance-attribute

Decoder = 'Decoder'

TensorField class-attribute instance-attribute

TensorField = 'TensorField'

loss class-attribute instance-attribute

loss = 'loss'

write class-attribute instance-attribute

write = 'write'

Metric

Bases: StrEnum

accuracy class-attribute instance-attribute

accuracy = 'accuracy'

precision class-attribute instance-attribute

precision = 'precision'

recall class-attribute instance-attribute

recall = 'recall'

loss class-attribute instance-attribute

loss = 'loss'

sigma class-attribute instance-attribute

sigma = 'sigma'

throughput class-attribute instance-attribute

throughput = 'throughput'

mae class-attribute instance-attribute

mae = 'mae'

rmse class-attribute instance-attribute

rmse = 'rmse'

Overflow

Bases: StrEnum

head class-attribute instance-attribute

head = 'head'

tail class-attribute instance-attribute

tail = 'tail'

error class-attribute instance-attribute

error = 'error'

ShardingStrategy

Bases: StrEnum

file class-attribute instance-attribute

file = 'file'

chunk class-attribute instance-attribute

chunk = 'chunk'

record class-attribute instance-attribute

record = 'record'

normalize classmethod

normalize(
    value: "ShardingStrategy | str",
) -> "ShardingStrategy"
Source code in src/json2vec/structs/enums.py
@classmethod
def normalize(cls, value: "ShardingStrategy | str") -> "ShardingStrategy":
    if isinstance(value, cls):
        return value

    return cls(value.strip().lower())

expand classmethod

expand(
    value: "ShardingStrategy | str | Mapping[Strata | str, ShardingStrategy | str]",
    *,
    default: "ShardingStrategy",
) -> dict[Strata, "ShardingStrategy"]
Source code in src/json2vec/structs/enums.py
@classmethod
def expand(
    cls,
    value: "ShardingStrategy | str | Mapping[Strata | str, ShardingStrategy | str]",
    *,
    default: "ShardingStrategy",
) -> dict[Strata, "ShardingStrategy"]:
    return {strata: cls.normalize(strategy) for strata, strategy in Strata.expand(value, default=default).items()}

Strata

Bases: StrEnum

train class-attribute instance-attribute

train = 'train'

validate class-attribute instance-attribute

validate = 'validate'

test class-attribute instance-attribute

test = 'test'

predict class-attribute instance-attribute

predict = 'predict'

normalize classmethod

normalize(value: 'Strata | str') -> 'Strata'
Source code in src/json2vec/structs/enums.py
@classmethod
def normalize(cls, value: "Strata | str") -> "Strata":
    if isinstance(value, cls):
        return value

    return cls(str(value).strip().lower())

expand classmethod

expand(
    value: T | Mapping[Strata | str, T],
    *,
    default: DefaultT,
) -> dict[Strata, T | DefaultT]
Source code in src/json2vec/structs/enums.py
@classmethod
def expand(cls, value: T | Mapping[Strata | str, T], *, default: DefaultT) -> dict[Strata, T | DefaultT]:
    if isinstance(value, Mapping):
        normalized: dict[Strata, T | DefaultT] = {strata: default for strata in cls}
        for key, item in cast(Mapping[Strata | str, T], value).items():
            normalized[cls.normalize(key)] = item
        return normalized

    return {strata: value for strata in cls}

Suffix

Bases: StrEnum

feather class-attribute instance-attribute

feather = 'feather'

parquet class-attribute instance-attribute

parquet = 'parquet'

ndjson class-attribute instance-attribute

ndjson = 'ndjson'

avro class-attribute instance-attribute

avro = 'avro'

csv class-attribute instance-attribute

csv = 'csv'

orc class-attribute instance-attribute

orc = 'orc'

json class-attribute instance-attribute

json = 'json'

TensorKey

Bases: StrEnum

value class-attribute instance-attribute

value = 'value'

content class-attribute instance-attribute

content = 'content'

state class-attribute instance-attribute

state = 'state'

trainable class-attribute instance-attribute

trainable = 'trainable'

targets class-attribute instance-attribute

targets = 'targets'

metadata class-attribute instance-attribute

metadata = 'metadata'

intervals class-attribute instance-attribute

intervals = 'intervals'

probability class-attribute instance-attribute

probability = 'probability'

topk class-attribute instance-attribute

topk = 'topk'

embedding class-attribute instance-attribute

embedding = 'embedding'

Tokens

Bases: IntEnum

valued class-attribute instance-attribute

valued = 0

null class-attribute instance-attribute

null = 1

padded class-attribute instance-attribute

padded = 2

masked class-attribute instance-attribute

masked = 3

other class-attribute instance-attribute

other = 4

Hyperparameters

Bases: Node

Serializable schema and training metadata used to build a Model.

model_config class-attribute instance-attribute

model_config = ConfigDict(extra='forbid')

name class-attribute instance-attribute

name: Literal["hyperparameters"] = Field(
    default="hyperparameters", exclude=True
)

type class-attribute instance-attribute

type: Literal["hyperparameters"] = Field(
    default="hyperparameters", exclude=True
)

description class-attribute instance-attribute

description: Literal[None] = Field(
    default=None, exclude=True
)

d_model instance-attribute

d_model: Annotated[int, Field(gt=0, default=128)]

fields instance-attribute

fields: Array

embed property

embed: list[Address]

dropout class-attribute

dropout: None = None

target property

target: list[Address]

arrays cached property

arrays: dict[Address, Array]

requests cached property

requests: dict[Address, RequestTypes]

active_requests cached property

active_requests: dict[Address, RequestTypes]

shapes cached property

shapes: dict[Address, tuple[int, ...]]

depthwise cached property

depthwise: list[list[Address]]

update_values classmethod

update_values(values: Mapping[str, Any]) -> dict[str, Any]
Source code in src/json2vec/structs/experiment.py
@classmethod
def update_values(cls, values: Mapping[str, Any]) -> dict[str, Any]:
    normalized = dict(values)
    target = normalized.get("target", None)

    if target is None:
        return normalized

    if not isinstance(target, bool):
        raise ValueError("target must be a boolean")

    if target:
        if normalized.get("p_prune") not in (None, 1.0):
            raise ValueError("target=True is shorthand for p_prune=1.0")
    else:
        if "p_prune" in normalized and normalized["p_prune"] not in (None, 0.0):
            raise ValueError("target=False is shorthand for p_prune=0.0")

    return normalized

jmespath_member classmethod

jmespath_member(value: str) -> str
Source code in src/json2vec/structs/experiment.py
@classmethod
def jmespath_member(cls, value: str) -> str:
    if re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", value):
        return value
    return json.dumps(value)

query_for_source classmethod

query_for_source(
    array_path: tuple[str, ...], source: str
) -> str

Infer a request-level query for a leaf source field.

The encoder prepends the outer batch selector during search. Inferred queries therefore start at the processed-observation level: [*].amount, not [*][*].amount.

Source code in src/json2vec/structs/experiment.py
@classmethod
def query_for_source(cls, array_path: tuple[str, ...], source: str) -> str:
    """Infer a request-level query for a leaf source field.

    The encoder prepends the outer batch selector during search. Inferred
    queries therefore start at the processed-observation level: `[*].amount`,
    not `[*][*].amount`.
    """
    selectors = "".join(f".{cls.jmespath_member(array)}[*]" for array in array_path)
    return f"[*]{selectors}.{cls.jmespath_member(source)}"

request_from_leaf classmethod

request_from_leaf(leaf: Leaf) -> RequestTypes
Source code in src/json2vec/structs/experiment.py
@classmethod
def request_from_leaf(cls, leaf: Leaf) -> RequestTypes:
    from json2vec.tensorfields.base import TENSORFIELDS

    request_cls = getattr(TENSORFIELDS[leaf.type], "Request")
    return request_cls.model_validate(leaf.model_dump(mode="python", round_trip=True))

from_schema_node classmethod

from_schema_node(
    node: SchemaField, *, array_path: tuple[str, ...] = ()
) -> Array | RequestTypes
Source code in src/json2vec/structs/experiment.py
@classmethod
def from_schema_node(cls, node: SchemaField, *, array_path: tuple[str, ...] = ()) -> Array | RequestTypes:
    if isinstance(node, Leaf):
        source = node.name
        node_name = Node.sanitize_name(source)
        updates: dict[str, Any] = {}

        if node_name != source:
            updates["name"] = node_name
            if node.description is None:
                updates["description"] = source

        if node.query is None:
            updates["query"] = cls.query_for_source(array_path, source)

        return cls.request_from_leaf(node.model_copy(update=updates))

    if isinstance(node, Array):
        child_path = (*array_path, node.name)
        fields = [cls.from_schema_node(field, array_path=child_path) for field in node.fields]
        payload = node.model_dump(mode="python", round_trip=True, exclude={"fields", "masks"})
        return Array(*fields, masks=list(node.masks), **payload)

    raise TypeError("schema fields must be Array, Leaf, or concrete request instances")

from_schema classmethod

from_schema(
    *field_args: SchemaField,
    d_model: int,
    n_layers: int,
    n_heads: int,
    fields: Sequence[SchemaField] | None = None,
    name: str = "record",
    description: str | None = None,
    embed: bool = False,
    attention: AttentionMode | str = AttentionMode.mha,
    n_linear: Annotated[int, Field(gt=0)] = 1,
    dropout: Rate | None = None,
) -> Self

Build hyperparameters from schema fields.

Source code in src/json2vec/structs/experiment.py
@classmethod
def from_schema(
    cls,
    *field_args: SchemaField,
    d_model: int,
    n_layers: int,
    n_heads: int,
    fields: Sequence[SchemaField] | None = None,
    name: str = "record",
    description: str | None = None,
    embed: bool = False,
    attention: AttentionMode | str = AttentionMode.mha,
    n_linear: Annotated[int, pydantic.Field(gt=0)] = 1,
    dropout: Rate | None = None,
) -> Self:
    """Build hyperparameters from schema fields."""
    normalized = [*(fields or ()), *field_args]
    if not normalized:
        raise ValueError("from_schema requires at least one field")

    seen_sources: set[str] = set()
    root_fields: list[Array | RequestTypes] = []

    for field in normalized:
        if not isinstance(field, (Array, Leaf)):
            raise TypeError("schema fields must be Array, Leaf, or concrete request instances")

        source = field.name
        if source in seen_sources:
            raise ValueError(f"duplicate schema source field: {source}")
        seen_sources.add(source)

        root_fields.append(cls.from_schema_node(field))

    array = Array(
        name=name,
        description=description,
        embed=embed,
        attention=attention,
        n_layers=n_layers,
        n_heads=n_heads,
        n_linear=n_linear,
        max_length=1,
        overflow=Overflow.error,
        dropout=dropout,
        fields=root_fields,
    )
    return cls(d_model=d_model, fields=array)

model_post_init

model_post_init(__context)
Source code in src/json2vec/structs/experiment.py
def model_post_init(self, __context):
    def materialize(array: Array) -> Array:
        fields: list[Array | RequestTypes] = []
        for field in list(array.fields):
            field.parent = None

            if isinstance(field, Array):
                fields.append(materialize(field))
            elif type(field) is Leaf:
                fields.append(self.request_from_leaf(field))
            else:
                fields.append(field)

        array.fields = fields
        for field in array.fields:
            field.parent = array

        return array

    self.fields = materialize(self.fields)
    self.fields.max_length = 1
    self.fields.overflow = Overflow.error
    self.fields.parent: Self = self
    self._post_bind_validate()

overflows

overflows(address: Address) -> tuple[Overflow, ...]
Source code in src/json2vec/structs/experiment.py
def overflows(self, address: Address) -> tuple[Overflow, ...]:
    return (Overflow.error, *self.requests[Address(str(address))].overflows)

array_masks_for

array_masks_for(
    address: Address,
) -> tuple[tuple[Address, Mask], ...]
Source code in src/json2vec/structs/experiment.py
def array_masks_for(self, address: Address) -> tuple[tuple[Address, Mask], ...]:
    request = self.requests[Address(str(address))]
    applications: list[tuple[Address, Mask]] = []
    for array in [node for node in request.path if isinstance(node, Array)]:
        for mask in array.masks:
            if any(leaf is request for leaf in array.excluded_leaves(mask)):
                continue

            applications.append((array.address, mask))

    return tuple(applications)

clear_selection_cache

clear_selection_cache() -> None
Source code in src/json2vec/structs/experiment.py
def clear_selection_cache(self) -> None:
    self._selection_cache.clear()

refresh_selection_cache

refresh_selection_cache() -> None
Source code in src/json2vec/structs/experiment.py
def refresh_selection_cache(self) -> None:
    self._selection_cache = {
        key: entry.model_copy(
            update={
                "nodes": tuple(
                    node
                    for node in PreOrderIter(self.fields)
                    if (entry.include_root or node is not self.fields)
                    if entry.predicate(node)
                )
            }
        )
        for key, entry in self._selection_cache.items()
    }

select

select(
    *predicates: NodeSelector,
    include_root: bool = True,
    use_cache: bool = True,
) -> list[Node]
Source code in src/json2vec/structs/experiment.py
def select(
    self,
    *predicates: NodeSelector,
    include_root: bool = True,
    use_cache: bool = True,
) -> list[Node]:
    if predicates:
        normalized = tuple(NodePredicate.from_selector(item) for item in predicates)
        combined = NodePredicate(
            func=lambda node: all(item(node) for item in normalized),
            key=("and", tuple(item.key for item in normalized)),
            cacheable=all(item.cacheable for item in normalized),
        )
    else:
        combined = NodePredicate(func=lambda node: True, key=("all",))

    key = ("select", include_root, combined.key)

    if use_cache and combined.cacheable and key in self._selection_cache:
        return Selection(self._selection_cache[key].nodes)

    nodes = tuple(
        node for node in PreOrderIter(self.fields) if (include_root or node is not self.fields) if combined(node)
    )

    if use_cache and combined.cacheable:
        self._selection_cache[key] = SelectionCacheEntry(
            key=key,
            predicate=combined,
            include_root=include_root,
            nodes=nodes,
        )

    return Selection(nodes)

update

update(
    *predicates: NodeSelector,
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    use_cache: bool = False,
    **values: Any,
) -> None

Mutate matching schema nodes.

target=True is normalized to p_prune=1.0; target=False clears the target prune rate by setting p_prune=0.0.

Source code in src/json2vec/structs/experiment.py
def update(
    self,
    *predicates: NodeSelector,
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    use_cache: bool = False,
    **values: Any,
) -> None:
    """Mutate matching schema nodes.

    `target=True` is normalized to `p_prune=1.0`; `target=False` clears the
    target prune rate by setting `p_prune=0.0`.
    """
    values = self.update_values(values)
    if not values:
        raise ValueError("update requires at least one field value")

    nodes = self.select(*predicates, include_root=include_root, use_cache=use_cache)
    for node in nodes:
        can_apply_extra = allow_extra and getattr(type(node), "model_config", {}).get("extra") == "allow"
        missing = [name for name in values if not _has_model_attribute(node, name) and not can_apply_extra]
        if missing and strict:
            label = str(node.address) or node.name
            raise AttributeError(f"{label} has no attribute(s): {missing}")

        applicable_values = {
            name: value for name, value in values.items() if _has_model_attribute(node, name) or can_apply_extra
        }

        if validate and applicable_values:
            payload = node.model_dump(mode="python", round_trip=True)
            if isinstance(node, Array) and "masks" not in applicable_values:
                payload["masks"] = list(node.masks)
            if "target" in applicable_values and "p_prune" not in applicable_values:
                payload.pop("p_prune", None)
            payload.update(applicable_values)
            validated = type(node).model_validate(payload)
            applicable_values = {name: getattr(validated, name) for name in applicable_values}

        for name, value in applicable_values.items():
            setattr(node, name, value)
            if name in getattr(type(node), "model_fields", {}):
                node.model_fields_set.add(name)

    self._clear_tree_caches()
    self._post_bind_validate()
    self.refresh_selection_cache()

extend

extend(
    *args: ExtendArg,
    include_root: bool = True,
    use_cache: bool = True,
) -> None

Append new schema fields under the single array selected by predicates.

Source code in src/json2vec/structs/experiment.py
def extend(
    self,
    *args: ExtendArg,
    include_root: bool = True,
    use_cache: bool = True,
) -> None:
    """Append new schema fields under the single array selected by predicates."""
    predicates: list[NodeSelector] = []
    fields: list[SchemaField] = []
    reading_fields = False

    for item in args:
        if isinstance(item, (Array, Leaf)):
            reading_fields = True
            fields.append(item)
            continue

        if reading_fields:
            raise TypeError("extend predicates must come before new schema fields")

        predicates.append(item)

    if not fields:
        raise ValueError("extend requires at least one schema field")

    candidates = [
        node
        for node in self.select(*predicates, include_root=include_root, use_cache=use_cache)
        if isinstance(node, Array)
    ]

    if len(candidates) != 1:
        raise ValueError(f"extend requires exactly one matching array node, found {len(candidates)}")

    parent = candidates[0]
    array_path = tuple(node.name for node in parent.path[2:] if isinstance(node, Array))
    new_fields = [self.from_schema_node(field, array_path=array_path) for field in fields]
    existing_names = {field.name for field in parent.fields}
    duplicate_names = sorted({field.name for field in new_fields if field.name in existing_names})
    duplicate_names.extend(
        sorted(
            {
                field.name
                for index, field in enumerate(new_fields)
                if any(other.name == field.name for other in new_fields[index + 1 :])
            }
        )
    )
    if duplicate_names:
        raise ValueError(f"duplicate field name(s): {sorted(set(duplicate_names))}")

    original_fields = list(parent.fields)
    try:
        parent.fields.extend(new_fields)
        for field in new_fields:
            field.parent = parent

        self._clear_tree_caches()
        self._post_bind_validate()
    except Exception:
        parent.fields = original_fields
        for field in new_fields:
            field.parent = None
        self._clear_tree_caches()
        self._post_bind_validate()
        self.refresh_selection_cache()
        raise

    self.refresh_selection_cache()

delete

delete(
    *predicates: NodeSelector,
    include_root: bool = False,
    use_cache: bool = True,
) -> None

Permanently remove selected schema nodes from the tree.

Source code in src/json2vec/structs/experiment.py
def delete(
    self,
    *predicates: NodeSelector,
    include_root: bool = False,
    use_cache: bool = True,
) -> None:
    """Permanently remove selected schema nodes from the tree."""
    if not predicates:
        raise ValueError("delete requires at least one predicate")

    selected = self.select(*predicates, include_root=include_root, use_cache=use_cache)
    if not selected:
        raise ValueError("delete matched no nodes")
    if self.fields in selected:
        raise ValueError("delete cannot remove the root array")

    selected_ids = {id(node) for node in selected}
    roots = [
        node
        for node in selected
        if not any(
            id(ancestor) in selected_ids for ancestor in getattr(node, "ancestors", ()) if ancestor is not self
        )
    ]
    removed_by_id = {id(node): node for node in roots}
    for node in roots:
        removed_by_id.update({id(descendant): descendant for descendant in getattr(node, "descendants", ())})
    removed_addresses = {node.address for node in removed_by_id.values()}

    remaining_request_addresses = {address for address in self.requests if address not in removed_addresses}
    if not remaining_request_addresses:
        raise ValueError("delete would remove every request")

    remaining_array_addresses = {address for address in self.arrays if address not in removed_addresses}
    for address in remaining_array_addresses:
        prefix = f"{address}/"
        if not any(str(request_address).startswith(prefix) for request_address in remaining_request_addresses):
            raise ValueError(f"delete would leave array '{address}' without request descendants")

    for node in roots:
        parent = node.parent
        if not isinstance(parent, Array):
            raise ValueError(f"delete cannot remove '{node.address}' because it has no array parent")
        parent.fields = [field for field in parent.fields if field is not node]
        node.parent = None

    self._clear_tree_caches()
    self._post_bind_validate()
    self.refresh_selection_cache()

override

override(
    *predicates: NodeSelector,
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    use_cache: bool = False,
    **values: Any,
) -> Iterator[None]
Source code in src/json2vec/structs/experiment.py
@contextmanager
def override(
    self,
    *predicates: NodeSelector,
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    use_cache: bool = False,
    **values: Any,
) -> Iterator[None]:
    nodes = self.select(*predicates, include_root=include_root, use_cache=use_cache)
    normalized_values = self.update_values(values)
    snapshot = [
        (
            node,
            "p_prune" if name == "target" else name,
            getattr(node, "p_prune" if name == "target" else name, _MISSING),
            ("p_prune" if name == "target" else name) in getattr(node, "model_fields_set", set()),
        )
        for node in nodes
        for name in normalized_values
        if _has_model_attribute(node, name)
        or (allow_extra and getattr(type(node), "model_config", {}).get("extra") == "allow")
    ]

    self.update(
        *predicates,
        strict=strict,
        allow_extra=allow_extra,
        include_root=include_root,
        validate=validate,
        use_cache=use_cache,
        **normalized_values,
    )

    try:
        yield
    finally:
        for node, name, original, was_set in snapshot:
            if original is _MISSING:
                if getattr(node, name, _MISSING) is _MISSING:
                    continue
                delattr(node, name)
            else:
                setattr(node, name, original)
                if name in getattr(type(node), "model_fields", {}):
                    if was_set:
                        node.model_fields_set.add(name)
                    else:
                        node.model_fields_set.discard(name)

        self._clear_tree_caches()
        self._post_bind_validate()
        self.refresh_selection_cache()

NodeAttribute

Bases: BaseModel

Queryable schema node attribute returned by where(...).

model_config class-attribute instance-attribute

model_config = ConfigDict(frozen=True)

name class-attribute instance-attribute

name: str = Field(
    description="Queryable node attribute. Built-ins include name, type, address, parent, children, ancestors, descendants, and target. Pydantic fields and extra metadata fields are also queryable."
)

named classmethod

named(name: str) -> 'NodeAttribute'
Source code in src/json2vec/structs/selectors.py
@classmethod
def named(cls, name: str) -> "NodeAttribute":
    return cls(name=name)

get

get(node: Node, default: Any = None) -> Any
Source code in src/json2vec/structs/selectors.py
def get(self, node: Node, default: Any = None) -> Any:
    if self.name == "address":
        return str(node.address)
    if self.name == "parent":
        parent = getattr(node, "parent", None)
        return None if parent is None or not getattr(parent, "address", None) else str(parent.address)
    if self.name == "children":
        return tuple(str(child.address) for child in getattr(node, "children", ()))
    if self.name == "ancestors":
        return tuple(str(parent.address) for parent in getattr(node, "ancestors", ()) if parent.address)
    if self.name == "descendants":
        return tuple(str(child.address) for child in getattr(node, "descendants", ()))
    if self.name == "target":
        return isinstance(node, Leaf) and node.active and node.target

    extra = getattr(node, "model_extra", None) or {}
    if self.name in extra:
        return extra[self.name]

    return getattr(node, self.name, default)

exists

exists() -> NodePredicate
Source code in src/json2vec/structs/selectors.py
def exists(self) -> NodePredicate:
    return NodePredicate(
        func=lambda node: _has_model_attribute(node, self.name),
        key=("exists", self.name),
    )

is_in

is_in(values: Iterable[Any]) -> NodePredicate
Source code in src/json2vec/structs/selectors.py
def is_in(self, values: Iterable[Any]) -> NodePredicate:
    cached_values = tuple(values)
    return NodePredicate(
        func=lambda node: self.get(node) in cached_values,
        key=(
            "is_in",
            self.name,
            tuple(sorted((_cache_value(value) for value in cached_values), key=repr)),
        ),
    )

matches

matches(pattern: str | Pattern[str]) -> NodePredicate
Source code in src/json2vec/structs/selectors.py
def matches(self, pattern: str | re.Pattern[str]) -> NodePredicate:
    regex = re.compile(pattern) if isinstance(pattern, str) else pattern
    return NodePredicate(
        func=lambda node: regex.search(str(self.get(node, ""))) is not None,
        key=("matches", self.name, regex.pattern),
    )

contains

contains(value: Any) -> NodePredicate
Source code in src/json2vec/structs/selectors.py
def contains(self, value: Any) -> NodePredicate:
    return NodePredicate(
        func=lambda node: value in (self.get(node) or ()),
        key=("contains", self.name, _cache_value(value)),
    )

is_null

is_null() -> NodePredicate
Source code in src/json2vec/structs/selectors.py
def is_null(self) -> NodePredicate:
    return NodePredicate(
        func=lambda node: self.get(node) is None,
        key=("is_null", self.name),
    )

is_not_null

is_not_null() -> NodePredicate
Source code in src/json2vec/structs/selectors.py
def is_not_null(self) -> NodePredicate:
    return NodePredicate(
        func=lambda node: self.get(node) is not None,
        key=("is_not_null", self.name),
    )

NodePredicate

Bases: BaseModel

Composable predicate used to select schema nodes.

model_config class-attribute instance-attribute

model_config = ConfigDict(
    frozen=True, arbitrary_types_allowed=True
)

func instance-attribute

func: Callable[[Node], bool]

key instance-attribute

key: SelectionKey

cacheable class-attribute instance-attribute

cacheable: bool = True

from_callable classmethod

from_callable(
    key: str | tuple[Any, ...], func: Callable[[Node], bool]
) -> "NodePredicate"
Source code in src/json2vec/structs/selectors.py
@classmethod
def from_callable(cls, key: str | tuple[Any, ...], func: Callable[[Node], bool]) -> "NodePredicate":
    cache_key = key if isinstance(key, tuple) else ("callable", key)
    return cls(func=func, key=cache_key)

from_selector classmethod

from_selector(value: 'NodeSelector') -> 'NodePredicate'
Source code in src/json2vec/structs/selectors.py
@classmethod
def from_selector(cls, value: "NodeSelector") -> "NodePredicate":
    if isinstance(value, cls):
        return value

    if isinstance(value, NodeAttribute):
        return cls(
            func=lambda node: _has_model_attribute(node, value.name) and value.get(node) is True,
            key=("truthy", value.name),
        )

    if not callable(value):
        raise TypeError("node predicates must be where(...) expressions or callables")

    return cls(
        func=value,
        key=("callable", id(value)),
        cacheable=True,
    )

Array

Array(*children: Self | RequestTypes | Leaf, **data)

Bases: Node

Repeated nested object group in a json2vec schema.

Positional children are treated as fields inside the array.

Source code in src/json2vec/structs/structure.py
def __init__(self, *children: Self | RequestTypes | Leaf, **data):
    if children:
        if "fields" in data:
            raise TypeError("array children were provided both positionally and by keyword")
        data["fields"] = list(children)

    super().__init__(**data)

name instance-attribute

name: str

type class-attribute instance-attribute

type: Annotated[Literal["array"], Field(default=array)] = (
    "array"
)

attention class-attribute instance-attribute

attention: AttentionMode = mha

max_length class-attribute instance-attribute

max_length: Annotated[int, Field(gt=0, default=1)] = 1

overflow class-attribute instance-attribute

overflow: Overflow = head

n_linear class-attribute instance-attribute

n_linear: Annotated[int, Field(gt=0, default=1)] = 1

n_layers class-attribute instance-attribute

n_layers: Annotated[int, Field(gt=0, default=1)] = 1

masks class-attribute instance-attribute

masks: list[Mask] = Field(default_factory=list)

fields class-attribute instance-attribute

fields: list[Self | RequestTypes | InstanceOf[Leaf]] = (
    Field(default_factory=list)
)

normalize_mask_shorthand classmethod

normalize_mask_shorthand(data: Any) -> Any
Source code in src/json2vec/structs/structure.py
@pydantic.model_validator(mode="before")
@classmethod
def normalize_mask_shorthand(cls, data: Any) -> Any:
    if not isinstance(data, dict):
        return data

    values = dict(data)
    mask = values.pop("mask", None)
    if mask is None:
        return values

    if "masks" in values:
        raise ValueError("pass either mask or masks, not both")

    values["masks"] = [mask]
    return values

model_post_init

model_post_init(__context)
Source code in src/json2vec/structs/structure.py
def model_post_init(self, __context):
    for field in self.fields:
        field.parent: Self = self

check_unique_child_names

check_unique_child_names()
Source code in src/json2vec/structs/structure.py
@pydantic.model_validator(mode="after")
def check_unique_child_names(self):
    seen: set[str] = set()
    for field in self.fields:
        if field.name in seen:
            raise ValueError(f"duplicate field name: {field.name}")
        seen.add(field.name)

    return self

post_bind_validate

post_bind_validate()
Source code in src/json2vec/structs/structure.py
def post_bind_validate(self):
    if len(self.masks) == 0:
        return None

    is_root = getattr(getattr(self, "parent", None), "type", None) == "hyperparameters"
    if is_root:
        raise ValueError("Mask on the generated root array is not supported")

    active_leaves = [
        descendant
        for descendant in getattr(self, "descendants", ())
        if isinstance(descendant, Leaf) and getattr(descendant, "active", True)
    ]
    if not active_leaves:
        raise ValueError(f"array '{self.address}' has masks but no active descendant leaves")

    names = [mask.name for mask in self.masks if mask.name is not None]
    duplicates = sorted({name for name in names if names.count(name) > 1})
    if duplicates:
        raise ValueError(f"array '{self.address}' has duplicate mask name(s): {duplicates}")

    for mask in self.masks:
        if mask.offset >= self.max_length:
            raise ValueError(f"array '{self.address}' mask offset must be less than max_length={self.max_length}")

        excluded = self.excluded_leaves(mask)
        if len(excluded) == len(active_leaves):
            label = f" '{mask.name}'" if mask.name is not None else ""
            raise ValueError(f"array '{self.address}' mask{label} excludes every active descendant leaf")

    return None

excluded_leaves

excluded_leaves(mask: Mask) -> tuple[Leaf, ...]
Source code in src/json2vec/structs/structure.py
def excluded_leaves(self, mask: Mask) -> tuple[Leaf, ...]:
    if mask.exclude is None:
        return ()

    from json2vec.structs.selectors import NodePredicate

    predicates = tuple(NodePredicate.from_selector(item) for item in mask.exclude)
    return tuple(
        descendant
        for descendant in getattr(self, "descendants", ())
        if isinstance(descendant, Leaf)
        if getattr(descendant, "active", True)
        if any(predicate(descendant) for predicate in predicates)
    )

Mask

Bases: BaseModel

Structured masking policy attached to an Array.

model_config class-attribute instance-attribute

model_config = ConfigDict(arbitrary_types_allowed=True)

name class-attribute instance-attribute

name: str | None = None

rate class-attribute instance-attribute

rate: Annotated[float, Field(ge=0.0, le=1.0)] | None = None

count class-attribute instance-attribute

count: Annotated[int, Field(ge=0)] | None = None

window class-attribute instance-attribute

window: Annotated[int, Field(gt=0)] | None = None

array class-attribute instance-attribute

array: bool = False

start class-attribute instance-attribute

start: bool = False

offset class-attribute instance-attribute

offset: Annotated[int, Field(ge=0)] = 0

exclude class-attribute instance-attribute

exclude: Any = None

check_rate_or_count

check_rate_or_count()
Source code in src/json2vec/structs/structure.py
@pydantic.model_validator(mode="after")
def check_rate_or_count(self):
    if (self.rate is None) == (self.count is None):
        raise ValueError("Mask requires exactly one of rate or count")

    return self

normalize_exclude classmethod

normalize_exclude(value: Any) -> Any
Source code in src/json2vec/structs/structure.py
@pydantic.field_validator("exclude", mode="before")
@classmethod
def normalize_exclude(cls, value: Any) -> Any:
    if value is None:
        return None

    if isinstance(value, (list, tuple)):
        values = tuple(value)
    else:
        values = (value,)

    if any(isinstance(item, dict) and {"func", "key"}.issubset(item) for item in values):
        from json2vec.structs.selectors import NodePredicate

        return tuple(
            NodePredicate.model_validate(item)
            if isinstance(item, dict) and {"func", "key"}.issubset(item)
            else item
            for item in values
        )

    return values

serialize_exclude

serialize_exclude(value: Any) -> Any
Source code in src/json2vec/structs/structure.py
@pydantic.field_serializer("exclude")
def serialize_exclude(self, value: Any) -> Any:
    return value

Address

Bases: str

Slash-delimited stable path to a schema node.

Leaf

Leaf(name: str | None = None, **data: Any)

Bases: Node

Base tensorfield request node.

Concrete tensorfield constructors such as Number and Category inherit from this class through their registered request models.

Source code in src/json2vec/structs/tree.py
def __init__(self, name: str | None = None, **data: Any):
    if name is not None:
        if "name" in data:
            raise TypeError("name was provided both positionally and by keyword")
        data["name"] = name
    super().__init__(**data)

model_config class-attribute instance-attribute

model_config = ConfigDict(extra='allow')

active class-attribute instance-attribute

active: bool = True

embed class-attribute instance-attribute

embed: bool = False

name instance-attribute

name: str

type instance-attribute

type: str

query class-attribute instance-attribute

query: str | None = None

pooling class-attribute instance-attribute

pooling: Literal['query', 'mean'] = 'query'

weight class-attribute instance-attribute

weight: Annotated[float, Field(gt=0.0, default=1.0)] = 1.0

p_mask class-attribute instance-attribute

p_mask: Rate = 0.0

p_prune class-attribute instance-attribute

p_prune: PruneRate = 0.0

n_linear class-attribute instance-attribute

n_linear: Annotated[int, Field(gt=0, default=1)] = 1

target property writable

target: bool

shape cached property

shape: tuple[int, ...]

overflows cached property

overflows: tuple[Overflow, ...]

resolve_role_shorthands classmethod

resolve_role_shorthands(data: Any) -> Any
Source code in src/json2vec/structs/tree.py
@pydantic.model_validator(mode="before")
@classmethod
def resolve_role_shorthands(cls, data: Any) -> Any:
    if not isinstance(data, Mapping):
        return data

    values = dict(data)
    target = values.pop("target", None)

    if target is None:
        return values

    if not isinstance(target, bool):
        raise ValueError("target must be a boolean")

    if target:
        if values.get("p_prune") not in (None, 1.0):
            raise ValueError("target=True is shorthand for p_prune=1.0")
        values["p_prune"] = 1.0
    else:
        if values.get("p_prune") not in (None, 0.0):
            raise ValueError("target=False is shorthand for p_prune=0.0")
        values["p_prune"] = 0.0

    return values

merge_constructor_kwargs classmethod

merge_constructor_kwargs(data: Any) -> Any
Source code in src/json2vec/structs/tree.py
@pydantic.model_validator(mode="before")
@classmethod
def merge_constructor_kwargs(cls, data: Any) -> Any:
    if not isinstance(data, Mapping):
        return data

    values = dict(data)
    kwargs = values.pop("kwargs", None)

    if kwargs is None:
        return values
    if not isinstance(kwargs, Mapping):
        raise TypeError("kwargs must be a mapping")

    for key, value in kwargs.items():
        values.setdefault(key, value)

    return values

validate_type classmethod

validate_type(value: str) -> str
Source code in src/json2vec/structs/tree.py
@pydantic.field_validator("type")
@classmethod
def validate_type(cls, value: str) -> str:
    from json2vec.tensorfields import extensions as _extensions  # noqa: F401
    from json2vec.tensorfields.base import TENSORFIELDS

    if value not in TENSORFIELDS:
        raise ValueError(f"unknown tensor field type: {value}")

    return value

check_jmespath_query

check_jmespath_query()
Source code in src/json2vec/structs/tree.py
@pydantic.model_validator(mode="after")
def check_jmespath_query(self):
    if self.query is None:
        return self

    if not isinstance(self.query, str) or not self.query.strip():
        raise ValueError("query must be a non-empty string")

    try:
        jmespath.compile(self.query)
    except JMESPathError as e:
        raise ValueError(f"invalid jmespath query: {e}") from e

    return self

post_bind_validate

post_bind_validate()
Source code in src/json2vec/structs/tree.py
def post_bind_validate(self):
    if self.query is None:
        raise ValueError(f"request '{self.address}' must define query")

DecoderBase

DecoderBase(
    hyperparameters: Hyperparameters, address: Address
)

Bases: Module

Base class for tensorfield decoders.

Source code in src/json2vec/tensorfields/base.py
def __init__(self, hyperparameters: Hyperparameters, address: Address):
    super().__init__()

    self.address: Address = address
    self.sigma: torch.Tensor = torch.nn.Parameter(torch.zeros(1))

    request = hyperparameters.requests[address]
    n_context = 1
    for dimension in hyperparameters.shapes[address]:
        n_context *= dimension
    match request.pooling:
        case "query":
            self.pool = LearnedQueryCrossAttention(
                n_context=n_context,
                d_model=hyperparameters.d_model,
                nhead=request.n_heads,
                dropout=float(request.dropout or 0.0),
                n_linear=request.n_linear,
            )
        case "mean":
            self.pool = MeanPool(n_context=n_context)
        case _:
            raise ValueError(f"unsupported decoder pooling: {request.pooling}")

address instance-attribute

address: Address = address

sigma instance-attribute

sigma: Tensor = Parameter(zeros(1))

pool instance-attribute

pool = MeanPool(n_context=n_context)

decode

decode(
    pooled: Tensor,
) -> TensorDict[TensorKey, torch.Tensor]
Source code in src/json2vec/tensorfields/base.py
def decode(self, pooled: torch.Tensor) -> TensorDict[TensorKey, torch.Tensor]:
    raise NotImplementedError("decoder must implement decode(pooled)")

forward

forward(
    parcels: list[Parcel], *, embed: bool = False
) -> Prediction
Source code in src/json2vec/tensorfields/base.py
def forward(self, parcels: list[Parcel], *, embed: bool = False) -> Prediction:
    if len(parcels) == 0:
        raise ValueError("decoder requires at least one parcel")

    N, *_, C = parcels[0].payload.shape
    stacked = torch.cat([parcel.payload.reshape(N, -1, C) for parcel in parcels], dim=1)
    pooled = self.pool(stacked)

    payload = self.decode(pooled)
    if embed:
        payload[TensorKey.embedding] = pooled

    return Prediction(
        payload=payload,
        address=self.address,
        batch_size=pooled.shape[0],
    )

EmbedderBase

EmbedderBase(
    hyperparameters: Hyperparameters, address: Address
)

Bases: Module

Base class for tensorfield embedders.

Source code in src/json2vec/tensorfields/base.py
def __init__(self, hyperparameters: Hyperparameters, address: Address):
    super().__init__()

Plugin

Plugin(name: str)

Registry object for a tensorfield implementation.

Register request, tensorfield, embedder, decoder, loss, and write components with @plugin.register. Creating a plugin with an existing name replaces the registry entry and emits a warning.

Source code in src/json2vec/tensorfields/base.py
def __init__(self, name: str):
    if not isinstance(name, str):
        raise TypeError("Plugin name must be a string")

    # should start with a letter and contain only lowercase letters, numbers, and underscores
    if not re.match(r"^[a-z0-9_]+$", name):
        raise ValueError("Plugin name must consist of lowercase letters, numbers, and underscores only")

    self.name: str = name
    self.components: dict[Component, ComponentValue | None] = {}
    self.callback_factories: list[CallbackFactory] = []

    if name in TENSORFIELDS:
        warnings.warn(
            f"Plugin '{name}' already registered; overriding existing tensorfield plugin",
            UserWarning,
            stacklevel=2,
        )

    TENSORFIELDS[name] = self

name instance-attribute

name: str = name

components instance-attribute

components: dict[Component, ComponentValue | None] = {}

callback_factories instance-attribute

callback_factories: list[CallbackFactory] = []

callbacks property

callbacks: list[Callback]

Instantiate all registered callback factories.

Request property

Request: type[RequestBase]

TensorField property

TensorField: type[TensorFieldBase]

Embedder property

Embedder: type[EmbedderBase]

Decoder property

Decoder: type[DecoderBase]

loss property

loss: Callable[..., Any]

write property

write: Callable[..., Any]

register

register(obj: None, component: Component | str) -> None
register(
    obj: RegisterT, component: Component | str | None = None
) -> RegisterT
register(
    obj: RegisterT | None,
    component: Component | str | None = None,
) -> RegisterT | None

Register one tensorfield component with this plugin.

Source code in src/json2vec/tensorfields/base.py
def register(
    self,
    obj: RegisterT | None,
    component: Component | str | None = None,
) -> RegisterT | None:
    """Register one tensorfield component with this plugin."""
    if obj is None:
        if component is None:
            raise TypeError("component must be provided when registering None")

        key = Component(component)
        if key != Component.write:
            raise TypeError("only write may be registered as None")

        if key in self.components:
            raise ValueError(f"Component '{key}' already registered in plugin '{self.name}'")

        self.components[key] = None
        return None

    if not hasattr(obj, "__name__"):
        raise NameError(f"Object {obj} does not have a name")

    name: str = str(obj.__name__)
    try:
        key = Component(name)
    except ValueError:
        raise ValueError(f"Component '{name}' is not a valid Component enum value") from None

    if key in self.components:
        raise ValueError(f"Component '{key}' already registered in plugin '{self.name}'")

    match key:
        case Component.Request:
            if not isinstance(obj, type):
                raise TypeError("Request must be a class type")

            if not issubclass(obj, Node):
                raise TypeError("Request must be a subclass of Node")

        case Component.TensorField:
            if not isinstance(obj, type):
                raise TypeError("TensorField must be a class type")

            if not issubclass(obj, TensorFieldBase):
                raise TypeError("TensorField must be a subclass of TensorFieldBase")

        case Component.Embedder:
            if not isinstance(obj, type):
                raise TypeError("Embedder must be a class type")

            if not issubclass(obj, EmbedderBase):
                raise TypeError("Embedder must be a subclass of EmbedderBase")

            # confirm the init method is expecting hyperparameters and address
            init_params = list(obj.__init__.__annotations__.keys())
            if "hyperparameters" not in init_params or "address" not in init_params:
                raise TypeError("Embedder __init__ method must accept 'hyperparameters' and 'address' parameters")

        case Component.Decoder:
            if not isinstance(obj, type):
                raise TypeError("Decoder must be a class type")

            if not issubclass(obj, DecoderBase):
                raise TypeError("Decoder must be a subclass of DecoderBase")

            init_params = list(obj.__init__.__annotations__.keys())
            if "hyperparameters" not in init_params or "address" not in init_params:
                raise TypeError("Decoder __init__ method must accept 'hyperparameters' and 'address' parameters")

        case Component.loss:
            if not callable(obj):
                raise TypeError("Loss must be a callable function")

            expected_params: list[str] = ["module", "prediction", "batch", "strata"]
            func_params: list[str] = list(obj.__annotations__.keys())

            if not set(expected_params).issubset(set(func_params)):
                raise TypeError(
                    f"Loss function must accept the following parameters: {expected_params}, got {func_params}"
                )

        case Component.write:
            if obj is not None and not callable(obj):
                raise TypeError("Write must be a callable function")

            # check the signature of the function
            expected_params: list[str] = ["module", "prediction"]
            func_params: list[str] = list(obj.__annotations__.keys())

            if func_params != expected_params:
                raise TypeError(
                    f"Write function must accept the following parameters: {expected_params}, got {func_params}"
                )

    self.components[key] = obj

    return obj

callback

callback(factory: CallbackFactory) -> CallbackFactory
callback(
    factory: CallbackFactory, *factories: CallbackFactory
) -> tuple[CallbackFactory, ...]
callback(
    factory: CallbackFactory, *factories: CallbackFactory
)

Register one or more Lightning callback factories for this tensorfield.

Source code in src/json2vec/tensorfields/base.py
def callback(self, factory: CallbackFactory, *factories: CallbackFactory):
    """Register one or more Lightning callback factories for this tensorfield."""
    registered = (factory, *factories)
    for callback_factory in registered:
        callback = callback_factory()
        if not isinstance(callback, Callback):
            raise TypeError(f"Plugin callback factory for '{self.name}' must produce a Lightning Callback")

    self.callback_factories.extend(registered)
    return factory if len(registered) == 1 else registered

TensorFieldBase

Bases: Renderable

Tensorized field values plus trainable target state.

STATE_PREVIEW_LIMIT class-attribute instance-attribute

STATE_PREVIEW_LIMIT: int = 80

STATE_LABELS class-attribute instance-attribute

STATE_LABELS: dict[int, str] = {
    value: "V",
    value: "N",
    value: "P",
    value: "M",
    value: "O",
}

STATE_STYLES class-attribute instance-attribute

STATE_STYLES: dict[int, str] = {
    value: "bold green",
    value: "bold yellow",
    value: "dim",
    value: "bold magenta",
    value: "bold cyan",
}

content instance-attribute

content: Tensor

state instance-attribute

state: Tensor

trainable instance-attribute

trainable: Tensor

targets instance-attribute

targets: TensorDict[TensorKey, Tensor]

new abstractmethod classmethod

new(
    values: list,
    address: Address,
    hyperparameters: Hyperparameters,
    strata: Strata,
) -> "TensorFieldBase"
Source code in src/json2vec/tensorfields/base.py
@classmethod
@abstractmethod
def new(
    cls,
    values: list,
    address: Address,
    hyperparameters: Hyperparameters,
    strata: Strata,
) -> "TensorFieldBase":
    raise NotImplementedError

empty abstractmethod classmethod

empty(
    batch_size: int,
    address: Address,
    hyperparameters: Hyperparameters,
) -> "TensorFieldBase"
Source code in src/json2vec/tensorfields/base.py
@classmethod
@abstractmethod
def empty(
    cls,
    batch_size: int,
    address: Address,
    hyperparameters: Hyperparameters,
) -> "TensorFieldBase":
    raise NotImplementedError

mask abstractmethod

mask(p_mask: float = 0.0, **kwargs: Any)
Source code in src/json2vec/tensorfields/base.py
@abstractmethod
def mask(self, p_mask: float = 0.0, **kwargs: Any):
    raise NotImplementedError

target abstractmethod

target(p_prune: float = 1.0)
Source code in src/json2vec/tensorfields/base.py
@abstractmethod
def target(self, p_prune: float = 1.0):
    raise NotImplementedError

hide

hide(
    selected: Tensor,
    *,
    cache_targets: bool = True,
    trainable: bool = True,
) -> None
Source code in src/json2vec/tensorfields/base.py
def hide(self, selected: torch.Tensor, *, cache_targets: bool = True, trainable: bool = True) -> None:
    raise NotImplementedError

Category

Bases: RequestBase

Categorical scalar tensorfield request backed by an online vocabulary.

type class-attribute instance-attribute

type: Literal['category'] = 'category'

max_vocab_size class-attribute instance-attribute

max_vocab_size: Annotated[
    int, Field(gt=0, default=10000)
] = 10000

p_unavailable class-attribute instance-attribute

p_unavailable: Annotated[
    float, Field(ge=0.0, le=1.0, default=0.01)
] = 0.01

topk class-attribute instance-attribute

topk: list[int] | None = None

reject_removed_options classmethod

reject_removed_options(data: Any) -> Any
Source code in src/json2vec/tensorfields/extensions/category.py
@pydantic.model_validator(mode="before")
@classmethod
def reject_removed_options(cls, data: Any) -> Any:
    if isinstance(data, Mapping) and "n_bands" in data:
        raise ValueError("Category does not support n_bands")

    return data

check_topk

check_topk()
Source code in src/json2vec/tensorfields/extensions/category.py
@pydantic.model_validator(mode="after")
def check_topk(self):
    if self.topk is None:
        self.topk = []

    # enforce uniqueness
    self.topk = sorted(set(self.topk))

    for topk in self.topk:
        if not isinstance(topk, int):
            raise ValueError("topk values must be integers")

        if topk <= 0:
            raise ValueError("topk values must be positive")

        if topk == 1:
            raise ValueError("topk values must not be 1")

        if topk >= self.max_vocab_size:
            raise ValueError("topk values must be less than max_vocab_size")

    return self

DateParts

Bases: RequestBase

Date/time tensorfield request that extracts configured calendar parts.

type class-attribute instance-attribute

type: Literal['dateparts'] = 'dateparts'

dateparts instance-attribute

dateparts: list[DatePart]

pattern class-attribute instance-attribute

pattern: Annotated[str | None, Field(default=None)] = None

check_dateparts classmethod

check_dateparts(v)
Source code in src/json2vec/tensorfields/extensions/dateparts.py
@pydantic.field_validator("dateparts", check_fields=False)
@classmethod
def check_dateparts(cls, v):
    if not v:
        raise ValueError("dateparts cannot be empty")

    if not len(v) == len(set(v)):
        raise ValueError("dateparts must be unique")

    return v

check_date_pattern classmethod

check_date_pattern(v)
Source code in src/json2vec/tensorfields/extensions/dateparts.py
@pydantic.field_validator("pattern", check_fields=False)
@classmethod
def check_date_pattern(cls, v):
    if v is None:
        return v

    regex: re.Pattern = re.compile(r"^(?:%%| %(?:[aAwdbBmyYHIpMSfzZjUWcxXGuV])|[^%])+$", re.VERBOSE)

    if not bool(regex.fullmatch(v)):
        raise ValueError(f"{v} is not a valid format pattern")

    return v

Entity

Bases: RequestBase

Per-observation entity tensorfield request for local identity matching.

type class-attribute instance-attribute

type: Literal['entity'] = 'entity'

topk class-attribute instance-attribute

topk: list[int] | None = None

check_topk

check_topk()
Source code in src/json2vec/tensorfields/extensions/entity.py
@pydantic.model_validator(mode="after")
def check_topk(self):
    if self.topk is None:
        self.topk = []

    for topk in self.topk:
        if not isinstance(topk, int):
            raise ValueError("topk values must be integers")

        if topk <= 0:
            raise ValueError("topk values must be positive")

        if topk == 1:
            raise ValueError("topk values must not be 1")

    return self

post_bind_validate

post_bind_validate()
Source code in src/json2vec/tensorfields/extensions/entity.py
def post_bind_validate(self):
    per_observation_count: int = math.prod(self.shape)
    if per_observation_count <= 1:
        raise ValueError(
            f"entity field at '{self.address}' requires at least 2 elements per observation, "
            f"but configured count is {per_observation_count}"
        )

Number

Bases: RequestBase

Numeric scalar tensorfield request.

type class-attribute instance-attribute

type: Literal['number'] = 'number'

jitter class-attribute instance-attribute

jitter: Annotated[float, Field(ge=0.0, default=0.0)] = 0.0

n_bands class-attribute instance-attribute

n_bands: Annotated[int, Field(gt=0, default=8)] = 8

offset class-attribute instance-attribute

offset: Annotated[int, Field(gt=0, default=4)] = 4

alpha class-attribute instance-attribute

alpha: Annotated[
    float | None, Field(gt=0.0, lt=1.0, default=None)
] = None

objective class-attribute instance-attribute

objective: Objective = mae

Set

Bases: RequestBase

Multi-label set tensorfield request backed by an online vocabulary.

type class-attribute instance-attribute

type: Literal['set'] = 'set'

max_vocab_size class-attribute instance-attribute

max_vocab_size: Annotated[
    int, Field(gt=0, default=10000)
] = 10000

p_unavailable class-attribute instance-attribute

p_unavailable: Annotated[
    float, Field(ge=0.0, le=1.0, default=0.01)
] = 0.01

threshold class-attribute instance-attribute

threshold: Annotated[
    float | None, Field(ge=0.0, le=1.0, default=None)
] = None

Text

Bases: RequestBase

Text tensorfield request encoded by a frozen Hugging Face model.

type class-attribute instance-attribute

type: Literal['text'] = 'text'

model_name instance-attribute

model_name: str

max_length class-attribute instance-attribute

max_length: Annotated[int, Field(gt=0, default=128)] = 128

encoder_batch_size class-attribute instance-attribute

encoder_batch_size: Annotated[
    int, Field(gt=0, default=32)
] = 32

encoder_pooling class-attribute instance-attribute

encoder_pooling: Pooling = cls

objective class-attribute instance-attribute

objective: Objective = l2

revision class-attribute instance-attribute

revision: str | None = None

local_files_only class-attribute instance-attribute

local_files_only: bool = False

normalize_model_name classmethod

normalize_model_name(value: str)
Source code in src/json2vec/tensorfields/extensions/text.py
@pydantic.field_validator("model_name", mode="before")
@classmethod
def normalize_model_name(cls, value: str):
    if not isinstance(value, str):
        raise ValueError("model_name must be a string")

    return value.strip()

normalize_revision classmethod

normalize_revision(value: str | None)
Source code in src/json2vec/tensorfields/extensions/text.py
@pydantic.field_validator("revision", mode="before")
@classmethod
def normalize_revision(cls, value: str | None):
    if value is None:
        return None

    if not isinstance(value, str):
        raise ValueError("revision must be a string when provided")

    normalized = value.strip()
    return normalized or None

check_model_name

check_model_name()
Source code in src/json2vec/tensorfields/extensions/text.py
@pydantic.model_validator(mode="after")
def check_model_name(self):
    if not self.model_name:
        raise ValueError("model_name must be a non-empty string")

    return self

Vector

Bases: RequestBase

Fixed-width numeric vector tensorfield request.

type class-attribute instance-attribute

type: Literal['vector'] = 'vector'

n_dim instance-attribute

n_dim: Annotated[int, Field(gt=0)]

objective class-attribute instance-attribute

objective: Objective = l2

VocabularySyncCallback

Bases: Callback

Synchronize online vocabularies registered by tensorfield extensions.

on_fit_start class-attribute instance-attribute

on_fit_start = partialmethod(sync, reason='fit_start')

on_train_epoch_end class-attribute instance-attribute

on_train_epoch_end = partialmethod(
    sync, reason="train_epoch_end"
)

on_fit_end

on_fit_end(trainer: Trainer, pl_module: Model) -> None
Source code in src/json2vec/tensorfields/shared/vocabulary.py
def on_fit_end(self, trainer: Trainer, pl_module: Model) -> None:  # ty:ignore[invalid-method-override]
    for vocabulary in OnlineVocabularyModel.from_model(pl_module).values():
        vocabulary.freeze()

Accelerator

Bases: StrEnum

auto class-attribute instance-attribute

auto = 'auto'

cpu class-attribute instance-attribute

cpu = 'cpu'

cuda class-attribute instance-attribute

cuda = 'cuda'

mps class-attribute instance-attribute

mps = 'mps'

Deployment

Bases: BaseSettings

Serving configuration for a json2vec checkpoint or model instance.

Deployment queues request/response schemas, optional preprocessors, optional postprocessors, and update(...) mutations before the model is loaded by FastAPI application startup.

model_config class-attribute instance-attribute

model_config = SettingsConfigDict(
    extra="ignore",
    case_sensitive=False,
    validate_by_name=True,
    validate_by_alias=True,
    arbitrary_types_allowed=True,
)

checkpoint class-attribute instance-attribute

checkpoint: ModelSource = Field(
    default="model.ckpt",
    validation_alias=AliasChoices(
        "JSON2VEC_CHECKPOINT", "CHECKPOINT"
    ),
)

model class-attribute instance-attribute

model: Model | None = Field(default=None, exclude=True)

max_batch_size class-attribute instance-attribute

max_batch_size: int = Field(
    default=128,
    ge=1,
    validation_alias=AliasChoices(
        "JSON2VEC_MAX_BATCH_SIZE", "MAX_BATCH_SIZE"
    ),
)

batch_timeout class-attribute instance-attribute

batch_timeout: float = Field(
    default=0.0,
    ge=0.0,
    validation_alias=AliasChoices(
        "JSON2VEC_BATCH_TIMEOUT", "BATCH_TIMEOUT"
    ),
)

workers class-attribute instance-attribute

workers: int = Field(
    default=1,
    ge=1,
    validation_alias=AliasChoices(
        "JSON2VEC_WORKERS", "WORKERS"
    ),
)

accelerator class-attribute instance-attribute

accelerator: Accelerator = Field(
    default=auto,
    validation_alias=AliasChoices(
        "JSON2VEC_ACCELERATOR", "ACCELERATOR"
    ),
)

host class-attribute instance-attribute

host: str = Field(
    default="0.0.0.0",
    validation_alias=AliasChoices("JSON2VEC_HOST", "HOST"),
)

port class-attribute instance-attribute

port: int = Field(
    default=8000,
    ge=1,
    le=65535,
    validation_alias=AliasChoices("JSON2VEC_PORT", "PORT"),
)

log_level class-attribute instance-attribute

log_level: str = Field(
    default="info",
    validation_alias=AliasChoices(
        "JSON2VEC_LOG_LEVEL", "LOG_LEVEL"
    ),
)

monitor_queries class-attribute instance-attribute

monitor_queries: bool = Field(
    default=False,
    validation_alias=AliasChoices(
        "JSON2VEC_MONITOR_QUERIES", "MONITOR_QUERIES"
    ),
)

query_monitor_every class-attribute instance-attribute

query_monitor_every: int = Field(
    default=1000,
    gt=0,
    validation_alias=AliasChoices(
        "JSON2VEC_QUERY_MONITOR_EVERY",
        "QUERY_MONITOR_EVERY",
    ),
)

json_backend class-attribute instance-attribute

json_backend: JSONBackend = Field(
    default=orjson,
    validation_alias=AliasChoices(
        "JSON2VEC_JSON_BACKEND", "JSON_BACKEND"
    ),
)

strip_checkpoint classmethod

strip_checkpoint(value: Any) -> Any
Source code in src/json2vec/inference/deployment.py
@field_validator("checkpoint", mode="before")
@classmethod
def strip_checkpoint(cls, value: Any) -> Any:
    if isinstance(value, str):
        stripped = value.strip()
        if stripped == "":
            raise ValueError("checkpoint must not be blank")
        return stripped

    return value

check_model_source

check_model_source() -> Deployment
Source code in src/json2vec/inference/deployment.py
@model_validator(mode="after")
def check_model_source(self) -> "Deployment":
    if self.model is not None and "checkpoint" in self.model_fields_set:
        raise ValueError("pass either checkpoint or model, not both")

    return self

forge

forge(
    request: type[BaseModel] | None = None,
    response: type[BaseModel] | None = None,
) -> Deployment

Attach optional Pydantic request and response signatures.

Source code in src/json2vec/inference/deployment.py
@beartype
def forge(
    self,
    request: type[pydantic.BaseModel] | None = None,
    response: type[pydantic.BaseModel] | None = None,
) -> "Deployment":
    """Attach optional Pydantic request and response signatures."""
    self._request_signature = request
    self._response_signature = response

    return self

preprocess

preprocess(preprocessor, **kwargs: Any) -> Deployment

Attach an optional request preprocessor.

If this method is not called, request objects are encoded unchanged.

Source code in src/json2vec/inference/deployment.py
@beartype
def preprocess(self, preprocessor, **kwargs: Any) -> "Deployment":
    """Attach an optional request preprocessor.

    If this method is not called, request objects are encoded unchanged.
    """
    self._preprocessor = functools.partial(preprocessor, **kwargs) if kwargs else preprocessor

    return self

postprocess

postprocess(postprocessor, **kwargs: Any) -> Deployment

Attach an optional response postprocessor.

Source code in src/json2vec/inference/deployment.py
@beartype
def postprocess(self, postprocessor, **kwargs: Any) -> "Deployment":
    """Attach an optional response postprocessor."""
    self._postprocessor = functools.partial(postprocessor, **kwargs) if kwargs else postprocessor

    return self

update

update(
    *predicates: NodePredicate
    | NodeAttribute
    | Callable[[Node], bool],
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    **values: Any,
) -> Deployment

Queue a model schema mutation to apply during server startup.

This mirrors Model.update(...) and is useful for serving-time changes such as target=False.

Source code in src/json2vec/inference/deployment.py
@beartype
def update(
    self,
    *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    **values: Any,
) -> "Deployment":
    """Queue a model schema mutation to apply during server startup.

    This mirrors `Model.update(...)` and is useful for serving-time changes
    such as `target=False`.
    """
    self._update_operations.append(
        (
            tuple(predicates),
            {
                "strict": strict,
                "allow_extra": allow_extra,
                "include_root": include_root,
                "validate": validate,
                **values,
            },
        )
    )

    return self

app

app() -> fastapi.FastAPI

Build a FastAPI app for the configured checkpoint or model.

Source code in src/json2vec/inference/deployment.py
def app(self) -> fastapi.FastAPI:
    """Build a FastAPI app for the configured checkpoint or model."""
    runtime = FastAPIRuntime(
        checkpoint=self.model if self.model is not None else self.checkpoint,
        accelerator=self.accelerator,
        preprocessor=self._preprocessor,
        postprocessor=self._postprocessor,
        update_operations=self._update_operations,
        request_signature=self._request_signature,
        response_signature=self._response_signature,
        monitor_queries=self.monitor_queries,
        query_monitor_every=self.query_monitor_every,
    )
    batcher = FastAPIBatcher(
        runtime=runtime,
        max_batch_size=self.max_batch_size,
        batch_timeout=self.batch_timeout,
    )

    @asynccontextmanager
    async def lifespan(app: fastapi.FastAPI):
        await batcher.start()
        app.state.json2vec_runtime = runtime
        app.state.json2vec_batcher = batcher
        try:
            yield
        finally:
            await batcher.stop()

    app = fastapi.FastAPI(lifespan=lifespan)

    @app.get("/health")
    async def health() -> dict[str, str]:
        return {"status": "ok"}

    def json_response(content: Any, status_code: int = 200) -> fastapi.Response:
        if self.json_backend == JSONBackend.orjson:
            return fastapi.Response(
                content=orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS),
                status_code=status_code,
                media_type="application/json",
            )

        return JSONResponse(content=content, status_code=status_code)

    @app.post("/predict")
    async def predict(request: fastapi.Request) -> fastapi.Response:
        try:
            if self.json_backend == JSONBackend.orjson:
                payload = orjson.loads(await request.body())
            else:
                payload = await request.json()
        except Exception as exception:
            return json_response(
                status_code=400,
                content={
                    "predictions": {},
                    "error": {
                        "status_code": 400,
                        "message": str(exception),
                    },
                },
            )

        if isinstance(payload, dict):
            response = await batcher.submit(cast(dict[str, Any], payload))
            return json_response(content=response)

        if isinstance(payload, list):
            for index, item in enumerate(payload):
                if not isinstance(item, dict):
                    return json_response(
                        status_code=422,
                        content={
                            "predictions": {},
                            "error": {
                                "status_code": 422,
                                "message": (
                                    "request body must be a JSON object or an array of JSON objects; "
                                    f"item {index} is {type(item).__name__}"
                                ),
                            },
                        },
                    )

            responses = await batcher.submit_many(cast(list[dict[str, Any]], payload))
            return json_response(content=responses)

        return json_response(
            status_code=422,
            content={
                "predictions": {},
                "error": {
                    "status_code": 422,
                    "message": f"request body must be a JSON object or an array of JSON objects, got {type(payload).__name__}",
                },
            },
        )

    return app

serve

serve() -> None

Start the FastAPI server for the configured checkpoint or model.

Source code in src/json2vec/inference/deployment.py
def serve(self) -> None:
    """Start the FastAPI server for the configured checkpoint or model."""
    if self.workers > 1:
        if self.model is not None or isinstance(self.checkpoint, Model):
            raise ValueError("workers > 1 requires a checkpoint path, not an in-memory model")

        env_updates = {
            "JSON2VEC_CHECKPOINT": str(self.checkpoint),
            "JSON2VEC_MAX_BATCH_SIZE": str(self.max_batch_size),
            "JSON2VEC_BATCH_TIMEOUT": str(self.batch_timeout),
            "JSON2VEC_ACCELERATOR": self.accelerator.value,
            "JSON2VEC_MONITOR_QUERIES": str(self.monitor_queries),
            "JSON2VEC_QUERY_MONITOR_EVERY": str(self.query_monitor_every),
            "JSON2VEC_JSON_BACKEND": self.json_backend.value,
        }
        previous = {key: os.environ.get(key) for key in env_updates}
        try:
            os.environ.update(env_updates)
            uvicorn.run(
                "json2vec.inference.deployment:create_app",
                factory=True,
                workers=self.workers,
                host=self.host,
                port=self.port,
                log_level=self.log_level,
            )
        finally:
            for key, value in previous.items():
                if value is None:
                    os.environ.pop(key, None)
                else:
                    os.environ[key] = value
        return

    uvicorn.run(
        self.app(),
        host=self.host,
        port=self.port,
        log_level=self.log_level,
    )

JSONBackend

Bases: StrEnum

orjson class-attribute instance-attribute

orjson = 'orjson'

stdlib class-attribute instance-attribute

stdlib = 'stdlib'

preprocess

preprocess(
    func: Callable[..., Any] | None = None,
    *,
    yields: bool | None = None,
    **kwargs: Any,
) -> Callable[..., Any]

Register a callable as a json2vec preprocessor.

Parameters:

Name Type Description Default
func Callable[..., Any] | None

Callable to register when used as @preprocess.

None
yields bool | None

Set to True for generator preprocessors.

None
**kwargs Any

Reserved for validation of unsupported decorator arguments.

{}

Returns:

Type Description
Callable[..., Any]

The original callable, after registering it in PREPROCESSORS.

Example
import json2vec as j2v

@j2v.preprocess
def normalize(record: dict) -> dict:
    return {**record, "amount": float(record["amount"])}
Source code in src/json2vec/preprocessors/base.py
def preprocess(
    func: Callable[..., Any] | None = None,
    *,
    yields: bool | None = None,
    **kwargs: Any,
) -> Callable[..., Any]:
    """Register a callable as a `json2vec` preprocessor.

    Args:
        func: Callable to register when used as `@preprocess`.
        yields: Set to `True` for generator preprocessors.
        **kwargs: Reserved for validation of unsupported decorator arguments.

    Returns:
        The original callable, after registering it in `PREPROCESSORS`.

    Example:
        ```python
        import json2vec as j2v

        @j2v.preprocess
        def normalize(record: dict) -> dict:
            return {**record, "amount": float(record["amount"])}
        ```
    """
    if "yield" in kwargs:
        if yields is not None:
            raise TypeError("use either 'yields' or 'yield', not both")
        yields = kwargs.pop("yield")

    if kwargs:
        unexpected = ", ".join(sorted(kwargs))
        raise TypeError(f"unexpected preprocess keyword argument(s): {unexpected}")

    if yields is None:
        yields = False

    mode = PreprocessorMode.from_yields(yields)

    def decorator(inner: Callable[..., Any]) -> Callable[..., Any]:
        return Preprocessor.register(inner, mode=mode)

    if func is None:
        return decorator

    if not callable(func):
        raise TypeError("preprocess can only decorate callables")

    return decorator(func)

predicate

predicate(
    key: str | tuple[Any, ...], func: Callable[[Node], bool]
) -> NodePredicate

Create a cacheable node predicate from a callable.

Source code in src/json2vec/structs/selectors.py
def predicate(key: str | tuple[Any, ...], func: Callable[[Node], bool]) -> NodePredicate:
    """Create a cacheable node predicate from a callable."""
    return NodePredicate.from_callable(key=key, func=func)

where

where(name: str) -> NodeAttribute

Start a schema predicate against a node attribute.

Source code in src/json2vec/structs/selectors.py
def where(name: str) -> NodeAttribute:
    """Start a schema predicate against a node attribute."""
    return NodeAttribute.named(name)

Model

Model

Model(
    hyperparameters: Hyperparameters,
    *,
    batch_size: int = 1,
    optimizer: OptimizerConfig | None = None,
    scheduler: SchedulerConfig | None = None,
)

Bases: LightningModule, Renderable

Neural model generated from a json2vec schema tree.

Model owns the schema hyperparameters, tensorfield embedders, array encoders, decoders, and convenience methods for prediction, checkpointing, schema display and mutation.

Example
import json2vec as j2v

model = j2v.Model.from_schema(
    j2v.Category("segment", max_vocab_size=32),
    j2v.Category("label", target=True, max_vocab_size=4),
    d_model=16,
    n_layers=1,
    n_heads=4,
    batch_size=8,
    embed=True,
)
Source code in src/json2vec/architecture/root.py
@beartype
def __init__(
    self,
    hyperparameters: Hyperparameters,
    *,
    batch_size: int = 1,
    optimizer: OptimizerConfig | None = None,
    scheduler: SchedulerConfig | None = None,
):
    super().__init__()
    if batch_size <= 0:
        raise ValueError("batch_size must be > 0")

    self.hyperparameters: Hyperparameters = hyperparameters
    self.batch_size: int = batch_size
    self.optimizer: OptimizerConfig | None = optimizer
    self.scheduler: SchedulerConfig | None = scheduler
    self.locks: Counter[str | Strata] = Counter()
    self.nodes: torch.nn.ModuleDict = torch.nn.ModuleDict()
    self.schema: SchemaEditor = SchemaEditor(self)
    self._contract_generation: int = 0
    self._contract_scheduler: ContractScheduler = ContractScheduler()

    self._build()

    logger.bind(
        component="model",
        batch_size=self.batch_size,
        requests=len(self.hyperparameters.active_requests),
        arrays=len(self.hyperparameters.arrays),
        embeds=len(self.hyperparameters.embed),
    ).info("initialized Model module")

from_schema classmethod

from_schema(
    *field_args: SchemaField,
    d_model: int,
    n_layers: int,
    n_heads: int,
    batch_size: int = 1,
    fields: Sequence[SchemaField] | None = None,
    name: str = "record",
    description: str | None = None,
    embed: bool = False,
    attention: AttentionMode | str = AttentionMode.mha,
    n_linear: int = 1,
    dropout: Rate | None = None,
    optimizer: OptimizerConfig | None = None,
    scheduler: SchedulerConfig | None = None,
) -> Self

Build a model directly from schema fields.

Parameters:

Name Type Description Default
*field_args SchemaField

Field constructors such as Category, Number, or nested Array nodes.

()
d_model int

Shared model width.

required
n_layers int

Number of encoder layers on generated array nodes.

required
n_heads int

Attention heads used by generated nodes.

required
batch_size int

Batch size used by data modules, examples, and mocked Lightning input arrays.

1
fields Sequence[SchemaField] | None

Optional sequence form of field_args.

None
name str

Root array name. Defaults to record.

'record'
description str | None

Optional description on the generated root array.

None
embed bool

Configure the generated root array as an embedding output.

False
attention AttentionMode | str

Attention mode for the generated root array.

mha
n_linear int

Feed-forward block count on the generated root array.

1
dropout Rate | None

Optional dropout rate on the generated root array.

None
optimizer OptimizerConfig | None

Optimizer instance or factory used by Lightning training.

None
scheduler SchedulerConfig | None

Optional scheduler config or factory.

None

Returns:

Type Description
Self

A compiled Model with modules built for the schema.

Source code in src/json2vec/architecture/root.py
@classmethod
def from_schema(
    cls,
    *field_args: SchemaField,
    d_model: int,
    n_layers: int,
    n_heads: int,
    batch_size: int = 1,
    fields: Sequence[SchemaField] | None = None,
    name: str = "record",
    description: str | None = None,
    embed: bool = False,
    attention: AttentionMode | str = AttentionMode.mha,
    n_linear: int = 1,
    dropout: Rate | None = None,
    optimizer: OptimizerConfig | None = None,
    scheduler: SchedulerConfig | None = None,
) -> Self:
    """Build a model directly from schema fields.

    Args:
        *field_args: Field constructors such as `Category`, `Number`, or
            nested `Array` nodes.
        d_model: Shared model width.
        n_layers: Number of encoder layers on generated array nodes.
        n_heads: Attention heads used by generated nodes.
        batch_size: Batch size used by data modules, examples, and mocked
            Lightning input arrays.
        fields: Optional sequence form of `field_args`.
        name: Root array name. Defaults to `record`.
        description: Optional description on the generated root array.
        embed: Configure the generated root array as an embedding output.
        attention: Attention mode for the generated root array.
        n_linear: Feed-forward block count on the generated root array.
        dropout: Optional dropout rate on the generated root array.
        optimizer: Optimizer instance or factory used by Lightning training.
        scheduler: Optional scheduler config or factory.

    Returns:
        A compiled `Model` with modules built for the schema.
    """
    hyperparameters = Hyperparameters.from_schema(
        *field_args,
        d_model=d_model,
        n_layers=n_layers,
        n_heads=n_heads,
        fields=fields,
        name=name,
        description=description,
        embed=embed,
        attention=attention,
        n_linear=n_linear,
        dropout=dropout,
    )
    return cls(
        hyperparameters=hyperparameters,
        batch_size=batch_size,
        optimizer=optimizer,
        scheduler=scheduler,
    )

select

select(
    *predicates: NodePredicate
    | NodeAttribute
    | Callable[[Node], bool],
    include_root: bool = True,
    use_cache: bool = True,
) -> list[Node]

Return schema nodes that satisfy every predicate.

Source code in src/json2vec/architecture/root.py
def select(
    self,
    *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
    include_root: bool = True,
    use_cache: bool = True,
) -> list[Node]:
    """Return schema nodes that satisfy every predicate."""
    return self.schema.select(*predicates, include_root=include_root, use_cache=use_cache)

update

update(
    *predicates: NodePredicate
    | NodeAttribute
    | Callable[[Node], bool],
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    use_cache: bool = False,
    **values: Any,
) -> None

Mutate selected schema nodes and rebuild compatible modules.

target=True is shorthand for p_prune=1.0; target=False clears target behavior by setting p_prune=0.0.

Parameters:

Name Type Description Default
*predicates NodePredicate | NodeAttribute | Callable[[Node], bool]

Predicates used to select nodes.

()
strict bool

Raise when a selected node cannot accept one of values.

True
allow_extra bool

Permit updates to extra metadata fields on models that allow unknown fields.

False
include_root bool

Include the root node in predicate matching.

True
validate bool

Validate each node after applying candidate values.

True
use_cache bool

Permit cached selector results. Mutations default this to False so updates always evaluate against current schema state.

False
**values Any

Schema attributes to update.

{}
Source code in src/json2vec/architecture/root.py
def update(
    self,
    *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    use_cache: bool = False,
    **values: Any,
) -> None:
    """Mutate selected schema nodes and rebuild compatible modules.

    `target=True` is shorthand for `p_prune=1.0`; `target=False` clears
    target behavior by setting `p_prune=0.0`.

    Args:
        *predicates: Predicates used to select nodes.
        strict: Raise when a selected node cannot accept one of `values`.
        allow_extra: Permit updates to extra metadata fields on models that
            allow unknown fields.
        include_root: Include the root node in predicate matching.
        validate: Validate each node after applying candidate values.
        use_cache: Permit cached selector results. Mutations default this to
            `False` so updates always evaluate against current schema state.
        **values: Schema attributes to update.
    """
    self.schema.update(
        *predicates,
        strict=strict,
        allow_extra=allow_extra,
        include_root=include_root,
        validate=validate,
        use_cache=use_cache,
        **values,
    )

extend

extend(
    *args: NodePredicate
    | NodeAttribute
    | Callable[[Node], bool]
    | SchemaField,
    include_root: bool = True,
    use_cache: bool = True,
) -> None

Append new schema fields under one selected array node and rebuild modules.

Source code in src/json2vec/architecture/root.py
def extend(
    self,
    *args: NodePredicate | NodeAttribute | Callable[[Node], bool] | SchemaField,
    include_root: bool = True,
    use_cache: bool = True,
) -> None:
    """Append new schema fields under one selected array node and rebuild modules."""
    self.schema.extend(*args, include_root=include_root, use_cache=use_cache)

delete

delete(
    *predicates: NodePredicate
    | NodeAttribute
    | Callable[[Node], bool],
    include_root: bool = False,
    use_cache: bool = True,
) -> None

Permanently remove selected schema nodes and rebuild modules.

Source code in src/json2vec/architecture/root.py
def delete(
    self,
    *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
    include_root: bool = False,
    use_cache: bool = True,
) -> None:
    """Permanently remove selected schema nodes and rebuild modules."""
    self.schema.delete(*predicates, include_root=include_root, use_cache=use_cache)

reset

reset(
    *predicates: NodePredicate
    | NodeAttribute
    | Callable[[Node], bool],
    include_root: bool = True,
    use_cache: bool = True,
    descendants: bool = False,
) -> None

Reinitialize selected runtime node modules while preserving schema values.

Source code in src/json2vec/architecture/root.py
def reset(
    self,
    *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
    include_root: bool = True,
    use_cache: bool = True,
    descendants: bool = False,
) -> None:
    """Reinitialize selected runtime node modules while preserving schema values."""
    self.schema.reset(
        *predicates,
        include_root=include_root,
        use_cache=use_cache,
        descendants=descendants,
    )

override

override(
    *predicates: NodePredicate
    | NodeAttribute
    | Callable[[Node], bool],
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    use_cache: bool = False,
    **values: Any,
) -> Iterator[None]

Temporarily mutate selected schema nodes and keep runtime modules synchronized.

Source code in src/json2vec/architecture/root.py
@contextmanager
def override(
    self,
    *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    use_cache: bool = False,
    **values: Any,
) -> Iterator[None]:
    """Temporarily mutate selected schema nodes and keep runtime modules synchronized."""
    with self.schema.override(
        *predicates,
        strict=strict,
        allow_extra=allow_extra,
        include_root=include_root,
        validate=validate,
        use_cache=use_cache,
        **values,
    ):
        yield

save

save(pathname: str | Path) -> str | Path

Save model weights and schema hyperparameters to a checkpoint.

Source code in src/json2vec/architecture/root.py
@beartype
def save(self, pathname: str | Path) -> str | Path:
    """Save model weights and schema hyperparameters to a checkpoint."""
    CheckpointState.save(self, pathname)

    return pathname

load classmethod

load(checkpoint: str | Path) -> Self

Load a Model checkpoint written by Model.save(...).

Source code in src/json2vec/architecture/root.py
@classmethod
def load(cls, checkpoint: str | Path) -> Self:
    """Load a `Model` checkpoint written by `Model.save(...)`."""
    return cast(Self, CheckpointState.load(cls, checkpoint))

predict

predict(
    batch: EncodedBatch | list[dict[str, Any]],
    preprocess: Preprocessor | None = None,
    postprocess: Postprocessor | None = None,
) -> dict[Address, dict[str, Any]]

Return typed predictions and configured embeddings for a raw or encoded batch.

Source code in src/json2vec/architecture/root.py
@immutable("inference")
def predict(
    self,
    batch: EncodedBatch | list[dict[str, Any]],
    preprocess: Preprocessor | None = None,
    postprocess: Postprocessor | None = None,
) -> dict[Address, dict[str, Any]]:
    """Return typed predictions and configured embeddings for a raw or encoded batch."""
    return ModelRuntime.predict(
        self,
        batch=batch,
        preprocess=preprocess,
        postprocess=postprocess,
    )

Schema

Array

Array(*children: Self | RequestTypes | Leaf, **data)

Bases: Node

Repeated nested object group in a json2vec schema.

Positional children are treated as fields inside the array.

Source code in src/json2vec/structs/structure.py
def __init__(self, *children: Self | RequestTypes | Leaf, **data):
    if children:
        if "fields" in data:
            raise TypeError("array children were provided both positionally and by keyword")
        data["fields"] = list(children)

    super().__init__(**data)

name instance-attribute

name: str

type class-attribute instance-attribute

type: Annotated[Literal["array"], Field(default=array)] = (
    "array"
)

attention class-attribute instance-attribute

attention: AttentionMode = mha

max_length class-attribute instance-attribute

max_length: Annotated[int, Field(gt=0, default=1)] = 1

overflow class-attribute instance-attribute

overflow: Overflow = head

n_linear class-attribute instance-attribute

n_linear: Annotated[int, Field(gt=0, default=1)] = 1

n_layers class-attribute instance-attribute

n_layers: Annotated[int, Field(gt=0, default=1)] = 1

masks class-attribute instance-attribute

masks: list[Mask] = Field(default_factory=list)

fields class-attribute instance-attribute

fields: list[Self | RequestTypes | InstanceOf[Leaf]] = (
    Field(default_factory=list)
)

normalize_mask_shorthand classmethod

normalize_mask_shorthand(data: Any) -> Any
Source code in src/json2vec/structs/structure.py
@pydantic.model_validator(mode="before")
@classmethod
def normalize_mask_shorthand(cls, data: Any) -> Any:
    if not isinstance(data, dict):
        return data

    values = dict(data)
    mask = values.pop("mask", None)
    if mask is None:
        return values

    if "masks" in values:
        raise ValueError("pass either mask or masks, not both")

    values["masks"] = [mask]
    return values

model_post_init

model_post_init(__context)
Source code in src/json2vec/structs/structure.py
def model_post_init(self, __context):
    for field in self.fields:
        field.parent: Self = self

check_unique_child_names

check_unique_child_names()
Source code in src/json2vec/structs/structure.py
@pydantic.model_validator(mode="after")
def check_unique_child_names(self):
    seen: set[str] = set()
    for field in self.fields:
        if field.name in seen:
            raise ValueError(f"duplicate field name: {field.name}")
        seen.add(field.name)

    return self

post_bind_validate

post_bind_validate()
Source code in src/json2vec/structs/structure.py
def post_bind_validate(self):
    if len(self.masks) == 0:
        return None

    is_root = getattr(getattr(self, "parent", None), "type", None) == "hyperparameters"
    if is_root:
        raise ValueError("Mask on the generated root array is not supported")

    active_leaves = [
        descendant
        for descendant in getattr(self, "descendants", ())
        if isinstance(descendant, Leaf) and getattr(descendant, "active", True)
    ]
    if not active_leaves:
        raise ValueError(f"array '{self.address}' has masks but no active descendant leaves")

    names = [mask.name for mask in self.masks if mask.name is not None]
    duplicates = sorted({name for name in names if names.count(name) > 1})
    if duplicates:
        raise ValueError(f"array '{self.address}' has duplicate mask name(s): {duplicates}")

    for mask in self.masks:
        if mask.offset >= self.max_length:
            raise ValueError(f"array '{self.address}' mask offset must be less than max_length={self.max_length}")

        excluded = self.excluded_leaves(mask)
        if len(excluded) == len(active_leaves):
            label = f" '{mask.name}'" if mask.name is not None else ""
            raise ValueError(f"array '{self.address}' mask{label} excludes every active descendant leaf")

    return None

excluded_leaves

excluded_leaves(mask: Mask) -> tuple[Leaf, ...]
Source code in src/json2vec/structs/structure.py
def excluded_leaves(self, mask: Mask) -> tuple[Leaf, ...]:
    if mask.exclude is None:
        return ()

    from json2vec.structs.selectors import NodePredicate

    predicates = tuple(NodePredicate.from_selector(item) for item in mask.exclude)
    return tuple(
        descendant
        for descendant in getattr(self, "descendants", ())
        if isinstance(descendant, Leaf)
        if getattr(descendant, "active", True)
        if any(predicate(descendant) for predicate in predicates)
    )

Hyperparameters

Bases: Node

Serializable schema and training metadata used to build a Model.

target property

target: list[Address]

embed property

embed: list[Address]

from_schema classmethod

from_schema(
    *field_args: SchemaField,
    d_model: int,
    n_layers: int,
    n_heads: int,
    fields: Sequence[SchemaField] | None = None,
    name: str = "record",
    description: str | None = None,
    embed: bool = False,
    attention: AttentionMode | str = AttentionMode.mha,
    n_linear: Annotated[int, Field(gt=0)] = 1,
    dropout: Rate | None = None,
) -> Self

Build hyperparameters from schema fields.

Source code in src/json2vec/structs/experiment.py
@classmethod
def from_schema(
    cls,
    *field_args: SchemaField,
    d_model: int,
    n_layers: int,
    n_heads: int,
    fields: Sequence[SchemaField] | None = None,
    name: str = "record",
    description: str | None = None,
    embed: bool = False,
    attention: AttentionMode | str = AttentionMode.mha,
    n_linear: Annotated[int, pydantic.Field(gt=0)] = 1,
    dropout: Rate | None = None,
) -> Self:
    """Build hyperparameters from schema fields."""
    normalized = [*(fields or ()), *field_args]
    if not normalized:
        raise ValueError("from_schema requires at least one field")

    seen_sources: set[str] = set()
    root_fields: list[Array | RequestTypes] = []

    for field in normalized:
        if not isinstance(field, (Array, Leaf)):
            raise TypeError("schema fields must be Array, Leaf, or concrete request instances")

        source = field.name
        if source in seen_sources:
            raise ValueError(f"duplicate schema source field: {source}")
        seen_sources.add(source)

        root_fields.append(cls.from_schema_node(field))

    array = Array(
        name=name,
        description=description,
        embed=embed,
        attention=attention,
        n_layers=n_layers,
        n_heads=n_heads,
        n_linear=n_linear,
        max_length=1,
        overflow=Overflow.error,
        dropout=dropout,
        fields=root_fields,
    )
    return cls(d_model=d_model, fields=array)

select

select(
    *predicates: NodeSelector,
    include_root: bool = True,
    use_cache: bool = True,
) -> list[Node]
Source code in src/json2vec/structs/experiment.py
def select(
    self,
    *predicates: NodeSelector,
    include_root: bool = True,
    use_cache: bool = True,
) -> list[Node]:
    if predicates:
        normalized = tuple(NodePredicate.from_selector(item) for item in predicates)
        combined = NodePredicate(
            func=lambda node: all(item(node) for item in normalized),
            key=("and", tuple(item.key for item in normalized)),
            cacheable=all(item.cacheable for item in normalized),
        )
    else:
        combined = NodePredicate(func=lambda node: True, key=("all",))

    key = ("select", include_root, combined.key)

    if use_cache and combined.cacheable and key in self._selection_cache:
        return Selection(self._selection_cache[key].nodes)

    nodes = tuple(
        node for node in PreOrderIter(self.fields) if (include_root or node is not self.fields) if combined(node)
    )

    if use_cache and combined.cacheable:
        self._selection_cache[key] = SelectionCacheEntry(
            key=key,
            predicate=combined,
            include_root=include_root,
            nodes=nodes,
        )

    return Selection(nodes)

update

update(
    *predicates: NodeSelector,
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    use_cache: bool = False,
    **values: Any,
) -> None

Mutate matching schema nodes.

target=True is normalized to p_prune=1.0; target=False clears the target prune rate by setting p_prune=0.0.

Source code in src/json2vec/structs/experiment.py
def update(
    self,
    *predicates: NodeSelector,
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    use_cache: bool = False,
    **values: Any,
) -> None:
    """Mutate matching schema nodes.

    `target=True` is normalized to `p_prune=1.0`; `target=False` clears the
    target prune rate by setting `p_prune=0.0`.
    """
    values = self.update_values(values)
    if not values:
        raise ValueError("update requires at least one field value")

    nodes = self.select(*predicates, include_root=include_root, use_cache=use_cache)
    for node in nodes:
        can_apply_extra = allow_extra and getattr(type(node), "model_config", {}).get("extra") == "allow"
        missing = [name for name in values if not _has_model_attribute(node, name) and not can_apply_extra]
        if missing and strict:
            label = str(node.address) or node.name
            raise AttributeError(f"{label} has no attribute(s): {missing}")

        applicable_values = {
            name: value for name, value in values.items() if _has_model_attribute(node, name) or can_apply_extra
        }

        if validate and applicable_values:
            payload = node.model_dump(mode="python", round_trip=True)
            if isinstance(node, Array) and "masks" not in applicable_values:
                payload["masks"] = list(node.masks)
            if "target" in applicable_values and "p_prune" not in applicable_values:
                payload.pop("p_prune", None)
            payload.update(applicable_values)
            validated = type(node).model_validate(payload)
            applicable_values = {name: getattr(validated, name) for name in applicable_values}

        for name, value in applicable_values.items():
            setattr(node, name, value)
            if name in getattr(type(node), "model_fields", {}):
                node.model_fields_set.add(name)

    self._clear_tree_caches()
    self._post_bind_validate()
    self.refresh_selection_cache()

extend

extend(
    *args: ExtendArg,
    include_root: bool = True,
    use_cache: bool = True,
) -> None

Append new schema fields under the single array selected by predicates.

Source code in src/json2vec/structs/experiment.py
def extend(
    self,
    *args: ExtendArg,
    include_root: bool = True,
    use_cache: bool = True,
) -> None:
    """Append new schema fields under the single array selected by predicates."""
    predicates: list[NodeSelector] = []
    fields: list[SchemaField] = []
    reading_fields = False

    for item in args:
        if isinstance(item, (Array, Leaf)):
            reading_fields = True
            fields.append(item)
            continue

        if reading_fields:
            raise TypeError("extend predicates must come before new schema fields")

        predicates.append(item)

    if not fields:
        raise ValueError("extend requires at least one schema field")

    candidates = [
        node
        for node in self.select(*predicates, include_root=include_root, use_cache=use_cache)
        if isinstance(node, Array)
    ]

    if len(candidates) != 1:
        raise ValueError(f"extend requires exactly one matching array node, found {len(candidates)}")

    parent = candidates[0]
    array_path = tuple(node.name for node in parent.path[2:] if isinstance(node, Array))
    new_fields = [self.from_schema_node(field, array_path=array_path) for field in fields]
    existing_names = {field.name for field in parent.fields}
    duplicate_names = sorted({field.name for field in new_fields if field.name in existing_names})
    duplicate_names.extend(
        sorted(
            {
                field.name
                for index, field in enumerate(new_fields)
                if any(other.name == field.name for other in new_fields[index + 1 :])
            }
        )
    )
    if duplicate_names:
        raise ValueError(f"duplicate field name(s): {sorted(set(duplicate_names))}")

    original_fields = list(parent.fields)
    try:
        parent.fields.extend(new_fields)
        for field in new_fields:
            field.parent = parent

        self._clear_tree_caches()
        self._post_bind_validate()
    except Exception:
        parent.fields = original_fields
        for field in new_fields:
            field.parent = None
        self._clear_tree_caches()
        self._post_bind_validate()
        self.refresh_selection_cache()
        raise

    self.refresh_selection_cache()

delete

delete(
    *predicates: NodeSelector,
    include_root: bool = False,
    use_cache: bool = True,
) -> None

Permanently remove selected schema nodes from the tree.

Source code in src/json2vec/structs/experiment.py
def delete(
    self,
    *predicates: NodeSelector,
    include_root: bool = False,
    use_cache: bool = True,
) -> None:
    """Permanently remove selected schema nodes from the tree."""
    if not predicates:
        raise ValueError("delete requires at least one predicate")

    selected = self.select(*predicates, include_root=include_root, use_cache=use_cache)
    if not selected:
        raise ValueError("delete matched no nodes")
    if self.fields in selected:
        raise ValueError("delete cannot remove the root array")

    selected_ids = {id(node) for node in selected}
    roots = [
        node
        for node in selected
        if not any(
            id(ancestor) in selected_ids for ancestor in getattr(node, "ancestors", ()) if ancestor is not self
        )
    ]
    removed_by_id = {id(node): node for node in roots}
    for node in roots:
        removed_by_id.update({id(descendant): descendant for descendant in getattr(node, "descendants", ())})
    removed_addresses = {node.address for node in removed_by_id.values()}

    remaining_request_addresses = {address for address in self.requests if address not in removed_addresses}
    if not remaining_request_addresses:
        raise ValueError("delete would remove every request")

    remaining_array_addresses = {address for address in self.arrays if address not in removed_addresses}
    for address in remaining_array_addresses:
        prefix = f"{address}/"
        if not any(str(request_address).startswith(prefix) for request_address in remaining_request_addresses):
            raise ValueError(f"delete would leave array '{address}' without request descendants")

    for node in roots:
        parent = node.parent
        if not isinstance(parent, Array):
            raise ValueError(f"delete cannot remove '{node.address}' because it has no array parent")
        parent.fields = [field for field in parent.fields if field is not node]
        node.parent = None

    self._clear_tree_caches()
    self._post_bind_validate()
    self.refresh_selection_cache()

override

override(
    *predicates: NodeSelector,
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    use_cache: bool = False,
    **values: Any,
) -> Iterator[None]
Source code in src/json2vec/structs/experiment.py
@contextmanager
def override(
    self,
    *predicates: NodeSelector,
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    use_cache: bool = False,
    **values: Any,
) -> Iterator[None]:
    nodes = self.select(*predicates, include_root=include_root, use_cache=use_cache)
    normalized_values = self.update_values(values)
    snapshot = [
        (
            node,
            "p_prune" if name == "target" else name,
            getattr(node, "p_prune" if name == "target" else name, _MISSING),
            ("p_prune" if name == "target" else name) in getattr(node, "model_fields_set", set()),
        )
        for node in nodes
        for name in normalized_values
        if _has_model_attribute(node, name)
        or (allow_extra and getattr(type(node), "model_config", {}).get("extra") == "allow")
    ]

    self.update(
        *predicates,
        strict=strict,
        allow_extra=allow_extra,
        include_root=include_root,
        validate=validate,
        use_cache=use_cache,
        **normalized_values,
    )

    try:
        yield
    finally:
        for node, name, original, was_set in snapshot:
            if original is _MISSING:
                if getattr(node, name, _MISSING) is _MISSING:
                    continue
                delattr(node, name)
            else:
                setattr(node, name, original)
                if name in getattr(type(node), "model_fields", {}):
                    if was_set:
                        node.model_fields_set.add(name)
                    else:
                        node.model_fields_set.discard(name)

        self._clear_tree_caches()
        self._post_bind_validate()
        self.refresh_selection_cache()

where

where(name: str) -> NodeAttribute

Start a schema predicate against a node attribute.

Source code in src/json2vec/structs/selectors.py
def where(name: str) -> NodeAttribute:
    """Start a schema predicate against a node attribute."""
    return NodeAttribute.named(name)

Tensorfield Constructors

Number

Bases: RequestBase

Numeric scalar tensorfield request.

type class-attribute instance-attribute

type: Literal['number'] = 'number'

jitter class-attribute instance-attribute

jitter: Annotated[float, Field(ge=0.0, default=0.0)] = 0.0

n_bands class-attribute instance-attribute

n_bands: Annotated[int, Field(gt=0, default=8)] = 8

offset class-attribute instance-attribute

offset: Annotated[int, Field(gt=0, default=4)] = 4

alpha class-attribute instance-attribute

alpha: Annotated[
    float | None, Field(gt=0.0, lt=1.0, default=None)
] = None

objective class-attribute instance-attribute

objective: Objective = mae

Category

Bases: RequestBase

Categorical scalar tensorfield request backed by an online vocabulary.

type class-attribute instance-attribute

type: Literal['category'] = 'category'

max_vocab_size class-attribute instance-attribute

max_vocab_size: Annotated[
    int, Field(gt=0, default=10000)
] = 10000

p_unavailable class-attribute instance-attribute

p_unavailable: Annotated[
    float, Field(ge=0.0, le=1.0, default=0.01)
] = 0.01

topk class-attribute instance-attribute

topk: list[int] | None = None

reject_removed_options classmethod

reject_removed_options(data: Any) -> Any
Source code in src/json2vec/tensorfields/extensions/category.py
@pydantic.model_validator(mode="before")
@classmethod
def reject_removed_options(cls, data: Any) -> Any:
    if isinstance(data, Mapping) and "n_bands" in data:
        raise ValueError("Category does not support n_bands")

    return data

check_topk

check_topk()
Source code in src/json2vec/tensorfields/extensions/category.py
@pydantic.model_validator(mode="after")
def check_topk(self):
    if self.topk is None:
        self.topk = []

    # enforce uniqueness
    self.topk = sorted(set(self.topk))

    for topk in self.topk:
        if not isinstance(topk, int):
            raise ValueError("topk values must be integers")

        if topk <= 0:
            raise ValueError("topk values must be positive")

        if topk == 1:
            raise ValueError("topk values must not be 1")

        if topk >= self.max_vocab_size:
            raise ValueError("topk values must be less than max_vocab_size")

    return self

Set

Bases: RequestBase

Multi-label set tensorfield request backed by an online vocabulary.

type class-attribute instance-attribute

type: Literal['set'] = 'set'

max_vocab_size class-attribute instance-attribute

max_vocab_size: Annotated[
    int, Field(gt=0, default=10000)
] = 10000

p_unavailable class-attribute instance-attribute

p_unavailable: Annotated[
    float, Field(ge=0.0, le=1.0, default=0.01)
] = 0.01

threshold class-attribute instance-attribute

threshold: Annotated[
    float | None, Field(ge=0.0, le=1.0, default=None)
] = None

DateParts

Bases: RequestBase

Date/time tensorfield request that extracts configured calendar parts.

type class-attribute instance-attribute

type: Literal['dateparts'] = 'dateparts'

dateparts instance-attribute

dateparts: list[DatePart]

pattern class-attribute instance-attribute

pattern: Annotated[str | None, Field(default=None)] = None

check_dateparts classmethod

check_dateparts(v)
Source code in src/json2vec/tensorfields/extensions/dateparts.py
@pydantic.field_validator("dateparts", check_fields=False)
@classmethod
def check_dateparts(cls, v):
    if not v:
        raise ValueError("dateparts cannot be empty")

    if not len(v) == len(set(v)):
        raise ValueError("dateparts must be unique")

    return v

check_date_pattern classmethod

check_date_pattern(v)
Source code in src/json2vec/tensorfields/extensions/dateparts.py
@pydantic.field_validator("pattern", check_fields=False)
@classmethod
def check_date_pattern(cls, v):
    if v is None:
        return v

    regex: re.Pattern = re.compile(r"^(?:%%| %(?:[aAwdbBmyYHIpMSfzZjUWcxXGuV])|[^%])+$", re.VERBOSE)

    if not bool(regex.fullmatch(v)):
        raise ValueError(f"{v} is not a valid format pattern")

    return v

Entity

Bases: RequestBase

Per-observation entity tensorfield request for local identity matching.

type class-attribute instance-attribute

type: Literal['entity'] = 'entity'

topk class-attribute instance-attribute

topk: list[int] | None = None

check_topk

check_topk()
Source code in src/json2vec/tensorfields/extensions/entity.py
@pydantic.model_validator(mode="after")
def check_topk(self):
    if self.topk is None:
        self.topk = []

    for topk in self.topk:
        if not isinstance(topk, int):
            raise ValueError("topk values must be integers")

        if topk <= 0:
            raise ValueError("topk values must be positive")

        if topk == 1:
            raise ValueError("topk values must not be 1")

    return self

post_bind_validate

post_bind_validate()
Source code in src/json2vec/tensorfields/extensions/entity.py
def post_bind_validate(self):
    per_observation_count: int = math.prod(self.shape)
    if per_observation_count <= 1:
        raise ValueError(
            f"entity field at '{self.address}' requires at least 2 elements per observation, "
            f"but configured count is {per_observation_count}"
        )

Vector

Bases: RequestBase

Fixed-width numeric vector tensorfield request.

type class-attribute instance-attribute

type: Literal['vector'] = 'vector'

n_dim instance-attribute

n_dim: Annotated[int, Field(gt=0)]

objective class-attribute instance-attribute

objective: Objective = l2

Text

Bases: RequestBase

Text tensorfield request encoded by a frozen Hugging Face model.

type class-attribute instance-attribute

type: Literal['text'] = 'text'

model_name instance-attribute

model_name: str

max_length class-attribute instance-attribute

max_length: Annotated[int, Field(gt=0, default=128)] = 128

encoder_batch_size class-attribute instance-attribute

encoder_batch_size: Annotated[
    int, Field(gt=0, default=32)
] = 32

encoder_pooling class-attribute instance-attribute

encoder_pooling: Pooling = cls

objective class-attribute instance-attribute

objective: Objective = l2

revision class-attribute instance-attribute

revision: str | None = None

local_files_only class-attribute instance-attribute

local_files_only: bool = False

normalize_model_name classmethod

normalize_model_name(value: str)
Source code in src/json2vec/tensorfields/extensions/text.py
@pydantic.field_validator("model_name", mode="before")
@classmethod
def normalize_model_name(cls, value: str):
    if not isinstance(value, str):
        raise ValueError("model_name must be a string")

    return value.strip()

normalize_revision classmethod

normalize_revision(value: str | None)
Source code in src/json2vec/tensorfields/extensions/text.py
@pydantic.field_validator("revision", mode="before")
@classmethod
def normalize_revision(cls, value: str | None):
    if value is None:
        return None

    if not isinstance(value, str):
        raise ValueError("revision must be a string when provided")

    normalized = value.strip()
    return normalized or None

check_model_name

check_model_name()
Source code in src/json2vec/tensorfields/extensions/text.py
@pydantic.model_validator(mode="after")
def check_model_name(self):
    if not self.model_name:
        raise ValueError("model_name must be a non-empty string")

    return self

Data

Use Data Modules for the workflow-level guide to CustomDataModule, PolarsDataModule, and StreamingDataModule.

CustomDataModule

CustomDataModule(
    model: Model,
    train: IterableDataset | None = None,
    validate: IterableDataset | None = None,
    test: IterableDataset | None = None,
    predict: IterableDataset | None = None,
    preprocessor: str
    | Callable[..., Any]
    | Preprocessor
    | None = None,
    datasets: DatasetMap | None = None,
    num_workers: NonNegativeInt
    | None
    | StrataMap[NonNegativeInt | None] = None,
    persistent_workers: bool | StrataMap[bool] = True,
    pin_memory: bool | StrataMap[bool] = True,
    observation_buffer_size: PositiveInt
    | StrataMap[PositiveInt] = 1,
    sample_rate: SampleRate | StrataMap[SampleRate] = 1.0,
    **kwargs: Any,
)

Bases: LightningDataModule

Lightning data module for user-provided iterable datasets.

Source code in src/json2vec/data/datasets/custom.py
def __init__(
    self,
    model: Model,
    train: IterableDataset | None = None,
    validate: IterableDataset | None = None,
    test: IterableDataset | None = None,
    predict: IterableDataset | None = None,
    preprocessor: str | Callable[..., Any] | Preprocessor | None = None,
    datasets: DatasetMap | None = None,
    num_workers: NonNegativeInt | None | StrataMap[NonNegativeInt | None] = None,
    persistent_workers: bool | StrataMap[bool] = True,
    pin_memory: bool | StrataMap[bool] = True,
    observation_buffer_size: PositiveInt | StrataMap[PositiveInt] = 1,
    sample_rate: SampleRate | StrataMap[SampleRate] = 1.0,
    **kwargs: Any,
):
    super().__init__()

    _validate_loader_configuration(
        num_workers=num_workers,
        persistent_workers=persistent_workers,
        pin_memory=pin_memory,
        observation_buffer_size=observation_buffer_size,
        sample_rate=sample_rate,
    )

    if datasets is not None and any(dataset is not None for dataset in (train, validate, test, predict)):
        raise ValueError("pass either datasets or named splits, not both")

    if datasets is None:
        split_datasets = {}
        for strata, dataset in {
            Strata.train: train,
            Strata.validate: validate,
            Strata.test: test,
            Strata.predict: predict,
        }.items():
            if dataset is None:
                continue
            if not isinstance(dataset, IterableDataset):
                raise TypeError(f"dataset for strata '{strata}' must be an IterableDataset")
            split_datasets[strata] = dataset
        if not split_datasets:
            raise ValueError("at least one dataset split is required")
    else:
        split_datasets = _datasets_by_strata(datasets)

    self.datasets = split_datasets
    self.preprocessor = PreprocessorConfig.normalize(preprocessor)
    self.preprocessor_kwargs = dict(kwargs)
    try:
        self._model_ref = weakref.ref(model)
    except TypeError:
        self._model_ref = None
    self._hyperparameters = model.hyperparameters
    self._interprocess_encoding_context = model.interprocess_encoding_context
    self._batch_size = model.batch_size
    self.num_workers = Strata.expand(num_workers, default=None)
    self.persistent_workers = Strata.expand(persistent_workers, default=True)
    self.pin_memory = Strata.expand(pin_memory, default=True)
    self.observation_buffer_size = Strata.expand(observation_buffer_size, default=1)
    self.sample_rate = {strata: float(rate) for strata, rate in Strata.expand(sample_rate, default=1.0).items()}

train_dataloader class-attribute instance-attribute

train_dataloader = partialmethod(
    dataloader, strata=train, required=False
)

val_dataloader class-attribute instance-attribute

val_dataloader = partialmethod(
    dataloader, strata=validate, required=False
)

test_dataloader class-attribute instance-attribute

test_dataloader = partialmethod(
    dataloader, strata=test, required=False
)

predict_dataloader class-attribute instance-attribute

predict_dataloader = partialmethod(
    dataloader, strata=predict, required=False
)

dataloader

dataloader(
    strata: Strata, required: bool = True
) -> DataLoader | None
Source code in src/json2vec/data/datasets/custom.py
def dataloader(self, strata: Strata, required: bool = True) -> DataLoader | None:
    strata = Strata.normalize(strata)
    trainer = getattr(self, "trainer", None)
    global_rank = getattr(trainer, "global_rank", None)
    world_size = getattr(trainer, "world_size", None)
    if strata not in self.datasets:
        if not required:
            return None
        raise ValueError(f"no dataset configured for strata: {strata}")

    workers = self.num_workers[strata]
    if workers is None:
        workers = os.cpu_count() or 0

    interprocess_encoding_context = self.interprocess_encoding_context
    if strata == Strata.train and workers > 0:
        share_interprocess_encoding_context(interprocess_encoding_context)

    return custom_dataloader(
        hyperparameters=self.hyperparameters,
        dataset=self.datasets[strata],
        preprocessor=self.preprocessor,
        preprocessor_kwargs=self.preprocessor_kwargs,
        interprocess_encoding_context=interprocess_encoding_context,
        batch_size=self.batch_size,
        strata=strata,
        num_workers=workers,
        persistent_workers=self.persistent_workers[strata],
        pin_memory=self.pin_memory[strata],
        observation_buffer_size=self.observation_buffer_size[strata],
        sample_rate=self.sample_rate[strata],
        global_rank=global_rank,
        world_size=world_size,
    )

PolarsDataModule

PolarsDataModule(
    model: Model,
    train: DataFrame | None = None,
    validate: DataFrame | None = None,
    test: DataFrame | None = None,
    predict: DataFrame | None = None,
    preprocessor: str
    | Callable[..., Any]
    | Preprocessor
    | None = None,
    dataframe: DataFrame | DataFrameMap | None = None,
    num_workers: NonNegativeInt
    | None
    | StrataMap[NonNegativeInt | None] = None,
    persistent_workers: bool | StrataMap[bool] = True,
    pin_memory: bool | StrataMap[bool] = True,
    sharding: ShardingStrategy
    | str
    | StrataMap[
        ShardingStrategy | str
    ] = ShardingStrategy.chunk,
    chunk_batch_size: PositiveInt
    | StrataMap[PositiveInt] = 4096,
    observation_buffer_size: PositiveInt
    | StrataMap[PositiveInt] = 1,
    sample_rate: SampleRate | StrataMap[SampleRate] = 1.0,
    replacement: bool | StrataMap[bool] = False,
    **kwargs: Any,
)

Bases: LightningDataModule

Lightning data module for in-memory Polars DataFrames.

Source code in src/json2vec/data/datasets/polars.py
@beartype
def __init__(
    self,
    model: Model,
    train: pl.DataFrame | None = None,
    validate: pl.DataFrame | None = None,
    test: pl.DataFrame | None = None,
    predict: pl.DataFrame | None = None,
    preprocessor: str | Callable[..., Any] | Preprocessor | None = None,
    dataframe: pl.DataFrame | DataFrameMap | None = None,
    num_workers: NonNegativeInt | None | StrataMap[NonNegativeInt | None] = None,
    persistent_workers: bool | StrataMap[bool] = True,
    pin_memory: bool | StrataMap[bool] = True,
    sharding: ShardingStrategy | str | StrataMap[ShardingStrategy | str] = ShardingStrategy.chunk,
    chunk_batch_size: PositiveInt | StrataMap[PositiveInt] = 4096,
    observation_buffer_size: PositiveInt | StrataMap[PositiveInt] = 1,
    sample_rate: SampleRate | StrataMap[SampleRate] = 1.0,
    replacement: bool | StrataMap[bool] = False,
    **kwargs: Any,
):
    super().__init__()

    if dataframe is not None and any(frame is not None for frame in (train, validate, test, predict)):
        raise ValueError("pass either dataframe or named splits, not both")

    if dataframe is None:
        dataframes = {
            strata: frame
            for strata, frame in {
                Strata.train: train,
                Strata.validate: validate,
                Strata.test: test,
                Strata.predict: predict,
            }.items()
            if frame is not None
        }
        if not dataframes:
            raise ValueError("at least one dataframe split is required")
    else:
        dataframes = _dataframes_by_strata(dataframe)

    self.dataframes = dataframes
    self.preprocessor = PreprocessorConfig.normalize(preprocessor)
    self.preprocessor_kwargs = dict(kwargs)
    try:
        self._model_ref = weakref.ref(model)
    except TypeError:
        self._model_ref = None
    self._hyperparameters = model.hyperparameters
    self._interprocess_encoding_context = model.interprocess_encoding_context
    self._batch_size = model.batch_size
    self.num_workers = Strata.expand(num_workers, default=None)
    self.persistent_workers = Strata.expand(persistent_workers, default=True)
    self.pin_memory = Strata.expand(pin_memory, default=True)
    self.sharding = ShardingStrategy.expand(sharding, default=ShardingStrategy.chunk)
    self.chunk_batch_size = Strata.expand(chunk_batch_size, default=4096)
    self.observation_buffer_size = Strata.expand(observation_buffer_size, default=1)
    self.sample_rate = {strata: float(rate) for strata, rate in Strata.expand(sample_rate, default=1.0).items()}
    self.replacement = Strata.expand(replacement, default=False)

train_dataloader class-attribute instance-attribute

train_dataloader = partialmethod(
    dataloader, strata=train, required=False
)

val_dataloader class-attribute instance-attribute

val_dataloader = partialmethod(
    dataloader, strata=validate, required=False
)

test_dataloader class-attribute instance-attribute

test_dataloader = partialmethod(
    dataloader, strata=test, required=False
)

predict_dataloader class-attribute instance-attribute

predict_dataloader = partialmethod(
    dataloader, strata=predict, required=False
)

dataloader

dataloader(
    strata: Strata, required: bool = True
) -> DataLoader | None
Source code in src/json2vec/data/datasets/polars.py
def dataloader(self, strata: Strata, required: bool = True) -> DataLoader | None:
    strata = Strata.normalize(strata)
    trainer = getattr(self, "trainer", None)
    global_rank = getattr(trainer, "global_rank", None)
    world_size = getattr(trainer, "world_size", None)
    if strata not in self.dataframes:
        if not required:
            return None
        raise ValueError(f"no dataframe configured for strata: {strata}")

    workers = self.num_workers[strata]
    if workers is None:
        workers = os.cpu_count() or 0

    interprocess_encoding_context = self.interprocess_encoding_context
    if strata == Strata.train and workers > 0:
        share_interprocess_encoding_context(interprocess_encoding_context)

    return polars_dataloader(
        hyperparameters=self.hyperparameters,
        dataframe=self.dataframes[strata],
        preprocessor=self.preprocessor,
        preprocessor_kwargs=self.preprocessor_kwargs,
        interprocess_encoding_context=interprocess_encoding_context,
        batch_size=self.batch_size,
        strata=strata,
        num_workers=workers,
        persistent_workers=self.persistent_workers[strata],
        pin_memory=self.pin_memory[strata],
        sharding=self.sharding[strata],
        chunk_batch_size=self.chunk_batch_size[strata],
        observation_buffer_size=self.observation_buffer_size[strata],
        sample_rate=self.sample_rate[strata],
        replacement=self.replacement[strata],
        global_rank=global_rank,
        world_size=world_size,
    )

StreamingDataModule

StreamingDataModule(
    model: Model,
    root: str | Path,
    suffix: Suffix | str,
    train: PatternInput | None = None,
    validate: PatternInput | None = None,
    test: PatternInput | None = None,
    predict: PatternInput | None = None,
    preprocessor: str
    | Callable[..., Any]
    | Preprocessor
    | None = None,
    num_workers: NonNegativeInt
    | None
    | StrataMap[NonNegativeInt | None] = None,
    persistent_workers: bool | StrataMap[bool] = True,
    pin_memory: bool | StrataMap[bool] = True,
    sharding: ShardingStrategy
    | str
    | StrataMap[
        ShardingStrategy | str
    ] = ShardingStrategy.file,
    chunk_batch_size: PositiveInt
    | StrataMap[PositiveInt] = 4096,
    file_buffer_size: PositiveInt
    | StrataMap[PositiveInt] = 1,
    observation_buffer_size: PositiveInt
    | StrataMap[PositiveInt] = 1,
    sample_rate: SampleRate | StrataMap[SampleRate] = 1.0,
    replacement: bool | StrataMap[bool] | None = None,
    **kwargs: Any,
)

Bases: LightningDataModule

Lightning data module for streaming records from files.

Reads file-backed records, applies an optional preprocessor, batches observations, and encodes them with model hyperparameters.

Source code in src/json2vec/data/datasets/streaming.py
@beartype
def __init__(
    self,
    model: Model,
    root: str | Path,
    suffix: Suffix | str,
    train: PatternInput | None = None,
    validate: PatternInput | None = None,
    test: PatternInput | None = None,
    predict: PatternInput | None = None,
    preprocessor: str | Callable[..., Any] | Preprocessor | None = None,
    num_workers: NonNegativeInt | None | StrataMap[NonNegativeInt | None] = None,
    persistent_workers: bool | StrataMap[bool] = True,
    pin_memory: bool | StrataMap[bool] = True,
    sharding: ShardingStrategy | str | StrataMap[ShardingStrategy | str] = ShardingStrategy.file,
    chunk_batch_size: PositiveInt | StrataMap[PositiveInt] = 4096,
    file_buffer_size: PositiveInt | StrataMap[PositiveInt] = 1,
    observation_buffer_size: PositiveInt | StrataMap[PositiveInt] = 1,
    sample_rate: SampleRate | StrataMap[SampleRate] = 1.0,
    replacement: bool | StrataMap[bool] | None = None,
    **kwargs: Any,
):
    super().__init__()

    self.root = root
    self.suffix = Suffix(suffix)
    self.train = _compile_pattern(train) if train is not None else None
    self.validate = _compile_pattern(validate) if validate is not None else None
    self.test = _compile_pattern(test) if test is not None else None
    self.predict = _compile_pattern(predict) if predict is not None else None
    self.preprocessor = PreprocessorConfig.normalize(preprocessor)
    self.preprocessor_kwargs = dict(kwargs)
    try:
        self._model_ref = weakref.ref(model)
    except TypeError:
        self._model_ref = None
    self._hyperparameters = model.hyperparameters
    self._interprocess_encoding_context = model.interprocess_encoding_context
    self._batch_size = model.batch_size
    self.num_workers = Strata.expand(num_workers, default=None)
    self.persistent_workers = Strata.expand(persistent_workers, default=True)
    self.pin_memory = Strata.expand(pin_memory, default=True)
    self.sharding = ShardingStrategy.expand(sharding, default=ShardingStrategy.file)
    self.chunk_batch_size = Strata.expand(chunk_batch_size, default=4096)
    self.file_buffer_size = Strata.expand(file_buffer_size, default=1)
    self.observation_buffer_size = Strata.expand(observation_buffer_size, default=1)
    self.sample_rate = {strata: float(rate) for strata, rate in Strata.expand(sample_rate, default=1.0).items()}
    self.replacement = (
        {strata: strata == Strata.train for strata in Strata}
        if replacement is None
        else Strata.expand(replacement, default=False)
    )

train_dataloader class-attribute instance-attribute

train_dataloader = partialmethod(
    dataloader, strata=train, required=False
)

val_dataloader class-attribute instance-attribute

val_dataloader = partialmethod(
    dataloader, strata=validate, required=False
)

test_dataloader class-attribute instance-attribute

test_dataloader = partialmethod(
    dataloader, strata=test, required=False
)

predict_dataloader class-attribute instance-attribute

predict_dataloader = partialmethod(
    dataloader, strata=predict, required=False
)

dataloader

dataloader(
    strata: Strata, required: bool = True
) -> DataLoader | None
Source code in src/json2vec/data/datasets/streaming.py
def dataloader(self, strata: Strata, required: bool = True) -> DataLoader | None:
    strata = Strata.normalize(strata)
    pattern = getattr(self, strata.value)
    if pattern is None:
        if not required:
            return None
        raise ValueError(f"no file pattern configured for strata: {strata}")

    trainer = getattr(self, "trainer", None)
    global_rank = getattr(trainer, "global_rank", None)
    world_size = getattr(trainer, "world_size", None)

    workers = self.num_workers[strata]
    if workers is None:
        workers = os.cpu_count() or 0

    interprocess_encoding_context = self.interprocess_encoding_context
    if strata == Strata.train and workers > 0:
        share_interprocess_encoding_context(interprocess_encoding_context)

    return dataloader(
        hyperparameters=self.hyperparameters,
        root=self.root,
        suffix=self.suffix,
        pattern=pattern,
        preprocessor=self.preprocessor,
        preprocessor_kwargs=self.preprocessor_kwargs,
        interprocess_encoding_context=interprocess_encoding_context,
        batch_size=self.batch_size,
        strata=strata,
        num_workers=workers,
        persistent_workers=self.persistent_workers[strata],
        pin_memory=self.pin_memory[strata],
        sharding=self.sharding[strata],
        chunk_batch_size=self.chunk_batch_size[strata],
        file_buffer_size=self.file_buffer_size[strata],
        observation_buffer_size=self.observation_buffer_size[strata],
        sample_rate=self.sample_rate[strata],
        replacement=self.replacement[strata],
        global_rank=global_rank,
        world_size=world_size,
    )

Batch Inference

Use Batch Inference for the workflow-level guide to Trainer.predict(...), Writer, and postprocessed Parquet output.

Writer

Writer(
    path: PathLike | str,
    flush_every_n_batches: int | None = None,
    postprocessor: Postprocessor | None = None,
)

Bases: BasePredictionWriter

Source code in src/json2vec/inference/callback.py
def __init__(
    self,
    path: os.PathLike | str,
    flush_every_n_batches: int | None = None,
    postprocessor: Postprocessor | None = None,
):
    super().__init__(write_interval="batch")

    self.path = Path(path)
    self.flush_every_n_batches: int | None = flush_every_n_batches
    self.postprocessor: Postprocessor | None = postprocessor
    self.schema: pa.Schema | None = None
    self.writer: pq.ParquetWriter | None = None

path instance-attribute

path = Path(path)

flush_every_n_batches instance-attribute

flush_every_n_batches: int | None = flush_every_n_batches

postprocessor instance-attribute

postprocessor: Postprocessor | None = postprocessor

schema instance-attribute

schema: Schema | None = None

writer instance-attribute

writer: ParquetWriter | None = None

write_on_batch_end

write_on_batch_end(
    trainer: Trainer,
    pl_module: Model,
    output: dict[str, list[Prediction]],
    batch_indices: list[int] | None,
    batch: TensorDict[Address, TensorFieldBase],
    batch_idx: int,
    dataloader_idx: int,
) -> None
Source code in src/json2vec/inference/callback.py
def write_on_batch_end(
    self,
    trainer: lit.Trainer,
    pl_module: Model,
    output: dict[str, list[Prediction]],
    batch_indices: list[int] | None,
    batch: TensorDict[Address, TensorFieldBase],
    batch_idx: int,
    dataloader_idx: int,
) -> None:  # ty:ignore[invalid-method-override]
    num_rows = len(batch[TensorKey.metadata])

    predictions: dict[Address, dict[str, Any]] = pl_module.write(predictions=output["predictions"])
    postprocessor = self.postprocessor

    if postprocessor is not None:
        context = {
            "input": batch,
            "batch": batch,
            TensorKey.metadata: batch[TensorKey.metadata],
            "batch_indices": batch_indices,
            "batch_idx": batch_idx,
            "dataloader_idx": dataloader_idx,
        }
        processed = postprocessor(context, predictions)

        if processed is not None:
            predictions = processed

    if len(predictions) == 0:
        predictions_frame = pl.DataFrame({"predictions": [None] * num_rows})
    else:
        columns: list[pl.DataFrame] = []
        for address, values in predictions.items():
            field_frame = pl.DataFrame(data=values)
            columns.append(field_frame.select(pl.struct(pl.all()).alias(name=address)))

        nested: pl.DataFrame = pl.concat(items=columns, how="horizontal")
        predictions_frame = nested.select(pl.struct(pl.all()).alias(name="predictions"))

    items = [
        pl.from_records(data=batch[TensorKey.metadata], schema=["inputs"], orient="row"),
        predictions_frame,
    ]

    table: pa.Table = pl.concat(items=items, how="horizontal").to_arrow()

    if self.writer is None:
        self.path.mkdir(parents=True, exist_ok=True)
        self.schema = table.schema

        self.writer = pq.ParquetWriter(
            where=self.path / f"rank-{trainer.local_rank}.parquet",
            schema=self.schema,
        )

    if table.schema != self.schema:
        table = table.cast(self.schema)

    self.writer.write_table(table)

    flush = getattr(self.writer, "flush", None)
    if self.flush_every_n_batches and (batch_idx + 1) % self.flush_every_n_batches == 0 and callable(flush):
        flush()

on_predict_end

on_predict_end(
    trainer: Trainer, pl_module: LightningModule
) -> None
Source code in src/json2vec/inference/callback.py
def on_predict_end(self, trainer: lit.Trainer, pl_module: lit.LightningModule) -> None:
    if self.writer:
        self.writer.close()
        self.writer = None

Preprocessing

preprocess

preprocess(
    func: Callable[..., Any] | None = None,
    *,
    yields: bool | None = None,
    **kwargs: Any,
) -> Callable[..., Any]

Register a callable as a json2vec preprocessor.

Parameters:

Name Type Description Default
func Callable[..., Any] | None

Callable to register when used as @preprocess.

None
yields bool | None

Set to True for generator preprocessors.

None
**kwargs Any

Reserved for validation of unsupported decorator arguments.

{}

Returns:

Type Description
Callable[..., Any]

The original callable, after registering it in PREPROCESSORS.

Example
import json2vec as j2v

@j2v.preprocess
def normalize(record: dict) -> dict:
    return {**record, "amount": float(record["amount"])}
Source code in src/json2vec/preprocessors/base.py
def preprocess(
    func: Callable[..., Any] | None = None,
    *,
    yields: bool | None = None,
    **kwargs: Any,
) -> Callable[..., Any]:
    """Register a callable as a `json2vec` preprocessor.

    Args:
        func: Callable to register when used as `@preprocess`.
        yields: Set to `True` for generator preprocessors.
        **kwargs: Reserved for validation of unsupported decorator arguments.

    Returns:
        The original callable, after registering it in `PREPROCESSORS`.

    Example:
        ```python
        import json2vec as j2v

        @j2v.preprocess
        def normalize(record: dict) -> dict:
            return {**record, "amount": float(record["amount"])}
        ```
    """
    if "yield" in kwargs:
        if yields is not None:
            raise TypeError("use either 'yields' or 'yield', not both")
        yields = kwargs.pop("yield")

    if kwargs:
        unexpected = ", ".join(sorted(kwargs))
        raise TypeError(f"unexpected preprocess keyword argument(s): {unexpected}")

    if yields is None:
        yields = False

    mode = PreprocessorMode.from_yields(yields)

    def decorator(inner: Callable[..., Any]) -> Callable[..., Any]:
        return Preprocessor.register(inner, mode=mode)

    if func is None:
        return decorator

    if not callable(func):
        raise TypeError("preprocess can only decorate callables")

    return decorator(func)

Preprocessor

Bases: BaseModel

Registered observation preprocessor.

A transformation preprocessor returns one dict. A generator preprocessor yields or returns multiple dict objects, each of which becomes a processed observation.

outputs

outputs(
    observation: dict, **kwargs
) -> Iterator[list[dict[str, Any]]]

Yield normalized processed observations for one raw observation.

Source code in src/json2vec/preprocessors/base.py
def outputs(self, observation: dict, **kwargs) -> Iterator[list[dict[str, Any]]]:
    """Yield normalized processed observations for one raw observation."""
    result = self(observation, **kwargs)

    if self.mode == PreprocessorMode.transformation:
        yield [self.require_object(result, mode=self.mode)]
        return

    if self.mode == PreprocessorMode.generator:
        if isinstance(result, list):
            iterable: list[Any] | Iterator[Any] = result
        elif isinstance(result, Iterator):
            iterable = result
        else:
            raise TypeError(
                f"generator preprocessor '{self.name}' must yield dict objects or return a list of dict objects, "
                f"got {type(result).__name__}"
            )

        for output in iterable:
            yield [self.require_object(output, mode=self.mode)]
        return

    raise ValueError(f"unsupported preprocessor mode: {self.mode}")

Serving

Deployment

Bases: BaseSettings

Serving configuration for a json2vec checkpoint or model instance.

Deployment queues request/response schemas, optional preprocessors, optional postprocessors, and update(...) mutations before the model is loaded by FastAPI application startup.

forge

forge(
    request: type[BaseModel] | None = None,
    response: type[BaseModel] | None = None,
) -> Deployment

Attach optional Pydantic request and response signatures.

Source code in src/json2vec/inference/deployment.py
@beartype
def forge(
    self,
    request: type[pydantic.BaseModel] | None = None,
    response: type[pydantic.BaseModel] | None = None,
) -> "Deployment":
    """Attach optional Pydantic request and response signatures."""
    self._request_signature = request
    self._response_signature = response

    return self

preprocess

preprocess(preprocessor, **kwargs: Any) -> Deployment

Attach an optional request preprocessor.

If this method is not called, request objects are encoded unchanged.

Source code in src/json2vec/inference/deployment.py
@beartype
def preprocess(self, preprocessor, **kwargs: Any) -> "Deployment":
    """Attach an optional request preprocessor.

    If this method is not called, request objects are encoded unchanged.
    """
    self._preprocessor = functools.partial(preprocessor, **kwargs) if kwargs else preprocessor

    return self

postprocess

postprocess(postprocessor, **kwargs: Any) -> Deployment

Attach an optional response postprocessor.

Source code in src/json2vec/inference/deployment.py
@beartype
def postprocess(self, postprocessor, **kwargs: Any) -> "Deployment":
    """Attach an optional response postprocessor."""
    self._postprocessor = functools.partial(postprocessor, **kwargs) if kwargs else postprocessor

    return self

update

update(
    *predicates: NodePredicate
    | NodeAttribute
    | Callable[[Node], bool],
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    **values: Any,
) -> Deployment

Queue a model schema mutation to apply during server startup.

This mirrors Model.update(...) and is useful for serving-time changes such as target=False.

Source code in src/json2vec/inference/deployment.py
@beartype
def update(
    self,
    *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
    strict: bool = True,
    allow_extra: bool = False,
    include_root: bool = True,
    validate: bool = True,
    **values: Any,
) -> "Deployment":
    """Queue a model schema mutation to apply during server startup.

    This mirrors `Model.update(...)` and is useful for serving-time changes
    such as `target=False`.
    """
    self._update_operations.append(
        (
            tuple(predicates),
            {
                "strict": strict,
                "allow_extra": allow_extra,
                "include_root": include_root,
                "validate": validate,
                **values,
            },
        )
    )

    return self

serve

serve() -> None

Start the FastAPI server for the configured checkpoint or model.

Source code in src/json2vec/inference/deployment.py
def serve(self) -> None:
    """Start the FastAPI server for the configured checkpoint or model."""
    if self.workers > 1:
        if self.model is not None or isinstance(self.checkpoint, Model):
            raise ValueError("workers > 1 requires a checkpoint path, not an in-memory model")

        env_updates = {
            "JSON2VEC_CHECKPOINT": str(self.checkpoint),
            "JSON2VEC_MAX_BATCH_SIZE": str(self.max_batch_size),
            "JSON2VEC_BATCH_TIMEOUT": str(self.batch_timeout),
            "JSON2VEC_ACCELERATOR": self.accelerator.value,
            "JSON2VEC_MONITOR_QUERIES": str(self.monitor_queries),
            "JSON2VEC_QUERY_MONITOR_EVERY": str(self.query_monitor_every),
            "JSON2VEC_JSON_BACKEND": self.json_backend.value,
        }
        previous = {key: os.environ.get(key) for key in env_updates}
        try:
            os.environ.update(env_updates)
            uvicorn.run(
                "json2vec.inference.deployment:create_app",
                factory=True,
                workers=self.workers,
                host=self.host,
                port=self.port,
                log_level=self.log_level,
            )
        finally:
            for key, value in previous.items():
                if value is None:
                    os.environ.pop(key, None)
                else:
                    os.environ[key] = value
        return

    uvicorn.run(
        self.app(),
        host=self.host,
        port=self.port,
        log_level=self.log_level,
    )

Tensorfield Extension API

Plugin

Plugin(name: str)

Registry object for a tensorfield implementation.

Register request, tensorfield, embedder, decoder, loss, and write components with @plugin.register. Creating a plugin with an existing name replaces the registry entry and emits a warning.

Source code in src/json2vec/tensorfields/base.py
def __init__(self, name: str):
    if not isinstance(name, str):
        raise TypeError("Plugin name must be a string")

    # should start with a letter and contain only lowercase letters, numbers, and underscores
    if not re.match(r"^[a-z0-9_]+$", name):
        raise ValueError("Plugin name must consist of lowercase letters, numbers, and underscores only")

    self.name: str = name
    self.components: dict[Component, ComponentValue | None] = {}
    self.callback_factories: list[CallbackFactory] = []

    if name in TENSORFIELDS:
        warnings.warn(
            f"Plugin '{name}' already registered; overriding existing tensorfield plugin",
            UserWarning,
            stacklevel=2,
        )

    TENSORFIELDS[name] = self

callbacks property

callbacks: list[Callback]

Instantiate all registered callback factories.

register

register(obj: None, component: Component | str) -> None
register(
    obj: RegisterT, component: Component | str | None = None
) -> RegisterT
register(
    obj: RegisterT | None,
    component: Component | str | None = None,
) -> RegisterT | None

Register one tensorfield component with this plugin.

Source code in src/json2vec/tensorfields/base.py
def register(
    self,
    obj: RegisterT | None,
    component: Component | str | None = None,
) -> RegisterT | None:
    """Register one tensorfield component with this plugin."""
    if obj is None:
        if component is None:
            raise TypeError("component must be provided when registering None")

        key = Component(component)
        if key != Component.write:
            raise TypeError("only write may be registered as None")

        if key in self.components:
            raise ValueError(f"Component '{key}' already registered in plugin '{self.name}'")

        self.components[key] = None
        return None

    if not hasattr(obj, "__name__"):
        raise NameError(f"Object {obj} does not have a name")

    name: str = str(obj.__name__)
    try:
        key = Component(name)
    except ValueError:
        raise ValueError(f"Component '{name}' is not a valid Component enum value") from None

    if key in self.components:
        raise ValueError(f"Component '{key}' already registered in plugin '{self.name}'")

    match key:
        case Component.Request:
            if not isinstance(obj, type):
                raise TypeError("Request must be a class type")

            if not issubclass(obj, Node):
                raise TypeError("Request must be a subclass of Node")

        case Component.TensorField:
            if not isinstance(obj, type):
                raise TypeError("TensorField must be a class type")

            if not issubclass(obj, TensorFieldBase):
                raise TypeError("TensorField must be a subclass of TensorFieldBase")

        case Component.Embedder:
            if not isinstance(obj, type):
                raise TypeError("Embedder must be a class type")

            if not issubclass(obj, EmbedderBase):
                raise TypeError("Embedder must be a subclass of EmbedderBase")

            # confirm the init method is expecting hyperparameters and address
            init_params = list(obj.__init__.__annotations__.keys())
            if "hyperparameters" not in init_params or "address" not in init_params:
                raise TypeError("Embedder __init__ method must accept 'hyperparameters' and 'address' parameters")

        case Component.Decoder:
            if not isinstance(obj, type):
                raise TypeError("Decoder must be a class type")

            if not issubclass(obj, DecoderBase):
                raise TypeError("Decoder must be a subclass of DecoderBase")

            init_params = list(obj.__init__.__annotations__.keys())
            if "hyperparameters" not in init_params or "address" not in init_params:
                raise TypeError("Decoder __init__ method must accept 'hyperparameters' and 'address' parameters")

        case Component.loss:
            if not callable(obj):
                raise TypeError("Loss must be a callable function")

            expected_params: list[str] = ["module", "prediction", "batch", "strata"]
            func_params: list[str] = list(obj.__annotations__.keys())

            if not set(expected_params).issubset(set(func_params)):
                raise TypeError(
                    f"Loss function must accept the following parameters: {expected_params}, got {func_params}"
                )

        case Component.write:
            if obj is not None and not callable(obj):
                raise TypeError("Write must be a callable function")

            # check the signature of the function
            expected_params: list[str] = ["module", "prediction"]
            func_params: list[str] = list(obj.__annotations__.keys())

            if func_params != expected_params:
                raise TypeError(
                    f"Write function must accept the following parameters: {expected_params}, got {func_params}"
                )

    self.components[key] = obj

    return obj

callback

callback(factory: CallbackFactory) -> CallbackFactory
callback(
    factory: CallbackFactory, *factories: CallbackFactory
) -> tuple[CallbackFactory, ...]
callback(
    factory: CallbackFactory, *factories: CallbackFactory
)

Register one or more Lightning callback factories for this tensorfield.

Source code in src/json2vec/tensorfields/base.py
def callback(self, factory: CallbackFactory, *factories: CallbackFactory):
    """Register one or more Lightning callback factories for this tensorfield."""
    registered = (factory, *factories)
    for callback_factory in registered:
        callback = callback_factory()
        if not isinstance(callback, Callback):
            raise TypeError(f"Plugin callback factory for '{self.name}' must produce a Lightning Callback")

    self.callback_factories.extend(registered)
    return factory if len(registered) == 1 else registered

TensorFieldBase

Bases: Renderable

Tensorized field values plus trainable target state.

STATE_PREVIEW_LIMIT class-attribute instance-attribute

STATE_PREVIEW_LIMIT: int = 80

STATE_LABELS class-attribute instance-attribute

STATE_LABELS: dict[int, str] = {
    value: "V",
    value: "N",
    value: "P",
    value: "M",
    value: "O",
}

STATE_STYLES class-attribute instance-attribute

STATE_STYLES: dict[int, str] = {
    value: "bold green",
    value: "bold yellow",
    value: "dim",
    value: "bold magenta",
    value: "bold cyan",
}

content instance-attribute

content: Tensor

state instance-attribute

state: Tensor

trainable instance-attribute

trainable: Tensor

targets instance-attribute

targets: TensorDict[TensorKey, Tensor]

new abstractmethod classmethod

new(
    values: list,
    address: Address,
    hyperparameters: Hyperparameters,
    strata: Strata,
) -> "TensorFieldBase"
Source code in src/json2vec/tensorfields/base.py
@classmethod
@abstractmethod
def new(
    cls,
    values: list,
    address: Address,
    hyperparameters: Hyperparameters,
    strata: Strata,
) -> "TensorFieldBase":
    raise NotImplementedError

empty abstractmethod classmethod

empty(
    batch_size: int,
    address: Address,
    hyperparameters: Hyperparameters,
) -> "TensorFieldBase"
Source code in src/json2vec/tensorfields/base.py
@classmethod
@abstractmethod
def empty(
    cls,
    batch_size: int,
    address: Address,
    hyperparameters: Hyperparameters,
) -> "TensorFieldBase":
    raise NotImplementedError

mask abstractmethod

mask(p_mask: float = 0.0, **kwargs: Any)
Source code in src/json2vec/tensorfields/base.py
@abstractmethod
def mask(self, p_mask: float = 0.0, **kwargs: Any):
    raise NotImplementedError

target abstractmethod

target(p_prune: float = 1.0)
Source code in src/json2vec/tensorfields/base.py
@abstractmethod
def target(self, p_prune: float = 1.0):
    raise NotImplementedError

hide

hide(
    selected: Tensor,
    *,
    cache_targets: bool = True,
    trainable: bool = True,
) -> None
Source code in src/json2vec/tensorfields/base.py
def hide(self, selected: torch.Tensor, *, cache_targets: bool = True, trainable: bool = True) -> None:
    raise NotImplementedError

EmbedderBase

EmbedderBase(
    hyperparameters: Hyperparameters, address: Address
)

Bases: Module

Base class for tensorfield embedders.

Source code in src/json2vec/tensorfields/base.py
def __init__(self, hyperparameters: Hyperparameters, address: Address):
    super().__init__()

DecoderBase

DecoderBase(
    hyperparameters: Hyperparameters, address: Address
)

Bases: Module

Base class for tensorfield decoders.

Source code in src/json2vec/tensorfields/base.py
def __init__(self, hyperparameters: Hyperparameters, address: Address):
    super().__init__()

    self.address: Address = address
    self.sigma: torch.Tensor = torch.nn.Parameter(torch.zeros(1))

    request = hyperparameters.requests[address]
    n_context = 1
    for dimension in hyperparameters.shapes[address]:
        n_context *= dimension
    match request.pooling:
        case "query":
            self.pool = LearnedQueryCrossAttention(
                n_context=n_context,
                d_model=hyperparameters.d_model,
                nhead=request.n_heads,
                dropout=float(request.dropout or 0.0),
                n_linear=request.n_linear,
            )
        case "mean":
            self.pool = MeanPool(n_context=n_context)
        case _:
            raise ValueError(f"unsupported decoder pooling: {request.pooling}")

address instance-attribute

address: Address = address

sigma instance-attribute

sigma: Tensor = Parameter(zeros(1))

pool instance-attribute

pool = MeanPool(n_context=n_context)

decode

decode(
    pooled: Tensor,
) -> TensorDict[TensorKey, torch.Tensor]
Source code in src/json2vec/tensorfields/base.py
def decode(self, pooled: torch.Tensor) -> TensorDict[TensorKey, torch.Tensor]:
    raise NotImplementedError("decoder must implement decode(pooled)")

forward

forward(
    parcels: list[Parcel], *, embed: bool = False
) -> Prediction
Source code in src/json2vec/tensorfields/base.py
def forward(self, parcels: list[Parcel], *, embed: bool = False) -> Prediction:
    if len(parcels) == 0:
        raise ValueError("decoder requires at least one parcel")

    N, *_, C = parcels[0].payload.shape
    stacked = torch.cat([parcel.payload.reshape(N, -1, C) for parcel in parcels], dim=1)
    pooled = self.pool(stacked)

    payload = self.decode(pooled)
    if embed:
        payload[TensorKey.embedding] = pooled

    return Prediction(
        payload=payload,
        address=self.address,
        batch_size=pooled.shape[0],
    )