From 894ac30eb64c05442ec9d4257da0f3204b09ddf7 Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty Date: Mon, 6 Jan 2025 19:56:31 -0800 Subject: [PATCH] up --- js/sdk/src/v3/clients/users.ts | 4 + py/core/main/api/v3/users_router.py | 79 ++++++++ py/core/main/services/management_service.py | 193 +++++++++++++++++++- py/sdk/base/base_client.py | 1 + py/sdk/v3/users.py | 14 ++ 5 files changed, 290 insertions(+), 1 deletion(-) diff --git a/js/sdk/src/v3/clients/users.ts b/js/sdk/src/v3/clients/users.ts index 7474c87a4..4decf1e29 100644 --- a/js/sdk/src/v3/clients/users.ts +++ b/js/sdk/src/v3/clients/users.ts @@ -501,4 +501,8 @@ export class UsersClient { ); } + async getLimits(options: { id: string }): Promise { + return this.client.makeRequest("GET", `users/${options.id}/limits`); + } + } diff --git a/py/core/main/api/v3/users_router.py b/py/core/main/api/v3/users_router.py index 57940fe69..7e197c2a3 100644 --- a/py/core/main/api/v3/users_router.py +++ b/py/core/main/api/v3/users_router.py @@ -1649,3 +1649,82 @@ async def delete_user_api_key( "API key not found or could not be deleted", 400 ) return {"success": True} # type: ignore + + @self.router.get( + "/users/{id}/limits", + summary="Fetch User Limits", + responses={ + 200: { + "description": "Returns system default limits, user overrides, and final effective settings." + }, + 403: { + "description": "If the requesting user is neither the same user nor a superuser." + }, + 404: {"description": "If the user ID does not exist."}, + }, + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": """ + from r2r import R2RClient + + client = R2RClient() + # client.login(...) + + user_limits = client.users.get_limits("550e8400-e29b-41d4-a716-446655440000") + print(user_limits) + """, + }, + { + "lang": "JavaScript", + "source": """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + // await client.users.login(...) + + async function main() { + const userLimits = await client.users.getLimits({ + id: "550e8400-e29b-41d4-a716-446655440000" + }); + console.log(userLimits); + } + + main(); + """, + }, + { + "lang": "cURL", + "source": """ + curl -X GET "https://api.example.com/v3/users/550e8400-e29b-41d4-a716-446655440000/limits" \\ + -H "Authorization: Bearer YOUR_API_KEY" + """, + }, + ] + }, + ) + @self.base_endpoint + async def get_user_limits( + id: UUID = Path( + ..., description="ID of the user to fetch limits for" + ), + auth_user=Depends(self.providers.auth.auth_wrapper()), + ) -> dict[str, dict]: + """ + Return the system default limits, user-level overrides, and final "effective" limit settings + for the specified user. + + Only superusers or the user themself may fetch these values. + """ + if (auth_user.id != id) and (not auth_user.is_superuser): + raise R2RException( + "Only the user themselves or a superuser can view these limits.", + status_code=403, + ) + + # This calls the new helper you created in ManagementService + limits_info = await self.services.management.get_all_user_limits( + id + ) + return limits_info diff --git a/py/core/main/services/management_service.py b/py/core/main/services/management_service.py index f7085921e..57094ba82 100644 --- a/py/core/main/services/management_service.py +++ b/py/core/main/services/management_service.py @@ -1,7 +1,7 @@ import logging import os from collections import defaultdict -from datetime import datetime +from datetime import datetime, timedelta, timezone from typing import IO, Any, BinaryIO, Optional, Tuple from uuid import UUID @@ -916,3 +916,194 @@ async def get_max_upload_size_by_type( # 6. Otherwise, return the global default return self.config.app.default_max_upload_size + + async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]: + """ + Return a dictionary containing: + - The system default limits (from self.config.limits) + - The user's overrides (from user.limits_overrides) + - The final 'effective' set of limits after merging + - The usage for each relevant limit (how many requests used, how many remain, etc.) + """ + # 1. Fetch the user to see if they have overrides + user = await self.providers.database.users_handler.get_user_by_id(user_id) + + # 2. System defaults + system_defaults = { + "global_per_min": self.config.database.limits.global_per_min, + "route_per_min": self.config.database.limits.route_per_min, + "monthly_limit": self.config.database.limits.monthly_limit, + # add other fields if your LimitSettings has them + } + + # 3. Grab user-level overrides + # (In your code, user.limits_overrides is a JSON field, e.g. {"global_per_min": 80, "route_overrides": {...}} ) + user_overrides = user.limits_overrides or {} + + # 4. Build effective limits by merging system_defaults with user_overrides + # For simplicity, we only directly handle "global_per_min" and "monthly_limit" at the top level + # Then route-specific overrides from user (like user_overrides["route_overrides"]) + # overshadow system route limits if they exist. + effective_limits = dict(system_defaults) + + # If the user added "global_per_min" or "monthly_limit" overrides, override them + if user_overrides.get("global_per_min") is not None: + effective_limits["global_per_min"] = user_overrides["global_per_min"] + if user_overrides.get("monthly_limit") is not None: + effective_limits["monthly_limit"] = user_overrides["monthly_limit"] + if user_overrides.get("route_per_min") is not None: + effective_limits["route_per_min"] = user_overrides["route_per_min"] + + # We'll also gather route-level overrides from: + # - self.config.route_limits (system route overrides) + # - user_overrides["route_overrides"] (user route overrides) + # So we can later show usage for each route. + system_route_limits = self.config.database.route_limits # dict[str, LimitSettings] + user_route_overrides = user_overrides.get("route_overrides", {}) # e.g. { "/api/foo": {...}, ... } + + # 5. Build usage data + usage = {} + # => We'll fill usage["global_per_min"], usage["monthly_limit"], and usage["routes"][route] ... + # We'll rely on your PostgresLimitsHandler to do the counting. + + # (a) Compute usage for global_per_min (requests in last minute) & monthly_limit + now = datetime.now(timezone.utc) + one_min_ago = now - timedelta(minutes=1) + + # Use your limits_handler to count + global_per_min_used = await self.providers.database.limits_handler._count_requests( + user_id, route=None, since=one_min_ago + ) + monthly_used = await self.providers.database.limits_handler._count_monthly_requests(user_id) + + # The final effective global/min is in `effective_limits["global_per_min"]`, etc. + usage["global_per_min"] = { + "used": global_per_min_used, + "limit": effective_limits["global_per_min"], + "remaining": ( + effective_limits["global_per_min"] - global_per_min_used + if effective_limits["global_per_min"] is not None + else None + ), + } + usage["monthly_limit"] = { + "used": monthly_used, + "limit": effective_limits["monthly_limit"], + "remaining": ( + effective_limits["monthly_limit"] - monthly_used + if effective_limits["monthly_limit"] is not None + else None + ), + } + + # (b) Build route-level usage + # We'll gather a union of the routes from system_route_limits + user_route_overrides + route_keys = set(system_route_limits.keys()) | set(user_route_overrides.keys()) + usage["routes"] = {} + for route in route_keys: + # 1) System route-limits + sys_route_lim = system_route_limits.get(route) # or None + route_global_per_min = sys_route_lim.global_per_min if sys_route_lim else system_defaults["global_per_min"] + route_route_per_min = sys_route_lim.route_per_min if sys_route_lim else system_defaults["route_per_min"] + route_monthly_limit = sys_route_lim.monthly_limit if sys_route_lim else system_defaults["monthly_limit"] + + # 2) Merge user overrides for that route + user_route_cfg = user_route_overrides.get(route, {}) # e.g. { "route_per_min": 25, "global_per_min": 80, ... } + if user_route_cfg.get("global_per_min") is not None: + route_global_per_min = user_route_cfg["global_per_min"] + if user_route_cfg.get("route_per_min") is not None: + route_route_per_min = user_route_cfg["route_per_min"] + if user_route_cfg.get("monthly_limit") is not None: + route_monthly_limit = user_route_cfg["monthly_limit"] + + # Now let's measure usage for this route over the last minute + route_per_min_used = await self.providers.database.limits_handler._count_requests( + user_id, route, one_min_ago + ) + # monthly usage is the same for all routes if there's a global monthly limit, + # but if you have route-specific monthly limits, we still want to do a global monthly count. + # (You can do something more advanced if you only want route-specific monthly usage, but + # your code currently lumps monthly usage by user_id, not by user+route.) + # We'll reuse monthly_used from above, so if there's a route-specific monthly limit, + # it compares the entire month's usage to that route limit. + # If you want only the route's monthly usage, you'd need a new function + # e.g. `_count_requests(user_id, route, start_of_month)` in your limits_handler. + + usage["routes"][route] = { + # The route-level per-minute usage (stuff relevant to route_per_min) + "route_per_min": { + "used": route_per_min_used, + "limit": route_route_per_min, + "remaining": ( + route_route_per_min - route_per_min_used + if route_route_per_min is not None + else None + ), + }, + # If you want to represent the "global_per_min" that applies to this route, + # you could put that here too if it’s route-specific. + # But typically "global_per_min" is for all requests, so usage is the same as above. + + # The route-specific monthly usage, in your code, is not specifically counted by route, + # but if you want to do it the same as route_per_min, you'd do: + # route_monthly_used = await self.providers.database.limits_handler._count_requests( + # user_id, route, start_of_month + # ) + # We'll just reuse the global monthly usage to compare to the route's monthly limit: + "monthly_limit": { + "used": monthly_used, + "limit": route_monthly_limit, + "remaining": ( + route_monthly_limit - monthly_used + if route_monthly_limit is not None + else None + ), + } + } + + # Return a structured response + return { + "system_defaults": system_defaults, + "user_overrides": user_overrides, + "effective_limits": effective_limits, + "usage": usage, + } + + # """ + # Return a dictionary containing: + # - The system default limits (from config) + # - The user overrides (from user.limits_overrides) + # - The final "effective" set of limits after merging + # """ + # # 1. Fetch the user to see if they have overrides + # user = await self.providers.database.users_handler.get_user_by_id( + # user_id + # ) + + # # 2. System defaults (example: from self.config.limits) + # # Adjust these names as needed based on your actual config + # system_defaults = { + # "global_per_min": self.config.database.limits.global_per_min, + # "route_per_min": self.config.database.limits.route_per_min, + # "monthly_limit": self.config.database.limits.monthly_limit, + # # add other fields as needed + # } + + # # 3. Grab user-level overrides + # user_overrides = ( + # user.limits_overrides or {} + # ) # In DB, typically a JSON field + + # # 4. Merge them. "Effective" means the final set of limits after user overrides + # # overshadow system defaults if present + # effective_limits = dict(system_defaults) # start with system + # for k, v in user_overrides.items(): + # # If your overrides nest like {"global_per_min": X, "route_overrides": {...}}, + # # you might need more robust merging logic. For simplicity, we do a shallow merge here. + # effective_limits[k] = v + + # return { + # "system_defaults": system_defaults, + # "user_overrides": user_overrides, + # "effective_limits": effective_limits, + # } diff --git a/py/sdk/base/base_client.py b/py/sdk/base/base_client.py index 589644ae9..6f8393d84 100644 --- a/py/sdk/base/base_client.py +++ b/py/sdk/base/base_client.py @@ -44,6 +44,7 @@ def __init__( self.timeout = timeout self.access_token: Optional[str] = None self._refresh_token: Optional[str] = None + self._user_id: Optional[str] = None self.api_key: Optional[str] = os.getenv("R2R_API_KEY", None) def _get_auth_header(self) -> dict[str, str]: diff --git a/py/sdk/v3/users.py b/py/sdk/v3/users.py index 3bd0eb190..99cf35c66 100644 --- a/py/sdk/v3/users.py +++ b/py/sdk/v3/users.py @@ -154,6 +154,13 @@ async def login(self, email: str, password: str) -> dict[str, Token]: self.client._refresh_token = response["results"]["refresh_token"][ "token" ] + user = await self.client._make_request( + "GET", + "users/me", + version="v3", + ) + + self.client._user_id = user["results"]["id"] return response # FIXME: What is going on here... @@ -518,3 +525,10 @@ async def delete_api_key( f"users/{str(id)}/api-keys/{str(key_id)}", version="v3", ) + + async def get_limits(self) -> dict[str, Any]: + return await self.client._make_request( + "GET", + f"users/{str(self.client._user_id)}/limits", + version="v3", + )