From 412851e923c79bbfc2e19f1cd7305d5d5d3278d6 Mon Sep 17 00:00:00 2001 From: Mostafa Khalil Date: Sun, 15 Dec 2024 09:05:13 +0200 Subject: [PATCH] feat: make form part and file size limits configurable --- starlette/formparsers.py | 27 ++++++++++++++------------- starlette/requests.py | 38 ++++++++++++++++++++++++++++++++++---- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/starlette/formparsers.py b/starlette/formparsers.py index 9255d1e95..d304abc01 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -122,8 +122,9 @@ async def parse(self) -> FormData: class MultiPartParser: - max_file_size = 1024 * 1024 # 1MB - max_part_size = 1024 * 1024 # 1MB + default_max_field_size = 1024 * 1024 # 1MB + default_max_file_mem_size = 1024 * 1024 # 1MB + default_max_file_disk_size = 1024 * 1024 * 1024 # 1GB def __init__( self, @@ -132,14 +133,18 @@ def __init__( *, max_files: int | float = 1000, max_fields: int | float = 1000, - max_part_file_size: int | float | None = None, + max_field_size: int | float | None = None, + max_file_mem_size: int | float | None = None, + max_file_disk_size: int | float | None = None, ) -> None: assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." self.headers = headers self.stream = stream self.max_files = max_files self.max_fields = max_fields - self.max_part_file_size = max_part_file_size + self.max_field_size: int | float = max_field_size or self.default_max_field_size + self.max_file_mem_size: int | float = max_file_mem_size or self.default_max_file_mem_size + self.max_file_disk_size: int | float = max_file_disk_size or self.default_max_file_disk_size self.items: list[tuple[str, str | UploadFile]] = [] self._current_files = 0 self._current_fields = 0 @@ -157,8 +162,8 @@ def on_part_begin(self) -> None: def on_part_data(self, data: bytes, start: int, end: int) -> None: message_bytes = data[start:end] if self._current_part.file is None: - if len(self._current_part.data) + len(message_bytes) > self.max_part_size: - raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.") + if len(self._current_part.data) + len(message_bytes) > self.max_field_size: + raise MultiPartException(f"Part exceeded maximum size of {int(self.max_field_size / 1024)}KB.") self._current_part.data.extend(message_bytes) else: self._file_parts_to_write.append((self._current_part, message_bytes)) @@ -203,7 +208,7 @@ def on_headers_finished(self) -> None: if self._current_files > self.max_files: raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.") filename = _user_safe_decode(options[b"filename"], self._charset) - tempfile = SpooledTemporaryFile(max_size=self.max_file_size) + tempfile = SpooledTemporaryFile(max_size=self.max_file_mem_size) self._files_to_close_on_error.append(tempfile) self._current_part.file = UploadFile( file=tempfile, # type: ignore[arg-type] @@ -257,12 +262,8 @@ async def parse(self) -> FormData: # the main thread. for part, data in self._file_parts_to_write: assert part.file # for type checkers - if ( - self.max_part_file_size is not None - and part.file.size is not None - and part.file.size + len(data) > self.max_part_file_size - ): - raise MultiPartException(f"File exceeds maximum size of {self.max_part_file_size} bytes.") + if part.file.size is not None and part.file.size + len(data) > self.max_file_disk_size: + raise MultiPartException(f"File exceeds maximum size of {self.max_file_disk_size} bytes.") await part.file.write(data) for part in self._file_parts_to_finish: assert part.file # for type checkers diff --git a/starlette/requests.py b/starlette/requests.py index a662e2bdd..06492d4f5 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -250,7 +250,13 @@ async def json(self) -> typing.Any: return self._json async def _get_form( - self, *, max_files: int | float = 1000, max_fields: int | float = 1000, max_file_size: int | None + self, + *, + max_files: int | float = 1000, + max_fields: int | float = 1000, + max_field_size: int | float | None, + max_file_mem_size: int | float | None, + max_file_disk_size: int | float | None, ) -> FormData: if self._form is None: assert ( @@ -266,7 +272,9 @@ async def _get_form( self.stream(), max_files=max_files, max_fields=max_fields, - max_part_file_size=max_file_size, + max_field_size=max_field_size, + max_file_mem_size=max_file_mem_size, + max_file_disk_size=max_file_disk_size, ) self._form = await multipart_parser.parse() except MultiPartException as exc: @@ -281,10 +289,32 @@ async def _get_form( return self._form def form( - self, *, max_files: int | float = 1000, max_fields: int | float = 1000, max_file_size: int | None = None + self, + *, + max_files: int | float = 1000, + max_fields: int | float = 1000, + max_field_size: int | float | None = None, + max_file_mem_size: int | float | None = None, + max_file_disk_size: int | float | None = None, ) -> AwaitableOrContextManager[FormData]: + """ + Return a FormData instance, representing the form data in the request. + + :param max_files: The maximum number of files that can be parsed. + :param max_fields: The maximum number of fields that can be parsed. + :param max_field_size: The maximum size of each field part in bytes. + :param max_file_mem_size: The maximum memory size for each file part in bytes. + :param max_file_disk_size: The maximum disk size for each file part in bytes. + https://docs.python.org/3/library/tempfile.html#tempfile.SpooledTemporaryFile + """ return AwaitableOrContextManagerWrapper( - self._get_form(max_files=max_files, max_fields=max_fields, max_file_size=max_file_size) + self._get_form( + max_files=max_files, + max_fields=max_fields, + max_field_size=max_field_size, + max_file_mem_size=max_file_mem_size, + max_file_disk_size=max_file_disk_size, + ) ) async def close(self) -> None: