Skip to content

Commit

Permalink
Merge pull request #42 from modelcontextprotocol/davidsp/types
Browse files Browse the repository at this point in the history
Types Rework
  • Loading branch information
dsp-ant authored Nov 11, 2024
2 parents 837309c + ec8c85e commit 99c402d
Show file tree
Hide file tree
Showing 14 changed files with 279 additions and 503 deletions.
237 changes: 79 additions & 158 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,85 +3,56 @@
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl

import mcp.types as types
from mcp.shared.session import BaseSession
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
from mcp.types import (
LATEST_PROTOCOL_VERSION,
CallToolResult,
ClientCapabilities,
ClientNotification,
ClientRequest,
ClientResult,
CompleteResult,
EmptyResult,
GetPromptResult,
Implementation,
InitializedNotification,
InitializeResult,
JSONRPCMessage,
ListPromptsResult,
ListResourcesResult,
ListToolsResult,
LoggingLevel,
PromptReference,
ReadResourceResult,
ResourceReference,
RootsCapability,
ServerNotification,
ServerRequest,
)


class ClientSession(
BaseSession[
ClientRequest,
ClientNotification,
ClientResult,
ServerRequest,
ServerNotification,
types.ClientRequest,
types.ClientNotification,
types.ClientResult,
types.ServerRequest,
types.ServerNotification,
]
):
def __init__(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_timeout_seconds: timedelta | None = None,
) -> None:
super().__init__(
read_stream,
write_stream,
ServerRequest,
ServerNotification,
types.ServerRequest,
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)

async def initialize(self) -> InitializeResult:
from mcp.types import (
InitializeRequest,
InitializeRequestParams,
)

