Skip to content

Commit

Permalink
wip: classifier voting
Browse files Browse the repository at this point in the history
  • Loading branch information
Marie Dev Bot committed Feb 17, 2024
1 parent 02e2360 commit 8462c5f
Show file tree
Hide file tree
Showing 10 changed files with 252 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Dockerfiles/docker-compose.s3-debug.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version: "3.8"
version: "3.9"

services:
s3server:
Expand Down
2 changes: 1 addition & 1 deletion marie/components/document_classifier/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def __getitem__(self, idx):
self.task == "text-classification-multimodal"
or self.task == "text-classification"
):
print(prediction[0])
formatted_prediction = {
"label": prediction[0]["label"],
"score": prediction[0]["score"],
Expand All @@ -313,7 +314,6 @@ def predict_document_image(
:param top_k: number of predictions to return
:return: prediction dictionary with label and score
"""

if self.id2label is None:
id2label = self.model.config.id2label
else:
Expand Down
11 changes: 10 additions & 1 deletion marie/pipe/classification_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import shutil
from abc import abstractmethod
from datetime import datetime
from typing import List, Optional, Union

Expand All @@ -25,6 +26,7 @@
split_filename,
store_assets,
)
from marie.pipe.voting import ClassificationResult, MaxScoreVoter
from marie.utils.docs import docs_from_image
from marie.utils.image_utils import hash_frames_fast
from marie.utils.json import store_json_object
Expand Down Expand Up @@ -61,7 +63,7 @@ def __init__(
pipelines_config: List[dict[str, any]] = None,
**kwargs,
) -> None:
# super().__init__(**kwargs)
super().__init__(**kwargs)
self.show_error = True # show prediction errors
self.logger = MarieLogger(context=self.__class__.__name__)

Expand Down Expand Up @@ -407,11 +409,18 @@ def execute_pipeline(
detail["classifier"] = classifier
class_by_page[page].append(detail)

voter = MaxScoreVoter()

# Classification strategy: max_score, max_votes, max_score_with_diff
# calculate max score for each page by max score
score_by_page = {}
for page, details in class_by_page.items():
max_score = 0.0

for detail in details:
print(detail)
result = ClassificationResult(**detail)

if page not in score_by_page:
score_by_page[page] = {}

Expand Down
32 changes: 32 additions & 0 deletions marie/pipe/voting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from abc import abstractmethod
from typing import List

from pydantic import BaseModel


class ClassificationResult(BaseModel):
classification: int
score: float
classifier: str


class VotingStrategy:
@abstractmethod
def vote(self, results: List[ClassificationResult]) -> tuple[str, float]:
pass


class MajorityVoter(VotingStrategy):
def __init__(self) -> None:
super().__init__()

def vote(self, results: List[ClassificationResult]) -> tuple[str, float]:
pass


class MaxScoreVoter(VotingStrategy):
def __init__(self) -> None:
super().__init__()

def vote(self, results: List[ClassificationResult]) -> tuple[str, float]:
pass
2 changes: 1 addition & 1 deletion marie/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def load_json_file(filename) -> Any:

def deserialize_value(json_str) -> Any:
"""Deserialize a JSON string to an object."""
data = json.load(json_str)
data = json.loads(json_str)
return data


Expand Down
2 changes: 1 addition & 1 deletion tests/integration/check_classification_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def setup_storage():

# s3_path = s3_asset_path(ref_id=filename, ref_type="pid", include_filename=True)
# StorageManager.write(img_path, s3_path, overwrite=True)

#
config = load_yaml(
os.path.join(
__config_dir__, "tests-integration", "pipeline-classify-005.partial.yml"
Expand Down
12 changes: 12 additions & 0 deletions workspaces/document-classifier-pipeline/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---
title: Document Classifier Pipeline
sdk: gradio
sdk_version: 1.0.0
app_file: app.py
license: apache-2.0
---

Check out the configuration reference at http://marieai.io/docs/ClassificationPipeline



168 changes: 168 additions & 0 deletions workspaces/document-classifier-pipeline/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import argparse
import os
import tempfile
from functools import partial
from typing import List, Union

import gradio as gr
import numpy as np
import torch as torch

from marie.conf.helper import load_yaml
from marie.helper import colored
from marie.logging.mdc import MDC
from marie.logging.profile import TimeContext
from marie.pipe.classification_pipeline import ClassificationPipeline
from marie.utils.docs import frames_from_file
from marie.utils.json import deserialize_value, to_json

use_cuda = torch.cuda.is_available()


def cleanup_json(mydict: dict):
return deserialize_value(to_json(mydict))


def process_frames(
frames: Union[np.ndarray, List[np.ndarray]],
pipeline: ClassificationPipeline,
):
MDC.put("request_id", "1")

if not isinstance(frames, list):
frames = [frames]
filename = "test-gradio"
with TimeContext(f"### ClassificationPipeline info"):
results = pipeline.execute(
ref_id=filename, ref_type="pid", frames=frames, runtime_conf=None
)
val = cleanup_json(results)
print('val', val)
return val


gallery_selection = None


def process_all_frames(pipeline: ClassificationPipeline, image_src):
MDC.put("request_id", "2")
frames = gradio_src_to_frames(image_src)
results = process_frames(frames, pipeline)
return results


def process_selection(pipeline: ClassificationPipeline, gallery_selection):
print("process_selection")
MDC.put("request_id", "3")
filename = gallery_selection["name"]
frame = frames_from_file(filename)[0]
results = process_frames(frame, pipeline)

return results


def gradio_src_to_frames(image_src):
if image_src is None:
return None
if not isinstance(image_src, tempfile._TemporaryFileWrapper):
raise Exception(
"Expected image_src to be of type tempfile._TemporaryFileWrapper, "
"ensure that the source is set to 'upload' in the gr.File component."
)
return frames_from_file(image_src.name)


def interface(classifier: ClassificationPipeline):
def gallery_click_handler(src_gallery, evt: gr.SelectData):
global gallery_selection
gallery_selection = src_gallery[evt.index]

with gr.Blocks() as iface:
gr.HTML(
"""
<div style="text-align: center; max-width: 1200px; margin: 20px auto;">
<h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
Document Classification Pipeline
</h1>
</div>
"""
)
with gr.Row():
src = gr.File(type="file", source="upload")

with gr.Row():
btn_reset = gr.Button("Clear")
# btn_submit = gr.Butt
# .on("Classify All", variant="primary")
# btn_submit_selected = gr.Button("Classify Selected", variant="primary")
btn_grid = gr.Button("Build-Grid", variant="primary")

with gr.Row(live=True):
gallery = gr.Gallery(
label="Image frames",
show_label=False,
elem_id="gallery",
interactive=True,
).style(columns=4, object_fit="contain", height="auto")

with gr.Row():
btn_submit_all = gr.Button("Classify All", variant="primary")
btn_submit_selected = gr.Button("Classify Selected", variant="primary")

with gr.Row():
with gr.Column():
json_output = gr.outputs.JSON()

btn_submit_all.click(
partial(process_all_frames, classifier),
inputs=[src],
outputs=[json_output],
)

btn_submit_selected.click(
partial(process_selection, classifier),
inputs=[gallery],
outputs=[json_output],
)

btn_grid.click(gradio_src_to_frames, inputs=[src], outputs=gallery)
btn_reset.click(lambda: src.clear())

gallery.select(gallery_click_handler, inputs=[gallery])

iface.launch(debug=True, share=False, server_name="0.0.0.0")


def parse_args():
parser = argparse.ArgumentParser(description="Document classification pipeline.")
parser.add_argument(
"--pipeline_path",
type=str,
help="Path to the pipeline configuration file",
)
args = parser.parse_args()
print(f'{colored("[√]", "green")} Arguments are loaded.')
print(args)
return args


if __name__ == "__main__":
import torch

torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = False
# torch._dynamo.config.suppress_errors = False
os.environ["MARIE_SUPPRESS_WARNINGS"] = "true"
args = parse_args()
config = load_yaml(
os.path.expanduser(
"~/dev/marieai/marie-ai/config/tests-integration/pipeline-classify-006.partial.yml"
)
)

pipelines_config = config["pipelines"]
pipeline = ClassificationPipeline(pipelines_config=pipelines_config)

interface(classifier=pipeline)

# python ./app.py --pretrained_model_name_or_path marie/lmv3-medical-document-classification
Empty file.
34 changes: 26 additions & 8 deletions workspaces/document-classifier/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@
from marie.logging.predefined import default_logger as logger
from marie.ocr import DefaultOcrEngine, MockOcrEngine
from marie.ocr.util import get_words_and_boxes
from marie.registry.model_registry import ModelRegistry
from marie.utils.docs import docs_from_image, frames_from_file

use_cuda = torch.cuda.is_available()

# # TODO : add support for dependency injection
# MDC.put("request_id", "0")

mock_ocr = False
if mock_ocr:
ocr_engine = MockOcrEngine(cuda=use_cuda)
Expand All @@ -39,7 +37,6 @@ def process_frames(
frames = [frames]

ocr_results = ocr_engine.extract(frames)
# classifier = TransformersDocumentClassifier(model_name_or_path=model_name_or_path)
documents = docs_from_image(frames)

words = []
Expand Down Expand Up @@ -78,13 +75,19 @@ def process_all_frames(
return results


def process_selection(model_name_or_path: str, evt):
print("process_selection", evt)
def process_selection(
model_name_or_path: str,
classifier: TransformersDocumentClassifier,
gallery_selection,
):
print("process_selection")
print("model_name_or_path", model_name_or_path)
MDC.put("request_id", "3")
filename = gallery_selection["name"]
frame = frames_from_file(filename)[0]
results = process_frames(frame, model_name_or_path=model_name_or_path)
results = process_frames(
frame, model_name_or_path=model_name_or_path, classifier=classifier
)

return results[0]

Expand Down Expand Up @@ -152,6 +155,7 @@ def gallery_click_handler(src_gallery, evt: gr.SelectData):
inputs=[src],
outputs=[json_output],
)

btn_submit_selected.click(
partial(process_selection, model_name_or_path, classifier),
inputs=[gallery],
Expand Down Expand Up @@ -183,6 +187,20 @@ def parse_args():
return args


def ensure_model(model_name_or_path):
kwargs = {
# "__model_path__": os.path.expanduser("~/tmp/models"),
"use_auth_token": False,
} # custom model path
resolved_model_name_or_path = ModelRegistry.get(
model_name_or_path,
version=None,
raise_exceptions_for_missing_entries=True,
**kwargs,
)
return resolved_model_name_or_path


if __name__ == "__main__":
import torch

Expand All @@ -192,7 +210,7 @@ def parse_args():
os.environ["MARIE_SUPPRESS_WARNINGS"] = "true"

args = parse_args()
model_name_or_path = args.pretrained_model_name_or_path
model_name_or_path = ensure_model(args.pretrained_model_name_or_path)

logger.info(f"Using model : {model_name_or_path}")
classifier = TransformersDocumentClassifier(model_name_or_path=model_name_or_path)
Expand Down

0 comments on commit 8462c5f

Please sign in to comment.