Skip to content

Commit

Permalink
fix: switch to Pydantic JSON schema
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Pokorný <JenomPokorny@gmail.com>
  • Loading branch information
JanPokorny committed Jan 2, 2025
1 parent 6c85601 commit 9fc08a3
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 87 deletions.
2 changes: 2 additions & 0 deletions executor/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ RUN apk add --no-cache --repository=https://dl-cdn.alpinelinux.org/alpine/edge/t
py3-pillow-pyc \
py3-pip \
py3-pip-pyc \
py3-pydantic \
py3-pydantic-pyc \
py3-pypandoc \
py3-pypandoc-pyc \
py3-scipy \
Expand Down
1 change: 1 addition & 0 deletions executor/requirements-skip.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pandas
pdf2image
pikepdf
pillow
pydantic
pypandoc
scipy
sympy
Expand Down
175 changes: 96 additions & 79 deletions src/code_interpreter/services/custom_tool_executor.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,20 @@
# Copyright 2024 IBM Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ast
from dataclasses import dataclass
import json
import typing
import inspect
import re
import json
import textwrap

from pydantic import validate_call

import pydantic
import pydantic.json_schema
from code_interpreter.services.kubernetes_code_executor import KubernetesCodeExecutor


@dataclass
class CustomTool:
name: str
description: str
input_schema: dict
input_schema: dict[str, typing.Any]


