Transformers have proven a successful model for a variety of tasks in sequence modeling. However, computing the attention matrix, which is their key component, has quadratic complexity with respect to the sequence length, thus making them prohibitively expensive for large sequences. To address this, we propose clustered attention, which instead of computing the attention for every query, groups queries into clusters and computes attention just for the centroids. To further improve this approximation, we use the computed clusters to identify the keys with the highest attention per query and compute the exact key/query dot products. This results in a model with linear complexity with respect to the sequence length for a fixed number of clusters.

This blog presents a summary of our paper on scaling transformers with clustering. For more details, kindly read our paper Fast Transformers with Clustered Attention.

For rest of the blog, for any matrix $M \in \R^{P \times Q}$, $M_i$ denotes the $i$-th row of the matrix $M$.

1
We start with a brief recap of the vanilla self-attention introduced in
*queries* denoted
by $Q \in \R^{N \times D_k}$ and *keys* denoted by $K \in
\R^{N \times D_k}$, we define the standard dot product attention
matrix $A \in \R^{N \times N}$ as:
$*
\begin{aligned}
A = \softmax{\frac{Q K^T}{\sqrt{D_k}}}.
\end{aligned}
$*
Using the attention weights $A$ and the values $V \in \R^{N \times
D_v}$, we compute the new values $\hat{V}$ as follows:
$*
\begin{aligned}
\hat{V} = A V.
\end{aligned}
$*
Computing attention matrix $A$ requires $\bigO{N^2 D_k}$ operations
and the new values $\hat{V}$ requires $\bigO{N^2 D_v}$ operations.
This results in an asymptotic complexity of $\bigO{N^2 \max
\left(D_k, D_v\right)}$.

The main idea of this work is based on the relation showing that queries close in Euclidean space have similar attention distributions.

More precisely, given two queries $Q_i$ and $Q_j$ such that $\norm{Q_i - Q_j}_2 \leq\epsilon$, $* \begin{equation} \begin{aligned} \norm{\softmax{Q_i K^T} - \softmax{Q_j K^T}}_2 \leq \epsilon \norm{K}_2 \,, \end{aligned} \end{equation} $* where $\norm{K}_2$ denotes the spectral norm of $K$.

2.1We exploit the above observation to improve the computational complexity of self-attention. To do so, we first cluster the queries into $C$ non-overlapping clusters. The partitioning of queries into $C$ clusters is denoted by $S \in \{0,1\}^{N \times C}$ such that, $S_{ij} = 1$, if the $i$-th query $Q_i$ belongs to the $j$-th cluster and $0$ if it doesn't. We now compute attention weights $A^c \in \R^{C \times N}$ using the centroids instead of computing them for every query. Using the clustered attention $A^c$, we compute the new values $V^c \in \R^{C \times N}$. Finally, we use the same attention weights and new values for queries that belong to same cluster. In the following figure we show the different steps of clustered attention:

2.2In this section, we will try to improve our clustered approximation of the vanilla attention. We start by assuming that for any query $Q_i$, we are given the $k$ keys with the highest attention weights. Given these keys, one simple way to improve the clustered attention approximation is to recompute the attention on these keys and substitute the clustered attention weights on these keys with recomputed weights.

However, we need to scale the recomputed attention so that the total attention weight sums up to one. The scaling factor is the total attention weight assigned by the clustered attention $A^c$ to these $top$-$k$ keys.

Recomputing the attention as explained before always improves the approximation (proof in the Appendix B in the paper). For any query $Q_i$ belonging to the cluster $j$, the set of $k$ keys with highest attention weight in clustered attention $A^c_j$ is a good candidate for the set of top-$k$ keys.

More formally, for any query $Q_i$ belonging to the cluster $j$, we start by introducing $T \in \{0, 1\}^{C \times N}$, where $T_{ji} = 1$ if the $i$-th key is among the $top$-$k$ keys for the $j$-th cluster and 0 otherwise. We can then compute the scaling factor denoted by $\hat{m}_j$ as the total attention weight on the $top$-$k$ keys as follows: $* \begin{equation} \hat{m}_j = \sum_{i=1}^{N}T_{ji}A^c_{ji}. \end{equation} $* Let us also denote the set of $k$ keys with the highest attention weights in $A^c_j$ by $K^t_j$. The updated attention weights $A^u_i$ on these keys can be computed as follows: $* \begin{equation} A^u_i = \hat{m}_j \softmax{\frac{Q_i {K^t_j}^T}{\sqrt{D_k}}} \end{equation}. $*

The new values, $\hat{V_i}$, can be efficiently computed using the following decomposition: $* \begin{align} \hat{V_i} = \hat{V_i}^{t} + \hat{V_i}^{b}, \label{eq:improved_values_fast} \end{align} $* where $\hat{V_i}^{t}$ is the weighted average of the values corresponding to the $top$-$k$ keys with weights being the recomputed attention $A^u_i$. $\hat{V_i}^b$ is the weighted average of the rest of the values with weights being the clustered attention $A^c_j$.

In the following figure, we show how we can efficiently compute $\hat{V_i}^t$ and $\hat{V_i}^{b}$ for a single query $Q_i$ belonging to cluster $j$. For clarity, we denote the set of values corresponding to the $top$-$k$ keys by $V^t_j$.

2.3Vanilla Attention | Clustered Attention | Improved Clustered Attention | |

Memory | $\bigO{N^2 \max\left(D_k, D_v\right)}$ | $\bigO{N C \max\left(D_k, D_v\right)}$ | $\bigO{N \max\left(C, k\right) \max\left(D_k, D_v\right)}$ |

