Skip to content

Commit

Permalink
Avoid Pydantic serializer warnings for Chat
Browse files Browse the repository at this point in the history
  • Loading branch information
sszgwdk committed Jan 18, 2025
1 parent ab557b2 commit 803437e
Showing 1 changed file with 28 additions and 2 deletions.
30 changes: 28 additions & 2 deletions backend/app/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from uuid import UUID
from typing import Optional, Dict
from pydantic import BaseModel
from sqlalchemy.types import TypeDecorator, Integer
from datetime import datetime

from sqlmodel import (
Field,
Column,
DateTime,
JSON,
SmallInteger,
Relationship as SQLRelationship,
)

Expand All @@ -20,6 +20,31 @@ class ChatVisibility(int, enum.Enum):
PRIVATE = 0
PUBLIC = 1

# Avoid Pydantic serializer warnings:
# When fetching values from the database, SQLAlchemy provides raw integers (0 or 1),
# which leads to warnings during serialization because Pydantic requires the ChatVisibility enum type.

# automatically handle the conversion between int and enum types
class IntEnumType(TypeDecorator):
impl = Integer

def __init__(self, enum_class, *args, **kwargs):
super().__init__(*args, **kwargs)
self.enum_class = enum_class

def process_bind_param(self, value, dialect):
# enum -> int
if isinstance(value, self.enum_class):
return value.value
elif value is None:
return None
raise ValueError(f"Invalid value for {self.enum_class}: {value}")

def process_result_value(self, value, dialect):
# int -> enum
if value is not None:
return self.enum_class(value)
return None

class Chat(UUIDBaseModel, UpdatableBaseModel, table=True):
title: str = Field(max_length=256)
Expand All @@ -43,7 +68,8 @@ class Chat(UUIDBaseModel, UpdatableBaseModel, table=True):
browser_id: str = Field(max_length=50, nullable=True)
origin: str = Field(max_length=256, default=None, nullable=True)
visibility: ChatVisibility = Field(
sa_column=Column(SmallInteger, default=ChatVisibility.PRIVATE, nullable=False)
# sa_column=Column(SmallInteger, default=ChatVisibility.PRIVATE, nullable=False)
sa_column=Column(IntEnumType(ChatVisibility), nullable=False, default=ChatVisibility.PRIVATE)
)

__tablename__ = "chats"
Expand Down

0 comments on commit 803437e

Please sign in to comment.