Skip to content

Commit

Permalink
[feature] Added API endpoint to return user's RADIUS usage #499
Browse files Browse the repository at this point in the history
Closes #499
  • Loading branch information
pandafy committed Dec 19, 2023
1 parent 90786b1 commit 5663fb7
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 21 deletions.
22 changes: 2 additions & 20 deletions openwisp_radius/api/freeradius_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ..counters.base import BaseCounter
from ..counters.exceptions import MaxQuotaReached, SkipCheck
from ..signals import radius_accounting_success
from ..utils import load_model
from ..utils import get_group_checks, load_model
from .serializers import (
AuthorizeSerializer,
RadiusAccountingSerializer,
Expand Down Expand Up @@ -302,7 +302,7 @@ def get_replies(self, user, organization_id):
for reply in self.get_group_replies(user_group.group):
data.update({reply.attribute: {'op': reply.op, 'value': reply.value}})

group_checks = self.get_group_checks(user_group.group)
group_checks = get_group_checks(user_group.group)

for counter in app_settings.COUNTERS:
group_check = group_checks.get(counter.check_name)
Expand Down Expand Up @@ -357,24 +357,6 @@ def get_user_group(self, user, organization_id):
def get_group_replies(self, group):
return group.radiusgroupreply_set.all()

def get_group_checks(self, group):
"""
Used to query the DB for group checks only once
instead of once per each counter in use.
"""
if not app_settings.COUNTERS:
return

check_attributes = []
for counter in app_settings.COUNTERS:
check_attributes.append(counter.check_name)

group_checks = group.radiusgroupcheck_set.filter(attribute__in=check_attributes)
result = {}
for group_check in group_checks:
result[group_check.attribute] = group_check
return result

def _get_user_query_conditions(self, request):
is_active = Q(is_active=True)
needs_verification = self._needs_identity_verification({'pk': request._auth})
Expand Down
41 changes: 40 additions & 1 deletion openwisp_radius/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@
AuthTokenSerializer as BaseAuthTokenSerializer,
)
from rest_framework.fields import empty
from rest_framework.generics import get_object_or_404

from openwisp_radius.api.exceptions import CrossOrgRegistrationException
from openwisp_users.backends import UsersAuthenticationBackend

from .. import settings as app_settings
from ..base.forms import PasswordResetForm
from ..counters.exceptions import MaxQuotaReached, SkipCheck
from ..registration import REGISTRATION_METHOD_CHOICES
from ..utils import get_organization_radius_settings, load_model
from ..utils import get_group_checks, get_organization_radius_settings, load_model
from .utils import ErrorDictMixin, IDVerificationHelper

logger = logging.getLogger(__name__)
Expand All @@ -42,6 +44,8 @@
RadiusAccounting = load_model('RadiusAccounting')
RadiusBatch = load_model('RadiusBatch')
RadiusToken = load_model('RadiusToken')
RadiusGroupCheck = load_model('RadiusGroupCheck')
RadiusUserGroup = load_model('RadiusUserGroup')
RegisteredUser = load_model('RegisteredUser')
OrganizationUser = swapper.load_model('openwisp_users', 'OrganizationUser')
Organization = swapper.load_model('openwisp_users', 'Organization')
Expand Down Expand Up @@ -266,6 +270,41 @@ class Meta:
read_only_fields = ('organization',)


class UserGroupCheckSerializer(serializers.ModelSerializer):
result = serializers.SerializerMethodField()

class Meta:
model = RadiusGroupCheck
fields = ('attribute', 'op', 'value', 'result')

def get_result(self, obj):
try:
counter = app_settings.CHECK_ATTRIBUTE_COUNTERS_MAP[obj.attribute]
remaining = counter(
user=self.context['user'],
group=self.context['group'],
group_check=obj,
).check()
return int(obj.value) - remaining
except MaxQuotaReached:
return int(obj.value)
except (SkipCheck, ValueError, KeyError):
return None


class UserRadiusUsageSerializer(serializers.Serializer):
def to_representation(self, obj):
organization = self.context['view'].organization
user_group = get_object_or_404(
RadiusUserGroup, group__organization=organization, user=obj
)
group_checks = get_group_checks(user_group.group).values()
checks_data = UserGroupCheckSerializer(
group_checks, many=True, context={'user': obj, 'group': user_group.group}
).data
return {'checks': checks_data}


