Skip to content

Commit

Permalink
feat: make form part and file size limits configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
khadrawy committed Dec 15, 2024
1 parent 0ba66c6 commit 716f22f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 15 deletions.
23 changes: 12 additions & 11 deletions starlette/formparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ async def parse(self) -> FormData:


class MultiPartParser:
max_file_size = 1024 * 1024 # 1MB
max_part_size = 1024 * 1024 # 1MB
default_max_part_size = 1024 * 1024 # 1MB
default_max_file_mem_size = 1024 * 1024 # 1MB
default_max_file_disk_size = 1024 * 1024 * 1024 # 1GB

def __init__(
self,
Expand All @@ -132,14 +133,18 @@ def __init__(
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_file_size: int | float | None = None,
max_part_size: int | float | None = None,
max_file_mem_size: int | float | None = None,
max_file_disk_size: int | float | None = None,
) -> None:
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.max_files = max_files
self.max_fields = max_fields
self.max_part_file_size = max_part_file_size
self.max_part_size: int | float = max_part_size or self.default_max_part_size
self.max_part_file_mem_size: int | float = max_file_mem_size or self.default_max_file_mem_size
self.max_part_file_disk_size: int | float = max_file_disk_size or self.default_max_file_disk_size
self.items: list[tuple[str, str | UploadFile]] = []
self._current_files = 0
self._current_fields = 0
Expand Down Expand Up @@ -203,7 +208,7 @@ def on_headers_finished(self) -> None:
if self._current_files > self.max_files:
raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
filename = _user_safe_decode(options[b"filename"], self._charset)
tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
tempfile = SpooledTemporaryFile(max_size=self.max_part_file_mem_size)
self._files_to_close_on_error.append(tempfile)
self._current_part.file = UploadFile(
file=tempfile, # type: ignore[arg-type]
Expand Down Expand Up @@ -257,12 +262,8 @@ async def parse(self) -> FormData:
# the main thread.
for part, data in self._file_parts_to_write:
assert part.file # for type checkers
if (
self.max_part_file_size is not None
and part.file.size is not None
and part.file.size + len(data) > self.max_part_file_size
):
raise MultiPartException(f"File exceeds maximum size of {self.max_part_file_size} bytes.")
if part.file.size is not None and part.file.size + len(data) > self.max_part_file_disk_size:
raise MultiPartException(f"File exceeds maximum size of {self.max_part_file_disk_size} bytes.")
await part.file.write(data)
for part in self._file_parts_to_finish:
assert part.file # for type checkers
Expand Down
38 changes: 34 additions & 4 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,13 @@ async def json(self) -> typing.Any:
return self._json

async def _get_form(
self, *, max_files: int | float = 1000, max_fields: int | float = 1000, max_file_size: int | None
self,
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int | float | None,
max_file_mem_size: int | float | None,
max_file_disk_size: int | float | None,
) -> FormData:
if self._form is None:
assert (
Expand All @@ -266,7 +272,9 @@ async def _get_form(
self.stream(),
max_files=max_files,
max_fields=max_fields,
max_part_file_size=max_file_size,
max_part_size=max_part_size,
max_file_mem_size=max_file_mem_size,
max_file_disk_size=max_file_disk_size,
)
self._form = await multipart_parser.parse()
except MultiPartException as exc:
Expand All @@ -281,10 +289,32 @@ async def _get_form(
return self._form

def form(
self, *, max_files: int | float = 1000, max_fields: int | float = 1000, max_file_size: int | None = None
self,
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int | float | None = None,
max_file_mem_size: int | float | None = None,
max_file_disk_size: int | float | None = None,
) -> AwaitableOrContextManager[FormData]:
"""
Return a FormData instance, representing the form data in the request.
:param max_files: The maximum number of files that can be parsed.
:param max_fields: The maximum number of fields that can be parsed.
:param max_part_size: The maximum size of each part in bytes.
:param max_file_mem_size: The maximum memory size for each file in bytes.
:param max_file_disk_size: The maximum disk size for each file in bytes.
https://docs.python.org/3/library/tempfile.html#tempfile.SpooledTemporaryFile
"""
return AwaitableOrContextManagerWrapper(
self._get_form(max_files=max_files, max_fields=max_fields, max_file_size=max_file_size)
self._get_form(
max_files=max_files,
max_fields=max_fields,
max_part_size=max_part_size,
max_file_mem_size=max_file_mem_size,
max_file_disk_size=max_file_disk_size,
)
)

async def close(self) -> None:
Expand Down

0 comments on commit 716f22f

Please sign in to comment.