implement IPv4 range bans

This commit is contained in:
Martin Michelsen
2024-04-21 01:12:51 -07:00
parent 79bf6b3fa9
commit de42135532
18 changed files with 296 additions and 71 deletions
+1
View File
@@ -92,6 +92,7 @@ set(SOURCES
src/HTTPServer.cc src/HTTPServer.cc
src/IPFrameInfo.cc src/IPFrameInfo.cc
src/IPStackSimulator.cc src/IPStackSimulator.cc
src/IPV4RangeSet.cc
src/ItemCreator.cc src/ItemCreator.cc
src/ItemData.cc src/ItemData.cc
src/ItemNameIndex.cc src/ItemNameIndex.cc
+10 -7
View File
@@ -19,10 +19,13 @@ using namespace std;
DNSServer::DNSServer( DNSServer::DNSServer(
shared_ptr<struct event_base> base, shared_ptr<struct event_base> base,
uint32_t local_connect_address, uint32_t external_connect_address) uint32_t local_connect_address,
uint32_t external_connect_address,
shared_ptr<const IPV4RangeSet> banned_ipv4_ranges)
: base(base), : base(base),
local_connect_address(local_connect_address), local_connect_address(local_connect_address),
external_connect_address(external_connect_address) {} external_connect_address(external_connect_address),
banned_ipv4_ranges(banned_ipv4_ranges) {}
DNSServer::~DNSServer() { DNSServer::~DNSServer() {
for (const auto& it : this->fd_to_receive_event) { for (const auto& it : this->fd_to_receive_event) {
@@ -55,8 +58,7 @@ void DNSServer::dispatch_on_receive_message(evutil_socket_t fd,
reinterpret_cast<DNSServer*>(ctx)->on_receive_message(fd, events); reinterpret_cast<DNSServer*>(ctx)->on_receive_message(fd, events);
} }
string DNSServer::response_for_query( string DNSServer::response_for_query(const void* vdata, size_t size, uint32_t resolved_address) {
const void* vdata, size_t size, uint32_t resolved_address) {
if (size < 0x0C) { if (size < 0x0C) {
throw invalid_argument("query too small"); throw invalid_argument("query too small");
} }
@@ -82,7 +84,7 @@ string DNSServer::response_for_query(
void DNSServer::on_receive_message(int fd, short) { void DNSServer::on_receive_message(int fd, short) {
for (;;) { for (;;) {
sockaddr_in remote; struct sockaddr_storage remote;
socklen_t remote_size = sizeof(sockaddr_in); socklen_t remote_size = sizeof(sockaddr_in);
memset(&remote, 0, remote_size); memset(&remote, 0, remote_size);
@@ -104,9 +106,10 @@ void DNSServer::on_receive_message(int fd, short) {
dns_server_log.warning("input query too small"); dns_server_log.warning("input query too small");
print_data(stderr, input.data(), bytes); print_data(stderr, input.data(), bytes);
} else { } else if (!this->banned_ipv4_ranges->check(remote)) {
input.resize(bytes); input.resize(bytes);
uint32_t remote_address = ntohl(remote.sin_addr.s_addr); const sockaddr_in* remote_sin = reinterpret_cast<const sockaddr_in*>(&remote);
uint32_t remote_address = ntohl(remote_sin->sin_addr.s_addr);
uint32_t connect_address = is_local_address(remote_address) uint32_t connect_address = is_local_address(remote_address)
? this->local_connect_address ? this->local_connect_address
: this->external_connect_address; : this->external_connect_address;
+14 -6
View File
@@ -7,29 +7,37 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "IPV4RangeSet.hh"
class DNSServer { class DNSServer {
public: public:
DNSServer(std::shared_ptr<struct event_base> base, DNSServer(
uint32_t local_connect_address, uint32_t external_connect_address); std::shared_ptr<struct event_base> base,
uint32_t local_connect_address,
uint32_t external_connect_address,
std::shared_ptr<const IPV4RangeSet> banned_ipv4_ranges);
DNSServer(const DNSServer&) = delete; DNSServer(const DNSServer&) = delete;
DNSServer(DNSServer&&) = delete; DNSServer(DNSServer&&) = delete;
virtual ~DNSServer(); virtual ~DNSServer();
inline void set_banned_ipv4_ranges(std::shared_ptr<const IPV4RangeSet> banned_ipv4_ranges) {
this->banned_ipv4_ranges = banned_ipv4_ranges;
}
void listen(const std::string& socket_path); void listen(const std::string& socket_path);
void listen(const std::string& addr, int port); void listen(const std::string& addr, int port);
void listen(int port); void listen(int port);
void add_socket(int fd); void add_socket(int fd);
static std::string response_for_query( static std::string response_for_query(const void* vdata, size_t size, uint32_t resolved_address);
const void* vdata, size_t size, uint32_t resolved_address); static std::string response_for_query(const std::string& query, uint32_t resolved_address);
static std::string response_for_query(
const std::string& query, uint32_t resolved_address);
private: private:
std::shared_ptr<struct event_base> base; std::shared_ptr<struct event_base> base;
std::unordered_map<int, std::unique_ptr<struct event, void (*)(struct event*)>> fd_to_receive_event; std::unordered_map<int, std::unique_ptr<struct event, void (*)(struct event*)>> fd_to_receive_event;
uint32_t local_connect_address; uint32_t local_connect_address;
uint32_t external_connect_address; uint32_t external_connect_address;
std::shared_ptr<const IPV4RangeSet> banned_ipv4_ranges;
static void dispatch_on_receive_message(evutil_socket_t fd, short events, void* ctx); static void dispatch_on_receive_message(evutil_socket_t fd, short events, void* ctx);
void on_receive_message(int fd, short event); void on_receive_message(int fd, short event);
+9 -4
View File
@@ -235,12 +235,17 @@ void IPStackSimulator::disconnect_client(struct bufferevent* bev) {
void IPStackSimulator::dispatch_on_listen_accept( void IPStackSimulator::dispatch_on_listen_accept(
struct evconnlistener* listener, evutil_socket_t fd, struct evconnlistener* listener, evutil_socket_t fd,
struct sockaddr* address, int socklen, void* ctx) { struct sockaddr* address, int socklen, void* ctx) {
reinterpret_cast<IPStackSimulator*>(ctx)->on_listen_accept( reinterpret_cast<IPStackSimulator*>(ctx)->on_listen_accept(listener, fd, address, socklen);
listener, fd, address, socklen);
} }
void IPStackSimulator::on_listen_accept(struct evconnlistener* listener, void IPStackSimulator::on_listen_accept(struct evconnlistener* listener, evutil_socket_t fd, struct sockaddr*, int) {
evutil_socket_t fd, struct sockaddr*, int) { struct sockaddr_storage remote_addr;
get_socket_addresses(fd, nullptr, &remote_addr);
if (this->state->banned_ipv4_ranges->check(remote_addr)) {
close(fd);
return;
}
int listen_fd = evconnlistener_get_fd(listener); int listen_fd = evconnlistener_get_fd(listener);
const ListeningSocket* listening_socket; const ListeningSocket* listening_socket;
+25 -24
View File
@@ -22,22 +22,6 @@ public:
HDLC_RAW, HDLC_RAW,
}; };
IPStackSimulator(
std::shared_ptr<struct event_base> base,
std::shared_ptr<ServerState> state);
~IPStackSimulator();
void listen(const std::string& name, const std::string& socket_path, Protocol protocol);
void listen(const std::string& name, const std::string& addr, int port, Protocol protocol);
void listen(const std::string& name, int port, Protocol protocol);
void add_socket(const std::string& name, int fd, Protocol protocol);
static uint32_t connect_address_for_remote_address(uint32_t remote_addr);
private:
std::shared_ptr<struct event_base> base;
std::shared_ptr<ServerState> state;
using unique_listener = std::unique_ptr<struct evconnlistener, void (*)(struct evconnlistener*)>; using unique_listener = std::unique_ptr<struct evconnlistener, void (*)(struct evconnlistener*)>;
using unique_bufferevent = std::unique_ptr<struct bufferevent, void (*)(struct bufferevent*)>; using unique_bufferevent = std::unique_ptr<struct bufferevent, void (*)(struct bufferevent*)>;
using unique_evbuffer = std::unique_ptr<struct evbuffer, void (*)(struct evbuffer*)>; using unique_evbuffer = std::unique_ptr<struct evbuffer, void (*)(struct evbuffer*)>;
@@ -92,6 +76,28 @@ private:
void on_idle_timeout(); void on_idle_timeout();
}; };
IPStackSimulator(
std::shared_ptr<struct event_base> base,
std::shared_ptr<ServerState> state);
~IPStackSimulator();
void listen(const std::string& name, const std::string& socket_path, Protocol protocol);
void listen(const std::string& name, const std::string& addr, int port, Protocol protocol);
void listen(const std::string& name, int port, Protocol protocol);
void add_socket(const std::string& name, int fd, Protocol protocol);
static uint32_t connect_address_for_remote_address(uint32_t remote_addr);
inline const std::unordered_map<struct bufferevent*, std::shared_ptr<IPClient>>& all_clients() const {
return this->bev_to_client;
}
void disconnect_client(struct bufferevent* bev);
private:
std::shared_ptr<struct event_base> base;
std::shared_ptr<ServerState> state;
struct ListeningSocket { struct ListeningSocket {
std::string name; std::string name;
Protocol protocol; Protocol protocol;
@@ -111,20 +117,16 @@ private:
FILE* pcap_text_log_file; FILE* pcap_text_log_file;
void disconnect_client(struct bufferevent* bev);
static uint64_t tcp_conn_key_for_connection(const IPClient::TCPConnection& conn); static uint64_t tcp_conn_key_for_connection(const IPClient::TCPConnection& conn);
static uint64_t tcp_conn_key_for_client_frame(const IPv4Header& ipv4, const TCPHeader& tcp); static uint64_t tcp_conn_key_for_client_frame(const IPv4Header& ipv4, const TCPHeader& tcp);
static uint64_t tcp_conn_key_for_client_frame(const FrameInfo& fi); static uint64_t tcp_conn_key_for_client_frame(const FrameInfo& fi);
static std::string str_for_ipv4_netloc(uint32_t addr, uint16_t port); static std::string str_for_ipv4_netloc(uint32_t addr, uint16_t port);
static std::string str_for_tcp_connection(std::shared_ptr<const IPClient> c, static std::string str_for_tcp_connection(std::shared_ptr<const IPClient> c, const IPClient::TCPConnection& conn);
const IPClient::TCPConnection& conn);
static void dispatch_on_listen_accept(struct evconnlistener* listener, static void dispatch_on_listen_accept(struct evconnlistener* listener,
evutil_socket_t fd, struct sockaddr* address, int socklen, void* ctx); evutil_socket_t fd, struct sockaddr* address, int socklen, void* ctx);
void on_listen_accept(struct evconnlistener* listener, evutil_socket_t fd, void on_listen_accept(struct evconnlistener* listener, evutil_socket_t fd, struct sockaddr* address, int socklen);
struct sockaddr* address, int socklen);
static void dispatch_on_listen_error(struct evconnlistener* listener, void* ctx); static void dispatch_on_listen_error(struct evconnlistener* listener, void* ctx);
void on_listen_error(struct evconnlistener* listener); void on_listen_error(struct evconnlistener* listener);
@@ -158,8 +160,7 @@ private:
struct evbuffer* src_buf = nullptr, struct evbuffer* src_buf = nullptr,
size_t src_bytes = 0); size_t src_bytes = 0);
void open_server_connection( void open_server_connection(std::shared_ptr<IPClient> c, IPClient::TCPConnection& conn);
std::shared_ptr<IPClient> c, IPClient::TCPConnection& conn);
void log_frame(const std::string& data) const; void log_frame(const std::string& data) const;
}; };
+73
View File
@@ -0,0 +1,73 @@
#include "IPV4RangeSet.hh"
#include <arpa/inet.h>
using namespace std;
IPV4RangeSet::IPV4RangeSet(const JSON& json) {
for (const auto& it : json.as_list()) {
// String should be of the form a.b.c.d or a.b.c.d/e
auto tokens = split(it->as_string(), '/');
size_t mask_bits;
if (tokens.size() == 1) {
mask_bits = 32;
} else if (tokens.size() == 2) {
mask_bits = stoul(tokens[1], nullptr, 10);
if (mask_bits > 32) {
throw runtime_error("invalid IPv4 address range");
}
} else {
throw runtime_error("invalid IPv4 address range");
}
auto addr_tokens = split(tokens[0], '.');
if (addr_tokens.size() != 4) {
throw runtime_error("invalid IPv4 address");
}
uint32_t addr = 0;
for (size_t z = 0; z < 4; z++) {
size_t end_pos = 0;
size_t new_byte = stoul(addr_tokens[z], &end_pos, 10);
if (end_pos != addr_tokens[z].size() || new_byte > 0xFF) {
throw runtime_error("invalid IPv4 address");
}
addr = (addr << 8) | new_byte;
}
addr &= (0xFFFFFFFF << (32 - mask_bits));
this->ranges.emplace(addr, mask_bits);
}
}
JSON IPV4RangeSet::json() const {
auto ret = JSON::list();
for (const auto& it : this->ranges) {
uint32_t addr = it.first;
uint8_t mask_bits = it.second;
ret.emplace_back(string_printf("%hhu.%hhu.%hhu.%hhu/%hhu",
static_cast<uint8_t>((addr >> 24) & 0xFF),
static_cast<uint8_t>((addr >> 16) & 0xFF),
static_cast<uint8_t>((addr >> 8) & 0xFF),
static_cast<uint8_t>(addr & 0xFF),
mask_bits));
}
return ret;
}
bool IPV4RangeSet::check(uint32_t addr) const {
auto it = this->ranges.upper_bound(addr);
if (it == this->ranges.begin()) {
return false; // addr is before any range
}
const auto& range = *(--it);
return (((range.first ^ addr) & (0xFFFFFFFF << (32 - range.second))) == 0);
}
bool IPV4RangeSet::check(const struct sockaddr_storage& ss) const {
if (ss.ss_family != AF_INET) {
return false;
}
const sockaddr_in* sin = reinterpret_cast<const sockaddr_in*>(&ss);
return this->check(ntohl(sin->sin_addr.s_addr));
}
+18
View File
@@ -0,0 +1,18 @@
#pragma once
#include <phosg/JSON.hh>
#include <set>
class IPV4RangeSet {
public:
IPV4RangeSet() = default;
explicit IPV4RangeSet(const JSON& json);
JSON json() const;
bool check(uint32_t addr) const;
bool check(const struct sockaddr_storage& ss) const;
protected:
std::map<uint32_t, uint8_t> ranges; // {addr: mask_bits}
};
+10 -8
View File
@@ -2458,22 +2458,21 @@ Action a_run_server_replay_log(
auto state = make_shared<ServerState>(base, get_config_filename(args), is_replay); auto state = make_shared<ServerState>(base, get_config_filename(args), is_replay);
state->load_all(); state->load_all();
shared_ptr<DNSServer> dns_server;
if (state->dns_server_port && !is_replay) { if (state->dns_server_port && !is_replay) {
if (!state->dns_server_addr.empty()) { if (!state->dns_server_addr.empty()) {
config_log.info("Starting DNS server on %s:%hu", state->dns_server_addr.c_str(), state->dns_server_port); config_log.info("Starting DNS server on %s:%hu", state->dns_server_addr.c_str(), state->dns_server_port);
} else { } else {
config_log.info("Starting DNS server on port %hu", state->dns_server_port); config_log.info("Starting DNS server on port %hu", state->dns_server_port);
} }
dns_server = make_shared<DNSServer>(base, state->local_address, state->external_address); state->dns_server = make_shared<DNSServer>(
dns_server->listen(state->dns_server_addr, state->dns_server_port); base, state->local_address, state->external_address, state->banned_ipv4_ranges);
state->dns_server->listen(state->dns_server_addr, state->dns_server_port);
} else { } else {
config_log.info("DNS server is disabled"); config_log.info("DNS server is disabled");
} }
shared_ptr<ServerShell> shell; shared_ptr<ServerShell> shell;
shared_ptr<ReplaySession> replay_session; shared_ptr<ReplaySession> replay_session;
shared_ptr<IPStackSimulator> ip_stack_simulator;
shared_ptr<HTTPServer> http_server; shared_ptr<HTTPServer> http_server;
if (is_replay) { if (is_replay) {
config_log.info("Starting proxy server"); config_log.info("Starting proxy server");
@@ -2548,21 +2547,24 @@ Action a_run_server_replay_log(
if (!state->ip_stack_addresses.empty() || !state->ppp_stack_addresses.empty() || !state->ppp_raw_addresses.empty()) { if (!state->ip_stack_addresses.empty() || !state->ppp_stack_addresses.empty() || !state->ppp_raw_addresses.empty()) {
config_log.info("Starting IP/PPP stack simulator"); config_log.info("Starting IP/PPP stack simulator");
ip_stack_simulator = make_shared<IPStackSimulator>(base, state); state->ip_stack_simulator = make_shared<IPStackSimulator>(base, state);
for (const auto& it : state->ip_stack_addresses) { for (const auto& it : state->ip_stack_addresses) {
auto netloc = parse_netloc(it); auto netloc = parse_netloc(it);
string spec = (netloc.second == 0) ? ("T-IPS-" + netloc.first) : string_printf("T-IPS-%hu", netloc.second); string spec = (netloc.second == 0) ? ("T-IPS-" + netloc.first) : string_printf("T-IPS-%hu", netloc.second);
ip_stack_simulator->listen(spec, netloc.first, netloc.second, IPStackSimulator::Protocol::ETHERNET_TAPSERVER); state->ip_stack_simulator->listen(
spec, netloc.first, netloc.second, IPStackSimulator::Protocol::ETHERNET_TAPSERVER);
} }
for (const auto& it : state->ppp_stack_addresses) { for (const auto& it : state->ppp_stack_addresses) {
auto netloc = parse_netloc(it); auto netloc = parse_netloc(it);
string spec = (netloc.second == 0) ? ("T-PPPST-" + netloc.first) : string_printf("T-PPPST-%hu", netloc.second); string spec = (netloc.second == 0) ? ("T-PPPST-" + netloc.first) : string_printf("T-PPPST-%hu", netloc.second);
ip_stack_simulator->listen(spec, netloc.first, netloc.second, IPStackSimulator::Protocol::HDLC_TAPSERVER); state->ip_stack_simulator->listen(
spec, netloc.first, netloc.second, IPStackSimulator::Protocol::HDLC_TAPSERVER);
} }
for (const auto& it : state->ppp_raw_addresses) { for (const auto& it : state->ppp_raw_addresses) {
auto netloc = parse_netloc(it); auto netloc = parse_netloc(it);
string spec = (netloc.second == 0) ? ("T-PPPSR-" + netloc.first) : string_printf("T-PPPSR-%hu", netloc.second); string spec = (netloc.second == 0) ? ("T-PPPSR-" + netloc.first) : string_printf("T-PPPSR-%hu", netloc.second);
ip_stack_simulator->listen(spec, netloc.first, netloc.second, IPStackSimulator::Protocol::HDLC_RAW); state->ip_stack_simulator->listen(
spec, netloc.first, netloc.second, IPStackSimulator::Protocol::HDLC_RAW);
if (netloc.second) { if (netloc.second) {
if (state->local_address == state->external_address) { if (state->local_address == state->external_address) {
config_log.info( config_log.info(
+14 -3
View File
@@ -320,6 +320,13 @@ void PatchServer::dispatch_on_listen_error(
} }
void PatchServer::on_listen_accept(struct evconnlistener* listener, evutil_socket_t fd, struct sockaddr*, int) { void PatchServer::on_listen_accept(struct evconnlistener* listener, evutil_socket_t fd, struct sockaddr*, int) {
struct sockaddr_storage remote_addr;
get_socket_addresses(fd, nullptr, &remote_addr);
if (this->config->banned_ipv4_ranges->check(remote_addr)) {
close(fd);
return;
}
int listen_fd = evconnlistener_get_fd(listener); int listen_fd = evconnlistener_get_fd(listener);
ListeningSocket* listening_socket; ListeningSocket* listening_socket;
try { try {
@@ -461,7 +468,11 @@ void PatchServer::thread_fn() {
} }
void PatchServer::set_config(std::shared_ptr<const Config> config) { void PatchServer::set_config(std::shared_ptr<const Config> config) {
forward_to_event_thread(this->base, [s = this->shared_from_this(), config = std::move(config)]() { if (this->base_is_shared) {
s->config = config; this->config = config;
}); } else {
forward_to_event_thread(this->base, [s = this->shared_from_this(), config = std::move(config)]() {
s->config = config;
});
}
} }
+2
View File
@@ -10,6 +10,7 @@
#include "Account.hh" #include "Account.hh"
#include "Channel.hh" #include "Channel.hh"
#include "IPV4RangeSet.hh"
#include "PatchFileIndex.hh" #include "PatchFileIndex.hh"
#include "Version.hh" #include "Version.hh"
@@ -22,6 +23,7 @@ public:
std::string message; std::string message;
std::shared_ptr<AccountIndex> account_index; std::shared_ptr<AccountIndex> account_index;
std::shared_ptr<const PatchFileIndex> patch_file_index; std::shared_ptr<const PatchFileIndex> patch_file_index;
std::shared_ptr<const IPV4RangeSet> banned_ipv4_ranges;
std::shared_ptr<struct event_base> shared_base; std::shared_ptr<struct event_base> shared_base;
}; };
+7
View File
@@ -98,6 +98,13 @@ void ProxyServer::ListeningSocket::dispatch_on_listen_error(
} }
void ProxyServer::ListeningSocket::on_listen_accept(int fd) { void ProxyServer::ListeningSocket::on_listen_accept(int fd) {
struct sockaddr_storage remote_addr;
get_socket_addresses(fd, nullptr, &remote_addr);
if (this->server->state->banned_ipv4_ranges->check(remote_addr)) {
close(fd);
return;
}
this->log.info("Client connected on fd %d (port %hu, version %s)", fd, this->port, name_for_enum(this->version)); this->log.info("Client connected on fd %d (port %hu, version %s)", fd, this->port, name_for_enum(this->version));
auto* bev = bufferevent_socket_new(this->server->base.get(), fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_DEFER_CALLBACKS); auto* bev = bufferevent_socket_new(this->server->base.get(), fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_DEFER_CALLBACKS);
this->server->on_client_connect(bev, this->port, this->version, this->server->on_client_connect(bev, this->port, this->version,
+2 -2
View File
@@ -5181,8 +5181,8 @@ typedef void (*on_command_t)(shared_ptr<Client> c, uint16_t command, uint32_t fl
// Command handler table, indexed by command number and game version. Null // Command handler table, indexed by command number and game version. Null
// entries in this table cause on_unimplemented_command to be called, which // entries in this table cause on_unimplemented_command to be called, which
// disconnects the client. // disconnects the client.
static_assert(NUM_VERSIONS - 2 == 12, "Don\'t forget to update the ReceiveCommands handler table"); static_assert(NUM_NON_PATCH_VERSIONS == 12, "Don\'t forget to update the ReceiveCommands handler table");
static on_command_t handlers[0x100][NUM_VERSIONS - 2] = { static on_command_t handlers[0x100][NUM_NON_PATCH_VERSIONS] = {
// clang-format off // clang-format off
// DC_NTE DC_112000 DCV1 DCV2 PC_NTE PC GCNTE GC EP3TE EP3 XB BB // DC_NTE DC_112000 DCV1 DCV2 PC_NTE PC GCNTE GC EP3TE EP3 XB BB
/* 00 */ {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}, /* 00 */ {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr},
+25 -15
View File
@@ -85,17 +85,20 @@ void Server::destroy_clients() {
void Server::dispatch_on_listen_accept( void Server::dispatch_on_listen_accept(
struct evconnlistener* listener, evutil_socket_t fd, struct evconnlistener* listener, evutil_socket_t fd,
struct sockaddr* address, int socklen, void* ctx) { struct sockaddr* address, int socklen, void* ctx) {
reinterpret_cast<Server*>(ctx)->on_listen_accept(listener, fd, address, reinterpret_cast<Server*>(ctx)->on_listen_accept(listener, fd, address, socklen);
socklen);
} }
void Server::dispatch_on_listen_error( void Server::dispatch_on_listen_error(struct evconnlistener* listener, void* ctx) {
struct evconnlistener* listener, void* ctx) {
reinterpret_cast<Server*>(ctx)->on_listen_error(listener); reinterpret_cast<Server*>(ctx)->on_listen_error(listener);
} }
void Server::on_listen_accept( void Server::on_listen_accept(struct evconnlistener* listener, evutil_socket_t fd, struct sockaddr*, int) {
struct evconnlistener* listener, evutil_socket_t fd, struct sockaddr*, int) { struct sockaddr_storage remote_addr;
get_socket_addresses(fd, nullptr, &remote_addr);
if (this->state->banned_ipv4_ranges->check(remote_addr)) {
close(fd);
return;
}
int listen_fd = evconnlistener_get_fd(listener); int listen_fd = evconnlistener_get_fd(listener);
ListeningSocket* listening_socket; ListeningSocket* listening_socket;
@@ -108,8 +111,7 @@ void Server::on_listen_accept(
return; return;
} }
struct bufferevent* bev = bufferevent_socket_new(this->base.get(), fd, struct bufferevent* bev = bufferevent_socket_new(this->base.get(), fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_DEFER_CALLBACKS);
BEV_OPT_CLOSE_ON_FREE | BEV_OPT_DEFER_CALLBACKS);
auto c = make_shared<Client>(this->shared_from_this(), bev, listening_socket->version, listening_socket->behavior); auto c = make_shared<Client>(this->shared_from_this(), bev, listening_socket->version, listening_socket->behavior);
c->channel.on_command_received = Server::on_client_input; c->channel.on_command_received = Server::on_client_input;
c->channel.on_error = Server::on_client_error; c->channel.on_error = Server::on_client_error;
@@ -298,34 +300,34 @@ vector<shared_ptr<Client>> Server::get_clients_by_identifier(const string& ident
for (const auto& it : this->state->channel_to_client) { for (const auto& it : this->state->channel_to_client) {
auto c = it.second; auto c = it.second;
if (c->login && c->login->account->account_id == account_id_hex) { if (c->login && c->login->account->account_id == account_id_hex) {
results.emplace_back(std::move(c)); results.emplace_back(c);
continue; continue;
} }
if (c->login && c->login->account->account_id == account_id_dec) { if (c->login && c->login->account->account_id == account_id_dec) {
results.emplace_back(std::move(c)); results.emplace_back(c);
continue; continue;
} }
if (c->login && c->login->xb_license && c->login->xb_license->gamertag == ident) { if (c->login && c->login->xb_license && c->login->xb_license->gamertag == ident) {
results.emplace_back(std::move(c)); results.emplace_back(c);
continue; continue;
} }
if (c->login && c->login->bb_license && c->login->bb_license->username == ident) { if (c->login && c->login->bb_license && c->login->bb_license->username == ident) {
results.emplace_back(std::move(c)); results.emplace_back(c);
continue; continue;
} }
auto p = c->character(false, false); auto p = c->character(false, false);
if (p && p->disp.name.eq(ident, p->inventory.language)) { if (p && p->disp.name.eq(ident, p->inventory.language)) {
results.emplace_back(std::move(c)); results.emplace_back(c);
continue; continue;
} }
if (c->channel.name == ident) { if (c->channel.name == ident) {
results.emplace_back(std::move(c)); results.emplace_back(c);
continue; continue;
} }
if (starts_with(c->channel.name, ident + " ")) { if (starts_with(c->channel.name, ident + " ")) {
results.emplace_back(std::move(c)); results.emplace_back(c);
continue; continue;
} }
} }
@@ -333,6 +335,14 @@ vector<shared_ptr<Client>> Server::get_clients_by_identifier(const string& ident
return results; return results;
} }
vector<shared_ptr<Client>> Server::all_clients() const {
vector<shared_ptr<Client>> ret;
for (const auto& it : this->state->channel_to_client) {
ret.emplace_back(it.second);
}
return ret;
}
shared_ptr<struct event_base> Server::get_base() const { shared_ptr<struct event_base> Server::get_base() const {
return this->base; return this->base;
} }
+1
View File
@@ -31,6 +31,7 @@ public:
std::shared_ptr<Client> get_client() const; std::shared_ptr<Client> get_client() const;
std::vector<std::shared_ptr<Client>> get_clients_by_identifier(const std::string& ident) const; std::vector<std::shared_ptr<Client>> get_clients_by_identifier(const std::string& ident) const;
std::vector<std::shared_ptr<Client>> all_clients() const;
std::shared_ptr<struct event_base> get_base() const; std::shared_ptr<struct event_base> get_base() const;
inline std::shared_ptr<ServerState> get_state() const { inline std::shared_ptr<ServerState> get_state() const {
+64 -1
View File
@@ -682,6 +682,13 @@ void ServerState::load_config_early() {
this->all_addresses.erase("<external>"); this->all_addresses.erase("<external>");
this->all_addresses.emplace("<external>", this->external_address); this->all_addresses.emplace("<external>", this->external_address);
try {
this->banned_ipv4_ranges = make_shared<IPV4RangeSet>(this->config_json->at("BannedIPV4Ranges"));
this->disconnect_all_banned_clients();
} catch (const out_of_range&) {
this->banned_ipv4_ranges = make_shared<IPV4RangeSet>();
}
this->client_ping_interval_usecs = this->config_json->get_int("ClientPingInterval", 30000000); this->client_ping_interval_usecs = this->config_json->get_int("ClientPingInterval", 30000000);
this->client_idle_timeout_usecs = this->config_json->get_int("ClientIdleTimeout", 60000000); this->client_idle_timeout_usecs = this->config_json->get_int("ClientIdleTimeout", 60000000);
this->patch_client_idle_timeout_usecs = this->config_json->get_int("PatchClientIdleTimeout", 300000000); this->patch_client_idle_timeout_usecs = this->config_json->get_int("PatchClientIdleTimeout", 300000000);
@@ -1108,6 +1115,8 @@ void ServerState::load_config_early() {
} }
} catch (const out_of_range&) { } catch (const out_of_range&) {
} }
this->update_dependent_server_configs();
} }
void ServerState::load_config_late() { void ServerState::load_config_late() {
@@ -1278,6 +1287,7 @@ void ServerState::load_accounts(bool from_non_event_thread) {
auto set = [s = this->shared_from_this(), new_index = std::move(new_index)]() { auto set = [s = this->shared_from_this(), new_index = std::move(new_index)]() {
s->account_index = std::move(new_index); s->account_index = std::move(new_index);
s->update_dependent_server_configs();
}; };
this->forward_or_call(from_non_event_thread, std::move(set)); this->forward_or_call(from_non_event_thread, std::move(set));
} }
@@ -1324,6 +1334,7 @@ void ServerState::load_patch_indexes(bool from_non_event_thread) {
s->bb_data_gsl = std::move(bb_data_gsl); s->bb_data_gsl = std::move(bb_data_gsl);
s->pc_patch_file_index = std::move(pc_patch_file_index); s->pc_patch_file_index = std::move(pc_patch_file_index);
s->bb_patch_file_index = std::move(bb_patch_file_index); s->bb_patch_file_index = std::move(bb_patch_file_index);
s->update_dependent_server_configs();
}; };
this->forward_or_call(from_non_event_thread, std::move(set)); this->forward_or_call(from_non_event_thread, std::move(set));
} }
@@ -1850,15 +1861,67 @@ shared_ptr<PatchServer::Config> ServerState::generate_patch_server_config(bool i
ret->idle_timeout_usecs = this->patch_client_idle_timeout_usecs; ret->idle_timeout_usecs = this->patch_client_idle_timeout_usecs;
ret->message = is_bb ? this->bb_patch_server_message : this->pc_patch_server_message; ret->message = is_bb ? this->bb_patch_server_message : this->pc_patch_server_message;
ret->account_index = this->account_index; ret->account_index = this->account_index;
ret->banned_ipv4_ranges = this->banned_ipv4_ranges;
ret->patch_file_index = is_bb ? this->bb_patch_file_index : this->pc_patch_file_index; ret->patch_file_index = is_bb ? this->bb_patch_file_index : this->pc_patch_file_index;
return ret; return ret;
} }
void ServerState::update_patch_server_configs() const { void ServerState::update_dependent_server_configs() const {
if (this->pc_patch_server) { if (this->pc_patch_server) {
this->pc_patch_server->set_config(this->generate_patch_server_config(false)); this->pc_patch_server->set_config(this->generate_patch_server_config(false));
} }
if (this->bb_patch_server) { if (this->bb_patch_server) {
this->bb_patch_server->set_config(this->generate_patch_server_config(true)); this->bb_patch_server->set_config(this->generate_patch_server_config(true));
} }
if (this->dns_server) {
this->dns_server->set_banned_ipv4_ranges(this->banned_ipv4_ranges);
}
}
void ServerState::disconnect_all_banned_clients() {
uint64_t now_usecs = now();
if (this->game_server) {
for (const auto& c : this->game_server->all_clients()) {
if ((c->login && (c->login->account->ban_end_time > now_usecs)) ||
this->banned_ipv4_ranges->check(c->channel.remote_addr)) {
this->game_server->disconnect_client(c);
}
}
}
// Proxy server
if (this->proxy_server) {
vector<uint32_t> sessions_to_close;
for (const auto& it : this->proxy_server->all_sessions()) {
auto ses = it.second;
if ((ses->login && (ses->login->account->ban_end_time > now_usecs)) ||
this->banned_ipv4_ranges->check(ses->client_channel.remote_addr)) {
sessions_to_close.emplace_back(it.first);
}
}
for (uint32_t ses_id : sessions_to_close) {
this->proxy_server->delete_session(ses_id);
}
}
// IP stack simulator (IP bans only; account bans will presumably be handled
// by one of the above cases)
if (this->ip_stack_simulator) {
vector<bufferevent*> bevs_to_disconnect;
for (const auto& it : this->ip_stack_simulator->all_clients()) {
int fd = bufferevent_getfd(it.first);
if (fd < 0) {
continue;
}
struct sockaddr_storage remote_ss;
get_socket_addresses(fd, nullptr, &remote_ss);
if (this->banned_ipv4_ranges->check(remote_ss)) {
bevs_to_disconnect.emplace_back(it.first);
}
}
for (auto* bev : bevs_to_disconnect) {
this->ip_stack_simulator->disconnect_client(bev);
}
}
} }
+9 -1
View File
@@ -14,11 +14,13 @@
#include "Account.hh" #include "Account.hh"
#include "Client.hh" #include "Client.hh"
#include "CommonItemSet.hh" #include "CommonItemSet.hh"
#include "DNSServer.hh"
#include "Episode3/DataIndexes.hh" #include "Episode3/DataIndexes.hh"
#include "Episode3/Tournament.hh" #include "Episode3/Tournament.hh"
#include "EventUtils.hh" #include "EventUtils.hh"
#include "FunctionCompiler.hh" #include "FunctionCompiler.hh"
#include "GSLArchive.hh" #include "GSLArchive.hh"
#include "IPV4RangeSet.hh"
#include "ItemNameIndex.hh" #include "ItemNameIndex.hh"
#include "ItemParameterTable.hh" #include "ItemParameterTable.hh"
#include "LevelTable.hh" #include "LevelTable.hh"
@@ -33,6 +35,7 @@
// Forward declarations due to reference cycles // Forward declarations due to reference cycles
class ProxyServer; class ProxyServer;
class Server; class Server;
class IPStackSimulator;
struct PortConfiguration { struct PortConfiguration {
std::string name; std::string name;
@@ -215,6 +218,7 @@ struct ServerState : public std::enable_shared_from_this<ServerState> {
std::vector<Ep3LobbyBannerEntry> ep3_lobby_banners; std::vector<Ep3LobbyBannerEntry> ep3_lobby_banners;
std::shared_ptr<AccountIndex> account_index; std::shared_ptr<AccountIndex> account_index;
std::shared_ptr<IPV4RangeSet> banned_ipv4_ranges;
std::shared_ptr<TeamIndex> team_index; std::shared_ptr<TeamIndex> team_index;
JSON team_reward_defs_json; JSON team_reward_defs_json;
@@ -253,6 +257,8 @@ struct ServerState : public std::enable_shared_from_this<ServerState> {
bool proxy_allow_save_files = true; bool proxy_allow_save_files = true;
bool proxy_enable_login_options = false; bool proxy_enable_login_options = false;
std::shared_ptr<IPStackSimulator> ip_stack_simulator;
std::shared_ptr<DNSServer> dns_server;
std::shared_ptr<ProxyServer> proxy_server; std::shared_ptr<ProxyServer> proxy_server;
std::shared_ptr<Server> game_server; std::shared_ptr<Server> game_server;
std::shared_ptr<PatchServer> pc_patch_server; std::shared_ptr<PatchServer> pc_patch_server;
@@ -343,7 +349,7 @@ struct ServerState : public std::enable_shared_from_this<ServerState> {
} }
std::shared_ptr<PatchServer::Config> generate_patch_server_config(bool is_bb) const; std::shared_ptr<PatchServer::Config> generate_patch_server_config(bool is_bb) const;
void update_patch_server_configs() const; void update_dependent_server_configs() const;
// The following functions may only be called from a non-event thread if they // The following functions may only be called from a non-event thread if they
// take a from_non_event_thread argument; any function that does not have this // take a from_non_event_thread argument; any function that does not have this
@@ -379,4 +385,6 @@ struct ServerState : public std::enable_shared_from_this<ServerState> {
void enqueue_destroy_lobbies(); void enqueue_destroy_lobbies();
static void dispatch_destroy_lobbies(evutil_socket_t, short, void* ctx); static void dispatch_destroy_lobbies(evutil_socket_t, short, void* ctx);
void disconnect_all_banned_clients();
}; };
+11
View File
@@ -176,6 +176,17 @@
// entries in this list is the same as for IPStackListen and PPPStackListen. // entries in this list is the same as for IPStackListen and PPPStackListen.
"HTTPListen": [], "HTTPListen": [],
// Banned IP address ranges. If a client whose remote IPv4 address is in any
// of these ranges connects to the server, they are immediately disconnected
// with no message. Entries in this list may be individiual IP addresses
// (e.g. "1.2.3.4") or CIDR ranges (e.g. "1.2.3.0/24"). This field takes
// effect immediately when `reload config` is run in the shell; any existing
// clients who are now banned are disconnected.
// Note that this setting does not apply to the HTTP server; it is not
// possible to ban IP ranges from the HTTP server. (It is also inadvisable
// to expose the HTTP server to the public Internet.)
"BannedIPV4Ranges": [],
// Other servers to support proxying to. If this is empty for any game // Other servers to support proxying to. If this is empty for any game
// version, the proxy server is disabled for that version. Entries in these // version, the proxy server is disabled for that version. Entries in these
// dictionaries should be of the form "name": "address:port"; the names are // dictionaries should be of the form "name": "address:port"; the names are
+1
View File
@@ -46,6 +46,7 @@
"IPStackListen": [], "IPStackListen": [],
"PPPStackListen": [], "PPPStackListen": [],
"HTTPListen": [], "HTTPListen": [],
"BannedIPV4Ranges": [],
"Episode3BehaviorFlags": 0xFA, "Episode3BehaviorFlags": 0xFA,
"EnableEpisode3SendFunctionCall": false, "EnableEpisode3SendFunctionCall": false,
"EnableV3V4ProtectedSubcommands": false, "EnableV3V4ProtectedSubcommands": false,