Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dnsdist: Add the ability to set custom HTTP responses over DoH3 #15029

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions pdns/dnsdistdist/dnsdist-lua-actions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2019,7 +2019,7 @@ class ContinueAction : public DNSAction
std::shared_ptr<DNSAction> d_action;
};

#ifdef HAVE_DNS_OVER_HTTPS
#if defined(HAVE_DNS_OVER_HTTPS) || defined(HAVE_DNS_OVER_HTTP3)
class HTTPStatusAction : public DNSAction
{
public:
Expand All @@ -2030,17 +2030,29 @@ class HTTPStatusAction : public DNSAction

DNSAction::Action operator()(DNSQuestion* dnsquestion, std::string* ruleresult) const override
{
if (!dnsquestion->ids.du) {
return Action::None;
#if defined(HAVE_DNS_OVER_HTTPS)
if (dnsquestion->ids.du) {
dnsquestion->ids.du->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType);
dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
header.qr = true; // for good measure
setResponseHeadersFromConfig(header, d_responseConfig);
return true;
});
return Action::HeaderModify;
}

dnsquestion->ids.du->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType);
dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
header.qr = true; // for good measure
setResponseHeadersFromConfig(header, d_responseConfig);
return true;
});
return Action::HeaderModify;
#endif /* HAVE_DNS_OVER_HTTPS */
#if defined(HAVE_DNS_OVER_HTTP3)
if (dnsquestion->ids.doh3u) {
dnsquestion->ids.doh3u->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType);
dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsquestion->getMutableData(), [this](dnsheader& header) {
header.qr = true; // for good measure
setResponseHeadersFromConfig(header, d_responseConfig);
return true;
});
return Action::HeaderModify;
}
#endif /* HAVE_DNS_OVER_HTTP3 */
return Action::None;
}

[[nodiscard]] std::string toString() const override
Expand All @@ -2059,7 +2071,7 @@ class HTTPStatusAction : public DNSAction
std::string d_contentType;
int d_code;
};
#endif /* HAVE_DNS_OVER_HTTPS */
#endif /* HAVE_DNS_OVER_HTTPS || HAVE_DNS_OVER_HTTP3 */

#if defined(HAVE_LMDB) || defined(HAVE_CDB)
class KeyValueStoreLookupAction : public DNSAction
Expand Down
13 changes: 9 additions & 4 deletions pdns/dnsdistdist/dnsdist-lua-bindings-dnsquestion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
#endif /* HAVE_NET_SNMP */
});

