Skip to content

Commit

Permalink
Move the Pydantic generic resolution into its own module
Browse files Browse the repository at this point in the history
  • Loading branch information
BlackHC committed Oct 14, 2024
1 parent 2639de5 commit 8c44c5b
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 200 deletions.
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ docs/source

# From https://raw.githubusercontent.com/github/gitignore/main/Python.gitignore

Byte-compiled / optimized / DLL files
Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
Expand Down Expand Up @@ -161,4 +161,6 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
#.idea/
logs/
wandb/
168 changes: 7 additions & 161 deletions llm_strategy/llm_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
from langchain_core.language_models import BaseLanguageModel, BaseLLM
from llmtracer import TraceNodeKind, trace_calls, update_event_properties, update_name
from llmtracer.trace_builder import slicer
from pydantic import BaseModel, ValidationError, create_model, generics
from pydantic import BaseModel, ValidationError, create_model
from pydantic.fields import FieldInfo, Undefined
from pydantic.generics import replace_types

from llm_hyperparameters.track_hyperparameters import (
Hyperparameter,
track_hyperparameters,
)
from llm_strategy.chat_chain import ChatChain
from llm_strategy.pydantic_generic_type_resolution import PydanticGenericTypeMap

T = typing.TypeVar("T")
S = typing.TypeVar("S")
Expand Down Expand Up @@ -126,9 +126,9 @@ def is_not_implemented(f: typing.Callable) -> bool:

class TyperWrapper(str):
"""
A wrapper around a type that can be used to create a Pydantic model.
A wrapper we use as an annotation instead of `type` to make types serializable.
This is used to support @classmethods.
This is used to support scenarios where we need to serialize type objects in Pydantic models.
"""

@classmethod
Expand Down Expand Up @@ -206,8 +206,8 @@ def create(docstring: str, input_type: type[B], return_annotation: T, input: B)
return_info = (return_annotation, ...)

# resolve generic types
generic_type_map = LLMStructuredPrompt.resolve_generic_types(input_type, input)
return_type: type = LLMStructuredPrompt.resolve_type(return_info[0], generic_type_map)
generic_type_map = PydanticGenericTypeMap.resolve_generic_types(input_type, input)
return_type: type = generic_type_map.resolve_type(return_info[0])

if return_type is types.NoneType: # noqa: E721
raise ValueError(f"Resolve return type {return_info[0]} is None! This would be a NOP.")
Expand All @@ -224,7 +224,7 @@ def create(docstring: str, input_type: type[B], return_annotation: T, input: B)
resolved_output_model_type = Output[return_type] # noqa

# resolve input_type
resolved_input_type = LLMStructuredPrompt.resolve_type(input_type, generic_type_map)
resolved_input_type = generic_type_map.resolve_type(input_type)

return LLMStructuredPrompt(
docstring=docstring,
Expand All @@ -234,160 +234,6 @@ def create(docstring: str, input_type: type[B], return_annotation: T, input: B)
input=input,
)

@staticmethod
def resolve_type(source_type: type, generic_type_map: dict[type, type]) -> type:
"""
Resolve a type using the generic type map.
Supports Pydantic.GenericModel and typing.Generic.
"""
if source_type in generic_type_map:
source_type = generic_type_map[source_type]

if isinstance(source_type, type) and issubclass(source_type, generics.GenericModel):
base_generic_type = LLMStructuredPrompt.get_base_generic_type(source_type)
generic_parameter_type_map = LLMStructuredPrompt.get_generic_type_map(source_type, base_generic_type)
# forward step using the generic type map
resolved_generic_type_map = {
generic_type: generic_type_map.get(target_type, target_type)
for generic_type, target_type in generic_parameter_type_map.items()
}
resolved_tuple = tuple(
resolved_generic_type_map[generic_type] for generic_type in base_generic_type.__parameters__
)
source_type = base_generic_type[resolved_tuple]
else:
# we let Pydantic handle the rest
source_type = replace_types(source_type, generic_type_map)

return source_type

@staticmethod
def resolve_generic_types(model: type[BaseModel], instance: BaseModel) -> dict:
"""
Resolves the generic types of a given model instance and returns a generic type map.
Args:
model (type[BaseModel]): The model type.
instance (BaseModel): The instance of the model.
Returns:
dict: The generic type map.
"""
generic_type_map: dict = {}

