diff --git a/rust/lib/yunq/src/server.rs b/rust/lib/yunq/src/server.rs index dcfd03c..57b844a 100644 --- a/rust/lib/yunq/src/server.rs +++ b/rust/lib/yunq/src/server.rs @@ -1,7 +1,13 @@ +use core::future::Future; + use crate::buffer::ByteBuffer; +use alloc::sync::Arc; use alloc::vec::Vec; use mammoth::cap::Capability; +use mammoth::sync::Mutex; use mammoth::syscall; +use mammoth::task::Executor; +use mammoth::task::Task; use mammoth::thread; use mammoth::thread::JoinHandle; use mammoth::zion::z_cap_t; @@ -59,3 +65,81 @@ where { thread::spawn(move || server.server_loop()) } + +pub trait AsyncYunqServer +where + Self: Send + Sync + 'static, +{ + fn server_loop(self: Arc, executor: Arc>) { + loop { + let mut byte_buffer = ByteBuffer::<1024>::new(); + let mut cap_buffer = vec![0; 10]; + let (_, _, reply_port_cap) = syscall::endpoint_recv( + self.endpoint_cap(), + byte_buffer.mut_slice(), + &mut cap_buffer, + ) + .expect("Failed to call endpoint recv"); + + let method = byte_buffer + .at::(8) + .expect("Failed to access request length."); + let self_clone = self.clone(); + executor.lock().spawn(Task::new((async move || { + self_clone + .handle_request_and_response(method, byte_buffer, cap_buffer, reply_port_cap) + .await; + })())); + } + } + + fn handle_request_and_response( + &self, + method: u64, + mut byte_buffer: ByteBuffer<1024>, + mut cap_buffer: Vec, + reply_port_cap: Capability, + ) -> impl Future + Sync { + async move { + let resp = self + .handle_request(method, &mut byte_buffer, &mut cap_buffer) + .await; + + match resp { + Ok(resp_len) => syscall::reply_port_send( + reply_port_cap, + byte_buffer.slice(resp_len), + &cap_buffer, + ) + .expect("Failed to reply"), + Err(err) => { + crate::message::serialize_error(&mut byte_buffer, err); + syscall::reply_port_send(reply_port_cap, &byte_buffer.slice(0x10), &[]) + .expect("Failed to reply w/ error") + } + } + + () + } + } + + fn endpoint_cap(&self) -> &Capability; + + fn create_client_cap(&self) -> Result { + self.endpoint_cap() + .duplicate(!mammoth::zion::kZionPerm_Read) + } + fn handle_request( + &self, + method_number: u64, + byte_buffer: &mut ByteBuffer<1024>, + cap_buffer: &mut Vec, + ) -> impl Future> + Sync; +} + +pub fn spawn_async_server_thread(server: Arc, executor: Arc>) +where + T: AsyncYunqServer + Send + Sync + 'static, +{ + server.server_loop(executor); +} diff --git a/yunq/rust/src/codegen.rs b/yunq/rust/src/codegen.rs index 94769f2..b022a33 100644 --- a/yunq/rust/src/codegen.rs +++ b/yunq/rust/src/codegen.rs @@ -392,13 +392,116 @@ fn generate_server(interface: &Interface) -> TokenStream { } } +fn generate_async_server_case(method: &Method) -> TokenStream { + let id = proc_macro2::Literal::u64_suffixed(method.number); + let name = ident(&method.name.to_case(Case::Snake)); + let maybe_req = method.request.clone().map(|r| ident(&r)); + let maybe_resp = method.response.clone().map(|r| ident(&r)); + match (maybe_req, maybe_resp) { + (Some(req), Some(_)) => quote! { + #id => { + let req = #req::parse_from_request(byte_buffer, cap_buffer)?; + let resp = self.handler.#name(req).await?; + cap_buffer.resize(0, 0); + let resp_len = resp.serialize_as_request(0, byte_buffer, cap_buffer)?; + Ok(resp_len) + }, + }, + (Some(req), None) => quote! { + #id => { + let req = #req::parse_from_request(byte_buffer, cap_buffer)?; + self.handler.#name(req).await?; + cap_buffer.resize(0, 0); + // TODO: Implement serialization for EmptyMessage so this is less hacky. + yunq::message::serialize_error(byte_buffer, ZError::from(0)); + Ok(0x10) + }, + }, + (None, Some(_)) => quote! { + #id => { + let resp = self.handler.#name().await?; + cap_buffer.resize(0, 0); + let resp_len = resp.serialize_as_request(0, byte_buffer, cap_buffer)?; + Ok(resp_len) + }, + }, + _ => unreachable!(), + } +} + +fn generate_async_server_method(method: &Method) -> TokenStream { + let name = ident(&method.name.to_case(Case::Snake)); + let maybe_req = method.request.clone().map(|r| ident(&r)); + let maybe_resp = method.response.clone().map(|r| ident(&r)); + match (maybe_req, maybe_resp) { + (Some(req), Some(resp)) => quote! { + fn #name (&self, req: #req) -> impl Future> + Sync; + }, + (Some(req), None) => quote! { + fn #name (&self, req: #req) -> impl Future> + Sync; + }, + (None, Some(resp)) => quote! { + fn #name (&self) -> impl Future> + Sync; + }, + _ => unreachable!(), + } +} + +fn generate_async_server(interface: &Interface) -> TokenStream { + let server_name = ident(&(String::from("Async") + &interface.name.clone() + "Server")); + let server_trait = ident(&(String::from("Async") + &interface.name.clone() + "ServerHandler")); + let server_trait_methods = interface.methods.iter().map(generate_async_server_method); + let server_match_cases = interface.methods.iter().map(generate_async_server_case); + quote! { + pub trait #server_trait { + #(#server_trait_methods)* + } + + pub struct #server_name { + endpoint_cap: Capability, + handler: T + } + + impl #server_name { + pub fn new(handler: T) -> Result { + Ok(Self { + endpoint_cap: syscall::endpoint_create()?, + handler, + }) + } + } + + impl yunq::server::AsyncYunqServer for #server_name { + fn endpoint_cap(&self) -> &Capability { + &self.endpoint_cap + } + + async fn handle_request( + &self, + method_number: u64, + byte_buffer: &mut ByteBuffer<1024>, + cap_buffer: &mut Vec, + ) -> Result { + match method_number { + #(#server_match_cases)* + + _ => Err(ZError::UNIMPLEMENTED) + } + } + } + } +} + fn generate_interface(interface: &Interface) -> TokenStream { let client = generate_client(interface); let server = generate_server(interface); + let async_server = generate_async_server(interface); quote! { #client #server + + #async_server } } @@ -428,6 +531,7 @@ pub fn generate_code(ast: &[Decl]) -> String { let interface_imports = if any_interfaces(ast) { quote! { + use core::future::Future; use mammoth::cap::Capability; use mammoth::syscall; }