Skip to content

Commit

Permalink
🐛 Fix: websockets 驱动器连接关闭 code 获取错误 (#2537)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanyongyu authored Jan 17, 2024
1 parent c2d2169 commit 2c6affe
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 12 deletions.
5 changes: 1 addition & 4 deletions nonebot/drivers/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ async def decorator(*args: P.args, **kwargs: P.kwargs) -> T:
try:
return await func(*args, **kwargs)
except ConnectionClosed as e:
if e.rcvd_then_sent:
raise WebSocketClosed(e.rcvd.code, e.rcvd.reason) # type: ignore
else:
raise WebSocketClosed(e.sent.code, e.sent.reason) # type: ignore
raise WebSocketClosed(e.code, e.reason)

return decorator

Expand Down
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ ruff = ">=0.0.272,<1.0.0"

[tool.poetry.group.test.dependencies]
nonebug = "^0.3.0"
wsproto = "^1.2.0"
pytest-cov = "^4.0.0"
pytest-xdist = "^3.0.2"
pytest-asyncio = "^0.23.2"
Expand Down
71 changes: 69 additions & 2 deletions tests/fake_server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import json
import base64
import socket
from typing import Dict, List, Union, TypeVar

from wsproto.events import Ping
from werkzeug import Request, Response
from werkzeug.datastructures import MultiDict
from wsproto.frame_protocol import CloseReason
from wsproto.events import Request as WSRequest
from wsproto import WSConnection, ConnectionType
from wsproto.events import TextMessage, BytesMessage, CloseConnection, AcceptConnection

K = TypeVar("K")
V = TypeVar("V")
Expand All @@ -29,8 +35,7 @@ def flattern(d: "MultiDict[K, V]") -> Dict[K, Union[V, List[V]]]:
return {k: v[0] if len(v) == 1 else v for k, v in d.to_dict(flat=False).items()}


@Request.application
def request_handler(request: Request) -> Response:
def http_echo(request: Request) -> Response:
try:
_json = json.loads(request.data.decode("utf-8"))
except (ValueError, TypeError):
Expand Down Expand Up @@ -67,3 +72,65 @@ def request_handler(request: Request) -> Response:
status=200,
content_type="application/json",
)


def websocket_echo(request: Request) -> Response:
stream = request.environ["werkzeug.socket"]

ws = WSConnection(ConnectionType.SERVER)

in_data = b"GET %s HTTP/1.1\r\n" % request.path.encode("utf-8")
for header, value in request.headers.items():
in_data += f"{header}: {value}\r\n".encode()
in_data += b"\r\n"

ws.receive_data(in_data)

running: bool = True
while True:
out_data = b""

for event in ws.events():
if isinstance(event, WSRequest):
out_data += ws.send(AcceptConnection())
elif isinstance(event, CloseConnection):
out_data += ws.send(event.response())
running = False
elif isinstance(event, Ping):
out_data += ws.send(event.response())
elif isinstance(event, TextMessage):
if event.data == "quit":
out_data += ws.send(
CloseConnection(CloseReason.NORMAL_CLOSURE, "bye")
)
running = False
else:
out_data += ws.send(TextMessage(data=event.data))
elif isinstance(event, BytesMessage):
if event.data == b"quit":
out_data += ws.send(
CloseConnection(CloseReason.NORMAL_CLOSURE, "bye")
)
running = False
else:
out_data += ws.send(BytesMessage(data=event.data))

if out_data:
stream.send(out_data)

if not running:
break

in_data = stream.recv(4096)
ws.receive_data(in_data)

stream.shutdown(socket.SHUT_RDWR)
return Response("", status=204)


@Request.application
def request_handler(request: Request) -> Response:
if request.headers.get("Connection") == "Upgrade":
return websocket_echo(request)
else:
return http_echo(request)
26 changes: 23 additions & 3 deletions tests/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ async def _handle_ws(ws: WebSocket) -> None:
assert data == b"ping"
await ws.send(b"pong")

with pytest.raises(WebSocketClosed):
with pytest.raises(WebSocketClosed, match=r"code=1000"):
await ws.receive()

ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws)
Expand All @@ -152,7 +152,7 @@ async def _handle_ws(ws: WebSocket) -> None:
await ws.send_bytes(b"ping")
assert await ws.receive_bytes() == b"pong"

await ws.close()
await ws.close(code=1000)

await asyncio.sleep(1)

Expand Down Expand Up @@ -315,9 +315,29 @@ async def test_http_client(driver: Driver, server_url: URL):
],
indirect=True,
)
async def test_websocket_client(driver: Driver):
async def test_websocket_client(driver: Driver, server_url: URL):
assert isinstance(driver, WebSocketClientMixin)

request = Request("GET", server_url.with_scheme("ws"))
async with driver.websocket(request) as ws:
await ws.send("test")
assert await ws.receive() == "test"

await ws.send(b"test")
assert await ws.receive() == b"test"

await ws.send_text("test")
assert await ws.receive_text() == "test"

await ws.send_bytes(b"test")
assert await ws.receive_bytes() == b"test"

await ws.send("quit")
with pytest.raises(WebSocketClosed, match=r"code=1000"):
await ws.receive()

await asyncio.sleep(1)


@pytest.mark.asyncio
@pytest.mark.parametrize(
Expand Down

0 comments on commit 2c6affe

Please sign in to comment.