Skip to content

Commit

Permalink
feat(lab-2766): support llm projects in append_many_to_dataset (#1680)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hervé <nicolas.herve@kili-technology.com>
  • Loading branch information
HNicolas and Nicolas Hervé authored Apr 8, 2024
1 parent 50ea73b commit c3b128b
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/kili/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"Pdf": "application/pdf",
"Text": "text/plain",
"TimeSeries": "text/csv",
"LLM": "application/json",
}

mime_extensions_for_IV2 = {
Expand All @@ -27,6 +28,7 @@
"URL": "",
"VIDEO": mime_extensions["Video"],
"VIDEO_LEGACY": "",
"LLM_RLHF": mime_extensions["LLM"],
}

mime_extensions_for_py_scripts = ["text/x-python"]
Expand Down
3 changes: 2 additions & 1 deletion src/kili/domain/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .tag import TagId

ProjectId = NewType("ProjectId", str)
InputType = Literal["IMAGE", "PDF", "TEXT", "VIDEO"]
InputType = Literal["IMAGE", "PDF", "TEXT", "VIDEO", "LLM_RLHF"]


class InputTypeEnum(str, Enum):
Expand All @@ -20,6 +20,7 @@ class InputTypeEnum(str, Enum):
PDF = "PDF"
TEXT = "TEXT"
VIDEO = "VIDEO"
LLM_RLHF = "LLM_RLHF"


ComplianceTag = Literal["PHI", "PII"]
Expand Down
6 changes: 4 additions & 2 deletions src/kili/entrypoints/mutations/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class MutationsAsset(BaseOperationEntrypointMixin):
def append_many_to_dataset(
self,
project_id: str,
content_array: Optional[List[str]] = None,
content_array: Optional[Union[List[str], List[dict]]] = None,
multi_layer_content_array: Optional[List[List[dict]]] = None,
external_id_array: Optional[List[str]] = None,
id_array: Optional[List[str]] = None,
Expand All @@ -67,7 +67,9 @@ def append_many_to_dataset(
- For a VIDEO project, the content can be either URLs pointing to videos hosted on a web server or paths to
existing video files on your computer. If you want to import video from frames, look at the json_content
section below.
- For an `VIDEO_LEGACY` project, the content can be only be URLs
- For an `VIDEO_LEGACY` project, the content can be only be URLs.
- For an `LLM_RLHF` project, the content can be dicts with the keys `prompt` and `completions`,
paths to local json files or URLs to json files.
multi_layer_content_array: List containing multiple lists of paths.
Each path correspond to a layer of a geosat asset. Should be used only for `IMAGE` projects.
external_id_array: List of external ids given to identify the assets.
Expand Down
2 changes: 2 additions & 0 deletions src/kili/services/asset_import/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ProjectParams,
)
from .image import ImageDataImporter
from .llm import LLMDataImporter
from .pdf import PdfDataImporter
from .text import TextDataImporter
from .types import AssetLike
Expand All @@ -28,6 +29,7 @@
"TEXT": TextDataImporter,
"VIDEO": VideoDataImporter,
"VIDEO_LEGACY": VideoDataImporter,
"LLM_RLHF": LLMDataImporter,
}


Expand Down
1 change: 1 addition & 0 deletions src/kili/services/asset_import/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def _get_organization(self, email: str, options: QueryOptions) -> Dict:
)

def _check_upload_is_allowed(self, assets: List[AssetLike]) -> None:
# TODO: avoid querying API for each asset to upload when doing this check
if not self.is_hosted_content(assets) and not self._can_upload_from_local_data():
raise UploadFromLocalDataForbiddenError("Cannot upload content from local data")

