diff --git a/src/prisma/generator/models.py b/src/prisma/generator/models.py index c445ab01f..5c40921ad 100644 --- a/src/prisma/generator/models.py +++ b/src/prisma/generator/models.py @@ -353,6 +353,7 @@ def to_params(self) -> Dict[str, Any]: """Get the parameters that should be sent to Jinja templates""" params = vars(self) params['type_schema'] = Schema.from_data(self) + params['client_types'] = ClientTypes.from_data(self) # add utility functions for func in [ @@ -628,11 +629,22 @@ def engine_type_validator(cls, value: EngineType) -> EngineType: assert_never(value) +class EnumType(BaseModel): + name: str + values: List[object] + + +class EnumTypes(BaseModel): + prisma: List[EnumType] + + +class PrismaSchema(BaseModel): + enum_types: EnumTypes = FieldInfo(alias='enumTypes') + + class DMMF(BaseModel): datamodel: 'Datamodel' - - # TODO - prisma_schema: Any = FieldInfo(alias='schema') + prisma_schema: PrismaSchema = FieldInfo(alias='schema') class Datamodel(BaseModel): @@ -1182,4 +1194,4 @@ class DefaultData(GenericData[_EmptyModel]): TemplateError, PartialTypeGeneratorError, ) -from .schema import Schema +from .schema import Schema, ClientTypes diff --git a/src/prisma/generator/schema.py b/src/prisma/generator/schema.py index d04a572e0..36da712c1 100644 --- a/src/prisma/generator/schema.py +++ b/src/prisma/generator/schema.py @@ -1,9 +1,10 @@ from enum import Enum -from typing import Any, Dict, List, Type, Tuple, Union +from typing import Any, Dict, List, Type, Tuple, Union, Optional from typing_extensions import ClassVar from pydantic import BaseModel +from .utils import to_constant_case from .models import Model as ModelInfo, AnyData, PrimaryKey from .._compat import ( PYDANTIC_V2, @@ -18,6 +19,7 @@ class Kind(str, Enum): alias = 'alias' union = 'union' typeddict = 'typeddict' + enum = 'enum' class PrismaType(BaseModel): @@ -45,6 +47,11 @@ class PrismaUnion(PrismaType): subtypes: List[PrismaType] +class PrismaEnum(PrismaType): + kind: Kind = Kind.enum + members: List[Tuple[str, str]] + + class PrismaAlias(PrismaType): kind: Kind = Kind.alias to: str @@ -143,6 +150,28 @@ def order_by(self) -> PrismaType: return PrismaType.from_subtypes(subtypes, name=f'{model}OrderByInput') +class ClientTypes(BaseModel): + transaction_isolation_level: Optional[PrismaEnum] + + @classmethod + def from_data(cls, data: AnyData) -> 'ClientTypes': + enum_types = data.dmmf.prisma_schema.enum_types.prisma + + tx_isolation = next((t for t in enum_types if t.name == 'TransactionIsolationLevel'), None) + if tx_isolation is not None: + tx_isolation = PrismaEnum( + name='TransactionIsolationLevel', + members=[ + (to_constant_case(str(value)), str(value)) + for value in tx_isolation.values + ], + ) + + return cls( + transaction_isolation_level=tx_isolation, + ) + + model_rebuild(Schema) model_rebuild(PrismaType) model_rebuild(PrismaDict) diff --git a/src/prisma/generator/templates/types.py.jinja b/src/prisma/generator/templates/types.py.jinja index 666bbc84f..cc9ac46a9 100644 --- a/src/prisma/generator/templates/types.py.jinja +++ b/src/prisma/generator/templates/types.py.jinja @@ -65,6 +65,11 @@ from .utils import _NoneType }, total={{ type.total }} ) +{% elif type.kind == 'enum' %} +class {{ type.name }}(StrEnum): + {% for name, value in type.members %} + {{ name }} = "{{ value }}" + {% endfor %} {% else %} {{ raise_err('Unhandled type kind: %s' % type.kind) }} {% endif %}