Skip to content

Commit

Permalink
Fixed the order of gated_act to be (act, linear) when Pax2TE.
Browse files Browse the repository at this point in the history
Signed-off-by: Ming Huang <mingh@nvidia.com>
  • Loading branch information
mingxu1067 committed Jun 10, 2024
1 parent 0005642 commit 458cef5
Showing 1 changed file with 42 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,38 @@ def _generate_ckpt_map(self):
hidden_dim = num_of_head * head_dim
mlp_intermediate_dim = self.model_config.mlp_intermediate_dim

for i in range(self.model_config.num_of_layer):
ckpt_map.update({
f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w":
self._get_convert_pkg(
if self.use_gated_act:
ckpt_map[f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1_gate.linear.w"] = \
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_kernel",
(hidden_dim, mlp_intermediate_dim), 0,
extra_src_paths = [f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1_gate.linear.w"],
stack_dim = -2) if self.use_gated_act else \
extra_src_paths = [f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w"],
stack_dim = -2)
else:
ckpt_map[f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w"] = \
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_kernel",
(hidden_dim, mlp_intermediate_dim), 0,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))),

for i in range(self.model_config.num_of_layer):
ckpt_map_for_ffn1 = {}
if self.use_gated_act:
ckpt_map_for_ffn1[f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1_gate.linear.w"] = \
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_kernel",
(hidden_dim, mlp_intermediate_dim), 0,
extra_src_paths = [f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w"],
stack_dim = -2)
else:
ckpt_map_for_ffn1[f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w"] = \
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_kernel",
(hidden_dim, mlp_intermediate_dim), 0,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))),

ckpt_map.update({
**ckpt_map_for_ffn1,
f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer2.linear.w":
self._get_convert_pkg(
f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wo_kernel",
Expand Down Expand Up @@ -313,17 +333,28 @@ def _generate_ckpt_map(self):
hidden_dim = num_of_head * head_dim
mlp_intermediate_dim = self.model_config.mlp_intermediate_dim

ckpt_map_for_ffn1 = {}
if self.use_gated_act:
ckpt_map_for_ffn1['lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1_gate.linear.w'] = \
self._get_convert_pkg(
f'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wi_kernel',
(hidden_dim, mlp_intermediate_dim), 0,
extra_src_paths = ['lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w'],
stack_dim = -2)
else:
ckpt_map_for_ffn1['lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w'] = \
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wi_kernel',
(num_of_layer, hidden_dim, mlp_intermediate_dim), 1,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1])))

ckpt_map.update({
'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.bias.b':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wi_bias',
(num_of_layer, mlp_intermediate_dim), None,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))),
'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wi_kernel',
(num_of_layer, hidden_dim, mlp_intermediate_dim), 1,
lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))),
**ckpt_map_for_ffn1,
'lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.bias.b':
self._get_convert_pkg(
'lm.transformer.repeat.sub.x_layers_0.transformerlayer.cld.mlp.wo_bias',
Expand Down

0 comments on commit 458cef5

Please sign in to comment.