diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index a9b8d54..27ca74d 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -3,85 +3,56 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl +import mcp.types as types from mcp.shared.session import BaseSession from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -from mcp.types import ( - LATEST_PROTOCOL_VERSION, - CallToolResult, - ClientCapabilities, - ClientNotification, - ClientRequest, - ClientResult, - CompleteResult, - EmptyResult, - GetPromptResult, - Implementation, - InitializedNotification, - InitializeResult, - JSONRPCMessage, - ListPromptsResult, - ListResourcesResult, - ListToolsResult, - LoggingLevel, - PromptReference, - ReadResourceResult, - ResourceReference, - RootsCapability, - ServerNotification, - ServerRequest, -) class ClientSession( BaseSession[ - ClientRequest, - ClientNotification, - ClientResult, - ServerRequest, - ServerNotification, + types.ClientRequest, + types.ClientNotification, + types.ClientResult, + types.ServerRequest, + types.ServerNotification, ] ): def __init__( self, - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[types.JSONRPCMessage], read_timeout_seconds: timedelta | None = None, ) -> None: super().__init__( read_stream, write_stream, - ServerRequest, - ServerNotification, + types.ServerRequest, + types.ServerNotification, read_timeout_seconds=read_timeout_seconds, ) - async def initialize(self) -> InitializeResult: - from mcp.types import ( - InitializeRequest, - InitializeRequestParams, - ) - + async def initialize(self) -> types.InitializeResult: result = await self.send_request( - ClientRequest( - InitializeRequest( + types.ClientRequest( + types.InitializeRequest( method="initialize", - params=InitializeRequestParams( - protocolVersion=LATEST_PROTOCOL_VERSION, - capabilities=ClientCapabilities( + params=types.InitializeRequestParams( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ClientCapabilities( sampling=None, experimental=None, - roots=RootsCapability( + roots=types.RootsCapability( # TODO: Should this be based on whether we # _will_ send notifications, or only whether # they're supported? listChanged=True ), ), - clientInfo=Implementation(name="mcp", version="0.1.0"), + clientInfo=types.Implementation(name="mcp", version="0.1.0"), ), ) ), - InitializeResult, + types.InitializeResult, ) if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: @@ -91,40 +62,33 @@ async def initialize(self) -> InitializeResult: ) await self.send_notification( - ClientNotification( - InitializedNotification(method="notifications/initialized") + types.ClientNotification( + types.InitializedNotification(method="notifications/initialized") ) ) return result - async def send_ping(self) -> EmptyResult: + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" - from mcp.types import PingRequest - return await self.send_request( - ClientRequest( - PingRequest( + types.ClientRequest( + types.PingRequest( method="ping", ) ), - EmptyResult, + types.EmptyResult, ) async def send_progress_notification( self, progress_token: str | int, progress: float, total: float | None = None ) -> None: """Send a progress notification.""" - from mcp.types import ( - ProgressNotification, - ProgressNotificationParams, - ) - await self.send_notification( - ClientNotification( - ProgressNotification( + types.ClientNotification( + types.ProgressNotification( method="notifications/progress", - params=ProgressNotificationParams( + params=types.ProgressNotificationParams( progressToken=progress_token, progress=progress, total=total, @@ -133,180 +97,137 @@ async def send_progress_notification( ) ) - async def set_logging_level(self, level: LoggingLevel) -> EmptyResult: + async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult: """Send a logging/setLevel request.""" - from mcp.types import ( - SetLevelRequest, - SetLevelRequestParams, - ) - return await self.send_request( - ClientRequest( - SetLevelRequest( + types.ClientRequest( + types.SetLevelRequest( method="logging/setLevel", - params=SetLevelRequestParams(level=level), + params=types.SetLevelRequestParams(level=level), ) ), - EmptyResult, + types.EmptyResult, ) - async def list_resources(self) -> ListResourcesResult: + async def list_resources(self) -> types.ListResourcesResult: """Send a resources/list request.""" - from mcp.types import ( - ListResourcesRequest, - ) - return await self.send_request( - ClientRequest( - ListResourcesRequest( + types.ClientRequest( + types.ListResourcesRequest( method="resources/list", ) ), - ListResourcesResult, + types.ListResourcesResult, ) - async def read_resource(self, uri: AnyUrl) -> ReadResourceResult: + async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: """Send a resources/read request.""" - from mcp.types import ( - ReadResourceRequest, - ReadResourceRequestParams, - ) - return await self.send_request( - ClientRequest( - ReadResourceRequest( + types.ClientRequest( + types.ReadResourceRequest( method="resources/read", - params=ReadResourceRequestParams(uri=uri), + params=types.ReadResourceRequestParams(uri=uri), ) ), - ReadResourceResult, + types.ReadResourceResult, ) - async def subscribe_resource(self, uri: AnyUrl) -> EmptyResult: + async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: """Send a resources/subscribe request.""" - from mcp.types import ( - SubscribeRequest, - SubscribeRequestParams, - ) - return await self.send_request( - ClientRequest( - SubscribeRequest( + types.ClientRequest( + types.SubscribeRequest( method="resources/subscribe", - params=SubscribeRequestParams(uri=uri), + params=types.SubscribeRequestParams(uri=uri), ) ), - EmptyResult, + types.EmptyResult, ) - async def unsubscribe_resource(self, uri: AnyUrl) -> EmptyResult: + async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: """Send a resources/unsubscribe request.""" - from mcp.types import ( - UnsubscribeRequest, - UnsubscribeRequestParams, - ) - return await self.send_request( - ClientRequest( - UnsubscribeRequest( + types.ClientRequest( + types.UnsubscribeRequest( method="resources/unsubscribe", - params=UnsubscribeRequestParams(uri=uri), + params=types.UnsubscribeRequestParams(uri=uri), ) ), - EmptyResult, + types.EmptyResult, ) async def call_tool( self, name: str, arguments: dict | None = None - ) -> CallToolResult: + ) -> types.CallToolResult: """Send a tools/call request.""" - from mcp.types import ( - CallToolRequest, - CallToolRequestParams, - ) - return await self.send_request( - ClientRequest( - CallToolRequest( + types.ClientRequest( + types.CallToolRequest( method="tools/call", - params=CallToolRequestParams(name=name, arguments=arguments), + params=types.CallToolRequestParams(name=name, arguments=arguments), ) ), - CallToolResult, + types.CallToolResult, ) - async def list_prompts(self) -> ListPromptsResult: + async def list_prompts(self) -> types.ListPromptsResult: """Send a prompts/list request.""" - from mcp.types import ListPromptsRequest - return await self.send_request( - ClientRequest( - ListPromptsRequest( + types.ClientRequest( + types.ListPromptsRequest( method="prompts/list", ) ), - ListPromptsResult, + types.ListPromptsResult, ) async def get_prompt( self, name: str, arguments: dict[str, str] | None = None - ) -> GetPromptResult: + ) -> types.GetPromptResult: """Send a prompts/get request.""" - from mcp.types import GetPromptRequest, GetPromptRequestParams - return await self.send_request( - ClientRequest( - GetPromptRequest( + types.ClientRequest( + types.GetPromptRequest( method="prompts/get", - params=GetPromptRequestParams(name=name, arguments=arguments), + params=types.GetPromptRequestParams(name=name, arguments=arguments), ) ), - GetPromptResult, + types.GetPromptResult, ) async def complete( - self, ref: ResourceReference | PromptReference, argument: dict - ) -> CompleteResult: + self, ref: types.ResourceReference | types.PromptReference, argument: dict + ) -> types.CompleteResult: """Send a completion/complete request.""" - from mcp.types import ( - CompleteRequest, - CompleteRequestParams, - CompletionArgument, - ) - return await self.send_request( - ClientRequest( - CompleteRequest( + types.ClientRequest( + types.CompleteRequest( method="completion/complete", - params=CompleteRequestParams( + params=types.CompleteRequestParams( ref=ref, - argument=CompletionArgument(**argument), + argument=types.CompletionArgument(**argument), ), ) ), - CompleteResult, + types.CompleteResult, ) - async def list_tools(self) -> ListToolsResult: + async def list_tools(self) -> types.ListToolsResult: """Send a tools/list request.""" - from mcp.types import ListToolsRequest - return await self.send_request( - ClientRequest( - ListToolsRequest( + types.ClientRequest( + types.ListToolsRequest( method="tools/list", ) ), - ListToolsResult, + types.ListToolsResult, ) async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" - from mcp.types import RootsListChangedNotification - await self.send_notification( - ClientNotification( - RootsListChangedNotification( + types.ClientNotification( + types.RootsListChangedNotification( method="notifications/roots/list_changed", ) ) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index b5c36db..8a90e58 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -9,7 +9,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import aconnect_sse -from mcp.types import JSONRPCMessage +import mcp.types as types logger = logging.getLogger(__name__) @@ -31,11 +31,11 @@ async def sse_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. """ - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -84,10 +84,8 @@ async def sse_reader( case "message": try: - message = ( - JSONRPCMessage.model_validate_json( - sse.data - ) + message = types.JSONRPCMessage.model_validate_json( # noqa: E501 + sse.data ) logger.debug( f"Received server message: {message}" diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 6a29138..e79a816 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -8,7 +8,7 @@ from anyio.streams.text import TextReceiveStream from pydantic import BaseModel, Field -from mcp.types import JSONRPCMessage +import mcp.types as types # Environment variables to inherit by default DEFAULT_INHERITED_ENV_VARS = ( @@ -72,11 +72,11 @@ async def stdio_client(server: StdioServerParameters): Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout. """ - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -99,7 +99,7 @@ async def stdout_reader(): for line in lines: try: - message = JSONRPCMessage.model_validate_json(line) + message = types.JSONRPCMessage.model_validate_json(line) except Exception as exc: await read_stream_writer.send(exc) continue diff --git a/src/mcp/server/__init__.py b/src/mcp/server/__init__.py index f7e66a3..3ce8b7b 100644 --- a/src/mcp/server/__init__.py +++ b/src/mcp/server/__init__.py @@ -7,49 +7,12 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl -from mcp.server import types +import mcp.types as types +from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.stdio import stdio_server as stdio_server from mcp.shared.context import RequestContext from mcp.shared.session import RequestResponder -from mcp.types import ( - METHOD_NOT_FOUND, - CallToolRequest, - ClientNotification, - ClientRequest, - CompleteRequest, - EmbeddedResource, - EmptyResult, - ErrorData, - JSONRPCMessage, - ListPromptsRequest, - ListPromptsResult, - ListResourcesRequest, - ListResourcesResult, - ListToolsRequest, - ListToolsResult, - LoggingCapability, - LoggingLevel, - PingRequest, - ProgressNotification, - Prompt, - PromptMessage, - PromptReference, - PromptsCapability, - ReadResourceRequest, - ReadResourceResult, - Resource, - ResourceReference, - ResourcesCapability, - ServerCapabilities, - ServerResult, - SetLevelRequest, - SubscribeRequest, - TextContent, - Tool, - ToolsCapability, - UnsubscribeRequest, -) logger = logging.getLogger(__name__) @@ -73,8 +36,10 @@ def __init__( class Server: def __init__(self, name: str): self.name = name - self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = { - PingRequest: _ping_handler, + self.request_handlers: dict[ + type, Callable[..., Awaitable[types.ServerResult]] + ] = { + types.PingRequest: _ping_handler, } self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self.notification_options = NotificationOptions() @@ -84,7 +49,7 @@ def create_initialization_options( self, notification_options: NotificationOptions | None = None, experimental_capabilities: dict[str, dict[str, Any]] | None = None, - ) -> types.InitializationOptions: + ) -> InitializationOptions: """Create initialization options from this server instance.""" def pkg_version(package: str) -> str: @@ -99,7 +64,7 @@ def pkg_version(package: str) -> str: return "unknown" - return types.InitializationOptions( + return InitializationOptions( server_name=self.name, server_version=pkg_version("mcp"), capabilities=self.get_capabilities( @@ -112,7 +77,7 @@ def get_capabilities( self, notification_options: NotificationOptions, experimental_capabilities: dict[str, dict[str, Any]], - ) -> ServerCapabilities: + ) -> types.ServerCapabilities: """Convert existing handlers to a ServerCapabilities object.""" prompts_capability = None resources_capability = None @@ -120,28 +85,28 @@ def get_capabilities( logging_capability = None # Set prompt capabilities if handler exists - if ListPromptsRequest in self.request_handlers: - prompts_capability = PromptsCapability( + if types.ListPromptsRequest in self.request_handlers: + prompts_capability = types.PromptsCapability( listChanged=notification_options.prompts_changed ) # Set resource capabilities if handler exists - if ListResourcesRequest in self.request_handlers: - resources_capability = ResourcesCapability( + if types.ListResourcesRequest in self.request_handlers: + resources_capability = types.ResourcesCapability( subscribe=False, listChanged=notification_options.resources_changed ) # Set tool capabilities if handler exists - if ListToolsRequest in self.request_handlers: - tools_capability = ToolsCapability( + if types.ListToolsRequest in self.request_handlers: + tools_capability = types.ToolsCapability( listChanged=notification_options.tools_changed ) # Set logging capabilities if handler exists - if SetLevelRequest in self.request_handlers: - logging_capability = LoggingCapability() + if types.SetLevelRequest in self.request_handlers: + logging_capability = types.LoggingCapability() - return ServerCapabilities( + return types.ServerCapabilities( prompts=prompts_capability, resources=resources_capability, tools=tools_capability, @@ -155,96 +120,59 @@ def request_context(self) -> RequestContext: return request_ctx.get() def list_prompts(self): - def decorator(func: Callable[[], Awaitable[list[Prompt]]]): + def decorator(func: Callable[[], Awaitable[list[types.Prompt]]]): logger.debug("Registering handler for PromptListRequest") async def handler(_: Any): prompts = await func() - return ServerResult(ListPromptsResult(prompts=prompts)) + return types.ServerResult(types.ListPromptsResult(prompts=prompts)) - self.request_handlers[ListPromptsRequest] = handler + self.request_handlers[types.ListPromptsRequest] = handler return func return decorator def get_prompt(self): - from mcp.types import ( - GetPromptRequest, - GetPromptResult, - ImageContent, - ) - from mcp.types import ( - Role as Role, - ) - def decorator( func: Callable[ - [str, dict[str, str] | None], Awaitable[types.PromptResponse] + [str, dict[str, str] | None], Awaitable[types.GetPromptResult] ], ): logger.debug("Registering handler for GetPromptRequest") - async def handler(req: GetPromptRequest): + async def handler(req: types.GetPromptRequest): prompt_get = await func(req.params.name, req.params.arguments) - messages: list[PromptMessage] = [] - for message in prompt_get.messages: - match message.content: - case str() as text_content: - content = TextContent(type="text", text=text_content) - case types.ImageContent() as img_content: - content = ImageContent( - type="image", - data=img_content.data, - mimeType=img_content.mime_type, - ) - case types.EmbeddedResource() as resource: - content = EmbeddedResource( - type="resource", resource=resource.resource - ) - case _: - raise ValueError( - f"Unexpected content type: {type(message.content)}" - ) - - prompt_message = PromptMessage(role=message.role, content=content) - messages.append(prompt_message) + return types.ServerResult(prompt_get) - return ServerResult( - GetPromptResult(description=prompt_get.desc, messages=messages) - ) - - self.request_handlers[GetPromptRequest] = handler + self.request_handlers[types.GetPromptRequest] = handler return func return decorator def list_resources(self): - def decorator(func: Callable[[], Awaitable[list[Resource]]]): + def decorator(func: Callable[[], Awaitable[list[types.Resource]]]): logger.debug("Registering handler for ListResourcesRequest") async def handler(_: Any): resources = await func() - return ServerResult(ListResourcesResult(resources=resources)) + return types.ServerResult( + types.ListResourcesResult(resources=resources) + ) - self.request_handlers[ListResourcesRequest] = handler + self.request_handlers[types.ListResourcesRequest] = handler return func return decorator def read_resource(self): - from mcp.types import ( - BlobResourceContents, - TextResourceContents, - ) - def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]): logger.debug("Registering handler for ReadResourceRequest") - async def handler(req: ReadResourceRequest): + async def handler(req: types.ReadResourceRequest): result = await func(req.params.uri) match result: case str(s): - content = TextResourceContents( + content = types.TextResourceContents( uri=req.params.uri, text=s, mimeType="text/plain", @@ -252,130 +180,125 @@ async def handler(req: ReadResourceRequest): case bytes(b): import base64 - content = BlobResourceContents( + content = types.BlobResourceContents( uri=req.params.uri, blob=base64.urlsafe_b64encode(b).decode(), mimeType="application/octet-stream", ) - return ServerResult( - ReadResourceResult( + return types.ServerResult( + types.ReadResourceResult( contents=[content], ) ) - self.request_handlers[ReadResourceRequest] = handler + self.request_handlers[types.ReadResourceRequest] = handler return func return decorator def set_logging_level(self): - from mcp.types import EmptyResult - - def decorator(func: Callable[[LoggingLevel], Awaitable[None]]): + def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]): logger.debug("Registering handler for SetLevelRequest") - async def handler(req: SetLevelRequest): + async def handler(req: types.SetLevelRequest): await func(req.params.level) - return ServerResult(EmptyResult()) + return types.ServerResult(types.EmptyResult()) - self.request_handlers[SetLevelRequest] = handler + self.request_handlers[types.SetLevelRequest] = handler return func return decorator def subscribe_resource(self): - from mcp.types import EmptyResult - def decorator(func: Callable[[AnyUrl], Awaitable[None]]): logger.debug("Registering handler for SubscribeRequest") - async def handler(req: SubscribeRequest): + async def handler(req: types.SubscribeRequest): await func(req.params.uri) - return ServerResult(EmptyResult()) + return types.ServerResult(types.EmptyResult()) - self.request_handlers[SubscribeRequest] = handler + self.request_handlers[types.SubscribeRequest] = handler return func return decorator def unsubscribe_resource(self): - from mcp.types import EmptyResult - def decorator(func: Callable[[AnyUrl], Awaitable[None]]): logger.debug("Registering handler for UnsubscribeRequest") - async def handler(req: UnsubscribeRequest): + async def handler(req: types.UnsubscribeRequest): await func(req.params.uri) - return ServerResult(EmptyResult()) + return types.ServerResult(types.EmptyResult()) - self.request_handlers[UnsubscribeRequest] = handler + self.request_handlers[types.UnsubscribeRequest] = handler return func return decorator def list_tools(self): - def decorator(func: Callable[[], Awaitable[list[Tool]]]): + def decorator(func: Callable[[], Awaitable[list[types.Tool]]]): logger.debug("Registering handler for ListToolsRequest") async def handler(_: Any): tools = await func() - return ServerResult(ListToolsResult(tools=tools)) + return types.ServerResult(types.ListToolsResult(tools=tools)) - self.request_handlers[ListToolsRequest] = handler + self.request_handlers[types.ListToolsRequest] = handler return func return decorator def call_tool(self): - from mcp.types import ( - CallToolResult, - EmbeddedResource, - ImageContent, - TextContent, - ) - def decorator( func: Callable[ ..., - Awaitable[Sequence[str | types.ImageContent | types.EmbeddedResource]], + Awaitable[ + Sequence[ + types.TextContent | types.ImageContent | types.EmbeddedResource + ] + ], ], ): logger.debug("Registering handler for CallToolRequest") - async def handler(req: CallToolRequest): + async def handler(req: types.CallToolRequest): try: results = await func(req.params.name, (req.params.arguments or {})) content = [] for result in results: match result: case str() as text: - content.append(TextContent(type="text", text=text)) + content.append( + types.TextContent(type="text", text=text) + ) case types.ImageContent() as img: content.append( - ImageContent( + types.ImageContent( type="image", data=img.data, - mimeType=img.mime_type, + mimeType=img.mimeType, ) ) case types.EmbeddedResource() as resource: content.append( - EmbeddedResource( + types.EmbeddedResource( type="resource", resource=resource.resource ) ) - return ServerResult(CallToolResult(content=content, isError=False)) + return types.ServerResult( + types.CallToolResult(content=content, isError=False) + ) except Exception as e: - return ServerResult( - CallToolResult( - content=[TextContent(type="text", text=str(e))], + return types.ServerResult( + types.CallToolResult( + content=[types.TextContent(type="text", text=str(e))], isError=True, ) ) - self.request_handlers[CallToolRequest] = handler + self.request_handlers[types.CallToolRequest] = handler return func return decorator @@ -386,48 +309,50 @@ def decorator( ): logger.debug("Registering handler for ProgressNotification") - async def handler(req: ProgressNotification): + async def handler(req: types.ProgressNotification): await func( req.params.progressToken, req.params.progress, req.params.total ) - self.notification_handlers[ProgressNotification] = handler + self.notification_handlers[types.ProgressNotification] = handler return func return decorator def completion(self): """Provides completions for prompts and resource templates""" - from mcp.types import CompleteResult, Completion, CompletionArgument def decorator( func: Callable[ - [PromptReference | ResourceReference, CompletionArgument], - Awaitable[Completion | None], + [ + types.PromptReference | types.ResourceReference, + types.CompletionArgument, + ], + Awaitable[types.Completion | None], ], ): logger.debug("Registering handler for CompleteRequest") - async def handler(req: CompleteRequest): + async def handler(req: types.CompleteRequest): completion = await func(req.params.ref, req.params.argument) - return ServerResult( - CompleteResult( + return types.ServerResult( + types.CompleteResult( completion=completion if completion is not None - else Completion(values=[], total=None, hasMore=None), + else types.Completion(values=[], total=None, hasMore=None), ) ) - self.request_handlers[CompleteRequest] = handler + self.request_handlers[types.CompleteRequest] = handler return func return decorator async def run( self, - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[JSONRPCMessage], - initialization_options: types.InitializationOptions, + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + initialization_options: InitializationOptions, # When True, exceptions are returned as messages to the client. # When False, exceptions are raised, which will cause the server to shut down # but also make tracing exceptions much easier during testing and when using @@ -442,7 +367,7 @@ async def run( logger.debug(f"Received message: {message}") match message: - case RequestResponder(request=ClientRequest(root=req)): + case RequestResponder(request=types.ClientRequest(root=req)): logger.info( f"Processing request of type {type(req).__name__}" ) @@ -467,7 +392,7 @@ async def run( except Exception as err: if raise_exceptions: raise err - response = ErrorData( + response = types.ErrorData( code=0, message=str(err), data=None ) finally: @@ -478,14 +403,14 @@ async def run( await message.respond(response) else: await message.respond( - ErrorData( - code=METHOD_NOT_FOUND, + types.ErrorData( + code=types.METHOD_NOT_FOUND, message="Method not found", ) ) logger.debug("Response sent") - case ClientNotification(root=notify): + case types.ClientNotification(root=notify): if type(notify) in self.notification_handlers: assert type(notify) in self.notification_handlers @@ -509,5 +434,5 @@ async def run( ) -async def _ping_handler(request: PingRequest) -> ServerResult: - return ServerResult(EmptyResult()) +async def _ping_handler(request: types.PingRequest) -> types.ServerResult: + return types.ServerResult(types.EmptyResult()) diff --git a/src/mcp/server/__main__.py b/src/mcp/server/__main__.py index 2313f46..1970eca 100644 --- a/src/mcp/server/__main__.py +++ b/src/mcp/server/__main__.py @@ -4,9 +4,9 @@ import anyio +from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.stdio import stdio_server -from mcp.server.types import InitializationOptions from mcp.types import ServerCapabilities if not sys.warnoptions: diff --git a/src/mcp/server/models.py b/src/mcp/server/models.py new file mode 100644 index 0000000..377ed51 --- /dev/null +++ b/src/mcp/server/models.py @@ -0,0 +1,16 @@ +""" +This module provides simpler types to use with the server for managing prompts +and tools. +""" + +from pydantic import BaseModel + +from mcp.types import ( + ServerCapabilities, +) + + +class InitializationOptions(BaseModel): + server_name: str + server_version: str + capabilities: ServerCapabilities diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index f6ed1b3..dd66307 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -6,34 +6,12 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl -from mcp.server.types import InitializationOptions +import mcp.types as types +from mcp.server.models import InitializationOptions from mcp.shared.session import ( BaseSession, RequestResponder, ) -from mcp.types import ( - LATEST_PROTOCOL_VERSION, - ClientNotification, - ClientRequest, - CreateMessageResult, - EmptyResult, - Implementation, - IncludeContext, - InitializedNotification, - InitializeRequest, - InitializeResult, - JSONRPCMessage, - ListRootsResult, - LoggingLevel, - ModelPreferences, - PromptListChangedNotification, - ResourceListChangedNotification, - SamplingMessage, - ServerNotification, - ServerRequest, - ServerResult, - ToolListChangedNotification, -) class InitializationState(Enum): @@ -44,37 +22,39 @@ class InitializationState(Enum): class ServerSession( BaseSession[ - ServerRequest, - ServerNotification, - ServerResult, - ClientRequest, - ClientNotification, + types.ServerRequest, + types.ServerNotification, + types.ServerResult, + types.ClientRequest, + types.ClientNotification, ] ): _initialized: InitializationState = InitializationState.NotInitialized def __init__( self, - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[types.JSONRPCMessage], init_options: InitializationOptions, ) -> None: - super().__init__(read_stream, write_stream, ClientRequest, ClientNotification) + super().__init__( + read_stream, write_stream, types.ClientRequest, types.ClientNotification + ) self._initialization_state = InitializationState.NotInitialized self._init_options = init_options async def _received_request( - self, responder: RequestResponder[ClientRequest, ServerResult] + self, responder: RequestResponder[types.ClientRequest, types.ServerResult] ): match responder.request.root: - case InitializeRequest(): + case types.InitializeRequest(): self._initialization_state = InitializationState.Initializing await responder.respond( - ServerResult( - InitializeResult( - protocolVersion=LATEST_PROTOCOL_VERSION, + types.ServerResult( + types.InitializeResult( + protocolVersion=types.LATEST_PROTOCOL_VERSION, capabilities=self._init_options.capabilities, - serverInfo=Implementation( + serverInfo=types.Implementation( name=self._init_options.server_name, version=self._init_options.server_version, ), @@ -87,11 +67,13 @@ async def _received_request( "Received request before initialization was complete" ) - async def _received_notification(self, notification: ClientNotification) -> None: + async def _received_notification( + self, notification: types.ClientNotification + ) -> None: # Need this to avoid ASYNC910 await anyio.lowlevel.checkpoint() match notification.root: - case InitializedNotification(): + case types.InitializedNotification(): self._initialization_state = InitializationState.Initialized case _: if self._initialization_state != InitializationState.Initialized: @@ -100,19 +82,14 @@ async def _received_notification(self, notification: ClientNotification) -> None ) async def send_log_message( - self, level: LoggingLevel, data: Any, logger: str | None = None + self, level: types.LoggingLevel, data: Any, logger: str | None = None ) -> None: """Send a log message notification.""" - from mcp.types import ( - LoggingMessageNotification, - LoggingMessageNotificationParams, - ) - await self.send_notification( - ServerNotification( - LoggingMessageNotification( + types.ServerNotification( + types.LoggingMessageNotification( method="notifications/message", - params=LoggingMessageNotificationParams( + params=types.LoggingMessageNotificationParams( level=level, data=data, logger=logger, @@ -123,43 +100,33 @@ async def send_log_message( async def send_resource_updated(self, uri: AnyUrl) -> None: """Send a resource updated notification.""" - from mcp.types import ( - ResourceUpdatedNotification, - ResourceUpdatedNotificationParams, - ) - await self.send_notification( - ServerNotification( - ResourceUpdatedNotification( + types.ServerNotification( + types.ResourceUpdatedNotification( method="notifications/resources/updated", - params=ResourceUpdatedNotificationParams(uri=uri), + params=types.ResourceUpdatedNotificationParams(uri=uri), ) ) ) async def create_message( self, - messages: list[SamplingMessage], + messages: list[types.SamplingMessage], *, max_tokens: int, system_prompt: str | None = None, - include_context: IncludeContext | None = None, + include_context: types.IncludeContext | None = None, temperature: float | None = None, stop_sequences: list[str] | None = None, metadata: dict[str, Any] | None = None, - model_preferences: ModelPreferences | None = None, - ) -> CreateMessageResult: + model_preferences: types.ModelPreferences | None = None, + ) -> types.CreateMessageResult: """Send a sampling/create_message request.""" - from mcp.types import ( - CreateMessageRequest, - CreateMessageRequestParams, - ) - return await self.send_request( - ServerRequest( - CreateMessageRequest( + types.ServerRequest( + types.CreateMessageRequest( method="sampling/createMessage", - params=CreateMessageRequestParams( + params=types.CreateMessageRequestParams( messages=messages, systemPrompt=system_prompt, includeContext=include_context, @@ -171,46 +138,40 @@ async def create_message( ), ) ), - CreateMessageResult, + types.CreateMessageResult, ) - async def list_roots(self) -> ListRootsResult: + async def list_roots(self) -> types.ListRootsResult: """Send a roots/list request.""" - from mcp.types import ListRootsRequest - return await self.send_request( - ServerRequest( - ListRootsRequest( + types.ServerRequest( + types.ListRootsRequest( method="roots/list", ) ), - ListRootsResult, + types.ListRootsResult, ) - async def send_ping(self) -> EmptyResult: + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" - from mcp.types import PingRequest - return await self.send_request( - ServerRequest( - PingRequest( + types.ServerRequest( + types.PingRequest( method="ping", ) ), - EmptyResult, + types.EmptyResult, ) async def send_progress_notification( self, progress_token: str | int, progress: float, total: float | None = None ) -> None: """Send a progress notification.""" - from mcp.types import ProgressNotification, ProgressNotificationParams - await self.send_notification( - ServerNotification( - ProgressNotification( + types.ServerNotification( + types.ProgressNotification( method="notifications/progress", - params=ProgressNotificationParams( + params=types.ProgressNotificationParams( progressToken=progress_token, progress=progress, total=total, @@ -222,8 +183,8 @@ async def send_progress_notification( async def send_resource_list_changed(self) -> None: """Send a resource list changed notification.""" await self.send_notification( - ServerNotification( - ResourceListChangedNotification( + types.ServerNotification( + types.ResourceListChangedNotification( method="notifications/resources/list_changed", ) ) @@ -232,8 +193,8 @@ async def send_resource_list_changed(self) -> None: async def send_tool_list_changed(self) -> None: """Send a tool list changed notification.""" await self.send_notification( - ServerNotification( - ToolListChangedNotification( + types.ServerNotification( + types.ToolListChangedNotification( method="notifications/tools/list_changed", ) ) @@ -242,8 +203,8 @@ async def send_tool_list_changed(self) -> None: async def send_prompt_list_changed(self) -> None: """Send a prompt list changed notification.""" await self.send_notification( - ServerNotification( - PromptListChangedNotification( + types.ServerNotification( + types.PromptListChangedNotification( method="notifications/prompts/list_changed", ) ) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 92ebb7a..f6e90bf 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -12,7 +12,7 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send -from mcp.types import JSONRPCMessage +import mcp.types as types logger = logging.getLogger(__name__) @@ -30,7 +30,9 @@ class SseServerTransport: """ _endpoint: str - _read_stream_writers: dict[UUID, MemoryObjectSendStream[JSONRPCMessage | Exception]] + _read_stream_writers: dict[ + UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception] + ] def __init__(self, endpoint: str) -> None: """ @@ -50,11 +52,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): raise ValueError("connect_sse can only handle HTTP requests") logger.debug("Setting up SSE connection") - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -125,7 +127,7 @@ async def handle_post_message( logger.debug(f"Received JSON: {json}") try: - message = JSONRPCMessage.model_validate(json) + message = types.JSONRPCMessage.model_validate(json) logger.debug(f"Validated client message: {message}") except ValidationError as err: logger.error(f"Failed to parse message: {err}") diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 29f6bb6..ffe4081 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -5,7 +5,7 @@ import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from mcp.types import JSONRPCMessage +import mcp.types as types @asynccontextmanager @@ -24,11 +24,11 @@ async def stdio_server( if not stdout: stdout = anyio.wrap_file(sys.stdout) - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -38,7 +38,7 @@ async def stdin_reader(): async with read_stream_writer: async for line in stdin: try: - message = JSONRPCMessage.model_validate_json(line) + message = types.JSONRPCMessage.model_validate_json(line) except Exception as exc: await read_stream_writer.send(exc) continue diff --git a/src/mcp/server/types.py b/src/mcp/server/types.py deleted file mode 100644 index 7946a4b..0000000 --- a/src/mcp/server/types.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -This module provides simpler types to use with the server for managing prompts -and tools. -""" - -from dataclasses import dataclass -from typing import Literal - -from pydantic import BaseModel - -from mcp.types import ( - BlobResourceContents, - Role, - ServerCapabilities, - TextResourceContents, -) - - -@dataclass -class ImageContent: - type: Literal["image"] - data: str - mime_type: str - - -@dataclass -class EmbeddedResource: - resource: TextResourceContents | BlobResourceContents - - -@dataclass -class Message: - role: Role - content: str | ImageContent | EmbeddedResource - - -@dataclass -class PromptResponse: - messages: list[Message] - desc: str | None = None - - -class InitializationOptions(BaseModel): - server_name: str - server_version: str - capabilities: ServerCapabilities diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 2a6d812..bd3d632 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -6,7 +6,7 @@ from starlette.types import Receive, Scope, Send from starlette.websockets import WebSocket -from mcp.types import JSONRPCMessage +import mcp.types as types logger = logging.getLogger(__name__) @@ -21,11 +21,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send): websocket = WebSocket(scope, receive, send) await websocket.accept(subprotocol="mcp") - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -35,7 +35,7 @@ async def ws_reader(): async with read_stream_writer: async for message in websocket.iter_json(): try: - client_message = JSONRPCMessage.model_validate(message) + client_message = types.JSONRPCMessage.model_validate(message) except Exception as exc: await read_stream_writer.send(exc) continue diff --git a/tests/conftest.py b/tests/conftest.py index 10f3ed6..8d792aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ from pydantic import AnyUrl from mcp.server import Server -from mcp.server.types import InitializationOptions +from mcp.server.models import InitializationOptions from mcp.types import Resource, ServerCapabilities TEST_INITIALIZATION_OPTIONS = InitializationOptions( diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 66ac58d..a78ca90 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -3,8 +3,8 @@ from mcp.client.session import ClientSession from mcp.server import NotificationOptions, Server +from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession -from mcp.server.types import InitializationOptions from mcp.types import ( ClientNotification, InitializedNotification, diff --git a/uv.lock b/uv.lock index 35e52f4..49d4e13 100644 --- a/uv.lock +++ b/uv.lock @@ -331,15 +331,14 @@ wheels = [ [[package]] name = "pyright" -version = "1.1.388" +version = "1.1.378" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodeenv" }, - { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9c/83/e9867538a794638d2d20ac3ab3106a31aca1d9cfea530c9b2921809dae03/pyright-1.1.388.tar.gz", hash = "sha256:0166d19b716b77fd2d9055de29f71d844874dbc6b9d3472ccd22df91db3dfa34", size = 21939 } +sdist = { url = "https://files.pythonhosted.org/packages/3d/f0/e8aa5555410d88f898bef04da2102b0a9bf144658c98d34872e91621ced2/pyright-1.1.378.tar.gz", hash = "sha256:78a043be2876d12d0af101d667e92c7734f3ebb9db71dccc2c220e7e7eb89ca2", size = 17486 } wheels = [ - { url = "https://files.pythonhosted.org/packages/03/57/7fb00363b7f267a398c5bdf4f55f3e64f7c2076b2e7d2901b3373d52b6ff/pyright-1.1.388-py3-none-any.whl", hash = "sha256:c7068e9f2c23539c6ac35fc9efac6c6c1b9aa5a0ce97a9a8a6cf0090d7cbf84c", size = 18579 }, + { url = "https://files.pythonhosted.org/packages/38/c6/f0d4bc20c13b20cecfbf13c699477c825e45767f1dc5068137323f86e495/pyright-1.1.378-py3-none-any.whl", hash = "sha256:8853776138b01bc284da07ac481235be7cc89d3176b073d2dba73636cb95be79", size = 18222 }, ] [[package]]