Skip to content

Commit

Permalink
Fix model name
Browse files Browse the repository at this point in the history
  • Loading branch information
zhhsplendid committed Nov 29, 2024
1 parent 841ed0a commit 99a3c58
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
2 changes: 2 additions & 0 deletions internlm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,15 @@ class ModelType(Enum):
GEMMA = 8
QWEN2MOE = 9
MIXTRALMOE = 10
CHAMELEON = 11


class DataType(Enum):
streaming = 1
tokenized = 2
megatron = 3
mocked = 4
lumina_pickle = 5


class TensorParallelMode(Enum):
Expand Down
23 changes: 20 additions & 3 deletions tests/test_model/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def check_norm(args):
# init
rank, world_size, free_port = args
rank, world_size, free_port, convert_to_input_dtype = args
build_environment(rank, world_size, free_port)
device = get_current_device()
rtol, atol = (1e-3, 5e-3)
Expand All @@ -22,7 +22,12 @@ def check_norm(args):
seed_all(1024)

# define norm
norm = new_layer_norm(norm_type="rmsnorm", normalized_shape=hidden_size, eps=layer_norm_epsilon)
norm = new_layer_norm(
norm_type="rmsnorm",
normalized_shape=hidden_size,
eps=layer_norm_epsilon,
convert_to_input_dtype=convert_to_input_dtype,
)
norm = norm.to(device)

# create input
Expand Down Expand Up @@ -83,8 +88,20 @@ def check_norm(args):
def test_norm():
ctx = mp.get_context("spawn")
free_port = str(find_free_port())
convert_input_dtype = False
with ctx.Pool(processes=8) as pool:
pool.map(check_norm, [[rank, 8, free_port] for rank in range(8)])
pool.map(check_norm, [[rank, 8, free_port, convert_input_dtype] for rank in range(8)])
pool.close()
pool.join()


@pytest.mark.norm_convert_to_input_dtype
def test_norm_convert_to_input_dtype():
ctx = mp.get_context("spawn")
free_port = str(find_free_port())
convert_input_dtype = True
with ctx.Pool(processes=8) as pool:
pool.map(check_norm, [[rank, 8, free_port, convert_input_dtype] for rank in range(8)])
pool.close()
pool.join()

Expand Down

0 comments on commit 99a3c58

Please sign in to comment.