Skip to content

Commit

Permalink
Add support for the PROXY protocol (#964)
Browse files Browse the repository at this point in the history
* Add option to enable proxy protocol on contexsts

* Add NAME to LISTEN config option
  • Loading branch information
Askaholic authored May 15, 2023
1 parent 592c575 commit e4e41df
Show file tree
Hide file tree
Showing 8 changed files with 447 additions and 260 deletions.
1 change: 1 addition & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ humanize = ">=2.6.0"
maxminddb = "*"
oauthlib = "*"
prometheus_client = "*"
proxy-protocol = "*"
pyjwt = {version = ">=2.4.0", extras = ["crypto"]}
pyyaml = "*"
sortedcontainers = "*"
Expand Down
422 changes: 207 additions & 215 deletions Pipfile.lock

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,16 @@ async def restart_control_server():
host = cfg["ADDRESS"]
port = cfg["PORT"]
proto_class_name = cfg["PROTOCOL"]
name = cfg.get("NAME")
proxy = cfg.get("PROXY", False)

proto_class = PROTO_CLASSES[proto_class_name]

await instance.listen(
address=(host, port),
protocol_class=proto_class
name=name,
protocol_class=proto_class,
proxy=proxy
)
except Exception as e:
raise RuntimeError(f"Error with server instance config: {cfg}") from e
Expand Down
16 changes: 13 additions & 3 deletions server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,21 +243,31 @@ async def initialize(service):
async def listen(
self,
address: tuple[str, int],
protocol_class: type[Protocol] = QDataStreamProtocol
name: Optional[str] = None,
protocol_class: type[Protocol] = QDataStreamProtocol,
proxy: bool = False,
) -> ServerContext:
"""
Start listening on a new address.
# Params
- `address`: Tuple indicating the host, port to listen on.
- `name`: String used to identify this context in log messages. The
default is to use the `protocol_class` name.
- `protocol_class`: The protocol class implementation to use.
- `proxy`: Boolean indicating whether or not to use the PROXY protocol.
See: https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
"""
if not self.started:
await self.start_services()

ctx = ServerContext(
f"{self.name}[{protocol_class.__name__}]",
f"{self.name}[{name or protocol_class.__name__}]",
self.connection_factory,
list(self.services.values()),
protocol_class
)
await ctx.listen(*address)
await ctx.listen(*address, proxy=proxy)

self.contexts.add(ctx)

Expand Down
6 changes: 5 additions & 1 deletion server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,16 @@ def __init__(self):
{
"ADDRESS": "",
"PORT": 8001,
"NAME": None,
"PROTOCOL": "QDataStreamProtocol",
"PROXY": False,
},
{
"ADDRESS": "",
"PORT": 8002,
"PROTOCOL": "SimpleJsonProtocol"
"NAME": None,
"PROTOCOL": "SimpleJsonProtocol",
"PROXY": False
}
]
self.LOG_LEVEL = "DEBUG"
Expand Down
86 changes: 74 additions & 12 deletions server/servercontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

import asyncio
import socket
from asyncio import StreamReader, StreamWriter
from contextlib import contextmanager
from typing import Callable, Iterable, Optional

import humanize
from proxyprotocol.detect import ProxyProtocolDetect
from proxyprotocol.reader import ProxyProtocolReader
from proxyprotocol.sock import SocketInfo

import server.metrics as metrics

