Skip to content

Commit

Permalink
implemented logic for blessed vs non-blessed model version management…
Browse files Browse the repository at this point in the history
… on the HF hub.
  • Loading branch information
aurelienmorgan committed Jan 1, 2025
1 parent 5860cc9 commit 2f03296
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 47 deletions.
80 changes: 79 additions & 1 deletion pkg_src/retrain_pipelines/model/hf_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@

import os
from datetime import datetime

from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError

from retrain_pipelines import __version__
from retrain_pipelines.utils.hf_utils import \
Expand All @@ -8,11 +12,12 @@

def push_model_version_to_hub(
repo_id: str,
model_version_blessed: bool,
version_label: str,
timestamp_str: str,
model_dir: str,
model_readme_content: str,
hf_token: str = None,
hf_token: str = os.getenv("HF_TOKEN", None)
) -> str:
"""
Loads locally-serialized model safetensor
Expand All @@ -28,6 +33,10 @@ def push_model_version_to_hub(
- repo_id (str):
Path to the HuggingFace model version
(is created if needed and if authorized).
- model_version_blessed (bool):
Whether the model version is blessed ;
dictates the branch on which to
publish it on the HF hub.
- version_label (str):
value associated to the version
to be published on the HF hub.
Expand Down Expand Up @@ -60,9 +69,15 @@ def push_model_version_to_hub(
"Upload model and tokenizer with README."
print(commit_message)

branch_name=(
"main" if model_version_blessed
else "retrain-pipelines_not-blessed"
)

model_version_commit_hash = \
local_repo_folder_to_hub(
repo_id=repo_id,
branch_name=branch_name,
local_folder=model_dir,
commit_message=commit_message,
repo_type="model",
Expand All @@ -71,3 +86,66 @@ def push_model_version_to_hub(

return model_version_commit_hash


def current_blessed_model_version_dict(
repo_id: str,
hf_token: str = os.getenv("HF_TOKEN", None)
) -> dict:
"""
None if no prior model version
exists on the HF Hub.
Params:
- repo_id (str):
Path to the HuggingFace model.
- hf_token (Optional, str):
"create on namespace" permission required.
Results:
- (dict):
- mf_run_id (str)
- commit_hash (str)
- version_label (str)
- commit_datetime (datetime)
- perf_metrics (dict)
"""

try:
model_info = HfApi().repo_info(
repo_id=repo_id,
revision="main",
token=hf_token
)
except RepositoryNotFoundError as err:
print(f"repo {repo_id} not found.\n" +
"If you are trying to access a " +
"private or gated repo, " +
"make sure you are authenticated " +
"and your credentials allow it.",
file=sys.stderr)
print(err, file=sys.stderr)
return None

if model_info:
model_version_card_data = \
model_info.cardData
commit_datetime = datetime.strptime(
model_version_card_data["timestamp"],
"%Y%m%d_%H%M%S%f_%Z")

eval_results_dict = {
m['type']: m['value']
for m in model_info \
.model_index[0]['results'][0]['metrics']
}

return {
"mf_run_id": model_version_card_data["mf_run_id"],
"commit_hash": model_info.sha,
"version_label": model_version_card_data["version"],
"commit_datetime": commit_datetime,
"perf_metrics": eval_results_dict
}

return None

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import os
import json
import textwrap

from ast import literal_eval
from datetime import datetime
Expand All @@ -22,6 +23,7 @@ def _model_readme_params(
training_dataset_dict: dict,
version_label: str,
commit_datetime: datetime,
perf_metrics: dict,
mf_flow_name: str,
mf_run_id: str,
) -> dict:
Expand All @@ -43,6 +45,9 @@ def _model_readme_params(
version label are of format "major.minor"
- commit_datetime (datetime):
timestamp for the new model version.
- perf_metrics (dict):
metric_name/metric_value as
key/value pairs.
- mf_flow_name (str)
- mf_run_id (str)
Expand Down Expand Up @@ -79,6 +84,13 @@ def _model_readme_params(
if not base_model_license_label:
base_model_license_label = "unknown"

perf_metrics_yaml = textwrap.indent(
" metrics:\n" + "\n".join(
[f" - type: {key}\n value: {value}"
for key, value in perf_metrics.items()])
, ' '
)

return {
"new_version_label": version_label,
"commit_datetime": commit_datetime,
Expand All @@ -101,6 +113,8 @@ def _model_readme_params(
"base_model_arxiv_codes": base_model_arxiv_codes,
"base_model_license_label": base_model_license_label,

"perf_metrics": perf_metrics_yaml,

"__version__": __version__,
"run_user": whoami()["name"],
"mf_flow_name": mf_flow_name,
Expand All @@ -116,6 +130,7 @@ def get_model_readme_content(

version_label: str,
commit_datetime: datetime,
perf_metrics: dict,

mf_flow_name: str,
mf_run_id: str,
Expand Down Expand Up @@ -144,6 +159,9 @@ def get_model_readme_content(
version label are of format "major.minor"
- commit_datetime (datetime):
timestamp for the new dataset version.
- perf_metrics (dict):
metric_name/metric_value as
key/value pairs.
- mf_flow_name (str)
- mf_run_id (str)
Expand All @@ -156,6 +174,7 @@ def get_model_readme_content(
training_dataset_dict=training_dataset_dict,
version_label=version_label,
commit_datetime=commit_datetime,
perf_metrics=perf_metrics,
mf_flow_name=mf_flow_name,
mf_run_id=mf_run_id
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ widget:
output:
text: "Hello my name is Julien"

mf_run_id: '{{ mf_run_id }}'

# @see https://huggingface.co/docs/huggingface_hub/guides/model-cards#include-evaluation-results
model-index:
Expand All @@ -49,11 +50,7 @@ model-index:
dataset:
name: Beans
type: beans
metrics:
- type: accuracy
value: 0.7
- type: f1
value: 0.65
{{ perf_metrics }}

---

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,12 @@ def get_source(self, environment, template):
previsous_blessed_card_href=(
None if model_version_blessed
else previsous_blessed_card_href),
previsous_blessed_card_url =(
previsous_blessed_card_url=(
None if model_version_blessed
else previsous_blessed_card_url),
previous_blessed_model_commit_hash=(
None if model_version_blessed
else params['current_blessed_model_commit_hash']),
###################################

# infra validation status => #
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,12 @@ <h1 style="padding-left: 30px;" class="shiny-gold-text">{{title}}</h1>
{% if current_blessed_run_id is not none %}
<font color="#C0C0C0">Last blessed run&nbsp;:
run_id {{current_blessed_run_id}} ({{current_blessed_run_finished}})
[<a id="blessed_run_link" style="color: orange;" target="_blank"
href="{{previsous_blessed_card_href}}">here</a>]
or
[<a href="https://huggingface.co/{{ model_repo_id }}/blob/{{ previous_blessed_model_commit_hash }}/README.md"
target="_blank">HuggingFace &#x1F917;</a>]
</font>
<a id="blessed_run_link" style="color: orange;" target="_blank"
href="{{previsous_blessed_card_href}}">here</a>
<script language="javascript">
window.onload = function() {
var link = document.getElementById("blessed_run_link");
Expand Down Expand Up @@ -590,8 +593,10 @@ <h2 class="btn-sub" style="color: #6082B6;">
</td>
<td style="border: none; text-align: left;
padding-bottom: 0;">
<b>{{ dataset_repo_id }} &nbsp;
v{{ dataset_version_label }}</b>
<b>{{ dataset_repo_id }}</b>
</td>
<td style="border: none; padding-bottom: 0;">
<b>v{{ dataset_version_label }}</b>
<span class="info-container">
<span style="background-color: #C0C0C0;"
class="info-icon">i</span>
Expand All @@ -608,7 +613,7 @@ <h2 class="btn-sub" style="color: #6082B6;">
<tr style="all: initial !important;
display: table-row !important;
color: #C0C0C0 !important;;">
<td colspan="3"
<td colspan="4"
style="border: none; font-style: italic;
text-align: left; font-size: smaller;
padding-top: 0;">
Expand All @@ -625,10 +630,11 @@ <h2 class="btn-sub" style="color: #6082B6;">
padding-top: 0;">
model version&nbsp;:
</td>
<td style="border: none; padding-top: 0;">
<b>{{ model_repo_id }}</b>
<td style="border: none; text-align: left;
padding-top: 0;">
<b>{{ model_repo_id }} &nbsp;
v{{ model_version_label }}</b>
padding-bottom: 0;">
<b>v{{ model_version_label }}</b>
<span class="info-container">
<span style="background-color: #C0C0C0;"
class="info-icon">i</span>
Expand Down
77 changes: 45 additions & 32 deletions pkg_src/retrain_pipelines/utils/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,41 +370,54 @@ def get_new_repo_minor_version(
new version label
"""

