- Address PR feedback.
- Add SecureTransport backend for macOS.
This commit is contained in:
comex
2023-07-01 15:02:25 -07:00
parent 98685d48e3
commit 0e191c2711
8 changed files with 279 additions and 213 deletions

View File

@ -64,7 +64,7 @@ public:
std::shared_ptr<SslContextSharedData>& shared_data,
std::unique_ptr<SSLConnectionBackend>&& backend)
: ServiceFramework{system_, "ISslConnection"}, ssl_version{version},
shared_data_{shared_data}, backend_{std::move(backend)} {
shared_data{shared_data}, backend{std::move(backend)} {
// clang-format off
static const FunctionInfo functions[] = {
{0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"},
@ -112,10 +112,10 @@ public:
}
~ISslConnection() {
shared_data_->connection_count--;
if (fd_to_close_.has_value()) {
const s32 fd = *fd_to_close_;
if (!do_not_close_socket_) {
shared_data->connection_count--;
if (fd_to_close.has_value()) {
const s32 fd = *fd_to_close;
if (!do_not_close_socket) {
LOG_ERROR(Service_SSL,
"do_not_close_socket was changed after setting socket; is this right?");
} else {
@ -132,30 +132,30 @@ public:
private:
SslVersion ssl_version;
std::shared_ptr<SslContextSharedData> shared_data_;
std::unique_ptr<SSLConnectionBackend> backend_;
std::optional<int> fd_to_close_;
bool do_not_close_socket_ = false;
bool get_server_cert_chain_ = false;
std::shared_ptr<Network::SocketBase> socket_;
bool did_set_host_name_ = false;
bool did_handshake_ = false;
std::shared_ptr<SslContextSharedData> shared_data;
std::unique_ptr<SSLConnectionBackend> backend;
std::optional<int> fd_to_close;
bool do_not_close_socket = false;
bool get_server_cert_chain = false;
std::shared_ptr<Network::SocketBase> socket;
bool did_set_host_name = false;
bool did_handshake = false;
ResultVal<s32> SetSocketDescriptorImpl(s32 fd) {
LOG_DEBUG(Service_SSL, "called, fd={}", fd);
ASSERT(!did_handshake_);
ASSERT(!did_handshake);
auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; });
s32 ret_fd;
// Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor
if (do_not_close_socket_) {
if (do_not_close_socket) {
auto res = bsd->DuplicateSocketImpl(fd);
if (!res.has_value()) {
LOG_ERROR(Service_SSL, "Failed to duplicate socket with fd {}", fd);
return ResultInvalidSocket;
}
fd = *res;
fd_to_close_ = fd;
fd_to_close = fd;
ret_fd = fd;
} else {
ret_fd = -1;
@ -165,34 +165,34 @@ private:
LOG_ERROR(Service_SSL, "invalid socket fd {}", fd);
return ResultInvalidSocket;
}
socket_ = std::move(*sock);
backend_->SetSocket(socket_);
socket = std::move(*sock);
backend->SetSocket(socket);
return ret_fd;
}
Result SetHostNameImpl(const std::string& hostname) {
LOG_DEBUG(Service_SSL, "called. hostname={}", hostname);
ASSERT(!did_handshake_);
Result res = backend_->SetHostName(hostname);
ASSERT(!did_handshake);
Result res = backend->SetHostName(hostname);
if (res == ResultSuccess) {
did_set_host_name_ = true;
did_set_host_name = true;
}
return res;
}
Result SetVerifyOptionImpl(u32 option) {
ASSERT(!did_handshake_);
ASSERT(!did_handshake);
LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option);
return ResultSuccess;
}
Result SetIOModeImpl(u32 _mode) {
auto mode = static_cast<IoMode>(_mode);
Result SetIoModeImpl(u32 input_mode) {
auto mode = static_cast<IoMode>(input_mode);
ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking);
ASSERT_OR_EXECUTE(socket_, { return ResultNoSocket; });
ASSERT_OR_EXECUTE(socket, { return ResultNoSocket; });
const bool non_block = mode == IoMode::NonBlocking;
const Network::Errno error = socket_->SetNonBlock(non_block);
const Network::Errno error = socket->SetNonBlock(non_block);
if (error != Network::Errno::SUCCESS) {
LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block);
}
@ -200,18 +200,18 @@ private:
}
Result SetSessionCacheModeImpl(u32 mode) {
ASSERT(!did_handshake_);
ASSERT(!did_handshake);
LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode);
return ResultSuccess;
}
Result DoHandshakeImpl() {
ASSERT_OR_EXECUTE(!did_handshake_ && socket_, { return ResultNoSocket; });
ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; });
ASSERT_OR_EXECUTE_MSG(
did_set_host_name_, { return ResultInternalError; },
did_set_host_name, { return ResultInternalError; },
"Expected SetHostName before DoHandshake");
Result res = backend_->DoHandshake();
did_handshake_ = res.IsSuccess();
Result res = backend->DoHandshake();
did_handshake = res.IsSuccess();
return res;
}
@ -225,7 +225,7 @@ private:
u32 size;
u32 offset;
};
if (!get_server_cert_chain_) {
if (!get_server_cert_chain) {
// Just return the first one, unencoded.
ASSERT_OR_EXECUTE_MSG(
!certs.empty(), { return {}; }, "Should be at least one server cert");
@ -248,9 +248,9 @@ private:
}
ResultVal<std::vector<u8>> ReadImpl(size_t size) {
ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; });
ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
std::vector<u8> res(size);
ResultVal<size_t> actual = backend_->Read(res);
ResultVal<size_t> actual = backend->Read(res);
if (actual.Failed()) {
return actual.Code();
}
@ -259,8 +259,8 @@ private:
}
ResultVal<size_t> WriteImpl(std::span<const u8> data) {
ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; });
return backend_->Write(data);
ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
return backend->Write(data);
}
ResultVal<s32> PendingImpl() {
@ -295,7 +295,7 @@ private:
void SetIoMode(HLERequestContext& ctx) {
IPC::RequestParser rp{ctx};
const u32 mode = rp.Pop<u32>();
const Result res = SetIOModeImpl(mode);
const Result res = SetIoModeImpl(mode);
IPC::ResponseBuilder rb{ctx, 2};
rb.Push(res);
}
@ -307,22 +307,26 @@ private:
}
void DoHandshakeGetServerCert(HLERequestContext& ctx) {
struct OutputParameters {
u32 certs_size;
u32 certs_count;
};
static_assert(sizeof(OutputParameters) == 0x8);
const Result res = DoHandshakeImpl();
u32 certs_count = 0;
u32 certs_size = 0;
OutputParameters out{};
if (res == ResultSuccess) {
auto certs = backend_->GetServerCerts();
auto certs = backend->GetServerCerts();
if (certs.Succeeded()) {
const std::vector<u8> certs_buf = SerializeServerCerts(*certs);
ctx.WriteBuffer(certs_buf);
certs_count = static_cast<u32>(certs->size());
certs_size = static_cast<u32>(certs_buf.size());
out.certs_count = static_cast<u32>(certs->size());
out.certs_size = static_cast<u32>(certs_buf.size());
}
}
IPC::ResponseBuilder rb{ctx, 4};
rb.Push(res);
rb.Push(certs_size);
rb.Push(certs_count);
rb.PushRaw(out);
}
void Read(HLERequestContext& ctx) {
@ -371,10 +375,10 @@ private:
switch (parameters.option) {
case OptionType::DoNotCloseSocket:
do_not_close_socket_ = static_cast<bool>(parameters.value);
do_not_close_socket = static_cast<bool>(parameters.value);
break;
case OptionType::GetServerCertChain:
get_server_cert_chain_ = static_cast<bool>(parameters.value);
get_server_cert_chain = static_cast<bool>(parameters.value);
break;
default:
LOG_WARNING(Service_SSL, "Unknown option={}, value={}", parameters.option,
@ -390,7 +394,7 @@ class ISslContext final : public ServiceFramework<ISslContext> {
public:
explicit ISslContext(Core::System& system_, SslVersion version)
: ServiceFramework{system_, "ISslContext"}, ssl_version{version},
shared_data_{std::make_shared<SslContextSharedData>()} {
shared_data{std::make_shared<SslContextSharedData>()} {
static const FunctionInfo functions[] = {
{0, &ISslContext::SetOption, "SetOption"},
{1, nullptr, "GetOption"},
@ -412,7 +416,7 @@ public:
private:
SslVersion ssl_version;
std::shared_ptr<SslContextSharedData> shared_data_;
std::shared_ptr<SslContextSharedData> shared_data;
void SetOption(HLERequestContext& ctx) {
struct Parameters {
@ -439,17 +443,17 @@ private:
IPC::ResponseBuilder rb{ctx, 2, 0, 1};
rb.Push(backend_res.Code());
if (backend_res.Succeeded()) {
rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data_,
rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data,
std::move(*backend_res));
}
}
void GetConnectionCount(HLERequestContext& ctx) {
LOG_WARNING(Service_SSL, "connection_count={}", shared_data_->connection_count);
LOG_DEBUG(Service_SSL, "connection_count={}", shared_data->connection_count);
IPC::ResponseBuilder rb{ctx, 3};
rb.Push(ResultSuccess);
rb.Push(shared_data_->connection_count);
rb.Push(shared_data->connection_count);
}
void ImportServerPki(HLERequestContext& ctx) {