Skip to content

Commit

Permalink
socket servers structure refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
zNitche committed May 15, 2024
1 parent 416e71d commit aea31d5
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 111 deletions.
132 changes: 21 additions & 111 deletions lightberry/core/server.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,24 @@
import network
import ssl
import time
import asyncio
from lightberry.core.communication.request import Request
from lightberry.core.communication.response import Response
from lightberry.tasks.periodic_tasks import ReconnectToNetworkTask, BlinkLedTask
from lightberry.utils import common_utils, requests_utils, files_utils
from lightberry.core.sockets_servers.http import AppServer
from lightberry.core.sockets_servers.http import SslProxyServer
from lightberry.utils import common_utils
from lightberry.config import ServerConfig as Config

from lightberry.typing import TYPE_CHECKING

if TYPE_CHECKING:
from network import WLAN
from asyncio import AbstractEventLoop, StreamReader, StreamWriter
from asyncio import AbstractEventLoop
from lightberry.core.app import App
from lightberry.tasks.task_base import TaskBase
from lightberry.core.sockets_servers.http import HttpSocketServer


class Server:
def __init__(self,
app: App,
host="0.0.0.0",
port=Config.SERVER_PORT,
debug_mode=Config.DEBUG,
wifi_ssid=Config.WIFI_SSID,
wifi_password=Config.WIFI_PASSWORD,
Expand All @@ -34,9 +31,6 @@ def __init__(self,
self.debug_mode = debug_mode
self.config = Config

self.host = host
self.port = port

self.wifi_ssid = wifi_ssid
self.wifi_password = wifi_password

Expand All @@ -52,6 +46,7 @@ def __init__(self,
self.__mainloop: AbstractEventLoop = asyncio.get_event_loop()
self.__app = app

self.__http_socket_servers: list[HttpSocketServer] = []
self.__background_tasks: list[TaskBase] = []

self.__run_as_host() if self.__hotspot_mode else self.__run_as_client()
Expand Down Expand Up @@ -96,114 +91,27 @@ def __run_as_client(self):
def __run_as_host(self):
self.__setup_wlan_as_host()

async def __load_request(self, request_stream: StreamReader) -> Request | None:
try:
request_header_string = await requests_utils.load_request_header_from_stream(request_stream)
self.__print_debug(f"request header string: {request_header_string}")

request = Request()
request.parse_header(request_header_string)

if request.content_length:
request_body_string = await request_stream.readexactly(request.content_length)
request.parse_body(request_body_string.decode())

self.__print_debug(f"request body string: {request.body}")

return request

except Exception as e:
self.__print_debug(f"error while parsing request", exception=e)
return None

async def __requests_handler(self, client_r: StreamReader, client_w: StreamWriter):
self.__print_debug(f"connection from: {client_w.get_extra_info('peername')}")

try:
start_time = time.ticks_ms() if self.debug_mode else None

request = await asyncio.wait_for(self.__load_request(client_r), self.config.TIMEOUT)

if request:
response = await asyncio.wait_for(self.__app.requests_handler(request), self.__app.config.TIMEOUT)

if response.is_payload_streamed:
client_w.write(bytes(f"{response.get_headers()}\r\n\r\n", "utf-8"))
await client_w.drain()

for chunk in response.get_body():
client_w.write(bytes(chunk, "utf-8"))
await client_w.drain()
else:
client_w.write(bytes(response.get_response_string(), "utf-8"))
await client_w.drain()

except Exception as e:
self.__print_debug(f"error occurred: {str(e)}", exception=e)
def __init_http_socket_servers(self):
app_server = AppServer(self.__app, port=self.config.SERVER_PORT)
app_server.setup()

finally:
client_w.close()
await client_w.wait_closed()
self.__http_socket_servers.append(app_server)

if self.debug_mode:
self.__print_debug(f"request took: {time.ticks_ms() - start_time}ms")
if app_server.ssl_context is not None:
proxy_server = SslProxyServer(self.__wlan, port=self.config.SERVER_PORT)
proxy_server.setup()

async def __ssl_proxy_requests_handler(self, client_r: StreamReader, client_w: StreamWriter):
self.__print_debug(f"ssl proxy - connection from: {client_w.get_extra_info('peername')}")
self.__http_socket_servers.append(proxy_server)

try:
request = await asyncio.wait_for(self.__load_request(client_r), self.config.TIMEOUT)

if request:
response = Response(301)
response.headers["LOCATION"] = f"https://{self.__wlan.ifconfig()[0]}{request.url}"

client_w.write(bytes(response.get_response_string(), "utf-8"))
await client_w.drain()

except Exception as e:
self.__print_debug(f"error occurred: {str(e)}", exception=e)

finally:
client_w.close()
await client_w.wait_closed()

self.__print_debug(f"ssl proxy - connection closed")

def __init_server(self):
ssl_context = None
ssl_cert_file = self.config.get("CERT_FILE")
ssl_key_file = self.config.get("CERT_KEY")

if ((ssl_cert_file and files_utils.file_exists(ssl_cert_file)) and
(ssl_key_file and files_utils.file_exists(ssl_key_file))):
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(ssl_cert_file, ssl_key_file)

self.__print_debug("SSL certificate has been loaded...")
self.__init_ssl_proxy()

target_port = self.port if ssl_context is None else 443
server_task = asyncio.start_server(self.__requests_handler,
self.host,
target_port,
ssl=ssl_context)

self.__mainloop.create_task(server_task)

def __init_ssl_proxy(self):
proxy_server_task = asyncio.start_server(self.__ssl_proxy_requests_handler,
self.host,
self.port)

self.__mainloop.create_task(proxy_server_task)
self.__print_debug(f"ssl proxy running at port: {self.port}")
for server in self.__http_socket_servers: # type: HttpSocketServer
self.__mainloop.create_task(server.server_task)
self.__print_debug(f"Task for {server.__class__.__name__} has been created")

def start(self):
self.__print_debug("starting mainloop...")

if self.__wlan is not None:
self.__init_server()
self.__init_http_socket_servers()

self.__register_background_tasks()
self.__setup_app()
Expand All @@ -212,13 +120,15 @@ def start(self):
self.__print_debug("mainloop running...")

self.__mainloop.run_forever()
else:
self.__print_debug("Couldn't start server wlan is None")

def stop(self):
self.__mainloop.stop()
self.__mainloop.close()

def __setup_app(self):
self.__app.host = f"{self.__wlan.ifconfig()[0]}:{self.port}"
self.__app.host = f"{self.__wlan.ifconfig()[0]}:{self.config.SERVER_PORT}"
self.__app.register_background_tasks(self.__mainloop)

def __register_background_tasks(self):
Expand Down
3 changes: 3 additions & 0 deletions lightberry/core/sockets_servers/http/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from lightberry.core.sockets_servers.http.http_socket_server import HttpSocketServer
from lightberry.core.sockets_servers.http.app_server import AppServer
from lightberry.core.sockets_servers.http.ssl_proxy_server import SslProxyServer
71 changes: 71 additions & 0 deletions lightberry/core/sockets_servers/http/app_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from lightberry.config import ServerConfig as Config
from lightberry.core.sockets_servers.http import HttpSocketServer
from lightberry.utils import files_utils
import ssl
import time
import asyncio

from lightberry.typing import TYPE_CHECKING

if TYPE_CHECKING:
from asyncio import StreamReader, StreamWriter
from lightberry.core.app import App


class AppServer(HttpSocketServer):
def __init__(self, app, host="0.0.0.0", port=Config.SERVER_PORT, debug_mode=Config.DEBUG):
super().__init__(host, port, debug_mode)

self.ssl_context: ssl.SSLContext | None = None

self.__app: App = app
self.config = Config

async def __requests_handler(self, client_r: StreamReader, client_w: StreamWriter):
self.__print_debug(f"connection from: {client_w.get_extra_info('peername')}")

try:
start_time = time.ticks_ms() if self.debug_mode else None
request = await asyncio.wait_for(self.__load_request(client_r), self.config.TIMEOUT)

if request:
response = await asyncio.wait_for(self.__app.requests_handler(request), self.__app.config.TIMEOUT)

if response.is_payload_streamed:
client_w.write(bytes(f"{response.get_headers()}\r\n\r\n", "utf-8"))
await client_w.drain()

for chunk in response.get_body():
client_w.write(bytes(chunk, "utf-8"))
await client_w.drain()
else:
client_w.write(bytes(response.get_response_string(), "utf-8"))
await client_w.drain()

except Exception as e:
self.__print_debug(f"error occurred: {str(e)}", exception=e)

finally:
client_w.close()
await client_w.wait_closed()

if self.debug_mode:
self.__print_debug(f"request took: {time.ticks_ms() - start_time}ms")

def setup(self):
ssl_cert_file = self.config.get("CERT_FILE")
ssl_key_file = self.config.get("CERT_KEY")

if ((ssl_cert_file and files_utils.file_exists(ssl_cert_file)) and
(ssl_key_file and files_utils.file_exists(ssl_key_file))):
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
self.ssl_context.load_cert_chain(ssl_cert_file, ssl_key_file)

self.port = 443

self.__print_debug("SSL certificate has been loaded...")

self.server_task = asyncio.start_server(self.__requests_handler,
self.host,
self.port,
ssl=self.ssl_context)
57 changes: 57 additions & 0 deletions lightberry/core/sockets_servers/http/http_socket_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from lightberry.utils import common_utils
from lightberry.core.communication.request import Request
from lightberry.utils import requests_utils
import asyncio

from lightberry.typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Type, Coroutine
from lightberry.config.base_config import BaseConfig
from asyncio import StreamReader, StreamWriter


class HttpSocketServer:
def __init__(self,
host="0.0.0.0",
port=80,
debug_mode=False):
self.server_task: Coroutine | None = None

self.debug_mode = debug_mode
self.config: Type[BaseConfig] | None = None

self.host = host
self.port = port

async def __load_request(self, request_stream: StreamReader) -> Request | None:
try:
request_header_string = await requests_utils.load_request_header_from_stream(request_stream)
self.__print_debug(f"request header string: {request_header_string}")

request = Request()
request.parse_header(request_header_string)

if request.content_length:
request_body_string = await request_stream.readexactly(request.content_length)
request.parse_body(request_body_string.decode())

self.__print_debug(f"request body string: {request.body}")

return request

except Exception as e:
self.__print_debug(f"error while parsing request", exception=e)
return None

async def __requests_handler(self, client_r: StreamReader, client_w: StreamWriter):
raise NotImplementedError("Not not implemented")

def setup(self):
self.server_task = asyncio.start_server(self.__requests_handler,
self.host,
self.port)

def __print_debug(self, message: str, exception: Exception | None = None):
common_utils.print_debug(message, f"SERVER - {self.__class__.__name__}",
debug_enabled=self.debug_mode, exception=exception)
40 changes: 40 additions & 0 deletions lightberry/core/sockets_servers/http/ssl_proxy_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from lightberry.config import ServerConfig as Config
from lightberry.core.communication.response import Response
from lightberry.core.sockets_servers.http import HttpSocketServer
import asyncio

from lightberry.typing import TYPE_CHECKING

if TYPE_CHECKING:
from asyncio import StreamReader, StreamWriter


class SslProxyServer(HttpSocketServer):
def __init__(self, wlan, host="0.0.0.0", port=Config.SERVER_PORT, debug_mode=Config.DEBUG):
super().__init__(host, port, debug_mode)

self.__wlan = wlan
self.hostname = self.__wlan.ifconfig()[0]
self.config = Config

async def __requests_handler(self, client_r: StreamReader, client_w: StreamWriter):
self.__print_debug(f"connection from: {client_w.get_extra_info('peername')}")

try:
request = await asyncio.wait_for(self.__load_request(client_r), self.config.TIMEOUT)

if request:
response = Response(301)
response.add_header("LOCATION", f"https://{self.hostname}{request.url}")

client_w.write(bytes(response.get_response_string(), "utf-8"))
await client_w.drain()

except Exception as e:
self.__print_debug(f"error occurred: {str(e)}", exception=e)

finally:
client_w.close()
await client_w.wait_closed()

self.__print_debug(f"connection closed")

0 comments on commit aea31d5

Please sign in to comment.