Transformers have proven a successful model for a variety of tasks in sequence modeling. However, computing the attention matrix, which is their key component, has quadratic complexity with respect to the sequence length, thus making them prohibitively expensive for large sequences. To address this, we propose clustered attention, which instead of computing the attention for every query, groups queries into clusters and computes attention just for the centroids. To further improve this approximation, we use the computed clusters to identify the keys with the highest attention per query and compute the exact key/query dot products. This results in a model with linear complexity with respect to the sequence length for a fixed number of clusters.

This blog presents a summary of our paper on scaling transformers with clustering. For more details, kindly read our paper Fast Transformers with Clustered Attention.

For rest of the blog, for any matrix $M \in \R^{P \times Q}$, $M_i$ denotes the $i$-th row of the matrix $M$.

1

Recap: Vanilla Self-Attention

We start with a brief recap of the vanilla self-attention introduced in . For any seqeunce of length $N$, given the queries denoted by $Q \in \R^{N \times D_k}$ and keys denoted by $K \in \R^{N \times D_k}$, we define the standard dot product attention matrix $A \in \R^{N \times N}$ as: $* \begin{aligned} A = \softmax{\frac{Q K^T}{\sqrt{D_k}}}. \end{aligned} $* Using the attention weights $A$ and the values $V \in \R^{N \times D_v}$, we compute the new values $\hat{V}$ as follows: $* \begin{aligned} \hat{V} = A V. \end{aligned} $* Computing attention matrix $A$ requires $\bigO{N^2 D_k}$ operations and the new values $\hat{V}$ requires $\bigO{N^2 D_v}$ operations. This results in an asymptotic complexity of $\bigO{N^2 \max \left(D_k, D_v\right)}$.

2

Self-Attention Approximation

The main idea of this work is based on the relation showing that queries close in Euclidean space have similar attention distributions.

More precisely, given two queries $Q_i$ and $Q_j$ such that $\norm{Q_i - Q_j}_2 \leq\epsilon$, $* \begin{equation} \begin{aligned} \norm{\softmax{Q_i K^T} - \softmax{Q_j K^T}}_2 \leq \epsilon \norm{K}_2 \,, \end{aligned} \end{equation} $* where $\norm{K}_2$ denotes the spectral norm of $K$.

2.1

Clustered Self-Attention

We exploit the above observation to improve the computational complexity of self-attention. To do so, we first cluster the queries into $C$ non-overlapping clusters. The partitioning of queries into $C$ clusters is denoted by $S \in \{0,1\}^{N \times C}$ such that, $S_{ij} = 1$, if the $i$-th query $Q_i$ belongs to the $j$-th cluster and $0$ if it doesn't. We now compute attention weights $A^c \in \R^{C \times N}$ using the centroids instead of computing them for every query. Using the clustered attention $A^c$, we compute the new values $V^c \in \R^{C \times N}$. Finally, we use the same attention weights and new values for queries that belong to same cluster. In the following figure we show the different steps of clustered attention:

Flow-chart demonstrating the computation for clustered attention. Different colors represent the query groups and the computed centroids. The same colors are then used to show the attention weights $A^c$, new values for the centroids $\hat{V}^c$, and the resulting values $\hat{V}$ after broadcasting. Since we compute the attention weights and new values using the $C$ centroids, the computational complexity is $\bigO{N C \max\left(D_k, D_v\right)}$.
2.2

Improved Clustered Attention

In this section, we will try to improve our clustered approximation of the vanilla attention. We start by assuming that for any query $Q_i$, we are given the $k$ keys with the highest attention weights. Given these keys, one simple way to improve the clustered attention approximation is to recompute the attention on these keys and substitute the clustered attention weights on these keys with recomputed weights.

However, we need to scale the recomputed attention so that the total attention weight sums up to one. The scaling factor is the total attention weight assigned by the clustered attention $A^c$ to these $top$-$k$ keys.

Recomputing the attention as explained before always improves the approximation (proof in the Appendix B in the paper). For any query $Q_i$ belonging to the cluster $j$, the set of $k$ keys with highest attention weight in clustered attention $A^c_j$ is a good candidate for the set of top-$k$ keys.

More formally, for any query $Q_i$ belonging to the cluster $j$, we start by introducing $T \in \{0, 1\}^{C \times N}$, where $T_{ji} = 1$ if the $i$-th key is among the $top$-$k$ keys for the $j$-th cluster and 0 otherwise. We can then compute the scaling factor denoted by $\hat{m}_j$ as the total attention weight on the $top$-$k$ keys as follows: $* \begin{equation} \hat{m}_j = \sum_{i=1}^{N}T_{ji}A^c_{ji}. \end{equation} $* Let us also denote the set of $k$ keys with the highest attention weights in $A^c_j$ by $K^t_j$. The updated attention weights $A^u_i$ on these keys can be computed as follows: $* \begin{equation} A^u_i = \hat{m}_j \softmax{\frac{Q_i {K^t_j}^T}{\sqrt{D_k}}} \end{equation}. $*