Expand Down
76 changes: 76 additions & 0 deletions src/kili/services/asset_import/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Functions to import assets into a TEXT project."""

import json
import os
from enum import Enum
from typing import List, Optional, Tuple

from kili.core.helpers import is_url

from .base import (
BaseAbstractAssetImporter,
BatchParams,
ContentBatchImporter,
)
from .exceptions import ImportValidationError
from .types import AssetLike


class LLMDataType(Enum):
"""LLM data type."""

DICT = "DICT"
LOCAL_FILE = "LOCAL_FILE"
HOSTED_FILE = "HOSTED_FILE"


class JSONBatchImporter(ContentBatchImporter):
"""Class for importing a batch of LLM assets with dict content into a LLM_RLHF project."""

def get_content_type_and_data_from_content(self, content: Optional[str]) -> Tuple[str, str]:
"""Returns the data of the content (path) and its content type."""
return content or "", "application/json"


class LLMDataImporter(BaseAbstractAssetImporter):
"""Class for importing data into a TEXT project."""

@staticmethod
def get_data_type(assets: List[AssetLike]) -> LLMDataType:
"""Determine the type of data to upload from the service payload."""
content_array = [asset.get("content", None) for asset in assets]
if all(is_url(content) for content in content_array):
return LLMDataType.HOSTED_FILE
if all(isinstance(content, str) and os.path.exists(content) for content in content_array):
return LLMDataType.LOCAL_FILE
if all(isinstance(content, dict) for content in content_array):
return LLMDataType.DICT
raise ImportValidationError("Invalid value in content for LLM project.")

def import_assets(self, assets: List[AssetLike]):
"""Import LLM assets into Kili."""
self._check_upload_is_allowed(assets)
data_type = self.get_data_type(assets)
assets = self.filter_duplicate_external_ids(assets)
if data_type == LLMDataType.LOCAL_FILE:
assets = self.filter_local_assets(assets, self.raise_error)
batch_params = BatchParams(is_hosted=False, is_asynchronous=False)
batch_importer = ContentBatchImporter(
self.kili, self.project_params, batch_params, self.pbar
)
elif data_type == LLMDataType.HOSTED_FILE:
batch_params = BatchParams(is_hosted=True, is_asynchronous=False)
batch_importer = ContentBatchImporter(
self.kili, self.project_params, batch_params, self.pbar
)
elif data_type == LLMDataType.DICT:
for asset in assets:
if "content" in asset and isinstance(asset["content"], dict):
asset["content"] = json.dumps(asset["content"]).encode("utf-8")
batch_params = BatchParams(is_hosted=False, is_asynchronous=False)
batch_importer = JSONBatchImporter(
self.kili, self.project_params, batch_params, self.pbar
)
else:
raise ImportValidationError
return self.import_assets_by_batch(assets, batch_importer)
2 changes: 1 addition & 1 deletion src/kili/services/asset_import/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class AssetLike(TypedDict, total=False):
"""General type of an asset object through the import functions."""

content: Union[str, bytes]
content: Union[str, bytes, dict]
multi_layer_content: Union[List[dict], None]
json_content: Union[dict, str, list]
external_id: str
Expand Down
59 changes: 59 additions & 0 deletions tests/unit/services/asset_import/test_import_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from unittest.mock import patch

from kili.services.asset_import import import_assets
from tests.unit.services.asset_import.base import ImportTestCase
from tests.unit.services.asset_import.mocks import (
mocked_request_signed_urls,
mocked_unique_id,
mocked_upload_data_via_rest,
)


@patch("kili.utils.bucket.request_signed_urls", mocked_request_signed_urls)
@patch("kili.utils.bucket.upload_data_via_rest", mocked_upload_data_via_rest)
@patch("kili.utils.bucket.generate_unique_id", mocked_unique_id)
class LLMTestCase(ImportTestCase):
def test_upload_from_one_local_file(self, *_):
self.kili.kili_api_gateway.get_project.return_value = {"inputType": "LLM_RLHF"}
url = "https://storage.googleapis.com/label-public-staging/asset-test-sample/llm/test_llm_file.json"
path = self.downloader(url)
assets = [{"content": path, "external_id": "local llm file"}]
import_assets(self.kili, self.project_id, assets)
expected_parameters = self.get_expected_sync_call(
["https://signed_url?id=id"],
["local llm file"],
["unique_id"],
[False],
[""],
["{}"],
)
self.kili.graphql_client.execute.assert_called_with(*expected_parameters)

def test_upload_from_one_hosted_text_file(self, *_):
self.kili.kili_api_gateway.get_project.return_value = {"inputType": "LLM_RLHF"}
assets = [
{"content": "https://hosted-data", "external_id": "hosted file", "id": "unique_id"}
]
import_assets(self.kili, self.project_id, assets)
expected_parameters = self.get_expected_sync_call(
["https://hosted-data"], ["hosted file"], ["unique_id"], [False], [""], ["{}"]
)
self.kili.graphql_client.execute.assert_called_with(*expected_parameters)

def test_upload_from_dict(self, *_):
self.kili.kili_api_gateway.get_project.return_value = {"inputType": "LLM_RLHF"}
assets = [
{
"content": {
"prompt": "does it contain code ?",
"completions": ["first completion", "second completion", "#this is markdown"],
"type": "markdown",
},
"external_id": "dict",
}
]
import_assets(self.kili, self.project_id, assets)
expected_parameters = self.get_expected_sync_call(
["https://signed_url?id=id"], ["dict"], ["unique_id"], [False], [""], ["{}"]
)
self.kili.graphql_client.execute.assert_called_with(*expected_parameters)

0 comments on commit c3b128b

Please sign in to comment.