This blog post is part 1 of a series that describes my attempt in implementing the Test-Time Training (TTT) model proposed by Sun et al. (2024), and Titans, proposed by Behrouz et al., (2024). At the time of writing, these two are two strong recurrent language models, but they have not yet open-sourced their implementation (TTT has only open-sourced the Jax implementation).

Introduction to Test-Time Training

Briefly explained, Test-Time Training (TTT) is an RNN model whose hidden state is replaced with an online learner, whose parameters are updated updated through gradient descent during inference. The goal is that this online learner compress contextual information into its parameters. A TTT operator can be expressed as:

where are the query, key, and value vectors at time , respectively, is the recurrent state and the parameters of an online learner. is a hyperparameter. The function is the “update rule”, and is the “query rule”.

The Update Rule

The interesting part is that the update rule is implemented with gradient descent:

where the loss is defined as:

In other words, the online learner is learning to produce the value vector when we query it with the key vector .

Mini-Batch Gradient Descent

Obviously, directly running this algorithm is extremely slow because it cannot be parallellized. Sun et al. (2024) proposed mini-batch gradient descent, in which the input sequence is split into mini-batches and the gradients are computed based on the online learner’s state at the start on mini-batch :

Dual Form

In order to make make use of matmuls and avoid the materialization of gradients, TTT further proposes to use a “dual form”. Without loss of generality, we consider a single mini-batch, where is the initial state of the online learner.

Dual Form of Linear Model

For simplicity, we first explain the dual form for a linear model:

where . We now add in the query vectors and make use of associativity of matrix multiplication to directly compute the output:

Now, let’s define the reconstruction error as and , , . Then we have:

where is the output matrix, are the query, key, and value matrices, is a mask matrix (upper triangular matrix), and is the element-wise product. We also need the state at the end of the batch, which is simply:

It’s interesting to see how similar this is to ordinary causal linear attention.

So, to sum up, the dual form of TTT accepts a batch of queries, keys, and values, and online learner state , and returns the output matrix and the state at the end of the batch . The state at the end of the batch is then used as the initial state for the next batch.

Dual Form of MLP

The dual form can be extended to a two-layer MLP. The idea is to first apply the dual form to the first layer to get the intermediate representation, and then apply the dual form to the second layer to get the output. The forward pass can be implemented as:

Then the gradients are:

Define , then we can apply dual form the first layer:

where is the intermediate representation when querying the updated state. Since this is the input query for the second layer, we can apply the dual form to the second layer as if it were a linear model:

The state at the end of the batch should be trivial to compute.

Code Implementation

Existing Open-Source Implementation

I have found an excellent PyTorch implementation of Titans by lucidrains. Titans is an extension of TTT that adds in momentum-based learning and data-depdendent weight decay and learning rate. However, in this implementation, the memory query is performed on the memory at the start of the batch:

instead of querying the state after KV insertion . Not only does this make the implementation much simpler and faster, it also allows for the use of PyTorch’s autograd engine (specifically the grad and functional_call functions) to avoid explicitly writing the gradient computation.

However, this is still not exactly the same as the TTT. So, I will try to implement the original TTT model.

My Implementation

It should be pretty straightforward to implement the dual form for a two-layer MLP, which is what I will do. Assume that the online learner is defined as:

The forward pass can be implemented as:

def two_layer_mlp_dual_form(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    weights: dict[str, Tensor] | None = None,
) -> tuple[dict[str, Tensor], Tensor]:
    """
    This operator utilizes the dual form of TTT (https://arxiv.org/pdf/2407.04620)
    to avoid materializing the gradients.
    """
    device = q.device
    batch_size = q.shape[0]
    w1 = weights["w1"]
    w2 = weights["w2"]

    # forward pass
    h = einsum(w1, k, "dm d, b d -> b dm")  # (b, dh)
    h_relu = torch.relu(h)  # (b, dh)
    h2 = einsum(w2, h_relu, "d dm, b dm -> b d")
    pred = h2 + k  # (b, d)

    # Backward pass
    # We first compute the gradients w.r.t. to the activations,
    # which can be efficiently computed because they utilize
    # backward matmuls.
    grad = pred - v  # (b, d)
    grad_h2 = einsum(w2, grad, "dy dh, b dy -> b dh")
    grad_h = grad_h2 * (h_relu > 0).float()

    # Then, we avoid materializing the gradients w.r.t. the two linear layers
    # W1 and W2, by associativity with the queries.
    mask = torch.triu(
        torch.ones(batch_size, batch_size, requires_grad=False, device=device),
    )
    attn_scores = einsum(k, q, 'bk dk, bq dk -> bk bq')  # (b, b)
    attn_scores = attn_scores * mask  # (b, b)
    h_q = einsum(w1, q, 'dm d, b d -> b dm')
    h_q = h_q - einsum(grad_h, attn_scores, 'bk dm, bk bq -> bq dm')
    h2_q = torch.relu(h_q)  # (b, dm)

    # Compute the output
    attn_scores = einsum(h_relu, h2_q, 'bk dm, bq dm -> bk bq')
    attn_scores = attn_scores * mask
    y = einsum(w2, h2_q, 'd dm, b dm -> b d')
    y = y - einsum(grad, attn_scores, 'bk d, bk bq -> bq d')
    y = y + q  # Residual connection

    # Compute the weights at the end of the batch
    grad_w1 = einsum(grad_h, k, "b dh, b dx -> dh dx")
    w1 = w1 - grad_w1
    grad_w2 = einsum(grad, h_relu, "b dy, b dh -> dy dh")
    w2 = w2 - grad_w2
    weights = {
        'w1': w1,
        'w2': w2,
    }
    return y, weights

How to Cite

@misc{chen2025ttt-implementation,
    author = {Yingfa Chen},
    title = {Implementating Test-Time Training - Part 1},
    year = {2025},
    url = {https://chen-yingfa.github.io/2025/03/19/ttt-implementation/},
}

References