diff --git a/backend/app/models/chat.py b/backend/app/models/chat.py index 2e619b47..c86ef788 100644 --- a/backend/app/models/chat.py +++ b/backend/app/models/chat.py @@ -2,6 +2,7 @@ from uuid import UUID from typing import Optional, Dict from pydantic import BaseModel +from sqlalchemy.types import TypeDecorator, Integer from datetime import datetime from sqlmodel import ( @@ -9,7 +10,6 @@ Column, DateTime, JSON, - SmallInteger, Relationship as SQLRelationship, ) @@ -20,6 +20,31 @@ class ChatVisibility(int, enum.Enum): PRIVATE = 0 PUBLIC = 1 +# Avoid Pydantic serializer warnings: +# When fetching values from the database, SQLAlchemy provides raw integers (0 or 1), +# which leads to warnings during serialization because Pydantic requires the ChatVisibility enum type. + +# automatically handle the conversion between int and enum types +class IntEnumType(TypeDecorator): + impl = Integer + + def __init__(self, enum_class, *args, **kwargs): + super().__init__(*args, **kwargs) + self.enum_class = enum_class + + def process_bind_param(self, value, dialect): + # enum -> int + if isinstance(value, self.enum_class): + return value.value + elif value is None: + return None + raise ValueError(f"Invalid value for {self.enum_class}: {value}") + + def process_result_value(self, value, dialect): + # int -> enum + if value is not None: + return self.enum_class(value) + return None class Chat(UUIDBaseModel, UpdatableBaseModel, table=True): title: str = Field(max_length=256) @@ -43,7 +68,8 @@ class Chat(UUIDBaseModel, UpdatableBaseModel, table=True): browser_id: str = Field(max_length=50, nullable=True) origin: str = Field(max_length=256, default=None, nullable=True) visibility: ChatVisibility = Field( - sa_column=Column(SmallInteger, default=ChatVisibility.PRIVATE, nullable=False) + # sa_column=Column(SmallInteger, default=ChatVisibility.PRIVATE, nullable=False) + sa_column=Column(IntEnumType(ChatVisibility), nullable=False, default=ChatVisibility.PRIVATE) ) __tablename__ = "chats"