Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSK-3237] Persistent on HF using repo_type and repo_id #57

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 37 additions & 59 deletions giskard_cicd/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import giskard
import json
import logging
import opendal
import os
import uuid
from huggingface_hub import HfApi

from giskard_cicd.automation import (
commit_to_dataset,
Expand Down Expand Up @@ -201,88 +201,66 @@ def main():
from giskard_cicd.persistent import PERSIST_CONFIG_ENV

persist_scan_config = json.loads(os.environ.get(PERSIST_CONFIG_ENV, "{}"))
scheme = persist_scan_config.pop("scheme")
op = opendal.Operator(scheme=scheme, **persist_scan_config)
repo_type = persist_scan_config.get("repo_type", "space")
repo_id = persist_scan_config.get("repo_id", "giskard-bot/scan-report")

model_uuid = str(uuid.uuid5(uuid.NAMESPACE_OID, args.model))
scan_uuid = str(uuid.uuid4())

# Login Hugging Face: use given token or HF_WRITE_TOKEN in env
check_env_vars_and_login(hf_token=args.hf_token)

api = HfApi()
# Configurations
scanned_configs = {
"giskard_version": giskard.__version__,
**anonymous_runner_kwargs,
}
if "scan_config" in scanned_configs and scanned_configs["scan_config"]:
with open(scanned_configs["scan_config"], "r") as f:
op.write(
f"{model_uuid}/{scan_uuid}/scan_config.yaml",
f.read().encode(),
content_type="text/x-yaml",
api.upload_file(
path_or_fileobj=f,
path_in_repo=f"{model_uuid}/{scan_uuid}/scan_config.yaml",
repo_id=repo_id,
repo_type=repo_type,
)
op.write(
f"{model_uuid}/{scan_uuid}/runner_config.json",
json.dumps(scanned_configs).encode(),
content_type="application/json",

api.upload_file(
path_or_fileobj=json.dumps(scanned_configs).encode(),
path_in_repo=f"{model_uuid}/{scan_uuid}/runner_config.json",
repo_id=repo_id,
repo_type=repo_type,
)

# HTML report
html_report = report.to_html()
op.write(
f"{model_uuid}/{scan_uuid}/report.html",
html_report.encode(),
content_type="text/html",
report_url = api.upload_file(
path_or_fileobj=html_report.encode(),
path_in_repo=f"{model_uuid}/{scan_uuid}/report.html",
repo_id=repo_id,
repo_type=repo_type,
)

# AVID report
avid_reports = report.to_avid()
avid_report = "\n".join(list(map(lambda r: r.json(), avid_reports)))
op.write(
f"{model_uuid}/{scan_uuid}/avid.jsonl",
avid_report.encode(),
content_type="application/jsonl",
api.upload_file(
path_or_fileobj=avid_report.encode(),
path_in_repo=f"{model_uuid}/{scan_uuid}/avid.jsonl",
repo_id=repo_id,
repo_type=repo_type,
)

# Get URL from S3
if scheme == "s3" and "bucket" in persist_scan_config:
from giskard_cicd.persistent import s3_utils

s3_utils.init_s3_client(
access_key=(
persist_scan_config["access_key_id"]
if "access_key_id" in persist_scan_config
else ""
),
secret_key=(
persist_scan_config["secret_access_key"]
if "secret_access_key" in persist_scan_config
else ""
),
endpoint_url=(
persist_scan_config["endpoint"]
if "endpoint" in persist_scan_config
else ""
),
region_name=(
persist_scan_config["region"]
if "region" in persist_scan_config
else "auto"
),
)
s3_root = (
persist_scan_config["root"] if "root" in persist_scan_config else ""
)
if s3_root == "/":
# Trim root
s3_root = ""

persistent_url = s3_utils.get_s3_url(
persist_scan_config["bucket"],
f"{s3_root}{model_uuid}/{scan_uuid}/report.html",
)
# Get URL from HF
if repo_type == "space":
base_url = f"https://{repo_id.replace('/', '-')}.static.hf.space/"
persistent_url = f"{base_url}{model_uuid}/{scan_uuid}/report.html"
else:
persistent_url = report_url

logger.info(
f"Scan report persisted under {scheme}://{model_uuid}/{scan_uuid} ({persistent_url})"
)
logger.info(
f"Scan report persisted on Hugging Face Space ({repo_id}): {model_uuid}/{scan_uuid} ({persistent_url})"
)
except Exception:
logger.warning(
"Failed to persist scan report for "
Expand Down