Skip to content

Postprocessors

Postprocessors are the output-side counterpart to preprocessors. A preprocessor shapes raw observations before json2vec encodes them. A postprocessor shapes decoded predictions and embeddings after the model writes them.

Use a postprocessor when the model output is correct but not in the shape your API, batch job, or warehouse table should expose.

from collections.abc import Callable
from typing import Any

import json2vec as j2v

Postprocessor = Callable[
    [dict[str, Any], dict[j2v.Address, dict[str, Any]]],
    dict[j2v.Address, dict[str, Any]] | None,
]

The public alias is j2v.Postprocessor. It is a callable that receives:

  • context: runtime details such as raw input, encoded batch tensors, metadata, batch indices, and deployment request data.
  • predictions: a dictionary keyed by schema address. Each value contains the written payload for that address, such as decoded content, state, or embedding.

Return a replacement dictionary when you want to reshape the output. Return None when you mutated predictions in place.

Note

The type above documents the stable model-facing contract. Model.predict and deployment postprocessors can return compact API dictionaries. Prediction writers should keep row-aligned values so the Parquet writer can build a table.

Common context keys include:

Key Meaning
"batch" The raw batch passed to prediction.
"observations" Processed observation metadata from encoding.
"input" Encoded tensorfield input.
j2v.TensorKey.metadata Row-aligned metadata used by writers and joins.

Where They Run

Entry point How postprocessing is used
Model.predict(..., postprocess=...) Runs after raw predictions are converted with model.write(...).
j2v.Writer(..., postprocessor=...) Runs before batch prediction output is written to Parquet.
j2v.Deployment(...).postprocess(...) Runs before the serving response is returned.

Postprocessors should usually be deterministic and side-effect-light. If a postprocessor logs or enriches output from another system, keep that dependency explicit and handle failures outside the model path.

Inspect Address-Keyed Output

Model.predict(...) returns predictions keyed by schema address:

predictions = model.predict(records)

label = predictions[j2v.Address("record", "label")]
root = predictions[j2v.Address("record")]

The exact payload depends on the datatype and schema settings. A categorical target usually writes state and content; a node configured with embed=True writes embedding.

Flatten For An API

This pattern is useful when callers should not know the schema address layout.

import lightning.pytorch as lit
import polars as pl
import torch

import json2vec as j2v


def get_payload(predictions, path: str) -> dict:
    for address, payload in predictions.items():
        if str(address) == path:
            return payload
    return {}


def compact_response(context, predictions):
    label = get_payload(predictions, "record/species").get("content", {})
    root = get_payload(predictions, "record")

    return {
        "species": label.get("value"),
        "species_probability": label.get("probability"),
        "record_embedding": root.get("embedding"),
        "metadata": context[j2v.TensorKey.metadata],
    }


records = pl.read_ndjson("docs/data/iris.jsonl").head(36)

model = j2v.Model.from_schema(
    j2v.Number("sepal_length"),
    j2v.Number("petal_length"),
    j2v.Category("species", target=True, max_vocab_size=4, topk=[2]),
    d_model=16,
    n_layers=1,
    n_heads=4,
    batch_size=8,
    embed=True,
    optimizer=lambda module: torch.optim.AdamW(module.parameters(), lr=1e-2),
)

datamodule = j2v.PolarsDataModule(
    model=model,
    train=records,
    validate=records,
    num_workers=0,
    persistent_workers=False,
    pin_memory=False,
    observation_buffer_size=32,
    sample_rate=1.0,
)

trainer = lit.Trainer(
    accelerator="cpu",
    max_epochs=1,
    logger=False,
    enable_progress_bar=False,
    enable_checkpointing=False,
    enable_model_summary=False,
    limit_train_batches=1,
    limit_val_batches=1,
)

trainer.fit(model=model, datamodule=datamodule)
response = model.predict(records.to_dicts()[:3], postprocess=compact_response)

Add Confidence Flags

You can add derived fields while keeping the original address-keyed structure.

import copy

import json2vec as j2v


def add_confidence_flags(context, predictions, threshold: float = 0.80):
    output = copy.deepcopy(predictions)
    label = output.get(j2v.Address("record", "label"), {})
    content = label.get("content", {})
    probability = content.get("probability")

    if isinstance(probability, list):
        content["high_confidence"] = [
            score >= threshold if score is not None else None
            for score in probability
        ]
    elif probability is not None:
        content["high_confidence"] = probability >= threshold

    return output

This keeps downstream consumers that expect schema addresses working while adding an application-specific flag.

Attach Input Metadata

Batch and serving systems often need to join predictions back to original requests. Use the context dictionary for that, not the model payload.

import json2vec as j2v


def attach_request_ids(context, predictions):
    metadata = context[j2v.TensorKey.metadata]

    return {
        "request_id": [item.get("request_id") for item in metadata],
        "predictions": predictions,
    }

For batch prediction, context can include encoded tensors, raw batch input, metadata, batch_indices, batch_idx, and dataloader_idx. For deployments, it can include the original request, preprocessed observations, and encoded input.

Strip Private Or Large Fields

Embeddings are useful for offline retrieval and clustering, but they may be too large or too sensitive for a public response.

import copy


def without_embeddings(context, predictions):
    output = copy.deepcopy(predictions)

    for payload in output.values():
        payload.pop("embedding", None)

    return output

Use this at the serving boundary when the same model is used for both internal analysis and public responses.

Warehouse Rows

For batch output, prefer stable column shapes that preserve row alignment.

import json2vec as j2v


def warehouse_rows(context, predictions):
    label = predictions.get(j2v.Address("record", "label"), {}).get("content", {})
    metadata = context[j2v.TensorKey.metadata]

    return {
        j2v.Address("record", "warehouse_row"): {
            "request_id": [item.get("request_id") for item in metadata],
            "label": label.get("value"),
            "label_probability": label.get("probability"),
        }
    }

Keep the returned lists the same length as the prediction batch. That makes the writer output easy to join back to evaluation or production records.

When Not To Use One

Do not use a postprocessor to fix schema problems. If the model is reading the wrong source values, update the schema query or use a preprocessor. If you need a different target or embedding, change the schema with target=True, embed=True, or model.update(...).

Postprocessors are best for naming, filtering, thresholding, redaction, metadata joins, and response formatting.

Where Next