Expand Down Expand Up @@ -45,11 +49,28 @@ def __init__(
def __repr__(self):
return f"ServerContext({self.name})"

async def listen(self, host, port):
self._logger.debug("%s: listen(%r, %r)", self.name, host, port)
async def listen(
self,
host: str,
port: Optional[int],
proxy: bool = False
):
self._logger.debug(
"%s: listen(%r, %r, proxy=%r)",
self.name,
host,
port,
proxy
)

callback = self.client_connected_callback
if proxy:
pp_detect = ProxyProtocolDetect()
pp_reader = ProxyProtocolReader(pp_detect)
callback = pp_reader.get_callback(callback)

self._server = await asyncio.start_server(
self.client_connected,
callback,
host=host,
port=port,
limit=LIMIT,
Expand Down Expand Up @@ -113,15 +134,56 @@ def write_broadcast_raw(self, data, validate_fn=lambda _: True):
conn
)

async def client_connected(self, stream_reader, stream_writer):
peername = Address(*stream_writer.get_extra_info("peername"))
self._logger.info(
"%s: Client connected from %s:%s",
self.name,
peername.host,
peername.port
)
protocol = self.protocol_class(stream_reader, stream_writer)
async def client_connected_callback(
self,
reader: StreamReader,
writer: StreamWriter,
proxy_info: Optional[SocketInfo] = None,
):
if proxy_info:
peername_writer = Address(*writer.get_extra_info("peername"))

if not proxy_info.peername:
# See security considerations:
# https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
self._logger.warning(
"%s: Client connected from %s:%s to a context in proxy "
"mode! The connection will be ignored, however this may "
"indicate a misconfiguration in your firewall.",
self.name,
peername_writer.host,
peername_writer.port
)
writer.close()
return

peername = Address(*proxy_info.peername)
self._logger.info(
"%s: Client connected from %s:%s via proxy %s:%s",
self.name,
peername.host,
peername.port,
peername_writer.host,
peername_writer.port
)
else:
peername = Address(*writer.get_extra_info("peername"))
self._logger.info(
"%s: Client connected from %s:%s",
self.name,
peername.host,
peername.port
)

await self.handle_client_connected(reader, writer, peername)

async def handle_client_connected(
self,
reader: StreamReader,
writer: StreamWriter,
peername: Address,
):
protocol = self.protocol_class(reader, writer)
connection = self._connection_factory()
self.connections[connection] = protocol

Expand Down
130 changes: 104 additions & 26 deletions tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import logging
import textwrap
from collections import defaultdict
from typing import Any, Callable
from typing import Any, Callable, Optional
from unittest import mock

import aio_pika
import proxyprotocol.dnsbl
import proxyprotocol.server
import proxyprotocol.server.protocol
import pytest
from aiohttp import web

Expand Down Expand Up @@ -115,7 +118,7 @@ def jwk_kid():


@pytest.fixture
async def lobby_contexts(
async def lobby_server_factory(
event_loop,
database,
broadcast_service,
Expand All @@ -131,11 +134,9 @@ async def lobby_contexts(
policy_server,
jwks_server
):
mock_policy = mock.patch(
"server.lobbyconnection.config.FAF_POLICY_SERVER_BASE_URL",
f"http://{policy_server.host}:{policy_server.port}"
)
with mock_policy:
all_contexts = []

async def make_lobby_server(config):
instance = ServerInstance(
"UnitTestServer",
database,
Expand All @@ -157,36 +158,81 @@ async def lobby_contexts(
broadcast_service.server = instance

contexts = {
"qstream": await instance.listen(
("127.0.0.1", None),
protocol_class=QDataStreamProtocol
),
"json": await instance.listen(
("127.0.0.1", None),
protocol_class=SimpleJsonProtocol
name: await instance.listen(
(cfg["ADDRESS"], cfg["PORT"]),
protocol_class=cfg["PROTOCOL"],
proxy=cfg.get("PROXY", False)
)
for name, cfg in config.items()
}
all_contexts.extend(contexts.values())
for context in contexts.values():
context.__connected_client_protos = []
player_service.is_uniqueid_exempt = lambda id: True

yield contexts
return contexts

for context in contexts.values():
await context.stop()
await context.shutdown()
# Close connected protocol objects
# https://github.com/FAForever/server/issues/717
for proto in context.__connected_client_protos:
proto.abort()
await exhaust_callbacks(event_loop)
mock_policy = mock.patch(
"server.lobbyconnection.config.FAF_POLICY_SERVER_BASE_URL",
f"http://{policy_server.host}:{policy_server.port}"
)
with mock_policy:
yield make_lobby_server

for context in all_contexts:
await context.stop()
await context.shutdown()
# Close connected protocol objects
# https://github.com/FAForever/server/issues/717
for proto in context.__connected_client_protos:
proto.abort()
await exhaust_callbacks(event_loop)


@pytest.fixture
async def lobby_contexts(lobby_server_factory):
return await lobby_server_factory({
"qstream": {
"ADDRESS": "127.0.0.1",
"PORT": None,
"PROTOCOL": QDataStreamProtocol
},
"json": {
"ADDRESS": "127.0.0.1",
"PORT": None,
"PROTOCOL": SimpleJsonProtocol
}
})


@pytest.fixture
async def lobby_contexts_proxy(lobby_server_factory):
return await lobby_server_factory({
"qstream": {
"ADDRESS": "127.0.0.1",
"PORT": None,
"PROTOCOL": QDataStreamProtocol,
"PROXY": True
},
"json": {
"ADDRESS": "127.0.0.1",
"PORT": None,
"PROTOCOL": SimpleJsonProtocol,
"PROXY": True
}
})


@pytest.fixture(params=("qstream", "json"))
def lobby_server(request, lobby_contexts):
yield lobby_contexts[request.param]


@pytest.fixture(params=("qstream", "json"))
def lobby_server_proxy(request, lobby_contexts_proxy):
yield lobby_contexts_proxy[request.param]


@pytest.fixture
async def control_server(player_service, game_service):
server = ControlServer(
Expand Down Expand Up @@ -286,6 +332,33 @@ async def get(request):
await runner.cleanup()


@pytest.fixture
async def proxy_server(lobby_server_proxy, event_loop):
buf_len = 262144
dnsbl = proxyprotocol.dnsbl.NoopDnsbl()

host, port = lobby_server_proxy.sockets[0].getsockname()
dest = proxyprotocol.server.Address(f"{host}:{port}")

server = await event_loop.create_server(
lambda: proxyprotocol.server.protocol.DownstreamProtocol(
proxyprotocol.server.protocol.UpstreamProtocol,
event_loop,
buf_len,
dnsbl,
dest
),
"127.0.0.1",
None,
)
await server.start_serving()

yield server

server.close()
await server.wait_closed()


@pytest.fixture
def tmp_user(database):
user_ids = defaultdict(lambda: 1)
Expand All @@ -307,9 +380,13 @@ async def make_user(name="TempUser"):
return make_user


async def connect_client(server: ServerContext) -> Protocol:
async def connect_client(
server: ServerContext,
address: Optional[tuple[str, int]] = None
) -> Protocol:
address = address or server.sockets[0].getsockname()
proto = server.protocol_class(
*(await asyncio.open_connection(*server.sockets[0].getsockname()))
*(await asyncio.open_connection(*address))
)
if hasattr(server, "__connected_client_protos"):
server.__connected_client_protos.append(proto)
Expand Down Expand Up @@ -388,8 +465,9 @@ async def get_session(proto):
async def connect_and_sign_in(
credentials,
lobby_server: ServerContext,
address: Optional[tuple[str, int]] = None
):
proto = await connect_client(lobby_server)
proto = await connect_client(lobby_server, address)
session = await get_session(proto)
await perform_login(proto, credentials)
hello = await read_until_command(proto, "welcome", timeout=120)
Expand Down
Loading

0 comments on commit e4e41df

Please sign in to comment.