Skip to content

Commit

Permalink
rebased
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Aug 13, 2024
1 parent 76b524e commit d2f81d4
Showing 1 changed file with 1 addition and 18 deletions.
19 changes: 1 addition & 18 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,23 +296,7 @@ def __init__(
use_hf: bool = False,
):
super().__init__(theta)

self.add_module(
"attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon)
)
self.add_module("attn_q", LinearLayer(theta("attn_q")))
self.add_module("attn_k", LinearLayer(theta("attn_k")))
self.add_module("attn_v", LinearLayer(theta("attn_v")))
self.add_module("attn_output", LinearLayer(theta("attn_output")))
self.add_module(
"ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon)
)
self.add_module("ffn_gate", LinearLayer(theta("ffn_gate")))
self.add_module("ffn_up", LinearLayer(theta("ffn_up")))
self.add_module("ffn_down", LinearLayer(theta("ffn_down")))
):
super().__init__(theta)
if hf:
if use_hf:
# tensor = theta("self_attn.qkv.weight").tensor
# tensor = tensor.reshape(head_count_kv, head_count // head_count_kv + 2, head_dim, head_dim * head_count)
# print(tensor)
Expand Down Expand Up @@ -353,7 +337,6 @@ def __init__(
self.block_index = block_index
self.cache = cache
assert isinstance(head_count, int)
>>>>>>> 4ed3c9d (add some fixes to run)
self.head_count = head_count
self.head_dim = head_dim
self.head_count_kv = head_count_kv
Expand Down

0 comments on commit d2f81d4

Please sign in to comment.