Transformers have proven a successful model for a variety of tasks in sequence modeling. However, computing the attention matrix, which is their key component, has quadratic complexity with respect to the sequence length, thus making them prohibitively expensive for large sequences. To address this, we propose clustered attention, which instead of computing the attention for every query, groups queries into clusters and computes attention just for the centroids. To further improve this approximation, we use the computed clusters to identify the keys with the highest attention per query and compute the exact key/query dot products. This results in a model with linear complexity with respect to the sequence length for a fixed number of clusters.
This blog presents a summary of our paper on scaling transformers with clustering. For more details, kindly read our paper Fast Transformers with Clustered Attention.
For rest of the blog, for any matrix $M \in \R^{P \times Q}$, $M_i$ denotes the $i$-th row of the matrix $M$.
1
We start with a brief recap of the vanilla self-attention introduced in
The main idea of this work is based on the relation showing that queries close in Euclidean space have similar attention distributions.
More precisely, given two queries $Q_i$ and $Q_j$ such that $\norm{Q_i - Q_j}_2 \leq\epsilon$, $* \begin{equation} \begin{aligned} \norm{\softmax{Q_i K^T} - \softmax{Q_j K^T}}_2 \leq \epsilon \norm{K}_2 \,, \end{aligned} \end{equation} $* where $\norm{K}_2$ denotes the spectral norm of $K$.
2.1We exploit the above observation to improve the computational complexity of self-attention. To do so, we first cluster the queries into $C$ non-overlapping clusters. The partitioning of queries into $C$ clusters is denoted by $S \in \{0,1\}^{N \times C}$ such that, $S_{ij} = 1$, if the $i$-th query $Q_i$ belongs to the $j$-th cluster and $0$ if it doesn't. We now compute attention weights $A^c \in \R^{C \times N}$ using the centroids instead of computing them for every query. Using the clustered attention $A^c$, we compute the new values $V^c \in \R^{C \times N}$. Finally, we use the same attention weights and new values for queries that belong to same cluster. In the following figure we show the different steps of clustered attention:
2.2In this section, we will try to improve our clustered approximation of the vanilla attention. We start by assuming that for any query $Q_i$, we are given the $k$ keys with the highest attention weights. Given these keys, one simple way to improve the clustered attention approximation is to recompute the attention on these keys and substitute the clustered attention weights on these keys with recomputed weights.
However, we need to scale the recomputed attention so that the total attention weight sums up to one. The scaling factor is the total attention weight assigned by the clustered attention $A^c$ to these $top$-$k$ keys.
Recomputing the attention as explained before always improves the approximation (proof in the Appendix B in the paper). For any query $Q_i$ belonging to the cluster $j$, the set of $k$ keys with highest attention weight in clustered attention $A^c_j$ is a good candidate for the set of top-$k$ keys.
More formally, for any query $Q_i$ belonging to the cluster $j$, we start by introducing $T \in \{0, 1\}^{C \times N}$, where $T_{ji} = 1$ if the $i$-th key is among the $top$-$k$ keys for the $j$-th cluster and 0 otherwise. We can then compute the scaling factor denoted by $\hat{m}_j$ as the total attention weight on the $top$-$k$ keys as follows: $* \begin{equation} \hat{m}_j = \sum_{i=1}^{N}T_{ji}A^c_{ji}. \end{equation} $* Let us also denote the set of $k$ keys with the highest attention weights in $A^c_j$ by $K^t_j$. The updated attention weights $A^u_i$ on these keys can be computed as follows: $* \begin{equation} A^u_i = \hat{m}_j \softmax{\frac{Q_i {K^t_j}^T}{\sqrt{D_k}}} \end{equation}. $*
The new values, $\hat{V_i}$, can be efficiently computed using the following decomposition: $* \begin{align} \hat{V_i} = \hat{V_i}^{t} + \hat{V_i}^{b}, \label{eq:improved_values_fast} \end{align} $* where $\hat{V_i}^{t}$ is the weighted average of the values corresponding to the $top$-$k$ keys with weights being the recomputed attention $A^u_i$. $\hat{V_i}^b$ is the weighted average of the rest of the values with weights being the clustered attention $A^c_j$.
In the following figure, we show how we can efficiently compute $\hat{V_i}^t$ and $\hat{V_i}^{b}$ for a single query $Q_i$ belonging to cluster $j$. For clarity, we denote the set of values corresponding to the $top$-$k$ keys by $V^t_j$.
2.3Vanilla Attention | Clustered Attention | Improved Clustered Attention | |
Memory | $\bigO{N^2 \max\left(D_k, D_v\right)}$ | $\bigO{N C \max\left(D_k, D_v\right)}$ | $\bigO{N \max\left(C, k\right) \max\left(D_k, D_v\right)}$ |
Time | $\bigO{N^2 \max\left(D_k, D_v\right)}$ | $\bigO{N C \max\left(D_k, D_v\right)}$ | $\bigO{N \max\left(C, k\right) \max\left(D_k, D_v\right)}$ |
In the following, we evaluate the performance of our model
with respect to its computational requirements,
accuracy on the task of Automatic Speech Recognition
,and approximation of pre-trained RoBERTa
model on the
GLUE
We compare our model with the vanilla transformers
We compare the memory consumption and computation time on
artificially generated sequences of various lengths. For clustered
attention variants we use $100$ clusters.
For Reformer
All methods other than vanilla transformers scale linearly with respect to the sequence length. Note that with respect to per sample memory, both clustered and improved clustered attention perform better than all other methods. It can be seen, that lsh-$1$ is faster than the improved clustered attention, however, Reformer typically requires multiple hashing rounds to generalize. In contrast, our clustered attention demonstrates good performance with 100 clusters or less.
3.2In this experiment, we evaluate different transformer variants on automatic speech recognition (ASR) using the Wall Street Journal and Switchboard databases.
In this experiment, we compare different transformer variants under an equalized computational budget. We train each transformer variant with varying capacities to get a range of the required computation time and achieved Phone Error Rate (PER) or Word Error Rate (WER). In the figure below, we plot the achieved PER/WER on the validation set with respect to the required time to perform a full forward pass. We observe that improved clustered consistently achieves lower PER than all other baselines for a given computational budget.
To highlight the ability of our clustered attention model to
approximate arbitrarily complicated attention distributions, we
evaluate our proposed method on the approximation of a fine-tuned
RoBERTa model
For the GLUE tasks, the maximum sequence length is 128 while for SQuAD, it is 384. For each task, we use $25$ clusters for approximation which is less than $20\%$ and $10\%$ of the input sequence length for GLUE and SQuAD tasks respectively. In the table below, we summarize the performance per task. We observe that improved clustered performs as well as the full transformer in all tasks but SQuAD, in which it is only marginally worse. Moreover, we note that clustered performs significantly worse in tasks that require more complicated attention patterns such as question answering (SQuAD) and textual entailment (RTE).
Attention | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST-2 | SST-B | WNLI | SQuAD |
---|---|---|---|---|---|---|---|---|---|---|
full | 0.601 | 0.880 | 0.868 | 0.929 | 0.915 | 0.682 | 0.947 | 0.900 | 0.437 | 0.904 |
clustered-25 | 0.598 | 0.794 | 0.436 | 0.746 | 0.894 | 0.498 | 0.944 | 0.789 | 0.437 | 0.006 |
i-clustered-25 | 0.601 | 0.880 | 0.873 | 0.930 | 0.915 | 0.704 | 0.947 | 0.900 | 0.437 | 0.876 |
In the figure below, we provide a qualitative comparison between the full attention, and the clustered attention variants used for approximation. As described previously, we use $25$ clusters for both attention variants. For a randomly selected question-context tuple from SQuAD dataset, we show the attention distribution for the question tokens. It can be seen that with only few clusters, improved clustered approximates the full attention very closely even when the attention distribution has complicated and sparse patterns. In contrast, clustered attention fails to capture such an attention distribution during approximation. Moreover, it can further be seen that for almost all question tokens, both full and improved clustered have the same tokens with the highest attention weights. Note that CLS token refers to the classification token appended before the questions.
Manning finished the year with a career-low 67.9 passer rating, throwing for 2,249 yards and nine touchdowns, with 17 interceptions. In contrast, Osweiler threw for 1,967 yards, 10 touchdowns and six interceptions for a rating of 86.4. Veteran receiver Demaryius Thomas led the team with 105 receptions for 1,304 yards and six touchdowns, while Emmanuel Sanders caught 76 passes for 1,135 yards and six scores, while adding another 106 yards returning punts. Tight end Owen Daniels was also a big element of the passing game with 46 receptions for 517 yards. Running back C. J. Anderson was the team's leading rusher 863 yards and seven touchdowns, while also catching 25 passes for 183 yards. Running back Ronnie Hillman also made a big impact with 720 yards, five touchdowns, 24 receptions, and a 4.7 yards per carry average. Overall, the offense ranked 19th in scoring with 355 points and did not have any Pro Bowl selections.
We have presented clustered attention a method that approximates vanilla transformers with significantly lower computational requirements. In particular, we have shown that for a given computational budget our model outperforms all other baselines. In contrast to recent fast variations of transformers, we have also shown that our method can efficiently approximate pre-trained models with full attention while retaining the linear asymptotic complexity.
The proposed method opens several research directions towards applying transformers on long sequence tasks such as music generation, scene flow estimation etc. We consider masked language modeling for long texts to be of particular importance, as it will allow finetuning for downstream tasks that need a context longer than the commonly used 512 tokens.
Apoorv Vyas was supported by the Swiss National Science Foundation under grant number FNS-30213 "SHISSM". Angelos Katharopoulos was supported by the Swiss National Science Foundation under grant numbers FNS-30209 "ISUL" and FNS-30224 "CORTI".
The article template is provided by distill.pub and many formatting styles are inspired from the articles appearing on Distill.