This blog post is part 1 of a series that describes my attempt in implementing the Test-Time Training (TTT) model proposed by Sun et al. (2024), and Titans, proposed by Behrouz et al., (2024). At the time of writing, these two are two strong recurrent language models, but they have not yet open-sourced their implementation (TTT has only open-sourced the Jax implementation).
Introduction to Test-Time Training
Briefly explained, Test-Time Training (TTT) is an RNN model whose hidden state is replaced with an online learner, whose parameters are updated updated through gradient descent during inference. The goal is that this online learner compress contextual information into its parameters. A TTT operator can be expressed as:
where
The Update Rule
The interesting part is that the update rule is implemented with gradient descent:
where the loss is defined as:
In other words, the online learner is learning to produce the value vector
Mini-Batch Gradient Descent
Obviously, directly running this algorithm is extremely slow because it cannot be parallellized. Sun et al. (2024) proposed mini-batch gradient descent, in which the input sequence is split into mini-batches and the gradients are computed based on the online learner’s state at the start on mini-batch
Dual Form
In order to make make use of matmuls and avoid the materialization of gradients, TTT further proposes to use a “dual form”. Without loss of generality, we consider a single mini-batch, where
Dual Form of Linear Model
For simplicity, we first explain the dual form for a linear model:
where
Now, let’s define the reconstruction error as
where
It’s interesting to see how similar this is to ordinary causal linear attention.
So, to sum up, the dual form of TTT accepts a batch of queries, keys, and values, and online learner state
Dual Form of MLP
The dual form can be extended to a two-layer MLP. The idea is to first apply the dual form to the first layer to get the intermediate representation, and then apply the dual form to the second layer to get the output. The forward pass can be implemented as:
Then the gradients are:
Define
where
The state at the end of the batch should be trivial to compute.
Code Implementation
Existing Open-Source Implementation
I have found an excellent PyTorch implementation of Titans by lucidrains. Titans is an extension of TTT that adds in momentum-based learning and data-depdendent weight decay and learning rate. However, in this implementation, the memory query is performed on the memory at the start of the batch:
instead of querying the state after KV insertion grad and functional_call functions) to avoid explicitly writing the gradient computation.
However, this is still not exactly the same as the TTT. So, I will try to implement the original TTT model.
My Implementation
It should be pretty straightforward to implement the dual form for a two-layer MLP, which is what I will do. Assume that the online learner is defined as:
The forward pass can be implemented as:
def two_layer_mlp_dual_form(
q: Tensor,
k: Tensor,
v: Tensor,
weights: dict[str, Tensor] | None = None,
) -> tuple[dict[str, Tensor], Tensor]:
"""
This operator utilizes the dual form of TTT (https://arxiv.org/pdf/2407.04620)
to avoid materializing the gradients.
"""
device = q.device
batch_size = q.shape[0]
w1 = weights["w1"]
w2 = weights["w2"]
# forward pass
h = einsum(w1, k, "dm d, b d -> b dm") # (b, dh)
h_relu = torch.relu(h) # (b, dh)
h2 = einsum(w2, h_relu, "d dm, b dm -> b d")
pred = h2 + k # (b, d)
# Backward pass
# We first compute the gradients w.r.t. to the activations,
# which can be efficiently computed because they utilize
# backward matmuls.
grad = pred - v # (b, d)
grad_h2 = einsum(w2, grad, "dy dh, b dy -> b dh")
grad_h = grad_h2 * (h_relu > 0).float()
# Then, we avoid materializing the gradients w.r.t. the two linear layers
# W1 and W2, by associativity with the queries.
mask = torch.triu(
torch.ones(batch_size, batch_size, requires_grad=False, device=device),
)
attn_scores = einsum(k, q, 'bk dk, bq dk -> bk bq') # (b, b)
attn_scores = attn_scores * mask # (b, b)
h_q = einsum(w1, q, 'dm d, b d -> b dm')
h_q = h_q - einsum(grad_h, attn_scores, 'bk dm, bk bq -> bq dm')
h2_q = torch.relu(h_q) # (b, dm)
# Compute the output
attn_scores = einsum(h_relu, h2_q, 'bk dm, bq dm -> bk bq')
attn_scores = attn_scores * mask
y = einsum(w2, h2_q, 'd dm, b dm -> b d')
y = y - einsum(grad, attn_scores, 'bk d, bk bq -> bq d')
y = y + q # Residual connection
# Compute the weights at the end of the batch
grad_w1 = einsum(grad_h, k, "b dh, b dx -> dh dx")
w1 = w1 - grad_w1
grad_w2 = einsum(grad, h_relu, "b dy, b dh -> dy dh")
w2 = w2 - grad_w2
weights = {
'w1': w1,
'w2': w2,
}
return y, weights
How to Cite
@misc{chen2025ttt-implementation,
author = {Yingfa Chen},
title = {Implementating Test-Time Training - Part 1},
year = {2025},
url = {https://chen-yingfa.github.io/2025/03/19/ttt-implementation/},
}
References
- Test-Time Training. 2024. https://arxiv.org/abs/2407.04620.
- Titans. 2024. https://arxiv.org/abs/2501.00663.
- PyTorch Implementation of Titans by lucidrains. https://github.com/lucidrains/titans-pytorch.