switch to coroutine execution model

This commit is contained in:
Martin Michelsen
2025-04-30 21:43:06 -07:00
parent f65b1f1c14
commit cc99050964
160 changed files with 269127 additions and 227736 deletions
+226 -249
View File
@@ -1,9 +1,6 @@
#include "Channel.hh"
#include <errno.h>
#include <event2/buffer.h>
#include <event2/bufferevent.h>
#include <event2/event.h>
#include <string.h>
#include <unistd.h>
@@ -17,230 +14,17 @@ using namespace std;
extern bool use_terminal_colors;
static void flush_and_free_bufferevent(struct bufferevent* bev) {
bufferevent_flush(bev, EV_READ | EV_WRITE, BEV_FINISHED);
bufferevent_free(bev);
}
Channel::Channel(
Version version,
uint8_t language,
on_command_received_t on_command_received,
on_error_t on_error,
void* context_obj,
const string& name,
phosg::TerminalFormat terminal_send_color,
phosg::TerminalFormat terminal_recv_color)
: bev(nullptr, flush_and_free_bufferevent),
virtual_network_id(0),
version(version),
: version(version),
language(language),
name(name),
terminal_send_color(terminal_send_color),
terminal_recv_color(terminal_recv_color),
on_command_received(on_command_received),
on_error(on_error),
context_obj(context_obj) {
}
Channel::Channel(
struct bufferevent* bev,
uint64_t virtual_network_id,
Version version,
uint8_t language,
on_command_received_t on_command_received,
on_error_t on_error,
void* context_obj,
const string& name,
phosg::TerminalFormat terminal_send_color,
phosg::TerminalFormat terminal_recv_color)
: bev(nullptr, flush_and_free_bufferevent),
version(version),
language(language),
name(name),
terminal_send_color(terminal_send_color),
terminal_recv_color(terminal_recv_color),
on_command_received(on_command_received),
on_error(on_error),
context_obj(context_obj) {
this->set_bufferevent(bev, virtual_network_id);
}
void Channel::replace_with(
Channel&& other,
on_command_received_t on_command_received,
on_error_t on_error,
void* context_obj,
const std::string& name) {
this->set_bufferevent(other.bev.release(), other.virtual_network_id);
this->local_addr = other.local_addr;
this->remote_addr = other.remote_addr;
this->version = other.version;
this->language = other.language;
this->crypt_in = other.crypt_in;
this->crypt_out = other.crypt_out;
this->name = name;
this->terminal_send_color = other.terminal_send_color;
this->terminal_recv_color = other.terminal_recv_color;
this->on_command_received = on_command_received;
this->on_error = on_error;
this->context_obj = context_obj;
other.disconnect(); // Clears crypts, addrs, etc.
}
void Channel::set_bufferevent(struct bufferevent* bev, uint64_t virtual_network_id) {
this->bev.reset(bev);
this->virtual_network_id = virtual_network_id;
if (this->bev.get()) {
int fd = bufferevent_getfd(this->bev.get());
if (fd < 0) {
memset(&this->local_addr, 0, sizeof(this->local_addr));
memset(&this->remote_addr, 0, sizeof(this->remote_addr));
} else {
phosg::get_socket_addresses(fd, &this->local_addr, &this->remote_addr);
}
bufferevent_setcb(this->bev.get(), &Channel::dispatch_on_input, nullptr, &Channel::dispatch_on_error, this);
bufferevent_enable(this->bev.get(), EV_READ | EV_WRITE);
} else {
memset(&this->local_addr, 0, sizeof(this->local_addr));
memset(&this->remote_addr, 0, sizeof(this->remote_addr));
}
}
void Channel::disconnect() {
if (this->bev.get()) {
// If the output buffer is not empty, move the bufferevent into the draining
// pool instead of disconnecting it, to make sure all the data gets sent.
struct evbuffer* out_buffer = bufferevent_get_output(this->bev.get());
if (evbuffer_get_length(out_buffer) == 0) {
this->bev.reset(); // Destructor flushes and frees the bufferevent
} else {
// The callbacks will free it when all the data is sent or the client
// disconnects
auto on_output = +[](struct bufferevent* bev, void*) -> void {
flush_and_free_bufferevent(bev);
};
auto on_error = +[](struct bufferevent* bev, short events, void*) -> void {
if (events & BEV_EVENT_ERROR) {
int err = EVUTIL_SOCKET_ERROR();
channel_exceptions_log.warning(
"Disconnecting channel caused error %d (%s)", err,
evutil_socket_error_to_string(err));
}
if (events & (BEV_EVENT_EOF | BEV_EVENT_ERROR)) {
bufferevent_flush(bev, EV_WRITE, BEV_FINISHED);
bufferevent_free(bev);
}
};
struct bufferevent* bev = this->bev.release();
bufferevent_setcb(bev, nullptr, on_output, on_error, bev);
bufferevent_disable(bev, EV_READ);
}
}
memset(&this->local_addr, 0, sizeof(this->local_addr));
memset(&this->remote_addr, 0, sizeof(this->remote_addr));
this->virtual_network_id = false;
this->crypt_in.reset();
this->crypt_out.reset();
}
Channel::Message Channel::recv() {
struct evbuffer* buf = bufferevent_get_input(this->bev.get());
size_t header_size = (this->version == Version::BB_V4) ? 8 : 4;
PSOCommandHeader header;
if (evbuffer_copyout(buf, &header, header_size) < static_cast<ssize_t>(header_size)) {
throw out_of_range("no command available");
}
if (this->crypt_in.get()) {
this->crypt_in->decrypt(&header, header_size, false);
}
size_t command_logical_size = header.size(version);
// If encryption is enabled, BB pads commands to 8-byte boundaries, and this
// is not reflected in the size field. This logic does not occur if encryption
// is not yet enabled.
size_t command_physical_size = (this->crypt_in.get() && (version == Version::BB_V4))
? ((command_logical_size + 7) & ~7)
: command_logical_size;
if (evbuffer_get_length(buf) < command_physical_size) {
throw out_of_range("no command available");
}
// If we get here, then there is a full command in the buffer. Some encryption
// algorithms' advancement depends on the decrypted data, so we have to
// actually decrypt the header again (with advance=true) to keep them in a
// consistent state.
string header_data(header_size, '\0');
if (evbuffer_remove(buf, header_data.data(), header_data.size()) < static_cast<ssize_t>(header_data.size())) {
throw logic_error("enough bytes available, but could not remove them");
}
if (this->crypt_in.get()) {
this->crypt_in->decrypt(header_data.data(), header_data.size());
}
string command_data(command_physical_size - header_size, '\0');
if (evbuffer_remove(buf, command_data.data(), command_data.size()) < static_cast<ssize_t>(command_data.size())) {
throw logic_error("enough bytes available, but could not remove them");
}
if (this->crypt_in.get()) {
// Some versions of PSO DC can send commands whose sizes are not a multiple
// of 4, but the server is expected to always use a multiple of 4 bytes when
// decrypting (the extra cipher bytes are lost). To emulate this behavior,
// we have to round up the size for DC commands here.
size_t orig_size = command_data.size();
command_data.resize((orig_size + 3) & (~3), 0);
this->crypt_in->decrypt(command_data.data(), command_data.size());
command_data.resize(orig_size);
}
command_data.resize(command_logical_size - header_size);
if (command_data_log.should_log(phosg::LogLevel::INFO) && (this->terminal_recv_color != phosg::TerminalFormat::END)) {
if (use_terminal_colors && this->terminal_recv_color != phosg::TerminalFormat::NORMAL) {
print_color_escape(stderr, this->terminal_recv_color, phosg::TerminalFormat::BOLD, phosg::TerminalFormat::END);
}
if (version == Version::BB_V4) {
command_data_log.info(
"Received from %s (version=BB command=%04hX flag=%08" PRIX32 ")",
this->name.c_str(),
header.command(this->version),
header.flag(this->version));
} else {
command_data_log.info(
"Received from %s (version=%s command=%02hX flag=%02" PRIX32 ")",
this->name.c_str(),
phosg::name_for_enum(this->version),
header.command(this->version),
header.flag(this->version));
}
vector<struct iovec> iovs;
iovs.emplace_back(iovec{.iov_base = header_data.data(), .iov_len = header_data.size()});
iovs.emplace_back(iovec{.iov_base = command_data.data(), .iov_len = command_data.size()});
phosg::print_data(stderr, iovs, 0, nullptr, phosg::PrintDataFlags::PRINT_ASCII | phosg::PrintDataFlags::DISABLE_COLOR | phosg::PrintDataFlags::OFFSET_16_BITS);
if (use_terminal_colors && this->terminal_recv_color != phosg::TerminalFormat::NORMAL) {
phosg::print_color_escape(stderr, phosg::TerminalFormat::NORMAL, phosg::TerminalFormat::END);
}
}
return {
.command = header.command(this->version),
.flag = header.flag(this->version),
.data = std::move(command_data),
};
terminal_recv_color(terminal_recv_color) {
}
void Channel::send(uint16_t cmd, uint32_t flag, bool silent) {
@@ -249,7 +33,7 @@ void Channel::send(uint16_t cmd, uint32_t flag, bool silent) {
void Channel::send(uint16_t cmd, uint32_t flag, const std::vector<std::pair<const void*, size_t>> blocks, bool silent) {
if (!this->connected()) {
channel_exceptions_log.warning("Attempted to send command on closed channel; dropping data");
channel_exceptions_log.warning_f("Attempted to send command on closed channel; dropping data");
return;
}
@@ -342,16 +126,16 @@ void Channel::send(uint16_t cmd, uint32_t flag, const std::vector<std::pair<cons
}
send_data.resize(send_data_size, '\0');
if (!silent && (command_data_log.should_log(phosg::LogLevel::INFO)) && (this->terminal_send_color != phosg::TerminalFormat::END)) {
if (!silent && (command_data_log.should_log(phosg::LogLevel::L_INFO)) && (this->terminal_send_color != phosg::TerminalFormat::END)) {
if (use_terminal_colors && this->terminal_send_color != phosg::TerminalFormat::NORMAL) {
print_color_escape(stderr, phosg::TerminalFormat::FG_YELLOW, phosg::TerminalFormat::BOLD, phosg::TerminalFormat::END);
}
if (version == Version::BB_V4) {
command_data_log.info("Sending to %s (version=BB command=%04hX flag=%08" PRIX32 ")",
this->name.c_str(), cmd, flag);
command_data_log.info_f("Sending to {} (version=BB command={:04X} flag={:08X})",
this->name, cmd, flag);
} else {
command_data_log.info("Sending to %s (version=%s command=%02hX flag=%02" PRIX32 ")",
this->name.c_str(), phosg::name_for_enum(version), cmd, flag);
command_data_log.info_f("Sending to {} (version={} command={:02X} flag={:02X})",
this->name, phosg::name_for_enum(version), cmd, flag);
}
phosg::print_data(stderr, send_data.data(), logical_size, 0, nullptr, phosg::PrintDataFlags::PRINT_ASCII | phosg::PrintDataFlags::DISABLE_COLOR | phosg::PrintDataFlags::OFFSET_16_BITS);
if (use_terminal_colors && this->terminal_send_color != phosg::TerminalFormat::NORMAL) {
@@ -363,8 +147,7 @@ void Channel::send(uint16_t cmd, uint32_t flag, const std::vector<std::pair<cons
this->crypt_out->encrypt(send_data.data(), send_data.size());
}
struct evbuffer* buf = bufferevent_get_output(this->bev.get());
evbuffer_add(buf, send_data.data(), send_data.size());
this->send_raw(std::move(send_data));
}
void Channel::send(uint16_t cmd, uint32_t flag, const void* data, size_t size, bool silent) {
@@ -387,35 +170,229 @@ void Channel::send(const void* data, size_t size, bool silent) {
}
void Channel::send(const string& data, bool silent) {
return this->send(data.data(), data.size(), silent);
this->send(data.data(), data.size(), silent);
}
void Channel::dispatch_on_input(struct bufferevent*, void* ctx) {
Channel* ch = reinterpret_cast<Channel*>(ctx);
// The client can be disconnected during on_command_received, so we have to
// make sure ch->bev is valid every time before calling recv()
while (ch->bev.get()) {
Message msg;
try {
msg = ch->recv();
} catch (const out_of_range&) {
break;
} catch (const exception& e) {
channel_exceptions_log.warning("Error receiving on channel: %s", e.what());
ch->on_error(*ch, BEV_EVENT_ERROR);
break;
asio::awaitable<Channel::Message> Channel::recv() {
size_t header_size = (this->version == Version::BB_V4) ? 8 : 4;
PSOCommandHeader header;
co_await this->recv_raw(&header, header_size);
if (this->crypt_in.get()) {
this->crypt_in->decrypt(&header, header_size);
}
size_t command_logical_size = header.size(version);
if (command_logical_size < header_size) {
throw runtime_error("header size field is smaller than header");
}
// If encryption is enabled, BB pads commands to 8-byte boundaries, and this
// is not reflected in the size field. This logic does not occur if encryption
// is not yet enabled.
size_t command_physical_size = (this->crypt_in.get() && (version == Version::BB_V4))
? ((command_logical_size + 7) & ~7)
: command_logical_size;
string command_data(command_physical_size - header_size, '\0');
co_await this->recv_raw(command_data.data(), command_data.size());
if (this->crypt_in.get()) {
// Some versions of PSO DC can send commands whose sizes are not a multiple
// of 4, but the server is expected to always use a multiple of 4 bytes when
// decrypting (the extra cipher bytes are lost). To emulate this behavior,
// we have to round up the size for DC commands here.
size_t orig_size = command_data.size();
command_data.resize((orig_size + 3) & (~3), 0);
this->crypt_in->decrypt(command_data.data(), command_data.size());
command_data.resize(orig_size);
}
command_data.resize(command_logical_size - header_size);
if (command_data_log.should_log(phosg::LogLevel::L_INFO) && (this->terminal_recv_color != phosg::TerminalFormat::END)) {
if (use_terminal_colors && this->terminal_recv_color != phosg::TerminalFormat::NORMAL) {
print_color_escape(stderr, this->terminal_recv_color, phosg::TerminalFormat::BOLD, phosg::TerminalFormat::END);
}
if (ch->on_command_received) {
ch->on_command_received(*ch, msg.command, msg.flag, msg.data);
if (version == Version::BB_V4) {
command_data_log.info_f(
"Received from {} (version=BB command={:04X} flag={:08X})",
this->name,
header.command(this->version),
header.flag(this->version));
} else {
command_data_log.info_f(
"Received from {} (version={} command={:02X} flag={:02X})",
this->name,
phosg::name_for_enum(this->version),
header.command(this->version),
header.flag(this->version));
}
vector<struct iovec> iovs;
iovs.emplace_back(iovec{.iov_base = &header, .iov_len = header_size});
iovs.emplace_back(iovec{.iov_base = command_data.data(), .iov_len = command_data.size()});
phosg::print_data(stderr, iovs, 0, nullptr, phosg::PrintDataFlags::PRINT_ASCII | phosg::PrintDataFlags::DISABLE_COLOR | phosg::PrintDataFlags::OFFSET_16_BITS);
if (use_terminal_colors && this->terminal_recv_color != phosg::TerminalFormat::NORMAL) {
phosg::print_color_escape(stderr, phosg::TerminalFormat::NORMAL, phosg::TerminalFormat::END);
}
}
co_return Message{
.command = header.command(this->version),
.flag = header.flag(this->version),
.data = std::move(command_data),
};
}
shared_ptr<SocketChannel> SocketChannel::create(
std::shared_ptr<asio::io_context> io_context,
std::unique_ptr<asio::ip::tcp::socket>&& sock,
Version version,
uint8_t language,
const string& name,
phosg::TerminalFormat terminal_send_color,
phosg::TerminalFormat terminal_recv_color) {
shared_ptr<SocketChannel> ret(new SocketChannel(
io_context, std::move(sock), version, language, name, terminal_send_color, terminal_recv_color));
asio::co_spawn(*io_context, ret->send_task(), asio::detached);
return ret;
}
SocketChannel::SocketChannel(
std::shared_ptr<asio::io_context> io_context,
std::unique_ptr<asio::ip::tcp::socket>&& sock,
Version version,
uint8_t language,
const string& name,
phosg::TerminalFormat terminal_send_color,
phosg::TerminalFormat terminal_recv_color)
: Channel(version, language, name, terminal_send_color, terminal_recv_color),
sock(std::move(sock)),
local_addr(this->sock->local_endpoint()),
remote_addr(this->sock->remote_endpoint()),
send_buffer_nonempty_signal(io_context->get_executor()) {}
std::string SocketChannel::default_name() const {
return "ip:" + str_for_endpoint(this->remote_addr);
}
bool SocketChannel::connected() const {
return !this->should_disconnect && this->sock && this->sock->is_open();
}
void SocketChannel::disconnect() {
this->should_disconnect = true;
this->send_buffer_nonempty_signal.set();
}
void SocketChannel::send_raw(string&& data) {
if (this->sock && !this->should_disconnect) {
this->outbound_data.emplace_back(std::move(data));
this->send_buffer_nonempty_signal.set();
}
}
asio::awaitable<void> SocketChannel::recv_raw(void* data, size_t size) {
if (!this->sock || this->should_disconnect) {
throw runtime_error("Cannot receive on closed channel");
}
co_await asio::async_read(*this->sock, asio::buffer(data, size), asio::use_awaitable);
}
asio::awaitable<void> SocketChannel::send_task() {
// Ensure *this doesn't get deleted while the socket is open
auto this_sh = this->shared_from_this();
while (this->sock->is_open()) {
deque<string> to_send;
to_send.swap(this->outbound_data);
if (!to_send.empty()) {
vector<asio::const_buffer> bufs;
bufs.reserve(to_send.size());
for (const auto& it : to_send) {
bufs.emplace_back(asio::buffer(it.data(), it.size()));
}
co_await asio::async_write(*this->sock, bufs, asio::use_awaitable);
}
if (this->outbound_data.empty()) {
if (this->should_disconnect) {
this->sock->close();
} else {
this->send_buffer_nonempty_signal.clear();
co_await this->send_buffer_nonempty_signal.wait();
}
}
}
}
void Channel::dispatch_on_error(struct bufferevent*, short events, void* ctx) {
Channel* ch = reinterpret_cast<Channel*>(ctx);
if (ch->on_error) {
ch->on_error(*ch, events);
} else {
ch->disconnect();
PeerChannel::PeerChannel(
std::shared_ptr<asio::io_context> io_context,
Version version,
uint8_t language,
const std::string& name,
phosg::TerminalFormat terminal_send_color,
phosg::TerminalFormat terminal_recv_color)
: Channel(version, language, name, terminal_send_color, terminal_recv_color),
send_buffer_nonempty_signal(io_context->get_executor()) {}
void PeerChannel::link_peers(std::shared_ptr<PeerChannel> peer1, std::shared_ptr<PeerChannel> peer2) {
if (peer1->connected() || peer2->connected()) {
throw logic_error("Cannot link already-connected peer channels");
}
peer1->peer = peer2;
peer2->peer = peer1;
}
std::string PeerChannel::default_name() const {
return std::format("peer:{}->{}", reinterpret_cast<const void*>(this), reinterpret_cast<const void*>(this->peer.lock().get()));
}
bool PeerChannel::connected() const {
return (!this->inbound_data.empty()) || (this->peer.lock() != nullptr);
}
void PeerChannel::disconnect() {
auto peer = this->peer.lock();
if (peer) {
peer->peer.reset();
peer->send_buffer_nonempty_signal.set();
}
this->peer.reset();
this->send_buffer_nonempty_signal.set();
}
void PeerChannel::send_raw(string&& data) {
auto peer = this->peer.lock();
if (peer) {
peer->inbound_data.emplace_back(std::move(data));
peer->send_buffer_nonempty_signal.set();
}
}
asio::awaitable<void> PeerChannel::recv_raw(void* data, size_t size) {
while (size > 0) {
while (this->inbound_data.empty() && this->peer.lock()) {
this->send_buffer_nonempty_signal.clear();
co_await this->send_buffer_nonempty_signal.wait();
}
if (!this->inbound_data.empty()) {
auto& front_block = this->inbound_data.front();
if (size < front_block.size()) {
memcpy(data, front_block.data(), size);
front_block = front_block.substr(size);
size = 0;
} else {
memcpy(data, front_block.data(), front_block.size());
size -= front_block.size();
data = reinterpret_cast<uint8_t*>(data) + front_block.size();
this->inbound_data.pop_front();
}
} else if (!this->peer.lock()) {
throw runtime_error("Channel peer has disconnected");
}
}
}