rust prototype for a maelstrom server

This commit is contained in:
J. Fernando Sánchez
2025-01-21 13:20:13 +01:00
parent d025eab6e3
commit 8e7ed006e7
4 changed files with 483 additions and 44 deletions

View File

@@ -1,18 +1,18 @@
use tokio::{
task,
time::{Duration, sleep},
time::{Instant, Duration, sleep},
sync::{oneshot,
mpsc::{Receiver, Sender, channel}
}
};
use std::collections::{VecDeque, HashMap};
use std::collections::{HashMap};
use anyhow::Result;
use anyhow::{bail, Context};
use tracing::{info, debug, instrument, Level};
use tracing_subscriber::{fmt, Registry, EnvFilter};
use tracing::{info, debug, instrument};
use tracing_subscriber::{fmt, EnvFilter};
use tracing_subscriber::prelude::*;
use tracing_subscriber::fmt::format::FmtSpan;
use tokio::runtime::Handle;
use std::future::Future;
struct Service {
t: Transport,
@@ -24,10 +24,15 @@ struct Event {
in_reply_to: Option<usize>,
}
type CB = dyn (FnMut(Event, &mut Transport) -> Result<()>) + Send;
//type CB<'a> = Box<dyn Fn(Event, &'a mut Transport) -> Box<dyn Future<Output = Result<()>>>>;
//type CB<'a> = dyn Fn(Event, &mut Transport) -> BoxFuture<'a, Result<()>>;
type BoxFuture<'a, T> = Box<dyn Future<Output = T> + 'a>;
type CallbackDyn = dyn for<'a> Fn(Event, &'a mut Transport) -> BoxFuture<'a, Result<()>>;
enum Handler {
Callback(Box<CB>),
Callback(Box<CallbackDyn>),
Channel(oneshot::Sender<Event>)
}
@@ -35,7 +40,7 @@ struct Transport {
msg_id: usize,
outbox: Sender<Event>,
inbox: Receiver<Option<Event>>,
callbacks: HashMap<usize, Handler>,
callbacks: HashMap<usize, (Handler, Instant)>,
}
@@ -55,28 +60,45 @@ impl Transport {
#[instrument(level="debug",skip(self),ret)]
fn close(&mut self) -> Result<()> {
self.inbox.close();
Ok(())
let i = Instant::now();
self.callbacks.retain(|_, (_, t)| *t >= i);
if self.callbacks.is_empty() {
Ok(())
} else {
bail!("some callbacks failed to launch.")
}
}
#[instrument(level="debug",skip(self),ret)]
async fn send(&mut self, e: Event) -> Result<()> {
async fn send(&self, e: Event) -> Result<()> {
self.outbox.send(e).await.context("unable to send message")
}
#[instrument(level="debug",skip(self),ret)]
async fn rpc(&mut self, e: Event) -> Result<Event> {
let id = e.id;
self.send(e).await?;
let (tx, rx) = oneshot::channel();
self.callbacks.insert(id, Handler::Channel(tx));
let rx = self.rpc_chan(e).await?;
let resp = rx.await?;
Ok(resp)
}
async fn rpc_callback<C: FnMut(Event, &mut Transport) -> Result<()> + Send + Sized + 'static>(&mut self, e: Event, cb: C) -> Result<()> {
self.callbacks.insert(e.id, Handler::Callback(Box::new(cb)));
#[instrument(level="debug",skip(self),ret)]
async fn rpc_chan(&mut self, e: Event) -> Result<oneshot::Receiver<Event>> {
let id = e.id;
self.send(e).await?;
let (tx, rx) = oneshot::channel();
self.callbacks.insert(id, (Handler::Channel(tx), Instant::now() + Duration::from_millis(1000)));
Ok(rx)
}
async fn rpc_callback<C>(&mut self, e: Event, cb: C) -> Result<()>
where
C: for<'a> Fn(Event, &'a mut Transport) -> BoxFuture<'a, Result<()>> + 'static
{
self.callbacks.insert(e.id, (Handler::Callback(Box::new(cb)), (Instant::now() + Duration::from_millis(1000))));
self.send(e).await
}
/// #Cancellation safety: This method is **not** cancellation safe, because it deals with callbacks
#[instrument(level="debug",skip(self),ret)]
async fn recv(&mut self) -> Result<Option<Event>> {
loop {
@@ -89,16 +111,25 @@ impl Transport {
let Some(original) = nxt.in_reply_to else {
return Ok(Some(nxt));
};
match self.callbacks.remove(&nxt.id) {
let sp = match self.callbacks.remove(&original) {
None => {
return Ok(Some(nxt));
},
Some(Handler::Callback(mut cb)) => {
cb(nxt, self)?;
Some((h, t)) => {
if t < Instant::now() {
continue;
}
h
}
};
match sp {
Handler::Callback(cb) => {
let f = Box::into_pin(cb(nxt, self));
f.await?;
},
Some(Handler::Channel(c)) => {
Handler::Channel(c) => {
if let Err(e) = c.send(nxt) {
bail!("could not send event to callback")
bail!("could not send event to callback {e:?}")
}
}
}
@@ -106,6 +137,17 @@ impl Transport {
}
}
macro_rules! callback {
(|$e:ident, $svc:ident| $blk:block) => {
|$e, $svc| {
Box::new(async move {
$blk
})
}
}
}
impl Service {
#[instrument(level="debug", skip(self),ret)]
async fn process(&mut self, event: Event) -> Result<()>{
@@ -117,16 +159,9 @@ impl Service {
#[instrument(level="debug", skip(self),ret)]
async fn process_nonblocking(&mut self, event: Event) -> Result<()>{
let resp = self.t.rpc_callback(Event{id: event.id, in_reply_to: None},
|e, svc| {
let handle = Handle::current();
task::block_in_place(move || {
handle.block_on(
async move {
svc.send(Event{id: 199, in_reply_to: Some(e.id)}).await
})
})
}
).await?;
callback!(|e, svc| {
svc.send(Event{id: 199, in_reply_to: Some(e.id)}).await
})).await?;
debug!("{:?}", &resp);
Ok(())
}
@@ -136,7 +171,7 @@ impl Service {
while let Some(e) = self.t.recv().await? {
self.process_nonblocking(e).await?;
}
Ok::<(), anyhow::Error>(())
self.t.close()
}
}
@@ -148,9 +183,11 @@ async fn main() -> Result<()> {
.with(EnvFilter::from_default_env())
.init();
info!("Starting example server");
let (t, tx, mut rx) = Transport::new();
let mut serv = Service{t};
let serv = Service{t};
let tx1 = tx.clone();
let s = tokio::spawn(async move {
@@ -162,28 +199,29 @@ async fn main() -> Result<()> {
});
debug!("Serving");
let r = tokio::spawn(async move {
for i in 0..5 {
for i in 0..2 {
let Some(msg) = rx.recv().await else {
break;
};
if msg.in_reply_to.is_some() {
continue;
}
sleep(Duration::from_millis(1000)).await;
let reply = Event{id: 99, in_reply_to: Some(msg.id)};
debug!("Sending reply: {:?}", &reply);
tx.send(Some(reply)).await.context("could not send reply")?;
sleep(Duration::from_millis(200)).await;
if i < 0 {
let reply = Event{id: 99, in_reply_to: Some(msg.id)};
debug!("Sending reply: {:?}", &reply);
tx.send(Some(reply)).await.context("could not send reply")?;
} else {
println!("DONE WITH REPLIES");
}
}
println!("DONE WITH REPLIES");
sleep(Duration::from_millis(5000)).await;
println!("REALLY DONE WITH REPLIES");
sleep(Duration::from_millis(100)).await;
tx.send(None).await.context("could not send closing message")?;
Ok(())
});
serv.serve().await?;
debug!("Served");
//serv.close();
info!("Served");
s.await.context("error joining")??;
r.await.context("error joining send")?
}