Scaling Rotational Embeddings for Long-Context Language Models

May 17, 2024

Gradient Team

In the past two weeks, we’ve released a series of Llama-3 long context models by increasing RoPE theta and adding in full context length supervised fine-tuning (SFT). One of our most interesting learnings was around scaling positional encodings, which is why we’ll be taking a deep dive on how we managed to scale the context length up to 4M.

In the past two weeks, we’ve released a series of Llama-3 long context models by increasing RoPE theta and adding in full context length supervised fine-tuning (SFT). One of our most interesting learnings was around scaling positional encodings, which is why we’ll be taking a deep dive on how we managed to scale the context length up to 4M.

In the past two weeks, we’ve released a series of Llama-3 long context models by increasing RoPE theta and adding in full context length supervised fine-tuning (SFT). One of our most interesting learnings was around scaling positional encodings, which is why we’ll be taking a deep dive on how we managed to scale the context length up to 4M.

Theoretical Limitations of Positional Encodings

LLMs typically come with a pre-set context window, which is defined during pre-training and taught to the model through the positional encoding of tokens. While there’s no fundamental limitation preventing inference on longer sequences than those used in training, there are nuances – you can encounter serious extrapolation issues when running inference on positions out of distribution.

Today, there are multiple positional mechanisms to choose from. Alibi [5], Rotary Position Embedding (RoPE) [8], and the T5 positional bias all share the concept of efficiently encoding a position and offer their own solution to dynamically adapt for greater input lengths. Most open-source LLMs, including Llama-3, use RoPE, which is why we focus on learnings for this specific mechanism.

Adapting RoPE to Longer Input Lengths

Two main approaches are used to adapt RoPE to longer input lengths – position interpolation and theta scaling.

Position Interpolation [1]

Image from [1]. Extrapolation of sinusoidal positional encodings.

Looking at the encodings of a single head using RoPE, a strong out-of-distribution scenario occurs when simply extrapolating the positional embeddings to the longer context length. This impacts even simple tasks, such as duplicating a string or reciting a telephone number mentioned in the longer context. As shown in the figure above, an alternative is to interpolate the target position length on the same periods as the pre-trained range. This moves out-of-distribution positional embeddings to new embeddings back in distribution, allowing the model to use interpolated positions that it has already seen in its training data.

Theta Scaling [3]

Image from [3]. Extrapolation with different base thetas. The training context length is the blue cubic volume, and positional encodings per dimension index are the curves. For small thetas (a), all dimension heads see a full rotational period of positional encodings. For medium thetas (b), higher dimension heads do not see a full period, leading to out-of-distribution extrapolation (red). For large thetas (c), this is partially mitigated (light blue) by very low periods (red).

Another approach is to increase the RoPE base theta, which decreases the frequency of the rotary period. Compared to position interpolation, this takes all attention heads into consideration at once. There’s a neat theoretical insight from this approach. Each hidden dimension gets a different positional encoding by scaling the rotation periodicity. Higher dimensions have slower rotations and if the model doesn’t see a full period of rotation during training (e.g. only up to 90 degrees), it can’t extrapolate past it (e.g. > 90 degrees). This is depicted in the figure above, where larger bases (b) and (c) higher dimension heads do not see a full rotational period of embeddings during training (blue box volume), and so it cannot extrapolate (red area).

In the case of Llama-3-8B, we have 4096 as hidden dimension, split across 32 heads with a dimension of d=128 each, a RoPE base of 500k, and a pre-trained context length of 8192. Following the equations for critical dimension in [3], we find that only the first 72 dimensions have seen the full period of positional encodings. The implication is that, when extending the context length, the first 72 out of 128 dimensions will observe previously seen positional encodings, while the remaining 56 dimensions will encounter novel encodings, leading to an out-of-distribution shift.

Finally, in practical implementations, you can also calculate this as new inverse frequency for a set target length on the fly. [6]

From Theory to Practice: Scaling Laws for RoPE Theta with Increasing Context Length

In training our long-context models, we took the theta scaling approach, and scaled up RoPE theta such that the newly stretched period overlapped with the previous theta’s period, keeping it in distribution. Adapting formulas from Liu et al. [3], the theta required to represent the encodings when extending context length is:

As an example, if we plug in Llama-3-8B values, we arrive at the following new RoPE theta for extending context length from 8k (base) to 262k:

This is the theta we used for training our Llama-3-8B-262k model, which is publicly available on Hugging Face.

While the above formula implies a potential scaling from 8192 pre-trained context length to any target length, there are nuances in taking this approach. We opted for a progressive, curriculum style multi-step training, which alternates scaling theta and fine-tuning on more data with stages of 65k, 262k, and 1048k context length.

Empirical Results for Step-Wise RoPE Theta Scaling

When looking at related projects that published their results on scaling vanilla multi-head attention (MHA) architectures with long-context, we did not find a strong rationale for the implemented theta schedule. For instance, the otherwise fantastic LWM-paper [4] opts for a 1M-32K, 10M-131k, 10M-262k, 25M-524k, 50M-1048k (theta-context length) schedule. When extrapolating this schedule out to new bases and longer contexts, we found it led to a suboptimally high initial perplexity, which forced the model to deviate more strongly from the initial 8k-instruct version of Llama-3-8B.

In order to determine the most optimal strategy for scaling theta, we measured the perplexity over the first training batch of 4M tokens, when scaling Llama-3-8B from 262k to 1048k context length. From these experiments, we found that scaling the theta according to NLTK [6] or single-step scaling laws from [1] had lower step-1 perplexity, resulting in a lower amount of tokens needed to achieve a low enough perplexity for strong model performance.

Achieving 4194k Context Length for Llama-3-8B-Instruct

This blog is centered around a small but crucial detail of theta scaling laws. We novelly combined these findings with ring-attention, which allowed us to step-wise train a Llama-3 model with a 4x and 8x increase of context length per curriculum step. Applying this method, we trained a model that performs well on context lengths up to 4194k (the longest open sourced context length in the world) and also achieved SOTA results on long context evals using RULER [9] https://gradient.ai/blog/evaluating-models-beyond-niah.

You can follow along and check out all of our open-source releases on 🤗, https://huggingface.co/gradientai!

Bibliography

Chen et al. Extending context window of large language models via positional interpolation (2023) [1]

B. Peng et al. Yarn (2023) [2]

X.Liu et al. Scaling Laws of RoPE-based Extrapolation (2024) [3]

H. Liu et al. Large World Model (2024) [4]

Press et al. Alibi (2022) [5]

Transformers, NLTK-Llama Implementation (2023) [6]

EasyContext, Training Schedule (2024) [7]

Su, Jianlin, et al. "Roformer: Enhanced transformer with rotary position embedding." Neurocomputing 568 (2024): 127063 [8]

Hsieh, Cheng-Ping, et al. "RULER: What's the Real Context Size of Your Long-Context Language Models?." arXiv preprint arXiv:2404.06654 (2024) [9]