diff --git a/src/kili/core/constants.py b/src/kili/core/constants.py index 2e4d2de31..50480ab8b 100644 --- a/src/kili/core/constants.py +++ b/src/kili/core/constants.py @@ -15,6 +15,7 @@ "Pdf": "application/pdf", "Text": "text/plain", "TimeSeries": "text/csv", + "LLM": "application/json", } mime_extensions_for_IV2 = { @@ -27,6 +28,7 @@ "URL": "", "VIDEO": mime_extensions["Video"], "VIDEO_LEGACY": "", + "LLM_RLHF": mime_extensions["LLM"], } mime_extensions_for_py_scripts = ["text/x-python"] diff --git a/src/kili/domain/project.py b/src/kili/domain/project.py index ca9400442..2d0326ea8 100644 --- a/src/kili/domain/project.py +++ b/src/kili/domain/project.py @@ -10,7 +10,7 @@ from .tag import TagId ProjectId = NewType("ProjectId", str) -InputType = Literal["IMAGE", "PDF", "TEXT", "VIDEO"] +InputType = Literal["IMAGE", "PDF", "TEXT", "VIDEO", "LLM_RLHF"] class InputTypeEnum(str, Enum): @@ -20,6 +20,7 @@ class InputTypeEnum(str, Enum): PDF = "PDF" TEXT = "TEXT" VIDEO = "VIDEO" + LLM_RLHF = "LLM_RLHF" ComplianceTag = Literal["PHI", "PII"] diff --git a/src/kili/entrypoints/mutations/asset/__init__.py b/src/kili/entrypoints/mutations/asset/__init__.py index 93e9efddd..f5b8f182d 100644 --- a/src/kili/entrypoints/mutations/asset/__init__.py +++ b/src/kili/entrypoints/mutations/asset/__init__.py @@ -40,7 +40,7 @@ class MutationsAsset(BaseOperationEntrypointMixin): def append_many_to_dataset( self, project_id: str, - content_array: Optional[List[str]] = None, + content_array: Optional[Union[List[str], List[dict]]] = None, multi_layer_content_array: Optional[List[List[dict]]] = None, external_id_array: Optional[List[str]] = None, id_array: Optional[List[str]] = None, @@ -67,7 +67,9 @@ def append_many_to_dataset( - For a VIDEO project, the content can be either URLs pointing to videos hosted on a web server or paths to existing video files on your computer. If you want to import video from frames, look at the json_content section below. - - For an `VIDEO_LEGACY` project, the content can be only be URLs + - For an `VIDEO_LEGACY` project, the content can be only be URLs. + - For an `LLM_RLHF` project, the content can be dicts with the keys `prompt` and `completions`, + paths to local json files or URLs to json files. multi_layer_content_array: List containing multiple lists of paths. Each path correspond to a layer of a geosat asset. Should be used only for `IMAGE` projects. external_id_array: List of external ids given to identify the assets. diff --git a/src/kili/services/asset_import/__init__.py b/src/kili/services/asset_import/__init__.py index db423a34d..fb593d3e9 100644 --- a/src/kili/services/asset_import/__init__.py +++ b/src/kili/services/asset_import/__init__.py @@ -14,6 +14,7 @@ ProjectParams, ) from .image import ImageDataImporter +from .llm import LLMDataImporter from .pdf import PdfDataImporter from .text import TextDataImporter from .types import AssetLike @@ -28,6 +29,7 @@ "TEXT": TextDataImporter, "VIDEO": VideoDataImporter, "VIDEO_LEGACY": VideoDataImporter, + "LLM_RLHF": LLMDataImporter, } diff --git a/src/kili/services/asset_import/base.py b/src/kili/services/asset_import/base.py index 2779f080b..e8fdef227 100644 --- a/src/kili/services/asset_import/base.py +++ b/src/kili/services/asset_import/base.py @@ -486,6 +486,7 @@ def _get_organization(self, email: str, options: QueryOptions) -> Dict: ) def _check_upload_is_allowed(self, assets: List[AssetLike]) -> None: + # TODO: avoid querying API for each asset to upload when doing this check if not self.is_hosted_content(assets) and not self._can_upload_from_local_data(): raise UploadFromLocalDataForbiddenError("Cannot upload content from local data") diff --git a/src/kili/services/asset_import/llm.py b/src/kili/services/asset_import/llm.py new file mode 100644 index 000000000..2ef9cebe6 --- /dev/null +++ b/src/kili/services/asset_import/llm.py @@ -0,0 +1,76 @@ +"""Functions to import assets into a TEXT project.""" + +import json +import os +from enum import Enum +from typing import List, Optional, Tuple + +from kili.core.helpers import is_url + +from .base import ( + BaseAbstractAssetImporter, + BatchParams, + ContentBatchImporter, +) +from .exceptions import ImportValidationError +from .types import AssetLike + + +class LLMDataType(Enum): + """LLM data type.""" + + DICT = "DICT" + LOCAL_FILE = "LOCAL_FILE" + HOSTED_FILE = "HOSTED_FILE" + + +class JSONBatchImporter(ContentBatchImporter): + """Class for importing a batch of LLM assets with dict content into a LLM_RLHF project.""" + + def get_content_type_and_data_from_content(self, content: Optional[str]) -> Tuple[str, str]: + """Returns the data of the content (path) and its content type.""" + return content or "", "application/json" + + +class LLMDataImporter(BaseAbstractAssetImporter): + """Class for importing data into a TEXT project.""" + + @staticmethod + def get_data_type(assets: List[AssetLike]) -> LLMDataType: + """Determine the type of data to upload from the service payload.""" + content_array = [asset.get("content", None) for asset in assets] + if all(is_url(content) for content in content_array): + return LLMDataType.HOSTED_FILE + if all(isinstance(content, str) and os.path.exists(content) for content in content_array): + return LLMDataType.LOCAL_FILE + if all(isinstance(content, dict) for content in content_array): + return LLMDataType.DICT + raise ImportValidationError("Invalid value in content for LLM project.") + + def import_assets(self, assets: List[AssetLike]): + """Import LLM assets into Kili.""" + self._check_upload_is_allowed(assets) + data_type = self.get_data_type(assets) + assets = self.filter_duplicate_external_ids(assets) + if data_type == LLMDataType.LOCAL_FILE: + assets = self.filter_local_assets(assets, self.raise_error) + batch_params = BatchParams(is_hosted=False, is_asynchronous=False) + batch_importer = ContentBatchImporter( + self.kili, self.project_params, batch_params, self.pbar + ) + elif data_type == LLMDataType.HOSTED_FILE: + batch_params = BatchParams(is_hosted=True, is_asynchronous=False) + batch_importer = ContentBatchImporter( + self.kili, self.project_params, batch_params, self.pbar + ) + elif data_type == LLMDataType.DICT: + for asset in assets: + if "content" in asset and isinstance(asset["content"], dict): + asset["content"] = json.dumps(asset["content"]).encode("utf-8") + batch_params = BatchParams(is_hosted=False, is_asynchronous=False) + batch_importer = JSONBatchImporter( + self.kili, self.project_params, batch_params, self.pbar + ) + else: + raise ImportValidationError + return self.import_assets_by_batch(assets, batch_importer) diff --git a/src/kili/services/asset_import/types.py b/src/kili/services/asset_import/types.py index c575c639a..4785b31d7 100644 --- a/src/kili/services/asset_import/types.py +++ b/src/kili/services/asset_import/types.py @@ -8,7 +8,7 @@ class AssetLike(TypedDict, total=False): """General type of an asset object through the import functions.""" - content: Union[str, bytes] + content: Union[str, bytes, dict] multi_layer_content: Union[List[dict], None] json_content: Union[dict, str, list] external_id: str diff --git a/tests/unit/services/asset_import/test_import_llm.py b/tests/unit/services/asset_import/test_import_llm.py new file mode 100644 index 000000000..3a3adf763 --- /dev/null +++ b/tests/unit/services/asset_import/test_import_llm.py @@ -0,0 +1,59 @@ +from unittest.mock import patch + +from kili.services.asset_import import import_assets +from tests.unit.services.asset_import.base import ImportTestCase +from tests.unit.services.asset_import.mocks import ( + mocked_request_signed_urls, + mocked_unique_id, + mocked_upload_data_via_rest, +) + + +@patch("kili.utils.bucket.request_signed_urls", mocked_request_signed_urls) +@patch("kili.utils.bucket.upload_data_via_rest", mocked_upload_data_via_rest) +@patch("kili.utils.bucket.generate_unique_id", mocked_unique_id) +class LLMTestCase(ImportTestCase): + def test_upload_from_one_local_file(self, *_): + self.kili.kili_api_gateway.get_project.return_value = {"inputType": "LLM_RLHF"} + url = "https://storage.googleapis.com/label-public-staging/asset-test-sample/llm/test_llm_file.json" + path = self.downloader(url) + assets = [{"content": path, "external_id": "local llm file"}] + import_assets(self.kili, self.project_id, assets) + expected_parameters = self.get_expected_sync_call( + ["https://signed_url?id=id"], + ["local llm file"], + ["unique_id"], + [False], + [""], + ["{}"], + ) + self.kili.graphql_client.execute.assert_called_with(*expected_parameters) + + def test_upload_from_one_hosted_text_file(self, *_): + self.kili.kili_api_gateway.get_project.return_value = {"inputType": "LLM_RLHF"} + assets = [ + {"content": "https://hosted-data", "external_id": "hosted file", "id": "unique_id"} + ] + import_assets(self.kili, self.project_id, assets) + expected_parameters = self.get_expected_sync_call( + ["https://hosted-data"], ["hosted file"], ["unique_id"], [False], [""], ["{}"] + ) + self.kili.graphql_client.execute.assert_called_with(*expected_parameters) + + def test_upload_from_dict(self, *_): + self.kili.kili_api_gateway.get_project.return_value = {"inputType": "LLM_RLHF"} + assets = [ + { + "content": { + "prompt": "does it contain code ?", + "completions": ["first completion", "second completion", "#this is markdown"], + "type": "markdown", + }, + "external_id": "dict", + } + ] + import_assets(self.kili, self.project_id, assets) + expected_parameters = self.get_expected_sync_call( + ["https://signed_url?id=id"], ["dict"], ["unique_id"], [False], [""], ["{}"] + ) + self.kili.graphql_client.execute.assert_called_with(*expected_parameters)