Initial commit

This commit is contained in:
zawz 2024-06-21 13:05:37 +02:00
commit 5b6d53edaf
15 changed files with 3991 additions and 0 deletions

1
.gitignore vendored Normal file
View file

@ -0,0 +1 @@
target

64
.vscode/launch.json vendored Normal file
View file

@ -0,0 +1,64 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Debug unit tests in library 'rspc'",
"cargo": {
"args": [
"test",
"--no-run",
"--lib",
"--package=rspc"
],
"filter": {
"name": "rspc",
"kind": "lib"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug executable 'rspc'",
"cargo": {
"args": [
"build",
"--bin=rspc",
"--package=rspc"
],
"filter": {
"name": "rspc",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug unit tests in executable 'rspc'",
"cargo": {
"args": [
"test",
"--no-run",
"--bin=rspc",
"--package=rspc"
],
"filter": {
"name": "rspc",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
}
]
}

6
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,6 @@
{
"rust-analyzer.checkOnSave.overrideCommand": [
"cargo", "check", "--message-format=json", "--target-dir", "target/lsp" ],
"rust-analyzer.cargo.buildScripts.overrideCommand": [
"cargo", "check", "--message-format=json", "--target-dir", "target/lsp" ]
}

1339
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

22
Cargo.toml Normal file
View file

@ -0,0 +1,22 @@
[package]
name = "rspc"
version = "0.1.0"
edition = "2021"
[dependencies]
serde = { version = "1.0", features = ["derive"] }
futures = "0.3"
tokio = { version = "1.0", features = ["full"] }
syn = { version = "2.0", features = ["full"] }
rspc_macros = { path = "macros", version = "0.1" }
thiserror = "1.0.49"
async-trait = "0.1.73"
async-std = "1.12.0"
tokio-serde = { version = "0.8.0", features=["json","bincode"] }
pin-project = "1.1.3"
tokio-util = { version = "0.7.10", features=["codec"] }
serde_json = "1.0.108"
oneshot = { version = "0.1.6", features=["std"] }
[lib]

1333
example/Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

13
example/Cargo.toml Normal file
View file

@ -0,0 +1,13 @@
[package]
name = "rspc_example"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.86"
async-std = "1.12.0"
rspc = { path = ".." }
serde = "1.0.203"
tokio = "1.38.0"

117
example/src/data.rs Normal file
View file

@ -0,0 +1,117 @@
use std::time::Duration;
use serde::{Serialize, Deserialize};
pub type RetData = (usize,Option<usize>);
pub fn dur_to_num(dur: Duration) -> (u128, &'static str) {
if dur.as_nanos() < 10000 {
(dur.as_nanos(), "ns")
} else if dur.as_micros() < 10000 {
(dur.as_micros(), "µs")
} else if dur.as_millis() < 10000 {
(dur.as_millis(), "ms")
} else {
(dur.as_secs().into(), "s")
}
}
pub fn dur_to_str(dur: Duration) -> String {
let (n,s) = dur_to_num(dur);
n.to_string() +" "+ s
}
pub fn process_data(dat: TestData) -> RetData {
(dat.len(), dat.calc())
}
pub fn process_data_ref(dat: &TestData) -> RetData {
(dat.len(), dat.calc())
}
pub const DATASIZE: usize = 1000000;
const TEST_STRINGS: [&str; 4] = [
"toto",
"tata",
"titi",
"tutu"
];
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub enum TestEnum {
NoValue,
Num(usize),
Str(String),
}
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct TestData {
vec: Vec<(bool, TestEnum)>,
}
use crate::transport::{ClientTransporter,ServerTransporter};
#[rspc::service]
impl TestData {
pub fn len(&self) -> usize {
self.vec.len()
}
pub fn calc(&self) -> Option<usize> {
for v in &self.vec {
if let (true, TestEnum::Num(n)) = v {
return Some(*n);
}
}
return None;
}
pub async fn slow_fct(&self) -> Option<usize> {
tokio::time::sleep(Duration::from_secs(1)).await;
self.calc()
}
fn internal_fib(n: usize) -> usize {
if n <= 1 {
n
} else {
Self::internal_fib(n-1) + Self::internal_fib(n-2)
}
}
pub fn fib(&self, n: usize) -> usize {
Self::internal_fib(n)
}
pub fn add(&self, a: usize, b: usize) -> usize {
return a+b;
}
pub fn calc_add(&self, data: TestData) -> Option<usize> {
return self.calc().and_then(|a| data.calc().map(|b| a+b));
}
pub fn push(&mut self, u: (bool, TestEnum)) {
self.vec.push(u)
}
}
pub fn make_test_data(n: usize) -> TestData {
let mut v = Vec::with_capacity(n);
let mut b = true;
for i in 0..n {
v.push(( b ,
match i%3 {
0 => TestEnum::NoValue,
1 => TestEnum::Num(i),
2 => TestEnum::Str(TEST_STRINGS[i%4].to_string()),
_ => panic!("unexpected error"),
}));
b = !b;
}
TestData {
vec: v,
}
}

73
example/src/main.rs Normal file
View file

@ -0,0 +1,73 @@
pub mod data;
use data::{make_test_data,dur_to_str,TestData,TestEnum,TestDataClient,TestDataServer,DATASIZE};
use tokio::join;
use rspc::transport::serde::TcpClient;
use rspc::transport;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// let t = TcpClient::connect("127.0.0.1:3306").await.unwrap();
// let client = t.spawn().await;
// let client = TestDataClient::new(client);
let (c,s) = transport::channel::new_async();
let data: TestData = make_test_data(DATASIZE);
let srv_thread = tokio::spawn(async move {
let mut server = TestDataServer::from(data);
server.listen(s).await
} );
let client = TestDataClient::new(c);
let clientref = &client;
let job1 = async {
let now = std::time::Instant::now();
assert_eq!(DATASIZE, client.len().await.unwrap());
println!("len: {}", dur_to_str(now.elapsed()));
let now = std::time::Instant::now();
assert_eq!(267914296, clientref.fib(42).await.unwrap());
println!("fib1: {}", dur_to_str(now.elapsed()));
};
let job2 = async {
let now = std::time::Instant::now();
assert_eq!(DATASIZE, client.len().await.unwrap());
println!("len: {}", dur_to_str(now.elapsed()));
let now = std::time::Instant::now();
assert_eq!(4, client.calc().await.unwrap().unwrap_or(0) );
println!("calc: {}", dur_to_str(now.elapsed()));
let cdat = make_test_data(DATASIZE);
let now = std::time::Instant::now();
assert_eq!(8, client.calc_add(cdat).await.unwrap().unwrap_or(0));
println!("calc_add: {}", dur_to_str(now.elapsed()));
let now = std::time::Instant::now();
assert_eq!(267914296, client.fib(42).await.unwrap());
println!("fib2: {}", dur_to_str(now.elapsed()));
let now = std::time::Instant::now();
clientref.push((false, TestEnum::NoValue)).await.unwrap();
println!("push: {}", dur_to_str(now.elapsed()));
let now = std::time::Instant::now();
assert_eq!(DATASIZE+1, client.len().await.unwrap());
println!("len: {}", dur_to_str(now.elapsed()));
};
join!(job1, job2);
client.stop().await.unwrap();
srv_thread.await.unwrap().unwrap();
Ok(())
}

16
macros/Cargo.toml Normal file
View file

@ -0,0 +1,16 @@
[package]
name = "rspc_macros"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
#anyhow = "1.0"
#futures = "0.3"
#tokio = { version = "1.0", features = ["macros", "net", "rt-multi-thread"] }
syn = { version = "2.0", features = ["full"] }
quote = "1.0"
[lib]
proc-macro = true

344
macros/src/lib.rs Normal file
View file

@ -0,0 +1,344 @@
extern crate proc_macro;
extern crate quote;
use proc_macro::TokenStream;
use syn::{
ext::IdentExt,
parse::{Parse, ParseStream},
parse_macro_input,
spanned::Spanned,
FnArg, Ident,
Pat, PatType, ReturnType, Visibility, Receiver, ItemImpl, ImplItemFn, ImplItemType, Signature, Attribute,
};
use quote::{format_ident, quote};
/// Accumulates multiple errors into a result.
/// Only use this for recoverable errors, i.e. non-parse errors. Fatal errors should early exit to
/// avoid further complications.
macro_rules! extend_errors {
($errors: ident, $e: expr) => {
match $errors {
Ok(_) => $errors = Err($e),
Err(ref mut errors) => errors.extend($e),
}
};
}
/// Convert a snake_case string to a UpperCamelCase string.
/// Used for converting function names to enum member names
fn snake_to_upper_camel(ident_str: &str) -> String {
let mut camel_ty = String::with_capacity(ident_str.len());
let mut last_char_was_underscore = true;
for c in ident_str.chars() {
match c {
'_' => last_char_was_underscore = true,
c if last_char_was_underscore => {
camel_ty.extend(c.to_uppercase());
last_char_was_underscore = false;
}
c => camel_ty.extend(c.to_lowercase()),
}
}
camel_ty.shrink_to_fit();
camel_ty
}
struct RpcMethod {
pub attrs: Vec<Attribute>,
pub sig: Signature,
pub reciever: Option<Receiver>,
pub args: Vec<PatType>,
}
struct Service {
pub ident: Ident,
pub rpcs: Vec<RpcMethod>,
pub item: ItemImpl,
pub types: Vec<ImplItemType>,
}
impl TryFrom<ImplItemFn> for RpcMethod {
type Error = syn::Error;
fn try_from(value: ImplItemFn) -> Result<Self, Self::Error> {
let mut reciever = None;
let mut args = vec!();
for arg in &value.sig.inputs {
match arg {
FnArg::Typed(captured) if matches!(&*captured.pat, Pat::Ident(_)) => {
args.push(captured.clone());
},
FnArg::Typed(captured) => {
return Err(syn::Error::new(captured.pat.span(), "patterns aren't allowed in RPC args"));
},
FnArg::Receiver(v) => {
if matches!(v.reference, None) {
return Err(syn::Error::new(v.span(), "self cannot be consumed in RPC method"));
} else {
reciever = Some(v.clone());
}
}
}
}
Ok(RpcMethod {
attrs: value.attrs,
sig: value.sig,
reciever,
args,
})
}
}
/// Parse an service instance, from a regular rust `impl`.
impl Parse for Service {
fn parse(input: ParseStream) -> syn::Result<Self> {
let item: ItemImpl = input.parse()?;
let mut errors = Ok(());
let mut typeitems: Vec<ImplItemType> = vec!();
let fns: Vec<&ImplItemFn> = item.items.iter().filter(|x| {
match x {
syn::ImplItem::Fn(fnit) => {
matches!(fnit.vis, Visibility::Public(_))
},
syn::ImplItem::Type(x) => {
typeitems.push(x.clone());
false
}
_ => false,
}
}).map(|x|
if let syn::ImplItem::Fn(fnit) = x {
fnit
} else {
unreachable!()
}
).collect();
let ident = match item.self_ty.as_ref() {
syn::Type::Path(x) => {
match x.path.get_ident() {
Some(i) => i.clone(),
None => return Err(syn::Error::new(x.span(), "generics and paths are not supported")),
}
},
_ => return Err(syn::Error::new(item.self_ty.span(), "unsupported self type")),
};
let mut rpcs = vec!();
for one_fn in fns {
match RpcMethod::try_from(one_fn.clone()) {
Ok(v) => rpcs.push(v),
Err(e) => extend_errors!(errors, e),
}
}
for rpc in &rpcs {
let fnident = &rpc.sig.ident;
if matches!(&fnident.to_string()[..], "new" | "stop") {
extend_errors!(
errors,
syn::Error::new(
fnident.span(),
format!(
"method name conflicts with generated fn `Client::{}`",
fnident
)
)
)
}
}
errors?;
Ok(Self {
ident,
rpcs,
item,
types: typeitems,
})
}
}
/// Macro used to generate the RSPC Server and Client objects.
#[proc_macro_attribute]
pub fn service(_attr: TokenStream, mut input: TokenStream) -> TokenStream {
let inputclone = input.clone();
let Service {
ident,
rpcs,
item: _,
types: _,
} = parse_macro_input!(inputclone as Service);
let transport_request = &format_ident!("{}TransportRequest", ident);
let transport_response = &format_ident!("{}TransportResponse", ident);
let server = &format_ident!("{}Server", ident);
let client = &format_ident!("{}Client", ident);
let fn_names: Vec<_> = rpcs.iter().map(|x| {
&x.sig.ident
}).collect();
let fn_names_camel: Vec<_> = fn_names.iter().map(|x| {
Ident::new(&snake_to_upper_camel(&x.unraw().to_string()), x.span())
}).collect();
let fn_locks: Vec<_> = rpcs.iter().map(|x| {
match x.reciever.as_ref() {
Some(rcv) => match &rcv.mutability {
Some(_) => quote!(write()),
None => quote!(read()),
},
None => quote!(read()),
}
}).collect();
let fn_mut: Vec<_> = rpcs.iter().map(|x| {
match x.reciever.as_ref() {
Some(rcv) => match &rcv.mutability {
Some(_) => quote!(mut),
None => quote!(),
},
None => quote!(),
}
}).collect();
let fn_await: Vec<_> = rpcs.iter().map(|x| {
match &x.sig.asyncness {
Some(_) => quote!(.await),
None => quote!(),
}
}).collect();
let args: Vec<_> = rpcs.iter().map(|x| {
let args = &x.args;
quote!{ #(#args),* }
}).collect();
let arg_types: Vec<_> = rpcs.iter().map(|x| {
let args = x.args.iter().map(|a| &a.ty);
quote!{ #(#args),* }
}).collect();
let arg_idents: Vec<_> = rpcs.iter().map(|x| {
let args = x.args.iter().map(|a| &a.pat);
quote!{ #(#args),* }
}).collect();
let outs: Vec<_> = rpcs.iter().map(|x| {
match &x.sig.output {
ReturnType::Default => quote!{()},
ReturnType::Type(_, t) => quote!{#t},
}
}).collect();
let t = quote! {
#[derive(PartialEq,Debug,Serialize,Deserialize)]
pub enum #transport_request {
#( #fn_names_camel(#arg_types) ),* ,
Stop,
}
#[derive(PartialEq,Debug,Serialize,Deserialize)]
pub enum #transport_response {
#( #fn_names_camel(#outs) ),* ,
Stop,
}
impl #transport_response {
#(
pub fn #fn_names(self) -> #outs {
if let #transport_response::#fn_names_camel(v) = self {
v
} else {
panic!()
}
}
)*
}
pub struct #server {
obj: std::sync::Arc<async_std::sync::RwLock<#ident>>,
}
impl From<#ident> for #server {
fn from(obj: #ident) -> Self {
Self {
obj: std::sync::Arc::new(async_std::sync::RwLock::new(obj)),
}
}
}
impl #server {
pub async fn listen<Tr>(&mut self, mut transport: Tr) -> Result<(), Tr::Error>
where
Tr: ServerTransporter<#transport_request,#transport_response> + Send
{
{
transport.listen( |v,obj| {
let obj = obj.clone();
async move {
match v {
#transport_request::Stop => None,
#(
#transport_request::#fn_names_camel(#arg_idents) => {
let #fn_mut obj = obj.#fn_locks.await;
Some(#transport_response::#fn_names_camel(obj.#fn_names(#arg_idents)#fn_await))
},
)*
}
}
}, Some(#transport_response::Stop), self.obj.clone()).await
}
}
}
pub struct #client<Tr>
where
Tr: ClientTransporter<#transport_request,#transport_response>,
{
transporter: Tr,
}
impl<Tr> #client<Tr>
where
Tr: ClientTransporter<#transport_request,#transport_response>
{
pub fn new(transporter: Tr) -> Self {
#client {
transporter,
}
}
#(
pub async fn #fn_names(&self, #args) -> Result<#outs, Tr::Error> {
Ok(self.transporter.request(#transport_request::#fn_names_camel(#arg_idents)).await?.#fn_names())
}
)*
// TODO: graceful stop response
pub async fn stop(&self) -> Result<(), Tr::Error> {
self.transporter.request(#transport_request::Stop).await.map(|_| ())
}
}
};
// for debugging
// println!("{}", t);
input.extend(TokenStream::from(t));
input
}

