From 1f42f5962724b693834ec0a24ff5fd8547626310 Mon Sep 17 00:00:00 2001 From: This is XiaoDeng <1744793737@qq.com> Date: Sun, 15 Oct 2023 15:48:50 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=20WebSocket=20(OneBot=20V11)?= =?UTF-8?q?=E3=80=81=E5=8F=8D=E5=90=91=20WebSocket=20(OneBot=20V11)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- connection.py | 55 +++++++++++++++++---- event_12_to_11.py | 6 ++- http_post_v11.py | 15 ++++++ ws.py | 16 ++++-- ws_reverse_v11.py | 123 ++++++++++++++++++++++++++++++++++++++++++++++ ws_v11.py | 96 ++++++++++++++++++++++++++++++++++++ 6 files changed, 296 insertions(+), 15 deletions(-) create mode 100644 ws_reverse_v11.py create mode 100644 ws_v11.py diff --git a/connection.py b/connection.py index 123d88c..3d1ce85 100644 --- a/connection.py +++ b/connection.py @@ -2,6 +2,9 @@ import asyncio from http_webhook import HttpWebhookConnect from http_server_v11 import HTTPServer4OB11 +from http_post_v11 import HTTPPost4OB11 +from ws_v11 import WebSocket4OB11 +from ws_reverse_v11 import WebSocketClient4OB11 from logger import get_logger from ws import WebSocketServer from ws_reverse import WebSocketClient @@ -11,11 +14,13 @@ async def init_connections(connection_list: list[dict]) -> None: + logger.debug(connection_list) for obc_config in connection_list: logger.debug(obc_config) if "type" not in obc_config: logger.error(f"无效的连接配置:{obc_config}") + continue match obc_config["type"], obc_config.get("protocol_version", 12): @@ -28,16 +33,6 @@ async def init_connections(connection_list: list[dict]) -> None: }) await tmp.start_server() del tmp - - case "http", 11: - connection_list.append({ - "type": "http", - "config": obc_config, - "object": (tmp := HTTPServer4OB11(obc_config)), - "add_event_func": tmp.push_event - }) - await tmp.start_server() - del tmp case "http-webhook", 12: connections.append({ @@ -68,6 +63,46 @@ async def init_connections(connection_list: list[dict]) -> None: asyncio.create_task(tmp.reconnect()) del tmp + + case "http", 11: + connections.append({ + "type": "http", + "config": obc_config, + "object": (tmp := HTTPServer4OB11(obc_config)), + "add_event_func": tmp.push_event + }) + await tmp.start_server() + del tmp + + + case "http-post", 11: + connections.append({ + "type": "http-post", + "config": obc_config, + "object": (tmp := HTTPPost4OB11(obc_config)), + "add_event_func": tmp.push_event + }) + del tmp + + + case "ws", 11: + connections.append({ + "type": "ws", + "config": obc_config, + "object": (tmp := WebSocket4OB11(obc_config)), + "add_event_func": tmp.push_event + }) + await tmp.start() + del tmp + + case "ws-reverse", 11: + connections.append({ + "type": "ws-reverse", + "config": obc_config, + "object": (tmp := WebSocketClient4OB11(obc_config)), + "add_event_func": tmp.push_event + }) + del tmp case _: logger.warning(f"无效的连接类型或协议版本,已忽略: {obc_config['type']} (协议版本: {obc_config.get('protocol_version', 12)}") diff --git a/event_12_to_11.py b/event_12_to_11.py index 5f64347..cfa52af 100644 --- a/event_12_to_11.py +++ b/event_12_to_11.py @@ -10,11 +10,12 @@ def translate_event(_event: dict) -> dict: event = _event.copy() # 键名替换 event["time"] = int(event["time"]) - event["self_id"] = int(event["self"]["user_id"]) + event["self_id"] = int(event.pop("self")["user_id"]) event["post_type"] = event.pop("type") if event["post_type"] == "meta": event["post_type"] = "meta_event" event[f"{event['post_type']}_type"] = event.pop("detail_type").replace("channel", "group") + event.pop("id") if event[f"{event['post_type']}_type"] == "channel": event[f"{event['post_type']}_type"] = "group" if not event["sub_type"]: @@ -33,6 +34,9 @@ def translate_event(_event: dict) -> dict: if event["post_type"] == "message": if event["message_type"] == "private": event["sub_type"] = config["system"].get("default_message_sub_type", "group") + elif event["message_type"] == "group": + event["sub_type"] = "normal" + event["anonymous"] = None event["raw_message"] = event.pop("alt_message") sender = client.get_user(event["user_id"]) event["sender"] = { diff --git a/http_post_v11.py b/http_post_v11.py index 0a7a126..9850c1e 100644 --- a/http_post_v11.py +++ b/http_post_v11.py @@ -1,4 +1,6 @@ from logger import get_logger +import asyncio +import event import quick_reply import json from client import client @@ -25,6 +27,19 @@ def __init__(self, config: dict) -> None: config (dict): 连接配置 """ self.config = BASE_CONFIG | config + asyncio.create_task(self.push_event(event.get_event_object( + "meta", + "lifecycle", + "enable" + ))) + + + def __del__(self) -> None: + asyncio.create_task(self.push_event(event.get_event_object( + "meta", + "lifecycle", + "disable" + ))) async def push_event(self, _event: dict) -> None: """ diff --git a/ws.py b/ws.py index c1310f2..0ad70bd 100644 --- a/ws.py +++ b/ws.py @@ -25,6 +25,7 @@ def __init__(self, config: dict) -> None: """ self.config = BASE_CONFIG.copy() self.config.update(config) + self.clients: list[fastapi.WebSocket] = [] self.app = fastapi.FastAPI() self.app.add_websocket_route("/", self.handle_ws_connect) @@ -33,9 +34,10 @@ async def start_server(self) -> None: async def handle_ws_connect(self, websocket: fastapi.WebSocket) -> None: if self.config["access_token"] and not verify_access_token(websocket, self.config["access_token"]): - raise fastapi.HTTPException(fastapi.status.HTTP_401_UNAUTHORIZED) + await websocket.close(fastapi.status.HTTP_401_UNAUTHORIZED) + return await websocket.accept() - self.websocket = websocket + self.clients.append(websocket) await websocket.send(event.get_event_object( "meta", "connect", @@ -47,8 +49,14 @@ async def handle_ws_connect(self, websocket: fastapi.WebSocket) -> None: while True: recv_data = await websocket.receive_json() logger.debug(recv_data) - await self.websocket.send_json(call_action.on_call_action(**recv_data)) + await websocket.send_json(call_action.on_call_action(**recv_data)) async def push_event(self, event: dict) -> None: - await self.websocket.send_json(event) + for websocket in self.clients: + try: + await websocket.send_json(event) + except Exception as e: + logger.error(f"在 {websocket} 推送事件失败:{e}") + await websocket.close() + self.clients.remove(websocket) \ No newline at end of file diff --git a/ws_reverse_v11.py b/ws_reverse_v11.py new file mode 100644 index 0000000..ba8655e --- /dev/null +++ b/ws_reverse_v11.py @@ -0,0 +1,123 @@ +from client import client +import event_12_to_11 +import websockets +import json +import json +import event +import call_action +import asyncio +import traceback +import websockets.client +import websockets.exceptions +from logger import get_logger +from version import VERSION + +BASE_CONFIG = { + "url": None, + "api_url": None, + "event_url": None, + "reconnect_interval": 3000, + "use_universal_client": False, + "access_token": None +} +logger = get_logger() + + +class WebSocketClient4OB11: + + def __init__(self, config: dict) -> None: + self.config = BASE_CONFIG | config + if not self.config["use_universal_client"]: + self.config["api_url"] = self.config["api_url"] or self.config["url"] + self.config["event_url"] = self.config["event_url"] or self.config["url"] + self.reconnect_task = asyncio.create_task(self.connect()) + + async def connect(self) -> None: + if self.config["use_universal_client"]: + self.api_ws = self.event_ws = await self.create_websocker_connection( + self.config["url"], "Universal" + ) + else: + self.api_ws = await self.create_websocker_connection( + self.config["api_url"], "API" + ) + self.event_ws = await self.create_websocker_connection( + self.config["event_url"], "Event" + ) + try: + del self.reconnect_task + except Exception: + pass + await self.event_ws.send(json.dumps( + event_12_to_11.translate_event(event.get_event_object( + "meta", + "lifecycle", + "connect" + )))) + + async def reconnect(self) -> None: + if hasattr(self, "reconnect_task"): + await self.reconnect_task + return + await self.close() + self.reconnect_task = asyncio.create_task(self.connect()) + + async def close(self) -> None: + try: + await self.api_ws.close() + except Exception: + pass + try: + await self.event_ws.close() + except Exception: + pass + + async def create_websocker_connection(self, url: str, role: str) -> websockets.client.WebSocketClientProtocol: + while True: + try: + logger.info(f"正在连接到反向 WebSocket {role} 服务器:{url}") + return await websockets.client.connect( + url, extra_headers=self.get_headers(role) + ) + except Exception as e: + logger.warning(f"连接到反向 WebSocket {role} 时出现错误:{e}") + await asyncio.sleep(self.config["reconnect_interval"] / 1000) + + async def push_event(self, event: dict) -> None: + try: + await self.event_ws.send( + json.dumps( + event_12_to_11.translate_event( + event + ) + ) + ) + except Exception: + logger.warning(f"推送事件时出现错误:{traceback.format_exc()}") + await self.reconnect() + await self.push_event(event) + + async def handle_api_requests(self) -> None: + while True: + try: + recv_data = json.loads(await self.api_ws.recv()) + resp_data = await call_action.on_call_action( + **recv_data, + protocol_version=11 + ) + resp_data["retcode"] = { + 10001: 1400, + 10002: 1404 + }.get(resp_data["retcode"], resp_data["retcode"]) + await self.api_ws.send(json.dumps(resp_data)) + except Exception: + logger.warning(f"处理 API 请求时出现错误:{traceback.format_exc()}") + break + await self.reconnect() + + + def get_headers(self, role: str) -> dict: + return { + "X-Self-ID": client.user.id, + "X-Client-Role": role + } | ({"Authorization": f'Bearer {self.config["access_token"]}'} if self.config["access_token"] else {}) \ No newline at end of file diff --git a/ws_v11.py b/ws_v11.py new file mode 100644 index 0000000..a023c9c --- /dev/null +++ b/ws_v11.py @@ -0,0 +1,96 @@ +import event_12_to_11 +import event +from http_server import verify_access_token +from logger import get_logger +import uvicorn_server +import fastapi +import call_action + +logger = get_logger() +BASE_CONFIG = { + "host": "0.0.0.0", + "port": 6700, + "access_token": None +} + +class WebSocket4OB11: + + def __init__(self, config: dict) -> None: + self.config = BASE_CONFIG | config + self.clients_on_event_route = [] + self.app = fastapi.FastAPI() + self.app.add_websocket_route("/", self.handle_root_route) + self.app.add_websocket_route("/event", self.handle_event_route) + self.app.add_websocket_route("/api", self.handle_api_route) + + async def start(self) -> None: + await uvicorn_server.run( + self.app, + host=self.config["host"], + port=self.config["port"] + ) + + + async def handle_event_route(self, websocket: fastapi.WebSocket) -> None: + if not verify_access_token(websocket, self.config["access_token"]): + if "Authorization" in websocket.headers.keys() or websocket.query_params.get("access_token"): + await websocket.close(403, "Invalid access token") + else: + await websocket.close(401, "Missing access token") + return + await websocket.accept() + self.clients_on_event_route.append(websocket) + await websocket.send_json(event_12_to_11.translate_event(event.get_event_object( + "meta", + "lifecycle", + "connect" + ))) + + async def handle_api_route(self, websocket: fastapi.WebSocket) -> None: + if not verify_access_token(websocket, self.config["access_token"]): + if "Authorization" in websocket.headers.keys() or websocket.query_params.get("access_token"): + await websocket.close(403, "Invalid access token") + else: + await websocket.close(401, "Missing access token") + return + await websocket.accept() + while True: + resp_data = await call_action.on_call_action( + **(await websocket.receive_json()), + protocol_version=11 + ) + resp_data["retcode"] = { + 10001: 1400, + 10002: 1404 + }.get(resp_data["retcode"], resp_data["retcode"]) + await websocket.send_json(resp_data) + + async def handle_root_route(self, websocket: fastapi.WebSocket) -> None: + if not verify_access_token(websocket, self.config["access_token"]): + if "Authorization" in websocket.headers.keys() or websocket.query_params.get("access_token"): + await websocket.close(403, "Invalid access token") + else: + await websocket.close(401, "Missing access token") + return + await websocket.accept() + self.clients_on_event_route.append(websocket) + while True: + resp_data = await call_action.on_call_action( + **(await websocket.receive_json()), + protocol_version=12 + ) + resp_data["retcode"] = { + 10001: 1400, + 10002: 1404 + }.get(resp_data["retcode"], resp_data["retcode"]) + await websocket.send_json(resp_data) + + async def push_event(self, _event: dict) -> None: + event = event_12_to_11.translate_event(_event) + for client in self.clients_on_event_route: + try: + await client.send_json(event) + except Exception as e: + logger.error(f"向 {client} 推送事件时出现错误:{e}") + await client.close() + self.clients_on_event_route.remove(client)