From bfc1f2763b1d744e50986db1379c408b6840bea5 Mon Sep 17 00:00:00 2001 From: JosefAlbers <146810011+JosefAlbers@users.noreply.github.com> Date: Fri, 12 Jul 2024 23:19:11 +0900 Subject: [PATCH] longrope (#886) --- llms/mlx_lm/models/phi3.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index e4a8cc7dc..dd2d6d823 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -33,9 +33,9 @@ def __post_init__(self): if not all(key in self.rope_scaling for key in required_keys): raise ValueError(f"rope_scaling must contain keys {required_keys}") - if self.rope_scaling["type"] not in ["su", "linear"]: + if self.rope_scaling["type"] not in ["longrope", "su", "linear"]: print( - "[WARNING] rope_scaling 'type' currently only supports 'linear' and 'su'; setting rope scaling to false." + "[WARNING] rope_scaling 'type' currently only supports 'linear', 'su', and 'longrope'; setting rope scaling to false." ) self.rope_scaling = None @@ -58,7 +58,7 @@ def __init__(self, args: ModelArgs): self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) rope_scale = 1.0 - if args.rope_scaling and args.rope_scaling["type"] == "su": + if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]: self.rope = SuScaledRotaryEmbedding( head_dim, traditional=False,