The new values, $\hat{V_i}$, can be efficiently computed using the following decomposition: $* \begin{align} \hat{V_i} = \hat{V_i}^{t} + \hat{V_i}^{b}, \label{eq:improved_values_fast} \end{align} $* where $\hat{V_i}^{t}$ is the weighted average of the values corresponding to the $top$-$k$ keys with weights being the recomputed attention $A^u_i$. $\hat{V_i}^b$ is the weighted average of the rest of the values with weights being the clustered attention $A^c_j$.

In the following figure, we show how we can efficiently compute $\hat{V_i}^t$ and $\hat{V_i}^{b}$ for a single query $Q_i$ belonging to cluster $j$. For clarity, we denote the set of values corresponding to the $top$-$k$ keys by $V^t_j$.

Flow-chart demonstrating the efficient $\hat{V_i}^b$ computation for a query $Q_i$ belonging to cluster $j$. Note that this is exactly the same computation for clustered attention but with attention weights corresponding to the $top$-$k$ keys masked out. Thus this has a computation complexity of $\bigO{N C \max\left(D_k, D_v\right)}$.
Flow-chart demonstrating the efficient $\hat{V_i}^t$ computation for a query $Q_i$ belonging to cluster $j$. Note that computing $A^u_i$ requires a sparse dot product between the query $Q_i$ and the keys $K^t_j$. Similarly, computing $\hat{V}^t$ requires weighted average of $V^t$. The overall computational complexity is given by $\bigO{N k \max\left(D_k, D_v\right)}$.
2.3

Computational Complexity

Comparing computational complexities for the different attention variants. Note that in practice $C \lt\lt N$ implying significant gains. Also, for $k \leq C$ improved clustered and clustered attention have same asymptotic complexities.
Vanilla Attention Clustered Attention Improved Clustered Attention
Memory $\bigO{N^2 \max\left(D_k, D_v\right)}$ $\bigO{N C \max\left(D_k, D_v\right)}$ $\bigO{N \max\left(C, k\right) \max\left(D_k, D_v\right)}$
Time $\bigO{N^2 \max\left(D_k, D_v\right)}$ $\bigO{N C \max\left(D_k, D_v\right)}$ $\bigO{N \max\left(C, k\right) \max\left(D_k, D_v\right)}$
3

Experimental Results

In the following, we evaluate the performance of our model with respect to its computational requirements, accuracy on the task of Automatic Speech Recognition ,and approximation of pre-trained RoBERTa model on the GLUE and SQuAD benchmarks in Natural Language Processing.

We compare our model with the vanilla transformers , which we refer to as full and the Reformer with various rounds of hashing, which we refer to as lsh-X, where $X$ denotes the different rounds of hashing. Following the Reformer paper, we set the number of buckets to $64$ and the chunk size to $32$. We refer to clustered attention as clustered-X and to improved clustered attention i-clustered-X, where $X$ denotes the number of clusters. For all experiments we set $k$ to 32 for the improved-clustered attention.

3.1

Time and Memory Benchmark

We compare the memory consumption and computation time on artificially generated sequences of various lengths. For clustered attention variants we use $100$ clusters. For Reformer , we evaluate on two variants using $1$ and $4$ rounds of hashing.

All methods other than vanilla transformers scale linearly with respect to the sequence length. Note that with respect to per sample memory, both clustered and improved clustered attention perform better than all other methods. It can be seen, that lsh-$1$ is faster than the improved clustered attention, however, Reformer typically requires multiple hashing rounds to generalize. In contrast, our clustered attention demonstrates good performance with 100 clusters or less.

Per element GPU time and memory consumption for a forward/backward pass. All models, except full, scale linearly with respect to the sequence length since they have constant time and memory per element.
3.2

Automatic Speech Recognition

In this experiment, we evaluate different transformer variants on automatic speech recognition (ASR) using the Wall Street Journal and Switchboard databases.

Speed Accuracy Trade-off

In this experiment, we compare different transformer variants under an equalized computational budget. We train each transformer variant with varying capacities to get a range of the required computation time and achieved Phone Error Rate (PER) or Word Error Rate (WER). In the figure below, we plot the achieved PER/WER on the validation set with respect to the required time to perform a full forward pass. We observe that improved clustered consistently achieves lower PER than all other baselines for a given computational budget.

