Skip to content

Commit

Permalink
adds a test for exporting moe block
Browse files Browse the repository at this point in the history
runs in ~5 seconds
  • Loading branch information
dan-garvey committed Sep 3, 2024
1 parent 0bc76f6 commit 766dd0a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 30 deletions.
60 changes: 31 additions & 29 deletions sharktank/sharktank/layers/ffn_moe_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,40 @@ def __init__(

super().__init__(theta)

try:
merged_tensor = theta.tensor("ffn_gate_exps", "weight")
# try:
print(theta.flatten())
merged_tensor = theta.tensor("ffn_gate_exps", "weight")

expert_tensor = extract_ffn_layer(
merged_tensor=merged_tensor,
layer_name="ffn_gate",
expert_idx=expert_idx,
)
expert_tensor = extract_ffn_layer(
merged_tensor=merged_tensor,
layer_name="ffn_gate",
expert_idx=expert_idx,
)

self.add_module("ffn_gate", LinearLayer(Theta({"weight": expert_tensor})))
self.add_module("ffn_gate", LinearLayer(Theta({"weight": expert_tensor})))

merged_tensor = theta.tensor("ffn_up_exps", "weight")
merged_tensor = theta.tensor("ffn_up_exps", "weight")

expert_tensor = extract_ffn_layer(
merged_tensor=merged_tensor, layer_name="ffn_up", expert_idx=expert_idx
)
expert_tensor = extract_ffn_layer(
merged_tensor=merged_tensor, layer_name="ffn_up", expert_idx=expert_idx
)

self.add_module("ffn_up", LinearLayer(Theta({"weight": expert_tensor})))
self.add_module("ffn_up", LinearLayer(Theta({"weight": expert_tensor})))

merged_tensor = theta.tensor("ffn_down_exps", "weight")
merged_tensor = theta.tensor("ffn_down_exps", "weight")

expert_tensor = extract_ffn_layer(
merged_tensor=merged_tensor,
layer_name="ffn_down",
expert_idx=expert_idx,
)
expert_tensor = extract_ffn_layer(
merged_tensor=merged_tensor,
layer_name="ffn_down",
expert_idx=expert_idx,
)

self.add_module("ffn_down", LinearLayer(Theta({"weight": expert_tensor})))
self.add_module("ffn_down", LinearLayer(Theta({"weight": expert_tensor})))

except:
self.add_module("ffn_gate", LinearLayer(theta("ffn_gate", expert_idx)))
self.add_module("ffn_up", LinearLayer(theta("ffn_up", expert_idx)))
self.add_module("ffn_down", LinearLayer(theta("ffn_down", expert_idx)))
# except:
# self.add_module("ffn_gate", LinearLayer(theta("ffn_gate", expert_idx)))
# self.add_module("ffn_up", LinearLayer(theta("ffn_up", expert_idx)))
# self.add_module("ffn_down", LinearLayer(theta("ffn_down", expert_idx)))

def forward(
self,
Expand All @@ -74,11 +75,12 @@ def forward(
def extract_ffn_layer(
merged_tensor: DefaultPrimitiveTensor, layer_name: str, expert_idx: int
):

expert_layer_name = (
f"blk.{merged_tensor.name.split('.')[1]}.{layer_name}.{expert_idx}.weight"
)
print(merged_tensor.name)
# blk.0.ffn_down_exps.weight
# expert_layer_name = (
# f"blk.{merged_tensor.name.split('.')[1]}.{layer_name}.{expert_idx}.weight"
# )
expert_tensor = DefaultPrimitiveTensor(
name=expert_layer_name, data=merged_tensor.as_torch()[expert_idx]
name="", data=merged_tensor.as_torch()[expert_idx]
)
return expert_tensor
27 changes: 26 additions & 1 deletion sharktank/sharktank/models/llama/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# Range of torch.rand() is [0,1)
# Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values
def make_rand_torch(shape, dtype):
def make_rand_torch(shape, dtype=torch.float32):
return torch.rand(shape, dtype=dtype) * 2 - 1


Expand Down Expand Up @@ -54,3 +54,28 @@ def make_attention_block_theta(
),
}
)


def make_moe_block_theta(feature_dim=1024, ffn_dim=6144, num_experts=8) -> Theta:
return Theta(
{
"blk.0.ffn_gate_inp.weight": DefaultPrimitiveTensor(
data=make_rand_torch((feature_dim, ffn_dim))
),
"blk.0.ffn_norm.weight": DefaultPrimitiveTensor(
data=make_rand_torch((ffn_dim))
),
"blk.0.layer_output_norm.weight": DefaultPrimitiveTensor(
data=make_rand_torch((ffn_dim))
),
"blk.0.ffn_gate_exps.weight": DefaultPrimitiveTensor(
data=make_rand_torch((8, feature_dim * num_experts, ffn_dim))
),
"blk.0.ffn_up_exps.weight": DefaultPrimitiveTensor(
data=make_rand_torch((8, feature_dim * num_experts, ffn_dim))
),
"blk.0.ffn_down_exps.weight": DefaultPrimitiveTensor(
data=make_rand_torch((8, ffn_dim, feature_dim * num_experts))
),
}
)

0 comments on commit 766dd0a

Please sign in to comment.