diff --git a/README.md b/README.md index 8e25a9e..5c9b309 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,27 @@ async def fetch_weather(city: str) -> str: return response.text ``` +Complex input handling example: +```python +from pydantic import BaseModel, Field +from typing import Annotated + +class ShrimpTank(BaseModel): + class Shrimp(BaseModel): + name: Annotated[str, Field(max_length=10)] + + shrimp: list[Shrimp] + +@mcp.tool() +def name_shrimp( + tank: ShrimpTank, + # You can use pydantic Field in function signatures for validation. + extra_names: Annotated[list[str], Field(max_length=10)], +) -> list[str]: + """List all shrimp names in the tank""" + return [shrimp.name for shrimp in tank.shrimp] + extra_names +``` + ### Prompts Prompts are reusable templates that help LLMs interact with your server effectively. They're like "best practices" encoded into your server. A prompt can be as simple as a string: diff --git a/examples/complex_inputs.py b/examples/complex_inputs.py new file mode 100644 index 0000000..52ad90e --- /dev/null +++ b/examples/complex_inputs.py @@ -0,0 +1,28 @@ +""" +FastMCP Complex inputs Example + +Demonstrates validation via pydantic with complex models. +""" + +from pydantic import BaseModel, Field +from typing import Annotated +from fastmcp.server import FastMCP + +mcp = FastMCP("Shrimp Tank") + + +class ShrimpTank(BaseModel): + class Shrimp(BaseModel): + name: Annotated[str, Field(max_length=10)] + + shrimp: list[Shrimp] + + +@mcp.tool() +def name_shrimp( + tank: ShrimpTank, + # You can use pydantic Field in function signatures for validation. + extra_names: Annotated[list[str], Field(max_length=10)], +) -> list[str]: + """List all shrimp names in the tank""" + return [shrimp.name for shrimp in tank.shrimp] + extra_names diff --git a/src/fastmcp/exceptions.py b/src/fastmcp/exceptions.py index 8295770..fb5bda1 100644 --- a/src/fastmcp/exceptions.py +++ b/src/fastmcp/exceptions.py @@ -15,3 +15,7 @@ class ResourceError(FastMCPError): class ToolError(FastMCPError): """Error in tool operations.""" + + +class InvalidSignature(Exception): + """Invalid signature for use with FastMCP.""" diff --git a/src/fastmcp/tools/base.py b/src/fastmcp/tools/base.py index 21c188c..3b177d2 100644 --- a/src/fastmcp/tools/base.py +++ b/src/fastmcp/tools/base.py @@ -1,8 +1,8 @@ import fastmcp from fastmcp.exceptions import ToolError - -from pydantic import BaseModel, Field, TypeAdapter, validate_call +from fastmcp.utilities.func_metadata import func_metadata, FuncMetadata +from pydantic import BaseModel, Field import inspect @@ -19,6 +19,9 @@ class Tool(BaseModel): name: str = Field(description="Name of the tool") description: str = Field(description="Description of what the tool does") parameters: dict = Field(description="JSON schema for tool parameters") + fn_metadata: FuncMetadata = Field( + description="Metadata about the function including a pydantic model for tool arguments" + ) is_async: bool = Field(description="Whether the tool is async") context_kwarg: Optional[str] = Field( None, description="Name of the kwarg that should receive context" @@ -41,9 +44,6 @@ def from_function( func_doc = description or fn.__doc__ or "" is_async = inspect.iscoroutinefunction(fn) - # Get schema from TypeAdapter - will fail if function isn't properly typed - parameters = TypeAdapter(fn).json_schema() - # Find context parameter if it exists if context_kwarg is None: sig = inspect.signature(fn) @@ -52,14 +52,18 @@ def from_function( context_kwarg = param_name break - # ensure the arguments are properly cast - fn = validate_call(fn) + func_arg_metadata = func_metadata( + fn, + skip_names=[context_kwarg] if context_kwarg is not None else [], + ) + parameters = func_arg_metadata.arg_model.model_json_schema() return cls( fn=fn, name=func_name, description=func_doc, parameters=parameters, + fn_metadata=func_arg_metadata, is_async=is_async, context_kwarg=context_kwarg, ) @@ -67,13 +71,13 @@ def from_function( async def run(self, arguments: dict, context: Optional["Context"] = None) -> Any: """Run the tool with arguments.""" try: - # Inject context if needed - if self.context_kwarg: - arguments[self.context_kwarg] = context - - # Call function with proper async handling - if self.is_async: - return await self.fn(**arguments) - return self.fn(**arguments) + return await self.fn_metadata.call_fn_with_arg_validation( + self.fn, + self.is_async, + arguments, + {self.context_kwarg: context} + if self.context_kwarg is not None + else None, + ) except Exception as e: raise ToolError(f"Error executing tool {self.name}: {e}") from e diff --git a/src/fastmcp/utilities/func_metadata.py b/src/fastmcp/utilities/func_metadata.py new file mode 100644 index 0000000..25c3baa --- /dev/null +++ b/src/fastmcp/utilities/func_metadata.py @@ -0,0 +1,200 @@ +import inspect +from collections.abc import Callable, Sequence, Awaitable +from typing import ( + Annotated, + Any, + Dict, + ForwardRef, +) +from pydantic import Field +from fastmcp.exceptions import InvalidSignature +from pydantic._internal._typing_extra import try_eval_type +import json +from pydantic import BaseModel +from pydantic.fields import FieldInfo +from pydantic import ConfigDict, create_model +from pydantic import WithJsonSchema +from pydantic_core import PydanticUndefined +from fastmcp.utilities.logging import get_logger + + +logger = get_logger(__name__) + + +class ArgModelBase(BaseModel): + """A model representing the arguments to a function.""" + + def model_dump_one_level(self) -> dict[str, Any]: + """Return a dict of the model's fields, one level deep. + + That is, sub-models etc are not dumped - they are kept as pydantic models. + """ + kwargs: dict[str, Any] = {} + for field_name in self.model_fields.keys(): + kwargs[field_name] = getattr(self, field_name) + return kwargs + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + +class FuncMetadata(BaseModel): + arg_model: Annotated[type[ArgModelBase], WithJsonSchema(None)] + # We can add things in the future like + # - Maybe some args are excluded from attempting to parse from JSON + # - Maybe some args are special (like context) for dependency injection + + async def call_fn_with_arg_validation( + self, + fn: Callable | Awaitable, + fn_is_async: bool, + arguments_to_validate: dict[str, Any], + arguments_to_pass_directly: dict[str, Any] | None, + ) -> Any: + """Call the given function with arguments validated and injected. + + Arguments are first attempted to be parsed from JSON, then validated against + the argument model, before being passed to the function. + """ + arguments_pre_parsed = self.pre_parse_json(arguments_to_validate) + arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed) + arguments_parsed_dict = arguments_parsed_model.model_dump_one_level() + + arguments_parsed_dict |= arguments_to_pass_directly or {} + + if fn_is_async: + return await fn(**arguments_parsed_dict) + return fn(**arguments_parsed_dict) + + def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: + """Pre-parse data from JSON. + + Return a dict with same keys as input but with values parsed from JSON + if appropriate. + + This is to handle cases like `["a", "b", "c"]` being passed in as JSON inside + a string rather than an actual list. Claude desktop is prone to this - in fact + it seems incapable of NOT doing this. For sub-models, it tends to pass + dicts (JSON objects) as JSON strings, which can be pre-parsed here. + """ + new_data = data.copy() # Shallow copy + for field_name, field_info in self.arg_model.model_fields.items(): + if field_name not in data.keys(): + continue + if isinstance(data[field_name], str): + try: + pre_parsed = json.loads(data[field_name]) + except json.JSONDecodeError: + continue # Not JSON - skip + if isinstance(pre_parsed, str): + # This is likely that the raw value is e.g. `"hello"` which we + # Should really be parsed as '"hello"' in Python - but if we parse + # it as JSON it'll turn into just 'hello'. So we skip it. + continue + new_data[field_name] = pre_parsed + assert new_data.keys() == data.keys() + return new_data + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + +def func_metadata(func: Callable, skip_names: Sequence[str] = ()) -> FuncMetadata: + """Given a function, return metadata including a pydantic model representing its signature. + + The use case for this is + ``` + meta = func_to_pyd(func) + validated_args = meta.arg_model.model_validate(some_raw_data_dict) + return func(**validated_args.model_dump_one_level()) + ``` + + **critically** it also provides pre-parse helper to attempt to parse things from JSON. + + Args: + func: The function to convert to a pydantic model + skip_names: A list of parameter names to skip. These will not be included in + the model. + Returns: + A pydantic model representing the function's signature. + """ + sig = _get_typed_signature(func) + params = sig.parameters + dynamic_pydantic_model_params: dict[str, Any] = {} + for param in params.values(): + if param.name.startswith("_"): + raise InvalidSignature( + f"Parameter {param.name} of {func.__name__} may not start with an underscore" + ) + if param.name in skip_names: + continue + annotation = param.annotation + + # `x: None` / `x: None = None` + if annotation is None: + annotation = Annotated[ + None, + Field( + default=param.default + if param.default is not inspect.Parameter.empty + else PydanticUndefined + ), + ] + + # Untyped field + if annotation is inspect.Parameter.empty: + annotation = Annotated[ + Any, + Field(), + # 🤷 + WithJsonSchema({"title": param.name, "type": "string"}), + ] + + field_info = FieldInfo.from_annotated_attribute( + annotation, + param.default + if param.default is not inspect.Parameter.empty + else PydanticUndefined, + ) + dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info) + continue + + arguments_model = create_model( + f"{func.__name__}Arguments", + **dynamic_pydantic_model_params, + __base__=ArgModelBase, + ) + resp = FuncMetadata(arg_model=arguments_model) + return resp + + +def _get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + annotation, status = try_eval_type(annotation, globalns, globalns) + + # This check and raise could perhaps be skipped, and we (FastMCP) just call + # model_rebuild right before using it 🤷 + if status is False: + raise InvalidSignature(f"Unable to evaluate type annotation {annotation}") + + return annotation + + +def _get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: + """Get function signature while evaluating forward references""" + signature = inspect.signature(call) + globalns = getattr(call, "__globals__", {}) + typed_params = [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=param.default, + annotation=_get_typed_annotation(param.annotation, globalns), + ) + for param in signature.parameters.values() + ] + typed_signature = inspect.Signature(typed_params) + return typed_signature diff --git a/tests/test_func_metadata.py b/tests/test_func_metadata.py new file mode 100644 index 0000000..69cf0dd --- /dev/null +++ b/tests/test_func_metadata.py @@ -0,0 +1,359 @@ +from pydantic import BaseModel, Field +from typing import Annotated +import annotated_types +from fastmcp.utilities.func_metadata import func_metadata +import pytest + + +class TestInputModelA(BaseModel): + pass + + +class TestInputModelB(BaseModel): + class InnerModel(BaseModel): + x: int + + how_many_shrimp: Annotated[int, Field(description="How many shrimp in the tank???")] + ok: InnerModel + y: None + + +def complex_arguments_fn( + an_int: int, + must_be_none: None, + must_be_none_dumb_annotation: Annotated[None, "blah"], + list_of_ints: list[int], + # list[str] | str is an interesting case because if it comes in as JSON like + # "[\"a\", \"b\"]" then it will be naively parsed as a string. + list_str_or_str: list[str] | str, + an_int_annotated_with_field: Annotated[ + int, Field(description="An int with a field") + ], + an_int_annotated_with_field_and_others: Annotated[ + int, + str, # Should be ignored, really + Field(description="An int with a field"), + annotated_types.Gt(1), + ], + an_int_annotated_with_junk: Annotated[ + int, + "123", + 456, + ], + field_with_default_via_field_annotation_before_nondefault_arg: Annotated[ + int, Field(1) + ], + unannotated, + my_model_a: TestInputModelA, + my_model_a_forward_ref: "TestInputModelA", + my_model_b: TestInputModelB, + an_int_annotated_with_field_default: Annotated[ + int, + Field(1, description="An int with a field"), + ], + unannotated_with_default=5, + my_model_a_with_default: TestInputModelA = TestInputModelA(), # noqa: B008 + an_int_with_default: int = 1, + must_be_none_with_default: None = None, + an_int_with_equals_field: int = Field(1, ge=0), + int_annotated_with_default: Annotated[int, Field(description="hey")] = 5, +) -> str: + _ = ( + an_int, + must_be_none, + must_be_none_dumb_annotation, + list_of_ints, + list_str_or_str, + an_int_annotated_with_field, + an_int_annotated_with_field_and_others, + an_int_annotated_with_junk, + field_with_default_via_field_annotation_before_nondefault_arg, + unannotated, + an_int_annotated_with_field_default, + unannotated_with_default, + my_model_a, + my_model_a_forward_ref, + my_model_b, + my_model_a_with_default, + an_int_with_default, + must_be_none_with_default, + an_int_with_equals_field, + int_annotated_with_default, + ) + return "ok!" + + +async def test_complex_function_runtime_arg_validation_non_json(): + """Test that basic non-JSON arguments are validated correctly""" + meta = func_metadata(complex_arguments_fn) + + # Test with minimum required arguments + result = await meta.call_fn_with_arg_validation( + complex_arguments_fn, + fn_is_async=False, + arguments_to_validate={ + "an_int": 1, + "must_be_none": None, + "must_be_none_dumb_annotation": None, + "list_of_ints": [1, 2, 3], + "list_str_or_str": "hello", + "an_int_annotated_with_field": 42, + "an_int_annotated_with_field_and_others": 5, + "an_int_annotated_with_junk": 100, + "unannotated": "test", + "my_model_a": {}, + "my_model_a_forward_ref": {}, + "my_model_b": {"how_many_shrimp": 5, "ok": {"x": 1}, "y": None}, + }, + arguments_to_pass_directly=None, + ) + assert result == "ok!" + + # Test with invalid types + with pytest.raises(ValueError): + await meta.call_fn_with_arg_validation( + complex_arguments_fn, + fn_is_async=False, + arguments_to_validate={"an_int": "not an int"}, + arguments_to_pass_directly=None, + ) + + +async def test_complex_function_runtime_arg_validation_with_json(): + """Test that JSON string arguments are parsed and validated correctly""" + meta = func_metadata(complex_arguments_fn) + + result = await meta.call_fn_with_arg_validation( + complex_arguments_fn, + fn_is_async=False, + arguments_to_validate={ + "an_int": 1, + "must_be_none": None, + "must_be_none_dumb_annotation": None, + "list_of_ints": "[1, 2, 3]", # JSON string + "list_str_or_str": '["a", "b", "c"]', # JSON string + "an_int_annotated_with_field": 42, + "an_int_annotated_with_field_and_others": "5", # JSON string + "an_int_annotated_with_junk": 100, + "unannotated": "test", + "my_model_a": "{}", # JSON string + "my_model_a_forward_ref": "{}", # JSON string + "my_model_b": '{"how_many_shrimp": 5, "ok": {"x": 1}, "y": null}', # JSON string + }, + arguments_to_pass_directly=None, + ) + assert result == "ok!" + + +def test_str_vs_list_str(): + """Test handling of string vs list[str] type annotations. + + This is tricky as '"hello"' can be parsed as a JSON string or a Python string. + We want to make sure it's kept as a python string. + """ + + def func_with_str_types(str_or_list: str | list[str]): + return str_or_list + + meta = func_metadata(func_with_str_types) + + # Test string input for union type + result = meta.pre_parse_json({"str_or_list": "hello"}) + assert result["str_or_list"] == "hello" + + # Test string input that contains valid JSON for union type + # We want to see here that the JSON-vali string is NOT parsed as JSON, but rather + # kept as a raw string + result = meta.pre_parse_json({"str_or_list": '"hello"'}) + assert result["str_or_list"] == '"hello"' + + # Test list input for union type + result = meta.pre_parse_json({"str_or_list": '["hello", "world"]'}) + assert result["str_or_list"] == ["hello", "world"] + + +def test_skip_names(): + """Test that skipped parameters are not included in the model""" + + def func_with_many_params( + keep_this: int, skip_this: str, also_keep: float, also_skip: bool + ): + return keep_this, skip_this, also_keep, also_skip + + # Skip some parameters + meta = func_metadata(func_with_many_params, skip_names=["skip_this", "also_skip"]) + + # Check model fields + assert "keep_this" in meta.arg_model.model_fields + assert "also_keep" in meta.arg_model.model_fields + assert "skip_this" not in meta.arg_model.model_fields + assert "also_skip" not in meta.arg_model.model_fields + + # Validate that we can call with only non-skipped parameters + model = meta.arg_model.model_validate({"keep_this": 1, "also_keep": 2.5}) + assert model.keep_this == 1 + assert model.also_keep == 2.5 + + +async def test_lambda_function(): + """Test lambda function schema and validation""" + fn = lambda x, y=5: x # noqa: E731 + meta = func_metadata(lambda x, y=5: x) + + # Test schema + assert meta.arg_model.model_json_schema() == { + "properties": { + "x": {"title": "x", "type": "string"}, + "y": {"default": 5, "title": "y", "type": "string"}, + }, + "required": ["x"], + "title": "Arguments", + "type": "object", + } + + async def check_call(args): + return await meta.call_fn_with_arg_validation( + fn, + fn_is_async=False, + arguments_to_validate=args, + arguments_to_pass_directly=None, + ) + + # Basic calls + assert await check_call({"x": "hello"}) == "hello" + assert await check_call({"x": "hello", "y": "world"}) == "hello" + assert await check_call({"x": '"hello"'}) == '"hello"' + + # Missing required arg + with pytest.raises(ValueError): + await check_call({"y": "world"}) + + +def test_complex_function_json_schema(): + meta = func_metadata(complex_arguments_fn) + assert meta.arg_model.model_json_schema() == { + "$defs": { + "InnerModel": { + "properties": {"x": {"title": "X", "type": "integer"}}, + "required": ["x"], + "title": "InnerModel", + "type": "object", + }, + "TestInputModelA": { + "properties": {}, + "title": "TestInputModelA", + "type": "object", + }, + "TestInputModelB": { + "properties": { + "how_many_shrimp": { + "description": "How many shrimp in the tank???", + "title": "How Many Shrimp", + "type": "integer", + }, + "ok": {"$ref": "#/$defs/InnerModel"}, + "y": {"title": "Y", "type": "null"}, + }, + "required": ["how_many_shrimp", "ok", "y"], + "title": "TestInputModelB", + "type": "object", + }, + }, + "properties": { + "an_int": {"title": "An Int", "type": "integer"}, + "must_be_none": {"title": "Must Be None", "type": "null"}, + "must_be_none_dumb_annotation": { + "title": "Must Be None Dumb Annotation", + "type": "null", + }, + "list_of_ints": { + "items": {"type": "integer"}, + "title": "List Of Ints", + "type": "array", + }, + "list_str_or_str": { + "anyOf": [ + {"items": {"type": "string"}, "type": "array"}, + {"type": "string"}, + ], + "title": "List Str Or Str", + }, + "an_int_annotated_with_field": { + "description": "An int with a field", + "title": "An Int Annotated With Field", + "type": "integer", + }, + "an_int_annotated_with_field_and_others": { + "description": "An int with a field", + "exclusiveMinimum": 1, + "title": "An Int Annotated With Field And Others", + "type": "integer", + }, + "an_int_annotated_with_junk": { + "title": "An Int Annotated With Junk", + "type": "integer", + }, + "field_with_default_via_field_annotation_before_nondefault_arg": { + "default": 1, + "title": "Field With Default Via Field Annotation Before Nondefault Arg", + "type": "integer", + }, + "unannotated": {"title": "unannotated", "type": "string"}, + "my_model_a": {"$ref": "#/$defs/TestInputModelA"}, + "my_model_a_forward_ref": {"$ref": "#/$defs/TestInputModelA"}, + "my_model_b": {"$ref": "#/$defs/TestInputModelB"}, + "an_int_annotated_with_field_default": { + "default": 1, + "description": "An int with a field", + "title": "An Int Annotated With Field Default", + "type": "integer", + }, + "unannotated_with_default": { + "default": 5, + "title": "unannotated_with_default", + "type": "string", + }, + "my_model_a_with_default": { + "$ref": "#/$defs/TestInputModelA", + "default": {}, + }, + "an_int_with_default": { + "default": 1, + "title": "An Int With Default", + "type": "integer", + }, + "must_be_none_with_default": { + "default": None, + "title": "Must Be None With Default", + "type": "null", + }, + "an_int_with_equals_field": { + "default": 1, + "minimum": 0, + "title": "An Int With Equals Field", + "type": "integer", + }, + "int_annotated_with_default": { + "default": 5, + "description": "hey", + "title": "Int Annotated With Default", + "type": "integer", + }, + }, + "required": [ + "an_int", + "must_be_none", + "must_be_none_dumb_annotation", + "list_of_ints", + "list_str_or_str", + "an_int_annotated_with_field", + "an_int_annotated_with_field_and_others", + "an_int_annotated_with_junk", + "unannotated", + "my_model_a", + "my_model_a_forward_ref", + "my_model_b", + ], + "title": "complex_arguments_fnArguments", + "type": "object", + } diff --git a/tests/test_tool_manager.py b/tests/test_tool_manager.py index 3192454..4356a9a 100644 --- a/tests/test_tool_manager.py +++ b/tests/test_tool_manager.py @@ -3,7 +3,7 @@ import pytest from pydantic import BaseModel - +import json from fastmcp.exceptions import ToolError from fastmcp.tools import ToolManager @@ -156,6 +156,74 @@ async def test_call_unknown_tool(self): with pytest.raises(ToolError): await manager.call_tool("unknown", {"a": 1}) + async def test_call_tool_with_list_int_input(self): + def sum_vals(vals: list[int]) -> int: + return sum(vals) + + manager = ToolManager() + manager.add_tool(sum_vals) + # Try both with plain list and with JSON list + result = await manager.call_tool("sum_vals", {"vals": "[1, 2, 3]"}) + assert result == 6 + result = await manager.call_tool("sum_vals", {"vals": [1, 2, 3]}) + assert result == 6 + + async def test_call_tool_with_list_str_or_str_input(self): + def concat_strs(vals: list[str] | str) -> str: + return vals if isinstance(vals, str) else "".join(vals) + + manager = ToolManager() + manager.add_tool(concat_strs) + # Try both with plain python object and with JSON list + result = await manager.call_tool("concat_strs", {"vals": ["a", "b", "c"]}) + assert result == "abc" + result = await manager.call_tool("concat_strs", {"vals": '["a", "b", "c"]'}) + assert result == "abc" + result = await manager.call_tool("concat_strs", {"vals": "a"}) + assert result == "a" + result = await manager.call_tool("concat_strs", {"vals": '"a"'}) + assert result == '"a"' + + async def test_call_tool_with_complex_model(self): + from fastmcp import Context + + class MyShrimpTank(BaseModel): + class Shrimp(BaseModel): + name: str + + shrimp: list[Shrimp] + x: None + + def name_shrimp(tank: MyShrimpTank, ctx: Context) -> list[str]: + return [x.name for x in tank.shrimp] + + manager = ToolManager() + manager.add_tool(name_shrimp) + result = await manager.call_tool( + "name_shrimp", + {"tank": {"x": None, "shrimp": [{"name": "rex"}, {"name": "gertrude"}]}}, + ) + assert result == ["rex", "gertrude"] + result = await manager.call_tool( + "name_shrimp", + {"tank": '{"x": null, "shrimp": [{"name": "rex"}, {"name": "gertrude"}]}'}, + ) + assert result == ["rex", "gertrude"] + + +class TestToolSchema: + async def test_context_arg_excluded_from_schema(self): + from fastmcp import Context + + def something(a: int, ctx: Context) -> int: + return a + + manager = ToolManager() + tool = manager.add_tool(something) + assert "ctx" not in json.dumps(tool.parameters) + assert "Context" not in json.dumps(tool.parameters) + assert "ctx" not in tool.fn_metadata.arg_model.model_fields + class TestContextHandling: """Test context handling in the tool manager."""