Generalize the EndpointServer to require less boilerplate.

Classes can now inherit from the EndpointServer and just implement the
HandleRequest function.
This commit is contained in:
Drew Galbraith 2023-08-01 15:52:08 -07:00
parent 847d37addc
commit caccb08e16
14 changed files with 215 additions and 147 deletions

View file

@ -7,8 +7,8 @@
Command::~Command() {}
DmaReadCommand::DmaReadCommand(uint64_t lba, uint64_t sector_cnt,
DmaCallback callback, z_cap_t reply_port)
: reply_port_(reply_port),
DmaCallback callback, ResponseContext& response)
: response_(response),
lba_(lba),
sector_cnt_(sector_cnt),
callback_(callback) {
@ -50,5 +50,5 @@ void DmaReadCommand::PopulatePrdt(PhysicalRegionDescriptor* prdt) {
prdt[0].byte_count = region_.size();
}
void DmaReadCommand::Callback() {
callback_(reply_port_, lba_, sector_cnt_, region_.cap());
callback_(response_, lba_, sector_cnt_, region_.cap());
}

View file

@ -1,6 +1,7 @@
#pragma once
#include <mammoth/memory_region.h>
#include <mammoth/response_context.h>
#include <stdint.h>
#include "ahci/ahci.h"
@ -15,9 +16,9 @@ class Command {
class DmaReadCommand : public Command {
public:
typedef void (*DmaCallback)(z_cap_t, uint64_t, uint64_t, z_cap_t);
typedef void (*DmaCallback)(ResponseContext&, uint64_t, uint64_t, z_cap_t);
DmaReadCommand(uint64_t lba, uint64_t sector_cnt, DmaCallback callback,
z_cap_t reply_port);
ResponseContext& reply_port);
virtual ~DmaReadCommand() override;
@ -27,7 +28,7 @@ class DmaReadCommand : public Command {
void Callback() override;
private:
z_cap_t reply_port_;
ResponseContext& response_;
uint64_t lba_;
uint64_t sector_cnt_;
DmaCallback callback_;

View file

@ -6,13 +6,14 @@
glcr::ErrorOr<MappedMemoryRegion> DenaliClient::ReadSectors(
uint64_t device_id, uint64_t lba, uint64_t num_sectors) {
DenaliRead read{
DenaliReadRequest read{
.device_id = device_id,
.lba = lba,
.size = num_sectors,
};
auto pair_or =
endpoint_->CallEndpointGetCap<DenaliRead, DenaliReadResponse>(read);
endpoint_->CallEndpointGetCap<DenaliReadRequest, DenaliReadResponse>(
read);
if (!pair_or) {
return pair_or.error();
}

View file

@ -19,14 +19,14 @@ uint64_t main(uint64_t init_port_cap) {
ASSIGN_OR_RETURN(MappedMemoryRegion ahci_region, stub.GetAhciConfig());
ASSIGN_OR_RETURN(auto driver, AhciDriver::Init(ahci_region));
ASSIGN_OR_RETURN(glcr::UniquePtr<EndpointServer> endpoint,
EndpointServer::Create());
ASSIGN_OR_RETURN(glcr::UniquePtr<DenaliServer> server,
DenaliServer::Create(*driver));
ASSIGN_OR_RETURN(glcr::UniquePtr<EndpointClient> client,
endpoint->CreateClient());
server->CreateClient());
check(stub.Register("denali", *client));
DenaliServer server(glcr::Move(endpoint), *driver);
RET_ERR(server.RunServer());
RET_ERR(server->RunServer());
// FIXME: Add thread join.
return 0;
}

View file

@ -7,61 +7,57 @@
namespace {
DenaliServer* gServer = nullptr;
void HandleResponse(z_cap_t reply_port, uint64_t lba, uint64_t size,
void HandleResponse(ResponseContext& response, uint64_t lba, uint64_t size,
z_cap_t mem) {
gServer->HandleResponse(reply_port, lba, size, mem);
gServer->HandleResponse(response, lba, size, mem);
}
} // namespace
DenaliServer::DenaliServer(glcr::UniquePtr<EndpointServer> server,
AhciDriver& driver)
: server_(glcr::Move(server)), driver_(driver) {
gServer = this;
glcr::ErrorOr<glcr::UniquePtr<DenaliServer>> DenaliServer::Create(
AhciDriver& driver) {
z_cap_t cap;
RET_ERR(ZEndpointCreate(&cap));
return glcr::UniquePtr<DenaliServer>(new DenaliServer(cap, driver));
}
glcr::ErrorCode DenaliServer::RunServer() {
while (true) {
uint64_t buff_size = kBuffSize;
z_cap_t reply_port;
RET_ERR(server_->Recieve(&buff_size, read_buffer_, &reply_port));
if (buff_size < sizeof(uint64_t)) {
dbgln("Skipping invalid message");
continue;
}
uint64_t type = *reinterpret_cast<uint64_t*>(read_buffer_);
switch (type) {
case Z_INVALID:
dbgln(reinterpret_cast<char*>(read_buffer_));
break;
case DENALI_READ: {
DenaliRead* read_req = reinterpret_cast<DenaliRead*>(read_buffer_);
uint64_t memcap = 0;
RET_ERR(HandleRead(*read_req, reply_port));
break;
glcr::ErrorCode DenaliServer::HandleRequest(RequestContext& request,
ResponseContext& response) {
switch (request.request_id()) {
case DENALI_READ: {
DenaliReadRequest* req = 0;
glcr::ErrorCode err = request.As<DenaliReadRequest>(&req);
if (err != glcr::OK) {
response.WriteError(err);
}
default:
dbgln("Invalid message type.");
return glcr::UNIMPLEMENTED;
err = HandleRead(req, response);
if (err != glcr::OK) {
response.WriteError(err);
}
break;
}
default:
response.WriteError(glcr::UNIMPLEMENTED);
break;
}
return glcr::OK;
}
glcr::ErrorCode DenaliServer::HandleRead(const DenaliRead& read,
z_cap_t reply_port) {
ASSIGN_OR_RETURN(AhciDevice * device, driver_.GetDevice(read.device_id));
glcr::ErrorCode DenaliServer::HandleRead(DenaliReadRequest* request,
ResponseContext& context) {
ASSIGN_OR_RETURN(AhciDevice * device, driver_.GetDevice(request->device_id));
device->IssueCommand(
new DmaReadCommand(read.lba, read.size, ::HandleResponse, reply_port));
device->IssueCommand(new DmaReadCommand(request->lba, request->size,
::HandleResponse, context));
return glcr::OK;
}
void DenaliServer::HandleResponse(z_cap_t reply_port, uint64_t lba,
void DenaliServer::HandleResponse(ResponseContext& response, uint64_t lba,
uint64_t size, z_cap_t mem) {
DenaliReadResponse resp{
.device_id = 0,
.lba = lba,
.size = size,
};
check(ZReplyPortSend(reply_port, sizeof(resp), &resp, 1, &mem));
check(response.WriteStructWithCap<DenaliReadResponse>(resp, mem));
}

View file

@ -6,21 +6,26 @@
#include "ahci/ahci_driver.h"
#include "denali/denali.h"
class DenaliServer {
class DenaliServer : public EndpointServer {
public:
DenaliServer(glcr::UniquePtr<EndpointServer> server, AhciDriver& driver);
static glcr::ErrorOr<glcr::UniquePtr<DenaliServer>> Create(
AhciDriver& driver);
glcr::ErrorCode RunServer();
void HandleResponse(z_cap_t reply_port, uint64_t lba, uint64_t size,
void HandleResponse(ResponseContext& response, uint64_t lba, uint64_t size,
z_cap_t cap);
virtual glcr::ErrorCode HandleRequest(RequestContext& request,
ResponseContext& response) override;
private:
static const uint64_t kBuffSize = 1024;
glcr::UniquePtr<EndpointServer> server_;
uint8_t read_buffer_[kBuffSize];
AhciDriver& driver_;
glcr::ErrorCode HandleRead(const DenaliRead& read, z_cap_t reply_port);
DenaliServer(z_cap_t endpoint_cap, AhciDriver& driver)
: EndpointServer(endpoint_cap), driver_(driver) {}
glcr::ErrorCode HandleRead(DenaliReadRequest* request,
ResponseContext& context);
};

View file

@ -5,7 +5,7 @@
#define DENALI_INVALID 0
#define DENALI_READ 100
struct DenaliRead {
struct DenaliReadRequest {
uint64_t request_type = DENALI_READ;
uint64_t device_id;

View file

@ -13,16 +13,15 @@ uint64_t main(uint64_t port_cap) {
check(ParseInitPort(port_cap));
ASSIGN_OR_RETURN(auto server, YellowstoneServer::Create());
Thread server_thread = server->RunServer();
Thread registration_thread = server->RunRegistration();
uint64_t vaddr;
check(ZAddressSpaceMap(gSelfVmasCap, 0, gBootDenaliVmmoCap, &vaddr));
ASSIGN_OR_RETURN(glcr::UniquePtr<EndpointClient> client,
server->GetServerClient());
server->CreateClient());
check(SpawnProcessFromElfRegion(vaddr, glcr::Move(client)));
check(server_thread.Join());
check(server->RunServer());
check(registration_thread.Join());
dbgln("Yellowstone Finished Successfully.");
return 0;

View file

@ -12,11 +12,6 @@
namespace {
void ServerThreadBootstrap(void* yellowstone) {
dbgln("Yellowstone server starting");
static_cast<YellowstoneServer*>(yellowstone)->ServerThread();
}
void RegistrationThreadBootstrap(void* yellowstone) {
dbgln("Yellowstone registration starting");
static_cast<YellowstoneServer*>(yellowstone)->RegistrationThread();
@ -40,71 +35,60 @@ glcr::ErrorOr<PartitionInfo> HandleDenaliRegistration(z_cap_t endpoint_cap) {
} // namespace
glcr::ErrorOr<glcr::UniquePtr<YellowstoneServer>> YellowstoneServer::Create() {
ASSIGN_OR_RETURN(auto server, EndpointServer::Create());
z_cap_t cap;
RET_ERR(ZEndpointCreate(&cap));
ASSIGN_OR_RETURN(PortServer port, PortServer::Create());
return glcr::UniquePtr<YellowstoneServer>(
new YellowstoneServer(glcr::Move(server), port));
return glcr::UniquePtr<YellowstoneServer>(new YellowstoneServer(cap, port));
}
YellowstoneServer::YellowstoneServer(glcr::UniquePtr<EndpointServer> server,
PortServer port)
: server_(glcr::Move(server)), register_port_(port) {}
Thread YellowstoneServer::RunServer() {
return Thread(ServerThreadBootstrap, this);
}
YellowstoneServer::YellowstoneServer(z_cap_t endpoint_cap, PortServer port)
: EndpointServer(endpoint_cap), register_port_(port) {}
Thread YellowstoneServer::RunRegistration() {
return Thread(RegistrationThreadBootstrap, this);
}
void YellowstoneServer::ServerThread() {
while (true) {
uint64_t num_bytes = kBufferSize;
uint64_t reply_port_cap;
// FIXME: Error handling.
check(server_->Recieve(&num_bytes, server_buffer_, &reply_port_cap));
YellowstoneGetReq* req =
reinterpret_cast<YellowstoneGetReq*>(server_buffer_);
switch (req->type) {
case kYellowstoneGetAhci: {
dbgln("Yellowstone::GetAHCI");
YellowstoneGetAhciResp resp{
.type = kYellowstoneGetAhci,
.ahci_phys_offset = pci_reader_.GetAhciPhysical(),
};
check(ZReplyPortSend(reply_port_cap, sizeof(resp), &resp, 0, nullptr));
break;
}
case kYellowstoneGetRegistration: {
dbgln("Yellowstone::GetRegistration");
auto client_or = register_port_.CreateClient();
if (!client_or.ok()) {
check(client_or.error());
}
YellowstoneGetRegistrationResp resp;
uint64_t reg_cap = client_or.value().cap();
check(ZReplyPortSend(reply_port_cap, sizeof(resp), &resp, 1, &reg_cap));
break;
}
case kYellowstoneGetDenali: {
dbgln("Yellowstone::GetDenali");
z_cap_t new_denali;
check(ZCapDuplicate(denali_cap_, &new_denali));
YellowstoneGetDenaliResp resp{
.type = kYellowstoneGetDenali,
.device_id = device_id_,
.lba_offset = lba_offset_,
};
check(ZReplyPortSend(reply_port_cap, sizeof(resp), &resp, 1,
&new_denali));
break;
}
default:
dbgln("Unknown request type: %x", req->type);
break;
glcr::ErrorCode YellowstoneServer::HandleRequest(RequestContext& request,
ResponseContext& response) {
switch (request.request_id()) {
case kYellowstoneGetAhci: {
dbgln("Yellowstone::GetAHCI");
YellowstoneGetAhciResp resp{
.type = kYellowstoneGetAhci,
.ahci_phys_offset = pci_reader_.GetAhciPhysical(),
};
RET_ERR(response.WriteStruct<YellowstoneGetAhciResp>(resp));
break;
}
case kYellowstoneGetRegistration: {
dbgln("Yellowstone::GetRegistration");
auto client_or = register_port_.CreateClient();
if (!client_or.ok()) {
check(client_or.error());
}
YellowstoneGetRegistrationResp resp;
uint64_t reg_cap = client_or.value().cap();
RET_ERR(response.WriteStructWithCap(resp, reg_cap));
break;
}
case kYellowstoneGetDenali: {
dbgln("Yellowstone::GetDenali");
z_cap_t new_denali;
check(ZCapDuplicate(denali_cap_, &new_denali));
YellowstoneGetDenaliResp resp{
.type = kYellowstoneGetDenali,
.device_id = device_id_,
.lba_offset = lba_offset_,
};
RET_ERR(response.WriteStructWithCap(resp, new_denali));
break;
}
default:
dbgln("Unknown request type: %x", request.request_id());
return glcr::UNIMPLEMENTED;
break;
}
return glcr::OK;
}
void YellowstoneServer::RegistrationThread() {
@ -127,7 +111,7 @@ void YellowstoneServer::RegistrationThread() {
uint64_t vaddr;
check(
ZAddressSpaceMap(gSelfVmasCap, 0, gBootVictoriaFallsVmmoCap, &vaddr));
auto client_or = GetServerClient();
auto client_or = CreateClient();
if (!client_or.ok()) {
check(client_or.error());
}
@ -144,8 +128,3 @@ void YellowstoneServer::RegistrationThread() {
dbgln(name.cstr());
}
}
glcr::ErrorOr<glcr::UniquePtr<EndpointClient>>
YellowstoneServer::GetServerClient() {
return server_->CreateClient();
}

View file

@ -8,20 +8,19 @@
#include "hw/pcie.h"
class YellowstoneServer {
class YellowstoneServer : public EndpointServer {
public:
static glcr::ErrorOr<glcr::UniquePtr<YellowstoneServer>> Create();
Thread RunServer();
Thread RunRegistration();
void ServerThread();
void RegistrationThread();
glcr::ErrorOr<glcr::UniquePtr<EndpointClient>> GetServerClient();
virtual glcr::ErrorCode HandleRequest(RequestContext& request,
ResponseContext& response) override;
private:
glcr::UniquePtr<EndpointServer> server_;
// FIXME: Separate this to its own service.
PortServer register_port_;
static const uint64_t kBufferSize = 128;
@ -36,5 +35,5 @@ class YellowstoneServer {
PciReader pci_reader_;
YellowstoneServer(glcr::UniquePtr<EndpointServer> server, PortServer port);
YellowstoneServer(z_cap_t endpoint_cap, PortServer port);
};