From f5d82bd22966af44c4603104147ba98f09e67c3c Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Mon, 11 Nov 2024 20:05:51 +0000 Subject: [PATCH 1/3] Remove helper types The helper types in mcp.server.types got really confusioning during implementation as they overlapped with mcp.types. I now believe it is better if we stay more low level to the spec types. To do this, we now only use mcp.types everywhere. We renamed mcp.server.types to mcp.server.models and removed it to the absolute minimum. --- src/mcp/server/__init__.py | 58 +++++++++--------------------------- src/mcp/server/__main__.py | 2 +- src/mcp/server/models.py | 19 ++++++++++++ src/mcp/server/session.py | 2 +- src/mcp/server/types.py | 46 ---------------------------- tests/conftest.py | 2 +- tests/server/test_session.py | 2 +- uv.lock | 7 ++--- 8 files changed, 40 insertions(+), 98 deletions(-) create mode 100644 src/mcp/server/models.py delete mode 100644 src/mcp/server/types.py diff --git a/src/mcp/server/__init__.py b/src/mcp/server/__init__.py index f7e66a3..29133e0 100644 --- a/src/mcp/server/__init__.py +++ b/src/mcp/server/__init__.py @@ -7,7 +7,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl -from mcp.server import types +from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.stdio import stdio_server as stdio_server from mcp.shared.context import RequestContext @@ -15,6 +15,10 @@ from mcp.types import ( METHOD_NOT_FOUND, CallToolRequest, + GetPromptResult, + GetPromptRequest, + GetPromptResult, + ImageContent, ClientNotification, ClientRequest, CompleteRequest, @@ -84,7 +88,7 @@ def create_initialization_options( self, notification_options: NotificationOptions | None = None, experimental_capabilities: dict[str, dict[str, Any]] | None = None, - ) -> types.InitializationOptions: + ) -> InitializationOptions: """Create initialization options from this server instance.""" def pkg_version(package: str) -> str: @@ -99,7 +103,7 @@ def pkg_version(package: str) -> str: return "unknown" - return types.InitializationOptions( + return InitializationOptions( server_name=self.name, server_version=pkg_version("mcp"), capabilities=self.get_capabilities( @@ -168,50 +172,16 @@ async def handler(_: Any): return decorator def get_prompt(self): - from mcp.types import ( - GetPromptRequest, - GetPromptResult, - ImageContent, - ) - from mcp.types import ( - Role as Role, - ) - def decorator( func: Callable[ - [str, dict[str, str] | None], Awaitable[types.PromptResponse] + [str, dict[str, str] | None], Awaitable[GetPromptResult] ], ): logger.debug("Registering handler for GetPromptRequest") async def handler(req: GetPromptRequest): prompt_get = await func(req.params.name, req.params.arguments) - messages: list[PromptMessage] = [] - for message in prompt_get.messages: - match message.content: - case str() as text_content: - content = TextContent(type="text", text=text_content) - case types.ImageContent() as img_content: - content = ImageContent( - type="image", - data=img_content.data, - mimeType=img_content.mime_type, - ) - case types.EmbeddedResource() as resource: - content = EmbeddedResource( - type="resource", resource=resource.resource - ) - case _: - raise ValueError( - f"Unexpected content type: {type(message.content)}" - ) - - prompt_message = PromptMessage(role=message.role, content=content) - messages.append(prompt_message) - - return ServerResult( - GetPromptResult(description=prompt_get.desc, messages=messages) - ) + return ServerResult(prompt_get) self.request_handlers[GetPromptRequest] = handler return func @@ -338,7 +308,7 @@ def call_tool(self): def decorator( func: Callable[ ..., - Awaitable[Sequence[str | types.ImageContent | types.EmbeddedResource]], + Awaitable[Sequence[TextContent | ImageContent | EmbeddedResource]], ], ): logger.debug("Registering handler for CallToolRequest") @@ -351,15 +321,15 @@ async def handler(req: CallToolRequest): match result: case str() as text: content.append(TextContent(type="text", text=text)) - case types.ImageContent() as img: + case ImageContent() as img: content.append( ImageContent( type="image", data=img.data, - mimeType=img.mime_type, + mimeType=img.mimeType, ) ) - case types.EmbeddedResource() as resource: + case EmbeddedResource() as resource: content.append( EmbeddedResource( type="resource", resource=resource.resource @@ -427,7 +397,7 @@ async def run( self, read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[JSONRPCMessage], - initialization_options: types.InitializationOptions, + initialization_options: InitializationOptions, # When True, exceptions are returned as messages to the client. # When False, exceptions are raised, which will cause the server to shut down # but also make tracing exceptions much easier during testing and when using diff --git a/src/mcp/server/__main__.py b/src/mcp/server/__main__.py index 2313f46..417380d 100644 --- a/src/mcp/server/__main__.py +++ b/src/mcp/server/__main__.py @@ -6,7 +6,7 @@ from mcp.server.session import ServerSession from mcp.server.stdio import stdio_server -from mcp.server.types import InitializationOptions +from mcp.server.models import InitializationOptions from mcp.types import ServerCapabilities if not sys.warnoptions: diff --git a/src/mcp/server/models.py b/src/mcp/server/models.py new file mode 100644 index 0000000..8920597 --- /dev/null +++ b/src/mcp/server/models.py @@ -0,0 +1,19 @@ +""" +This module provides simpler types to use with the server for managing prompts +and tools. +""" + +from dataclasses import dataclass +from typing import Literal + +from pydantic import BaseModel + +from mcp.types import ( + ServerCapabilities, +) + + +class InitializationOptions(BaseModel): + server_name: str + server_version: str + capabilities: ServerCapabilities diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index f6ed1b3..03e6882 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -6,7 +6,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl -from mcp.server.types import InitializationOptions +from mcp.server.models import InitializationOptions from mcp.shared.session import ( BaseSession, RequestResponder, diff --git a/src/mcp/server/types.py b/src/mcp/server/types.py deleted file mode 100644 index 7946a4b..0000000 --- a/src/mcp/server/types.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -This module provides simpler types to use with the server for managing prompts -and tools. -""" - -from dataclasses import dataclass -from typing import Literal - -from pydantic import BaseModel - -from mcp.types import ( - BlobResourceContents, - Role, - ServerCapabilities, - TextResourceContents, -) - - -@dataclass -class ImageContent: - type: Literal["image"] - data: str - mime_type: str - - -@dataclass -class EmbeddedResource: - resource: TextResourceContents | BlobResourceContents - - -@dataclass -class Message: - role: Role - content: str | ImageContent | EmbeddedResource - - -@dataclass -class PromptResponse: - messages: list[Message] - desc: str | None = None - - -class InitializationOptions(BaseModel): - server_name: str - server_version: str - capabilities: ServerCapabilities diff --git a/tests/conftest.py b/tests/conftest.py index 10f3ed6..8d792aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ from pydantic import AnyUrl from mcp.server import Server -from mcp.server.types import InitializationOptions +from mcp.server.models import InitializationOptions from mcp.types import Resource, ServerCapabilities TEST_INITIALIZATION_OPTIONS = InitializationOptions( diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 66ac58d..728d276 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -4,7 +4,7 @@ from mcp.client.session import ClientSession from mcp.server import NotificationOptions, Server from mcp.server.session import ServerSession -from mcp.server.types import InitializationOptions +from mcp.server.models import InitializationOptions from mcp.types import ( ClientNotification, InitializedNotification, diff --git a/uv.lock b/uv.lock index 35e52f4..49d4e13 100644 --- a/uv.lock +++ b/uv.lock @@ -331,15 +331,14 @@ wheels = [ [[package]] name = "pyright" -version = "1.1.388" +version = "1.1.378" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodeenv" }, - { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9c/83/e9867538a794638d2d20ac3ab3106a31aca1d9cfea530c9b2921809dae03/pyright-1.1.388.tar.gz", hash = "sha256:0166d19b716b77fd2d9055de29f71d844874dbc6b9d3472ccd22df91db3dfa34", size = 21939 } +sdist = { url = "https://files.pythonhosted.org/packages/3d/f0/e8aa5555410d88f898bef04da2102b0a9bf144658c98d34872e91621ced2/pyright-1.1.378.tar.gz", hash = "sha256:78a043be2876d12d0af101d667e92c7734f3ebb9db71dccc2c220e7e7eb89ca2", size = 17486 } wheels = [ - { url = "https://files.pythonhosted.org/packages/03/57/7fb00363b7f267a398c5bdf4f55f3e64f7c2076b2e7d2901b3373d52b6ff/pyright-1.1.388-py3-none-any.whl", hash = "sha256:c7068e9f2c23539c6ac35fc9efac6c6c1b9aa5a0ce97a9a8a6cf0090d7cbf84c", size = 18579 }, + { url = "https://files.pythonhosted.org/packages/38/c6/f0d4bc20c13b20cecfbf13c699477c825e45767f1dc5068137323f86e495/pyright-1.1.378-py3-none-any.whl", hash = "sha256:8853776138b01bc284da07ac481235be7cc89d3176b073d2dba73636cb95be79", size = 18222 }, ] [[package]] From b9b44e6dadfe0378bed59d00ab2fa3ca5b4a1a46 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Mon, 11 Nov 2024 20:14:03 +0000 Subject: [PATCH 2/3] Don't re-export types We will be a bit more low level and expect callees to import mcp.types instead of relying in re-exported types. This makes usage more explicit and avoids potential collisions in mcp.server. --- src/mcp/client/session.py | 237 ++++++++++++------------------------ src/mcp/client/sse.py | 12 +- src/mcp/client/stdio.py | 12 +- src/mcp/server/__init__.py | 204 +++++++++++-------------------- src/mcp/server/session.py | 141 ++++++++------------- src/mcp/server/sse.py | 14 +-- src/mcp/server/stdio.py | 12 +- src/mcp/server/websocket.py | 12 +- 8 files changed, 231 insertions(+), 413 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index a9b8d54..0f3e313 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -5,83 +5,54 @@ from mcp.shared.session import BaseSession from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -from mcp.types import ( - LATEST_PROTOCOL_VERSION, - CallToolResult, - ClientCapabilities, - ClientNotification, - ClientRequest, - ClientResult, - CompleteResult, - EmptyResult, - GetPromptResult, - Implementation, - InitializedNotification, - InitializeResult, - JSONRPCMessage, - ListPromptsResult, - ListResourcesResult, - ListToolsResult, - LoggingLevel, - PromptReference, - ReadResourceResult, - ResourceReference, - RootsCapability, - ServerNotification, - ServerRequest, -) +import mcp.types as types class ClientSession( BaseSession[ - ClientRequest, - ClientNotification, - ClientResult, - ServerRequest, - ServerNotification, + types.ClientRequest, + types.ClientNotification, + types.ClientResult, + types.ServerRequest, + types.ServerNotification, ] ): def __init__( self, - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[types.JSONRPCMessage], read_timeout_seconds: timedelta | None = None, ) -> None: super().__init__( read_stream, write_stream, - ServerRequest, - ServerNotification, + types.ServerRequest, + types.ServerNotification, read_timeout_seconds=read_timeout_seconds, ) - async def initialize(self) -> InitializeResult: - from mcp.types import ( - InitializeRequest, - InitializeRequestParams, - ) - + async def initialize(self) -> types.InitializeResult: result = await self.send_request( - ClientRequest( - InitializeRequest( + types.ClientRequest( + types.InitializeRequest( method="initialize", - params=InitializeRequestParams( - protocolVersion=LATEST_PROTOCOL_VERSION, - capabilities=ClientCapabilities( + params=types.InitializeRequestParams( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ClientCapabilities( sampling=None, experimental=None, - roots=RootsCapability( + roots=types.RootsCapability( # TODO: Should this be based on whether we # _will_ send notifications, or only whether # they're supported? listChanged=True ), ), - clientInfo=Implementation(name="mcp", version="0.1.0"), + clientInfo=types.Implementation(name="mcp", version="0.1.0"), ), ) ), - InitializeResult, + types.InitializeResult, ) if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: @@ -91,40 +62,33 @@ async def initialize(self) -> InitializeResult: ) await self.send_notification( - ClientNotification( - InitializedNotification(method="notifications/initialized") + types.ClientNotification( + types.InitializedNotification(method="notifications/initialized") ) ) return result - async def send_ping(self) -> EmptyResult: + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" - from mcp.types import PingRequest - return await self.send_request( - ClientRequest( - PingRequest( + types.ClientRequest( + types.PingRequest( method="ping", ) ), - EmptyResult, + types.EmptyResult, ) async def send_progress_notification( self, progress_token: str | int, progress: float, total: float | None = None ) -> None: """Send a progress notification.""" - from mcp.types import ( - ProgressNotification, - ProgressNotificationParams, - ) - await self.send_notification( - ClientNotification( - ProgressNotification( + types.ClientNotification( + types.ProgressNotification( method="notifications/progress", - params=ProgressNotificationParams( + params=types.ProgressNotificationParams( progressToken=progress_token, progress=progress, total=total, @@ -133,180 +97,137 @@ async def send_progress_notification( ) ) - async def set_logging_level(self, level: LoggingLevel) -> EmptyResult: + async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult: """Send a logging/setLevel request.""" - from mcp.types import ( - SetLevelRequest, - SetLevelRequestParams, - ) - return await self.send_request( - ClientRequest( - SetLevelRequest( + types.ClientRequest( + types.SetLevelRequest( method="logging/setLevel", - params=SetLevelRequestParams(level=level), + params=types.SetLevelRequestParams(level=level), ) ), - EmptyResult, + types.EmptyResult, ) - async def list_resources(self) -> ListResourcesResult: + async def list_resources(self) -> types.ListResourcesResult: """Send a resources/list request.""" - from mcp.types import ( - ListResourcesRequest, - ) - return await self.send_request( - ClientRequest( - ListResourcesRequest( + types.ClientRequest( + types.ListResourcesRequest( method="resources/list", ) ), - ListResourcesResult, + types.ListResourcesResult, ) - async def read_resource(self, uri: AnyUrl) -> ReadResourceResult: + async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: """Send a resources/read request.""" - from mcp.types import ( - ReadResourceRequest, - ReadResourceRequestParams, - ) - return await self.send_request( - ClientRequest( - ReadResourceRequest( + types.ClientRequest( + types.ReadResourceRequest( method="resources/read", - params=ReadResourceRequestParams(uri=uri), + params=types.ReadResourceRequestParams(uri=uri), ) ), - ReadResourceResult, + types.ReadResourceResult, ) - async def subscribe_resource(self, uri: AnyUrl) -> EmptyResult: + async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: """Send a resources/subscribe request.""" - from mcp.types import ( - SubscribeRequest, - SubscribeRequestParams, - ) - return await self.send_request( - ClientRequest( - SubscribeRequest( + types.ClientRequest( + types.SubscribeRequest( method="resources/subscribe", - params=SubscribeRequestParams(uri=uri), + params=types.SubscribeRequestParams(uri=uri), ) ), - EmptyResult, + types.EmptyResult, ) - async def unsubscribe_resource(self, uri: AnyUrl) -> EmptyResult: + async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: """Send a resources/unsubscribe request.""" - from mcp.types import ( - UnsubscribeRequest, - UnsubscribeRequestParams, - ) - return await self.send_request( - ClientRequest( - UnsubscribeRequest( + types.ClientRequest( + types.UnsubscribeRequest( method="resources/unsubscribe", - params=UnsubscribeRequestParams(uri=uri), + params=types.UnsubscribeRequestParams(uri=uri), ) ), - EmptyResult, + types.EmptyResult, ) async def call_tool( self, name: str, arguments: dict | None = None - ) -> CallToolResult: + ) -> types.CallToolResult: """Send a tools/call request.""" - from mcp.types import ( - CallToolRequest, - CallToolRequestParams, - ) - return await self.send_request( - ClientRequest( - CallToolRequest( + types.ClientRequest( + types.CallToolRequest( method="tools/call", - params=CallToolRequestParams(name=name, arguments=arguments), + params=types.CallToolRequestParams(name=name, arguments=arguments), ) ), - CallToolResult, + types.CallToolResult, ) - async def list_prompts(self) -> ListPromptsResult: + async def list_prompts(self) -> types.ListPromptsResult: """Send a prompts/list request.""" - from mcp.types import ListPromptsRequest - return await self.send_request( - ClientRequest( - ListPromptsRequest( + types.ClientRequest( + types.ListPromptsRequest( method="prompts/list", ) ), - ListPromptsResult, + types.ListPromptsResult, ) async def get_prompt( self, name: str, arguments: dict[str, str] | None = None - ) -> GetPromptResult: + ) -> types.GetPromptResult: """Send a prompts/get request.""" - from mcp.types import GetPromptRequest, GetPromptRequestParams - return await self.send_request( - ClientRequest( - GetPromptRequest( + types.ClientRequest( + types.GetPromptRequest( method="prompts/get", - params=GetPromptRequestParams(name=name, arguments=arguments), + params=types.GetPromptRequestParams(name=name, arguments=arguments), ) ), - GetPromptResult, + types.GetPromptResult, ) async def complete( - self, ref: ResourceReference | PromptReference, argument: dict - ) -> CompleteResult: + self, ref: types.ResourceReference | types.PromptReference, argument: dict + ) -> types.CompleteResult: """Send a completion/complete request.""" - from mcp.types import ( - CompleteRequest, - CompleteRequestParams, - CompletionArgument, - ) - return await self.send_request( - ClientRequest( - CompleteRequest( + types.ClientRequest( + types.CompleteRequest( method="completion/complete", - params=CompleteRequestParams( + params=types.CompleteRequestParams( ref=ref, - argument=CompletionArgument(**argument), + argument=types.CompletionArgument(**argument), ), ) ), - CompleteResult, + types.CompleteResult, ) - async def list_tools(self) -> ListToolsResult: + async def list_tools(self) -> types.ListToolsResult: """Send a tools/list request.""" - from mcp.types import ListToolsRequest - return await self.send_request( - ClientRequest( - ListToolsRequest( + types.ClientRequest( + types.ListToolsRequest( method="tools/list", ) ), - ListToolsResult, + types.ListToolsResult, ) async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" - from mcp.types import RootsListChangedNotification - await self.send_notification( - ClientNotification( - RootsListChangedNotification( + types.ClientNotification( + types.RootsListChangedNotification( method="notifications/roots/list_changed", ) ) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index b5c36db..c79f48a 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -9,7 +9,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import aconnect_sse -from mcp.types import JSONRPCMessage +import mcp.types as types logger = logging.getLogger(__name__) @@ -31,11 +31,11 @@ async def sse_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. """ - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -85,7 +85,7 @@ async def sse_reader( case "message": try: message = ( - JSONRPCMessage.model_validate_json( + types.JSONRPCMessage.model_validate_json( sse.data ) ) diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 6a29138..e79a816 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -8,7 +8,7 @@ from anyio.streams.text import TextReceiveStream from pydantic import BaseModel, Field -from mcp.types import JSONRPCMessage +import mcp.types as types # Environment variables to inherit by default DEFAULT_INHERITED_ENV_VARS = ( @@ -72,11 +72,11 @@ async def stdio_client(server: StdioServerParameters): Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout. """ - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -99,7 +99,7 @@ async def stdout_reader(): for line in lines: try: - message = JSONRPCMessage.model_validate_json(line) + message = types.JSONRPCMessage.model_validate_json(line) except Exception as exc: await read_stream_writer.send(exc) continue diff --git a/src/mcp/server/__init__.py b/src/mcp/server/__init__.py index 29133e0..abfc40d 100644 --- a/src/mcp/server/__init__.py +++ b/src/mcp/server/__init__.py @@ -12,48 +12,7 @@ from mcp.server.stdio import stdio_server as stdio_server from mcp.shared.context import RequestContext from mcp.shared.session import RequestResponder -from mcp.types import ( - METHOD_NOT_FOUND, - CallToolRequest, - GetPromptResult, - GetPromptRequest, - GetPromptResult, - ImageContent, - ClientNotification, - ClientRequest, - CompleteRequest, - EmbeddedResource, - EmptyResult, - ErrorData, - JSONRPCMessage, - ListPromptsRequest, - ListPromptsResult, - ListResourcesRequest, - ListResourcesResult, - ListToolsRequest, - ListToolsResult, - LoggingCapability, - LoggingLevel, - PingRequest, - ProgressNotification, - Prompt, - PromptMessage, - PromptReference, - PromptsCapability, - ReadResourceRequest, - ReadResourceResult, - Resource, - ResourceReference, - ResourcesCapability, - ServerCapabilities, - ServerResult, - SetLevelRequest, - SubscribeRequest, - TextContent, - Tool, - ToolsCapability, - UnsubscribeRequest, -) +import mcp.types as types logger = logging.getLogger(__name__) @@ -77,8 +36,8 @@ def __init__( class Server: def __init__(self, name: str): self.name = name - self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = { - PingRequest: _ping_handler, + self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = { + types.PingRequest: _ping_handler, } self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self.notification_options = NotificationOptions() @@ -116,7 +75,7 @@ def get_capabilities( self, notification_options: NotificationOptions, experimental_capabilities: dict[str, dict[str, Any]], - ) -> ServerCapabilities: + ) -> types.ServerCapabilities: """Convert existing handlers to a ServerCapabilities object.""" prompts_capability = None resources_capability = None @@ -124,28 +83,28 @@ def get_capabilities( logging_capability = None # Set prompt capabilities if handler exists - if ListPromptsRequest in self.request_handlers: - prompts_capability = PromptsCapability( + if types.ListPromptsRequest in self.request_handlers: + prompts_capability = types.PromptsCapability( listChanged=notification_options.prompts_changed ) # Set resource capabilities if handler exists - if ListResourcesRequest in self.request_handlers: - resources_capability = ResourcesCapability( + if types.ListResourcesRequest in self.request_handlers: + resources_capability = types.ResourcesCapability( subscribe=False, listChanged=notification_options.resources_changed ) # Set tool capabilities if handler exists - if ListToolsRequest in self.request_handlers: - tools_capability = ToolsCapability( + if types.ListToolsRequest in self.request_handlers: + tools_capability = types.ToolsCapability( listChanged=notification_options.tools_changed ) # Set logging capabilities if handler exists - if SetLevelRequest in self.request_handlers: - logging_capability = LoggingCapability() + if types.SetLevelRequest in self.request_handlers: + logging_capability = types.LoggingCapability() - return ServerCapabilities( + return types.ServerCapabilities( prompts=prompts_capability, resources=resources_capability, tools=tools_capability, @@ -159,14 +118,14 @@ def request_context(self) -> RequestContext: return request_ctx.get() def list_prompts(self): - def decorator(func: Callable[[], Awaitable[list[Prompt]]]): + def decorator(func: Callable[[], Awaitable[list[types.Prompt]]]): logger.debug("Registering handler for PromptListRequest") async def handler(_: Any): prompts = await func() - return ServerResult(ListPromptsResult(prompts=prompts)) + return types.ServerResult(types.ListPromptsResult(prompts=prompts)) - self.request_handlers[ListPromptsRequest] = handler + self.request_handlers[types.ListPromptsRequest] = handler return func return decorator @@ -174,47 +133,42 @@ async def handler(_: Any): def get_prompt(self): def decorator( func: Callable[ - [str, dict[str, str] | None], Awaitable[GetPromptResult] + [str, dict[str, str] | None], Awaitable[types.GetPromptResult] ], ): logger.debug("Registering handler for GetPromptRequest") - async def handler(req: GetPromptRequest): + async def handler(req: types.GetPromptRequest): prompt_get = await func(req.params.name, req.params.arguments) - return ServerResult(prompt_get) + return types.ServerResult(prompt_get) - self.request_handlers[GetPromptRequest] = handler + self.request_handlers[types.GetPromptRequest] = handler return func return decorator def list_resources(self): - def decorator(func: Callable[[], Awaitable[list[Resource]]]): + def decorator(func: Callable[[], Awaitable[list[types.Resource]]]): logger.debug("Registering handler for ListResourcesRequest") async def handler(_: Any): resources = await func() - return ServerResult(ListResourcesResult(resources=resources)) + return types.ServerResult(types.ListResourcesResult(resources=resources)) - self.request_handlers[ListResourcesRequest] = handler + self.request_handlers[types.ListResourcesRequest] = handler return func return decorator def read_resource(self): - from mcp.types import ( - BlobResourceContents, - TextResourceContents, - ) - def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]): logger.debug("Registering handler for ReadResourceRequest") - async def handler(req: ReadResourceRequest): + async def handler(req: types.ReadResourceRequest): result = await func(req.params.uri) match result: case str(s): - content = TextResourceContents( + content = types.TextResourceContents( uri=req.params.uri, text=s, mimeType="text/plain", @@ -222,130 +176,117 @@ async def handler(req: ReadResourceRequest): case bytes(b): import base64 - content = BlobResourceContents( + content = types.BlobResourceContents( uri=req.params.uri, blob=base64.urlsafe_b64encode(b).decode(), mimeType="application/octet-stream", ) - return ServerResult( - ReadResourceResult( + return types.ServerResult( + types.ReadResourceResult( contents=[content], ) ) - self.request_handlers[ReadResourceRequest] = handler + self.request_handlers[types.ReadResourceRequest] = handler return func return decorator def set_logging_level(self): - from mcp.types import EmptyResult - - def decorator(func: Callable[[LoggingLevel], Awaitable[None]]): + def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]): logger.debug("Registering handler for SetLevelRequest") - async def handler(req: SetLevelRequest): + async def handler(req: types.SetLevelRequest): await func(req.params.level) - return ServerResult(EmptyResult()) + return types.ServerResult(types.EmptyResult()) - self.request_handlers[SetLevelRequest] = handler + self.request_handlers[types.SetLevelRequest] = handler return func return decorator def subscribe_resource(self): - from mcp.types import EmptyResult - def decorator(func: Callable[[AnyUrl], Awaitable[None]]): logger.debug("Registering handler for SubscribeRequest") - async def handler(req: SubscribeRequest): + async def handler(req: types.SubscribeRequest): await func(req.params.uri) - return ServerResult(EmptyResult()) + return types.ServerResult(types.EmptyResult()) - self.request_handlers[SubscribeRequest] = handler + self.request_handlers[types.SubscribeRequest] = handler return func return decorator def unsubscribe_resource(self): - from mcp.types import EmptyResult - def decorator(func: Callable[[AnyUrl], Awaitable[None]]): logger.debug("Registering handler for UnsubscribeRequest") - async def handler(req: UnsubscribeRequest): + async def handler(req: types.UnsubscribeRequest): await func(req.params.uri) - return ServerResult(EmptyResult()) + return types.ServerResult(types.EmptyResult()) - self.request_handlers[UnsubscribeRequest] = handler + self.request_handlers[types.UnsubscribeRequest] = handler return func return decorator def list_tools(self): - def decorator(func: Callable[[], Awaitable[list[Tool]]]): + def decorator(func: Callable[[], Awaitable[list[types.Tool]]]): logger.debug("Registering handler for ListToolsRequest") async def handler(_: Any): tools = await func() - return ServerResult(ListToolsResult(tools=tools)) + return types.ServerResult(types.ListToolsResult(tools=tools)) - self.request_handlers[ListToolsRequest] = handler + self.request_handlers[types.ListToolsRequest] = handler return func return decorator def call_tool(self): - from mcp.types import ( - CallToolResult, - EmbeddedResource, - ImageContent, - TextContent, - ) - def decorator( func: Callable[ ..., - Awaitable[Sequence[TextContent | ImageContent | EmbeddedResource]], + Awaitable[Sequence[types.TextContent | types.ImageContent | types.EmbeddedResource]], ], ): logger.debug("Registering handler for CallToolRequest") - async def handler(req: CallToolRequest): + async def handler(req: types.CallToolRequest): try: results = await func(req.params.name, (req.params.arguments or {})) content = [] for result in results: match result: case str() as text: - content.append(TextContent(type="text", text=text)) - case ImageContent() as img: + content.append(types.TextContent(type="text", text=text)) + case types.ImageContent() as img: content.append( - ImageContent( + types.ImageContent( type="image", data=img.data, mimeType=img.mimeType, ) ) - case EmbeddedResource() as resource: + case types.EmbeddedResource() as resource: content.append( - EmbeddedResource( + types.EmbeddedResource( type="resource", resource=resource.resource ) ) - return ServerResult(CallToolResult(content=content, isError=False)) + return types.ServerResult(types.CallToolResult(content=content, isError=False)) except Exception as e: - return ServerResult( - CallToolResult( - content=[TextContent(type="text", text=str(e))], + return types.ServerResult( + types.CallToolResult( + content=[types.TextContent(type="text", text=str(e))], isError=True, ) ) - self.request_handlers[CallToolRequest] = handler + self.request_handlers[types.CallToolRequest] = handler return func return decorator @@ -356,47 +297,46 @@ def decorator( ): logger.debug("Registering handler for ProgressNotification") - async def handler(req: ProgressNotification): + async def handler(req: types.ProgressNotification): await func( req.params.progressToken, req.params.progress, req.params.total ) - self.notification_handlers[ProgressNotification] = handler + self.notification_handlers[types.ProgressNotification] = handler return func return decorator def completion(self): """Provides completions for prompts and resource templates""" - from mcp.types import CompleteResult, Completion, CompletionArgument def decorator( func: Callable[ - [PromptReference | ResourceReference, CompletionArgument], - Awaitable[Completion | None], + [types.PromptReference | types.ResourceReference, types.CompletionArgument], + Awaitable[types.Completion | None], ], ): logger.debug("Registering handler for CompleteRequest") - async def handler(req: CompleteRequest): + async def handler(req: types.CompleteRequest): completion = await func(req.params.ref, req.params.argument) - return ServerResult( - CompleteResult( + return types.ServerResult( + types.CompleteResult( completion=completion if completion is not None - else Completion(values=[], total=None, hasMore=None), + else types.Completion(values=[], total=None, hasMore=None), ) ) - self.request_handlers[CompleteRequest] = handler + self.request_handlers[types.CompleteRequest] = handler return func return decorator async def run( self, - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[types.JSONRPCMessage], initialization_options: InitializationOptions, # When True, exceptions are returned as messages to the client. # When False, exceptions are raised, which will cause the server to shut down @@ -412,7 +352,7 @@ async def run( logger.debug(f"Received message: {message}") match message: - case RequestResponder(request=ClientRequest(root=req)): + case RequestResponder(request=types.ClientRequest(root=req)): logger.info( f"Processing request of type {type(req).__name__}" ) @@ -437,7 +377,7 @@ async def run( except Exception as err: if raise_exceptions: raise err - response = ErrorData( + response = types.ErrorData( code=0, message=str(err), data=None ) finally: @@ -448,14 +388,14 @@ async def run( await message.respond(response) else: await message.respond( - ErrorData( - code=METHOD_NOT_FOUND, + types.ErrorData( + code=types.METHOD_NOT_FOUND, message="Method not found", ) ) logger.debug("Response sent") - case ClientNotification(root=notify): + case types.ClientNotification(root=notify): if type(notify) in self.notification_handlers: assert type(notify) in self.notification_handlers @@ -479,5 +419,5 @@ async def run( ) -async def _ping_handler(request: PingRequest) -> ServerResult: - return ServerResult(EmptyResult()) +async def _ping_handler(request: types.PingRequest) -> types.ServerResult: + return types.ServerResult(types.EmptyResult()) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 03e6882..97b70bd 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -11,29 +11,7 @@ BaseSession, RequestResponder, ) -from mcp.types import ( - LATEST_PROTOCOL_VERSION, - ClientNotification, - ClientRequest, - CreateMessageResult, - EmptyResult, - Implementation, - IncludeContext, - InitializedNotification, - InitializeRequest, - InitializeResult, - JSONRPCMessage, - ListRootsResult, - LoggingLevel, - ModelPreferences, - PromptListChangedNotification, - ResourceListChangedNotification, - SamplingMessage, - ServerNotification, - ServerRequest, - ServerResult, - ToolListChangedNotification, -) +import mcp.types as types class InitializationState(Enum): @@ -44,37 +22,37 @@ class InitializationState(Enum): class ServerSession( BaseSession[ - ServerRequest, - ServerNotification, - ServerResult, - ClientRequest, - ClientNotification, + types.ServerRequest, + types.ServerNotification, + types.ServerResult, + types.ClientRequest, + types.ClientNotification, ] ): _initialized: InitializationState = InitializationState.NotInitialized def __init__( self, - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[types.JSONRPCMessage], init_options: InitializationOptions, ) -> None: - super().__init__(read_stream, write_stream, ClientRequest, ClientNotification) + super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification) self._initialization_state = InitializationState.NotInitialized self._init_options = init_options async def _received_request( - self, responder: RequestResponder[ClientRequest, ServerResult] + self, responder: RequestResponder[types.ClientRequest, types.ServerResult] ): match responder.request.root: - case InitializeRequest(): + case types.InitializeRequest(): self._initialization_state = InitializationState.Initializing await responder.respond( - ServerResult( - InitializeResult( - protocolVersion=LATEST_PROTOCOL_VERSION, + types.ServerResult( + types.InitializeResult( + protocolVersion=types.LATEST_PROTOCOL_VERSION, capabilities=self._init_options.capabilities, - serverInfo=Implementation( + serverInfo=types.Implementation( name=self._init_options.server_name, version=self._init_options.server_version, ), @@ -87,11 +65,11 @@ async def _received_request( "Received request before initialization was complete" ) - async def _received_notification(self, notification: ClientNotification) -> None: + async def _received_notification(self, notification: types.ClientNotification) -> None: # Need this to avoid ASYNC910 await anyio.lowlevel.checkpoint() match notification.root: - case InitializedNotification(): + case types.InitializedNotification(): self._initialization_state = InitializationState.Initialized case _: if self._initialization_state != InitializationState.Initialized: @@ -100,19 +78,14 @@ async def _received_notification(self, notification: ClientNotification) -> None ) async def send_log_message( - self, level: LoggingLevel, data: Any, logger: str | None = None + self, level: types.LoggingLevel, data: Any, logger: str | None = None ) -> None: """Send a log message notification.""" - from mcp.types import ( - LoggingMessageNotification, - LoggingMessageNotificationParams, - ) - await self.send_notification( - ServerNotification( - LoggingMessageNotification( + types.ServerNotification( + types.LoggingMessageNotification( method="notifications/message", - params=LoggingMessageNotificationParams( + params=types.LoggingMessageNotificationParams( level=level, data=data, logger=logger, @@ -123,43 +96,33 @@ async def send_log_message( async def send_resource_updated(self, uri: AnyUrl) -> None: """Send a resource updated notification.""" - from mcp.types import ( - ResourceUpdatedNotification, - ResourceUpdatedNotificationParams, - ) - await self.send_notification( - ServerNotification( - ResourceUpdatedNotification( + types.ServerNotification( + types.ResourceUpdatedNotification( method="notifications/resources/updated", - params=ResourceUpdatedNotificationParams(uri=uri), + params=types.ResourceUpdatedNotificationParams(uri=uri), ) ) ) async def create_message( self, - messages: list[SamplingMessage], + messages: list[types.SamplingMessage], *, max_tokens: int, system_prompt: str | None = None, - include_context: IncludeContext | None = None, + include_context: types.IncludeContext | None = None, temperature: float | None = None, stop_sequences: list[str] | None = None, metadata: dict[str, Any] | None = None, - model_preferences: ModelPreferences | None = None, - ) -> CreateMessageResult: + model_preferences: types.ModelPreferences | None = None, + ) -> types.CreateMessageResult: """Send a sampling/create_message request.""" - from mcp.types import ( - CreateMessageRequest, - CreateMessageRequestParams, - ) - return await self.send_request( - ServerRequest( - CreateMessageRequest( + types.ServerRequest( + types.CreateMessageRequest( method="sampling/createMessage", - params=CreateMessageRequestParams( + params=types.CreateMessageRequestParams( messages=messages, systemPrompt=system_prompt, includeContext=include_context, @@ -171,46 +134,40 @@ async def create_message( ), ) ), - CreateMessageResult, + types.CreateMessageResult, ) - async def list_roots(self) -> ListRootsResult: + async def list_roots(self) -> types.ListRootsResult: """Send a roots/list request.""" - from mcp.types import ListRootsRequest - return await self.send_request( - ServerRequest( - ListRootsRequest( + types.ServerRequest( + types.ListRootsRequest( method="roots/list", ) ), - ListRootsResult, + types.ListRootsResult, ) - async def send_ping(self) -> EmptyResult: + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" - from mcp.types import PingRequest - return await self.send_request( - ServerRequest( - PingRequest( + types.ServerRequest( + types.PingRequest( method="ping", ) ), - EmptyResult, + types.EmptyResult, ) async def send_progress_notification( self, progress_token: str | int, progress: float, total: float | None = None ) -> None: """Send a progress notification.""" - from mcp.types import ProgressNotification, ProgressNotificationParams - await self.send_notification( - ServerNotification( - ProgressNotification( + types.ServerNotification( + types.ProgressNotification( method="notifications/progress", - params=ProgressNotificationParams( + params=types.ProgressNotificationParams( progressToken=progress_token, progress=progress, total=total, @@ -222,8 +179,8 @@ async def send_progress_notification( async def send_resource_list_changed(self) -> None: """Send a resource list changed notification.""" await self.send_notification( - ServerNotification( - ResourceListChangedNotification( + types.ServerNotification( + types.ResourceListChangedNotification( method="notifications/resources/list_changed", ) ) @@ -232,8 +189,8 @@ async def send_resource_list_changed(self) -> None: async def send_tool_list_changed(self) -> None: """Send a tool list changed notification.""" await self.send_notification( - ServerNotification( - ToolListChangedNotification( + types.ServerNotification( + types.ToolListChangedNotification( method="notifications/tools/list_changed", ) ) @@ -242,8 +199,8 @@ async def send_tool_list_changed(self) -> None: async def send_prompt_list_changed(self) -> None: """Send a prompt list changed notification.""" await self.send_notification( - ServerNotification( - PromptListChangedNotification( + types.ServerNotification( + types.PromptListChangedNotification( method="notifications/prompts/list_changed", ) ) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 92ebb7a..4074fdb 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -12,7 +12,7 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send -from mcp.types import JSONRPCMessage +import mcp.types as types logger = logging.getLogger(__name__) @@ -30,7 +30,7 @@ class SseServerTransport: """ _endpoint: str - _read_stream_writers: dict[UUID, MemoryObjectSendStream[JSONRPCMessage | Exception]] + _read_stream_writers: dict[UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]] def __init__(self, endpoint: str) -> None: """ @@ -50,11 +50,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): raise ValueError("connect_sse can only handle HTTP requests") logger.debug("Setting up SSE connection") - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -125,7 +125,7 @@ async def handle_post_message( logger.debug(f"Received JSON: {json}") try: - message = JSONRPCMessage.model_validate(json) + message = types.JSONRPCMessage.model_validate(json) logger.debug(f"Validated client message: {message}") except ValidationError as err: logger.error(f"Failed to parse message: {err}") diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 29f6bb6..ffe4081 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -5,7 +5,7 @@ import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from mcp.types import JSONRPCMessage +import mcp.types as types @asynccontextmanager @@ -24,11 +24,11 @@ async def stdio_server( if not stdout: stdout = anyio.wrap_file(sys.stdout) - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -38,7 +38,7 @@ async def stdin_reader(): async with read_stream_writer: async for line in stdin: try: - message = JSONRPCMessage.model_validate_json(line) + message = types.JSONRPCMessage.model_validate_json(line) except Exception as exc: await read_stream_writer.send(exc) continue diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 2a6d812..bd3d632 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -6,7 +6,7 @@ from starlette.types import Receive, Scope, Send from starlette.websockets import WebSocket -from mcp.types import JSONRPCMessage +import mcp.types as types logger = logging.getLogger(__name__) @@ -21,11 +21,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send): websocket = WebSocket(scope, receive, send) await websocket.accept(subprotocol="mcp") - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -35,7 +35,7 @@ async def ws_reader(): async with read_stream_writer: async for message in websocket.iter_json(): try: - client_message = JSONRPCMessage.model_validate(message) + client_message = types.JSONRPCMessage.model_validate(message) except Exception as exc: await read_stream_writer.send(exc) continue From ec8c85edea126dd6aa3260df2c1c7023e66669d5 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Mon, 11 Nov 2024 20:17:39 +0000 Subject: [PATCH 3/3] run ruff --- src/mcp/client/session.py | 2 +- src/mcp/client/sse.py | 6 ++---- src/mcp/server/__init__.py | 29 ++++++++++++++++++++++------- src/mcp/server/__main__.py | 2 +- src/mcp/server/models.py | 3 --- src/mcp/server/session.py | 10 +++++++--- src/mcp/server/sse.py | 4 +++- tests/server/test_session.py | 2 +- 8 files changed, 37 insertions(+), 21 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 0f3e313..27ca74d 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -3,9 +3,9 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl +import mcp.types as types from mcp.shared.session import BaseSession from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -import mcp.types as types class ClientSession( diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index c79f48a..8a90e58 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -84,10 +84,8 @@ async def sse_reader( case "message": try: - message = ( - types.JSONRPCMessage.model_validate_json( - sse.data - ) + message = types.JSONRPCMessage.model_validate_json( # noqa: E501 + sse.data ) logger.debug( f"Received server message: {message}" diff --git a/src/mcp/server/__init__.py b/src/mcp/server/__init__.py index abfc40d..3ce8b7b 100644 --- a/src/mcp/server/__init__.py +++ b/src/mcp/server/__init__.py @@ -7,12 +7,12 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl +import mcp.types as types from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.stdio import stdio_server as stdio_server from mcp.shared.context import RequestContext from mcp.shared.session import RequestResponder -import mcp.types as types logger = logging.getLogger(__name__) @@ -36,7 +36,9 @@ def __init__( class Server: def __init__(self, name: str): self.name = name - self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = { + self.request_handlers: dict[ + type, Callable[..., Awaitable[types.ServerResult]] + ] = { types.PingRequest: _ping_handler, } self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} @@ -153,7 +155,9 @@ def decorator(func: Callable[[], Awaitable[list[types.Resource]]]): async def handler(_: Any): resources = await func() - return types.ServerResult(types.ListResourcesResult(resources=resources)) + return types.ServerResult( + types.ListResourcesResult(resources=resources) + ) self.request_handlers[types.ListResourcesRequest] = handler return func @@ -249,7 +253,11 @@ def call_tool(self): def decorator( func: Callable[ ..., - Awaitable[Sequence[types.TextContent | types.ImageContent | types.EmbeddedResource]], + Awaitable[ + Sequence[ + types.TextContent | types.ImageContent | types.EmbeddedResource + ] + ], ], ): logger.debug("Registering handler for CallToolRequest") @@ -261,7 +269,9 @@ async def handler(req: types.CallToolRequest): for result in results: match result: case str() as text: - content.append(types.TextContent(type="text", text=text)) + content.append( + types.TextContent(type="text", text=text) + ) case types.ImageContent() as img: content.append( types.ImageContent( @@ -277,7 +287,9 @@ async def handler(req: types.CallToolRequest): ) ) - return types.ServerResult(types.CallToolResult(content=content, isError=False)) + return types.ServerResult( + types.CallToolResult(content=content, isError=False) + ) except Exception as e: return types.ServerResult( types.CallToolResult( @@ -312,7 +324,10 @@ def completion(self): def decorator( func: Callable[ - [types.PromptReference | types.ResourceReference, types.CompletionArgument], + [ + types.PromptReference | types.ResourceReference, + types.CompletionArgument, + ], Awaitable[types.Completion | None], ], ): diff --git a/src/mcp/server/__main__.py b/src/mcp/server/__main__.py index 417380d..1970eca 100644 --- a/src/mcp/server/__main__.py +++ b/src/mcp/server/__main__.py @@ -4,9 +4,9 @@ import anyio +from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.stdio import stdio_server -from mcp.server.models import InitializationOptions from mcp.types import ServerCapabilities if not sys.warnoptions: diff --git a/src/mcp/server/models.py b/src/mcp/server/models.py index 8920597..377ed51 100644 --- a/src/mcp/server/models.py +++ b/src/mcp/server/models.py @@ -3,9 +3,6 @@ and tools. """ -from dataclasses import dataclass -from typing import Literal - from pydantic import BaseModel from mcp.types import ( diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 97b70bd..dd66307 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -6,12 +6,12 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl +import mcp.types as types from mcp.server.models import InitializationOptions from mcp.shared.session import ( BaseSession, RequestResponder, ) -import mcp.types as types class InitializationState(Enum): @@ -37,7 +37,9 @@ def __init__( write_stream: MemoryObjectSendStream[types.JSONRPCMessage], init_options: InitializationOptions, ) -> None: - super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification) + super().__init__( + read_stream, write_stream, types.ClientRequest, types.ClientNotification + ) self._initialization_state = InitializationState.NotInitialized self._init_options = init_options @@ -65,7 +67,9 @@ async def _received_request( "Received request before initialization was complete" ) - async def _received_notification(self, notification: types.ClientNotification) -> None: + async def _received_notification( + self, notification: types.ClientNotification + ) -> None: # Need this to avoid ASYNC910 await anyio.lowlevel.checkpoint() match notification.root: diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 4074fdb..f6e90bf 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -30,7 +30,9 @@ class SseServerTransport: """ _endpoint: str - _read_stream_writers: dict[UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]] + _read_stream_writers: dict[ + UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception] + ] def __init__(self, endpoint: str) -> None: """ diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 728d276..a78ca90 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -3,8 +3,8 @@ from mcp.client.session import ClientSession from mcp.server import NotificationOptions, Server -from mcp.server.session import ServerSession from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession from mcp.types import ( ClientNotification, InitializedNotification,