Generalize the EndpointServer to require less boilerplate.

Classes can now inherit from the EndpointServer and just implement the
HandleRequest function.
This commit is contained in:
Drew Galbraith 2023-08-01 15:52:08 -07:00
parent 847d37addc
commit caccb08e16
14 changed files with 215 additions and 147 deletions

View file

@ -5,6 +5,8 @@
#include <ztypes.h>
#include "mammoth/endpoint_client.h"
#include "mammoth/request_context.h"
#include "mammoth/response_context.h"
class EndpointServer {
public:
@ -12,19 +14,24 @@ class EndpointServer {
EndpointServer(const EndpointServer&) = delete;
EndpointServer& operator=(const EndpointServer&) = delete;
static glcr::ErrorOr<glcr::UniquePtr<EndpointServer>> Create();
static glcr::UniquePtr<EndpointServer> Adopt(z_cap_t endpoint_cap);
glcr::ErrorOr<glcr::UniquePtr<EndpointClient>> CreateClient();
// FIXME: Release Cap here.
z_cap_t GetCap() { return endpoint_cap_; }
glcr::ErrorCode Recieve(uint64_t* num_bytes, void* data,
glcr::ErrorCode Receive(uint64_t* num_bytes, void* data,
z_cap_t* reply_port_cap);
glcr::ErrorCode RunServer();
virtual glcr::ErrorCode HandleRequest(RequestContext& request,
ResponseContext& response) = 0;
protected:
EndpointServer(z_cap_t cap) : endpoint_cap_(cap) {}
private:
z_cap_t endpoint_cap_;
EndpointServer(z_cap_t cap) : endpoint_cap_(cap) {}
static const uint64_t kBufferSize = 1024;
uint8_t recieve_buffer_[kBufferSize];
};

View file

@ -0,0 +1,32 @@
#pragma once
#include <glacier/status/error.h>
#include <stdint.h>
class RequestContext {
public:
RequestContext(void* buffer, uint64_t buffer_length)
: buffer_(buffer), buffer_length_(buffer_length) {
if (buffer_length_ < sizeof(uint64_t)) {
request_id_ = -1;
} else {
request_id_ = *reinterpret_cast<uint64_t*>(buffer);
}
}
uint64_t request_id() { return request_id_; }
template <typename T>
glcr::ErrorCode As(T** arg) {
if (buffer_length_ < sizeof(T)) {
return glcr::INVALID_ARGUMENT;
}
*arg = reinterpret_cast<T*>(buffer_);
return glcr::OK;
}
private:
uint64_t request_id_;
void* buffer_;
uint64_t buffer_length_;
};

View file

@ -0,0 +1,40 @@
#pragma once
#include <glacier/status/error.h>
#include <zcall.h>
#include <ztypes.h>
class ResponseContext {
public:
ResponseContext(z_cap_t reply_port) : reply_port_(reply_port) {}
ResponseContext(ResponseContext&) = delete;
template <typename T>
glcr::ErrorCode WriteStruct(const T& response) {
// FIXME: Here and below probably don't count as written on error.
written_ = true;
return ZReplyPortSend(reply_port_, sizeof(T), &response, 0, nullptr);
}
template <typename T>
glcr::ErrorCode WriteStructWithCap(const T& response, z_cap_t capability) {
written_ = true;
return ZReplyPortSend(reply_port_, sizeof(T), &response, 1, &capability);
}
glcr::ErrorCode WriteError(glcr::ErrorCode code) {
uint64_t response[2]{
static_cast<uint64_t>(-1),
code,
};
written_ = true;
return ZReplyPortSend(reply_port_, sizeof(response), &response, 0, nullptr);
}
bool HasWritten() { return written_; }
private:
z_cap_t reply_port_;
bool written_ = false;
};

View file

@ -1,14 +1,6 @@
#include "mammoth/endpoint_server.h"
glcr::ErrorOr<glcr::UniquePtr<EndpointServer>> EndpointServer::Create() {
uint64_t cap;
RET_ERR(ZEndpointCreate(&cap));
return glcr::UniquePtr<EndpointServer>(new EndpointServer(cap));
}
glcr::UniquePtr<EndpointServer> EndpointServer::Adopt(z_cap_t endpoint_cap) {
return glcr::UniquePtr<EndpointServer>(new EndpointServer(endpoint_cap));
}
#include "mammoth/debug.h"
glcr::ErrorOr<glcr::UniquePtr<EndpointClient>> EndpointServer::CreateClient() {
uint64_t client_cap;
@ -17,7 +9,24 @@ glcr::ErrorOr<glcr::UniquePtr<EndpointClient>> EndpointServer::CreateClient() {
return EndpointClient::AdoptEndpoint(client_cap);
}
glcr::ErrorCode EndpointServer::Recieve(uint64_t* num_bytes, void* data,
glcr::ErrorCode EndpointServer::Receive(uint64_t* num_bytes, void* data,
z_cap_t* reply_port_cap) {
return ZEndpointRecv(endpoint_cap_, num_bytes, data, reply_port_cap);
}
glcr::ErrorCode EndpointServer::RunServer() {
while (true) {
uint64_t message_size = kBufferSize;
uint64_t reply_port_cap = 0;
RET_ERR(Receive(&message_size, recieve_buffer_, &reply_port_cap));
RequestContext request(recieve_buffer_, message_size);
ResponseContext response(reply_port_cap);
// FIXME: Consider pumping these errors into the response as well.
RET_ERR(HandleRequest(request, response));
if (!response.HasWritten()) {
dbgln("Returning without having written a response. Req type %x",
request.request_id());
}
}
}