@dataclass
Expand All @@ -53,8 +38,8 @@ def parse(self, tool_source_code: str) -> CustomTool:
The source code must contain a single function definition, optionally preceded by imports. The function must not have positional-only arguments, *args or **kwargs.
The function arguments must have type annotations. The docstring must follow the ReST format -- :param something: and :return: directives are supported.
Supported types for input arguments: int, float, str, bool, typing.Any, list[...], dict[str, ...], typing.Tuple[...], typing.Optional[...], typing.Union[...], where ... is any of the supported types.
Supported types for return value: anything that can be JSON-serialized.
Function arguments will be converted to JSONSchema by Pydantic, so everything that can be (de)serialized through Pydantic can be used.
However, the imports that can be used in types are currently limited to `typing`, `pathlib` and `datetime` for safety reasons.
"""
try:
*imports, function_def = ast.parse(textwrap.dedent(tool_source_code)).body
Expand Down Expand Up @@ -99,12 +84,14 @@ def parse(self, tool_source_code: str) -> CustomTool:
ast.get_docstring(function_def) or ""
)

namespace = _build_namespace(imports)

json_schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"title": function_def.name,
"properties": {
arg.arg: _type_to_json_schema(arg.annotation)
arg.arg: _type_to_json_schema(arg.annotation, namespace)
| (
{"description": param_description}
if (param_description := param_descriptions.get(arg.arg))
Expand Down Expand Up @@ -153,11 +140,11 @@ def parse(self, tool_source_code: str) -> CustomTool:
input_schema=json_schema,
)

@validate_call
@pydantic.validate_call
async def execute(
self,
tool_source_code: str,
tool_input: dict[str, typing.Any],
tool_input_json: str,
) -> typing.Any:
"""
Execute the given custom tool with the given input.
Expand All @@ -166,24 +153,22 @@ async def execute(
The input is expected to be valid according to the input schema produced by the parse method.
"""

clean_tool_source_code = textwrap.dedent(tool_source_code)
*imports, function_def = ast.parse(clean_tool_source_code).body

result = await self.code_executor.execute(
source_code=f"""
# Import all tool dependencies here -- to aid the dependency detection
{"\n".join(ast.unparse(node) for node in imports if isinstance(node, (ast.Import, ast.ImportFrom)))}
import pydantic
import contextlib
import json
# Import all tool dependencies here -- to aid the dependency detection
{
"\n".join(
ast.unparse(node)
for node in ast.parse(textwrap.dedent(tool_source_code)).body
if isinstance(node, (ast.Import, ast.ImportFrom))
)
}
with contextlib.redirect_stdout(None):
inner_globals = {{}}
exec(compile({repr(textwrap.dedent(tool_source_code))}, "<string>", "exec"), inner_globals)
result = next(x for x in inner_globals.values() if getattr(x, '__module__', ...) is None)(**{repr(tool_input)})
exec(compile({repr(clean_tool_source_code)}, "<string>", "exec"), inner_globals)
result = pydantic.TypeAdapter(inner_globals[{repr(function_def.name)}]).validate_json({repr(tool_input_json)})
print(json.dumps(result))
""",
Expand All @@ -195,50 +180,6 @@ async def execute(
return json.loads(result.stdout)


def _type_to_json_schema(type_node: ast.AST) -> dict:
if isinstance(type_node, ast.Subscript):
type_node_name = ast.unparse(type_node.value)
if type_node_name == "list":
return {"type": "array", "items": _type_to_json_schema(type_node.slice)}
elif type_node_name == "dict" and isinstance(type_node.slice, ast.Tuple):
key_type_node, value_type_node = type_node.slice.elts
if ast.unparse(key_type_node) != "str":
raise ValueError(f"Unsupported type: {type_node}")
return {
"type": "object",
"additionalProperties": _type_to_json_schema(value_type_node),
}
elif type_node_name == "Optional" or type_node_name == "typing.Optional":
return {"anyOf": [{"type": "null"}, _type_to_json_schema(type_node.slice)]}
elif (
type_node_name == "Union" or type_node_name == "typing.Union"
) and isinstance(type_node.slice, ast.Tuple):
return {"anyOf": [_type_to_json_schema(el) for el in type_node.slice.elts]}
elif (
type_node_name == "Tuple" or type_node_name == "typing.Tuple"
) and isinstance(type_node.slice, ast.Tuple):
return {
"type": "array",
"minItems": len(type_node.slice.elts),
"items": [_type_to_json_schema(el) for el in type_node.slice.elts],
"additionalItems": False,
}

type_node_name = ast.unparse(type_node)
if type_node_name == "int":
return {"type": "integer"}
elif type_node_name == "float":
return {"type": "number"}
elif type_node_name == "str":
return {"type": "string"}
elif type_node_name == "bool":
return {"type": "boolean"}
elif type_node_name == "Any" or type_node_name == "typing.Any":
return {"type": "array"}
else:
raise ValueError(f"Unsupported type: {type_node_name}")


def _parse_docstring(docstring: str) -> typing.Tuple[str, str, dict[str, str]]:
"""
Parse a docstring in the ReST format and return the function description, return description and a dictionary of parameter descriptions.
Expand All @@ -262,3 +203,79 @@ def _parse_docstring(docstring: str) -> typing.Tuple[str, str, dict[str, str]]:
elif match := re.match(r"return: ((?:.|\n)+)", chunk, flags=re.MULTILINE):
return_description = match.group(1)
return fn_description, return_description, param_descriptions


def _build_namespace(
imports: list[ast.AST],
allowed_modules: set[str] = {"typing", "pathlib", "datetime"},
) -> dict[str, typing.Any]:
namespace = {
"str": str,
"int": int,
"float": float,
"bool": bool,
"list": list,
"dict": dict,
"set": set,
"tuple": tuple,
}

for node in imports:
if isinstance(node, ast.Import):
for name in node.names:
if name.name in allowed_modules:
namespace[name.asname or name.name] = __import__(name.name)
elif isinstance(node, ast.ImportFrom):
if node.module in allowed_modules:
module = __import__(node.module, fromlist=[n.name for n in node.names])
for name in node.names:
namespace[name.asname or name.name] = getattr(module, name.name)

return namespace


def _type_to_json_schema(type_ast: ast.AST, namespace: dict) -> dict:
type_str = ast.unparse(type_ast)
if not _is_safe_type_ast(type_ast):
raise CustomToolParseError([f"Invalid type annotation `{type_str}`"])
try:
return pydantic.TypeAdapter(eval(type_str, namespace)).json_schema(
schema_generator=_GenerateJsonSchema
)
except Exception as e:
raise CustomToolParseError([f"Error when parsing type `{type_str}`: {e}"])


class _GenerateJsonSchema(pydantic.json_schema.GenerateJsonSchema):
schema_dialect = "http://json-schema.org/draft-07/schema#"

def tuple_schema(self, schema):
# Use draft-07 syntax for tuples
schema = super().tuple_schema(schema)
if "prefixItems" in schema:
schema["items"] = schema.pop("prefixItems")
schema.pop("maxItems")
schema["additionalItems"] = False
return schema


def _is_safe_type_ast(node: ast.AST) -> bool:
match node:
case ast.Name():
return True
case ast.Attribute():
return _is_safe_type_ast(node.value)
case ast.Subscript():
return _is_safe_type_ast(node.value) and _is_safe_type_ast(node.slice)
case ast.Tuple() | ast.List():
return all(_is_safe_type_ast(elt) for elt in node.elts)
case ast.Constant():
return isinstance(node.value, (str, int, float, bool, type(None)))
case ast.BinOp():
return (
isinstance(node.op, ast.BitOr)
and _is_safe_type_ast(node.left)
and _is_safe_type_ast(node.right)
)
case _:
return False
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ async def ExecuteCustomTool(

try:
result = await self.custom_tool_executor.execute(
tool_input=json.loads(request.tool_input_json),
tool_input_json=request.tool_input_json,
tool_source_code=request.tool_source_code,
)
except CustomToolExecuteError as e:
Expand Down
2 changes: 1 addition & 1 deletion src/code_interpreter/services/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def execute_custom_tool(
"Executing custom tool with source code %s", request.tool_source_code
)
result = await custom_tool_executor.execute(
tool_input=json.loads(request.tool_input_json),
tool_input_json=request.tool_input_json,
tool_source_code=request.tool_source_code,
)
logger.info("Executed custom tool with result %s", result)
Expand Down
27 changes: 24 additions & 3 deletions test/e2e/test_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@ def test_parse_custom_tool_success(grpc_stub: CodeInterpreterServiceStub):
response: ParseCustomToolResponse = grpc_stub.ParseCustomTool(
ParseCustomToolRequest(
tool_source_code='''
def my_tool(a: int, b: typing.Tuple[Optional[str], str] = ("hello", "world"), *, c: typing.Union[list[str], dict[str, typing.Optional[float]]]) -> int:
import typing
import typing as banana
from typing import Optional
def my_tool(a: int, b: typing.Tuple[Optional[str], str] = ("hello", "world"), *, c: typing.Union[list[str], dict[str, banana.Optional[float]]]) -> int:
"""
This tool is really really cool.
Very toolish experience:
Expand Down Expand Up @@ -149,7 +153,7 @@ def my_tool(a: int, b: typing.Tuple[Optional[str], str] = ("hello", "world"), *,
"type": "array",
"minItems": 2,
"items": [
{"anyOf": [{"type": "null"}, {"type": "string"}]},
{"anyOf": [{"type": "string"}, {"type": "null"}]},
{"type": "string"},
],
"additionalItems": False,
Expand All @@ -161,7 +165,7 @@ def my_tool(a: int, b: typing.Tuple[Optional[str], str] = ("hello", "world"), *,
{
"type": "object",
"additionalProperties": {
"anyOf": [{"type": "null"}, {"type": "number"}]
"anyOf": [{"type": "number"}, {"type": "null"}]
},
},
],
Expand Down Expand Up @@ -249,6 +253,23 @@ def test_execute_custom_tool_success(grpc_stub: CodeInterpreterServiceStub):
assert result.success.tool_output_json == "3"


def test_execute_custom_tool_advanced_success(grpc_stub: CodeInterpreterServiceStub):
result = grpc_stub.ExecuteCustomTool(
ExecuteCustomToolRequest(
tool_source_code="""
import datetime
def date_tool(a: datetime.datetime) -> str:
return f"The year is {a.year}"
""",
tool_input_json='{"a": "2000-01-01T00:00:00"}',
)
)

assert result.WhichOneof("response") == "success"
assert result.success.tool_output_json == "\"The year is 2000\""


def test_execute_custom_tool_error(grpc_stub: CodeInterpreterServiceStub):
result = grpc_stub.ExecuteCustomTool(
ExecuteCustomToolRequest(
Expand Down
29 changes: 26 additions & 3 deletions test/e2e/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def test_parse_custom_tool_success(http_client: httpx.Client):
"/v1/parse-custom-tool",
json={
"tool_source_code": '''
def my_tool(a: int, b: typing.Tuple[Optional[str], str] = ("hello", "world"), *, c: typing.Union[list[str], dict[str, typing.Optional[float]]]) -> int:
import typing
import typing as banana
from typing import Optional
def my_tool(a: int, b: typing.Tuple[Optional[str], str] = ("hello", "world"), *, c: typing.Union[list[str], dict[str, banana.Optional[float]]]) -> int:
"""
This tool is really really cool.
Very toolish experience:
Expand Down Expand Up @@ -128,7 +132,7 @@ def my_tool(a: int, b: typing.Tuple[Optional[str], str] = ("hello", "world"), *,
"type": "array",
"minItems": 2,
"items": [
{"anyOf": [{"type": "null"}, {"type": "string"}]},
{"anyOf": [{"type": "string"}, {"type": "null"}]},
{"type": "string"},
],
"additionalItems": False,
Expand All @@ -140,7 +144,7 @@ def my_tool(a: int, b: typing.Tuple[Optional[str], str] = ("hello", "world"), *,
{
"type": "object",
"additionalProperties": {
"anyOf": [{"type": "null"}, {"type": "number"}]
"anyOf": [{"type": "number"}, {"type": "null"}]
},
},
],
Expand Down Expand Up @@ -215,6 +219,25 @@ def test_execute_custom_tool_success(http_client: httpx.Client):
assert json.loads(response_json["tool_output_json"]) == 3


def test_execute_custom_tool_advanced_success(http_client: httpx.Client):
response = http_client.post(
"/v1/execute-custom-tool",
json={
"tool_source_code": """
import datetime
def date_tool(a: datetime.datetime) -> str:
return f"The year is {a.year}"
""",
"tool_input_json": '{"a": "2000-01-01T00:00:00"}',
},
)

assert response.status_code == 200
response_json = response.json()
assert json.loads(response_json["tool_output_json"]) == "The year is 2000"


def test_parse_custom_tool_error(http_client: httpx.Client):
response = http_client.post(
"/v1/parse-custom-tool",
Expand Down

0 comments on commit 9fc08a3

Please sign in to comment.