From e532adce937c3fb0d64a62d987109c19fdb9d8a1 Mon Sep 17 00:00:00 2001 From: dafnapension Date: Tue, 7 Jan 2025 21:36:25 +0200 Subject: [PATCH 1/3] complete json_cache to read missing artifacts also from github catalog Signed-off-by: dafnapension --- src/unitxt/artifact.py | 20 ++++++++++++++++---- src/unitxt/catalog.py | 4 ++++ src/unitxt/utils.py | 6 ------ 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/unitxt/artifact.py b/src/unitxt/artifact.py index ed9036225..0c0e5397e 100644 --- a/src/unitxt/artifact.py +++ b/src/unitxt/artifact.py @@ -6,8 +6,11 @@ import re import warnings from abc import abstractmethod +from functools import lru_cache from typing import Any, Dict, List, Optional, Tuple, Union, final +import requests + from .dataclass import ( AbstractField, Dataclass, @@ -25,8 +28,8 @@ from .text_utils import camel_to_snake_case, is_camel_case from .type_utils import isoftype, issubtype from .utils import ( - artifacts_json_cache, json_dump, + load_json, save_to_file, shallow_copy, ) @@ -118,6 +121,15 @@ def reset(self): self.catalogs = [] +@lru_cache(maxsize=None) +def artifacts_json_cache(artifact_path, catalog): + if catalog.is_local: + return load_json(artifact_path) + # github catalog + response = requests.get(artifact_path) + return response.json() + + def maybe_recover_artifacts_structure(obj): if Artifact.is_possible_identifier(obj): return verbosed_fetch_artifact(obj) @@ -280,8 +292,8 @@ def from_dict(cls, d, overwrite_args=None): return cls._recursive_load(d) @classmethod - def load(cls, path, artifact_identifier=None, overwrite_args=None): - d = artifacts_json_cache(path) + def load(cls, catalog, path, artifact_identifier=None, overwrite_args=None): + d = artifacts_json_cache(path, catalog) if "artifact_linked_to" in d and d["artifact_linked_to"] is not None: # d stands for an ArtifactLink artifact_link = ArtifactLink.from_dict(d) @@ -503,7 +515,7 @@ def load(self, overwrite_args: dict) -> Artifact: raise UnitxtArtifactNotFoundError(self.artifact_linked_to, catalogs) path = needed_catalog.path(self.artifact_linked_to) - d = artifacts_json_cache(path) + d = artifacts_json_cache(path, needed_catalog) # if needed, follow, in a recursive manner, over multiple links, # passing through instantiating of the ArtifactLink-s on the way, triggering # deprecatioin warning as needed. diff --git a/src/unitxt/catalog.py b/src/unitxt/catalog.py index ee4347cfc..20713d135 100644 --- a/src/unitxt/catalog.py +++ b/src/unitxt/catalog.py @@ -51,6 +51,7 @@ def load(self, artifact_identifier: str, overwrite_args=None): ), f"Artifact with name {artifact_identifier} does not exist" path = self.path(artifact_identifier) return Artifact.load( + self, path, artifact_identifier=artifact_identifier, overwrite_args=overwrite_args, @@ -106,6 +107,9 @@ def prepare(self): tag = version self.location = f"https://raw.githubusercontent.com/{self.user}/{self.repo}/{tag}/{self.repo_dir}" + def get_with_overwrite(self, name, overwrite_args): + return self.load(name, overwrite_args=overwrite_args) + def load(self, artifact_identifier: str, overwrite_args=None): url = self.path(artifact_identifier) response = requests.get(url) diff --git a/src/unitxt/utils.py b/src/unitxt/utils.py index 353de79e7..c11c57889 100644 --- a/src/unitxt/utils.py +++ b/src/unitxt/utils.py @@ -5,7 +5,6 @@ import re import threading from collections import OrderedDict -from functools import lru_cache from typing import Any, Dict from .text_utils import is_made_of_sub_strings @@ -122,11 +121,6 @@ def flatten_dict( return dict(items) -@lru_cache(maxsize=None) -def artifacts_json_cache(artifact_path): - return load_json(artifact_path) - - def load_json(path): with open(path) as f: try: From 8b8c7b27184d26915ca11fd21d050f012c54d066 Mon Sep 17 00:00:00 2001 From: dafnapension Date: Tue, 7 Jan 2025 22:04:45 +0200 Subject: [PATCH 2/3] fixed for load artifact from file Signed-off-by: dafnapension --- src/unitxt/artifact.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/unitxt/artifact.py b/src/unitxt/artifact.py index 0c0e5397e..4cbe6d09b 100644 --- a/src/unitxt/artifact.py +++ b/src/unitxt/artifact.py @@ -123,7 +123,7 @@ def reset(self): @lru_cache(maxsize=None) def artifacts_json_cache(artifact_path, catalog): - if catalog.is_local: + if catalog is None or catalog.is_local: return load_json(artifact_path) # github catalog response = requests.get(artifact_path) @@ -510,6 +510,9 @@ def load(self, overwrite_args: dict) -> Artifact: for catalog in catalogs: if self.artifact_linked_to in catalog: needed_catalog = catalog + if needed_catalog.is_local: + # prefer a local catalog, if possible + break if needed_catalog is None: raise UnitxtArtifactNotFoundError(self.artifact_linked_to, catalogs) @@ -594,7 +597,7 @@ def fetch_artifact(artifact_rep) -> Tuple[Artifact, Union[AbstractCatalog, None] # If local file if isinstance(artifact_rep, str) and Artifact.is_artifact_file(artifact_rep): - artifact_to_return = Artifact.load(artifact_rep) + artifact_to_return = Artifact.load(catalog=None, path=artifact_rep) if isinstance(artifact_rep, ArtifactLink): artifact_to_return = fetch_artifact(artifact_to_return.artifact_linked_to) From 49197fecf5c6c057dca46163275594907febd0ff Mon Sep 17 00:00:00 2001 From: dafnapension Date: Tue, 7 Jan 2025 23:54:13 +0200 Subject: [PATCH 3/3] added test Signed-off-by: dafnapension --- tests/library/test_catalogs.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/library/test_catalogs.py b/tests/library/test_catalogs.py index 76eb45f89..7dbdfba01 100644 --- a/tests/library/test_catalogs.py +++ b/tests/library/test_catalogs.py @@ -5,6 +5,7 @@ import unitxt from unitxt import add_to_catalog from unitxt.artifact import Artifact, Catalogs +from unitxt.catalog import GithubCatalog from unitxt.error_utils import UnitxtError from unitxt.register import ( _reset_env_local_catalogs, @@ -80,3 +81,19 @@ class ClassToSave(Artifact): content = json.load(f) self.assertDictEqual(content, {"__type__": "class_to_save", "t": 1}) + + def test_from_github_catalog(self): + github_catalog = GithubCatalog() + path = github_catalog.path("tasks.qa.with_context.abstractive") + artifact = Artifact.load( + catalog=github_catalog, + path=path, + artifact_identifier="task", + overwrite_args={"metrics": ["metrics.anls"]}, + ) + + # assert reached the linked_to artifact + self.assertEqual("tasks.qa.with_context", artifact.__id__) + + # assert employed overwrites + self.assertListEqual(["metrics.anls"], artifact.metrics)