#ifdef HAVE_DNS_OVER_HTTPS
#if defined(HAVE_DNS_OVER_HTTPS) || defined(HAVE_DNS_OVER_HTTP3)
luaCtx.registerFunction<std::string (DNSQuestion::*)(void) const>("getHTTPPath", [](const DNSQuestion& dnsQuestion) {
if (dnsQuestion.ids.du) {
return dnsQuestion.ids.du->getHTTPPath();
Expand Down Expand Up @@ -563,14 +563,19 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
});

luaCtx.registerFunction<void (DNSQuestion::*)(uint64_t statusCode, const std::string& body, const boost::optional<std::string> contentType)>("setHTTPResponse", [](DNSQuestion& dnsQuestion, uint64_t statusCode, const std::string& body, const boost::optional<std::string>& contentType) {
if (dnsQuestion.ids.du == nullptr) {
if (dnsQuestion.ids.du == nullptr && dnsQuestion.ids.doh3u == nullptr) {
return;
}
checkParameterBound("DNSQuestion::setHTTPResponse", statusCode, std::numeric_limits<uint16_t>::max());
PacketBuffer vect(body.begin(), body.end());
dnsQuestion.ids.du->setHTTPResponse(statusCode, std::move(vect), contentType ? *contentType : "");
if (dnsQuestion.ids.du) {
dnsQuestion.ids.du->setHTTPResponse(statusCode, std::move(vect), contentType ? *contentType : "");
}
else {
dnsQuestion.ids.doh3u->setHTTPResponse(statusCode, std::move(vect), contentType ? *contentType : "");
}
});
#endif /* HAVE_DNS_OVER_HTTPS */
#endif /* HAVE_DNS_OVER_HTTPS HAVE_DNS_OVER_HTTP3 */

luaCtx.registerFunction<bool (DNSQuestion::*)(bool nxd, const std::string& zone, uint64_t ttl, const std::string& mname, const std::string& rname, uint64_t serial, uint64_t refresh, uint64_t retry, uint64_t expire, uint64_t minimum)>("setNegativeAndAdditionalSOA", [](DNSQuestion& dnsQuestion, bool nxd, const std::string& zone, uint64_t ttl, const std::string& mname, const std::string& rname, uint64_t serial, uint64_t refresh, uint64_t retry, uint64_t expire, uint64_t minimum) {
checkParameterBound("setNegativeAndAdditionalSOA", ttl, std::numeric_limits<uint32_t>::max());
Expand Down
30 changes: 20 additions & 10 deletions pdns/dnsdistdist/dnsdist-lua-ffi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,17 +499,27 @@ void dnsdist_ffi_dnsquestion_set_result(dnsdist_ffi_dnsquestion_t* dq, const cha

void dnsdist_ffi_dnsquestion_set_http_response(dnsdist_ffi_dnsquestion_t* dq, uint16_t statusCode, const char* body, size_t bodyLen, const char* contentType)
{
if (dq->dq->ids.du == nullptr) {
return;
#if defined(HAVE_DNS_OVER_HTTPS)
if (dq->dq->ids.du) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): C API
PacketBuffer bodyVect(body, body + bodyLen);
dq->dq->ids.du->setHTTPResponse(statusCode, std::move(bodyVect), contentType);
dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->dq->getMutableData(), [](dnsheader& header) {
header.qr = true;
return true;
});
}
#endif
#if defined(HAVE_DNS_OVER_HTTP3)
if (dq->dq->ids.doh3u) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): C API
PacketBuffer bodyVect(body, body + bodyLen);
dq->dq->ids.doh3u->setHTTPResponse(statusCode, std::move(bodyVect), contentType);
dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->dq->getMutableData(), [](dnsheader& header) {
header.qr = true;
return true;
});
}

#ifdef HAVE_DNS_OVER_HTTPS
PacketBuffer bodyVect(body, body + bodyLen);
dq->dq->ids.du->setHTTPResponse(statusCode, std::move(bodyVect), contentType);
dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->dq->getMutableData(), [](dnsheader& header) {
header.qr = true;
return true;
});
#endif
}

Expand Down
94 changes: 59 additions & 35 deletions pdns/dnsdistdist/doh3.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,40 +285,53 @@ static bool tryWriteResponse(H3Connection& conn, const uint64_t streamID, Packet
return true;
}

static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len)
static void addHeaderToList(std::vector<quiche_h3_header>& headers, const char* name, size_t nameLen, const char* value, size_t valueLen)
{
headers.emplace_back((quiche_h3_header){
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
.name = reinterpret_cast<const uint8_t*>(name),
.name_len = nameLen,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
.value = reinterpret_cast<const uint8_t*>(value),
.value_len = valueLen,
});
}

