Dynamic Masking
Dynamic masking hides values during training so the model learns to reconstruct
them from surrounding context. Leaf fields already support local stochastic
masking with p_mask, pruning with p_prune, and supervised targets with
target=True. Array masks add one more control: they select slices of a
repeated context and apply that selection to descendant tensorfields.
Use array masks when the masking rule is about a repeated coordinate rather than one leaf value. Common examples are "mask two values from each word", "mask half of the last 10 retained events", or "mask a few line items but leave one field visible."
Array Masks
Attach j2v.Mask(...) to an Array(...) with mask= or masks=:
import json2vec as j2v
model = j2v.Model.from_schema(
j2v.Array(
j2v.Category("event_type", max_vocab_size=128),
j2v.Number("amount"),
name="events",
max_length=64,
overflow="tail",
mask=j2v.Mask(rate=0.50, window=10),
),
name="customer",
d_model=64,
n_layers=2,
n_heads=4,
)
This selects the last 10 retained, non-padded events slots for each
observation and masks about half of those slots. The same selected event rows
are projected onto each active descendant leaf unless the mask excludes that
leaf.
mask= is shorthand for a one-item list. Use masks= when one array needs
more than one masking policy:
j2v.Array(
j2v.Category("sku", max_vocab_size=4096),
j2v.Number("quantity"),
name="line_items",
max_length=128,
masks=[
j2v.Mask(name="recent", rate=0.25, window=32),
j2v.Mask(name="early", count=3, window=20, start=True),
],
)
Mask Options
Each Mask chooses a candidate window, then samples from that window.
| Option | Default | Meaning |
|---|---|---|
rate |
None |
Bernoulli probability for each candidate slot. Mutually exclusive with count. |
count |
None |
Exact number of candidate slots to sample per parent coordinate, capped by available candidates. Mutually exclusive with rate. |
window |
None |
Candidate span length. None means all slots in the selected coordinate space. |
array |
False |
Use fixed padded array coordinates instead of observed non-padded coordinates. |
start |
False |
Take the window from the start instead of the end. |
offset |
0 |
Skip slots away from the selected edge before taking the window. |
exclude |
None |
One predicate, or a list of predicates, for descendant leaves that should not receive this mask. |
Exactly one of rate or count is required.
The default coordinate mode is array=False, start=False, which means "the end
of the retained real values." This is usually the right choice for recency
histories, especially with overflow="tail".
Window Coordinates
Mask windows are resolved after preprocessing, query extraction, overflow handling, and padding. Padding is never selected.
For an array with max_length=8, window=3, and retained values shaped like
this:
The candidate windows are:
| Configuration | Candidate slots |
|---|---|
j2v.Mask(rate=0.2, window=3, start=True) |
0, 1, 2 |
j2v.Mask(rate=0.2, window=3) |
2, 3, 4 |
j2v.Mask(rate=0.2, window=3, array=True, start=True) |
0, 1, 2 |
j2v.Mask(rate=0.2, window=3, array=True) |
none, because fixed slots 5, 6, 7 are padded |
offset skips away from the chosen edge. With the same values and offset=2:
| Configuration | Candidate slots |
|---|---|
j2v.Mask(rate=0.2, window=3, start=True, offset=2) |
2, 3, 4 |
j2v.Mask(rate=0.2, window=3, offset=2) |
0, 1, 2 |
Nested Arrays
An array mask is resolved at the array that owns it. If that array is nested, the mask is computed independently for each parent coordinate.
This example masks two letters from the last two retained letters in each word:
from __future__ import annotations
import random
import string
from rich import print
import json2vec as j2v
ALPHABET = string.ascii_uppercase
ADDRESS = "record/words/letters/letter"
def consecutive_letters(rng: random.Random) -> list[dict[str, str]]:
length = rng.randint(3, 8)
start = rng.randint(0, len(ALPHABET) - length)
return [{"letter": value} for value in ALPHABET[start : start + length]]
def words(rng: random.Random) -> list[dict[str, list[dict[str, str]]]]:
return [{"letters": consecutive_letters(rng)} for _ in range(rng.randint(2, 4))]
rng = random.Random(7)
data = [{"words": words(rng)} for _ in range(5)]
model = j2v.Model.from_schema(
j2v.Array(
j2v.Array(
j2v.Category("letter", max_vocab_size=len(ALPHABET), p_unavailable=0.0),
name="letters",
max_length=8,
mask=j2v.Mask(count=2, window=2),
),
name="words",
max_length=3,
),
d_model=16,
n_layers=1,
n_heads=4,
)
inputs = model.encode(data, strata=j2v.Strata.train, mask=False)
masked = model.encode(data, strata=j2v.Strata.train, mask=True)
print(inputs[ADDRESS])
print(masked[ADDRESS])
The tensorfield display shows the first batch item and the root singleton dimension. Nested arrays are printed with one row per inner array:
TensorField [tensorfield] state=(5, 1, 3, 8) device=cpu trainable=0
counts V=54 N=0 P=66 M=0 O=0
state V V V V P P P P
V V V V V V V V
V V V P P P P P
After masking, selected values become M and cached targets are present:
TensorField [tensorfield] state=(5, 1, 3, 8) device=cpu trainable=22
counts V=32 N=0 P=66 M=22 O=0
state V V M M P P P P
V V V V V V M M
V M M P P P P P
targets=content, state
Randomness is intentionally stochastic, so exact masked positions can differ between runs.
Debugging Encodes
Model.encode(...) applies masking by default. Pass mask=False when you want
to inspect the raw tensorized state before any dynamic masks, p_mask, or
target=True pruning are applied:
unmasked = model.encode(records, strata=j2v.Strata.train, mask=False)
masked = model.encode(records, strata=j2v.Strata.train, mask=True)
This is most useful when validating mask windows for ragged arrays. In the rendered tensorfield state:
| Symbol | Token |
|---|---|
V |
observed valued input |
N |
explicit null |
P |
padded slot |
M |
masked hidden input |
O |
other reserved state |
The first tensor dimension is batch. The second dimension is the generated root
singleton. The display previews state[0, 0].
Excluding Leaves
Use exclude= when a row should be selected for some descendant fields but not
others:
j2v.Array(
j2v.Category("merchant", max_vocab_size=4096),
j2v.Number("amount"),
j2v.Category("country", max_vocab_size=256),
name="transactions",
max_length=128,
mask=j2v.Mask(
rate=0.25,
window=32,
exclude=j2v.where("name") == "country",
),
)
exclude accepts one predicate or a list of predicates. Predicates are resolved
after the schema is bound, relative to the active descendant leaves of the array
that owns the mask:
j2v.Mask(
rate=0.25,
window=32,
exclude=[
j2v.where("type") == "entity",
j2v.where("name") == "country",
],
)
Prediction Placeholders
During prediction, every tensorfield accepts the literal string "<MASK>" as a
placeholder for "decode this value":
model.predict(
[
{
"events": [
{"event_type": "login", "amount": 0.0},
{"event_type": "<MASK>", "amount": "<MASK>"},
]
}
]
)
"<MASK>" is predict-only. It raises in train, validate, and test strata, and
vocabulary-backed fields never learn it as an observed token.
For structured leaves such as Set or Vector, a scalar "<MASK>" masks the
whole field. A collection containing "<MASK>" as one element is invalid.
Validation
Mask configuration is validated when the schema is bound:
Maskrequires exactly one ofrateorcount.ratemust be between0and1.countmust be non-negative.windowmust be positive when provided.offsetmust be non-negative and smaller than the owning array'smax_length.- Passing both
mask=andmasks=to one array is invalid. - Masks on the generated root array are not supported.
- A masked array must have at least one active descendant leaf.
excludecannot remove every active descendant leaf.- Duplicate non-null mask names on the same array are invalid.
Performance Notes
Mask resolution happens once per active leaf tensorfield. rate masks are fully
vectorized and usually add little overhead. count masks sample an exact number
of slots per parent coordinate and are more expensive, especially for large
nested batches or many descendant leaves.
For very large schemas, prefer rate when exact counts are not required. If
multiple leaves share the same array mask, expect cost to scale with the number
of affected leaves.