async def initialize(self) -> types.InitializeResult:
result = await self.send_request(
ClientRequest(
InitializeRequest(
types.ClientRequest(
types.InitializeRequest(
method="initialize",
params=InitializeRequestParams(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(
params=types.InitializeRequestParams(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=None,
experimental=None,
roots=RootsCapability(
roots=types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True
),
),
clientInfo=Implementation(name="mcp", version="0.1.0"),
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
),
)
),
InitializeResult,
types.InitializeResult,
)

if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
Expand All @@ -91,40 +62,33 @@ async def initialize(self) -> InitializeResult:
)

await self.send_notification(
ClientNotification(
InitializedNotification(method="notifications/initialized")
types.ClientNotification(
types.InitializedNotification(method="notifications/initialized")
)
)

return result

async def send_ping(self) -> EmptyResult:
async def send_ping(self) -> types.EmptyResult:
"""Send a ping request."""
from mcp.types import PingRequest

return await self.send_request(
ClientRequest(
PingRequest(
types.ClientRequest(
types.PingRequest(
method="ping",
)
),
EmptyResult,
types.EmptyResult,
)

async def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
"""Send a progress notification."""
from mcp.types import (
ProgressNotification,
ProgressNotificationParams,
)

await self.send_notification(
ClientNotification(
ProgressNotification(
types.ClientNotification(
types.ProgressNotification(
method="notifications/progress",
params=ProgressNotificationParams(
params=types.ProgressNotificationParams(
progressToken=progress_token,
progress=progress,
total=total,
Expand All @@ -133,180 +97,137 @@ async def send_progress_notification(
)
)

async def set_logging_level(self, level: LoggingLevel) -> EmptyResult:
async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
"""Send a logging/setLevel request."""
from mcp.types import (
SetLevelRequest,
SetLevelRequestParams,
)

return await self.send_request(
ClientRequest(
SetLevelRequest(
types.ClientRequest(
types.SetLevelRequest(
method="logging/setLevel",
params=SetLevelRequestParams(level=level),
params=types.SetLevelRequestParams(level=level),
)
),
EmptyResult,
types.EmptyResult,
)

async def list_resources(self) -> ListResourcesResult:
async def list_resources(self) -> types.ListResourcesResult:
"""Send a resources/list request."""
from mcp.types import (
ListResourcesRequest,
)

return await self.send_request(
ClientRequest(
ListResourcesRequest(
types.ClientRequest(
types.ListResourcesRequest(
method="resources/list",
)
),
ListResourcesResult,
types.ListResourcesResult,
)

async def read_resource(self, uri: AnyUrl) -> ReadResourceResult:
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
"""Send a resources/read request."""
from mcp.types import (
ReadResourceRequest,
ReadResourceRequestParams,
)

return await self.send_request(
ClientRequest(
ReadResourceRequest(
types.ClientRequest(
types.ReadResourceRequest(
method="resources/read",
params=ReadResourceRequestParams(uri=uri),
params=types.ReadResourceRequestParams(uri=uri),
)
),
ReadResourceResult,
types.ReadResourceResult,
)

async def subscribe_resource(self, uri: AnyUrl) -> EmptyResult:
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/subscribe request."""
from mcp.types import (
SubscribeRequest,
SubscribeRequestParams,
)

return await self.send_request(
ClientRequest(
SubscribeRequest(
types.ClientRequest(
types.SubscribeRequest(
method="resources/subscribe",
params=SubscribeRequestParams(uri=uri),
params=types.SubscribeRequestParams(uri=uri),
)
),
EmptyResult,
types.EmptyResult,
)

async def unsubscribe_resource(self, uri: AnyUrl) -> EmptyResult:
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/unsubscribe request."""
from mcp.types import (
UnsubscribeRequest,
UnsubscribeRequestParams,
)

return await self.send_request(
ClientRequest(
UnsubscribeRequest(
types.ClientRequest(
types.UnsubscribeRequest(
method="resources/unsubscribe",
params=UnsubscribeRequestParams(uri=uri),
params=types.UnsubscribeRequestParams(uri=uri),
)
),
EmptyResult,
types.EmptyResult,
)

async def call_tool(
self, name: str, arguments: dict | None = None
) -> CallToolResult:
) -> types.CallToolResult:
"""Send a tools/call request."""
from mcp.types import (
CallToolRequest,
CallToolRequestParams,
)

return await self.send_request(
ClientRequest(
CallToolRequest(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=CallToolRequestParams(name=name, arguments=arguments),
params=types.CallToolRequestParams(name=name, arguments=arguments),
)
),
CallToolResult,
types.CallToolResult,
)

async def list_prompts(self) -> ListPromptsResult:
async def list_prompts(self) -> types.ListPromptsResult:
"""Send a prompts/list request."""
from mcp.types import ListPromptsRequest

return await self.send_request(
ClientRequest(
ListPromptsRequest(
types.ClientRequest(
types.ListPromptsRequest(
method="prompts/list",
)
),
ListPromptsResult,
types.ListPromptsResult,
)

async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
) -> GetPromptResult:
) -> types.GetPromptResult:
"""Send a prompts/get request."""
from mcp.types import GetPromptRequest, GetPromptRequestParams

return await self.send_request(
ClientRequest(
GetPromptRequest(
types.ClientRequest(
types.GetPromptRequest(
method="prompts/get",
params=GetPromptRequestParams(name=name, arguments=arguments),
params=types.GetPromptRequestParams(name=name, arguments=arguments),
)
),
GetPromptResult,
types.GetPromptResult,
)

async def complete(
self, ref: ResourceReference | PromptReference, argument: dict
) -> CompleteResult:
self, ref: types.ResourceReference | types.PromptReference, argument: dict
) -> types.CompleteResult:
"""Send a completion/complete request."""
from mcp.types import (
CompleteRequest,
CompleteRequestParams,
CompletionArgument,
)

return await self.send_request(
ClientRequest(
CompleteRequest(
types.ClientRequest(
types.CompleteRequest(
method="completion/complete",
params=CompleteRequestParams(
params=types.CompleteRequestParams(
ref=ref,
argument=CompletionArgument(**argument),
argument=types.CompletionArgument(**argument),
),
)
),
CompleteResult,
types.CompleteResult,
)

async def list_tools(self) -> ListToolsResult:
async def list_tools(self) -> types.ListToolsResult:
"""Send a tools/list request."""
from mcp.types import ListToolsRequest

return await self.send_request(
ClientRequest(
ListToolsRequest(
types.ClientRequest(
types.ListToolsRequest(
method="tools/list",
)
),
ListToolsResult,
types.ListToolsResult,
)

async def send_roots_list_changed(self) -> None:
"""Send a roots/list_changed notification."""
from mcp.types import RootsListChangedNotification

await self.send_notification(
ClientNotification(
RootsListChangedNotification(
types.ClientNotification(
types.RootsListChangedNotification(
method="notifications/roots/list_changed",
)
)
Expand Down
Loading

0 comments on commit 99c402d

Please sign in to comment.