static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len, const std::string& contentType = {})
{
std::string status = std::to_string(statusCode);
std::string lenStr = std::to_string(len);
std::array<quiche_h3_header, 3> headers{
(quiche_h3_header){
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
.name = reinterpret_cast<const uint8_t*>(":status"),
.name_len = sizeof(":status") - 1,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
.value = reinterpret_cast<const uint8_t*>(status.data()),
.value_len = status.size(),
},
(quiche_h3_header){
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
.name = reinterpret_cast<const uint8_t*>("content-length"),
.name_len = sizeof("content-length") - 1,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
.value = reinterpret_cast<const uint8_t*>(lenStr.data()),
.value_len = lenStr.size(),
},
(quiche_h3_header){
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
.name = reinterpret_cast<const uint8_t*>("content-type"),
.name_len = sizeof("content-type") - 1,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
.value = reinterpret_cast<const uint8_t*>("application/dns-message"),
.value_len = sizeof("application/dns-message") - 1,
},
};
PacketBuffer location;
PacketBuffer responseBody;
std::vector<quiche_h3_header> headers;
headers.reserve(4);
addHeaderToList(headers, ":status", sizeof(":status") - 1, status.data(), status.size());

if (statusCode >= 300 && statusCode < 400) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
addHeaderToList(headers, "location", sizeof("location") - 1, reinterpret_cast<const char*>(body), len);
static const std::string s_redirectStart{"<!DOCTYPE html><TITLE>Moved</TITLE><P>The document has moved <A HREF=\""};
static const std::string s_redirectEnd{"\">here</A>"};
static const std::string s_redirectContentType("text/html; charset=utf-8");
addHeaderToList(headers, "content-type", sizeof("content-type") - 1, s_redirectContentType.data(), s_redirectContentType.size());
responseBody.reserve(s_redirectStart.size() + len + s_redirectEnd.size());
responseBody.insert(responseBody.begin(), s_redirectStart.begin(), s_redirectStart.end());
// NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic)
responseBody.insert(responseBody.end(), body, body + len);
responseBody.insert(responseBody.end(), s_redirectEnd.begin(), s_redirectEnd.end());
body = responseBody.data();
len = responseBody.size();
}
else if (len > 0 && (statusCode == 200U || !contentType.empty())) {
// do not include content-type header info if there is no content
addHeaderToList(headers, "content-type", sizeof("content-type") - 1, contentType.empty() ? "application/dns-message" : contentType.data(), contentType.empty() ? sizeof("application/dns-message") - 1 : contentType.size());
}

const std::string lenStr = std::to_string(len);
addHeaderToList(headers, "content-length", sizeof("content-length") - 1, lenStr.data(), lenStr.size());

auto returnValue = quiche_h3_send_response(conn.d_http3.get(), conn.d_conn.get(),
streamID, headers.data(),
// do not include content-type header info if there is no content
(len > 0 && statusCode == 200U ? headers.size() : headers.size() - 1),
headers.size(),
len == 0);
if (returnValue != 0) {
/* in theory it could be QUICHE_H3_ERR_STREAM_BLOCKED if the stream is not writable / congested, but we are not going to handle this case */
Expand Down Expand Up @@ -350,13 +363,13 @@ static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16
}
}

static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const std::string& content)
static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const std::string& content = {})
{
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
h3_send_response(conn, streamID, statusCode, reinterpret_cast<const uint8_t*>(content.data()), content.size());
}

static void handleResponse(DOH3Frontend& frontend, H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const PacketBuffer& response)
static void handleResponse(DOH3Frontend& frontend, H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const PacketBuffer& response, const std::string& contentType)
{
if (statusCode == 200) {
++frontend.d_validResponses;
Expand All @@ -368,7 +381,7 @@ static void handleResponse(DOH3Frontend& frontend, H3Connection& conn, const uin
quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_UNSPECIFIED_ERROR));
}
else {
h3_send_response(conn, streamID, statusCode, &response.at(0), response.size());
h3_send_response(conn, streamID, statusCode, &response.at(0), response.size(), contentType);
}
}

Expand Down Expand Up @@ -471,7 +484,7 @@ static void processDOH3Query(DOH3UnitUniquePtr&& doh3Unit)
const auto handleImmediateResponse = [](DOH3UnitUniquePtr&& unit, [[maybe_unused]] const char* reason) {
DEBUGLOG("handleImmediateResponse() reason=" << reason);
auto conn = getConnection(unit->dsc->df->d_server_config->d_connections, unit->serverConnID);
handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response);
handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response, unit->d_contentTypeOut);
unit->ids.doh3u.reset();
};

Expand Down Expand Up @@ -658,7 +671,7 @@ static void flushResponses(pdns::channel::Receiver<DOH3Unit>& receiver)
auto unit = std::move(*tmp);
auto conn = getConnection(unit->dsc->df->d_server_config->d_connections, unit->serverConnID);
if (conn) {
handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response);
handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response, unit->d_contentTypeOut);
}
}
catch (const std::exception& e) {
Expand Down Expand Up @@ -1078,6 +1091,13 @@ const dnsdist::doh3::h3_headers_t& DOH3Unit::getHTTPHeaders() const
return headers;
}

