Skip to content

Commit

Permalink
add LinearRotaryEmbedding
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoyang07 committed Jan 25, 2024
1 parent 34fc28d commit 84bfe34
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions internlm/model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,47 @@ def _single_eval_forward(self, x, seqlen_offset=0):
return legacy_apply_rotary_embed(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])


class LinearRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev.
Reference implementation:
https://github.com/huggingface/transformers/blob/200009566639b5a83604e522a41df3a9 \
5b6056ed/src/transformers/models/llama/modeling_llama.py#L159C1-L176C1
"""

def __init__(
self, dim: int, base=10000, scale_base=0, device=None, max_position_embeddings=2048, scaling_factor=1.0
):
super().__init__(dim=dim, base=base, scale_base=scale_base, device=device)
self.max_position_embeddings = max_position_embeddings
self.scaling_factor = scaling_factor

def _update_cos_sin_cache(self, x, indexes):
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)"""
if not isinstance(indexes, int):
seqlen = indexes.max().item() + 1
else:
seqlen = indexes + 1

t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor

freqs = torch.outer(t, self.inv_freq.to(device=t.device))
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(x.dtype)
self._sin_cached = torch.sin(freqs).to(x.dtype)
else:
power = (
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)


class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla.
Expand Down

0 comments on commit 84bfe34

Please sign in to comment.