class GroupSerializer(serializers.ModelSerializer):
class Meta:
model = Group
Expand Down
5 changes: 5 additions & 0 deletions openwisp_radius/api/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def get_api_urls(api_views=None):
api_views.user_accounting,
name='user_accounting',
),
path(
'radius/organization/<slug:slug>/account/usage/',
api_views.user_radius_usage,
name='user_radius_usage',
),
# generate new sms phone token
path(
'radius/organization/<slug:slug>/account/phone/token/',
Expand Down
25 changes: 25 additions & 0 deletions openwisp_radius/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
ChangePhoneNumberSerializer,
RadiusAccountingSerializer,
RadiusBatchSerializer,
UserRadiusUsageSerializer,
ValidatePhoneTokenSerializer,
)
from .swagger import ObtainTokenRequest, ObtainTokenResponse, RegisterResponse
Expand All @@ -80,6 +81,8 @@
RadiusAccounting = load_model('RadiusAccounting')
RadiusToken = load_model('RadiusToken')
RadiusBatch = load_model('RadiusBatch')
RadiusUserGroup = load_model('RadiusUserGroup')
RadiusGroupCheck = load_model('RadiusGroupCheck')
auth_backend = UsersAuthenticationBackend()


Expand Down Expand Up @@ -444,6 +447,28 @@ def get_queryset(self):
user_accounting = UserAccountingView.as_view()


@method_decorator(
name='get',
decorator=swagger_auto_schema(
operation_description="""
**Requires the user auth token (Bearer Token).**
Returns the user's accounting usage and limit for the organization.
""",
),
)
class UserRadiusUsageView(ThrottledAPIMixin, DispatchOrgMixin, RetrieveAPIView):
authentication_classes = (BearerAuthentication, SessionAuthentication)
permission_classes = (IsAuthenticated,)
queryset = User.objects.none()
serializer_class = UserRadiusUsageSerializer

def get_object(self):
return self.request.user


user_radius_usage = UserRadiusUsageView.as_view()


class PasswordChangeView(ThrottledAPIMixin, DispatchOrgMixin, BasePasswordChangeView):
authentication_classes = (BearerAuthentication,)

Expand Down
2 changes: 2 additions & 0 deletions openwisp_radius/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,14 @@ def get_default_password_reset_url(urls):
raise ImproperlyConfigured(str(e))

COUNTERS = []
CHECK_ATTRIBUTE_COUNTERS_MAP = {}
for counter_path in _counters:
try:
counter_class = import_string(counter_path)
except ImportError as e: # pragma: no cover
raise ImproperlyConfigured(str(e))
COUNTERS.append(counter_class)
CHECK_ATTRIBUTE_COUNTERS_MAP[counter_class.check_name] = counter_class


# Extend the EXPORT_USERS_COMMAND_CONFIG[fields]
Expand Down
152 changes: 152 additions & 0 deletions openwisp_radius/tests/test_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,3 +1008,155 @@ def test_organization_registration_enabled(self):
)
self.assertEqual(r.status_code, 201)
self.assertIn('key', r.data)

def test_user_radius_usage_view(self):
auth_url = reverse('radius:user_auth_token', args=[self.default_org.slug])
usage_url = reverse('radius:user_radius_usage', args=[self.default_org.slug])
self._get_org_user()
response = self.client.post(
auth_url, {'username': 'tester', 'password': 'tester'}
)
authorization = f'Bearer {response.data["key"]}'
self.assertEqual(response.status_code, 200)
with self.subTest('Test user has not used any data'):
response = self.client.get(usage_url, HTTP_AUTHORIZATION=authorization)
self.assertEqual(response.status_code, 200)
self.assertIn('checks', response.data)
checks = response.data['checks']
self.assertDictEqual(
dict(checks[0]),
{
'attribute': 'Max-Daily-Session',
'op': ':=',
'value': '10800',
'result': 0,
},
)
self.assertDictEqual(
dict(checks[1]),
{
'attribute': 'Max-Daily-Session-Traffic',
'op': ':=',
'value': '3000000000',
'result': 0,
},
)

stop_time = '2018-03-02T11:43:24.020460+01:00'
data1 = self.acct_post_data
data1.update(
dict(
session_id='35000006',
unique_id='75058e50',
input_octets=1000000000,
output_octets=1000000000,
username='tester',
stop_time=stop_time,
organization=self.default_org,
)
)
self._create_radius_accounting(**data1)

with self.subTest('Test user consumed some data'):
response = self.client.get(usage_url, HTTP_AUTHORIZATION=authorization)
self.assertEqual(response.status_code, 200)
self.assertIn('checks', response.data)
checks = response.data['checks']
self.assertDictEqual(
dict(checks[0]),
{
'attribute': 'Max-Daily-Session',
'op': ':=',
'value': '10800',
'result': 261,
},
)
self.assertDictEqual(
dict(checks[1]),
{
'attribute': 'Max-Daily-Session-Traffic',
'op': ':=',
'value': '3000000000',
'result': 2000000000,
},
)

data2 = self.acct_post_data
data2.update(
dict(
session_id='40111116',
unique_id='12234f69',
input_octets=500000000,
output_octets=500000000,
username='tester',
)
)
self._create_radius_accounting(**data2)

