Skip to content

Commit

Permalink
实现 send_msg 和 send_group_msg 接口
Browse files Browse the repository at this point in the history
  • Loading branch information
This-is-XiaoDeng committed Oct 20, 2023
1 parent 480e016 commit 19d965b
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 101 deletions.
16 changes: 7 additions & 9 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from logger import get_logger

logger = get_logger()
action_list = {}
ob11_api_list = {}
action_list = {"v12": {}, "v11": {}}

'''
def register_action(name: str) -> Callable:
Expand All @@ -16,10 +15,9 @@ def decorator(func: Callable) -> None:
return decorator
'''

def register_action(func: Callable) -> None:
action_list[func.__name__] = func
logger.debug(f"成功注册动作:{func.__name__}")

def register_ob11_api(func: Callable) -> None:
ob11_api_list[func.__name__] = func
logger.debug(f"成功注册接口:{func.__name__} (OneBot V11)")
def register_action(_type: str = "v12") -> Callable:
def _(func: Callable):
action_list[_type][func.__name__] = func
logger.debug(f"成功注册动作:{func.__name__} ({_type=})")
return func
return _
50 changes: 25 additions & 25 deletions basic_actions_v12.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
logger = get_logger()


@register_action
@register_action()
async def send_message(
detail_type: str,
message: list,
Expand Down Expand Up @@ -58,7 +58,7 @@ async def send_message(
return return_object.get(0, message_id=message_id, time=time.time())


@register_action
@register_action()
async def get_supported_actions() -> dict:
"""
获取支持的动作列表
Expand All @@ -69,7 +69,7 @@ async def get_supported_actions() -> dict:
return {
"status": "ok",
"retcode": 0,
"data": list(action_list.keys()),
"data": list(action_list["v12"].keys()),
"message": "",
}

Expand All @@ -87,12 +87,12 @@ async def get_status() -> dict:
)


@register_action
@register_action()
async def get_version() -> dict:
return return_object.get(0, impl="onedisc", version=VERSION, onebot_version="12")


@register_action
@register_action()
async def delete_message(message_id: str) -> dict:
for message in client.cached_messages[::-1]:
if str(message.id) == message_id:
Expand All @@ -108,14 +108,14 @@ async def delete_message(message_id: str) -> dict:
return return_object.get(0)


@register_action
@register_action()
async def get_self_info() -> dict:
return return_object.get(
0, user_id=str(client.user.id), user_name=client.user.name, user_displayname=""
)


@register_action
@register_action()
async def get_user_info(user_id: str) -> dict:
if not (user := client.get_user(int(user_id))):
return return_object.get(35003, "用户不存在")
Expand All @@ -128,27 +128,27 @@ async def get_user_info(user_id: str) -> dict:
)


@register_action
@register_action()
async def get_friend_list() -> dict:
return return_object._get(0, [])


@register_action
@register_action()
async def get_group_info(group_id: str) -> dict:
if not (channel := client.get_channel(int(group_id))):
return return_object.get(35001, "频道不存在")
return return_object.get(0, group_id=str(channel.id), group_name=channel.name)


@register_action
@register_action()
async def get_group_list() -> dict:
channel_list = []
for channel in client.get_all_channels():
channel_list.append({"group_id": str(channel.id), "group_name": channel.name})
return return_object._get(0, channel_list)


@register_action
@register_action()
async def get_group_member_info(group_id: str, user_id: str) -> dict:
if not (channel := client.get_channel(int(group_id))):
return return_object.get(35001, "频道不存在")
Expand All @@ -163,7 +163,7 @@ async def get_group_member_info(group_id: str, user_id: str) -> dict:
)


@register_action
@register_action()
async def get_group_member_list(group_id: str) -> dict:
if not (channel := client.get_channel(int(group_id))):
return return_object.get(35001, "频道不存在")
Expand All @@ -179,40 +179,40 @@ async def get_group_member_list(group_id: str) -> dict:
return return_object._get(0, member_list)


@register_action
@register_action()
async def set_group_name(group_id: str, group_name: str) -> dict:
return return_object.get(10002, "不支持机器人修改频道名")


@register_action
@register_action()
async def leave_group(group_id: str) -> dict:
if not (channel := client.get_channel(int(group_id))):
return return_object.get(35001, "频道不存在")
await channel.leave()
return return_object.get(0)


@register_action
@register_action()
async def get_guild_info(guild_id: str) -> dict:
if not (guild := client.get_guild(int(guild_id))):
return return_object.get(35004, "服务器不存在")
return return_object.get(0, guild_id=str(guild.id), guild_name=guild.name)


@register_action
@register_action()
async def get_guild_list() -> dict:
guild_list = []
for guild in client.guilds:
guild_list.append({"guild_id": str(guild.id), "guild_name": guild.name})
return return_object._get(0, guild_list)


@register_action
@register_action()
async def set_guild_name(guild_id: str, guild_name: str) -> dict:
return return_object.get(10002, "不支持机器人修改群组名")


@register_action
@register_action()
async def get_guild_member_info(guild_id: str, user_id: str) -> dict:
if not (guild := client.get_guild(int(guild_id))):
return return_object.get(35004, "服务器不存在")
Expand All @@ -227,7 +227,7 @@ async def get_guild_member_info(guild_id: str, user_id: str) -> dict:
)


@register_action
@register_action()
async def get_guild_member_list(guild_id: str) -> dict:
if not (guild := client.get_guild(int(guild_id))):
return return_object.get(35004, "服务器不存在")
Expand All @@ -243,7 +243,7 @@ async def get_guild_member_list(guild_id: str) -> dict:
return return_object._get(0, member_list)


@register_action
@register_action()
async def leave_guild(guild_id: str) -> dict:
if not (guild := client.get_guild(int(guild_id))):
return return_object.get(35004, "服务器不存在")
Expand Down Expand Up @@ -274,7 +274,7 @@ def _parse_channel_action_data(_data: dict) -> dict:
return data


@register_action
@register_action()
async def get_channel_list(guild_id: str, joined_only: bool = False) -> dict:
if not (guild := client.get_guild(int(guild_id))):
return return_object.get(35004, "服务器不存在")
Expand All @@ -291,21 +291,21 @@ async def get_channel_list(guild_id: str, joined_only: bool = False) -> dict:
return return_object._get(0, channel_list)


@register_action
@register_action()
async def set_channel_name(guild_id: str, channel_id: str, channel_name: str) -> dict:
return return_object.get(10002, "不支持机器人修改频道名")


@register_action
@register_action()
async def get_channel_member_info(guild_id: str, channel_id: str, user_id: str) -> dict:
return _parse_channel_action_data(await get_group_member_info(channel_id, user_id))


@register_action
@register_action()
async def get_channel_member_list(guild_id: str, channel_id: str) -> dict:
return _parse_channel_action_data(await get_group_member_list(channel_id))


@register_action
@register_action()
async def leave_channel(guild_id: str, channel_id: str) -> dict:
return _parse_channel_action_data(await leave_group(channel_id))
56 changes: 56 additions & 0 deletions basic_api_v11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import basic_actions_v12
from api import register_action
import translator
import message_parser_v11


@register_action("v11")
async def send_group_msg(
group_id: int,
message: str | list[dict],
auto_escape: bool = False
) -> dict:
if isinstance(message, str) and not auto_escape:
message = message_parser_v11.parse_text(message)
elif isinstance(message, str):
message = [{
"type": "text",
"data": {
"text": message
}
}]
return translator.translate_action_response(
await basic_actions_v12.send_message(
"group",
translator.translate_message_array(message),
group_id = str(group_id)
)
)



@register_action("v11")
async def send_msg(
message: str | list,
message_type: str,
group_id: int | None = None,
user_id: int | None = None,
auto_escape: bool = False
) -> dict:
if isinstance(message, str) and not auto_escape:
message = message_parser_v11.parse_text(message)
elif isinstance(message, str):
message = [{
"type": "text",
"data": {
"text": message
}
}]
return translator.translate_action_response(
await basic_actions_v12.send_message(
message_type,
translator.translate_message_array(message),
group_id=str(group_id),
user_id=str(user_id)
)
)
58 changes: 10 additions & 48 deletions call_action.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from api import action_list, ob11_api_list
from api import action_list
from typing import Callable
import inspect
from checker import check_request_params
import return_object
from checker import BadParam
from logger import get_logger
Expand All @@ -11,33 +11,6 @@

logger = get_logger()

def check_params(func: Callable, params: dict) -> tuple[bool, dict]:
"""
检查参数及类型类型
Args:
func (Callable): 动作函数
params (dict): 实参列表
Returns:
tuple[bool] | tuple[bool, dict]: 检查结果
"""
arg_spec = inspect.getfullargspec(func)
for key in list(params.keys()):
if key not in arg_spec.args:
if config["system"].get("ignore_unneeded_args", True):
logger.warning(f"参数 {key} 未在 {func.__name__} 中定义,已忽略")
del params[key]
continue
else:
return False, return_object.get(10004, f"参数 {key} 未在 {func.__name__} 中定义")
if key in arg_spec.annotations.keys() and not isinstance(params[key], arg_spec.annotations[key]):
if not config["system"].get("ignore_error_types"):
return False, return_object.get(10001, f"参数 {key} ({type(params[key])},应为 {arg_spec.annotations[key]}) 类型不正确")
logger.warning(f"参数 {key} ({type(params[key])},应为 {arg_spec.annotations[key]}) 类型不正确,已忽略")
return True, {}


def get_action_function(action: str, protocol_version: int) -> Callable | None:
"""
获取动作函数
Expand All @@ -49,26 +22,21 @@ def get_action_function(action: str, protocol_version: int) -> Callable | None:
Returns:
Callable: 动作执行函数
"""
if protocol_version == 11 and action not in ob11_api_list.keys() and config["system"].get("allow_v12_actions", True):
logger.warning(f"接口 {action} (V11) 不存在,尝试使用 V12")
return action_list.get(action)
elif protocol_version == 11:
logger.error(f"接口 {action} (V11) 不存在")
return ob11_api_list.get(action)
elif protocol_version == 12 and action not in action_list.keys() and config["system"].get("allow_v11_actions", False):
logger.warning(f"动作 {action} 不存在,尝试使用 V11")
return ob11_api_list.get(action)
else:
return action_list.get(action)
if action in action_list.get(f"v{protocol_version}", {}).keys():
return action_list[f"v{protocol_version}"][action]
for actions in action_list.values():
if action in actions.keys():
return actions[action]
return None


async def on_call_action(action: str, params: dict, echo: str | None = None, protocol_version: int = 12, **_) -> dict:
logger.debug(f"请求执行动作:{action} ({params=}, {echo=}, {protocol_version=})")
if config['system'].get("allow_strike") and random.random() <= 0.1:
return return_object.get(36000, "I am tried.")
if not (action_function := get_action_function(action, protocol_version)):
return return_object.get(10002, "action not found")
if not (params_checking_result := check_params(action_function, params))[0]:
return return_object.get(10002, f"未定义的动作:{action}")
if not (params_checking_result := check_request_params(action_function, params))[0]:
return params_checking_result[1]
try:
return_data = await action_function(**params)
Expand All @@ -78,12 +46,6 @@ async def on_call_action(action: str, params: dict, echo: str | None = None, pro
return return_object.get(10006, str(e))
except BadParam as e:
return return_object.get(10003, str(e))
# except TypeError as e:
# if "got an unexpected keyword argument" in str(e):
# return return_object.get(10004, str(e))
# else:
# logger.error(traceback.format_exc())
# return_data = return_object.get(20002, str(e))
except Exception as e:
logger.error(traceback.format_exc())
return_data = return_object.get(20002, str(e))
Expand Down
Loading

0 comments on commit 19d965b

Please sign in to comment.