wall-attention-release
Attention variant with per-channel multiplicative decay
Wall Attention is a Python library that implements a new variation on the attention mechanism used inside AI language models. In standard transformer models, the attention calculation computes a score for every pair of positions in a sequence based on a dot product of query and key vectors. Wall Attention modifies that score by applying a learned decay to each individual channel of the query-key product, letting the model independently control how much it forgets each dimension of context as distance grows. The authors describe this as a generalization of simpler decay approaches used in other recent architectures.
The library ships two GPU kernels written using Triton, a Python-based GPU programming framework. The first handles training and the initial processing of a full input sequence (called prefill), using a fused computation similar to FlashAttention that avoids storing large intermediate matrices. The second handles the decode phase, where the model generates one token at a time. For decode, the library pre-processes the key-value cache so each new token costs only a small, fixed amount of work regardless of context length.
Both kernels support grouped query attention (GQA), where the number of query heads can be higher than the number of key-value heads. Optional features include a scalar gate per head, an attention sink bias, and a sliding window to limit how far back attention looks. The library also supports variable-length sequence packing for efficient batch processing.
The library is aimed at machine learning researchers and engineers who want to experiment with or build on this attention variant. Tests verify that the kernels match a reference PyTorch implementation and that gradients are numerically correct. The code is MIT licensed and links to a blog post from Tilde Research for additional theoretical background.