(a) Wall Street Journal
(b) Switchboard
We compare the performance of various transformer models under an equalized computational budget. The numbers near the datapoints denote the number of layers and number of clusters or hashing rounds where applicable. Improved clustered is consistently better than all baselines for a given computational budget for both WSJ and Switchboard datasets.
3.3

RoBERTa Approximation

To highlight the ability of our clustered attention model to approximate arbitrarily complicated attention distributions, we evaluate our proposed method on the approximation of a fine-tuned RoBERTa model on the GLUE and SQuAD benchmarks.

For the GLUE tasks, the maximum sequence length is 128 while for SQuAD, it is 384. For each task, we use $25$ clusters for approximation which is less than $20\%$ and $10\%$ of the input sequence length for GLUE and SQuAD tasks respectively. In the table below, we summarize the performance per task. We observe that improved clustered performs as well as the full transformer in all tasks but SQuAD, in which it is only marginally worse. Moreover, we note that clustered performs significantly worse in tasks that require more complicated attention patterns such as question answering (SQuAD) and textual entailment (RTE).

Attention CoLA MNLI MRPC QNLI QQP RTE SST-2 SST-B WNLI SQuAD
full 0.601 0.880 0.868 0.929 0.915 0.682 0.947 0.900 0.437 0.904
clustered-25 0.598 0.794 0.436 0.746 0.894 0.498 0.944 0.789 0.437 0.006
i-clustered-25 0.601 0.880 0.873 0.930 0.915 0.704 0.947 0.900 0.437 0.876
Evaluating the approximation performance on GLUE and SQuAD benchmarks. We report accuracy for all tasks except STS-B and SQuAD, where we report Pearson correlation and F1-score respectively. For all metrics higher is better.
3.4

Attention Patterns (Qualitative Comparison)

In the figure below, we provide a qualitative comparison between the full attention, and the clustered attention variants used for approximation. As described previously, we use $25$ clusters for both attention variants. For a randomly selected question-context tuple from SQuAD dataset, we show the attention distribution for the question tokens. It can be seen that with only few clusters, improved clustered approximates the full attention very closely even when the attention distribution has complicated and sparse patterns. In contrast, clustered attention fails to capture such an attention distribution during approximation. Moreover, it can further be seen that for almost all question tokens, both full and improved clustered have the same tokens with the highest attention weights. Note that CLS token refers to the classification token appended before the questions.

Manning finished the year with a career-low 67.9 passer rating, throwing for 2,249 yards and nine touchdowns, with 17 interceptions. In contrast, Osweiler threw for 1,967 yards, 10 touchdowns and six interceptions for a rating of 86.4. Veteran receiver Demaryius Thomas led the team with 105 receptions for 1,304 yards and six touchdowns, while Emmanuel Sanders caught 76 passes for 1,135 yards and six scores, while adding another 106 yards returning punts. Tight end Owen Daniels was also a big element of the passing game with 46 receptions for 517 yards. Running back C. J. Anderson was the team's leading rusher 863 yards and seven touchdowns, while also catching 25 passes for 183 yards. Running back Ronnie Hillman also made a big impact with 720 yards, five touchdowns, 24 receptions, and a 4.7 yards per carry average. Overall, the offense ranked 19th in scoring with 355 points and did not have any Pro Bowl selections.

(a) Context
(b) Full Attention
(c) Improved Clustered Attention
(d) Clustered Attention
Attention matrices for question-context tuples for full attention, and clustered and i-clustered attention used for approximation. (a) shows the context for the question with the answer highlighted. (b) shows the attention distribution for full, (c) and (d) show the approximation using i-clustered and clustered attention respectively. Note that improved clustered attention patterns are very similar to full while clustered attention shows qualitatively different attention patterns. For each question token, we also present the tokens with highest attention on the right axis. Darker color represents higher attention weight.
4

Conclusions

We have presented clustered attention a method that approximates vanilla transformers with significantly lower computational requirements. In particular, we have shown that for a given computational budget our model outperforms all other baselines. In contrast to recent fast variations of transformers, we have also shown that our method can efficiently approximate pre-trained models with full attention while retaining the linear asymptotic complexity.

The proposed method opens several research directions towards applying transformers on long sequence tasks such as music generation, scene flow estimation etc. We consider masked language modeling for long texts to be of particular importance, as it will allow finetuning for downstream tasks that need a context longer than the commonly used 512 tokens.

Acknowledgments

Apoorv Vyas was supported by the Swiss National Science Foundation under grant number FNS-30213 "SHISSM". Angelos Katharopoulos was supported by the Swiss National Science Foundation under grant numbers FNS-30209 "ISUL" and FNS-30224 "CORTI".

The article template is provided by distill.pub and many formatting styles are inspired from the articles appearing on Distill.