Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

complete json_cache to read missing artifacts also from github catalog #1493

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions src/unitxt/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
import re
import warnings
from abc import abstractmethod
from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple, Union, final

import requests

from .dataclass import (
AbstractField,
Dataclass,
Expand All @@ -25,8 +28,8 @@
from .text_utils import camel_to_snake_case, is_camel_case
from .type_utils import isoftype, issubtype
from .utils import (
artifacts_json_cache,
json_dump,
load_json,
save_to_file,
shallow_copy,
)
Expand Down Expand Up @@ -118,6 +121,15 @@ def reset(self):
self.catalogs = []


@lru_cache(maxsize=None)
def artifacts_json_cache(artifact_path, catalog):
if catalog is None or catalog.is_local:
return load_json(artifact_path)
# github catalog
response = requests.get(artifact_path)
return response.json()


def maybe_recover_artifacts_structure(obj):
if Artifact.is_possible_identifier(obj):
return verbosed_fetch_artifact(obj)
Expand Down Expand Up @@ -280,8 +292,8 @@ def from_dict(cls, d, overwrite_args=None):
return cls._recursive_load(d)

@classmethod
def load(cls, path, artifact_identifier=None, overwrite_args=None):
d = artifacts_json_cache(path)
def load(cls, catalog, path, artifact_identifier=None, overwrite_args=None):
d = artifacts_json_cache(path, catalog)
if "artifact_linked_to" in d and d["artifact_linked_to"] is not None:
# d stands for an ArtifactLink
artifact_link = ArtifactLink.from_dict(d)
Expand Down Expand Up @@ -498,12 +510,15 @@ def load(self, overwrite_args: dict) -> Artifact:
for catalog in catalogs:
if self.artifact_linked_to in catalog:
needed_catalog = catalog
if needed_catalog.is_local:
# prefer a local catalog, if possible
break

if needed_catalog is None:
raise UnitxtArtifactNotFoundError(self.artifact_linked_to, catalogs)

path = needed_catalog.path(self.artifact_linked_to)
d = artifacts_json_cache(path)
d = artifacts_json_cache(path, needed_catalog)
# if needed, follow, in a recursive manner, over multiple links,
# passing through instantiating of the ArtifactLink-s on the way, triggering
# deprecatioin warning as needed.
Expand Down Expand Up @@ -582,7 +597,7 @@ def fetch_artifact(artifact_rep) -> Tuple[Artifact, Union[AbstractCatalog, None]

# If local file
if isinstance(artifact_rep, str) and Artifact.is_artifact_file(artifact_rep):
artifact_to_return = Artifact.load(artifact_rep)
artifact_to_return = Artifact.load(catalog=None, path=artifact_rep)
if isinstance(artifact_rep, ArtifactLink):
artifact_to_return = fetch_artifact(artifact_to_return.artifact_linked_to)

Expand Down
4 changes: 4 additions & 0 deletions src/unitxt/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def load(self, artifact_identifier: str, overwrite_args=None):
), f"Artifact with name {artifact_identifier} does not exist"
path = self.path(artifact_identifier)
return Artifact.load(
self,
path,
artifact_identifier=artifact_identifier,
overwrite_args=overwrite_args,
Expand Down Expand Up @@ -106,6 +107,9 @@ def prepare(self):
tag = version
self.location = f"https://raw.githubusercontent.com/{self.user}/{self.repo}/{tag}/{self.repo_dir}"

def get_with_overwrite(self, name, overwrite_args):
return self.load(name, overwrite_args=overwrite_args)

def load(self, artifact_identifier: str, overwrite_args=None):
url = self.path(artifact_identifier)
response = requests.get(url)
Expand Down
6 changes: 0 additions & 6 deletions src/unitxt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import re
import threading
from collections import OrderedDict
from functools import lru_cache
from typing import Any, Dict

from .text_utils import is_made_of_sub_strings
Expand Down Expand Up @@ -122,11 +121,6 @@ def flatten_dict(
return dict(items)


@lru_cache(maxsize=None)
def artifacts_json_cache(artifact_path):
return load_json(artifact_path)


def load_json(path):
with open(path) as f:
try:
Expand Down
17 changes: 17 additions & 0 deletions tests/library/test_catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import unitxt
from unitxt import add_to_catalog
from unitxt.artifact import Artifact, Catalogs
from unitxt.catalog import GithubCatalog
from unitxt.error_utils import UnitxtError
from unitxt.register import (
_reset_env_local_catalogs,
Expand Down Expand Up @@ -80,3 +81,19 @@ class ClassToSave(Artifact):
content = json.load(f)

self.assertDictEqual(content, {"__type__": "class_to_save", "t": 1})

def test_from_github_catalog(self):
github_catalog = GithubCatalog()
path = github_catalog.path("tasks.qa.with_context.abstractive")
artifact = Artifact.load(
catalog=github_catalog,
path=path,
artifact_identifier="task",
overwrite_args={"metrics": ["metrics.anls"]},
)

# assert reached the linked_to artifact
self.assertEqual("tasks.qa.with_context", artifact.__id__)

# assert employed overwrites
self.assertListEqual(["metrics.anls"], artifact.metrics)
Loading