for field_name, attr_value in list(instance):
if field_name not in model.__annotations__:
continue

annotation = model.__annotations__[field_name]

# if the annotation is an Annotated type, get the type annotation
if typing.get_origin(annotation) is typing.Annotated:
annotation = typing.get_args(annotation)[0]

# if the annotation is a type var, resolve it into the generic type map
if isinstance(annotation, typing.TypeVar):
LLMStructuredPrompt.add_resolved_type(generic_type_map, annotation, type(attr_value))
# if the annotation is a generic type alias ignore
elif isinstance(annotation, types.GenericAlias):
# The generic type alias is not supported yet
# The problem is that GenericAlias types are elided: e.g. type(list[str](["hello"])) -> list and not list[str].
# But either way, we would need to resolve the types based on the actual elements and their mros.
continue
# if the annotation is a type, check if it is a generic type
elif isinstance(annotation, type) and issubclass(annotation, generics.GenericModel):
# check if the type is in generics._assigned_parameters
generic_definition_type_map = LLMStructuredPrompt.get_generic_type_map(annotation)

argument_type = type(attr_value)
generic_instance_type_map = LLMStructuredPrompt.get_generic_type_map(argument_type)

assert list(generic_definition_type_map.keys()) == list(generic_instance_type_map.keys())

# update the generic type map
# if the generic type is already in the map, check that it is the same
for generic_parameter, generic_parameter_target in generic_definition_type_map.items():
if generic_parameter_target not in annotation.__parameters__:
continue
resolved_type = generic_instance_type_map[generic_parameter]
LLMStructuredPrompt.add_resolved_type(generic_type_map, generic_parameter_target, resolved_type)
else:
# Let Pydantic handle the rest
continue

return generic_type_map

@staticmethod
def add_resolved_type(generic_type_map, source_type, resolved_type):
"""
Add a resolved type to the generic type map.
"""
if source_type in generic_type_map:
# TODO: support finding the common base class?
if (previous_resolution := generic_type_map[source_type]) is not resolved_type:
raise ValueError(
f"Cannot resolve generic type {source_type}, conflicting "
f"resolution: {previous_resolution} and {resolved_type}."
)
else:
generic_type_map[source_type] = resolved_type

@staticmethod
def get_generic_type_map(generic_type, base_generic_type=None):
"""Build a generic type map for a generic type.
It maps the generic type variables to the actual types.
"""

if base_generic_type is None:
base_generic_type = LLMStructuredPrompt.get_base_generic_type(generic_type)

base_classes = inspect.getmro(generic_type)
# we have to iterate through the base classes
generic_parameter_type_map = {generic_type: generic_type for generic_type in generic_type.__parameters__}
for base_class in base_classes:
# skip baseclasses that are from pydantic.generic
# this avoids a bug that is caused by generics.GenericModel.__parameterized_bases_
if base_class.__module__ == "pydantic.generics":
continue
if issubclass(base_class, base_generic_type):
if base_class in generics._assigned_parameters:
assignment = generics._assigned_parameters[base_class]
generic_parameter_type_map = {
old_generic_type: generic_parameter_type_map.get(new_generic_type, new_generic_type)
for old_generic_type, new_generic_type in assignment.items()
}

return generic_parameter_type_map

@staticmethod
def get_base_generic_type(field_type: type) -> type[generics.GenericModel]:
"""Determine the base generic type of a generic type. E.g. List[str] -> List.
Args:
field_type (type): The generic type.
Raises:
ValueError: If the base generic type cannot be found.
Returns:
type[generics.GenericModel]: The base generic type.
"""

# get the base class name from annotation (which is without [])
base_generic_name = field_type.__name__
if "[" in field_type.__name__:
base_generic_name = field_type.__name__.split("[")[0]
# get the base class from argument_type_base_classes with base_generic_name
for base_class in reversed(inspect.getmro(field_type)):
if base_class.__name__ == base_generic_name and issubclass(field_type, base_class):
base_generic_type = base_class
break
else:
raise ValueError(f"Could not find base generic type {base_generic_name} for {field_type}.")
return base_generic_type

@trace_calls(name="LLMStructuredPrompt", kind=TraceNodeKind.CHAIN, capture_args=False, capture_return=False)
def __call__(
self,
Expand Down
Loading

0 comments on commit 8c44c5b

Please sign in to comment.