Skip to content

Dynamic Masking

Dynamic masking hides values during training so the model learns to reconstruct them from surrounding context. Leaf fields already support local stochastic masking with p_mask, pruning with p_prune, and supervised targets with target=True. Array masks add one more control: they select slices of a repeated context and apply that selection to descendant tensorfields.

Use array masks when the masking rule is about a repeated coordinate rather than one leaf value. Common examples are "mask two values from each word", "mask half of the last 10 retained events", or "mask a few line items but leave one field visible."

Array Masks

Attach j2v.Mask(...) to an Array(...) with mask= or masks=:

import json2vec as j2v

model = j2v.Model.from_schema(
    j2v.Array(
        j2v.Category("event_type", max_vocab_size=128),
        j2v.Number("amount"),
        name="events",
        max_length=64,
        overflow="tail",
        mask=j2v.Mask(rate=0.50, window=10),
    ),
    name="customer",
    d_model=64,
    n_layers=2,
    n_heads=4,
)

This selects the last 10 retained, non-padded events slots for each observation and masks about half of those slots. The same selected event rows are projected onto each active descendant leaf unless the mask excludes that leaf.

mask= is shorthand for a one-item list. Use masks= when one array needs more than one masking policy:

j2v.Array(
    j2v.Category("sku", max_vocab_size=4096),
    j2v.Number("quantity"),
    name="line_items",
    max_length=128,
    masks=[
        j2v.Mask(name="recent", rate=0.25, window=32),
        j2v.Mask(name="early", count=3, window=20, start=True),
    ],
)

Mask Options

Each Mask chooses a candidate window, then samples from that window.

Option Default Meaning
rate None Bernoulli probability for each candidate slot. Mutually exclusive with count.
count None Exact number of candidate slots to sample per parent coordinate, capped by available candidates. Mutually exclusive with rate.
window None Candidate span length. None means all slots in the selected coordinate space.
array False Use fixed padded array coordinates instead of observed non-padded coordinates.
start False Take the window from the start instead of the end.
offset 0 Skip slots away from the selected edge before taking the window.
exclude None One predicate, or a list of predicates, for descendant leaves that should not receive this mask.

Exactly one of rate or count is required.

The default coordinate mode is array=False, start=False, which means "the end of the retained real values." This is usually the right choice for recency histories, especially with overflow="tail".

Window Coordinates

Mask windows are resolved after preprocessing, query extraction, overflow handling, and padding. Padding is never selected.

For an array with max_length=8, window=3, and retained values shaped like this:

slot:   0 1 2 3 4 5 6 7
state:  V V V V V P P P

The candidate windows are:

Configuration Candidate slots
j2v.Mask(rate=0.2, window=3, start=True) 0, 1, 2
j2v.Mask(rate=0.2, window=3) 2, 3, 4
j2v.Mask(rate=0.2, window=3, array=True, start=True) 0, 1, 2
j2v.Mask(rate=0.2, window=3, array=True) none, because fixed slots 5, 6, 7 are padded

offset skips away from the chosen edge. With the same values and offset=2:

Configuration Candidate slots
j2v.Mask(rate=0.2, window=3, start=True, offset=2) 2, 3, 4
j2v.Mask(rate=0.2, window=3, offset=2) 0, 1, 2

Nested Arrays

An array mask is resolved at the array that owns it. If that array is nested, the mask is computed independently for each parent coordinate.

This example masks two letters from the last two retained letters in each word:

from __future__ import annotations

import random
import string

from rich import print

import json2vec as j2v

ALPHABET = string.ascii_uppercase
ADDRESS = "record/words/letters/letter"


def consecutive_letters(rng: random.Random) -> list[dict[str, str]]:
    length = rng.randint(3, 8)
    start = rng.randint(0, len(ALPHABET) - length)
    return [{"letter": value} for value in ALPHABET[start : start + length]]


def words(rng: random.Random) -> list[dict[str, list[dict[str, str]]]]:
    return [{"letters": consecutive_letters(rng)} for _ in range(rng.randint(2, 4))]


rng = random.Random(7)
data = [{"words": words(rng)} for _ in range(5)]

