Skip to content

Commit

Permalink
misc(trtllm): quality
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Jan 15, 2025
1 parent 94d6e32 commit 9dd82e4
Showing 1 changed file with 0 additions and 7 deletions.
7 changes: 0 additions & 7 deletions tests/integration/test_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc

import pytest
import torch
from transformers import AutoModelForCausalLM as TransformersAutoModelForCausalLM
from transformers import AutoTokenizer
from utils_testing import clean_cached_engines_for_model

from optimum.nvidia import AutoModelForCausalLM, ExportConfig
from optimum.nvidia.export.config import sharded
from optimum.nvidia.utils.nvml import get_device_count
from optimum.nvidia.utils.tests import (
assert_generated_partially_match,
)


MODEL_TO_TEST = {
Expand All @@ -48,8 +43,6 @@ def test_generation(model_id: str, batch_size: int, tp: int, pp: int):
if get_device_count() < tp * pp:
pytest.skip("Not enough GPU on the system")

torch_dtype = torch.float16 # TODO: test fp8, int4, int8, fp32

# TODO: test batched generation as well.
# TODO: This is flaky depending on the prompt for Mistral / Gemma, maybe see if it is a bug or not.
prompts = ["Today I am in Paris and I would like to eat crepes."]
Expand Down

0 comments on commit 9dd82e4

Please sign in to comment.