Skip to content

Commit

Permalink
feat: improve code parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomas2D committed Dec 30, 2024
1 parent 6c85601 commit c2ce87b
Showing 1 changed file with 46 additions and 23 deletions.
69 changes: 46 additions & 23 deletions src/code_interpreter/services/custom_tool_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def parse(self, tool_source_code: str) -> CustomTool:
Supported types for return value: anything that can be JSON-serialized.
"""
try:
*imports, function_def = ast.parse(textwrap.dedent(tool_source_code)).body
tree = ast.parse(textwrap.dedent(tool_source_code))
*imports, function_def = tree.body
except SyntaxError as e:
raise CustomToolParseError([f"Syntax error: {e.msg} on line {e.lineno}"])

Expand All @@ -70,6 +71,13 @@ def parse(self, tool_source_code: str) -> CustomTool:
]
)

import_aliases = {
alias.asname or alias.name
for node in imports
for alias in node.names
if alias.name == "typing"
}

errors = [
x
for x in (
Expand Down Expand Up @@ -104,7 +112,7 @@ def parse(self, tool_source_code: str) -> CustomTool:
"type": "object",
"title": function_def.name,
"properties": {
arg.arg: _type_to_json_schema(arg.annotation)
arg.arg: _type_to_json_schema(arg.annotation, import_aliases)
| (
{"description": param_description}
if (param_description := param_descriptions.get(arg.arg))
Expand Down Expand Up @@ -194,37 +202,52 @@ async def execute(

return json.loads(result.stdout)

def _normalize_type_name(type_node_name: str, aliases: set) -> str:
"""Normalize type names to their canonical forms."""
for alias in aliases:
if type_node_name.startswith(f"{alias}."):
return type_node_name[len(alias) + 1:].lower()

return type_node_name

def _type_to_json_schema(type_node: ast.AST) -> dict:
def _type_to_json_schema(type_node: ast.AST, import_aliases: set) -> dict:
if isinstance(type_node, ast.Subscript):
type_node_name = ast.unparse(type_node.value)
type_node_name = _normalize_type_name(ast.unparse(type_node.value), import_aliases)

if type_node_name == "list":
return {"type": "array", "items": _type_to_json_schema(type_node.slice)}
return {"type": "array", "items": _type_to_json_schema(type_node.slice, import_aliases)}
elif type_node_name == "dict" and isinstance(type_node.slice, ast.Tuple):
key_type_node, value_type_node = type_node.slice.elts
if ast.unparse(key_type_node) != "str":
raise ValueError(f"Unsupported type: {type_node}")
key_type_node_name = _normalize_type_name(ast.unparse(key_type_node), import_aliases)
if key_type_node_name != "str":
raise ValueError(f"Unsupported key type for dict: {ast.unparse(key_type_node)}")
return {
"type": "object",
"additionalProperties": _type_to_json_schema(value_type_node),
"additionalProperties": _type_to_json_schema(value_type_node, import_aliases),
}
elif type_node_name == "Optional" or type_node_name == "typing.Optional":
return {"anyOf": [{"type": "null"}, _type_to_json_schema(type_node.slice)]}
elif (
type_node_name == "Union" or type_node_name == "typing.Union"
) and isinstance(type_node.slice, ast.Tuple):
return {"anyOf": [_type_to_json_schema(el) for el in type_node.slice.elts]}
elif (
type_node_name == "Tuple" or type_node_name == "typing.Tuple"
) and isinstance(type_node.slice, ast.Tuple):
elif type_node_name == "optional":
return {"anyOf": [{"type": "null"}, _type_to_json_schema(type_node.slice, import_aliases)]}
elif type_node_name == "union" and isinstance(type_node.slice, ast.Tuple):
return {"anyOf": [_type_to_json_schema(el, import_aliases) for el in type_node.slice.elts]}
elif type_node_name == "tuple" and isinstance(type_node.slice, ast.Tuple):
return {
"type": "array",
"minItems": len(type_node.slice.elts),
"items": [_type_to_json_schema(el) for el in type_node.slice.elts],
"items": [_type_to_json_schema(el, import_aliases) for el in type_node.slice.elts],
"additionalItems": False,
}

type_node_name = ast.unparse(type_node)
elif type_node_name == "literal":
if isinstance(type_node.slice, ast.Tuple):
return {"enum": [ast.literal_eval(el) for el in type_node.slice.elts]}
else:
return {"enum": [ast.literal_eval(type_node.slice)]}
elif type_node_name == "final":
return _type_to_json_schema(type_node.slice, import_aliases)
elif type_node_name == "annotated":
base_type, *annotations = type_node.slice.elts
return _type_to_json_schema(base_type, import_aliases)

type_node_name = _normalize_type_name(ast.unparse(type_node), import_aliases)
if type_node_name == "int":
return {"type": "integer"}
elif type_node_name == "float":
Expand All @@ -233,10 +256,10 @@ def _type_to_json_schema(type_node: ast.AST) -> dict:
return {"type": "string"}
elif type_node_name == "bool":
return {"type": "boolean"}
elif type_node_name == "Any" or type_node_name == "typing.Any":
elif type_node_name == "any":
return {"type": "array"}
else:
raise ValueError(f"Unsupported type: {type_node_name}")

raise ValueError(f"Unsupported type: {type_node_name}")


def _parse_docstring(docstring: str) -> typing.Tuple[str, str, dict[str, str]]:
Expand Down

0 comments on commit c2ce87b

Please sign in to comment.