back · main · about · writing · notes · reading · d3 · now · contact


How does weight-tying work in a RNN?
7 Oct 2020 · 965 words

Weight-tying is where you have a language model and use the same weight matrix for the input-to-embedding layer (the input embedding) and the hidden-to-softmax layer (the output embedding). The idea is that these two matrices contain essentially the same information, each having a row per word in the vocabulary. The concept seems to be first mentioned by Press and Wolf in 2016.

Below we illustrate this with a RNN. This post assumes that you are somewhat familiar already with how a RNN works. Keep an eye out for the matrices $U$ and $V$ below: these are the input and output embedding matrices.

The RNN setup

Say you have a document split into $T$ words, $w_1, \dots, w_T$. The RNN will take one word at a time as input. Call this $w_t$.

First $w_t$ is transformed into a one-hot vector $x_t$. This vector has length $C$, where $C$ is the number of words in the vocabulary, which means $x_t$ has shape $C \times 1$.

$U$ is the input embedding matrix and is usually pre-trained. Say we are using $D$-dimensional GloVe embeddings (e.g. 50-dimensional). Then $U$ has shape $C \times D$, with one row per word in the vocab. The purpose of $U$ is to convert the one-hot vector $x_t$ into its embedding representation $e_t$ via $e_t = U^Tx_t$. The embedding vector $e_t$ now has shape $D \times 1$.

Next some operations are done on $e_t$ to obtain the hidden state $h_t$ of some length $H$. For a RNN this looks something like $$h_t = f \big(P_{DH} e_t + W_{hh} h_{t-1} + b_1 \big) $$ where

Now that you have the hidden state calculated for the current time step, the next step is to obtain a vector of scores $s_t$ which are used to calculate probabilities. The score vector $s_t$ has size $C \times 1$, one entry per word in the vocab, and is called the hidden scores or the logits layer. It is given by

$$s_t = Vh_t + b_2 $$

where

The final vector of next-word probabilities $p_t$ is obtained from $s_t$ via the softmax function $\sigma$, so that $p_t = \sigma(s_t)$. This layer is known as the softmax layer. At this point you are finished.

Adding weight tying

Now, let’s consider the shapes of the input and output embedding matrices $U$ and $V$ in the above formulation.

These are different at the moment, but if you set the size of the hidden layer $H$ to be equal to the size of each embedding vector $D$, then these two matrices become the same size. Then we can do weight-tying and use a common matrix for both $U$ and $V$.

Setting $D = H$ is required in some Python examples of weight-tying (here is one example). The downside is that you are stuck setting those two values to be the same, which is restrictive and not optimal for performance. For example, if you are using 300-d GloVe vectors and want to measure performance with the 50-d version, then the hidden vector will also shrink from length 300 to length 50, impacting performance in a way you probably didn’t intend.

Can we have $H$ different to $D$? We can if we modify the above equations by inserting a projection matrix $P_{HD}$ of size $D \times H$. Now we get

$$s_t = VP_{HD}h_t + b_2 $$

where

So to summarise, our RNN with weight tying for the case $D \neq H$ looks like

$$\begin{cases} e_t = U^Tx_t \\
h_t = f \big(P_{HD}e_t + W_{hh}h_{t-1}+b_1\big) \\
s_t = VP_{DH}h_t + b_2 \\
p_t = \sigma (s_t) \end{cases}$$

and for weight tying with $D = H$, it becomes

$$ \begin{cases} e_t = U^Tx_t \\
h_t = f \big(e_t + W_{hh}h_{t-1}+b_1\big) \\
s_t = Vh_t + b_2 \\
p_t = \sigma (s_t) \end{cases} $$

A second implicit assumption here is the number of words with an input embedding vector is the same as the number of words able to be predicted by the language model. So if you have a pre-trained embedding that has a vocab of 500,000 words, you can only predict those 500,000 words. We have used the common symbol $C$ to denote both these values, but I can imagine there are situations where they should be different.


back · main · about · writing · notes · reading · d3 · now · contact