Time | $\bigO{N^2 \max\left(D_k, D_v\right)}$ | $\bigO{N C \max\left(D_k, D_v\right)}$ | $\bigO{N \max\left(C, k\right) \max\left(D_k, D_v\right)}$ |

In the following, we evaluate the performance of our model
with respect to its computational requirements,
accuracy on the task of Automatic Speech Recognition
,and approximation of pre-trained RoBERTa
model on the
GLUE

We compare our model with the vanilla transformers
**full** and the Reformer **lsh-X**, where $X$ denotes the
different rounds of hashing. Following the Reformer paper, we set
the number of buckets to $64$ and the chunk size to $32$. We refer
to clustered attention as **clustered-X** and to improved
clustered attention **i-clustered-X**, where $X$ denotes the
number of clusters. For all experiments we set $k$ to 32 for the
improved-clustered attention.

We compare the memory consumption and computation time on
artificially generated sequences of various lengths. For clustered
attention variants we use $100$ clusters.
For Reformer

All methods other than vanilla transformers scale linearly with respect to the sequence length. Note that with respect to per sample memory, both clustered and improved clustered attention perform better than all other methods. It can be seen, that lsh-$1$ is faster than the improved clustered attention, however, Reformer typically requires multiple hashing rounds to generalize. In contrast, our clustered attention demonstrates good performance with 100 clusters or less.

3.2In this experiment, we evaluate different transformer variants on automatic speech recognition (ASR) using the Wall Street Journal and Switchboard databases.

In this experiment, we compare different transformer variants under
an equalized computational budget. We train each transformer
variant with varying capacities to get a range of the required
computation time and achieved *Phone Error Rate* (PER) or
*Word Error Rate* (WER). In the figure below, we plot the
achieved PER/WER on the validation set with respect to the required
time to perform a full forward pass. We observe that
improved clustered consistently achieves lower PER than all other
baselines for a given computational budget.

To highlight the ability of our clustered attention model to
approximate arbitrarily complicated attention distributions, we
evaluate our proposed method on the approximation of a fine-tuned
RoBERTa model

For the GLUE tasks, the maximum sequence length is 128 while for SQuAD, it is 384. For each task, we use $25$ clusters for approximation which is less than $20\%$ and $10\%$ of the input sequence length for GLUE and SQuAD tasks respectively. In the table below, we summarize the performance per task. We observe that improved clustered performs as well as the full transformer in all tasks but SQuAD, in which it is only marginally worse. Moreover, we note that clustered performs significantly worse in tasks that require more complicated attention patterns such as question answering (SQuAD) and textual entailment (RTE).

Attention | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST-2 | SST-B | WNLI | SQuAD |
---|---|---|---|---|---|---|---|---|---|---|

full | 0.601 | 0.880 | 0.868 | 0.929 | 0.915 | 0.682 | 0.947 | 0.900 | 0.437 | 0.904 |

clustered-25 | 0.598 | 0.794 | 0.436 | 0.746 | 0.894 | 0.498 | 0.944 | 0.789 | 0.437 | 0.006 |

i-clustered-25 | 0.601 | 0.880 | 0.873 | 0.930 | 0.915 | 0.704 | 0.947 | 0.900 | 0.437 | 0.876 |

In the figure below, we provide a qualitative comparison between
the full attention, and the clustered attention variants used for
approximation. As described previously, we use $25$ clusters for
both attention variants. For a randomly selected question-context
tuple from SQuAD dataset, we show the attention distribution for
the question tokens. It can be seen that with only few clusters,
improved clustered approximates the full attention very closely
even when the attention distribution has complicated and sparse
patterns. In contrast, clustered attention fails to capture such
an attention distribution during approximation. Moreover, it can
further be seen that for almost all question tokens, both full and
improved clustered have the same tokens with the highest attention
weights. Note that *CLS* token refers to the classification
token appended before the questions.

Manning finished the year with a career-low 67.9 passer rating, throwing for 2,249 yards and nine touchdowns, with 17 interceptions. In contrast, Osweiler threw for 1,967 yards, 10 touchdowns and six interceptions for a rating of 86.4. Veteran receiver Demaryius Thomas led the team with 105 receptions for 1,304 yards and six touchdowns, while Emmanuel Sanders caught 76 passes for 1,135 yards and six scores, while adding another 106 yards returning punts. Tight end Owen Daniels was also a big element of the passing game with 46 receptions for 517 yards. Running back C. J. Anderson was the team's leading rusher 863 yards and seven touchdowns, while also catching 25 passes for 183 yards. Running back Ronnie Hillman also made a big impact with 720 yards, five touchdowns, 24 receptions, and a 4.7 yards per carry average. Overall, the offense ranked 19th in scoring with 355 points and did not have any Pro Bowl selections.

(a) Context
We have presented *clustered attention* a method that
approximates vanilla transformers with significantly lower
computational requirements. In particular, we have shown that for a
given computational budget our model outperforms all other baselines.
In contrast to recent fast variations of transformers, we have also
shown that our method can efficiently approximate pre-trained models
with full attention while retaining the linear asymptotic complexity.

The proposed method opens several research directions towards applying transformers on long sequence tasks such as music generation, scene flow estimation etc. We consider masked language modeling for long texts to be of particular importance, as it will allow finetuning for downstream tasks that need a context longer than the commonly used 512 tokens.

Apoorv Vyas was supported by the Swiss National Science Foundation under grant number FNS-30213 "SHISSM". Angelos Katharopoulos was supported by the Swiss National Science Foundation under grant numbers FNS-30209 "ISUL" and FNS-30224 "CORTI".

The article template is provided by distill.pub and many formatting styles are inspired from the articles appearing on Distill.