Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue:4044453:Replace the flask framework with aiohttp in PDR plugin #243

Merged
merged 6 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugins/pdr_deterministic_plugin/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ The ports are added or removed to blacklist via PDR plugin Rest API.
Add ports to exclude list (to be excluded from the analysis):
curl -k -i -X PUT 'http://<host_ip>/excluded' -d '[<formatted_ports_list>]'
TTL (time to live in blacklist) can optionally follow the port after the comma (if zero or not specified, then port is excluded forever)
Example: curl -k -i -X PUT 'http://127.0.0.1:8977/excluded' -d '[["9c0591030085ac80_45"],["9c0591030085ac80_46",300]' (first port is added forever, second - just for 300 seconds)
Example: curl -k -i -X PUT 'http://127.0.0.1:8977/excluded' -d '[["9c0591030085ac80_45"],["9c0591030085ac80_46",300]]' (first port is added forever, second - just for 300 seconds)

Remove ports from exclude list
curl -k -i -X DELETE 'http://<host_ip>/excluded' -d '[<comma_separated_port_mames>]'
Expand Down
2 changes: 1 addition & 1 deletion plugins/pdr_deterministic_plugin/build/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ EXPOSE 9007

RUN apt-get update && apt-get -y install supervisor python3 python3-pip rsyslog vim curl sudo

RUN python3 -m pip install flask flask_restful requests twisted jsonschema pandas numpy
RUN python3 -m pip install requests jsonschema pandas numpy aiohttp

# remove an unused library that caused a high CVE vulnerability issue https://redmine.mellanox.com/issues/3837452
RUN apt-get remove -y linux-libc-dev
Expand Down
4 changes: 1 addition & 3 deletions plugins/pdr_deterministic_plugin/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
flask<=3.0.3
numpy<=1.26.4
pandas<=2.2.2
pytest<=8.2.0
requests<=2.31.0
twisted<=22.1.0
flask_restful<=0.3.10
tzlocal<=4.2
jsonschema<=4.5.1
aiohttp<=3.9.1
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#
# Copyright © 2013-2024 NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# This software product is a proprietary product of Nvidia Corporation and its affiliates
# (the "Company") and all right, title, and interest in and to the software
# product, including all associated intellectual property rights, are and
# shall remain exclusively with the Company.
#
# This software product is governed by the End User License Agreement
# provided with the software product.
#

import asyncio
from aiohttp import web

class BaseAiohttpAPI:
"""
Base class for API implemented with aiohttp
"""
def __init__(self):
"""
Initialize a new instance of the BaseAiohttpAPI class.
"""
self.app = web.Application()

@property
def application(self):
"""
Read-only property for the application instance.
"""
return self.app

def add_route(self, method, path, handler):
"""
Add route to API.
"""
self.app.router.add_route(method, path, handler)

def web_response(self, text, status):
"""
Create response object.
"""
return web.json_response(text=text, status=status)


class BaseAiohttpServer:
"""
Base class for HTTP server implemented with aiohttp
"""
def __init__(self, logger):
"""
Initialize a new instance of the BaseAiohttpAPI class.
"""
self.logger = logger

def run(self, app, host, port):
"""
Run the server on the specified host and port.
"""
loop = asyncio.get_event_loop()
loop.run_until_complete(self._run_server(app, host, port))

async def _run_server(self, app, host, port):
"""
Asynchronously run the server and handle shutdown.
"""
# Run server
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, host, port)
await site.start()
self.logger.info(f"Server started at {host}:{port}")

# Wait for shutdown signal
shutdown_event = asyncio.Event()
try:
await shutdown_event.wait()
except KeyboardInterrupt:
self.logger.info(f"Shutting down server {host}:{port}...")
finally:
await runner.cleanup()
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
# provided with the software product.
#

import json
import time
from http import HTTPStatus
from json import JSONDecodeError
from flask import json, request
from utils.flask_server.base_flask_api_server import BaseAPIApplication
from api.base_aiohttp_api import BaseAiohttpAPI

ERROR_INCORRECT_INPUT_FORMAT = "Incorrect input format"
EOL = '\n'

class PDRPluginAPI(BaseAPIApplication):
class PDRPluginAPI(BaseAiohttpAPI):
'''
class PDRPluginAPI
'''
Expand All @@ -31,42 +31,36 @@ def __init__(self, isolation_mgr):
super(PDRPluginAPI, self).__init__()
self.isolation_mgr = isolation_mgr


def _get_routes(self):
"""
Map URLs to function calls
"""
return {
self.get_excluded_ports: dict(urls=["/excluded"], methods=["GET"]),
self.exclude_ports: dict(urls=["/excluded"], methods=["PUT"]),
self.include_ports: dict(urls=["/excluded"], methods=["DELETE"])
}
# Define routes using the base class's method
self.add_route("GET", "/excluded", self.get_excluded_ports)
self.add_route("PUT", "/excluded", self.exclude_ports)
self.add_route("DELETE", "/excluded", self.include_ports)


