From 7de630de271a29757430bb271fefae9cb4b9af06 Mon Sep 17 00:00:00 2001 From: Radek Osmulski Date: Wed, 5 Jul 2023 22:38:38 +1000 Subject: [PATCH] add test --- .../torch/examples/test_01_getting_started.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 tests/unit/torch/examples/test_01_getting_started.py diff --git a/tests/unit/torch/examples/test_01_getting_started.py b/tests/unit/torch/examples/test_01_getting_started.py new file mode 100644 index 0000000000..d0ae8dc5f9 --- /dev/null +++ b/tests/unit/torch/examples/test_01_getting_started.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from testbook import testbook + +from tests.conftest import REPO_ROOT + + +@testbook(REPO_ROOT / "examples/pytorch/01-Getting-started.ipynb", execute=False) +@pytest.mark.notebook +def test_example_01_getting_started(tb): + tb.inject( + """ + from unittest.mock import patch + from merlin.datasets.synthetic import generate_data + mock_train, mock_valid = generate_data( + input="movielens-1m", + num_rows=1000, + set_sizes=(0.8, 0.2) + ) + p1 = patch( + "merlin.datasets.entertainment.get_movielens", + return_value=[mock_train, mock_valid] + ) + p1.start() + """ + ) + tb.execute() + metrics = tb.ref("metrics") + assert set(metrics[0].keys()) == set( + [ + "val_loss", + "val_binary_accuracy", + "val_binary_auroc", + "val_binary_precision", + "val_binary_recall" + ] + )