Skip to content

Commit

Permalink
Add Web Socket support for StreamProxy (#3)
Browse files Browse the repository at this point in the history
* Add Web Socket support for StreamProxy

* update Cargo.lock
  • Loading branch information
hippalus authored Sep 13, 2024
1 parent e53029f commit e800446
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 56 deletions.
33 changes: 30 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 25 additions & 1 deletion config/umay.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ stream:
servers:
- address: "localhost"
port: 1994
ws_localhost:
load_balancer: round_robin
service_discovery: dns
servers:
- address: "localhost"
port: 1984

servers:
- name: "secure_tcp_server"
Expand All @@ -31,4 +37,22 @@ stream:
proxy_tls_protocols:
- TLSv1.2
- TLSv1.3
proxy_tls_ciphers: "TLS13_AES_256_GCM_SHA384"
proxy_tls_ciphers: "TLS13_AES_256_GCM_SHA384"
- name: "secure_ws_server"
listen:
port: 9984
protocol: ws
proxy_pass: ws_localhost
tls:
enabled: true
proxy_tls: on
proxy_tls_certificate: "/Users/hakanisler/Workspace/Github/hippalus/umay/certs/crt.der"
proxy_tls_certificate_key: "/Users/hakanisler/Workspace/Github/hippalus/umay/certs/key.pem"
proxy_tls_trusted_certificate: "/Users/hakanisler/Workspace/Github/hippalus/umay/certs/ca.pem"
proxy_tls_verify: on
proxy_tls_verify_depth: 2
proxy_tls_session_reuse: on
proxy_tls_protocols:
- TLSv1.2
- TLSv1.3
proxy_tls_ciphers: "TLS13_AES_256_GCM_SHA384"
85 changes: 65 additions & 20 deletions scripts/echo_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import signal
import sys
import websockets


class AsyncTCPEchoServer:
Expand All @@ -10,18 +11,19 @@ def __init__(self, host, port, server_id):
self.port = port
self.server_id = server_id
self.clients = set()
self.server = None

async def handle_client(self, reader, writer):
self.clients.add(writer)
addr = writer.get_extra_info('peername')
print(f"New connection from {addr}")
print(f"New TCP connection from {addr}")
try:
while True:
data = await reader.read(1024)
if not data:
break
message = data.decode()
response = f"Echo from server {self.server_id} on port {self.port}: {message}"
response = f"Echo from TCP server {self.server_id} on port {self.port}: {message}"
writer.write(response.encode())
await writer.drain()
except asyncio.CancelledError:
Expand All @@ -30,44 +32,87 @@ async def handle_client(self, reader, writer):
self.clients.remove(writer)
writer.close()
await writer.wait_closed()
print(f"Connection closed for {addr}")
print(f"TCP connection closed for {addr}")

async def run_server(self):
server = await asyncio.start_server(
self.server = await asyncio.start_server(
self.handle_client, self.host, self.port)

addr = server.sockets[0].getsockname()
print(f'Serving on {addr}')

async with server:
await server.serve_forever()
addr = self.server.sockets[0].getsockname()
print(f'Serving TCP on {addr}')
await self.server.serve_forever()

async def shutdown(self):
print("Shutting down the server...")
print(f"Shutting down the TCP server...")
self.server.close()
await self.server.wait_closed()
for client in self.clients:
client.close()
await asyncio.gather(*[client.wait_closed() for client in self.clients])
print("All connections closed")
await client.wait_closed()
print("All TCP connections closed")


class AsyncWSEchoServer:
def __init__(self, host, port, server_id):
self.host = host
self.port = port
self.server_id = server_id
self.server = None

async def handle_client(self, websocket, path):
addr = websocket.remote_address
print(f"New WebSocket connection from {addr}")
try:
async for message in websocket:
response = f"Echo from WS server {self.server_id} on port {self.port}: {message}"
await websocket.send(response)
except websockets.exceptions.ConnectionClosed:
pass
finally:
print(f"WebSocket connection closed for {addr}")

async def run_server(self):
self.server = await websockets.serve(
self.handle_client, self.host, self.port)
print(f'Serving WebSocket on {self.host}:{self.port}')
await self.server.wait_closed()

async def shutdown(self):
print(f"Shutting down the WebSocket server...")
self.server.close()
await self.server.wait_closed()
print("All WebSocket connections closed")


async def main():
host = '0.0.0.0'
port = int(os.environ.get('PORT', 1994))
tcp_port = int(os.environ.get('TCP_PORT', 1994))
ws_port = int(os.environ.get('WS_PORT', 1984))
server_id = os.environ.get('SERVER_ID', 'Unknown')

if len(sys.argv) > 1:
port = int(sys.argv[1])
if len(sys.argv) > 2:
tcp_port = int(sys.argv[1])
ws_port = int(sys.argv[2])

server = AsyncTCPEchoServer(host, port, server_id)
tcp_server = AsyncTCPEchoServer(host, tcp_port, f"{server_id}-TCP")
ws_server = AsyncWSEchoServer(host, ws_port, f"{server_id}-WS")

loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, lambda: asyncio.create_task(server.shutdown()))
loop.add_signal_handler(sig, lambda: asyncio.create_task(tcp_server.shutdown()))
loop.add_signal_handler(sig, lambda: asyncio.create_task(ws_server.shutdown()))

try:
await server.run_server()
await asyncio.gather(
tcp_server.run_server(),
ws_server.run_server()
)
except asyncio.CancelledError:
pass
finally:
await server.shutdown()
await asyncio.gather(
tcp_server.shutdown(),
ws_server.shutdown()
)


if __name__ == "__main__":
Expand Down
13 changes: 7 additions & 6 deletions umay/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,25 @@ version = "0.1.0"
edition = "2021"

[dependencies]
tokio = { version = "1.39", features = ["full"] }
tokio-stream = "0.1"
tokio = { version = "1.40.0", features = ["full"] }
tokio-stream = { version = "0.1", features = ["full"] }
tokio-tls = { version = "0.3" }
tokio-rustls = { version = "0.26" }
tokio-util = { version = "0.7" }
tokio-test = "0.4"
rustls = "0.23.12"
rustls = "0.23"
rustls-pemfile = "2.1"
rustls-webpki = "0.102"
webpki = "0.22.4"
rcgen = "0.13.1"
tower = { version = "0.5.0", features = ["full"] }
tower = { version = "0.5", features = ["full"] }
hyper = { version = "1.4", features = ["full"] }
hyper-util = "0.1"
http = "1.1.0"
futures = "0.3"
bytes = "1.7"
prometheus-client = "0.22.3"
ipnet = "2.9.0"
ipnet = "2.10.0"
socket2 = { version = "0.5", features = ["all"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
Expand All @@ -37,12 +37,13 @@ thiserror = "1.0"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
arc-swap = "1.7.1"
rand = "0.8.5"
tokio-tungstenite = "0.23.1"
tokio-tungstenite = { version = "0.23.1", features = ["stream", "__rustls-tls"] }
config = "0.14.0"
drain = "0.1.2"
base64 = "0.22.1"
tokio-tower = "0.7.0-rc4"
chrono = "0.4.38"
tungstenite = { version = "0.24.0", features = ["__rustls-tls"] }


[lib]
Expand Down
4 changes: 2 additions & 2 deletions umay/src/app/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,8 @@ pub struct LocationConfig {
pub enum Protocol {
Tcp,
Udp,
Wss,
Https,
Ws,
Http,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
Expand Down
31 changes: 31 additions & 0 deletions umay/src/app/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,35 @@
use socket2::TcpKeepalive;
use std::time::Duration;
use tokio::net::TcpStream;

pub mod config;
pub mod metric;
pub mod server;
pub mod signal;

fn set_nodelay_or_warn(socket: &TcpStream) {
if let Err(e) = socket.set_nodelay(true) {
tracing::warn!("failed to set nodelay: {}", e);
}
}

fn set_keepalive_or_warn(
tcp: tokio::net::TcpStream,
keepalive_duration: Option<Duration>,
) -> eyre::Result<tokio::net::TcpStream> {
let sock = {
let stream: std::net::TcpStream = tokio::net::TcpStream::into_std(tcp)?;
socket2::Socket::from(stream)
};

let ka = keepalive_duration
.into_iter()
.fold(TcpKeepalive::new(), |k, t| k.with_time(t));

if let Err(e) = sock.set_tcp_keepalive(&ka) {
tracing::warn!("failed to set keepalive: {}", e);
}

let stream: std::net::TcpStream = socket2::Socket::into(sock);
Ok(tokio::net::TcpStream::from_std(stream)?)
}
Loading

0 comments on commit e800446

Please sign in to comment.