Skip to content

Commit

Permalink
支持 WebSocket (OneBot V11)、反向 WebSocket (OneBot V11)
Browse files Browse the repository at this point in the history
  • Loading branch information
This-is-XiaoDeng committed Oct 15, 2023
1 parent b99764b commit 1f42f59
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 15 deletions.
55 changes: 45 additions & 10 deletions connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import asyncio
from http_webhook import HttpWebhookConnect
from http_server_v11 import HTTPServer4OB11
from http_post_v11 import HTTPPost4OB11
from ws_v11 import WebSocket4OB11
from ws_reverse_v11 import WebSocketClient4OB11
from logger import get_logger
from ws import WebSocketServer
from ws_reverse import WebSocketClient
Expand All @@ -11,11 +14,13 @@


async def init_connections(connection_list: list[dict]) -> None:
logger.debug(connection_list)
for obc_config in connection_list:
logger.debug(obc_config)

if "type" not in obc_config:
logger.error(f"无效的连接配置:{obc_config}")
continue

match obc_config["type"], obc_config.get("protocol_version", 12):

Expand All @@ -28,16 +33,6 @@ async def init_connections(connection_list: list[dict]) -> None:
})
await tmp.start_server()
del tmp

case "http", 11:
connection_list.append({
"type": "http",
"config": obc_config,
"object": (tmp := HTTPServer4OB11(obc_config)),
"add_event_func": tmp.push_event
})
await tmp.start_server()
del tmp

