Skip to content

Commit

Permalink
Merge pull request #28 from modelcontextprotocol/justin/sampling
Browse files Browse the repository at this point in the history
Add tab and approval flow for server -> client sampling
  • Loading branch information
jspahrsummers authored Oct 28, 2024
2 parents 8ff82d7 + 7926eea commit a5606f7
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 11 deletions.
76 changes: 65 additions & 11 deletions client/src/App.tsx
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js";
import {
CallToolResultSchema,
ClientRequest,
CreateMessageRequestSchema,
CreateMessageResult,
EmptyResultSchema,
GetPromptResultSchema,
ListPromptsResultSchema,
Expand All @@ -24,16 +16,28 @@ import {
ServerNotification,
Tool,
} from "@modelcontextprotocol/sdk/types.js";
import { useEffect, useRef, useState } from "react";

import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs";
import {
Bell,
Files,
Hammer,
Hash,
MessageSquare,
Play,
Send,
Terminal,
} from "lucide-react";
import { useEffect, useRef, useState } from "react";

import { AnyZodObject } from "zod";
import "./App.css";
Expand All @@ -43,6 +47,7 @@ import PingTab from "./components/PingTab";
import PromptsTab, { Prompt } from "./components/PromptsTab";
import RequestsTab from "./components/RequestsTabs";
import ResourcesTab from "./components/ResourcesTab";
import SamplingTab, { PendingRequest } from "./components/SamplingTab";
import Sidebar from "./components/Sidebar";
import ToolsTab from "./components/ToolsTab";

Expand Down Expand Up @@ -77,6 +82,32 @@ const App = () => {
const [mcpClient, setMcpClient] = useState<Client | null>(null);
const [notifications, setNotifications] = useState<ServerNotification[]>([]);

const [pendingSampleRequests, setPendingSampleRequests] = useState<
Array<
PendingRequest & {
resolve: (result: CreateMessageResult) => void;
reject: (error: Error) => void;
}
>
>([]);
const nextRequestId = useRef(0);

const handleApproveSampling = (id: number, result: CreateMessageResult) => {
setPendingSampleRequests((prev) => {
const request = prev.find((r) => r.id === id);
request?.resolve(result);
return prev.filter((r) => r.id !== id);
});
};

const handleRejectSampling = (id: number) => {
setPendingSampleRequests((prev) => {
const request = prev.find((r) => r.id === id);
request?.reject(new Error("Sampling request rejected"));
return prev.filter((r) => r.id !== id);
});
};

const [selectedResource, setSelectedResource] = useState<Resource | null>(
null,
);
Expand Down Expand Up @@ -229,6 +260,15 @@ const App = () => {
},
);

client.setRequestHandler(CreateMessageRequestSchema, (request) => {
return new Promise<CreateMessageResult>((resolve, reject) => {
setPendingSampleRequests((prev) => [
...prev,
{ id: nextRequestId.current++, request, resolve, reject },
]);
});
});

setMcpClient(client);
setConnectionStatus("connected");
} catch (e) {
Expand Down Expand Up @@ -314,6 +354,15 @@ const App = () => {
<Bell className="w-4 h-4 mr-2" />
Ping
</TabsTrigger>
<TabsTrigger value="sampling" className="relative">
<Hash className="w-4 h-4 mr-2" />
Sampling
{pendingSampleRequests.length > 0 && (
<span className="absolute -top-1 -right-1 bg-red-500 text-white text-xs rounded-full h-4 w-4 flex items-center justify-center">
{pendingSampleRequests.length}
</span>
)}
</TabsTrigger>
</TabsList>

<div className="w-full">
Expand Down Expand Up @@ -362,6 +411,11 @@ const App = () => {
);
}}
/>
<SamplingTab
pendingRequests={pendingSampleRequests}
onApprove={handleApproveSampling}
onReject={handleRejectSampling}
/>
</div>
</Tabs>
) : (
Expand Down
65 changes: 65 additions & 0 deletions client/src/components/SamplingTab.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { Alert, AlertDescription } from "@/components/ui/alert";
import { Button } from "@/components/ui/button";
import { TabsContent } from "@/components/ui/tabs";
import {
CreateMessageRequest,
CreateMessageResult,
} from "@modelcontextprotocol/sdk/types.js";

export type PendingRequest = {
id: number;
request: CreateMessageRequest;
};

export type Props = {
pendingRequests: PendingRequest[];
onApprove: (id: number, result: CreateMessageResult) => void;
onReject: (id: number) => void;
};

const SamplingTab = ({ pendingRequests, onApprove, onReject }: Props) => {
const handleApprove = (id: number) => {
// For now, just return a stub response
onApprove(id, {
model: "stub-model",
stopReason: "endTurn",
role: "assistant",
content: {
type: "text",
text: "This is a stub response.",
},
});
};

return (
<TabsContent value="sampling" className="h-96">
<Alert>
<AlertDescription>
When the server requests LLM sampling, requests will appear here for
approval.
</AlertDescription>
</Alert>
<div className="mt-4 space-y-4">
<h3 className="text-lg font-semibold">Recent Requests</h3>
{pendingRequests.map((request) => (
<div key={request.id} className="p-4 border rounded-lg space-y-4">
<pre className="bg-gray-50 p-2 rounded">
{JSON.stringify(request.request, null, 2)}
</pre>
<div className="flex space-x-2">
<Button onClick={() => handleApprove(request.id)}>Approve</Button>
<Button variant="outline" onClick={() => onReject(request.id)}>
Reject
</Button>
</div>
</div>
))}
{pendingRequests.length === 0 && (
<p className="text-gray-500">No pending requests</p>
)}
</div>
</TabsContent>
);
};

export default SamplingTab;

0 comments on commit a5606f7

Please sign in to comment.