Skip to content

Commit

Permalink
chore(internal): support rendering enum types from the DMMF
Browse files Browse the repository at this point in the history
This will be helpful for transaction isolation levels support

#878
  • Loading branch information
RobertCraigie committed Feb 24, 2024
1 parent 3de2b22 commit 68083ee
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 5 deletions.
20 changes: 16 additions & 4 deletions src/prisma/generator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def to_params(self) -> Dict[str, Any]:
"""Get the parameters that should be sent to Jinja templates"""
params = vars(self)
params['type_schema'] = Schema.from_data(self)
params['client_types'] = ClientTypes.from_data(self)

# add utility functions
for func in [
Expand Down Expand Up @@ -628,11 +629,22 @@ def engine_type_validator(cls, value: EngineType) -> EngineType:
assert_never(value)


class EnumType(BaseModel):
name: str
values: List[object]


class EnumTypes(BaseModel):
prisma: List[EnumType]


class PrismaSchema(BaseModel):
enum_types: EnumTypes = FieldInfo(alias='enumTypes')


class DMMF(BaseModel):
datamodel: 'Datamodel'

# TODO
prisma_schema: Any = FieldInfo(alias='schema')
prisma_schema: PrismaSchema = FieldInfo(alias='schema')


class Datamodel(BaseModel):
Expand Down Expand Up @@ -1182,4 +1194,4 @@ class DefaultData(GenericData[_EmptyModel]):
TemplateError,
PartialTypeGeneratorError,
)
from .schema import Schema
from .schema import Schema, ClientTypes
31 changes: 30 additions & 1 deletion src/prisma/generator/schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from enum import Enum
from typing import Any, Dict, List, Type, Tuple, Union
from typing import Any, Dict, List, Type, Tuple, Union, Optional
from typing_extensions import ClassVar

from pydantic import BaseModel

from .utils import to_constant_case
from .models import Model as ModelInfo, AnyData, PrimaryKey
from .._compat import (
PYDANTIC_V2,
Expand All @@ -18,6 +19,7 @@ class Kind(str, Enum):
alias = 'alias'
union = 'union'
typeddict = 'typeddict'
enum = 'enum'


class PrismaType(BaseModel):
Expand Down Expand Up @@ -45,6 +47,11 @@ class PrismaUnion(PrismaType):
subtypes: List[PrismaType]


class PrismaEnum(PrismaType):
kind: Kind = Kind.enum
members: List[Tuple[str, str]]


class PrismaAlias(PrismaType):
kind: Kind = Kind.alias
to: str
Expand Down Expand Up @@ -143,6 +150,28 @@ def order_by(self) -> PrismaType:
return PrismaType.from_subtypes(subtypes, name=f'{model}OrderByInput')


class ClientTypes(BaseModel):
transaction_isolation_level: Optional[PrismaEnum]

@classmethod
def from_data(cls, data: AnyData) -> 'ClientTypes':
enum_types = data.dmmf.prisma_schema.enum_types.prisma

tx_isolation = next((t for t in enum_types if t.name == 'TransactionIsolationLevel'), None)
if tx_isolation is not None:
tx_isolation = PrismaEnum(
name='TransactionIsolationLevel',
members=[
(to_constant_case(str(value)), str(value))
for value in tx_isolation.values
],
)

return cls(
transaction_isolation_level=tx_isolation,
)


model_rebuild(Schema)
model_rebuild(PrismaType)
model_rebuild(PrismaDict)
Expand Down
5 changes: 5 additions & 0 deletions src/prisma/generator/templates/types.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ from .utils import _NoneType
},
total={{ type.total }}
)
{% elif type.kind == 'enum' %}
class {{ type.name }}(StrEnum):
{% for name, value in type.members %}
{{ name }} = "{{ value }}"
{% endfor %}
{% else %}
{{ raise_err('Unhandled type kind: %s' % type.kind) }}
{% endif %}
Expand Down

0 comments on commit 68083ee

Please sign in to comment.