void DOH3Unit::setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType)
{
status_code = statusCode;
response = std::move(body);
d_contentTypeOut = contentType;
}

#else /* HAVE_DNS_OVER_HTTP3 */

std::string DOH3Unit::getHTTPPath() const
Expand Down Expand Up @@ -1106,4 +1126,8 @@ const dnsdist::doh3::h3_headers_t& DOH3Unit::getHTTPHeaders() const
return headers;
}

void DOH3Unit::setHTTPResponse(uint16_t, PacketBuffer&&, const std::string&)
{
}

#endif /* HAVE_DNS_OVER_HTTP3 */
5 changes: 4 additions & 1 deletion pdns/dnsdistdist/doh3.hh
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
#include <unordered_map>

#include "config.h"
#include "noinitvector.hh"

#ifdef HAVE_DNS_OVER_HTTP3
#include "channel.hh"
#include "iputils.hh"
#include "libssl.hh"
#include "noinitvector.hh"
#include "stat_t.hh"
#include "dnsdist-idstate.hh"

Expand Down Expand Up @@ -93,13 +93,15 @@ struct DOH3Unit
[[nodiscard]] std::string getHTTPHost() const;
[[nodiscard]] std::string getHTTPScheme() const;
[[nodiscard]] const dnsdist::doh3::h3_headers_t& getHTTPHeaders() const;
void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType = "");

InternalQueryState ids;
PacketBuffer query;
PacketBuffer response;
PacketBuffer serverConnID;
dnsdist::doh3::h3_headers_t headers;
std::shared_ptr<DownstreamState> downstream{nullptr};
std::string d_contentTypeOut;
DOH3ServerConfig* dsc{nullptr};
uint64_t streamID{0};
size_t proxyProtocolPayloadSize{0};
Expand All @@ -126,6 +128,7 @@ struct DOH3Unit
[[nodiscard]] std::string getHTTPHost() const;
[[nodiscard]] std::string getHTTPScheme() const;
[[nodiscard]] const dnsdist::doh3::h3_headers_t& getHTTPHeaders() const;
void setHTTPResponse(uint16_t, PacketBuffer&&, const std::string&);
};

struct DOH3Frontend
Expand Down
7 changes: 5 additions & 2 deletions regression-tests.dnsdist/dnsdisttests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,15 +1151,18 @@ def sendDOQQuery(cls, port, query, response=None, timeout=2.0, caFile=None, useQ
return (receivedQuery, message)

@classmethod
def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False, customHeaders=None):
def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None, post=False, customHeaders=None, rawResponse=False):

if response:
if toQueue:
toQueue.put(response, True, timeout)
else:
cls._toResponderQueue.put(response, True, timeout)

message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post, additional_headers=customHeaders)
if rawResponse:
return doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post, additional_headers=customHeaders, raw_response=rawResponse)

message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName, post=post, additional_headers=customHeaders, raw_response=rawResponse)

receivedQuery = None

Expand Down
14 changes: 10 additions & 4 deletions regression-tests.dnsdist/doh3client.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,16 @@ async def perform_http_request(
elapsed = time.time() - start

result = bytes()
headers = {}
for http_event in http_events:
if isinstance(http_event, DataReceived):
result += http_event.data
if isinstance(http_event, StreamReset):
result = http_event
return result
if isinstance(http_event, HeadersReceived):
for k, v in http_event.headers:
headers[k] = v
return (result, headers)


async def async_h3_query(
Expand Down Expand Up @@ -220,15 +224,15 @@ async def async_h3_query(

return answer
except asyncio.TimeoutError as e:
return e
return (e,{})


def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False, additional_headers=None):
def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None, post=False, additional_headers=None, raw_response=False):
configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True)
if verify:
configuration.load_verify_locations(verify)

result = asyncio.run(
(result, headers) = asyncio.run(
async_h3_query(
configuration=configuration,
baseurl=baseurl,
Expand All @@ -245,4 +249,6 @@ def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname
raise StreamResetError(result.error_code)
if (isinstance(result, asyncio.TimeoutError)):
raise TimeoutError()
if raw_response:
return (result, headers)
return dns.message.from_wire(result)
Loading
Loading