From 578c0b84c107c1bd5febcb4c0a9f469eeb3fd0a6 Mon Sep 17 00:00:00 2001 From: phact Date: Tue, 19 Mar 2024 12:11:12 -0400 Subject: [PATCH] go to file extension stop list --- VERSION | 2 +- impl/routes/files.py | 1 - impl/services/file.py | 92 ++++++++++++++++++++++++++++++++----------- 3 files changed, 69 insertions(+), 26 deletions(-) diff --git a/VERSION b/VERSION index fad30c5..da14730 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v0.1.7 +v0.1.8 diff --git a/impl/routes/files.py b/impl/routes/files.py index 07a90fc..a613b32 100644 --- a/impl/routes/files.py +++ b/impl/routes/files.py @@ -9,7 +9,6 @@ Depends, File, Form, - Header, Path, Request, Query, diff --git a/impl/services/file.py b/impl/services/file.py index 9c6fcdc..312790f 100644 --- a/impl/services/file.py +++ b/impl/services/file.py @@ -13,6 +13,53 @@ from impl.astra_vector import HandledResponse from impl.models import Document +exclude_exts: list[str] = [ + ".map", + ".tfstate", + ".jar", + ".png", + ".jpg", + ".jpeg", + ".download", + ".gif", + ".bmp", + ".tiff", + ".ico", + ".mp3", + ".wav", + ".wma", + ".ogg", + ".flac", + ".mp4", + ".avi", + ".mkv", + ".mov", + ".patch", + ".wmv", + ".m4a", + ".m4v", + ".3gp", + ".3g2", + ".rm", + ".swf", + ".flv", + ".iso", + ".bin", + ".tar", + ".zip", + ".7z", + ".gz", + ".rar", + ".svg", + ".pyc", + ".pub", + ".pem", + ".ttf", + ".dfn", + ".dfm", + ".feature", + ".lock", +] async def get_document_from_file(file: UploadFile, file_id: str) -> Document: extracted_text = await extract_text_from_from_file(file) @@ -25,16 +72,25 @@ async def get_document_from_file(file: UploadFile, file_id: str) -> Document: def extract_text_from_filepath(filepath: str, mimetype: Optional[str] = None) -> str: """Return the text content of a file given its filepath.""" + #get extension from filepath for example /tmp/pytest.ini + extension = os.path.splitext(filepath)[1] if mimetype is None or mimetype == "application/octet-stream": # Get the mimetype of the file based on its extension mimetype, _ = mimetypes.guess_type(filepath) - if not mimetype: - if filepath.endswith(".md"): - mimetype = "text/markdown" + # when there's no mimetype, treat other valid extensions as text/plain, including files without extensions (i.e. Dockerfile) + if extension not in exclude_exts: + mimetype = "text/plain" else: - raise HTTPException(status_code=400, detail="Unsupported file type") - + # Unsupported file type + raise HTTPException( + status_code=400, + detail="Unsupported file type: {}".format(filepath), + ) + else: + # treat programming language extensions as text/plain regardless of mimetype + if extension in (".c", ".cpp", ".css", ".html", ".java", ".js", ".json", ".md", ".php", ".py", ".rb", ".ts", ".xml"): + mimetype = "text/plain" try: with open(filepath, "rb") as file: extracted_text = extract_text_from_file(file, mimetype) @@ -51,7 +107,7 @@ def extract_text_from_file(file: BufferedReader, mimetype: str) -> str: # Extract text from pdf using PyPDF2 reader = PdfReader(file) extracted_text = " ".join([page.extract_text() for page in reader.pages]) - elif mimetype == "text/plain" or mimetype == "text/markdown": + elif mimetype == "text/plain" or mimetype == "text/markdown" or "application/sql": # Read text from plain text file extracted_text = file.read().decode("utf-8") elif ( @@ -61,6 +117,7 @@ def extract_text_from_file(file: BufferedReader, mimetype: str) -> str: # Extract text from docx using docx2txt extracted_text = docx2txt.process(file) # TODO: supported formats should be Supported formats: "c", "cpp", "css", "csv", "docx", "gif", "html", "java", "jpeg", "jpg", "js", "json", "md", "pdf", "php", "png", "pptx", "py", "rb", "tar", "tex", "ts", "txt", "xlsx", "xml", "zip" + # figure out what they do with the images. elif mimetype == "text/csv": # Extract text from csv using csv module extracted_text = "" @@ -83,24 +140,11 @@ def extract_text_from_file(file: BufferedReader, mimetype: str) -> str: extracted_text += run.text + " " extracted_text += "\n" else: - raw_extension = mimetypes.guess_extension(mimetype) - if raw_extension is not None: - extension = raw_extension[1:] - if extension in ("c", "cpp", "css", "html", "java", "js", "json", "md", "php", "py", "rb", "ts", "xml"): - extracted_text = file.read().decode("utf-8") - else: - # Unsupported file type - raise HTTPException( - status_code=400, - detail="Unsupported file type: {}".format(mimetype), - ) - - else: - # Unsupported file type - raise HTTPException( - status_code=400, - detail="Unsupported file type: {}".format(mimetype), - ) + # Unsupported file type + raise HTTPException( + status_code=400, + detail="Unsupported file type: {}".format(mimetype), + ) return extracted_text