Skip to content

Commit

Permalink
misc(trtllm): attempt to fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Jan 15, 2025
1 parent 6aaa045 commit d236ac7
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tests/integration/test_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@


MODEL_TO_TEST = {
"google/gemma-2b-it",
# "google/gemma-2b-it",
"meta-llama/Llama-2-7b-chat-hf",
"mistralai/Mistral-7B-Instruct-v0.2",
"meta-llama/Meta-Llama-3-8B",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
# "mistralai/Mistral-7B-Instruct-v0.2",
# "meta-llama/Meta-Llama-3-8B",
# "mistralai/Mixtral-8x7B-Instruct-v0.1",
}

MODEL_KWARGS_MAPS = {"Mixtral-8x7B-Instruct-v0.1": {"tp": 4}}
Expand Down Expand Up @@ -67,9 +67,9 @@ def test_generation(model_id: str, batch_size: int, tp: int, pp: int):

export_config = ExportConfig(
dtype="float16",
max_input_len=1024,
max_input_len=128,
max_batch_size=batch_size,
max_output_len=1000,
max_output_len=128,
max_num_tokens=max_new_tokens,
)
export_config = sharded(export_config, tp, pp)
Expand All @@ -81,7 +81,7 @@ def test_generation(model_id: str, batch_size: int, tp: int, pp: int):
)

trt_generated_ids = trt_model.generate(
inp, num_beams=1, do_sample=False, max_new_tokens=max_new_tokens, **kwargs
inp["input_ids"], num_beams=1, do_sample=False, max_new_tokens=max_new_tokens, **kwargs
)

# TODO: left/right padding is not aligned between Transformers and TRT-LLM.
Expand Down

0 comments on commit d236ac7

Please sign in to comment.