diff --git a/lightberry/core/server.py b/lightberry/core/server.py index 5a64987..33918d2 100644 --- a/lightberry/core/server.py +++ b/lightberry/core/server.py @@ -1,27 +1,24 @@ import network -import ssl -import time import asyncio -from lightberry.core.communication.request import Request -from lightberry.core.communication.response import Response from lightberry.tasks.periodic_tasks import ReconnectToNetworkTask, BlinkLedTask -from lightberry.utils import common_utils, requests_utils, files_utils +from lightberry.core.sockets_servers.http import AppServer +from lightberry.core.sockets_servers.http import SslProxyServer +from lightberry.utils import common_utils from lightberry.config import ServerConfig as Config from lightberry.typing import TYPE_CHECKING if TYPE_CHECKING: from network import WLAN - from asyncio import AbstractEventLoop, StreamReader, StreamWriter + from asyncio import AbstractEventLoop from lightberry.core.app import App from lightberry.tasks.task_base import TaskBase + from lightberry.core.sockets_servers.http import HttpSocketServer class Server: def __init__(self, app: App, - host="0.0.0.0", - port=Config.SERVER_PORT, debug_mode=Config.DEBUG, wifi_ssid=Config.WIFI_SSID, wifi_password=Config.WIFI_PASSWORD, @@ -34,9 +31,6 @@ def __init__(self, self.debug_mode = debug_mode self.config = Config - self.host = host - self.port = port - self.wifi_ssid = wifi_ssid self.wifi_password = wifi_password @@ -52,6 +46,7 @@ def __init__(self, self.__mainloop: AbstractEventLoop = asyncio.get_event_loop() self.__app = app + self.__http_socket_servers: list[HttpSocketServer] = [] self.__background_tasks: list[TaskBase] = [] self.__run_as_host() if self.__hotspot_mode else self.__run_as_client() @@ -96,114 +91,27 @@ def __run_as_client(self): def __run_as_host(self): self.__setup_wlan_as_host() - async def __load_request(self, request_stream: StreamReader) -> Request | None: - try: - request_header_string = await requests_utils.load_request_header_from_stream(request_stream) - self.__print_debug(f"request header string: {request_header_string}") - - request = Request() - request.parse_header(request_header_string) - - if request.content_length: - request_body_string = await request_stream.readexactly(request.content_length) - request.parse_body(request_body_string.decode()) - - self.__print_debug(f"request body string: {request.body}") - - return request - - except Exception as e: - self.__print_debug(f"error while parsing request", exception=e) - return None - - async def __requests_handler(self, client_r: StreamReader, client_w: StreamWriter): - self.__print_debug(f"connection from: {client_w.get_extra_info('peername')}") - - try: - start_time = time.ticks_ms() if self.debug_mode else None - - request = await asyncio.wait_for(self.__load_request(client_r), self.config.TIMEOUT) - - if request: - response = await asyncio.wait_for(self.__app.requests_handler(request), self.__app.config.TIMEOUT) - - if response.is_payload_streamed: - client_w.write(bytes(f"{response.get_headers()}\r\n\r\n", "utf-8")) - await client_w.drain() - - for chunk in response.get_body(): - client_w.write(bytes(chunk, "utf-8")) - await client_w.drain() - else: - client_w.write(bytes(response.get_response_string(), "utf-8")) - await client_w.drain() - - except Exception as e: - self.__print_debug(f"error occurred: {str(e)}", exception=e) + def __init_http_socket_servers(self): + app_server = AppServer(self.__app, port=self.config.SERVER_PORT) + app_server.setup() - finally: - client_w.close() - await client_w.wait_closed() + self.__http_socket_servers.append(app_server) - if self.debug_mode: - self.__print_debug(f"request took: {time.ticks_ms() - start_time}ms") + if app_server.ssl_context is not None: + proxy_server = SslProxyServer(self.__wlan, port=self.config.SERVER_PORT) + proxy_server.setup() - async def __ssl_proxy_requests_handler(self, client_r: StreamReader, client_w: StreamWriter): - self.__print_debug(f"ssl proxy - connection from: {client_w.get_extra_info('peername')}") + self.__http_socket_servers.append(proxy_server) - try: - request = await asyncio.wait_for(self.__load_request(client_r), self.config.TIMEOUT) - - if request: - response = Response(301) - response.headers["LOCATION"] = f"https://{self.__wlan.ifconfig()[0]}{request.url}" - - client_w.write(bytes(response.get_response_string(), "utf-8")) - await client_w.drain() - - except Exception as e: - self.__print_debug(f"error occurred: {str(e)}", exception=e) - - finally: - client_w.close() - await client_w.wait_closed() - - self.__print_debug(f"ssl proxy - connection closed") - - def __init_server(self): - ssl_context = None - ssl_cert_file = self.config.get("CERT_FILE") - ssl_key_file = self.config.get("CERT_KEY") - - if ((ssl_cert_file and files_utils.file_exists(ssl_cert_file)) and - (ssl_key_file and files_utils.file_exists(ssl_key_file))): - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - ssl_context.load_cert_chain(ssl_cert_file, ssl_key_file) - - self.__print_debug("SSL certificate has been loaded...") - self.__init_ssl_proxy() - - target_port = self.port if ssl_context is None else 443 - server_task = asyncio.start_server(self.__requests_handler, - self.host, - target_port, - ssl=ssl_context) - - self.__mainloop.create_task(server_task) - - def __init_ssl_proxy(self): - proxy_server_task = asyncio.start_server(self.__ssl_proxy_requests_handler, - self.host, - self.port) - - self.__mainloop.create_task(proxy_server_task) - self.__print_debug(f"ssl proxy running at port: {self.port}") + for server in self.__http_socket_servers: # type: HttpSocketServer + self.__mainloop.create_task(server.server_task) + self.__print_debug(f"Task for {server.__class__.__name__} has been created") def start(self): self.__print_debug("starting mainloop...") if self.__wlan is not None: - self.__init_server() + self.__init_http_socket_servers() self.__register_background_tasks() self.__setup_app() @@ -212,13 +120,15 @@ def start(self): self.__print_debug("mainloop running...") self.__mainloop.run_forever() + else: + self.__print_debug("Couldn't start server wlan is None") def stop(self): self.__mainloop.stop() self.__mainloop.close() def __setup_app(self): - self.__app.host = f"{self.__wlan.ifconfig()[0]}:{self.port}" + self.__app.host = f"{self.__wlan.ifconfig()[0]}:{self.config.SERVER_PORT}" self.__app.register_background_tasks(self.__mainloop) def __register_background_tasks(self): diff --git a/lightberry/core/sockets_servers/http/__init__.py b/lightberry/core/sockets_servers/http/__init__.py new file mode 100644 index 0000000..d1da4ac --- /dev/null +++ b/lightberry/core/sockets_servers/http/__init__.py @@ -0,0 +1,3 @@ +from lightberry.core.sockets_servers.http.http_socket_server import HttpSocketServer +from lightberry.core.sockets_servers.http.app_server import AppServer +from lightberry.core.sockets_servers.http.ssl_proxy_server import SslProxyServer diff --git a/lightberry/core/sockets_servers/http/app_server.py b/lightberry/core/sockets_servers/http/app_server.py new file mode 100644 index 0000000..c68acad --- /dev/null +++ b/lightberry/core/sockets_servers/http/app_server.py @@ -0,0 +1,71 @@ +from lightberry.config import ServerConfig as Config +from lightberry.core.sockets_servers.http import HttpSocketServer +from lightberry.utils import files_utils +import ssl +import time +import asyncio + +from lightberry.typing import TYPE_CHECKING + +if TYPE_CHECKING: + from asyncio import StreamReader, StreamWriter + from lightberry.core.app import App + + +class AppServer(HttpSocketServer): + def __init__(self, app, host="0.0.0.0", port=Config.SERVER_PORT, debug_mode=Config.DEBUG): + super().__init__(host, port, debug_mode) + + self.ssl_context: ssl.SSLContext | None = None + + self.__app: App = app + self.config = Config + + async def __requests_handler(self, client_r: StreamReader, client_w: StreamWriter): + self.__print_debug(f"connection from: {client_w.get_extra_info('peername')}") + + try: + start_time = time.ticks_ms() if self.debug_mode else None + request = await asyncio.wait_for(self.__load_request(client_r), self.config.TIMEOUT) + + if request: + response = await asyncio.wait_for(self.__app.requests_handler(request), self.__app.config.TIMEOUT) + + if response.is_payload_streamed: + client_w.write(bytes(f"{response.get_headers()}\r\n\r\n", "utf-8")) + await client_w.drain() + + for chunk in response.get_body(): + client_w.write(bytes(chunk, "utf-8")) + await client_w.drain() + else: + client_w.write(bytes(response.get_response_string(), "utf-8")) + await client_w.drain() + + except Exception as e: + self.__print_debug(f"error occurred: {str(e)}", exception=e) + + finally: + client_w.close() + await client_w.wait_closed() + + if self.debug_mode: + self.__print_debug(f"request took: {time.ticks_ms() - start_time}ms") + + def setup(self): + ssl_cert_file = self.config.get("CERT_FILE") + ssl_key_file = self.config.get("CERT_KEY") + + if ((ssl_cert_file and files_utils.file_exists(ssl_cert_file)) and + (ssl_key_file and files_utils.file_exists(ssl_key_file))): + self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + self.ssl_context.load_cert_chain(ssl_cert_file, ssl_key_file) + + self.port = 443 + + self.__print_debug("SSL certificate has been loaded...") + + self.server_task = asyncio.start_server(self.__requests_handler, + self.host, + self.port, + ssl=self.ssl_context) diff --git a/lightberry/core/sockets_servers/http/http_socket_server.py b/lightberry/core/sockets_servers/http/http_socket_server.py new file mode 100644 index 0000000..55fea38 --- /dev/null +++ b/lightberry/core/sockets_servers/http/http_socket_server.py @@ -0,0 +1,57 @@ +from lightberry.utils import common_utils +from lightberry.core.communication.request import Request +from lightberry.utils import requests_utils +import asyncio + +from lightberry.typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Type, Coroutine + from lightberry.config.base_config import BaseConfig + from asyncio import StreamReader, StreamWriter + + +class HttpSocketServer: + def __init__(self, + host="0.0.0.0", + port=80, + debug_mode=False): + self.server_task: Coroutine | None = None + + self.debug_mode = debug_mode + self.config: Type[BaseConfig] | None = None + + self.host = host + self.port = port + + async def __load_request(self, request_stream: StreamReader) -> Request | None: + try: + request_header_string = await requests_utils.load_request_header_from_stream(request_stream) + self.__print_debug(f"request header string: {request_header_string}") + + request = Request() + request.parse_header(request_header_string) + + if request.content_length: + request_body_string = await request_stream.readexactly(request.content_length) + request.parse_body(request_body_string.decode()) + + self.__print_debug(f"request body string: {request.body}") + + return request + + except Exception as e: + self.__print_debug(f"error while parsing request", exception=e) + return None + + async def __requests_handler(self, client_r: StreamReader, client_w: StreamWriter): + raise NotImplementedError("Not not implemented") + + def setup(self): + self.server_task = asyncio.start_server(self.__requests_handler, + self.host, + self.port) + + def __print_debug(self, message: str, exception: Exception | None = None): + common_utils.print_debug(message, f"SERVER - {self.__class__.__name__}", + debug_enabled=self.debug_mode, exception=exception) diff --git a/lightberry/core/sockets_servers/http/ssl_proxy_server.py b/lightberry/core/sockets_servers/http/ssl_proxy_server.py new file mode 100644 index 0000000..a12d8f3 --- /dev/null +++ b/lightberry/core/sockets_servers/http/ssl_proxy_server.py @@ -0,0 +1,40 @@ +from lightberry.config import ServerConfig as Config +from lightberry.core.communication.response import Response +from lightberry.core.sockets_servers.http import HttpSocketServer +import asyncio + +from lightberry.typing import TYPE_CHECKING + +if TYPE_CHECKING: + from asyncio import StreamReader, StreamWriter + + +class SslProxyServer(HttpSocketServer): + def __init__(self, wlan, host="0.0.0.0", port=Config.SERVER_PORT, debug_mode=Config.DEBUG): + super().__init__(host, port, debug_mode) + + self.__wlan = wlan + self.hostname = self.__wlan.ifconfig()[0] + self.config = Config + + async def __requests_handler(self, client_r: StreamReader, client_w: StreamWriter): + self.__print_debug(f"connection from: {client_w.get_extra_info('peername')}") + + try: + request = await asyncio.wait_for(self.__load_request(client_r), self.config.TIMEOUT) + + if request: + response = Response(301) + response.add_header("LOCATION", f"https://{self.hostname}{request.url}") + + client_w.write(bytes(response.get_response_string(), "utf-8")) + await client_w.drain() + + except Exception as e: + self.__print_debug(f"error occurred: {str(e)}", exception=e) + + finally: + client_w.close() + await client_w.wait_closed() + + self.__print_debug(f"connection closed")