diff --git a/src/backend/app/drones/drone_crud.py b/src/backend/app/drones/drone_crud.py index 0bb50783..1494091f 100644 --- a/src/backend/app/drones/drone_crud.py +++ b/src/backend/app/drones/drone_crud.py @@ -1,14 +1,15 @@ from app.drones import drone_schemas from app.models.enums import HTTPStatus -from databases import Database from loguru import logger as log from fastapi import HTTPException -from asyncpg import UniqueViolationError +from psycopg import Connection + +# from asyncpg import UniqueViolationError from typing import List from app.drones.drone_schemas import DroneOut -async def read_all_drones(db: Database) -> List[DroneOut]: +async def read_all_drones(db: Connection) -> List[DroneOut]: """ Retrieves all drone records from the database. @@ -32,7 +33,7 @@ async def read_all_drones(db: Database) -> List[DroneOut]: ) from e -async def delete_drone(db: Database, drone_id: int) -> bool: +async def delete_drone(db: Connection, drone_id: int) -> bool: """ Deletes a drone record from the database, along with associated drone flights. @@ -63,7 +64,7 @@ async def delete_drone(db: Database, drone_id: int) -> bool: ) from e -async def get_drone(db: Database, drone_id: int): +async def get_drone(db: Connection, drone_id: int): """ Retrieves a drone record from the database. @@ -89,7 +90,7 @@ async def get_drone(db: Database, drone_id: int): ) from e -async def create_drone(db: Database, drone_info: drone_schemas.DroneIn): +async def create_drone(db: Connection, drone_info: drone_schemas.DroneIn): """ Creates a new drone record in the database. @@ -116,12 +117,12 @@ async def create_drone(db: Database, drone_info: drone_schemas.DroneIn): result = await db.execute(insert_query, drone_info.__dict__) return result - except UniqueViolationError as e: - log.exception("Unique constraint violation: %s", e) - raise HTTPException( - status_code=HTTPStatus.CONFLICT, - detail="A drone with this model already exists", - ) + # except UniqueViolationError as e: + # log.exception("Unique constraint violation: %s", e) + # raise HTTPException( + # status_code=HTTPStatus.CONFLICT, + # detail="A drone with this model already exists", + # ) except Exception as e: log.exception(e) diff --git a/src/backend/app/drones/drone_routes.py b/src/backend/app/drones/drone_routes.py index a3fc232e..4fee8b4f 100644 --- a/src/backend/app/drones/drone_routes.py +++ b/src/backend/app/drones/drone_routes.py @@ -1,11 +1,12 @@ +from typing import Annotated from app.users.user_deps import login_required from app.users.user_schemas import AuthUser from app.models.enums import HTTPStatus from fastapi import APIRouter, Depends, HTTPException -from app.db.database import get_db +from app.db import database from app.config import settings from app.drones import drone_schemas -from databases import Database +from psycopg import Connection from app.drones import drone_crud from typing import List @@ -18,8 +19,8 @@ @router.get("/", tags=["Drones"], response_model=List[drone_schemas.DroneOut]) async def read_drones( - db: Database = Depends(get_db), - user_data: AuthUser = Depends(login_required), + db: Annotated[Connection, Depends(database.get_db)], + user_data: Annotated[AuthUser, Depends(login_required)], ): """ Retrieves all drone records from the database. @@ -38,8 +39,8 @@ async def read_drones( @router.delete("/{drone_id}", tags=["Drones"]) async def delete_drone( drone_id: int, - db: Database = Depends(get_db), - user_data: AuthUser = Depends(login_required), + db: Annotated[Connection, Depends(database.get_db)], + user_data: Annotated[AuthUser, Depends(login_required)], ): """ Deletes a drone record from the database. @@ -61,8 +62,8 @@ async def delete_drone( @router.post("/create_drone", tags=["Drones"]) async def create_drone( drone_info: drone_schemas.DroneIn, - db: Database = Depends(get_db), - user_data: AuthUser = Depends(login_required), + db: Annotated[Connection, Depends(database.get_db)], + user_data: Annotated[AuthUser, Depends(login_required)], ): """ Creates a new drone record in the database. @@ -86,8 +87,8 @@ async def create_drone( @router.get("/{drone_id}", tags=["Drones"], response_model=drone_schemas.DroneOut) async def read_drone( drone_id: int, - db: Database = Depends(get_db), - user_data: AuthUser = Depends(login_required), + db: Annotated[Connection, Depends(database.get_db)], + user_data: Annotated[AuthUser, Depends(login_required)], ): """ Retrieves a drone record from the database. diff --git a/src/backend/app/projects/project_schemas.py b/src/backend/app/projects/project_schemas.py index e83c6a00..cb4f815d 100644 --- a/src/backend/app/projects/project_schemas.py +++ b/src/backend/app/projects/project_schemas.py @@ -47,8 +47,6 @@ class ProjectIn(BaseModel): dem_url: Optional[str] = None gsd_cm_px: float = None is_terrain_follow: bool = False - # TODO change all references outline_geojson --> outline - # TODO also no_fly_zones outline: Annotated[ FeatureCollection | Feature | Polygon, AfterValidator(validate_geojson) ] diff --git a/src/backend/app/tasks/task_crud.py b/src/backend/app/tasks/task_crud.py index 4044abc8..32a75897 100644 --- a/src/backend/app/tasks/task_crud.py +++ b/src/backend/app/tasks/task_crud.py @@ -1,11 +1,11 @@ import uuid -from databases import Database from app.models.enums import HTTPStatus, State from fastapi import HTTPException from loguru import logger as log +from psycopg import Connection -async def get_tasks_by_user(user_id: str, db: Database): +async def get_tasks_by_user(user_id: str, db: Connection): try: query = """WITH task_details AS ( SELECT @@ -42,7 +42,7 @@ async def get_tasks_by_user(user_id: str, db: Database): ) from e -async def get_all_tasks(db: Database, project_id: uuid.UUID): +async def get_all_tasks(db: Connection, project_id: uuid.UUID): query = """ SELECT id FROM tasks WHERE project_id = :project_id """ @@ -56,7 +56,7 @@ async def get_all_tasks(db: Database, project_id: uuid.UUID): return task_ids -async def all_tasks_states(db: Database, project_id: uuid.UUID): +async def all_tasks_states(db: Connection, project_id: uuid.UUID): query = """ SELECT DISTINCT ON (task_id) project_id, task_id, state FROM task_events @@ -95,7 +95,11 @@ async def all_tasks_states(db: Database, project_id: uuid.UUID): async def request_mapping( - db: Database, project_id: uuid.UUID, task_id: uuid.UUID, user_id: str, comment: str + db: Connection, + project_id: uuid.UUID, + task_id: uuid.UUID, + user_id: str, + comment: str, ): query = """ WITH last AS ( @@ -140,7 +144,7 @@ async def request_mapping( async def update_task_state( - db: Database, + db: Connection, project_id: uuid.UUID, task_id: uuid.UUID, user_id: str, @@ -183,7 +187,7 @@ async def update_task_state( async def get_requested_user_id( - db: Database, project_id: uuid.UUID, task_id: uuid.UUID + db: Connection, project_id: uuid.UUID, task_id: uuid.UUID ): query = """ SELECT user_id @@ -204,7 +208,7 @@ async def get_requested_user_id( return result["user_id"] -async def get_project_task_by_id(db: Database, user_id: str): +async def get_project_task_by_id(db: Connection, user_id: str): """Get a list of pending tasks created by a specific user (project creator).""" raw_sql = """ SELECT t.id AS task_id, te.event_id, te.user_id, te.project_id, te.comment, te.state, te.created_at diff --git a/src/backend/app/tasks/task_routes.py b/src/backend/app/tasks/task_routes.py index 87caebcf..c5d78b32 100644 --- a/src/backend/app/tasks/task_routes.py +++ b/src/backend/app/tasks/task_routes.py @@ -1,4 +1,5 @@ import uuid +from typing import Annotated from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from app.config import settings from app.models.enums import EventType, State, UserRole @@ -6,7 +7,7 @@ from app.users.user_deps import login_required from app.users.user_schemas import AuthUser from app.users.user_crud import get_user_by_id -from databases import Database +from psycopg import Connection from app.db import database from app.utils import send_notification_email, render_email_template from app.projects.project_crud import get_project_by_id @@ -21,8 +22,8 @@ @router.get("/", response_model=list[task_schemas.UserTasksStatsOut]) async def list_tasks( - db: Database = Depends(database.get_db), - user_data: AuthUser = Depends(login_required), + db: Annotated[Connection, Depends(database.get_db)], + user_data: Annotated[AuthUser, Depends(login_required)], ): """Get all tasks for a drone user.""" @@ -31,7 +32,9 @@ async def list_tasks( @router.get("/states/{project_id}") -async def task_states(project_id: uuid.UUID, db: Database = Depends(database.get_db)): +async def task_states( + db: Annotated[Connection, Depends(database.get_db)], project_id: uuid.UUID +): """Get all tasks states for a project.""" return await task_crud.all_tasks_states(db, project_id) @@ -39,12 +42,12 @@ async def task_states(project_id: uuid.UUID, db: Database = Depends(database.get @router.post("/event/{project_id}/{task_id}") async def new_event( + db: Annotated[Connection, Depends(database.get_db)], background_tasks: BackgroundTasks, project_id: uuid.UUID, task_id: uuid.UUID, detail: task_schemas.NewEvent, - user_data: AuthUser = Depends(login_required), - db: Database = Depends(database.get_db), + user_data: Annotated[AuthUser, Depends(login_required)], ): user_id = user_data.id @@ -212,8 +215,8 @@ async def new_event( @router.get("/requested_tasks/pending") async def get_pending_tasks( - user_data: AuthUser = Depends(login_required), - db: Database = Depends(database.get_db), + db: Annotated[Connection, Depends(database.get_db)], + user_data: Annotated[AuthUser, Depends(login_required)], ): """Get a list of pending tasks for a project creator.""" user_id = user_data.id diff --git a/src/backend/app/users/user_crud.py b/src/backend/app/users/user_crud.py index af521924..ec649102 100644 --- a/src/backend/app/users/user_crud.py +++ b/src/backend/app/users/user_crud.py @@ -5,9 +5,9 @@ from passlib.context import CryptContext from app.db import db_models from app.users.user_schemas import AuthUser, ProfileUpdate -from databases import Database from fastapi import HTTPException from pydantic import EmailStr +from psycopg import Connection pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -61,26 +61,26 @@ def get_password_hash(password: str) -> str: return pwd_context.hash(password) -async def get_user_by_id(db: Database, id: str): +async def get_user_by_id(db: Connection, id: str): query = "SELECT * FROM users WHERE id = :id LIMIT 1;" result = await db.fetch_one(query, {"id": id}) return result -async def get_userprofile_by_userid(db: Database, user_id: str): +async def get_userprofile_by_userid(db: Connection, user_id: str): query = "SELECT * FROM user_profile WHERE user_id = :user_id LIMIT 1;" result = await db.fetch_one(query, {"user_id": user_id}) return result -async def get_user_by_email(db: Database, email: str): +async def get_user_by_email(db: Connection, email: str): query = "SELECT * FROM users WHERE email_address = :email LIMIT 1;" result = await db.fetch_one(query, {"email": email}) return result async def authenticate( - db: Database, email: EmailStr, password: str + db: Connection, email: EmailStr, password: str ) -> db_models.DbUser | None: db_user = await get_user_by_email(db, email) if not db_user: @@ -91,7 +91,7 @@ async def authenticate( async def get_or_create_user( - db: Database, + db: Connection, user_data: AuthUser, ): """Get user from User table if exists, else create.""" @@ -132,7 +132,7 @@ async def get_or_create_user( async def update_user_profile( - db: Database, user_id: int, profile_update: ProfileUpdate + db: Connection, user_id: int, profile_update: ProfileUpdate ): """ Update user profile in the database. diff --git a/src/backend/app/users/user_routes.py b/src/backend/app/users/user_routes.py index 5bef1d91..eae74869 100644 --- a/src/backend/app/users/user_routes.py +++ b/src/backend/app/users/user_routes.py @@ -12,7 +12,7 @@ from app.users import user_crud from app.db import database from app.models.enums import HTTPStatus -from databases import Database +from psycopg import Connection from fastapi.responses import JSONResponse from loguru import logger as log @@ -31,7 +31,7 @@ @router.post("/login/") async def login_access_token( form_data: Annotated[OAuth2PasswordRequestForm, Depends()], - db: Database = Depends(database.get_db), + db: Annotated[Connection, Depends(database.get_db)], ) -> Token: """ OAuth2 compatible token login, get an access token for future requests @@ -60,8 +60,8 @@ async def login_access_token( async def update_user_profile( user_id: str, profile_update: ProfileUpdate, - db: Database = Depends(database.get_db), - user_data: AuthUser = Depends(login_required), + db: Annotated[Connection, Depends(database.get_db)], + user_data: Annotated[AuthUser, Depends(login_required)], ): """ Update user profile based on provided user_id and profile_update data. @@ -124,7 +124,7 @@ async def callback(request: Request, google_auth=Depends(init_google_auth)): @router.get("/refresh-token", response_model=Token) -async def update_token(user_data: AuthUser = Depends(login_required)): +async def update_token(user_data: Annotated[AuthUser, Depends(login_required)]): """Refresh access token""" access_token, refresh_token = await user_crud.create_access_token( @@ -135,8 +135,8 @@ async def update_token(user_data: AuthUser = Depends(login_required)): @router.get("/my-info/") async def my_data( - db: Database = Depends(database.get_db), - user_data: AuthUser = Depends(login_required), + db: Annotated[Connection, Depends(database.get_db)], + user_data: Annotated[AuthUser, Depends(login_required)], ): """Read access token and get user details from Google"""