Skip to content

Commit

Permalink
Merge pull request #60 from Pradip-p/Pradip-p-feat-project-details
Browse files Browse the repository at this point in the history
feat: add authentication and validate DB results with Pydantic schemas
  • Loading branch information
nrjadkry authored Jul 7, 2024
2 parents dbf1e82 + 20f9ed5 commit d0fddb4
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 137 deletions.
116 changes: 37 additions & 79 deletions src/backend/app/projects/project_crud.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,28 @@
import json
import uuid
from typing import List, Optional
from sqlalchemy.orm import Session
from typing import Optional
from app.projects import project_schemas
from app.db import db_models
from loguru import logger as log
import shapely.wkb as wkblib
from shapely.geometry import shape
from fastapi import HTTPException
from app.utils import merge_multipolygon, str_to_geojson
from app.utils import merge_multipolygon
from fmtm_splitter.splitter import split_by_square
from fastapi.concurrency import run_in_threadpool
from app.db import database
from fastapi import Depends
from asyncio import gather
from databases import Database


async def create_project_with_project_info(
db: Database, project_metadata: project_schemas.ProjectIn
db: Database, author_id: uuid.UUID, project_metadata: project_schemas.ProjectIn
):
"""Create a project in database."""
project_id = uuid.uuid4()
_id = uuid.uuid4()
query = """
INSERT INTO projects (
id, author_id, name, short_description, description, per_task_instructions, status, visibility, outline, dem_url, created
)
VALUES (
:project_id,
:id,
:author_id,
:name,
:short_description,
Expand All @@ -41,12 +36,11 @@ async def create_project_with_project_info(
)
RETURNING id
"""
# new_project_id = await db.execute(query)
new_project_id = await db.execute(
project_id = await db.execute(
query,
values={
"project_id": project_id,
"author_id": str(110878106282210575794), # TODO: update this
"id": _id,
"author_id": author_id,
"name": project_metadata.name,
"short_description": project_metadata.short_description,
"description": project_metadata.description,
Expand All @@ -58,96 +52,60 @@ async def create_project_with_project_info(
},
)

if not new_project_id:
raise HTTPException(status_code=500, detail="Project could not be created")
# Fetch the newly created project using the returned ID
select_query = f"""
SELECT id, name, short_description, description, per_task_instructions, outline
FROM projects
WHERE id = '{new_project_id}'
WHERE id = '{project_id}'
"""
new_project = await db.fetch_one(query=select_query)
return new_project


async def get_project_by_id(
db: Session = Depends(database.get_db), project_id: Optional[int] = None
) -> db_models.DbProject:
"""Get a single project by id."""
db_project = (
db.query(db_models.DbProject)
.filter(db_models.DbProject.id == project_id)
.first()
)
return await convert_to_app_project(db_project)
db: Database, author_id: uuid.UUID, project_id: Optional[int] = None
):
"""Get a single project & all associated tasks by ID."""
raw_sql = """
SELECT
projects.id,
projects.name,
projects.short_description,
projects.description,
projects.per_task_instructions,
projects.outline
FROM projects
WHERE projects.author_id = :author_id
LIMIT 1;
"""

project_record = await db.fetch_one(raw_sql, {"author_id": author_id})
query = """ SELECT id, project_task_index, outline FROM tasks WHERE project_id = :project_id;"""
task_records = await db.fetch_all(query, {"project_id": project_id})
project_record.tasks = task_records
project_record.task_count = len(task_records)
return project_record


async def get_projects(
db: Database,
author_id: uuid.UUID,
skip: int = 0,
limit: int = 100,
):
"""Get all projects."""
raw_sql = """
SELECT id, name, short_description, description, per_task_instructions, outline
FROM projects
WHERE author_id = :author_id
ORDER BY id DESC
OFFSET :skip
LIMIT :limit;
"""
db_projects = await db.fetch_all(raw_sql, {"skip": skip, "limit": limit})
return await convert_to_app_projects(db_projects)


# async def get_projects(
# db: Session,
# skip: int = 0,
# limit: int = 100,
# ):
# """Get all projects."""
# db_projects = (
# db.query(db_models.DbProject)
# .order_by(db_models.DbProject.id.desc())
# .offset(skip)
# .limit(limit)
# .all()
# )
# project_count = db.query(db_models.DbProject).count()
# return project_count, await convert_to_app_projects(db_projects)


async def convert_to_app_projects(
db_projects: List[db_models.DbProject],
) -> List[project_schemas.ProjectOut]:
"""Legacy function to convert db models --> Pydantic.
TODO refactor to use Pydantic model methods instead.
"""
if db_projects and len(db_projects) > 0:

async def convert_project(project):
return await convert_to_app_project(project)

app_projects = await gather(
*[convert_project(project) for project in db_projects]
)
return [project for project in app_projects if project is not None]
else:
return []


async def convert_to_app_project(db_project: db_models.DbProject):
"""Legacy function to convert db models --> Pydantic."""
if not db_project:
log.debug("convert_to_app_project called, but no project provided")
return None
app_project = db_project

if db_project.outline:
app_project.outline_geojson = str_to_geojson(
db_project.outline, {"id": db_project.id}, db_project.id
)
return app_project
db_projects = await db.fetch_all(
raw_sql, {"author_id": author_id, "skip": skip, "limit": limit}
)
return db_projects


async def create_tasks_from_geojson(
Expand Down
45 changes: 37 additions & 8 deletions src/backend/app/projects/project_routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import json
import uuid
from app.users.user_deps import login_required
from app.users.user_schemas import AuthUser
import geojson
from datetime import timedelta

Expand All @@ -24,7 +26,11 @@


@router.delete("/{project_id}", tags=["Projects"])
def delete_project_by_id(project_id: int, db: Session = Depends(database.get_db)):
def delete_project_by_id(
project_id: uuid.UUID,
db: Session = Depends(database.get_db),
user: AuthUser = Depends(login_required),
):
"""
Delete a project by its ID, along with all associated tasks.
Expand Down Expand Up @@ -68,9 +74,13 @@ def delete_project_by_id(project_id: int, db: Session = Depends(database.get_db)
async def create_project(
project_info: project_schemas.ProjectIn,
db: Database = Depends(database.encode_db),
user_data: AuthUser = Depends(login_required),
):
"""Create a project in database."""
project = await project_crud.create_project_with_project_info(db, project_info)
author_id = user_data.id
project = await project_crud.create_project_with_project_info(
db, author_id, project_info
)
if not project:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail="Project creation failed"
Expand All @@ -83,6 +93,7 @@ async def upload_project_task_boundaries(
project_id: uuid.UUID,
task_geojson: UploadFile = File(...),
db: Database = Depends(database.encode_db),
user: AuthUser = Depends(login_required),
):
"""Set project task boundaries using split GeoJSON from frontend.
Expand All @@ -95,6 +106,13 @@ async def upload_project_task_boundaries(
Returns:
dict: JSON containing success message, project ID, and number of tasks.
"""
# check the project in Database
raw_sql = f"""SELECT id FROM projects WHERE id = '{project_id}' LIMIT 1;"""
project = await db.fetch_one(query=raw_sql)
if not project:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail="Project not found."
)
# read entire file
content = await task_geojson.read()
task_boundaries = json.loads(content)
Expand All @@ -108,7 +126,9 @@ async def upload_project_task_boundaries(

@router.post("/preview-split-by-square/", tags=["Projects"])
async def preview_split_by_square(
project_geojson: UploadFile = File(...), dimension: int = Form(100)
project_geojson: UploadFile = File(...),
dimension: int = Form(100),
user: AuthUser = Depends(login_required),
):
"""Preview splitting by square."""

Expand Down Expand Up @@ -164,19 +184,28 @@ async def generate_presigned_url(data: project_schemas.PresignedUrlRequest):

@router.get("/", tags=["Projects"], response_model=list[project_schemas.ProjectOut])
async def read_projects(
skip: int = 0, limit: int = 100, db: Database = Depends(database.encode_db)
skip: int = 0,
limit: int = 100,
db: Database = Depends(database.encode_db),
user_data: AuthUser = Depends(login_required),
):
"Return all projects"
projects = await project_crud.get_projects(db, skip, limit)
author_id = user_data.id
projects = await project_crud.get_projects(db, author_id, skip, limit)
return projects


@router.get(
"/{project_id}", tags=["Projects"], response_model=project_schemas.ProjectOut
)
async def read_project(
db: Session = Depends(database.get_db),
project: db_models.DbProject = Depends(project_crud.get_project_by_id),
project_id: uuid.UUID,
db: Database = Depends(database.encode_db),
user_data: AuthUser = Depends(login_required),
):
"""Get a specific project by ID."""
"""Get a specific project and all associated tasks by ID."""
author_id = user_data.id
project = await project_crud.get_project_by_id(db, author_id, project_id)
if project is None:
raise HTTPException(status_code=404, detail="Project not found")
return project
22 changes: 21 additions & 1 deletion src/backend/app/projects/project_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,25 @@ def centroid(self) -> Optional[Any]:
return write_wkb(read_wkb(self.outline).centroid)


class TaskOut(BaseModel):
"""Base project model."""

id: uuid.UUID
project_task_index: int
outline: Any = Field(exclude=True)

@computed_field
@property
def outline_geojson(self) -> Optional[Feature]:
"""Compute the geojson outline from WKBElement outline."""
if not self.outline:
return None
wkb_data = bytes.fromhex(self.outline)
geom = wkb.loads(wkb_data)
bbox = geom.bounds # Calculate bounding box
return str_to_geojson(self.outline, {"id": self.id, "bbox": bbox}, str(self.id))


class ProjectOut(BaseModel):
"""Base project model."""

Expand All @@ -65,6 +84,8 @@ class ProjectOut(BaseModel):
description: str
per_task_instructions: Optional[str] = None
outline: Any = Field(exclude=True)
tasks: list[TaskOut] = []
task_count: int = None

@computed_field
@property
Expand All @@ -74,7 +95,6 @@ def outline_geojson(self) -> Optional[Feature]:
return None
wkb_data = bytes.fromhex(self.outline)
geom = wkb.loads(wkb_data)
# geometry = wkb.loads(bytes(self.outline.data))
bbox = geom.bounds # Calculate bounding box
return str_to_geojson(self.outline, {"id": self.id, "bbox": bbox}, str(self.id))

Expand Down
25 changes: 1 addition & 24 deletions src/backend/app/users/user_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any
from passlib.context import CryptContext
from app.db import db_models
from app.users.user_schemas import UserCreate, AuthUser, ProfileUpdate
from app.users.user_schemas import AuthUser, ProfileUpdate
from databases import Database
from fastapi import HTTPException
from app.models.enums import UserRole
Expand Down Expand Up @@ -95,29 +95,6 @@ async def authenticate(
return db_user


# def authenticate(db: Session, username: str, password: str) -> db_models.DbUser | None:
# db_user = get_user_by_username(db, username)
# if not db_user:
# return None
# if not verify_password(password, db_user.password):
# return None
# return db_user


async def create_user(db: Database, user_create: UserCreate):
query = f"""
INSERT INTO users (username, password, is_active, name, email_address, is_superuser)
VALUES ('{user_create.username}', '{get_password_hash(user_create.password)}', {True}, '{user_create.name}', '{user_create.email_address}', {False})
RETURNING id
"""
_id = await db.execute(query)
raw_query = f"SELECT * from users WHERE id = {_id} LIMIT 1"
db_obj = await db.fetch_one(query=raw_query)
if not db_obj:
raise HTTPException(status_code=500, detail="User could not be created")
return db_obj


async def get_or_create_user(
db: Database,
user_data: AuthUser,
Expand Down
Loading

0 comments on commit d0fddb4

Please sign in to comment.