Abstract:
The main idea in test-time training (TTT) (Sun et al. 2024) is that a model with fixed parameters produces the supervision for another network that is updated during test-time (or inference-time). This article first reviews the TTT paper. Then, we discuss the problem with TTT and how LaCT addresses them, resulting in a powerful attention alternative that balances efficiency and performance.
Currently, “test-time training” is an overloaded term with multiple meanings. In this article, we use the term to refer to the test-time training paradigm proposed in Sun et al. 2024, which is a framework for recurrent architectures for sequence modeling.
Background
We start by formulating vanilla attention (Vaswani et al. 2017) and linear attention (Katharopoulos et al. 2020), which helps us set up the notations. It also allows us to better understand the TTT architecture by comparing it to attention variants.
Attention
To unify the notation, we formulate the popular attention mechanism as follows:
where
And the output at each time step is:
We are omitting the
denominator from here on for simplicity and because it can be absorbed into the query or key matrices.
Linear Attention
We formulate linear attention here as it has become a popular and competitive RNN variant, and it is fairly easy to understand.
In linear attention (Katharopoulos et al. 2020), we remove the softmax function from attention:
This can be written as a recurrent equation:
Test-Time Training
In test-time training (TTT), the recurrent state is replaced with a neural network
| Model | Update Rule | Query Rule | State Size |
|---|---|---|---|
| Attention | |||
| Linear Attention | |||
| DeltaNet | |||
| TTT | Arbitrary |
Test-Time Training Update Rule
TTT is trained with a self-supervised reconstruction loss:
The main principle behind TTT is that the fast weights is an associative memory that reconstructs
Example 1 - Linear Attention:
Consider the simplest case where the fast weight is a linear model (i.e.,
The update rule we get is the same as the update rule of linear attention (with the addition of the learning rate
Example 2 - DeltaNet:
If
The reason we unlock a forgetting mechanism by simply changing the loss function, is because with a negative dot product loss, the model can simply push the reconstructed value
Efficient Meta Training
TTT cannot be parallelized because the update rule is sequential.
To make it parallelizable, they use a mini-batch gradient descent trick: The sequence is split into chunks, and the gradient at each step is computed with respect to the state at the beginning of the chunk.
where
This change allows us to compute the gradients at different positions within the same chunk in parallel, which improves the training speed.
The Problem with TTT
Training was performed on TPUs. The official PyTorch implementation is extremely slow for training on GPUs (FLOPs utilization below 5%). The arithmetic intensity (FLOPs per memory access) is very low, making the TTT layer memory-bound.
Assume a fast weight of shape
where
Thus, in practice, the arithmetic intensity of TTT is bounded by the chunk size
Large Chunk Test-Time Training (LaCT)
LaCT (Zhang et al., 2025) increases the arithmetic intensity of TTT by increasing the chunk size from 16 to 2048 or even 1M tokens. However, directly increasing the chunk size results in poor modeling of local dependencies. Thus, LaCT incorporates sliding window attention layer for handling local dependencies. In practice, the SWA and TTT layers share the same set of QKV.
LaCT architecture
(to be continued…)
Thanks for reading. If you want further discussions, reach me at chenyingfa1999@gmail.com.