Following his first blog on "Attention Sinks from the Graph perspective",
@tensorqt has now released a new blogpost, titled "Beyond Attention as a Graph".
First and foremost, tensorqt introduces why standard neural networks require depth, despite the issues that this introduces in sequence modeling (most notably, gradients' instabilities).
In the specific case of Transformers however, depth, while still problematic, is easy to justify: considering that "attention is an operation that message-passes between pairs of tokens in a graph", depth (intended as number of decoder layers) ends up approximating n-hops information transmissions between pair of tokens.
However, what if these n-hops of information passage between pair of tokens could be approximated without resorting to depth?
As such, detailing existing literature, 2-Simplicial Attention (and, more in general, High-order Attention) is introduced.
The intuition here is the following: instead of considering just one key to attend the query to, one can project the key tokens in two subspaces, considering K = XW_k and K' = XW_k', which finally renders the attention calculation a multilinear product.
The result is that while standard attention scores A_ij "represented the link weight going from node
to node i to node j, now each entry can instead be seen as the collective weight assigned to the triangle determined by the (directed) walk from node i, passing through node j, and ending up in node s".
This idea can also be extended to n > 2 key projections, with the equations describing the resulting n-order attention scores here attached (the case described before is with n = 2).
It is immediate though that the (already) quadratic cost of ordinary attention ends up exploding to O(L**(n+1)), where L is the sequence length and n the attention order.
One proposed way to solve this issue builds on DeepSeek Sparse Attention (DSA): first, compute the dot product of each query vector of token i at each head h with a (shared across heads) key for each token j. Pass the result in a ReLU and multiply via a per-head learned weight.
Sum the resulting scores across heads, and only retain, for attention calculations, k keys with the largest scores obtained above to make attend to q: the final computational complexity, in the context of standard attention, goes O(L**2) to O(Lk).
As such, while the original paper sparsifies 2-simplicial attention using local sliding window, tensorqt adapts DSA to n-order attention, testing his framework in the 2-simplicial case. First substituting ReLU with softmax, and then simplifying the scoring by directly using standard QK' attention from previous layers, considering the top-k based on those: from O(L**n) to O(L*k**(n)).
All in all, given the potential of High Order attention, further research to rendering it computationally tractable is welcomed.
Link to the blog below: in the picture, the aforementioned equations governing n-order attention.