diff --git a/src/code_interpreter/services/custom_tool_executor.py b/src/code_interpreter/services/custom_tool_executor.py index 4440e74..e1d4084 100644 --- a/src/code_interpreter/services/custom_tool_executor.py +++ b/src/code_interpreter/services/custom_tool_executor.py @@ -57,7 +57,8 @@ def parse(self, tool_source_code: str) -> CustomTool: Supported types for return value: anything that can be JSON-serialized. """ try: - *imports, function_def = ast.parse(textwrap.dedent(tool_source_code)).body + tree = ast.parse(textwrap.dedent(tool_source_code)) + *imports, function_def = tree.body except SyntaxError as e: raise CustomToolParseError([f"Syntax error: {e.msg} on line {e.lineno}"]) @@ -70,6 +71,13 @@ def parse(self, tool_source_code: str) -> CustomTool: ] ) + import_aliases = { + alias.asname or alias.name + for node in imports + for alias in node.names + if alias.name == "typing" + } + errors = [ x for x in ( @@ -104,7 +112,7 @@ def parse(self, tool_source_code: str) -> CustomTool: "type": "object", "title": function_def.name, "properties": { - arg.arg: _type_to_json_schema(arg.annotation) + arg.arg: _type_to_json_schema(arg.annotation, import_aliases) | ( {"description": param_description} if (param_description := param_descriptions.get(arg.arg)) @@ -194,37 +202,52 @@ async def execute( return json.loads(result.stdout) +def _normalize_type_name(type_node_name: str, aliases: set) -> str: + """Normalize type names to their canonical forms.""" + for alias in aliases: + if type_node_name.startswith(f"{alias}."): + return type_node_name[len(alias) + 1:].lower() + + return type_node_name -def _type_to_json_schema(type_node: ast.AST) -> dict: +def _type_to_json_schema(type_node: ast.AST, import_aliases: set) -> dict: if isinstance(type_node, ast.Subscript): - type_node_name = ast.unparse(type_node.value) + type_node_name = _normalize_type_name(ast.unparse(type_node.value), import_aliases) + if type_node_name == "list": - return {"type": "array", "items": _type_to_json_schema(type_node.slice)} + return {"type": "array", "items": _type_to_json_schema(type_node.slice, import_aliases)} elif type_node_name == "dict" and isinstance(type_node.slice, ast.Tuple): key_type_node, value_type_node = type_node.slice.elts - if ast.unparse(key_type_node) != "str": - raise ValueError(f"Unsupported type: {type_node}") + key_type_node_name = _normalize_type_name(ast.unparse(key_type_node), import_aliases) + if key_type_node_name != "str": + raise ValueError(f"Unsupported key type for dict: {ast.unparse(key_type_node)}") return { "type": "object", - "additionalProperties": _type_to_json_schema(value_type_node), + "additionalProperties": _type_to_json_schema(value_type_node, import_aliases), } - elif type_node_name == "Optional" or type_node_name == "typing.Optional": - return {"anyOf": [{"type": "null"}, _type_to_json_schema(type_node.slice)]} - elif ( - type_node_name == "Union" or type_node_name == "typing.Union" - ) and isinstance(type_node.slice, ast.Tuple): - return {"anyOf": [_type_to_json_schema(el) for el in type_node.slice.elts]} - elif ( - type_node_name == "Tuple" or type_node_name == "typing.Tuple" - ) and isinstance(type_node.slice, ast.Tuple): + elif type_node_name == "optional": + return {"anyOf": [{"type": "null"}, _type_to_json_schema(type_node.slice, import_aliases)]} + elif type_node_name == "union" and isinstance(type_node.slice, ast.Tuple): + return {"anyOf": [_type_to_json_schema(el, import_aliases) for el in type_node.slice.elts]} + elif type_node_name == "tuple" and isinstance(type_node.slice, ast.Tuple): return { "type": "array", "minItems": len(type_node.slice.elts), - "items": [_type_to_json_schema(el) for el in type_node.slice.elts], + "items": [_type_to_json_schema(el, import_aliases) for el in type_node.slice.elts], "additionalItems": False, } - - type_node_name = ast.unparse(type_node) + elif type_node_name == "literal": + if isinstance(type_node.slice, ast.Tuple): + return {"enum": [ast.literal_eval(el) for el in type_node.slice.elts]} + else: + return {"enum": [ast.literal_eval(type_node.slice)]} + elif type_node_name == "final": + return _type_to_json_schema(type_node.slice, import_aliases) + elif type_node_name == "annotated": + base_type, *annotations = type_node.slice.elts + return _type_to_json_schema(base_type, import_aliases) + + type_node_name = _normalize_type_name(ast.unparse(type_node), import_aliases) if type_node_name == "int": return {"type": "integer"} elif type_node_name == "float": @@ -233,10 +256,10 @@ def _type_to_json_schema(type_node: ast.AST) -> dict: return {"type": "string"} elif type_node_name == "bool": return {"type": "boolean"} - elif type_node_name == "Any" or type_node_name == "typing.Any": + elif type_node_name == "any": return {"type": "array"} - else: - raise ValueError(f"Unsupported type: {type_node_name}") + + raise ValueError(f"Unsupported type: {type_node_name}") def _parse_docstring(docstring: str) -> typing.Tuple[str, str, dict[str, str]]: