diff --git a/pkg_src/retrain_pipelines/dataset/hf_utils/hf_utils.py b/pkg_src/retrain_pipelines/dataset/hf_utils/hf_utils.py
index 1d9ea53..2b54018 100644
--- a/pkg_src/retrain_pipelines/dataset/hf_utils/hf_utils.py
+++ b/pkg_src/retrain_pipelines/dataset/hf_utils/hf_utils.py
@@ -12,108 +12,16 @@
import pandas as pd
import polars as pl
-from datetime import datetime
-
from typing import Optional, Callable, Iterator
-from huggingface_hub import list_repo_refs, list_repo_commits, \
- list_repo_files, hf_hub_download, HfApi
from huggingface_hub.utils import RevisionNotFoundError, \
EntryNotFoundError, HfHubHTTPError
from datasets import IterableDataset, DatasetDict
from retrain_pipelines import __version__
-from retrain_pipelines.utils.hf_utils import local_repo_folder_to_hub
-
-
-def _dataset_repo_branch_commits_files(
- repo_id: str,
- repo_branch: str
-) -> dict:
- """
- Params:
- - repo_id (str):
- Path to the HuggingFace dataset.
- - repo_branch (str):
- Branch (of the repository of interest)
- to be considered.
-
- Results:
- - (dict)
- 'commit_hash', 'created_at',
- 'title', 'files'
- """
- commits = list_repo_commits(repo_id, revision=repo_branch,
- repo_type="dataset",
- token=os.environ["HF_TOKEN"])
- commits_dict = {}
- for commit in commits:
- files = list_repo_files(
- repo_id, revision=commit.commit_id,
- repo_type="dataset",
- token=os.environ["HF_TOKEN"])
-
- commits_dict[commit.commit_id] = {
- "created_at": commit.created_at.strftime(
- "%Y-%m-%d %H:%M:%S UTC"),
- "title": commit.title,
- "files": files
- }
-
- return commits_dict
-
-
-def get_dataset_branches_commits_files(
- repo_id: str
-) -> dict:
- """
- Selection of metadata for (litterally)
- all files of all commits of a given
- HF dataset repo.
-
- Params:
- - repo_id (str):
- Path to the HuggingFace dataset.
-
- Results:
- - (dict)
- 'branches'
- (
- 'branch_name', 'commits',
- (
- 'commit_hash', 'created_at',
- 'title', 'files'
- )
- )
- """
-
- refs = list_repo_refs(repo_id, repo_type="dataset",
- token=os.environ["HF_TOKEN"])
-
- dataset_repo_branches = {
- "repo_standard_branches": {},
- "repo_convert_branches": {}
- }
- for repo_standard_branches in refs.branches:
- dataset_repo_branches[
- "repo_standard_branches"
- ][repo_standard_branches.name] = {
- "branch_name": repo_standard_branches.ref,
- "commits": _dataset_repo_branch_commits_files(
- repo_id, repo_standard_branches.ref)
- }
-
- for repo_convert_branch in refs.converts:
- dataset_repo_branches[
- "repo_convert_branches"
- ][repo_convert_branch.name] = {
- "branch_name": repo_convert_branch.ref,
- "commits": _dataset_repo_branch_commits_files(
- repo_id, repo_convert_branch.ref)
- }
-
- return dataset_repo_branches
+from retrain_pipelines.utils.hf_utils import \
+ get_repo_branches_commits_files, local_repo_folder_to_hub
def get_latest_commit(
@@ -140,7 +48,8 @@ def get_latest_commit(
"""
dataset_repo_branches = \
- get_dataset_branches_commits_files(repo_id)
+ get_repo_branches_commits_files(
+ repo_id=repo_id, repo_type="dataset")
latest_matching_commit = None
regex_pattern = re.compile(files_filter)
@@ -208,7 +117,8 @@ def get_commit(
return matching_commit
else:
dataset_repo_branches = \
- get_dataset_branches_commits_files(repo_id)
+ get_repo_branches_commits_files(
+ repo_id=repo_id, repo_type="dataset")
for \
branch_type, branches \
in dataset_repo_branches.items() \
@@ -521,83 +431,6 @@ def dataset_dict_to_config_str(
return result
-def get_latest_README_commit(
- repo_id: str,
- target_commit_hash: str,
- verbose: bool = True
-) -> (str, datetime):
- """
- Using a given commit as a starting point,
- look for the latest prior commit for which
- there was a README.md file.
-
- This is to address cases where
- 'the commit corresponding to this commit_hash
- didn't include a README and
- many entries are missing from `dataset_info`'.
- for instance, typical of 'auto-convert bot'
- (think duckdb or parquet,
- @see https://huggingface.co/docs/dataset-viewer/en/parquet#conversion-to-parquet).
-
- Params:
- - repo_id (str):
- Path to the HuggingFace dataset.
- - commit_hash (Optional, str):
- particular "revision" of the dataset
- to scan.
- - verbose (bool):
- whether or not to print commit
- hash and date (target vs latest README)
-
- Results:
- - (str, datetime):
- latest_README_commit_hash,
- latest_README_commit_date
- """
- hf_dataset_branches_commits_files = \
- get_dataset_branches_commits_files(repo_id=repo_id)
-
- target_date = None
- for repo, repo_data in hf_dataset_branches_commits_files.items():
- for branch, branch_data in repo_data.items():
- for commit_hash, commit_data in branch_data['commits'].items():
- if commit_hash == target_commit_hash:
- target_date = datetime.strptime(
- commit_data['created_at'], '%Y-%m-%d %H:%M:%S UTC')
- break
- if target_date:
- break
- if target_date:
- break
- if verbose:
- print("target commit : ".ljust(25), target_commit_hash, target_date)
-
- README_date = None
- README_commit_hash = None
- for repo, repo_data in hf_dataset_branches_commits_files.items():
- for branch, branch_data in repo_data.items():
- for commit_hash, commit_data in branch_data['commits'].items():
- if 'README.md' in commit_data['files']:
- commit_date = datetime.strptime(
- commit_data['created_at'], '%Y-%m-%d %H:%M:%S UTC')
- if commit_date <= target_date:
- README_date = datetime.strptime(
- commit_data['created_at'], '%Y-%m-%d %H:%M:%S UTC')
- README_commit_hash = commit_hash
- if verbose:
- print("lastest README commit : ".ljust(25),
- README_commit_hash, README_date)
- break
- else:
- continue
- break
- else:
- continue
- break
-
- return README_commit_hash, README_date
-
-
def push_dataset_version_to_hub(
repo_id: str,
dataset_dict: DatasetDict,
diff --git a/pkg_src/retrain_pipelines/pipeline_card/mf_unsloth_func_call_litserve/dataset_readme.py b/pkg_src/retrain_pipelines/pipeline_card/mf_unsloth_func_call_litserve/dataset_readme.py
index 2590c4b..855fdde 100644
--- a/pkg_src/retrain_pipelines/pipeline_card/mf_unsloth_func_call_litserve/dataset_readme.py
+++ b/pkg_src/retrain_pipelines/pipeline_card/mf_unsloth_func_call_litserve/dataset_readme.py
@@ -11,11 +11,10 @@
from retrain_pipelines import __version__
from retrain_pipelines.dataset.hf_utils import \
- get_latest_README_commit, get_size_category, \
- dataset_dict_to_config_str
+ get_size_category, dataset_dict_to_config_str
from retrain_pipelines.utils.hf_utils import \
- get_arxiv_codes, get_license_label, \
- get_pretty_name
+ get_latest_README_commit, get_arxiv_codes, \
+ get_license_label, get_pretty_name
def _dataset_readme_params(
@@ -80,13 +79,15 @@ def _dataset_readme_params(
main_commit_hash, main_commit_utc_date_str = \
get_latest_README_commit(
repo_id=hf_dataset_dict["repo_id"],
- target_commit_hash=hf_dataset_dict["commit_hash"]
+ target_commit_hash=hf_dataset_dict["commit_hash"],
+ repo_type="dataset"
)
enrich_commit_hash, enrich_commit_utc_date_str = \
get_latest_README_commit(
repo_id=hf_enrich_dataset_dict["repo_id"],
target_commit_hash=\
- hf_enrich_dataset_dict["commit_hash"]
+ hf_enrich_dataset_dict["commit_hash"],
+ repo_type="dataset"
)
main_pretty_name = get_pretty_name(
@@ -256,7 +257,8 @@ def get_dataset_readme_content(
version_label=version_label,
utc_timestamp_str=utc_timestamp_str,
mf_flow_name=mf_flow_name,
- mf_run_id=mf_run_id
+ mf_run_id=mf_run_id,
+ engine=engine
)
env = Environment(loader=FileSystemLoader(template_folder))
diff --git a/pkg_src/retrain_pipelines/pipeline_card/mf_unsloth_func_call_litserve/dataset_readme_template.md b/pkg_src/retrain_pipelines/pipeline_card/mf_unsloth_func_call_litserve/dataset_readme_template.md
index c4c6cb7..a1b9500 100644
--- a/pkg_src/retrain_pipelines/pipeline_card/mf_unsloth_func_call_litserve/dataset_readme_template.md
+++ b/pkg_src/retrain_pipelines/pipeline_card/mf_unsloth_func_call_litserve/dataset_readme_template.md
@@ -96,7 +96,9 @@ Data-enrichment rate : +{{ (enrichment_rate * 100)|round(1) ~ '%' }}
retrain-pipelines
+ {{ __version__ }}
-
Run by {{ run_user }}
-
{{ mf_flow_name }} - mf_run_id : {{ mf_run_id }}
diff --git a/pkg_src/retrain_pipelines/pipeline_card/mf_unsloth_func_call_litserve/model_readme.py b/pkg_src/retrain_pipelines/pipeline_card/mf_unsloth_func_call_litserve/model_readme.py
index e4de2b9..db8bd6e 100644
--- a/pkg_src/retrain_pipelines/pipeline_card/mf_unsloth_func_call_litserve/model_readme.py
+++ b/pkg_src/retrain_pipelines/pipeline_card/mf_unsloth_func_call_litserve/model_readme.py
@@ -11,18 +11,18 @@
from retrain_pipelines import __version__
-from retrain_pipelines.dataset.hf_utils import \
- get_latest_README_commit, \
- get_arxiv_codes, get_license_label, \
- get_pretty_name,
+from retrain_pipelines.utils.hf_utils import \
+ get_latest_README_commit, get_arxiv_codes, \
+ get_license_label, get_pretty_name
def _model_readme_params(
+ base_model_dict: dict,
+ training_dataset_dict: dict,
version_label: str,
utc_timestamp_str: str,
mf_flow_name: str,
mf_run_id: str,
- engine:str = "cpu"
) -> dict:
"""
Populates the params dict to be used
@@ -31,6 +31,12 @@ def _model_readme_params(
Built on metadata from the base model.
Params:
+ - base_model_dict (dict)
+ - training_dataset_dict (dict):
+ - repo_id
+ - commit_hash
+ - commit_utc_date_str
+ -
- version_label (str):
typical `retrain-pipelines`
version label are of format "major.minor"
@@ -38,8 +44,6 @@ def _model_readme_params(
timestampt for the new dataset version.
- mf_flow_name (str)
- mf_run_id (str)
- - engine (str):
- Polars engine (can be "cpu", "gpu"..)
Results:
- (dict)
@@ -47,28 +51,29 @@ def _model_readme_params(
pretty_name = "retrain-pipelines Function Caller"
- base_commit_hash, base_commit_utc_date_str = \
+ base_model_commit_hash, base_model_commit_utc_date_str = \
get_latest_README_commit(
- repo_id=hf_dataset_dict["repo_id"],
- target_commit_hash=hf_dataset_dict["commit_hash"]
+ repo_id=base_model_dict["repo_id"],
+ target_commit_hash=base_model_dict["commit_hash"],
+ repo_type="model"
)
- base_pretty_name = get_pretty_name(
- repo_id=hf_dataset_dict["repo_id"],
- commit_hash=base_commit_hash
+ base_model_pretty_name = get_pretty_name(
+ repo_id=base_model_dict["repo_id"],
+ commit_hash=base_model_commit_hash
)
- base_arxiv_codes = get_arxiv_codes(
- repo_id=hf_dataset_dict["repo_id"],
- commit_hash=base_commit_hash
+ base_model_arxiv_codes = get_arxiv_codes(
+ repo_id=base_model_dict["repo_id"],
+ commit_hash=base_model_commit_hash
)
- base_license_label = get_license_label(
- repo_id=hf_dataset_dict["repo_id"],
- commit_hash=base_commit_hash
+ base_model_license_label = get_license_label(
+ repo_id=base_model_dict["repo_id"],
+ commit_hash=base_model_commit_hash
)
- if not base_license_label:
- base_license_label = "unknown"
+ if not base_model_license_label:
+ base_model_license_label = "unknown"
return {
"new_version_label": version_label,
@@ -76,31 +81,21 @@ def _model_readme_params(
"pretty_name": pretty_name,
- "main_repo_id": \
- hf_dataset_dict["repo_id"],
- "enrich_repo_id": \
- hf_enrich_dataset_dict["repo_id"],
-
- "main_commit_hash": main_commit_hash,
- "enrich_commit_hash": enrich_commit_hash,
-
- "main_commit_utc_date_str": \
- main_commit_utc_date_str,
- "enrich_commit_utc_date_str": \
- enrich_commit_utc_date_str,
-
- "main_pretty_name": main_pretty_name,
- "enrich_pretty_name": enrich_pretty_name,
-
- "size_category": size_category,
-
- "main_arxiv_codes": main_arxiv_codes,
- "enrich_arxiv_codes": enrich_arxiv_codes,
-
- "main_license_label": main_license_label,
- "enrich_license_label": enrich_license_label,
-
- "main_format_description" : main_format_description,
+ "dataset_repo_id": \
+ training_dataset_dict["repo_id"],
+ "dataset_version_label": \
+ training_dataset_dict["version_label"],
+ "dataset_commit_hash": \
+ training_dataset_dict["commit_hash"],
+ "dataset_utc_timestamp_str": \
+ training_dataset_dict["utc_timestamp_str"],
+
+ "base_model_repo_id": base_model_dict["repo_id"],
+ "base_model_pretty_name": base_model_pretty_name,
+ "base_model_commit_hash": base_model_commit_hash,
+ "base_model_commit_utc_date_str": base_model_commit_utc_date_str,
+ "base_model_arxiv_codes": base_model_arxiv_codes,
+ "base_model_license_label": base_model_license_label,
"__version__": __version__,
"run_user": whoami()["name"],
@@ -109,16 +104,17 @@ def _model_readme_params(
}
-def get_dataset_readme_content(
+def get_model_readme_content(
template_folder: str,
- hf_dataset_dict: dict,
- hf_enrich_dataset_dict: dict,
- dataset_dict: DatasetDict,
+
+ base_model_dict: dict,
+ training_dataset_dict: dict,
+
version_label: str,
utc_timestamp_str: str,
+
mf_flow_name: str,
mf_run_id: str,
- engine:str = "cpu"
) -> str:
"""
@@ -132,18 +128,13 @@ def get_dataset_readme_content(
Params:
- template_folder (str)
- - hf_dataset_dict (dict):
+ - base_model_dict (dict)
- repo_id
- commit_hash
- - commit_utc_date_str
- - lazy_df
- - hf_enrich_dataset_dict (dict)
+ - training_dataset_dict (dict)
- repo_id
- commit_hash
- commit_utc_date_str
- - dataset_dict (DatasetDict):
- the dataset version to be pushed
- to the HF hub.
- version_label (str):
typical `retrain-pipelines`
version label are of format "major.minor"
@@ -151,17 +142,14 @@ def get_dataset_readme_content(
timestampt for the new dataset version.
- mf_flow_name (str)
- mf_run_id (str)
- - engine (str):
- Polars engine (can be "cpu", gpu"..)
Results:
- (str)
"""
- params = _dataset_readme_params(
- hf_dataset_dict=hf_dataset_dict,
- hf_enrich_dataset_dict=hf_enrich_dataset_dict,
- dataset_dict=dataset_dict,
+ params = _model_readme_params(
+ base_model_dict=base_model_dict,
+ training_dataset_dict=training_dataset_dict,
version_label=version_label,
utc_timestamp_str=utc_timestamp_str,
mf_flow_name=mf_flow_name,
@@ -169,7 +157,7 @@ def get_dataset_readme_content(
)
env = Environment(loader=FileSystemLoader(template_folder))
- template = env.get_template("dataset_readme_template.md")
+ template = env.get_template("model_readme_template.md")
readme_content = template.render(params)
return readme_content
diff --git a/pkg_src/retrain_pipelines/pipeline_card/mf_unsloth_func_call_litserve/model_readme_template.md b/pkg_src/retrain_pipelines/pipeline_card/mf_unsloth_func_call_litserve/model_readme_template.md
index ff5be55..1ca94ef 100644
--- a/pkg_src/retrain_pipelines/pipeline_card/mf_unsloth_func_call_litserve/model_readme_template.md
+++ b/pkg_src/retrain_pipelines/pipeline_card/mf_unsloth_func_call_litserve/model_readme_template.md
@@ -11,7 +11,7 @@ model_name: {{ pretty_name }}
base_model: {{ base_model_repo_id }}
library_name: peft
-license: {{license_label}}
+license: {{ base_model_license_label }}
language:
- en
@@ -62,13 +62,20 @@ model-index:
`version {{ new_version_label }}` - `{{ utc_timestamp }}`
-Training dataset :
- {{ dataset_repo_id }}
-{{ dataset_version_label }}
- {{ dataset_commit_commit_utc_date_str }}
-{{ dataset_commit_hash }}
+Training dataset :
+ {{ dataset_repo_id }}
+v{{ dataset_version_label }}
({{ dataset_commit_hash[:7] }})
+ target="_blank">{{ dataset_commit_hash[:7] }} - {{ dataset_utc_timestamp_str }})
+Base model :
+{% if base_model_arxiv_codes -%}
+arxiv :retrain-pipelines v{{ __version__ }}
-
+Powered by
+retrain-pipelines
+ {{ __version__ }}
-
Run by {{ run_user }}
-
{{ mf_flow_name }} - mf_run_id : {{ mf_run_id }}
diff --git a/pkg_src/retrain_pipelines/utils/__init__.py b/pkg_src/retrain_pipelines/utils/__init__.py
index c165d2d..d16e5f9 100644
--- a/pkg_src/retrain_pipelines/utils/__init__.py
+++ b/pkg_src/retrain_pipelines/utils/__init__.py
@@ -12,5 +12,6 @@
grant_read_access, \
tmp_os_environ, \
get_get_html, \
- get_get_dataset_readme_content
+ get_get_dataset_readme_content, \
+ get_get_model_readme_content
diff --git a/pkg_src/retrain_pipelines/utils/hf_utils.py b/pkg_src/retrain_pipelines/utils/hf_utils.py
index 4de7584..78bbb6d 100644
--- a/pkg_src/retrain_pipelines/utils/hf_utils.py
+++ b/pkg_src/retrain_pipelines/utils/hf_utils.py
@@ -4,12 +4,194 @@
import re
import traceback
+from datetime import datetime
from requests.exceptions import ReadTimeout
+from huggingface_hub import list_repo_refs, \
+ list_repo_commits, list_repo_files, HfApi
from huggingface_hub.utils import \
RepositoryNotFoundError, HfHubHTTPError
-from huggingface_hub import HfApi
+
+
+def _repo_branch_commits_files(
+ repo_id: str,
+ repo_type: str = "model",
+ repo_branch: str = "main"
+) -> dict:
+ """
+ Params:
+ - repo_id (str):
+ Path to the HuggingFace dataset.
+ - repo_type (str):
+ can be "model", "dataset", "space".
+ - repo_branch (str):
+ Branch (of the repository of interest)
+ to be considered.
+
+ Results:
+ - (dict)
+ 'commit_hash', 'created_at',
+ 'title', 'files'
+ """
+ commits = list_repo_commits(repo_id, revision=repo_branch,
+ repo_type=repo_type,
+ token=os.environ["HF_TOKEN"])
+ commits_dict = {}
+ for commit in commits:
+ files = list_repo_files(
+ repo_id, revision=commit.commit_id,
+ repo_type=repo_type,
+ token=os.environ["HF_TOKEN"])
+
+ commits_dict[commit.commit_id] = {
+ "created_at": commit.created_at.strftime(
+ "%Y-%m-%d %H:%M:%S UTC"),
+ "title": commit.title,
+ "files": files
+ }
+
+ return commits_dict
+
+
+def get_repo_branches_commits_files(
+ repo_id: str,
+ repo_type: str = "model"
+) -> dict:
+ """
+ Selection of metadata for (litterally)
+ all files of all commits of a given
+ HF repo.
+
+ Params:
+ - repo_id (str):
+ Path to the HuggingFace dataset.
+ - repo_type (str):
+ can be "model", "dataset", "space".
+
+ Results:
+ - (dict)
+ 'branches'
+ (
+ 'branch_name', 'commits',
+ (
+ 'commit_hash', 'created_at',
+ 'title', 'files'
+ )
+ )
+ """
+
+ refs = list_repo_refs(repo_id, repo_type=repo_type,
+ token=os.environ["HF_TOKEN"])
+
+ repo_branches = {
+ "repo_standard_branches": {},
+ "repo_convert_branches": {}
+ }
+ for repo_standard_branches in refs.branches:
+ repo_branches[
+ "repo_standard_branches"
+ ][repo_standard_branches.name] = {
+ "branch_name": repo_standard_branches.ref,
+ "commits": _repo_branch_commits_files(
+ repo_id, repo_type,
+ repo_standard_branches.ref)
+ }
+
+ for repo_convert_branch in refs.converts:
+ repo_branches[
+ "repo_convert_branches"
+ ][repo_convert_branch.name] = {
+ "branch_name": repo_convert_branch.ref,
+ "commits": _repo_branch_commits_files(
+ repo_id, repo_type,
+ repo_convert_branch.ref)
+ }
+
+ return repo_branches
+
+
+def get_latest_README_commit(
+ repo_id: str,
+ target_commit_hash: str,
+ repo_type: str = "model",
+ verbose: bool = True
+) -> (str, datetime):
+ """
+ Using a given commit as a starting point,
+ look for the latest prior commit for which
+ there was a README.md file.
+
+ This is to address cases where
+ 'the commit corresponding to this commit_hash
+ didn't include a README and
+ many entries are missing from
+ `HfApi().dataset_info`, `HfApi().model_info`,
+ `HfApi().space_info`..'.
+ for instance, typical of datasets 'auto-convert bot'
+ (think duckdb or parquet,
+ @see https://huggingface.co/docs/dataset-viewer/en/parquet#conversion-to-parquet).
+
+ Params:
+ - repo_id (str):
+ Path to the HuggingFace repository.
+ - commit_hash (Optional, str):
+ particular "revision" of the repository
+ to scan.
+ - repo_type (str):
+ can be "model", "dataset", "space".
+ - verbose (bool):
+ whether or not to print commit
+ hash and date (target vs latest README)
+
+ Results:
+ - (str, datetime):
+ latest_README_commit_hash,
+ latest_README_commit_date
+ """
+ hf_repo_branches_commits_files = \
+ get_repo_branches_commits_files(
+ repo_id=repo_id, repo_type=repo_type)
+
+ target_date = None
+ for repo, repo_data in hf_repo_branches_commits_files.items():
+ for branch, branch_data in repo_data.items():
+ for commit_hash, commit_data in branch_data['commits'].items():
+ if commit_hash == target_commit_hash:
+ target_date = datetime.strptime(
+ commit_data['created_at'], '%Y-%m-%d %H:%M:%S UTC')
+ break
+ if target_date:
+ break
+ if target_date:
+ break
+ if verbose:
+ print("target commit : ".ljust(25), target_commit_hash, target_date)
+
+ README_date = None
+ README_commit_hash = None
+ for repo, repo_data in hf_repo_branches_commits_files.items():
+ for branch, branch_data in repo_data.items():
+ for commit_hash, commit_data in branch_data['commits'].items():
+ if 'README.md' in commit_data['files']:
+ commit_date = datetime.strptime(
+ commit_data['created_at'], '%Y-%m-%d %H:%M:%S UTC')
+ if commit_date <= target_date:
+ README_date = datetime.strptime(
+ commit_data['created_at'], '%Y-%m-%d %H:%M:%S UTC')
+ README_commit_hash = commit_hash
+ if verbose:
+ print("lastest README commit : ".ljust(25),
+ README_commit_hash, README_date)
+ break
+ else:
+ continue
+ break
+ else:
+ continue
+ break
+
+ return README_commit_hash, README_date
def get_arxiv_codes(
@@ -142,7 +324,7 @@ def get_pretty_name(
type(err), err, err.__traceback__))
print(stack_trace, file=sys.stderr)
except Exception as err:
- print(err, file=sys.stderr)
+ print(("get_pretty_name", err), file=sys.stderr)
if not pretty_name:
pretty_name = ' '.join(
diff --git a/pkg_src/retrain_pipelines/utils/utils.py b/pkg_src/retrain_pipelines/utils/utils.py
index 67f8743..168439f 100644
--- a/pkg_src/retrain_pipelines/utils/utils.py
+++ b/pkg_src/retrain_pipelines/utils/utils.py
@@ -85,7 +85,7 @@ def get_get_dataset_readme_content(
pipeline_card_module_dir: str
) -> callable:
"""
- Loads the "pipeline_card" module,
+ Loads the "dataset_readme" module,
which can be user-provided
(path given through flow
"pipeline_card_module_dir" parameter)
@@ -106,6 +106,33 @@ def get_get_dataset_readme_content(
return get_dataset_readme_content
+
+def get_get_model_readme_content(
+ pipeline_card_module_dir: str
+) -> callable:
+ """
+ Loads the "model_readme" module,
+ which can be user-provided
+ (path given through flow
+ "pipeline_card_module_dir" parameter)
+ and returns its "get_model_readme_content" function.
+ """
+
+ pipeline_card_module_path = \
+ os.path.realpath(os.path.join(pipeline_card_module_dir,
+ "model_readme.py"))
+
+ get_model_readme_content = \
+ _load_and_get_function(
+ pipeline_card_module_path,
+ f"retrain_pipelines.pipeline_card."+
+ f"{retrain_pipeline_type}.dataset_readme",
+ "get_model_readme_content"
+ )
+
+ return get_model_readme_content
+
+
def get_get_html(
pipeline_card_module_dir: str
) -> callable: