Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

使用数据库代替部分 cached_messages #29

Merged
merged 3 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 23 additions & 27 deletions actions/v11/basic.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import asyncio
from utils import discord_api
import sys
import httpx
from utils.db import get_session, Message
import actions.v12.basic as basic
from actions import register_action
from discord.abc import PrivateChannel
import utils.node2image as node2image
from discord.channel import CategoryChannel, ForumChannel
from utils.logger import discord_api_failed, get_logger
from utils.logger import get_logger
import os
import utils.translator as translator
import utils.message.v11.parser as parser
import utils.return_object as return_object
from utils.config import config
from utils.client import client
import utils.message.v11.parser as parser
from utils.message.v12 import parser as v12_parser
import version
import discord
import actions.v12.file as file
Expand Down Expand Up @@ -151,25 +153,25 @@ async def get_login_info() -> dict:

@register_action("v11")
async def get_msg(message_id: int) -> dict:
for msg in client.cached_messages:
if msg.id == message_id:
return return_object.get(
0,
time=-1,
message_type="normal",
message_id=msg.id,
real_id=msg.id,
sender={
"user_id": msg.author.id,
"nickname": msg.author.name,
"card": msg.author.display_name,
"sex": "unknown"
},
message=parser.parse_string_to_array(msg.content)
)
async with get_session() as session:
message = await session.get_one(
Message,
message_id
)
message_data = await discord_api.call("GET", f"/channels/{message.channel}/messages/{message.id}")
return return_object.get(
1404,
"消息不存在!"
0,
time=message.time,
message_type="private" if message.channel == message_data["author"]["id"] else "group",
message_id=message.id,
real_id=message.id,
sender={
"user_id": message_data["author"]["id"],
"nickname": message_data["author"]["name"],
"card": message_data["author"]["display_name"],
"sex": "unknown"
},
message=translator.translate_message_array(v12_parser.parse_dict_message(message_data))
)

