TasksSotADatasetsPapersMethodsSubmitAbout
Papers With Code 2

A community resource for machine learning research: papers, code, benchmarks, and state-of-the-art results.

Explore

Notable BenchmarksAll SotADatasetsPapersMethods

Community

Submit ResultsAbout

Data sourced from the PWC Archive (CC-BY-SA 4.0). Built by the community, for the community.

Methods/SRU++

SRU++

SequentialIntroduced 20002 papers
Source Paper

Description

SRU++ is a self-attentive recurrent unit that combines fast recurrence and attention for sequence modeling, extending the SRU unit. The key modification of SRU++ is to incorporate more expressive non-linear operations into the recurrent network. Specifically, given the input sequence represented as a matrix X∈RL×d\mathbf{X} \in \mathbb{R}^{L \times d}X∈RL×d, the attention component computes the query, key and value representations using the following multiplications,

Q=WqX⊤\mathbf{Q} =\mathbf{W}^{q} \mathbf{X}^{\top} Q=WqX⊤ K=WkQ\mathbf{K} =\mathbf{W}^{k} \mathbf{Q} \\K=WkQ V=WvQ\mathbf{V} =\mathbf{W}^{v} \mathbf{Q}V=WvQ

where Wq∈Rd′×d,Wk,Wv∈Rd′×d′\mathbf{W}^{q} \in \mathbb{R}^{d^{\prime} \times d}, \mathbf{W}^{k}, \mathbf{W}^{v} \in \mathbb{R}^{d^{\prime} \times d^{\prime}}Wq∈Rd′×d,Wk,Wv∈Rd′×d′ are model parameters. d′d^{\prime}d′ is the attention dimension that is typically much smaller than ddd. Note that the keys K\mathbf{K}K and values V\mathbf{V}V are computed using Q\mathbf{Q}Q instead of X\mathbf{X}X such that the weight matrices Wk\mathbf{W}^{k}Wk and Wv\mathbf{W}^{v}Wv are significantly smaller.

Next, we compute a weighted average output A∈Rd′×L\mathbf{A} \in \mathbb{R}^{d^{\prime} \times L}A∈Rd′×L using scaled dot-product attention:

A⊤=softmax⁡(Q⊤Kd′)V⊤\mathbf{A}^{\top}=\operatorname{softmax}\left(\frac{\mathbf{Q}^{\top} \mathbf{K}}{\sqrt{d^{\prime}}}\right) \mathbf{V}^{\top}A⊤=softmax(d′​Q⊤K​)V⊤

The final output UUU required by the elementwise recurrence is obtained by another linear projection,

U⊤=Wo(Q+α⋅A)\mathbf{U}^{\top}=\mathbf{W}^{o}(\mathbf{Q}+\alpha \cdot \mathbf{A})U⊤=Wo(Q+α⋅A)

where α∈R\alpha \in \mathbb{R}α∈R is a learned scalar and W_o∈R3d×d′\mathbf{W}\_{o} \in \mathbb{R}^{3 d \times d^{\prime}}W_o∈R3d×d′ is a parameter matrix. Q+α⋅A\mathbf{Q}+\alpha \cdot \mathbf{A}Q+α⋅A is a residual connection which improves gradient propagation and stabilizes training. We initialize α\alphaα to zero and as a result,

U⊤=WoQ=(WoWq)X⊤\mathbf{U}^{\top}=\mathbf{W}^{o} \mathbf{Q}=\left(\mathbf{W}^{o} \mathbf{W}^{q}\right) \mathbf{X}^{\top}U⊤=WoQ=(WoWq)X⊤

initially falls back to a linear transformation of the input XXX skipping the attention transformation. Intuitively, skipping attention encourages leveraging recurrence to capture sequential patterns during early stage of training. As ∣α∣|\alpha|∣α∣ grows, the attention mechanism can learn long-range dependencies for the model. In addition, WoWq\mathbf{W}^{o} \mathbf{W}^{q}WoWq can be interpreted as applying a matrix factorization trick with a small inner dimension d′<dd^{\prime}<dd′<d, reducing the total number of parameters. The Figure compares the differences of SRU, SRU with this factorization trick (but without attention), and SRU++.

The last modification is adding layer normalization to each SRU++ layer. We apply normalization after the attention operation and before the matrix multiplication with Wo\mathbf{W}^{o}Wo

U⊤=Wolayernorm⁡(Q+α⋅A)\mathbf{U}^{\top}=\mathbf{W}^{o} \operatorname{layernorm}(\mathbf{Q}+\alpha \cdot \mathbf{A})U⊤=Wolayernorm(Q+α⋅A)

This implementation is post-layer normalization in which the normalization is added after the residual connection.

Papers Using This Method

SRU++: Pioneering Fast Recurrence with Attention for Speech Recognition2021-10-11When Attention Meets Fast Recurrence: Training Language Models with Reduced Compute2021-02-24