55
src/lib.rs Normal file
View file

@ -0,0 +1,55 @@
pub mod transport;
pub use rspc_macros::service;
#[cfg(test)]
mod tests {
use super::transport::{channel, ClientTransporter,ServerTransporter};
use super::service;
use serde::{Deserialize, Serialize};
// use rspc::transport::{ClientTransporter,ServerTransporter};
#[derive(Serialize,Deserialize)]
pub struct MyStruct {
my_vec: Vec<String>,
}
#[service]
impl MyStruct
{
pub fn len(&self) -> usize {
self.my_vec.len()
}
pub fn push(&mut self, val: String) {
self.my_vec.push(val)
}
pub fn pop(&mut self) -> Option<String> {
self.my_vec.pop()
}
}
#[tokio::test]
async fn test() {
let my_data = MyStruct {
my_vec: Vec::new(),
};
let (c,s) = channel::new_async();
let srv_thread = tokio::spawn(async move {
let mut server = MyStructServer::from(my_data);
server.listen(s).await
} );
let client = MyStructClient::new(c);
assert_eq!(client.len().await.unwrap(), 0);
client.push("Hello world!".to_string()).await.unwrap();
assert_eq!(client.len().await.unwrap(), 1);
assert_eq!(client.pop().await.unwrap(), Some("Hello world!".to_string()));
client.stop().await.unwrap();
srv_thread.await.unwrap().unwrap();
}
}

