Abstract:

The main idea in test-time training (TTT) (Sun et al. 2024) is that a model with fixed parameters produces the supervision for another network that is updated during test-time (or inference-time). This article first reviews the TTT paper. Then, we discuss the problem with TTT and how LaCT addresses them, resulting in a powerful attention alternative that balances efficiency and performance.

Currently, “test-time training” is an overloaded term with multiple meanings. In this article, we use the term to refer to the test-time training paradigm proposed in Sun et al. 2024, which is a framework for recurrent architectures for sequence modeling.

Background

We start by formulating vanilla attention (Vaswani et al. 2017) and linear attention (Katharopoulos et al. 2020), which helps us set up the notations. It also allows us to better understand the TTT architecture by comparing it to attention variants.

Attention

To unify the notation, we formulate the popular attention mechanism as follows:

where is the input at time , is the head size, are the query, key, and value projection matrices, respectively. are called the KV cache, which are the concatenation of all keys and values seen so far:

And the output at each time step is:

We are omitting the denominator from here on for simplicity and because it can be absorbed into the query or key matrices.

Linear Attention

We formulate linear attention here as it has become a popular and competitive RNN variant, and it is fairly easy to understand.

In linear attention (Katharopoulos et al. 2020), we remove the softmax function from attention:

This can be written as a recurrent equation:

Test-Time Training

In test-time training (TTT), the recurrent state is replaced with a neural network , parameterized by , which we call the fast weight. The following table summarizes the update and query rules of different attention variants and DeltaNet (Schlag et al. 2021) and TTT, which are two popular RNN variants. We will soon see that both linear attention and DeltaNet are special cases of TTT.

ModelUpdate RuleQuery RuleState Size
Attention
Linear Attention
DeltaNet
TTTArbitrary

Test-Time Training Update Rule

TTT is trained with a self-supervised reconstruction loss:

The main principle behind TTT is that the fast weights is an associative memory that reconstructs from .

Example 1 - Linear Attention:

Consider the simplest case where the fast weight is a linear model (i.e., ) and the loss function is the negative dot product of and the reconstructed value :

The update rule we get is the same as the update rule of linear attention (with the addition of the learning rate ). In practice, TTT uses a data-dependent learning rate to enhance the model’s expressiveness.

Example 2 - DeltaNet:

If is a linear model and the loss function is the MSE loss, we get the well-known Delta Rule:

The reason we unlock a forgetting mechanism by simply changing the loss function, is because with a negative dot product loss, the model can simply push the reconstructed value indefinitely towards the direction of . In contrast, with an MSE loss, the has to be close to in both direction and magnitude to minimize the loss. This creates a forgetting mechanism where the effect of the previous key-value association is gradually forgotten.

Efficient Meta Training

TTT cannot be parallelized because the update rule is sequential.

To make it parallelizable, they use a mini-batch gradient descent trick: The sequence is split into chunks, and the gradient at each step is computed with respect to the state at the beginning of the chunk.

where is the time step at the beginning of the chunk.

This change allows us to compute the gradients at different positions within the same chunk in parallel, which improves the training speed.

The Problem with TTT

Training was performed on TPUs. The official PyTorch implementation is extremely slow for training on GPUs (FLOPs utilization below 5%). The arithmetic intensity (FLOPs per memory access) is very low, making the TTT layer memory-bound. Assume a fast weight of shape , and a (chunk) input of shape , the arithmetic intensity is:

where is the memory of the fast weight, and is the memory of the input and output.

Thus, in practice, the arithmetic intensity of TTT is bounded by the chunk size .

Large Chunk Test-Time Training (LaCT)

LaCT (Zhang et al., 2025) increases the arithmetic intensity of TTT by increasing the chunk size from 16 to 2048 or even 1M tokens. However, directly increasing the chunk size results in poor modeling of local dependencies. Thus, LaCT incorporates sliding window attention layer for handling local dependencies. In practice, the SWA and TTT layers share the same set of QKV.

LaCT architecture

(to be continued…)


Thanks for reading. If you want further discussions, reach me at chenyingfa1999@gmail.com.