model = j2v.Model.from_schema(
    j2v.Array(
        j2v.Array(
            j2v.Category("letter", max_vocab_size=len(ALPHABET), p_unavailable=0.0),
            name="letters",
            max_length=8,
            mask=j2v.Mask(count=2, window=2),
        ),
        name="words",
        max_length=3,
    ),
    d_model=16,
    n_layers=1,
    n_heads=4,
)

inputs = model.encode(data, strata=j2v.Strata.train, mask=False)
masked = model.encode(data, strata=j2v.Strata.train, mask=True)

print(inputs[ADDRESS])
print(masked[ADDRESS])

The tensorfield display shows the first batch item and the root singleton dimension. Nested arrays are printed with one row per inner array:

TensorField [tensorfield] state=(5, 1, 3, 8) device=cpu trainable=0
 counts V=54 N=0 P=66 M=0 O=0
 state V V V V P P P P
       V V V V V V V V
       V V V P P P P P

After masking, selected values become M and cached targets are present:

TensorField [tensorfield] state=(5, 1, 3, 8) device=cpu trainable=22
 counts V=32 N=0 P=66 M=22 O=0
 state V V M M P P P P
       V V V V V V M M
       V M M P P P P P
 targets=content, state

Randomness is intentionally stochastic, so exact masked positions can differ between runs.

Debugging Encodes

Model.encode(...) applies masking by default. Pass mask=False when you want to inspect the raw tensorized state before any dynamic masks, p_mask, or target=True pruning are applied:

unmasked = model.encode(records, strata=j2v.Strata.train, mask=False)
masked = model.encode(records, strata=j2v.Strata.train, mask=True)

This is most useful when validating mask windows for ragged arrays. In the rendered tensorfield state:

Symbol Token
V observed valued input
N explicit null
P padded slot
M masked hidden input
O other reserved state

The first tensor dimension is batch. The second dimension is the generated root singleton. The display previews state[0, 0].

Excluding Leaves

Use exclude= when a row should be selected for some descendant fields but not others:

j2v.Array(
    j2v.Category("merchant", max_vocab_size=4096),
    j2v.Number("amount"),
    j2v.Category("country", max_vocab_size=256),
    name="transactions",
    max_length=128,
    mask=j2v.Mask(
        rate=0.25,
        window=32,
        exclude=j2v.where("name") == "country",
    ),
)

exclude accepts one predicate or a list of predicates. Predicates are resolved after the schema is bound, relative to the active descendant leaves of the array that owns the mask:

j2v.Mask(
    rate=0.25,
    window=32,
    exclude=[
        j2v.where("type") == "entity",
        j2v.where("name") == "country",
    ],
)

Prediction Placeholders

During prediction, every tensorfield accepts the literal string "<MASK>" as a placeholder for "decode this value":

model.predict(
    [
        {
            "events": [
                {"event_type": "login", "amount": 0.0},
                {"event_type": "<MASK>", "amount": "<MASK>"},
            ]
        }
    ]
)

"<MASK>" is predict-only. It raises in train, validate, and test strata, and vocabulary-backed fields never learn it as an observed token.

For structured leaves such as Set or Vector, a scalar "<MASK>" masks the whole field. A collection containing "<MASK>" as one element is invalid.

Validation

Mask configuration is validated when the schema is bound:

  • Mask requires exactly one of rate or count.
  • rate must be between 0 and 1.
  • count must be non-negative.
  • window must be positive when provided.
  • offset must be non-negative and smaller than the owning array's max_length.
  • Passing both mask= and masks= to one array is invalid.
  • Masks on the generated root array are not supported.
  • A masked array must have at least one active descendant leaf.
  • exclude cannot remove every active descendant leaf.
  • Duplicate non-null mask names on the same array are invalid.

Performance Notes

Mask resolution happens once per active leaf tensorfield. rate masks are fully vectorized and usually add little overhead. count masks sample an exact number of slots per parent coordinate and are more expensive, especially for large nested batches or many descendant leaves.

For very large schemas, prefer rate when exact counts are not required. If multiple leaves share the same array mask, expect cost to scale with the number of affected leaves.