312
src/transport/channel.rs Normal file
View file

@ -0,0 +1,312 @@
use std::collections::BTreeMap;
use futures::future::BoxFuture;
use futures::stream::FuturesUnordered;
use futures::{Future, StreamExt};
use thiserror::Error;
use tokio::sync::oneshot;
use tokio::sync::mpsc::{self, UnboundedSender, UnboundedReceiver};
use async_trait::async_trait;
use super::{ClientTransporter,ServerTransporter};
#[derive(Error,Debug)]
pub enum Error {
#[error("channel recv error")]
ChannelRecvError,
#[error("channel resp error")]
ChannelRespError,
#[error("channel send error")]
ChannelSendError,
#[error(transparent)]
OneshotRecvError(#[from] oneshot::error::RecvError),
}
#[derive(Clone)]
/// Client endpoint of any channel.
pub struct ChannelClient<T,R> {
channel: UnboundedSender<(T,oneshot::Sender<R>)>,
}
/// Server endpoint of a synchronous channel
pub struct SyncChannelServer<T,R> {
channel: UnboundedReceiver<(T,oneshot::Sender<R>)>,
}
/// Create a new synchronous channel client/server instance.
///
/// Synchronous channels can only process one job at a time.
/// If you want job concurrency refer to new_async() and Asynchronous channels
///
/// Unpon recieving a stop signal, jobs currently pending recv() will not be processed.
/// Said jobs will continue waiting until either the server listens again, or the server is dropped.
///
/// Example:
/// ```no_run
/// use serde::{Serialize,Deserialize};
/// use rspc::transport::{channel, ClientTransporter,ServerTransporter};
///
/// #[derive(Serialize,Deserialize)]
/// pub struct MyStruct;
///
/// #[rspc::service]
/// impl MyStruct {
/// pub fn dummy(&self) -> () { () }
/// }
///
/// let (client_channel,server_channel) = channel::new_sync();
///
/// let server_thread = tokio::spawn(async move {
/// let mut server = MyStructServer::from( MyStruct );
/// server.listen(server_channel).await
/// });
///
/// let client = MyStructClient::new(client_channel);
/// ```
pub fn new_sync<T,R>() -> (ChannelClient<T,R>,SyncChannelServer<T,R>) {
let (c,s) = mpsc::unbounded_channel();
(
ChannelClient { channel: c },
SyncChannelServer { channel: s },
)
}
impl<T,R> ChannelClient<T,R>
where
T: Send + Sync,
R: Send + Sync,
{
pub async fn internal_request(&self, data: T) -> Result<R, Error> {
let (t,r) = oneshot::channel();
self.channel.send( (data, t) ).map_err(|_| Error::ChannelSendError)?;
let output = r.await?;
Ok(output)
}
}
#[async_trait]
impl<T,R> ClientTransporter<T,R> for ChannelClient<T,R>
where
T: Send + Sync,
R: Send + Sync,
{
type Error = Error;
async fn request(&self, data: T) -> Result<R, Self::Error> {
self.internal_request(data).await
}
}
#[async_trait]
impl<T,R> ServerTransporter<T,R> for SyncChannelServer<T,R>
where
T: Send + Sync,
R: Send + Sync,
{
type Error = Error;
async fn listen<F, FR, D>(&mut self, handler: F, stop_response: Option<R>, userdata: D) -> Result<(), Self::Error>
where
FR: Future<Output = Option<R>> + Send,
F: Fn(T, &D) -> FR + Send + Sync,
D: Send+Sync,
{
while let Some(msg) = self.channel.recv().await {
match handler(msg.0, &userdata).await {
Some(r) => msg.1.send(r).map_err(|_| Error::ChannelRespError)?,
None => {
if let Some(v) = stop_response {
msg.1.send(v).map_err(|_| Error::ChannelRespError)?;
}
break;
},
};
}
Ok(())
}
}
pub struct AsyncChannelServer<T,R> {
channel: UnboundedReceiver<(T,oneshot::Sender<R>)>,
}
/// Create a new asynchronous channel client/server instance.
///
/// Can process any number of jobs in parallel.
///
/// Unpon recieving a stop signal, pending jobs are finished and reponded to, but new jobs are not processed.
/// Said new jobs will continue waiting until either the server listens again, or the server is dropped.
///
/// Example:
/// ```no_run
/// use serde::{Serialize,Deserialize};
/// use rspc::transport::{channel, ClientTransporter,ServerTransporter};
///
/// #[derive(Serialize,Deserialize)]
/// pub struct MyStruct;
///
/// #[rspc::service]
/// impl MyStruct {
/// pub fn dummy(&self) -> () { () }
/// }
///
/// let (client_channel,server_channel) = channel::new_async();
///
/// let server_thread = tokio::spawn(async move {
/// let mut server = MyStructServer::from( MyStruct );
/// server.listen(server_channel).await
/// });
///
/// let client = MyStructClient::new(client_channel);
/// ```
pub fn new_async<T,R>() -> (ChannelClient<T,R>,AsyncChannelServer<T,R>) {
let (c,s) = mpsc::unbounded_channel();
(
ChannelClient { channel: c },
AsyncChannelServer { channel: s },
)
}
impl<T,R> AsyncChannelServer<T,R>
where
T: Send + Sync,
R: Send + Sync + 'static,
{
async fn internal_listen<F, FR, D>(&mut self, handler: F, stop_response: Option<R>, userdata: D) -> Result<(), Error>
where
FR: Future<Output = Option<R>> + Send + 'static,
F: Fn(T, &D) -> FR + Send + Sync,
D: Send + Sync,
{
let mut pending = FuturesUnordered::new();
loop {
tokio::select!{
Some(rcv) = self.channel.recv() => {
pending.push(
async {
(
tokio::spawn(handler(rcv.0, &userdata)).await.unwrap(),
rcv.1,
)
}
);
},
Some(r) = pending.next() => {
match r {
(Some(r),sender) => sender.send(r).map_err(|_| Error::ChannelRespError)?,
(None,sender) => {
if let Some(v) = stop_response {
sender.send(v).map_err(|_| Error::ChannelRespError)?;
}
break;
},
}
},
else => break,
}
}
let results: Vec<_> = pending.collect().await;
results.into_iter().map(|r| -> Result<(), Error> {
match r {
(Some(r),sender) => sender.send(r).map_err(|_| Error::ChannelRespError),
_ => Ok(()),
}
}).collect::<Result<Vec<_>, Error>>()?;
Ok(())
}
}
#[async_trait]
impl<T,R> ServerTransporter<T,R> for AsyncChannelServer<T,R>
where
T: Send + Sync,
R: Send + Sync + 'static,
{
type Error = Error;
async fn listen<F, FR, D>(&mut self, handler: F, stop_response: Option<R>, userdata: D) -> Result<(), Self::Error>
where
FR: Future<Output = Option<R>> + Send + 'static,
F: Fn(T, &D) -> FR + Send + Sync,
D: Send + Sync,
{
self.internal_listen(handler, stop_response, userdata).await
}
}
/// Create a channel multiplexer.
///
/// This is intended to be used with self-mutable clients to provide immutable clients to it
pub fn new_multiplexer<T,R>() -> (ChannelClient<T,R>,Multiplexer<T,R>) {
let (c,s) = mpsc::unbounded_channel();
(
ChannelClient { channel: c },
Multiplexer { channel: s },
)
}
pub struct Multiplexer<T,R> {
channel: UnboundedReceiver<(T,oneshot::Sender<R>)>,
}
impl<T,R> Multiplexer<T,R>
where
T: Send + Sync,
R: Send + Sync + 'static,
{
/// Start multiplexing with a 3rd party.
///
/// While this is running, the associated ChannelClient can then call request() to send data, and recieve the final response from the 3rd party
///
/// Requirements:
/// - The 3rd party must be handling (usize,T) as input and (usize,R) as output
/// - sender_send(sender) is the function used to send data to the 3rd party
/// - listener_recv(listener) is the function used to recieve data from the 3rd party
pub async fn start<S, SF, L, LF>(mut self, mut sender: S, sender_send: SF, mut listener: L, listener_recv: LF) -> Result<(), Error>
where
SF: Fn(&mut S, (usize,T)) -> BoxFuture<bool> + Send + Sync + 'static,
LF: Fn(&mut L) -> BoxFuture<Option<(usize,R)>> + 'static,
{
let mut pending: BTreeMap<usize, oneshot::Sender<R>> = BTreeMap::new();
let mut id_counter: usize = 0;
loop {
tokio::select!{
q = self.channel.recv() => {
match q {
Some(rcv) => {
if sender_send(&mut sender, (id_counter, rcv.0)).await {
pending.insert(id_counter, rcv.1);
id_counter+=1;
};
},
None => break,
}
},
q = listener_recv(&mut listener) => {
match q {
Some((id,r)) => {
match pending.remove(&id) {
Some(s) => s.send(r).map_err(|_| Error::ChannelRespError)?,
None => todo!(),
};
}
None => break,
}
},
else => break,
}
}
Ok(())
}
}

98
src/transport/mod.rs Normal file
View file

@ -0,0 +1,98 @@
use async_trait::async_trait;
use futures::{Future, future::BoxFuture, stream::FuturesUnordered, StreamExt};
pub mod channel;
pub mod serde;
/// Definition of a client transporter for RSPC.
///
/// ### Implementation specifications
/// - ClientTransporter is a producer object, can be single-producer or multi-producer.
/// - ClientTransporter must be immutable. If a client implementation requires self mutability,
/// use Mutex, RwLock, or similar tools to mutate values without requiring self mutability
#[async_trait]
pub trait ClientTransporter<T,R> {
type Error: std::fmt::Debug;
async fn request(&self, data: T) -> Result<R, Self::Error>;
}
/// Definition of a server transporter for RSPC.
///
/// ### Implementation specifications
/// - ServerTransporter is a single-consumer object
/// - Upon recieving a stop request (handler function return None), server must respond with stop_response if specified to only this request and none other.
/// Finishing and responding to pending jobs is optional.
#[async_trait]
pub trait ServerTransporter<T,R>
{
type Error: std::fmt::Debug;
async fn listen<F, FR, D>(&mut self, handler: F, stop_response: Option<R>, userdata: D) -> Result<(), Self::Error>
where
FR: Future<Output = Option<R>> + Send + 'static,
F: Fn(T, &D) -> FR + Send + Sync + Copy + 'static,
D: Send + Sync + 'static,
;
}
pub
async fn async_listener<T, R, C, L, LF, S, SF, F, FR, D, E>(
listener: &mut L, listener_recv: LF,
sender: &mut S, sender_send: SF,
handler: F, stop_response: Option<R>, userdata: &D) -> Result<(), E>
where
T: Send + Sync,
R: Send + Sync + 'static,
FR: Future<Output = Option<R>> + Send + 'static,
F: Fn(T, &D) -> FR + Send + Sync,
D: Send + Sync + 'static,
C: Send + Sync + 'static,
SF: Fn(&mut S, (C,R)) -> BoxFuture<Result<(), E>> + Send + Sync + 'static,
LF: Fn(&mut L) -> BoxFuture<Result<Option<(C,T)>, E>> + 'static,
{
let mut pending = FuturesUnordered::new();
loop {
tokio::select!{
rcv = listener_recv(listener) => {
match rcv? {
Some((id, data)) => {
pending.push(
async {
(
id,
tokio::spawn(handler(data, &userdata)).await.unwrap(),
)
}
);
}
None => break,
}
},
Some(r) = pending.next() => {
match r {
(id,Some(r)) => {
sender_send(sender, (id,r)).await?;
},
(id,None) => {
if let Some(v) = stop_response {
sender_send(sender, (id,v)).await?;
}
break;
},
}
},
else => break,
}
}
let results: Vec<_> = pending.collect().await;
for it in results {
match it {
(id,Some(r)) => {
sender_send(sender, (id,r)).await?;
},
_ => (),
}
}
Ok(())
}

198
src/transport/serde.rs Normal file
View file

@ -0,0 +1,198 @@
use async_trait::async_trait;
use futures::future::BoxFuture;
use thiserror::Error;
use futures::prelude::*;
use serde::{Deserialize, Serialize};
use std::net::Ipv4Addr;
use std::sync::atomic::AtomicUsize;
use tokio::net::{TcpListener, TcpStream};
//use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio::io::{ReadHalf,WriteHalf};
use tokio_serde::SymmetricallyFramed;
use tokio_serde::formats::*;
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
use super::channel::{self, ChannelClient, Multiplexer};
use super::{ClientTransporter,ServerTransporter};
#[derive(Debug, Error)]
pub enum Error {
#[error(transparent)]
IO(std::io::Error),
#[error("reader error")]
ReaderError,
}
type SymmetricalReader<T> = SymmetricallyFramed<
FramedRead<ReadHalf<TcpStream>, LengthDelimitedCodec>,
T,
SymmetricalBincode<T>>;
type SymmetricalWriter<T> = SymmetricallyFramed<
FramedWrite<WriteHalf<TcpStream>, LengthDelimitedCodec>,
T,
SymmetricalBincode<T>>;
pub struct Receiver<T> {
pub reader: SymmetricalReader<T>,
}
impl<T> Receiver<T> where
SymmetricalReader<T> : TryStream<Ok=T> + Unpin,
{
pub async fn recv(&mut self) -> Result<Option<T>, Error> {
if let Ok(msg) = self.reader.try_next().await {
Ok(msg)
} else {
Err(Error::ReaderError)
}
}
}
pub struct Sender<T> {
pub writer: SymmetricalWriter<T>,
}
impl<T> Sender<T> where
T: for<'a> Deserialize<'a> + Serialize + Unpin
{
pub async fn send(&mut self, item: T) -> Result<(), Error> {
self.writer.send(item).await.map_err(Error::IO)
}
}
async fn split<T,R>(socket: TcpStream) -> (Sender<(usize, T)>, Receiver<(usize, R)>) {
let (reader, writer) = tokio::io::split(socket);
let reader: FramedRead<
ReadHalf<TcpStream>,
LengthDelimitedCodec,
> = FramedRead::new(reader, LengthDelimitedCodec::new());
let reader: SymmetricalReader<(usize, R)> = SymmetricallyFramed::new(
reader, SymmetricalBincode::default());
let writer: FramedWrite<
WriteHalf<TcpStream>,
LengthDelimitedCodec,
> = FramedWrite::new(writer, LengthDelimitedCodec::new());
let writer: SymmetricalWriter<(usize, T)> = SymmetricallyFramed::new(
writer, SymmetricalBincode::default());
(Sender{ writer }, Receiver{ reader })
}
pub struct TcpClient<T,R> {
sender: Sender<(usize, T)>,
receiver: Receiver<(usize, R)>,
multiplexer: Option<Multiplexer<T,R>>,
req_id: AtomicUsize,
ghost: std::marker::PhantomData<(T, R)>,
}
pub struct TcpServer<T,R> {
listener: TcpListener,
ghost: std::marker::PhantomData<(T, R)>,
}
impl<T,R> TcpClient<T,R>
where
T: for<'a> Deserialize<'a> + Serialize + Send + Sync + Unpin + 'static,
R: for<'a> Deserialize<'a> + Serialize + Send + Sync + Unpin + 'static,
{
pub async fn connect<A>(address: &A) ->
Result<TcpClient<T, R>, Error>
where
A: tokio::net::ToSocketAddrs + ?Sized,
{
let socket = TcpStream::connect(&address).await.map_err(Error::IO)?;
let (sender,receiver) = split(socket).await;
Ok(TcpClient{ sender, receiver, multiplexer: None, req_id: AtomicUsize::new(0), ghost: Default::default() })
}
pub async fn multiplex(self) -> (ChannelClient<T,R>, BoxFuture<'static, Result<(), channel::Error>>) {
let (client,multiplexer) = channel::new_multiplexer::<T,R>();
let fut = multiplexer.start(
self.sender,
|sender, data| { Box::pin(async {sender.send(data).await.is_ok()}) },
self.receiver,
|receiver| { Box::pin(async { receiver.recv().await.map_or_else(|_| None, |x| x) }) },
);
(client, Box::pin(fut))
}
pub async fn spawn(self) -> ChannelClient<T,R> {
let (client, job) = self.multiplex().await;
tokio::spawn(job);
client
}
}
impl<T,R> TcpServer<T,R> where
T: for<'a> Deserialize<'a> + Serialize,
R: for<'a> Deserialize<'a> + Serialize,
{
pub async fn new(address: &Ipv4Addr, port: u16) ->
Result<TcpServer<T,R>, Error>
{
let address = format!("{}:{}", address, port);
let listener = TcpListener::bind(&address).await.map_err(Error::IO)?;
Ok(TcpServer{ listener, ghost: Default::default() })
}
async fn accept(&mut self) -> Result<TcpStream, Error>
{
let (socket, address) = self.listener.accept().await.map_err(Error::IO)?;
println!("connection accepted: {:?}", address);
Ok(socket)
}
}
#[async_trait]
impl<T,R> ServerTransporter<T,R> for TcpServer<T,R>
where
T: for<'a> Deserialize<'a> + Serialize + Send + Sync + Unpin,
R: for<'a> Deserialize<'a> + Serialize + Send + Sync + Unpin,
{
type Error = Error;
async fn listen<F, FR, D>(&mut self, handler: F, stop_response: Option<R>, userdata: D) -> Result<(), Self::Error>
where
FR: Future<Output = Option<R>> + Send + 'static,
F: Fn(T, &D) -> FR + Send + Sync,
D: Send + Sync,
{
let (client,fut) = channel::new_multiplexer::<R,T>();
// super::async_listener(
// &mut receiver, |_self| { Box::pin(async {
// _self.recv().await
// }) },
// &mut sender, |_self, data| { Box::pin(async {
// _self.send(data).await
// }) },
// handler, stop_response, &userdata);
while let Ok(mut stream) = self.accept().await {
let (sender,receiver) = split::<R,T>(stream).await;
// tokio::spawn(async move {
// super::async_listener(
// &mut receiver, |_self| { Box::pin(async {
// _self.recv().await
// }) },
// &mut sender, |_self, data| { Box::pin(async {
// _self.send(data).await
// }) },
// handler, stop_response, &userdata)
// });
}
todo!()
}
}