diff --git a/labelu/internal/adapter/persistence/crud_pre_annotation.py b/labelu/internal/adapter/persistence/crud_pre_annotation.py index 7b3ddcb6..ff7fa4c6 100644 --- a/labelu/internal/adapter/persistence/crud_pre_annotation.py +++ b/labelu/internal/adapter/persistence/crud_pre_annotation.py @@ -15,14 +15,14 @@ def batch(db: Session, pre_annotations: List[TaskPreAnnotation]) -> List[TaskPre def list_by( db: Session, - task_id: Union[int, None], owner_id: int, - sample_name: str | None, - after: Union[int, None], - before: Union[int, None], - pageNo: Union[int, None], - pageSize: int, - sorting: Union[str, None], + task_id: int | None = None, + sample_name: str | None = None, + after: int | None = None, + before: int | None = None, + pageNo: int | None = None, + sorting: str | None = None, + pageSize: int | None = 10, ) -> Tuple[List[TaskPreAnnotation], int]: # query filter diff --git a/labelu/internal/application/response/sample.py b/labelu/internal/application/response/sample.py index 63def954..391bafb2 100644 --- a/labelu/internal/application/response/sample.py +++ b/labelu/internal/application/response/sample.py @@ -26,6 +26,9 @@ class SampleResponse(BaseModel): file: Union[object, None] = Field( default=None, description="description: media attachment file" ) + is_pre_annotated: Union[bool, None] = Field( + default=False, description="description: is pre annotated" + ) annotated_count: Union[int, None] = Field( default=0, description="description: annotate result count" ) diff --git a/labelu/internal/application/service/sample.py b/labelu/internal/application/service/sample.py index 65df5195..695d9438 100644 --- a/labelu/internal/application/service/sample.py +++ b/labelu/internal/application/service/sample.py @@ -12,8 +12,9 @@ from labelu.internal.common.converter import converter from labelu.internal.common.error_code import ErrorCode from labelu.internal.common.error_code import LabelUException -from labelu.internal.adapter.persistence import crud_task +from labelu.internal.adapter.persistence import crud_pre_annotation, crud_task from labelu.internal.adapter.persistence import crud_sample +from labelu.internal.domain.models.pre_annotation import TaskPreAnnotation from labelu.internal.domain.models.user import User from labelu.internal.domain.models.task import Task from labelu.internal.domain.models.task import TaskStatus @@ -28,6 +29,17 @@ from labelu.internal.application.response.sample import SampleResponse from labelu.internal.application.response.attachment import AttachmentResponse +def is_sample_pre_annotated(db: Session, task_id: int, current_user: User, sample_name: str = None) -> Tuple[List[TaskPreAnnotation], int]: + _, total = crud_pre_annotation.list_by( + db=db, + task_id=task_id, + owner_id=current_user.id, + sample_name=sample_name, + pageSize=1, + ) + + return total > 0 + async def create( db: Session, task_id: int, cmd: List[CreateSampleCommand], current_user: User ) -> CreateSampleResponse: @@ -74,42 +86,49 @@ async def list_by( sorting: Union[str, None], current_user: User, ) -> Tuple[List[SampleResponse], int]: + try: + samples = crud_sample.list_by( + db=db, + task_id=task_id, + owner_id=current_user.id, + after=after, + before=before, + pageNo=pageNo, + pageSize=pageSize, + sorting=sorting, + ) - samples = crud_sample.list_by( - db=db, - task_id=task_id, - owner_id=current_user.id, - after=after, - before=before, - pageNo=pageNo, - pageSize=pageSize, - sorting=sorting, - ) - - total = crud_sample.count(db=db, task_id=task_id, owner_id=current_user.id) + total = crud_sample.count(db=db, task_id=task_id, owner_id=current_user.id) - # response - return [ - SampleResponse( - id=sample.id, - inner_id=sample.inner_id, - state=sample.state, - data=json.loads(sample.data), - annotated_count=sample.annotated_count, - file=AttachmentResponse(id=sample.file.id, filename=sample.file.filename, url=sample.file.url) if sample.file else None, - created_at=sample.created_at, - created_by=UserResp( - id=sample.owner.id, - username=sample.owner.username, - ), - updated_at=sample.updated_at, - updated_by=UserResp( - id=sample.updater.id, - username=sample.updater.username, - ), + # response + return [ + SampleResponse( + id=sample.id, + inner_id=sample.inner_id, + state=sample.state, + data=json.loads(sample.data), + annotated_count=sample.annotated_count, + is_pre_annotated=is_sample_pre_annotated(db=db, task_id=task_id, current_user=current_user, sample_name=sample.file.filename), + file=AttachmentResponse(id=sample.file.id, filename=sample.file.filename, url=sample.file.url) if sample.file else None, + created_at=sample.created_at, + created_by=UserResp( + id=sample.owner.id, + username=sample.owner.username, + ), + updated_at=sample.updated_at, + updated_by=UserResp( + id=sample.updater.id, + username=sample.updater.username, + ), + ) + for sample in samples + ], total + except Exception as e: + logger.error(e) + raise LabelUException( + code=ErrorCode.CODE_55000_SAMPLE_LIST_PARAMETERS_ERROR, + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, ) - for sample in samples - ], total async def get( @@ -133,6 +152,7 @@ async def get( inner_id=sample.inner_id, state=sample.state, data=json.loads(sample.data), + is_pre_annotated=is_sample_pre_annotated(db=db, task_id=task_id, current_user=current_user, sample_name=sample.file.filename), file=AttachmentResponse(id=sample.file.id, filename=sample.file.filename, url=sample.file.url) if sample.file else None, annotated_count=sample.annotated_count, created_at=sample.created_at, @@ -210,6 +230,7 @@ async def patch( inner_id=updated_sample.inner_id, state=updated_sample.state, data=json.loads(updated_sample.data), + is_pre_annotated=is_sample_pre_annotated(db=db, task_id=task_id, current_user=current_user, sample_name=sample.file.filename), annotated_count=updated_sample.annotated_count, created_at=updated_sample.created_at, created_by=UserResp( diff --git a/pyproject.toml b/pyproject.toml index 2bb9c18f..0f3b469a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "labelu" -version = '1.1.0-alpha.24' +version = '1.1.4' description = "" license = "Apache-2.0" authors = ["shenguanlin "]