@register_action("v11")
Expand Down Expand Up @@ -203,13 +205,7 @@ async def set_group_ban(group_id: int, user_id: int, duration: int = 1800, reaso
@register_action("v11")
async def set_group_leave(group_id: int, is_dismiss: bool = False) -> dict:
if is_dismiss:
async with httpx.AsyncClient(proxies=config["system"].get("proxy")) as client:
response = await client.delete(
f"https://discord.com/api/v10/channels/{group_id}",
headers={"Authorization": f"Bot {config['account_token']}"}
)
if response.status_code == 400:
return discord_api_failed(response)
await discord_api.call("DELETE", f"/channels/{group_id}")
return return_object.get(0)
return translator.translate_action_response(await basic.leave_group(str(group_id)))

Expand Down
47 changes: 12 additions & 35 deletions actions/v12/basic.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import httpx
from utils.client import client
import discord
from version import VERSION
import utils.return_object as return_object
import time
import traceback
from utils.db import commit_message, get_session, Message
import utils.message.v12.parser as parser
from utils.config import config
from utils.logger import get_logger, discord_api_failed
from utils.logger import get_logger
from actions import register_action, action_list
from utils import commands
from utils import commands, discord_api

logger = get_logger()

Expand Down Expand Up @@ -69,6 +68,7 @@ async def send_message(
except discord.HTTPException as e:
logger.debug(traceback.format_exc())
return return_object.get(34000, str(e))
await commit_message(message_id, channel.id, int(time.time()))
return return_object.get(0, message_id=message_id, time=time.time())


Expand Down Expand Up @@ -103,31 +103,13 @@ async def get_status() -> dict:

@register_action()
async def set_group_name(group_id: str, group_name: str) -> dict:
async with httpx.AsyncClient(proxies=config["system"].get("proxy")) as client:
response = await client.patch(
f"https://discord.com/api/v10/channels/{group_id}",
data={"name": group_name},
headers={
"Authorization": f"Bot {config['account_token']}"
}
)
if response.status_code == 400:
return discord_api_failed(response)
await discord_api.call("PATCH", f"/channels/{group_id}", {"name": group_name})
return return_object.get(0)


@register_action()
async def set_guild_name(guild_id: str, guild_name: str) -> dict:
async with httpx.AsyncClient(proxies=config["system"].get("proxy")) as client:
response = await client.patch(
f"https://discord.com/api/v10/guilds/{guild_id}",
data={"name": guild_name},
headers={
"Authorization": f"Bot {config['account_token']}"
}
)
if response.status_code == 400:
return discord_api_failed(response)
await discord_api.call("PATCH", f"/guilds/{guild_id}", {"name": guild_name})
return return_object.get(0)

@register_action()
Expand All @@ -141,17 +123,12 @@ async def get_version() -> dict:

@register_action()
async def delete_message(message_id: str) -> dict:
for message in client.cached_messages[::-1]:
if str(message.id) == message_id:
try:
await message.delete()
except discord.Forbidden:
return return_object.get(34001, "权限错误")
except discord.NotFound:
return return_object.get(35002, "消息已被撤回")
break
else:
return return_object.get(35002, "消息不存在")
async with get_session() as session:
message = await session.get_one(
Message,
int(message_id)
)
await discord_api.call("DELETE", f"/channels/{message.channel}/messages/{message.id}")
return return_object.get(0)


Expand Down
32 changes: 17 additions & 15 deletions actions/v12/extra.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
from utils.config import config
from utils import discord_api
import utils.message.v12.parser as parser
import utils.return_object as return_object
from actions import register_action
from utils.client import client
from utils.db import get_session, Message


@register_action()
async def edit_message(message_id: str, content: list) -> dict:
for message in client.cached_messages:
if message.id == int(message_id):
await message.edit(content=(await parser.parse_message(content))["content"])
return return_object.get(0)
return return_object.get(35002, f"消息 {message_id} 不存在")
async with get_session() as session:
message = await session.get_one(
Message,
int(message_id)
)
await discord_api.call(
"PATCH",
f"/channels/{message.channel}/messages/{message.id}",
await parser.parse_message(content)
)
return return_object.get(0)



@register_action()
async def call_api(endpoint: str, method: str, data: dict) -> dict:
async with httpx.AsyncClient(proxies=config["system"]["proxy"]) as client:
response = await client.request(
method,
f"https://discord.com/api/v9/{endpoint}",
json=data
)
response = await discord_api.call(method, endpoint, data)
return return_object.get(
0,
status_code=response.status_code,
response=response.json()
status_code=response["code"],
response=response
)
6 changes: 6 additions & 0 deletions call_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from utils.type_checker import BadParam
from utils.logger import get_logger
import traceback
from utils.discord_api import DiscordApiException
from utils.message.v12.parser import UnsupportedSegment, BadSegmentData
import random
from utils.config import config
Expand Down Expand Up @@ -48,6 +49,11 @@ async def on_call_action(action: str, params: dict, echo: str | None = None, pro
return return_object.get(10006, str(e))
except BadParam as e:
return return_object.get(10003, str(e))
except DiscordApiException as e:
return return_object.get(
34002,
e.message
)
except Exception as e:
logger.error(traceback.format_exc())
return_data = return_object.get(20002, str(e))
Expand Down
15 changes: 15 additions & 0 deletions docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,21 @@ OneDisc 高级设置(无特殊需要不建议更改)

将合并转发消息渲染为图片缓存并发送时使用的图片类型


### 数据库地址(`database`)

| 类型 | 必须 | 默认值 |
|:----------:|:----:|:----------------------:|
| 字符串 | 否 | `sqlite+aiosqlite:///:memory:` |

OneDisc 缓存消息使用的数据库地址

参考 [Engine Configuration — SQLAlchemy 2.0 Documentation](https://docs.sqlalchemy.org/en/20/core/engines.html#database-urls)

不支持自动创建数据库

> 目前可执行版只支持 SQLite3,源码版使用其他数据库需要手动安装依赖

### 使用静态表情(`use_static_face`)

| 类型 | 必须 | 默认值 |
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ httpx
aiohttp>=3.7.4,<4
imgkit
websockets
sqlalchemy
aiosqlite
1 change: 1 addition & 0 deletions utils/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
proxy=config["system"]["proxy"]
)
tree = app_commands.CommandTree(client)

42 changes: 42 additions & 0 deletions utils/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from traceback import format_exc
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import declarative_base
from sqlalchemy import Column, Integer
from .config import config
from .logger import get_logger

logger = get_logger()
Base = declarative_base()
db_url = config["system"].get("database", "sqlite+aiosqlite:///:memory:")
logger.debug(f'使用数据库: {db_url}')
engine = create_async_engine(db_url)
del db_url


class Message(Base):
__tablename__ = "message"
id = Column(Integer, primary_key=True)
channel = Column(Integer)
time = Column(Integer)

async def init_database() -> None:
async with engine.connect() as conn:
await conn.run_sync(Message.metadata.create_all)
logger.info("数据库初始化完成!")

def get_session() -> AsyncSession:
return AsyncSession(engine)

async def commit_message(id_: int, channel: int, time_: int) -> None:
async with get_session() as session:
try:
session.add(Message(
id=id_,
channel=channel,
time=time_
))
except Exception:
await session.rollback()
logger.warning(f"写入数据库失败: {format_exc()}")
else:
await session.commit()
31 changes: 31 additions & 0 deletions utils/discord_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import json
import httpx
from .config import config
from .logger import get_logger

logger = get_logger()

class DiscordApiException(Exception):
def __init__(self, data: dict) -> None:
super().__init__(data)
self.code = data["code"]
self.message = data["message"]
self.data = data
logger.warning(f"调用 Discord API 时出现错误({self.code}):\n{json.dumps(self.data, indent=4)}")

async def call(method: str, path: str, data: dict | None = None, **params) -> dict:
async with httpx.AsyncClient(proxies=config["system"].get("proxy"), base_url="https://discord.com/api/v10") as client:
response = await client.request(
method,
path,
data=data,
headers={"Authorization": f"Bot {config['account_token']}"},
**params
)
if response.status_code == 400:
raise DiscordApiException(response.json())
elif response.status_code == 204:
return {
"code": 204
}
return response.json()
4 changes: 3 additions & 1 deletion utils/event/discord_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import discord
from discord import Object
import asyncio

from ..db import commit_message, init_database
from utils.update_checker import check_update
from actions.v12.basic import get_status
from actions.v11.basic import get_role
Expand Down Expand Up @@ -38,13 +38,15 @@ async def on_ready() -> None:
)
if config["system"].get("check_update", True):
asyncio.create_task(check_update())
await init_database()


@client.event
async def on_message(message: discord.Message) -> None:
if message.author == client.user and config["system"].get("ignore_self_events", True):
return
print_message_log(message)
await commit_message(message.id, message.channel.id, int(message.created_at.timestamp()))
if message.guild and config["system"].get("enable_channel_event"):
event.new_event(
_type="message",
Expand Down
18 changes: 9 additions & 9 deletions utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ def print_message_delete_log(message: discord.Message) -> None:
except AttributeError:
logger.info(f"{message.author.name}({message.author.id}) 撤回了消息:{message.content}")

def discord_api_failed(response: httpx.Response) -> dict:
body = json.loads(response.read())
logger.warning(f"调用 Discord API 时出现错误({body['cdoe']}):\n{json.dumps(body, indent=4)}")
return {
"status": "failed",
"retcode": 34002,
"data": None,
"message": body["message"]
}
# def discord_api_failed(response: httpx.Response) -> dict:
# body = json.loads(response.read())
# logger.warning(f"调用 Discord API 时出现错误({body['code']}):\n{json.dumps(body, indent=4)}")
# return {
# "status": "failed",
# "retcode": 34002,
# "data": None,
# "message": body["message"]
# }
Loading