From 19d965bb8db5859921fcde082e9626ba68151192 Mon Sep 17 00:00:00 2001 From: This is XiaoDeng <1744793737@qq.com> Date: Fri, 20 Oct 2023 20:17:52 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=20send=5Fmsg=20=E5=92=8C=20s?= =?UTF-8?q?end=5Fgroup=5Fmsg=20=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api.py | 16 ++++----- basic_actions_v12.py | 50 +++++++++++++------------- basic_api_v11.py | 56 +++++++++++++++++++++++++++++ call_action.py | 58 ++++++------------------------ checker.py | 44 ++++++++++++++++++++++- file.py | 8 ++--- http_post_v11.py | 4 +-- main.py | 2 +- message_parser.py | 6 ++-- event_12_to_11.py => translator.py | 35 +++++++++++++++++- ws_reverse_v11.py | 10 +++--- ws_v11.py | 6 ++-- 12 files changed, 194 insertions(+), 101 deletions(-) create mode 100644 basic_api_v11.py rename event_12_to_11.py => translator.py (63%) diff --git a/api.py b/api.py index c6f1d4c..9c4d98c 100644 --- a/api.py +++ b/api.py @@ -2,8 +2,7 @@ from logger import get_logger logger = get_logger() -action_list = {} -ob11_api_list = {} +action_list = {"v12": {}, "v11": {}} ''' def register_action(name: str) -> Callable: @@ -16,10 +15,9 @@ def decorator(func: Callable) -> None: return decorator ''' -def register_action(func: Callable) -> None: - action_list[func.__name__] = func - logger.debug(f"成功注册动作:{func.__name__}") - -def register_ob11_api(func: Callable) -> None: - ob11_api_list[func.__name__] = func - logger.debug(f"成功注册接口:{func.__name__} (OneBot V11)") +def register_action(_type: str = "v12") -> Callable: + def _(func: Callable): + action_list[_type][func.__name__] = func + logger.debug(f"成功注册动作:{func.__name__} ({_type=})") + return func + return _ diff --git a/basic_actions_v12.py b/basic_actions_v12.py index 5adccd7..7f7add8 100644 --- a/basic_actions_v12.py +++ b/basic_actions_v12.py @@ -10,7 +10,7 @@ logger = get_logger() -@register_action +@register_action() async def send_message( detail_type: str, message: list, @@ -58,7 +58,7 @@ async def send_message( return return_object.get(0, message_id=message_id, time=time.time()) -@register_action +@register_action() async def get_supported_actions() -> dict: """ 获取支持的动作列表 @@ -69,7 +69,7 @@ async def get_supported_actions() -> dict: return { "status": "ok", "retcode": 0, - "data": list(action_list.keys()), + "data": list(action_list["v12"].keys()), "message": "", } @@ -87,12 +87,12 @@ async def get_status() -> dict: ) -@register_action +@register_action() async def get_version() -> dict: return return_object.get(0, impl="onedisc", version=VERSION, onebot_version="12") -@register_action +@register_action() async def delete_message(message_id: str) -> dict: for message in client.cached_messages[::-1]: if str(message.id) == message_id: @@ -108,14 +108,14 @@ async def delete_message(message_id: str) -> dict: return return_object.get(0) -@register_action +@register_action() async def get_self_info() -> dict: return return_object.get( 0, user_id=str(client.user.id), user_name=client.user.name, user_displayname="" ) -@register_action +@register_action() async def get_user_info(user_id: str) -> dict: if not (user := client.get_user(int(user_id))): return return_object.get(35003, "用户不存在") @@ -128,19 +128,19 @@ async def get_user_info(user_id: str) -> dict: ) -@register_action +@register_action() async def get_friend_list() -> dict: return return_object._get(0, []) -@register_action +@register_action() async def get_group_info(group_id: str) -> dict: if not (channel := client.get_channel(int(group_id))): return return_object.get(35001, "频道不存在") return return_object.get(0, group_id=str(channel.id), group_name=channel.name) -@register_action +@register_action() async def get_group_list() -> dict: channel_list = [] for channel in client.get_all_channels(): @@ -148,7 +148,7 @@ async def get_group_list() -> dict: return return_object._get(0, channel_list) -@register_action +@register_action() async def get_group_member_info(group_id: str, user_id: str) -> dict: if not (channel := client.get_channel(int(group_id))): return return_object.get(35001, "频道不存在") @@ -163,7 +163,7 @@ async def get_group_member_info(group_id: str, user_id: str) -> dict: ) -@register_action +@register_action() async def get_group_member_list(group_id: str) -> dict: if not (channel := client.get_channel(int(group_id))): return return_object.get(35001, "频道不存在") @@ -179,12 +179,12 @@ async def get_group_member_list(group_id: str) -> dict: return return_object._get(0, member_list) -@register_action +@register_action() async def set_group_name(group_id: str, group_name: str) -> dict: return return_object.get(10002, "不支持机器人修改频道名") -@register_action +@register_action() async def leave_group(group_id: str) -> dict: if not (channel := client.get_channel(int(group_id))): return return_object.get(35001, "频道不存在") @@ -192,14 +192,14 @@ async def leave_group(group_id: str) -> dict: return return_object.get(0) -@register_action +@register_action() async def get_guild_info(guild_id: str) -> dict: if not (guild := client.get_guild(int(guild_id))): return return_object.get(35004, "服务器不存在") return return_object.get(0, guild_id=str(guild.id), guild_name=guild.name) -@register_action +@register_action() async def get_guild_list() -> dict: guild_list = [] for guild in client.guilds: @@ -207,12 +207,12 @@ async def get_guild_list() -> dict: return return_object._get(0, guild_list) -@register_action +@register_action() async def set_guild_name(guild_id: str, guild_name: str) -> dict: return return_object.get(10002, "不支持机器人修改群组名") -@register_action +@register_action() async def get_guild_member_info(guild_id: str, user_id: str) -> dict: if not (guild := client.get_guild(int(guild_id))): return return_object.get(35004, "服务器不存在") @@ -227,7 +227,7 @@ async def get_guild_member_info(guild_id: str, user_id: str) -> dict: ) -@register_action +@register_action() async def get_guild_member_list(guild_id: str) -> dict: if not (guild := client.get_guild(int(guild_id))): return return_object.get(35004, "服务器不存在") @@ -243,7 +243,7 @@ async def get_guild_member_list(guild_id: str) -> dict: return return_object._get(0, member_list) -@register_action +@register_action() async def leave_guild(guild_id: str) -> dict: if not (guild := client.get_guild(int(guild_id))): return return_object.get(35004, "服务器不存在") @@ -274,7 +274,7 @@ def _parse_channel_action_data(_data: dict) -> dict: return data -@register_action +@register_action() async def get_channel_list(guild_id: str, joined_only: bool = False) -> dict: if not (guild := client.get_guild(int(guild_id))): return return_object.get(35004, "服务器不存在") @@ -291,21 +291,21 @@ async def get_channel_list(guild_id: str, joined_only: bool = False) -> dict: return return_object._get(0, channel_list) -@register_action +@register_action() async def set_channel_name(guild_id: str, channel_id: str, channel_name: str) -> dict: return return_object.get(10002, "不支持机器人修改频道名") -@register_action +@register_action() async def get_channel_member_info(guild_id: str, channel_id: str, user_id: str) -> dict: return _parse_channel_action_data(await get_group_member_info(channel_id, user_id)) -@register_action +@register_action() async def get_channel_member_list(guild_id: str, channel_id: str) -> dict: return _parse_channel_action_data(await get_group_member_list(channel_id)) -@register_action +@register_action() async def leave_channel(guild_id: str, channel_id: str) -> dict: return _parse_channel_action_data(await leave_group(channel_id)) diff --git a/basic_api_v11.py b/basic_api_v11.py new file mode 100644 index 0000000..416434f --- /dev/null +++ b/basic_api_v11.py @@ -0,0 +1,56 @@ +import basic_actions_v12 +from api import register_action +import translator +import message_parser_v11 + + +@register_action("v11") +async def send_group_msg( + group_id: int, + message: str | list[dict], + auto_escape: bool = False +) -> dict: + if isinstance(message, str) and not auto_escape: + message = message_parser_v11.parse_text(message) + elif isinstance(message, str): + message = [{ + "type": "text", + "data": { + "text": message + } + }] + return translator.translate_action_response( + await basic_actions_v12.send_message( + "group", + translator.translate_message_array(message), + group_id = str(group_id) + ) + ) + + + +@register_action("v11") +async def send_msg( + message: str | list, + message_type: str, + group_id: int | None = None, + user_id: int | None = None, + auto_escape: bool = False +) -> dict: + if isinstance(message, str) and not auto_escape: + message = message_parser_v11.parse_text(message) + elif isinstance(message, str): + message = [{ + "type": "text", + "data": { + "text": message + } + }] + return translator.translate_action_response( + await basic_actions_v12.send_message( + message_type, + translator.translate_message_array(message), + group_id=str(group_id), + user_id=str(user_id) + ) + ) diff --git a/call_action.py b/call_action.py index 09b0738..433accc 100644 --- a/call_action.py +++ b/call_action.py @@ -1,6 +1,6 @@ -from api import action_list, ob11_api_list +from api import action_list from typing import Callable -import inspect +from checker import check_request_params import return_object from checker import BadParam from logger import get_logger @@ -11,33 +11,6 @@ logger = get_logger() -def check_params(func: Callable, params: dict) -> tuple[bool, dict]: - """ - 检查参数及类型类型 - - Args: - func (Callable): 动作函数 - params (dict): 实参列表 - - Returns: - tuple[bool] | tuple[bool, dict]: 检查结果 - """ - arg_spec = inspect.getfullargspec(func) - for key in list(params.keys()): - if key not in arg_spec.args: - if config["system"].get("ignore_unneeded_args", True): - logger.warning(f"参数 {key} 未在 {func.__name__} 中定义,已忽略") - del params[key] - continue - else: - return False, return_object.get(10004, f"参数 {key} 未在 {func.__name__} 中定义") - if key in arg_spec.annotations.keys() and not isinstance(params[key], arg_spec.annotations[key]): - if not config["system"].get("ignore_error_types"): - return False, return_object.get(10001, f"参数 {key} ({type(params[key])},应为 {arg_spec.annotations[key]}) 类型不正确") - logger.warning(f"参数 {key} ({type(params[key])},应为 {arg_spec.annotations[key]}) 类型不正确,已忽略") - return True, {} - - def get_action_function(action: str, protocol_version: int) -> Callable | None: """ 获取动作函数 @@ -49,17 +22,12 @@ def get_action_function(action: str, protocol_version: int) -> Callable | None: Returns: Callable: 动作执行函数 """ - if protocol_version == 11 and action not in ob11_api_list.keys() and config["system"].get("allow_v12_actions", True): - logger.warning(f"接口 {action} (V11) 不存在,尝试使用 V12") - return action_list.get(action) - elif protocol_version == 11: - logger.error(f"接口 {action} (V11) 不存在") - return ob11_api_list.get(action) - elif protocol_version == 12 and action not in action_list.keys() and config["system"].get("allow_v11_actions", False): - logger.warning(f"动作 {action} 不存在,尝试使用 V11") - return ob11_api_list.get(action) - else: - return action_list.get(action) + if action in action_list.get(f"v{protocol_version}", {}).keys(): + return action_list[f"v{protocol_version}"][action] + for actions in action_list.values(): + if action in actions.keys(): + return actions[action] + return None async def on_call_action(action: str, params: dict, echo: str | None = None, protocol_version: int = 12, **_) -> dict: @@ -67,8 +35,8 @@ async def on_call_action(action: str, params: dict, echo: str | None = None, pro if config['system'].get("allow_strike") and random.random() <= 0.1: return return_object.get(36000, "I am tried.") if not (action_function := get_action_function(action, protocol_version)): - return return_object.get(10002, "action not found") - if not (params_checking_result := check_params(action_function, params))[0]: + return return_object.get(10002, f"未定义的动作:{action}") + if not (params_checking_result := check_request_params(action_function, params))[0]: return params_checking_result[1] try: return_data = await action_function(**params) @@ -78,12 +46,6 @@ async def on_call_action(action: str, params: dict, echo: str | None = None, pro return return_object.get(10006, str(e)) except BadParam as e: return return_object.get(10003, str(e)) - # except TypeError as e: - # if "got an unexpected keyword argument" in str(e): - # return return_object.get(10004, str(e)) - # else: - # logger.error(traceback.format_exc()) - # return_data = return_object.get(20002, str(e)) except Exception as e: logger.error(traceback.format_exc()) return_data = return_object.get(20002, str(e)) diff --git a/checker.py b/checker.py index 0e54686..7006d5a 100644 --- a/checker.py +++ b/checker.py @@ -1,5 +1,47 @@ +import inspect +import traceback +import return_object +from config import config +from typing import Callable +from logger import get_logger + +logger = get_logger() + class BadParam(Exception): pass def check_aruments(*args) -> None: if None in args: - raise BadParam("None is not allowed") \ No newline at end of file + raise BadParam("None is not allowed") + + +def check_request_params(func: Callable, params: dict) -> tuple[bool, dict]: + """ + 检查参数及类型类型 + + Args: + func (Callable): 动作函数 + params (dict): 实参列表 + + Returns: + tuple[bool] | tuple[bool, dict]: 检查结果 + """ + arg_spec = inspect.getfullargspec(func) + for key in list(params.keys()): + if key not in arg_spec.args: + if config["system"].get("ignore_unneeded_args", True): + logger.warning(f"参数 {key} 未在 {func.__name__} 中定义,已忽略") + del params[key] + continue + else: + return False, return_object.get(10004, f"参数 {key} 未在 {func.__name__} 中定义") + if config["system"].get("skip_params_type_checking", False): + continue + try: + if key in arg_spec.annotations.keys() and not isinstance(params[key], arg_spec.annotations[key]): + if not config["system"].get("ignore_error_types"): + return False, return_object.get(10001, f"参数 {key} ({type(params[key])},应为 {arg_spec.annotations[key]}) 类型不正确") + logger.warning(f"参数 {key} ({type(params[key])},应为 {arg_spec.annotations[key]}) 类型不正确,已忽略") + except TypeError: + logger.warning(f"检查参数 {key} 的类型时出现错误:{traceback.format_exc()}") + return True, {} + diff --git a/file.py b/file.py index b14f755..1a80707 100644 --- a/file.py +++ b/file.py @@ -88,7 +88,7 @@ def upload_file_from_path(name: str, path: str) -> tuple[bool, str]: return False, str(e) -@register_action +@register_action() async def get_file_fragmented( stage: str, file_id: str, @@ -120,7 +120,7 @@ async def get_file_fragmented( uploading_files = {} -@register_action +@register_action() async def upload_file_fragmented( stage: str, name: str | None = None, @@ -161,7 +161,7 @@ async def upload_file_fragmented( ) return return_object.get(10003, f"无效的 stage 参数:{stage}") -@register_action +@register_action() async def upload_file( type: str, name: str, @@ -218,7 +218,7 @@ def get_file_name_by_id(file_id: str) -> str: def get_file_path(file_name: str) -> str: return os.path.abspath(f".cache/files/{file_name}") -@register_action +@register_action() async def get_file(file_id: str, type: str) -> dict: """ 获取文件 diff --git a/http_post_v11.py b/http_post_v11.py index 9850c1e..515a509 100644 --- a/http_post_v11.py +++ b/http_post_v11.py @@ -6,7 +6,7 @@ from client import client import httpx import hmac -import event_12_to_11 +import translator logger = get_logger() BASE_CONFIG = { @@ -48,7 +48,7 @@ async def push_event(self, _event: dict) -> None: Args: event (dict): 事件 """ - event = event_12_to_11.translate_event(_event) + event = translator.translate_event(_event) async with httpx.AsyncClient(timeout=self.config["timeout"]) as client: response = await client.post( self.config["url"], diff --git a/main.py b/main.py index 92f7a00..bc42c0d 100644 --- a/main.py +++ b/main.py @@ -12,7 +12,7 @@ import discord_event import file import basic_actions_v12 - +import basic_api_v11 client.run(config["account_token"], log_handler=None) diff --git a/message_parser.py b/message_parser.py index 450ccf0..ff4ce76 100644 --- a/message_parser.py +++ b/message_parser.py @@ -70,9 +70,9 @@ def parse_message(message: list) -> dict: if config["system"].get("ignore_unsupported_segment"): logger.warning(f"不支持的消息段类型:{segment['type']},已忽略") else: - raise UnsupportedSegment(segment["type"]) - except KeyError: - raise BadSegmentData(segment["type"]) + raise UnsupportedSegment(f'不支持的消息段: {segment["type"]}') + except KeyError as e: + raise BadSegmentData(f"无效的参数:{e} (在 {segment['type']} 中)") if not message_data["files"]: message_data.pop("files") logger.debug(message_data) diff --git a/event_12_to_11.py b/translator.py similarity index 63% rename from event_12_to_11.py rename to translator.py index cfa52af..ef78049 100644 --- a/event_12_to_11.py +++ b/translator.py @@ -37,6 +37,7 @@ def translate_event(_event: dict) -> dict: elif event["message_type"] == "group": event["sub_type"] = "normal" event["anonymous"] = None + event["font"] = 0 event["raw_message"] = event.pop("alt_message") sender = client.get_user(event["user_id"]) event["sender"] = { @@ -49,4 +50,36 @@ def translate_event(_event: dict) -> dict: }) event["message"] = message_parser_v11.parse_text(event["raw_message"]) logger.debug(event) - return event \ No newline at end of file + return event + + +def translate_action_response(_response: dict) -> dict: + response = _response.copy() + if isinstance(response["data"], dict): + for key, value in response["data"].items(): + if key.endswith("_id"): + try: + response["data"][key] = int(value) + except ValueError: + pass + elif isinstance(response["data"], list): + length = 0 + for item in response["data"]: + response["data"][length] = translate_action_response(item) + length += 1 + return response + +def translate_message_array(_message: list) -> list: + message = _message.copy() + length = -1 + for item in message: + length += 1 + match item["type"]: + case "at": + message[length]["type"] = "mention" + message[length]["data"]["user_id"] = message[length]["data"].pop("qq") + case "reply": + message[length]["data"]["message_id"] = message[length]["data"].pop("id") + return message + + diff --git a/ws_reverse_v11.py b/ws_reverse_v11.py index ba8655e..ae6f2ab 100644 --- a/ws_reverse_v11.py +++ b/ws_reverse_v11.py @@ -1,5 +1,5 @@ from client import client -import event_12_to_11 +import translator import websockets import json import json @@ -49,11 +49,12 @@ async def connect(self) -> None: except Exception: pass await self.event_ws.send(json.dumps( - event_12_to_11.translate_event(event.get_event_object( + translator.translate_event(event.get_event_object( "meta", "lifecycle", "connect" )))) + asyncio.create_task(self.handle_api_requests()) async def reconnect(self) -> None: if hasattr(self, "reconnect_task"): @@ -87,13 +88,14 @@ async def push_event(self, event: dict) -> None: try: await self.event_ws.send( json.dumps( - event_12_to_11.translate_event( + translator.translate_event( event ) ) ) except Exception: - logger.warning(f"推送事件时出现错误:{traceback.format_exc()}") + if not hasattr(self, "event_ws"): + logger.warning(f"推送事件时出现错误:{traceback.format_exc()}") await self.reconnect() await self.push_event(event) diff --git a/ws_v11.py b/ws_v11.py index a023c9c..be65632 100644 --- a/ws_v11.py +++ b/ws_v11.py @@ -1,4 +1,4 @@ -import event_12_to_11 +import translator import event from http_server import verify_access_token from logger import get_logger @@ -40,7 +40,7 @@ async def handle_event_route(self, websocket: fastapi.WebSocket) -> None: return await websocket.accept() self.clients_on_event_route.append(websocket) - await websocket.send_json(event_12_to_11.translate_event(event.get_event_object( + await websocket.send_json(translator.translate_event(event.get_event_object( "meta", "lifecycle", "connect" @@ -86,7 +86,7 @@ async def handle_root_route(self, websocket: fastapi.WebSocket) -> None: await websocket.send_json(resp_data) async def push_event(self, _event: dict) -> None: - event = event_12_to_11.translate_event(_event) + event = translator.translate_event(_event) for client in self.clients_on_event_route: try: await client.send_json(event)