-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(lab-2766): support llm projects in append_many_to_dataset (#1680)
Co-authored-by: Nicolas Hervé <nicolas.herve@kili-technology.com>
- Loading branch information
Showing
8 changed files
with
147 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
"""Functions to import assets into a TEXT project.""" | ||
|
||
import json | ||
import os | ||
from enum import Enum | ||
from typing import List, Optional, Tuple | ||
|
||
from kili.core.helpers import is_url | ||
|
||
from .base import ( | ||
BaseAbstractAssetImporter, | ||
BatchParams, | ||
ContentBatchImporter, | ||
) | ||
from .exceptions import ImportValidationError | ||
from .types import AssetLike | ||
|
||
|
||
class LLMDataType(Enum): | ||
"""LLM data type.""" | ||
|
||
DICT = "DICT" | ||
LOCAL_FILE = "LOCAL_FILE" | ||
HOSTED_FILE = "HOSTED_FILE" | ||
|
||
|
||
class JSONBatchImporter(ContentBatchImporter): | ||
"""Class for importing a batch of LLM assets with dict content into a LLM_RLHF project.""" | ||
|
||
def get_content_type_and_data_from_content(self, content: Optional[str]) -> Tuple[str, str]: | ||
"""Returns the data of the content (path) and its content type.""" | ||
return content or "", "application/json" | ||
|
||
|
||
class LLMDataImporter(BaseAbstractAssetImporter): | ||
"""Class for importing data into a TEXT project.""" | ||
|
||
@staticmethod | ||
def get_data_type(assets: List[AssetLike]) -> LLMDataType: | ||
"""Determine the type of data to upload from the service payload.""" | ||
content_array = [asset.get("content", None) for asset in assets] | ||
if all(is_url(content) for content in content_array): | ||
return LLMDataType.HOSTED_FILE | ||
if all(isinstance(content, str) and os.path.exists(content) for content in content_array): | ||
return LLMDataType.LOCAL_FILE | ||
if all(isinstance(content, dict) for content in content_array): | ||
return LLMDataType.DICT | ||
raise ImportValidationError("Invalid value in content for LLM project.") | ||
|
||
def import_assets(self, assets: List[AssetLike]): | ||
"""Import LLM assets into Kili.""" | ||
self._check_upload_is_allowed(assets) | ||
data_type = self.get_data_type(assets) | ||
assets = self.filter_duplicate_external_ids(assets) | ||
if data_type == LLMDataType.LOCAL_FILE: | ||
assets = self.filter_local_assets(assets, self.raise_error) | ||
batch_params = BatchParams(is_hosted=False, is_asynchronous=False) | ||
batch_importer = ContentBatchImporter( | ||
self.kili, self.project_params, batch_params, self.pbar | ||
) | ||
elif data_type == LLMDataType.HOSTED_FILE: | ||
batch_params = BatchParams(is_hosted=True, is_asynchronous=False) | ||
batch_importer = ContentBatchImporter( | ||
self.kili, self.project_params, batch_params, self.pbar | ||
) | ||
elif data_type == LLMDataType.DICT: | ||
for asset in assets: | ||
if "content" in asset and isinstance(asset["content"], dict): | ||
asset["content"] = json.dumps(asset["content"]).encode("utf-8") | ||
batch_params = BatchParams(is_hosted=False, is_asynchronous=False) | ||
batch_importer = JSONBatchImporter( | ||
self.kili, self.project_params, batch_params, self.pbar | ||
) | ||
else: | ||
raise ImportValidationError | ||
return self.import_assets_by_batch(assets, batch_importer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from unittest.mock import patch | ||
|
||
from kili.services.asset_import import import_assets | ||
from tests.unit.services.asset_import.base import ImportTestCase | ||
from tests.unit.services.asset_import.mocks import ( | ||
mocked_request_signed_urls, | ||
mocked_unique_id, | ||
mocked_upload_data_via_rest, | ||
) | ||
|
||
|
||
@patch("kili.utils.bucket.request_signed_urls", mocked_request_signed_urls) | ||
@patch("kili.utils.bucket.upload_data_via_rest", mocked_upload_data_via_rest) | ||
@patch("kili.utils.bucket.generate_unique_id", mocked_unique_id) | ||
class LLMTestCase(ImportTestCase): | ||
def test_upload_from_one_local_file(self, *_): | ||
self.kili.kili_api_gateway.get_project.return_value = {"inputType": "LLM_RLHF"} | ||
url = "https://storage.googleapis.com/label-public-staging/asset-test-sample/llm/test_llm_file.json" | ||
path = self.downloader(url) | ||
assets = [{"content": path, "external_id": "local llm file"}] | ||
import_assets(self.kili, self.project_id, assets) | ||
expected_parameters = self.get_expected_sync_call( | ||
["https://signed_url?id=id"], | ||
["local llm file"], | ||
["unique_id"], | ||
[False], | ||
[""], | ||
["{}"], | ||
) | ||
self.kili.graphql_client.execute.assert_called_with(*expected_parameters) | ||
|
||
def test_upload_from_one_hosted_text_file(self, *_): | ||
self.kili.kili_api_gateway.get_project.return_value = {"inputType": "LLM_RLHF"} | ||
assets = [ | ||
{"content": "https://hosted-data", "external_id": "hosted file", "id": "unique_id"} | ||
] | ||
import_assets(self.kili, self.project_id, assets) | ||
expected_parameters = self.get_expected_sync_call( | ||
["https://hosted-data"], ["hosted file"], ["unique_id"], [False], [""], ["{}"] | ||
) | ||
self.kili.graphql_client.execute.assert_called_with(*expected_parameters) | ||
|
||
def test_upload_from_dict(self, *_): | ||
self.kili.kili_api_gateway.get_project.return_value = {"inputType": "LLM_RLHF"} | ||
assets = [ | ||
{ | ||
"content": { | ||
"prompt": "does it contain code ?", | ||
"completions": ["first completion", "second completion", "#this is markdown"], | ||
"type": "markdown", | ||
}, | ||
"external_id": "dict", | ||
} | ||
] | ||
import_assets(self.kili, self.project_id, assets) | ||
expected_parameters = self.get_expected_sync_call( | ||
["https://signed_url?id=id"], ["dict"], ["unique_id"], [False], [""], ["{}"] | ||
) | ||
self.kili.graphql_client.execute.assert_called_with(*expected_parameters) |