refs = list_repo_refs(
repo_id=repo_id,
repo_type=repo_type,
token=hf_token
)

api = HfApi()
api_info_method = \
api.model_info if "model" == repo_type \
else api.dataset_info if "dataset" == repo_type \
else api.space_info # if "space" == repo_type

latest_version_major = 0
latest_version_minor = 0
for branch in refs.branches:
branch_model_info = api_info_method(
refs = None
new_version_label = "0.1"
try:
refs = list_repo_refs(
repo_id=repo_id,
revision=branch.target_commit,
repo_type=repo_type,
token=hf_token
)
branch_card_data = branch_model_info.card_data
if branch_card_data and "version" in branch_card_data:
branch_version_label = \
branch_card_data["version"]
branch_major, branch_minor = \
map(int, branch_version_label.split('.'))
if branch_major > latest_version_major:
latest_version_major, latest_version_minor = \
branch_major, branch_minor
elif branch_minor > latest_version_minor:
latest_version_major, latest_version_minor = \
branch_major, branch_minor
# print(branch_card_data["version"])

new_version_label = f"{branch_major}.{branch_minor+1}"
except RepositoryNotFoundError as err:
print(f"repo {repo_id} not found.\n" +
"If you are trying to access a " +
"private or gated repo, " +
"make sure you are authenticated " +
"and your credentials allow it.",
file=sys.stderr)
print(err, file=sys.stderr)

if refs:
api = HfApi()
api_info_method = \
api.model_info if "model" == repo_type \
else api.dataset_info if "dataset" == repo_type \
else api.space_info # if "space" == repo_type

latest_version_major = 0
latest_version_minor = 0
for branch in refs.branches:
branch_model_info = api_info_method(
repo_id=repo_id,
revision=branch.target_commit,
token=hf_token
)
branch_card_data = branch_model_info.card_data
if branch_card_data and "version" in branch_card_data:
branch_version_label = \
branch_card_data["version"]
branch_major, branch_minor = \
map(int, branch_version_label.split('.'))
if branch_major > latest_version_major:
latest_version_major, latest_version_minor = \
branch_major, branch_minor
elif branch_minor > latest_version_minor:
latest_version_major, latest_version_minor = \
branch_major, branch_minor
# print(branch_card_data["version"])

new_version_label = \
f"{latest_version_major}.{latest_version_minor+1}"

return new_version_label

Expand Down

0 comments on commit 2f03296

Please sign in to comment.