From e0418780e2fe4e027409a5418ad915c93978b612 Mon Sep 17 00:00:00 2001 From: Riccardo Orlando Date: Fri, 2 Aug 2024 13:40:00 +0200 Subject: [PATCH] chore: Update Dockerfile entrypoint to use start-gunic.sh script, update relik version --- relik/cli/cli.py | 3 +- relik/inference/annotator.py | 40 +++++++++++++++++++-- relik/inference/serve/backend/fastapi_be.py | 17 +++++++++ relik/inference/serve/frontend/gradio_fe.py | 25 +++---------- 4 files changed, 60 insertions(+), 25 deletions(-) diff --git a/relik/cli/cli.py b/relik/cli/cli.py index 248f7ba..5db8e9b 100644 --- a/relik/cli/cli.py +++ b/relik/cli/cli.py @@ -216,10 +216,9 @@ def serve( annotation_type=annotation_type, host=host, port=port, + frontend=frontend ) - if frontend: - serve_gradio() if __name__ == "__main__": diff --git a/relik/inference/annotator.py b/relik/inference/annotator.py index 65fefb5..6fc51fd 100644 --- a/relik/inference/annotator.py +++ b/relik/inference/annotator.py @@ -348,6 +348,37 @@ def __call__( self.window_manager = WindowManager( self.tokenizer, self.sentence_splitter ) + else: + if isinstance(self.sentence_splitter, WindowSentenceSplitter): + if not isinstance(window_size, int): + logger.warning( + "With WindowSentenceSplitter the window_size must be an integer. " + f"Using the default window size {self.window_manager.window_size}." + f"If you want to change the window size to `sentence` or `none`, " + f"please create a new Relik instance." + ) + window_size = self.window_manager.window_size + window_stride = self.window_manager.window_stride + if isinstance(self.sentence_splitter, SpacySentenceSplitter): + if window_size != "sentence": + logger.warning( + "With SpacySentenceSplitter the window_size must be `sentence`. " + f"Using the default window size {self.window_manager.window_size}." + f"If you want to change the window size to an integer or `none`, " + f"please create a new Relik instance." + ) + window_size = "sentence" + window_stride = None + if isinstance(self.sentence_splitter, BlankSentenceSplitter): + if window_size != "none" or window_stride is not None: + logger.warning( + "With BlankSentenceSplitter the window_size must be `none`. " + f"Using the default window size {self.window_manager.window_size}." + f"If you want to change the window size to an integer or `sentence`, " + f"please create a new Relik instance." + ) + window_size = "none" + window_stride = None # sanity check for window size and stride if ( @@ -428,7 +459,8 @@ def __call__( **kwargs, ) windows_candidates[task_type] = [ - [p.document for p in predictions] for predictions in retriever_out + [p.document for p in predictions] + for predictions in retriever_out ] else: # check if the candidates are a list of lists @@ -467,7 +499,8 @@ def __call__( **kwargs, ) windows_candidates[task_type] = [ - [p.document for p in predictions] for predictions in retriever_out + [p.document for p in predictions] + for predictions in retriever_out ] # clean up None's @@ -518,6 +551,8 @@ def __call__( windows = windows + blank_windows windows.sort(key=lambda x: (x.doc_id, x.offset)) + print(windows) + # if there is no reader, just return the windows if self.reader is None: # normalize window candidates to be a list of lists, like when the reader is used @@ -528,6 +563,7 @@ def __call__( merged_windows = self.window_manager.merge_windows(windows) # transform predictions into RelikOutput objects + print(merged_windows) output = [] for w in merged_windows: span_labels = [] diff --git a/relik/inference/serve/backend/fastapi_be.py b/relik/inference/serve/backend/fastapi_be.py index 67a0ef5..d6e0a3d 100644 --- a/relik/inference/serve/backend/fastapi_be.py +++ b/relik/inference/serve/backend/fastapi_be.py @@ -148,6 +148,16 @@ async def relik_endpoint( relation_threshold: float = 0.5, ) -> List: try: + if window_size: + # check if window size is a number as string + if window_size.isdigit(): + window_size = int(window_size) + + if window_stride: + # check if window stride is a number as string + if window_stride.isdigit(): + window_stride = int(window_stride) + # get predictions for the retriever return await self( text=text, @@ -182,6 +192,7 @@ def main( workers: int = None, host: str = "localhost", port: int = 8000, + frontend: bool = False, ): app = FastAPI( title="ReLiK - A blazing fast and lightweight Information Extraction model for Entity Linking and Relation Extraction.", @@ -201,6 +212,12 @@ def main( annotation_type=annotation_type, ) app.include_router(server.router) + if frontend: + from relik.inference.serve.frontend.gradio_fe import main as serve_frontend + import threading + + threading.Thread(target=serve_frontend, daemon=True).start() + uvicorn.run(app, host=host, port=port, log_level="info", workers=workers) diff --git a/relik/inference/serve/frontend/gradio_fe.py b/relik/inference/serve/frontend/gradio_fe.py index c096ff2..11d61c2 100644 --- a/relik/inference/serve/frontend/gradio_fe.py +++ b/relik/inference/serve/frontend/gradio_fe.py @@ -305,22 +305,8 @@ def generate_graph( RELIK = os.getenv("RELIK", "localhost:8000/api/relik") -def text_analysis(Text, Model, Relation_Threshold, Window_Size, Window_Stride): - global loaded_model - if Model is None: - return "", "" - # if loaded_model is None or loaded_model["key"] != Model: - # relik = Relik.from_pretrained(Model, index_precision="bf16") - # loaded_model = {"key": Model, "model": relik} - # else: - # relik = loaded_model["model"] - # if Model not in relik_models: - # raise ValueError(f"Model {Model} not found.") - # relik = relik_models[Model] - # spacy for span visualization - +def text_analysis(Text, Relation_Threshold, Window_Size, Window_Stride): relik = RELIK - nlp = spacy.blank("xx") # annotated_text = relik( # Text, @@ -331,8 +317,10 @@ def text_analysis(Text, Model, Relation_Threshold, Window_Size, Window_Stride): # window_size=Window_Size, # window_stride=Window_Stride, # ) + print(f"Using ReLiK at {relik}") + print(f"Querying ReLiK with ?text={Text}&relation_threshold={Relation_Threshold}&window_size={Window_Size}&window_stride={Window_Stride}&annotation_type=word&remove_nmes=False") response = requests.get( - f"{relik}?text={Text}&relation_threshold={Relation_Threshold}&window_size={Window_Size}&window_stride={Window_Stride}&annotation_type=word&remove_nmes=False" + f"http://{relik}/?text={Text}&relation_threshold={Relation_Threshold}&window_size={Window_Size}&window_stride={Window_Stride}&annotation_type=word&remove_nmes=False", ) if response.status_code != 200: raise gr.Error(response.text) @@ -390,11 +378,6 @@ def text_analysis(Text, Model, Relation_Threshold, Window_Size, Window_Stride): text_analysis, [ gr.Textbox(label="Input Text", placeholder="Enter sentence here..."), - # gr.Dropdown( - # relik_available_models, - # value=relik_available_models[0], - # label="Relik Model", - # ), gr.Slider( minimum=0, maximum=1,