Skip to content

Commit

Permalink
Merge pull request #1773 from SciPhi-AI/feature/add-limit-checks-and-…
Browse files Browse the repository at this point in the history
…sdk-tweaks

commit
  • Loading branch information
emrgnt-cmplxty authored Jan 8, 2025
2 parents e9fc549 + 1f8b3c3 commit 0caece6
Show file tree
Hide file tree
Showing 12 changed files with 331 additions and 28 deletions.
2 changes: 1 addition & 1 deletion js/sdk/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "r2r-js",
"version": "0.4.10",
"version": "0.4.11",
"description": "",
"main": "dist/index.js",
"browser": "dist/index.browser.js",
Expand Down
3 changes: 3 additions & 0 deletions js/sdk/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ export interface User {
name?: string;
bio?: string;
profilePicture?: string;
metadata?: Record<string, any>;
limitOverrides?: Record<string, any>;
documentIds?: string[];
}

// Generic Responses
Expand Down
38 changes: 36 additions & 2 deletions js/sdk/src/v3/clients/users.ts
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,19 @@ export class UsersClient {
* @returns WrappedAPIKeyResponse
*/
@feature("users.createApiKey")
async createApiKey(options: { id: string }): Promise<WrappedAPIKeyResponse> {
return this.client.makeRequest("POST", `users/${options.id}/api-keys`);
async createApiKey(options: {
id: string;
name?: string;
description?: string;
}): Promise<WrappedAPIKeyResponse> {
const data = {
...(options.name && { name: options.name }),
...(options.description && { description: options.description }),
};

return this.client.makeRequest("POST", `users/${options.id}/api-keys`, {
data,
});
}

/**
Expand Down Expand Up @@ -503,4 +514,27 @@ export class UsersClient {
async getLimits(options: { id: string }): Promise<any> {
return this.client.makeRequest("GET", `users/${options.id}/limits`);
}


/**
* **Patch metadata** for a user using a Stripe-like approach.
*
* The `metadata` parameter merges existing metadata with new keys and values:
* - `metadata[key] = "some value"` => sets or updates the key
* - `metadata[key] = ""` => removes the key
* - empty `{}` => removes all metadata keys
*
* @param id The user ID to patch
* @param metadata Partial metadata updates
* @returns WrappedUserResponse
*/
@feature("users.patchMetadata")
async patchMetadata(options: {
id: string;
metadata: Record<string, string | null>;
}): Promise<WrappedUserResponse> {
return this.client.makeRequest("PATCH", `users/${options.id}/metadata`, {
data: options.metadata,
});
}
}
83 changes: 73 additions & 10 deletions py/core/database/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,39 @@
from .collections import PostgresCollectionsHandler


def _merge_metadata(
existing_metadata: dict[str, str], new_metadata: dict[str, Optional[str]]
) -> dict[str, str]:
"""
Merges the new metadata with the existing metadata in the Stripe-style approach:
- new_metadata[key] = <string> => update or add that key
- new_metadata[key] = "" => remove that key
- if new_metadata is empty => remove all keys
"""
# If new_metadata is an empty dict, it signals removal of all keys.
if new_metadata == {}:
return {}

# Copy so we don't mutate the original
final_metadata = dict(existing_metadata)

for key, value in new_metadata.items():
# If the user sets the key to an empty string, it means "delete" that key
if value == "":
if key in final_metadata:
del final_metadata[key]
# If not None and not empty, set or override
elif value is not None:
final_metadata[key] = value
else:
# If the user sets the value to None in some contexts, decide if you want to remove or ignore
# For now we might treat None same as empty string => remove
if key in final_metadata:
del final_metadata[key]

return final_metadata


class PostgresUserHandler(Handler):
TABLE_NAME = "users"
API_KEYS_TABLE_NAME = "users_api_keys"
Expand Down Expand Up @@ -47,6 +80,7 @@ async def create_tables(self):
reset_token_expiry TIMESTAMPTZ,
collection_ids UUID[] NULL,
limits_overrides JSONB,
metadata JSONB,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
);
Expand All @@ -60,6 +94,7 @@ async def create_tables(self):
public_key TEXT UNIQUE NOT NULL,
hashed_key TEXT NOT NULL,
name TEXT,
description TEXT,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
);
Expand Down Expand Up @@ -92,6 +127,7 @@ async def get_user_by_id(self, id: UUID) -> User:
"bio",
"collection_ids",
"limits_overrides",
"metadata",
]
)
.where("id = $1")
Expand All @@ -116,6 +152,7 @@ async def get_user_by_id(self, id: UUID) -> User:
bio=result["bio"],
collection_ids=result["collection_ids"],
limits_overrides=json.loads(result["limits_overrides"] or "{}"),
metadata=json.loads(result["metadata"] or "{}"),
)

async def get_user_by_email(self, email: str) -> User:
Expand All @@ -135,6 +172,7 @@ async def get_user_by_email(self, email: str) -> User:
"profile_picture",
"bio",
"collection_ids",
"metadata",
"limits_overrides",
]
)
Expand All @@ -159,6 +197,7 @@ async def get_user_by_email(self, email: str) -> User:
bio=result["bio"],
collection_ids=result["collection_ids"],
limits_overrides=json.loads(result["limits_overrides"] or "{}"),
metadata=json.loads(result["metadata"] or "{}"),
)