def get_excluded_ports(self):
async def get_excluded_ports(self, request):
"""
Return ports from exclude list as comma separated port names
"""
items = self.isolation_mgr.exclude_list.items()
formatted_items = [f"{item.port_name}: {'infinite' if item.ttl_seconds == 0 else int(max(0, item.remove_time - time.time()))}" for item in items]
response = EOL.join(formatted_items) + ('' if not formatted_items else EOL)
return response, HTTPStatus.OK
return self.web_response(response, HTTPStatus.OK)


def exclude_ports(self):
async def exclude_ports(self, request):
"""
Parse input ports and add them to exclude list (or just update TTL)
Input string example: [["0c42a10300756a04_1"],["98039b03006c73ba_2",300]]
TTL that follows port name after the colon is optional
"""

try:
pairs = self.get_request_data()
pairs = await self.get_request_data(request)
except (JSONDecodeError, ValueError):
return ERROR_INCORRECT_INPUT_FORMAT + EOL, HTTPStatus.BAD_REQUEST
return self.web_response(ERROR_INCORRECT_INPUT_FORMAT + EOL, HTTPStatus.BAD_REQUEST)

if not isinstance(pairs, list) or not all(isinstance(pair, list) for pair in pairs):
return ERROR_INCORRECT_INPUT_FORMAT + EOL, HTTPStatus.BAD_REQUEST
return self.web_response(ERROR_INCORRECT_INPUT_FORMAT + EOL, HTTPStatus.BAD_REQUEST)

response = ""
for pair in pairs:
Expand All @@ -80,23 +74,23 @@ def exclude_ports(self):
response += f"Port {port_name} added to exclude list for {ttl} seconds"

response += self.get_port_warning(port_name) + EOL

return response, HTTPStatus.OK

return self.web_response(response, HTTPStatus.OK)

def include_ports(self):

async def include_ports(self, request):
"""
Remove ports from exclude list
Input string: comma separated port names list
Example: ["0c42a10300756a04_1","98039b03006c73ba_2"]
"""
try:
port_names = self.get_request_data()
port_names = await self.get_request_data(request)
except (JSONDecodeError, ValueError):
return ERROR_INCORRECT_INPUT_FORMAT + EOL, HTTPStatus.BAD_REQUEST
return self.web_response(ERROR_INCORRECT_INPUT_FORMAT + EOL, HTTPStatus.BAD_REQUEST)

if not isinstance(port_names, list):
return ERROR_INCORRECT_INPUT_FORMAT + EOL, HTTPStatus.BAD_REQUEST
return self.web_response(ERROR_INCORRECT_INPUT_FORMAT + EOL, HTTPStatus.BAD_REQUEST)

response = ""
for port_name in port_names:
Expand All @@ -108,19 +102,25 @@ def include_ports(self):

response += self.get_port_warning(port_name) + EOL

return response, HTTPStatus.OK
return self.web_response(response, HTTPStatus.OK)


def get_request_data(self):
async def get_request_data(self, request):
"""
Deserialize request json data into object
Deserialize request data into object for aiohttp
"""
if request.is_json:
# Directly convert JSON data into Python object
return request.get_json()
else:
# Attempt to load plain data text as JSON
return json.loads(request.get_data(as_text=True))
try:
# Try to get JSON data
return await request.json()
except json.JSONDecodeError:
# Try to get plain text data
text = await request.text()
try:
# Try to parse the text as JSON
return json.loads(text)
except json.JSONDecodeError:
# Return the raw text data
return text


def fix_port_name(self, port_name):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,12 @@
import os
import logging
from logging.handlers import RotatingFileHandler
import threading
from constants import PDRConstants as Constants
from isolation_mgr import IsolationMgr
from api.base_aiohttp_api import BaseAiohttpServer
from ufm_communication_mgr import UFMCommunicator
from api.pdr_plugin_api import PDRPluginAPI
from twisted.web.wsgi import WSGIResource
from twisted.internet import reactor
from twisted.web import server
from utils.flask_server import run_api
from utils.flask_server.base_flask_api_app import BaseFlaskAPIApp
from utils.utils import Utils


Expand Down Expand Up @@ -81,19 +78,16 @@ def main():
logger = create_logger(Constants.LOG_FILE)

algo_loop = IsolationMgr(ufm_client, logger)
reactor.callInThread(algo_loop.main_flow)
threading.Thread(target=algo_loop.main_flow).start()

try:
plugin_port = Utils.get_plugin_port(
port_conf_file='/config/pdr_deterministic_httpd_proxy.conf',
default_port_value=8977)

routes = {
"": PDRPluginAPI(algo_loop).application
}

app = BaseFlaskAPIApp(routes)
run_api(app=app, port_number=int(plugin_port))
api = PDRPluginAPI(algo_loop)
server = BaseAiohttpServer(logger)
server.run(api.application, "127.0.0.1", int(plugin_port))

except Exception as ex:
print(f'Failed to run the app: {str(ex)}')
Expand Down