with self.subTest('Test user exhausted limits'):
response = self.client.get(usage_url, HTTP_AUTHORIZATION=authorization)
self.assertEqual(response.status_code, 200)
self.assertIn('checks', response.data)
checks = response.data['checks']
self.assertDictEqual(
dict(checks[0]),
{
'attribute': 'Max-Daily-Session',
'op': ':=',
'value': '10800',
'result': 522,
},
)
self.assertDictEqual(
dict(checks[1]),
{
'attribute': 'Max-Daily-Session-Traffic',
'op': ':=',
'value': '3000000000',
'result': 3000000000,
},
)

# data3 = self.acct_post_data
# data3.update(
# dict(
# session_id='89897654',
# unique_id='99144d60',
# input_octets=4440909,
# output_octets=1119074409,
# username='admin',
# stop_time=stop_time,
# )
# )
# self._create_radius_accounting(**data3)
# url = reverse('radius:user_accounting', args=[self.default_org.slug])
# response = self.client.get(
# f'{url}?page_size=1&page=1',
# HTTP_AUTHORIZATION=authorization,
# )
# self.assertEqual(len(response.json()), 1)
# self.assertEqual(response.status_code, 200)
# item = response.data[0]
# self.assertEqual(item['output_octets'], data2['output_octets'])
# self.assertEqual(item['input_octets'], data2['input_octets'])
# self.assertEqual(item['nas_ip_address'], '172.16.64.91')
# self.assertEqual(item['calling_station_id'], '5c:7d:c1:72:a7:3b')
# self.assertIsNone(item['stop_time'])
# response = self.client.get(
# f'{url}?page_size=1&page=2',
# HTTP_AUTHORIZATION=authorization,
# )
# self.assertEqual(len(response.json()), 1)
# self.assertEqual(response.status_code, 200)
# item = response.data[0]
# self.assertEqual(item['output_octets'], data1['output_octets'])
# self.assertEqual(item['nas_ip_address'], '172.16.64.91')
# self.assertEqual(item['input_octets'], data1['input_octets'])
# self.assertEqual(item['called_station_id'], '00-27-22-F3-FA-F1:hostname')
# self.assertIsNotNone(item['stop_time'])
# response = self.client.get(
# f'{url}?page_size=1&page=3',
# HTTP_AUTHORIZATION=authorization,
# )
# self.assertEqual(len(response.json()), 1)
# self.assertEqual(response.status_code, 404)
26 changes: 26 additions & 0 deletions openwisp_radius/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,29 @@ def get_organization_radius_settings(organization, radius_setting):
raise APIException(
_('Could not complete operation because of an internal misconfiguration')
)


def get_group_checks(group):
"""
Retrieves a dictionary of checks for the given group.
Parameters:
group (Group): The group object for which to retrieve the checks.
Returns:
dict: A dictionary of group checks with the attribute as the key and
the corresponding group check object as the value.
Used to query the DB for group checks only once
instead of once per each counter in use.
"""

if not app_settings.COUNTERS:
return

check_attributes = app_settings.CHECK_ATTRIBUTE_COUNTERS_MAP.keys()
group_checks = group.radiusgroupcheck_set.filter(attribute__in=check_attributes)
result = {}
for group_check in group_checks:
result[group_check.attribute] = group_check
return result
6 changes: 6 additions & 0 deletions tests/openwisp2/sample_radius/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from openwisp_radius.api.views import PasswordResetView as BasePasswordResetView
from openwisp_radius.api.views import RegisterView as BaseRegisterView
from openwisp_radius.api.views import UserAccountingView as BaseUserAccountingView
from openwisp_radius.api.views import UserRadiusUsageView as BaseUserRadiusUsageView
from openwisp_radius.api.views import ValidateAuthTokenView as BaseValidateAuthTokenView
from openwisp_radius.api.views import (
ValidatePhoneTokenView as BaseValidatePhoneTokenView,
Expand Down Expand Up @@ -56,6 +57,10 @@ class UserAccountingView(BaseUserAccountingView):
pass


class UserRadiusUsageView(BaseUserRadiusUsageView):
pass


class PasswordChangeView(BasePasswordChangeView):
pass

Expand Down Expand Up @@ -96,6 +101,7 @@ class DownloadRadiusBatchPdfView(BaseDownloadRadiusBatchPdfView):
obtain_auth_token = ObtainAuthTokenView.as_view()
validate_auth_token = ValidateAuthTokenView.as_view()
user_accounting = UserAccountingView.as_view()
user_radius_usage = UserRadiusUsageView.as_view()
password_change = PasswordChangeView.as_view()
password_reset = PasswordResetView.as_view()
password_reset_confirm = PasswordResetConfirmView.as_view()
Expand Down

0 comments on commit 5663fb7

Please sign in to comment.