async def create_user(
Expand Down Expand Up @@ -187,6 +226,7 @@ async def create_user(
"hashed_password": hashed_password,
"collection_ids": [],
"limits_overrides": None,
"metadata": None,
}
)
.returning(
Expand All @@ -200,6 +240,7 @@ async def create_user(
"updated_at",
"collection_ids",
"limits_overrides",
"metadata",
]
)
.build()
Expand All @@ -223,13 +264,17 @@ async def create_user(
collection_ids=result["collection_ids"] or [],
hashed_password=hashed_password,
limits_overrides=json.loads(result["limits_overrides"] or "{}"),
metadata=json.loads(result["metadata"] or "{}"),
name=None,
bio=None,
profile_picture=None,
)

async def update_user(
self, user: User, merge_limits: bool = False
self,
user: User,
merge_limits: bool = False,
new_metadata: dict[str, Optional[str]] | None = None,
) -> User:
"""
Update user information including limits_overrides.
Expand All @@ -242,13 +287,19 @@ async def update_user(
Returns:
Updated User object
"""

# Get current user if we need to merge limits or get hashed password
current_user = None
try:
current_user = await self.get_user_by_id(user.id)
except R2RException:
raise R2RException(status_code=404, message="User not found")

# Merge or replace metadata if provided
final_metadata = current_user.metadata or {}
if new_metadata is not None:
final_metadata = _merge_metadata(final_metadata, new_metadata)

# Merge or replace limits_overrides
final_limits = user.limits_overrides
if (
Expand All @@ -271,11 +322,12 @@ async def update_user(
profile_picture = $6,
bio = $7,
collection_ids = $8,
limits_overrides = $9::jsonb
WHERE id = $10
limits_overrides = $9::jsonb,
metadata = $10::jsonb
WHERE id = $11
RETURNING id, email, is_superuser, is_active, is_verified,
created_at, updated_at, name, profile_picture, bio,
collection_ids, limits_overrides, hashed_password
collection_ids, limits_overrides, metadata, hashed_password
"""
result = await self.connection_manager.fetchrow_query(
query,
Expand All @@ -289,6 +341,7 @@ async def update_user(
user.bio,
user.collection_ids or [],
json.dumps(final_limits),
json.dumps(final_metadata),
user.id,
],
)
Expand Down Expand Up @@ -318,6 +371,7 @@ async def update_user(
limits_overrides=json.loads(
result["limits_overrides"] or "{}"
), # Can be null
metadata=json.loads(result["metadata"] or "{}"),
)

async def delete_user_relational(self, id: UUID) -> None:
Expand Down Expand Up @@ -389,6 +443,7 @@ async def get_all_users(self) -> list[User]:
"collection_ids",
"hashed_password",
"limits_overrides",
"metadata",
"name",
"bio",
"profile_picture",
Expand All @@ -412,6 +467,7 @@ async def get_all_users(self) -> list[User]:
limits_overrides=json.loads(
result["limits_overrides"] or "{}"
),
metadata=json.loads(result["metadata"] or "{}"),
name=result["name"],
bio=result["bio"],
profile_picture=result["profile_picture"],
Expand Down Expand Up @@ -586,6 +642,7 @@ async def get_users_in_collection(
"profile_picture",
"hashed_password",
"limits_overrides",
"metadata",
"COUNT(*) OVER() AS total_entries",
]
)
Expand Down Expand Up @@ -617,6 +674,7 @@ async def get_users_in_collection(
profile_picture=row["profile_picture"],
hashed_password=row["hashed_password"],
limits_overrides=json.loads(row["limits_overrides"] or "{}"),
metadata=json.loads(row["metadata"] or "{}"),
)
for row in results
]
Expand Down Expand Up @@ -808,16 +866,19 @@ async def store_user_api_key(
key_id: str,
hashed_key: str,
name: Optional[str] = None,
description: Optional[str] = None,
) -> UUID:
"""Store a new API key for a user."""
"""
Store a new API key for a user with optional name and description.
"""
query = f"""
INSERT INTO {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
(user_id, public_key, hashed_key, name)
VALUES ($1, $2, $3, $4)
(user_id, public_key, hashed_key, name, description)
VALUES ($1, $2, $3, $4, $5)
RETURNING id
"""
result = await self.connection_manager.fetchrow_query(
query, [user_id, key_id, hashed_key, name]
query, [user_id, key_id, hashed_key, name or "", description or ""]
)
if not result:
raise R2RException(
Expand Down Expand Up @@ -847,7 +908,7 @@ async def get_api_key_record(self, key_id: str) -> Optional[dict]:
async def get_user_api_keys(self, user_id: UUID) -> list[dict]:
"""Get all API keys for a user."""
query = f"""
SELECT id, public_key, name, created_at, updated_at
SELECT id, public_key, name, description, created_at, updated_at
FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
WHERE user_id = $1
ORDER BY created_at DESC
Expand All @@ -858,6 +919,7 @@ async def get_user_api_keys(self, user_id: UUID) -> list[dict]:
"key_id": str(row["id"]),
"public_key": row["public_key"],
"name": row["name"] or "",
"description": row["description"] or "",
"updated_at": row["updated_at"],
}
for row in results
Expand All @@ -868,7 +930,7 @@ async def delete_api_key(self, user_id: UUID, key_id: UUID) -> dict:
query = f"""
DELETE FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
WHERE id = $1 AND user_id = $2
RETURNING id, public_key, name
RETURNING id, public_key, name, description
"""
result = await self.connection_manager.fetchrow_query(
query, [key_id, user_id]
Expand All @@ -880,6 +942,7 @@ async def delete_api_key(self, user_id: UUID, key_id: UUID) -> dict:
"key_id": str(result["id"]),
"public_key": str(result["public_key"]),
"name": result["name"] or "",
"description": result["description"] or "",
}

async def update_api_key_name(
Expand Down
Loading

0 comments on commit 0caece6

Please sign in to comment.