case "http-webhook", 12:
connections.append({
Expand Down Expand Up @@ -68,6 +63,46 @@ async def init_connections(connection_list: list[dict]) -> None:
asyncio.create_task(tmp.reconnect())
del tmp


case "http", 11:
connections.append({
"type": "http",
"config": obc_config,
"object": (tmp := HTTPServer4OB11(obc_config)),
"add_event_func": tmp.push_event
})
await tmp.start_server()
del tmp


case "http-post", 11:
connections.append({
"type": "http-post",
"config": obc_config,
"object": (tmp := HTTPPost4OB11(obc_config)),
"add_event_func": tmp.push_event
})
del tmp


case "ws", 11:
connections.append({
"type": "ws",
"config": obc_config,
"object": (tmp := WebSocket4OB11(obc_config)),
"add_event_func": tmp.push_event
})
await tmp.start()
del tmp

case "ws-reverse", 11:
connections.append({
"type": "ws-reverse",
"config": obc_config,
"object": (tmp := WebSocketClient4OB11(obc_config)),
"add_event_func": tmp.push_event
})
del tmp

case _:
logger.warning(f"无效的连接类型或协议版本,已忽略: {obc_config['type']} (协议版本: {obc_config.get('protocol_version', 12)}")
6 changes: 5 additions & 1 deletion event_12_to_11.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ def translate_event(_event: dict) -> dict:
event = _event.copy()
# 键名替换
event["time"] = int(event["time"])
event["self_id"] = int(event["self"]["user_id"])
event["self_id"] = int(event.pop("self")["user_id"])
event["post_type"] = event.pop("type")
if event["post_type"] == "meta":
event["post_type"] = "meta_event"
event[f"{event['post_type']}_type"] = event.pop("detail_type").replace("channel", "group")
event.pop("id")
if event[f"{event['post_type']}_type"] == "channel":
event[f"{event['post_type']}_type"] = "group"
if not event["sub_type"]:
Expand All @@ -33,6 +34,9 @@ def translate_event(_event: dict) -> dict:
if event["post_type"] == "message":
if event["message_type"] == "private":
event["sub_type"] = config["system"].get("default_message_sub_type", "group")
elif event["message_type"] == "group":
event["sub_type"] = "normal"
event["anonymous"] = None
event["raw_message"] = event.pop("alt_message")
sender = client.get_user(event["user_id"])
event["sender"] = {
Expand Down
15 changes: 15 additions & 0 deletions http_post_v11.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from logger import get_logger
import asyncio
import event
import quick_reply
import json
from client import client
Expand All @@ -25,6 +27,19 @@ def __init__(self, config: dict) -> None:
config (dict): 连接配置
"""
self.config = BASE_CONFIG | config
asyncio.create_task(self.push_event(event.get_event_object(
"meta",
"lifecycle",
"enable"
)))


def __del__(self) -> None:
asyncio.create_task(self.push_event(event.get_event_object(
"meta",
"lifecycle",
"disable"
)))

async def push_event(self, _event: dict) -> None:
"""
Expand Down
16 changes: 12 additions & 4 deletions ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, config: dict) -> None:
"""
self.config = BASE_CONFIG.copy()
self.config.update(config)
self.clients: list[fastapi.WebSocket] = []
self.app = fastapi.FastAPI()
self.app.add_websocket_route("/", self.handle_ws_connect)

Expand All @@ -33,9 +34,10 @@ async def start_server(self) -> None:

async def handle_ws_connect(self, websocket: fastapi.WebSocket) -> None:
if self.config["access_token"] and not verify_access_token(websocket, self.config["access_token"]):
raise fastapi.HTTPException(fastapi.status.HTTP_401_UNAUTHORIZED)
await websocket.close(fastapi.status.HTTP_401_UNAUTHORIZED)
return
await websocket.accept()
self.websocket = websocket
self.clients.append(websocket)
await websocket.send(event.get_event_object(
"meta",
"connect",
Expand All @@ -47,8 +49,14 @@ async def handle_ws_connect(self, websocket: fastapi.WebSocket) -> None:
while True:
recv_data = await websocket.receive_json()
logger.debug(recv_data)
await self.websocket.send_json(call_action.on_call_action(**recv_data))
await websocket.send_json(call_action.on_call_action(**recv_data))

async def push_event(self, event: dict) -> None:
await self.websocket.send_json(event)
for websocket in self.clients:
try:
await websocket.send_json(event)
except Exception as e:
logger.error(f"在 {websocket} 推送事件失败:{e}")
await websocket.close()
self.clients.remove(websocket)

123 changes: 123 additions & 0 deletions ws_reverse_v11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from client import client
import event_12_to_11
import websockets
import json
import json
import event
import call_action
import asyncio
import traceback
import websockets.client
import websockets.exceptions
from logger import get_logger
from version import VERSION

BASE_CONFIG = {
"url": None,
"api_url": None,
"event_url": None,
"reconnect_interval": 3000,
"use_universal_client": False,
"access_token": None
}
logger = get_logger()


class WebSocketClient4OB11:

def __init__(self, config: dict) -> None:
self.config = BASE_CONFIG | config
if not self.config["use_universal_client"]:
self.config["api_url"] = self.config["api_url"] or self.config["url"]
self.config["event_url"] = self.config["event_url"] or self.config["url"]
self.reconnect_task = asyncio.create_task(self.connect())

async def connect(self) -> None:
if self.config["use_universal_client"]:
self.api_ws = self.event_ws = await self.create_websocker_connection(
self.config["url"], "Universal"
)
else:
self.api_ws = await self.create_websocker_connection(
self.config["api_url"], "API"
)
self.event_ws = await self.create_websocker_connection(
self.config["event_url"], "Event"
)
try:
del self.reconnect_task
except Exception:
pass
await self.event_ws.send(json.dumps(
event_12_to_11.translate_event(event.get_event_object(
"meta",
"lifecycle",
"connect"
))))

async def reconnect(self) -> None:
if hasattr(self, "reconnect_task"):
await self.reconnect_task
return
await self.close()
self.reconnect_task = asyncio.create_task(self.connect())

async def close(self) -> None:
try:
await self.api_ws.close()
except Exception:
pass
try:
await self.event_ws.close()
except Exception:
pass

async def create_websocker_connection(self, url: str, role: str) -> websockets.client.WebSocketClientProtocol:
while True:
try:
logger.info(f"正在连接到反向 WebSocket {role} 服务器:{url}")
return await websockets.client.connect(
url, extra_headers=self.get_headers(role)
)
except Exception as e:
logger.warning(f"连接到反向 WebSocket {role} 时出现错误:{e}")
await asyncio.sleep(self.config["reconnect_interval"] / 1000)

async def push_event(self, event: dict) -> None:
try:
await self.event_ws.send(
json.dumps(
event_12_to_11.translate_event(
event
)
)
)
except Exception:
logger.warning(f"推送事件时出现错误:{traceback.format_exc()}")
await self.reconnect()
await self.push_event(event)

async def handle_api_requests(self) -> None:
while True:
try:
recv_data = json.loads(await self.api_ws.recv())
resp_data = await call_action.on_call_action(
**recv_data,
protocol_version=11
)
resp_data["retcode"] = {
10001: 1400,
10002: 1404
}.get(resp_data["retcode"], resp_data["retcode"])
await self.api_ws.send(json.dumps(resp_data))
except Exception:
logger.warning(f"处理 API 请求时出现错误:{traceback.format_exc()}")
break
await self.reconnect()


def get_headers(self, role: str) -> dict:
return {
"X-Self-ID": client.user.id,
"X-Client-Role": role
} | ({"Authorization": f'Bearer {self.config["access_token"]}'} if self.config["access_token"] else {})
96 changes: 96 additions & 0 deletions ws_v11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import event_12_to_11
import event
from http_server import verify_access_token
from logger import get_logger
import uvicorn_server
import fastapi
import call_action

logger = get_logger()
BASE_CONFIG = {
"host": "0.0.0.0",
"port": 6700,
"access_token": None
}

class WebSocket4OB11:

def __init__(self, config: dict) -> None:
self.config = BASE_CONFIG | config
self.clients_on_event_route = []
self.app = fastapi.FastAPI()
self.app.add_websocket_route("/", self.handle_root_route)
self.app.add_websocket_route("/event", self.handle_event_route)
self.app.add_websocket_route("/api", self.handle_api_route)

async def start(self) -> None:
await uvicorn_server.run(
self.app,
host=self.config["host"],
port=self.config["port"]
)


async def handle_event_route(self, websocket: fastapi.WebSocket) -> None:
if not verify_access_token(websocket, self.config["access_token"]):
if "Authorization" in websocket.headers.keys() or websocket.query_params.get("access_token"):
await websocket.close(403, "Invalid access token")
else:
await websocket.close(401, "Missing access token")
return
await websocket.accept()
self.clients_on_event_route.append(websocket)
await websocket.send_json(event_12_to_11.translate_event(event.get_event_object(
"meta",
"lifecycle",
"connect"
)))

async def handle_api_route(self, websocket: fastapi.WebSocket) -> None:
if not verify_access_token(websocket, self.config["access_token"]):
if "Authorization" in websocket.headers.keys() or websocket.query_params.get("access_token"):
await websocket.close(403, "Invalid access token")
else:
await websocket.close(401, "Missing access token")
return
await websocket.accept()
while True:
resp_data = await call_action.on_call_action(
**(await websocket.receive_json()),
protocol_version=11
)
resp_data["retcode"] = {
10001: 1400,
10002: 1404
}.get(resp_data["retcode"], resp_data["retcode"])
await websocket.send_json(resp_data)

async def handle_root_route(self, websocket: fastapi.WebSocket) -> None:
if not verify_access_token(websocket, self.config["access_token"]):
if "Authorization" in websocket.headers.keys() or websocket.query_params.get("access_token"):
await websocket.close(403, "Invalid access token")
else:
await websocket.close(401, "Missing access token")
return
await websocket.accept()
self.clients_on_event_route.append(websocket)
while True:
resp_data = await call_action.on_call_action(
**(await websocket.receive_json()),
protocol_version=12
)
resp_data["retcode"] = {
10001: 1400,
10002: 1404
}.get(resp_data["retcode"], resp_data["retcode"])
await websocket.send_json(resp_data)

async def push_event(self, _event: dict) -> None:
event = event_12_to_11.translate_event(_event)
for client in self.clients_on_event_route:
try:
await client.send_json(event)
except Exception as e:
logger.error(f"向 {client} 推送事件时出现错误:{e}")
await client.close()
self.clients_on_event_route.remove(client)

0 comments on commit 1f42f59

Please sign in to comment.