diff --git a/crates/futures/src/mpmc.rs b/crates/futures/src/mpmc.rs index e358a43a8f3b390cf5a8ec9522d4d1b43cb1da80..48441c5616a218cc06052fc76a0239ce88470a2b 100644 --- a/crates/futures/src/mpmc.rs +++ b/crates/futures/src/mpmc.rs @@ -1,73 +1,124 @@ -use alloc::sync::Arc; -use core::{ - future::Future, - pin::Pin, - sync::atomic::{AtomicUsize, Ordering}, - task::{Context, Poll}, -}; +use alloc::{boxed::Box, sync::Arc}; +use core::sync::atomic::{AtomicUsize, Ordering}; use crossbeam_queue::SegQueue; -use futures_util::{task::AtomicWaker, FutureExt, Stream}; +use event_listener::Event; +use futures::stream::{unfold, Stream}; + +use crate::block_on; // pub fn channel<T>() -> (Sender<T>, Receiver<T>) { - let inner = Arc::new(SplitChannel::new()); - ( - Sender { - inner: inner.clone(), - }, - Receiver { inner }, - ) + Channel::new().split() +} + +// + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct SendError<T>(pub T); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RecvError; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TryRecvError { + Empty, + Closed, } // -#[derive(Clone)] pub struct Sender<T> { inner: Arc<SplitChannel<T>>, } impl<T> Sender<T> { - pub fn send(&self, data: T) -> Option<()> { + pub fn send(&self, data: T) -> Result<(), SendError<T>> { self.inner.send(data) } + + pub fn receiver(&self) -> Option<Receiver<T>> { + loop { + let current = self.inner.readers.load(Ordering::Acquire); + if current == 0 { + // the read end was closed + return None; + } + + // fetch_add but don't increment if it was 0 + // once the count goes to 0, the channel is permanently closed + if self + .inner + .readers + .compare_exchange(current, current + 1, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + { + break; + } + } + + Some(Receiver { + inner: self.inner.clone(), + }) + } +} + +impl<T> Clone for Sender<T> { + fn clone(&self) -> Self { + self.inner.writers.fetch_add(1, Ordering::Acquire); + Self { + inner: self.inner.clone(), + } + } } impl<T> Drop for Sender<T> { fn drop(&mut self) { - if self.inner.writers.fetch_sub(1, Ordering::SeqCst) == 1 { - self.inner.channel.waker.wake(); - } + self.inner.close_send(); } } // -#[derive(Clone)] pub struct Receiver<T> { inner: Arc<SplitChannel<T>>, } impl<T> Receiver<T> { - pub fn recv(&self) -> Recv<T> { - self.inner.recv() + pub async fn recv(&self) -> Result<T, RecvError> { + self.inner.recv().await } - pub fn try_recv(&self) -> Option<T> { + pub fn blocking_recv(&self) -> Result<T, RecvError> { + self.inner.blocking_recv() + } + + pub fn spin_recv(&self) -> Result<T, RecvError> { + self.inner.spin_recv() + } + + pub fn try_recv(&self) -> Result<T, TryRecvError> { self.inner.try_recv() } - pub fn race_stream(&self) -> RecvStream<T> { + pub fn race_stream(&self) -> impl Stream<Item = T> + Unpin + '_ { self.inner.race_stream() } } +impl<T> Clone for Receiver<T> { + fn clone(&self) -> Self { + self.inner.readers.fetch_add(1, Ordering::Acquire); + Self { + inner: self.inner.clone(), + } + } +} + impl<T> Drop for Receiver<T> { fn drop(&mut self) { - if self.inner.readers.fetch_sub(1, Ordering::SeqCst) == 1 { - self.inner.channel.waker.wake(); - } + self.inner.close_recv(); } } @@ -75,104 +126,69 @@ impl<T> Drop for Receiver<T> { pub struct Channel<T> { queue: SegQueue<T>, - waker: AtomicWaker, + wakers: Event, } impl<T> Channel<T> { pub const fn new() -> Self { Self { queue: SegQueue::new(), - waker: AtomicWaker::new(), + wakers: Event::new(), } } - pub fn send(&self, val: T) { - self.queue.push(val); - self.waker.wake(); + pub fn split(self) -> (Sender<T>, Receiver<T>) { + let inner = Arc::new(SplitChannel::new(self)); + let tx = Sender { + inner: inner.clone(), + }; + let rx = Receiver { inner }; + + (tx, rx) } - pub fn recv(&self) -> ChannelRecv<T> { - ChannelRecv { inner: self } + pub fn send(&self, val: T) { + self.queue.push(val); + self.wakers.notify(1); } pub fn try_recv(&self) -> Option<T> { self.queue.pop() } - pub fn race_stream(&self) -> ChannelRecvStream<T> { - ChannelRecvStream { inner: self } - } -} - -// - -pub struct ChannelRecv<'a, T> { - inner: &'a Channel<T>, -} - -impl<'a, T> Future for ChannelRecv<'a, T> { - type Output = Option<T>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { - if let Some(v) = self.inner.try_recv() { - return Poll::Ready(Some(v)); - } - - self.inner.waker.register(cx.waker()); - - if let Some(v) = self.inner.try_recv() { - self.inner.waker.take(); - return Poll::Ready(Some(v)); + pub async fn recv(&self) -> T { + if let Some(val) = self.try_recv() { + return val; } - Poll::Pending + self.recv_slow().await } -} - -pub struct ChannelRecvStream<'a, T> { - inner: &'a Channel<T>, -} - -impl<'a, T> Stream for ChannelRecvStream<'a, T> { - type Item = T; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - self.inner.recv().poll_unpin(cx) + pub fn race_stream(&self) -> impl Stream<Item = T> + Unpin + '_ { + unfold(self, |ch| { + // TODO: manual Future impl to get rid of the Box::pin + Box::pin(async move { + let item = ch.recv().await; + Some((item, ch)) + }) + }) } -} -pub struct Recv<'a, T> { - inner: &'a SplitChannel<T>, -} + #[cold] + async fn recv_slow(&self) -> T { + loop { + let l = self.wakers.listen(); -impl<'a, T> Future for Recv<'a, T> { - type Output = Option<T>; + if let Some(val) = self.try_recv() { + return val; + } - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { - let closed = self.inner.writers.load(Ordering::SeqCst) == 0; + l.await; - if let Poll::Ready(val) = self.inner.channel.recv().poll_unpin(cx) { - return Poll::Ready(val); + if let Some(val) = self.try_recv() { + return val; + } } - - if closed { - self.inner.channel.waker.take(); - Poll::Ready(None) - } else { - Poll::Pending - } - } -} - -pub struct RecvStream<'a, T> { - inner: &'a SplitChannel<T>, -} - -impl<'a, T> Stream for RecvStream<'a, T> { - type Item = T; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - self.inner.recv().poll_unpin(cx) } } @@ -185,32 +201,107 @@ struct SplitChannel<T> { } impl<T> SplitChannel<T> { - const fn new() -> Self { + const fn new(channel: Channel<T>) -> Self { Self { readers: AtomicUsize::new(1), writers: AtomicUsize::new(1), - channel: Channel::new(), + channel, + } + } + + fn close_send(&self) { + if self.writers.fetch_sub(1, Ordering::SeqCst) == 1 { + self.channel.wakers.notify(usize::MAX); + } + } + + fn close_recv(&self) { + if self.readers.fetch_sub(1, Ordering::SeqCst) == 1 { + // self.channel.wakers.notify(usize::MAX); } } - fn send(&self, val: T) -> Option<()> { - if self.readers.load(Ordering::SeqCst) == 0 { - return None; + fn is_send_closed(&self) -> bool { + self.writers.load(Ordering::SeqCst) == 0 + } + + fn is_recv_closed(&self) -> bool { + self.readers.load(Ordering::SeqCst) == 0 + } + + fn send(&self, val: T) -> Result<(), SendError<T>> { + if self.is_recv_closed() { + return Err(SendError(val)); } self.channel.send(val); - Some(()) + Ok(()) } - pub fn recv(&self) -> Recv<T> { - Recv { inner: self } + async fn recv(&self) -> Result<T, RecvError> { + match self.try_recv() { + Ok(val) => return Ok(val), + Err(TryRecvError::Empty) => {} + Err(TryRecvError::Closed) => return Err(RecvError), + } + + self.recv_slow().await } - pub fn try_recv(&self) -> Option<T> { - self.channel.try_recv() + fn blocking_recv(&self) -> Result<T, RecvError> { + block_on(self.recv()) + } + + fn spin_recv(&self) -> Result<T, RecvError> { + loop { + match self.try_recv() { + Ok(val) => return Ok(val), + Err(TryRecvError::Empty) => {} + Err(TryRecvError::Closed) => return Err(RecvError), + } + } + } + + fn try_recv(&self) -> Result<T, TryRecvError> { + if let Some(val) = self.channel.try_recv() { + return Ok(val); + } + + if self.is_send_closed() { + return Err(TryRecvError::Closed); + } + + Err(TryRecvError::Empty) } - pub fn race_stream(&self) -> RecvStream<T> { - RecvStream { inner: self } + fn race_stream(&self) -> impl Stream<Item = T> + Unpin + '_ { + unfold(self, |ch| { + // TODO: same as [`Channel::race_stream`] + Box::pin(async move { + let item = Box::pin(ch.recv()).await.ok()?; + Some((item, ch)) + }) + }) + } + + #[cold] + async fn recv_slow(&self) -> Result<T, RecvError> { + loop { + let l = self.channel.wakers.listen(); + + match self.try_recv() { + Ok(val) => return Ok(val), + Err(TryRecvError::Empty) => {} + Err(TryRecvError::Closed) => return Err(RecvError), + } + + l.await; + + match self.try_recv() { + Ok(val) => return Ok(val), + Err(TryRecvError::Empty) => {} + Err(TryRecvError::Closed) => return Err(RecvError), + } + } } } diff --git a/crates/kshell/src/shell.rs b/crates/kshell/src/shell.rs index 6767fe9a835a3d06fa4bce465dc90026d862dcc5..be7b8506e65d9d3f5d2bcae18d0970500e7b33c3 100644 --- a/crates/kshell/src/shell.rs +++ b/crates/kshell/src/shell.rs @@ -3,12 +3,13 @@ use alloc::{ string::{String, ToString}, sync::Arc, }; -use core::{fmt::Write, sync::atomic::Ordering}; +use core::{fmt::Write, str, sync::atomic::Ordering}; use anyhow::anyhow; use futures_util::stream::select; use hyperion_cpu_id::cpu_count; use hyperion_driver_acpi::apic::ApicId; +use hyperion_futures::mpmc; use hyperion_instant::Instant; use hyperion_kernel_impl::{FileDescData, FileDescriptor}; use hyperion_keyboard::{ @@ -19,7 +20,7 @@ use hyperion_mem::pmm; use hyperion_num_postfix::NumberPostfix; use hyperion_scheduler::{ idle, - ipc::pipe::pipe, + ipc::pipe::{self, pipe}, spawn, task::{processes, Pid, TASKS_READY, TASKS_RUNNING, TASKS_SLEEPING}, }; @@ -211,48 +212,30 @@ impl Shell { let (o_tx, o_rx) = hyperion_futures::mpmc::channel(); // last program's stdout (stdin) -> terminal let o_tx_2 = o_tx.clone(); - spawn(move || { - loop { - let mut buf = [0; 128]; - let Ok(len) = stderr_rx.read(&mut buf) else { - trace!("end of stream"); - break; - }; - if len == 0 { - trace!("end of stream"); - break; - } - let Ok(str) = core::str::from_utf8(&buf[..len]) else { - trace!("invalid utf8"); - break; - }; - // debug!("stderr:{str}"); - o_tx_2.send(Some(str.to_string())); - } - o_tx_2.send(None); - }); - spawn(move || { + fn try_forward(from: &dyn FileDescriptor, to: &mpmc::Sender<Option<String>>) -> Option<()> { loop { - let mut buf = [0; 128]; - let Ok(len) = stdin.read(&mut buf) else { - trace!("end of stream"); - break; - }; + // TODO: if the buffer is full, the result might not be UTF-8 + let mut buf = [0u8; 512]; + let len = from.read(&mut buf).ok()?; + if len == 0 { - trace!("end of stream"); - break; + return None; } - let Ok(str) = core::str::from_utf8(&buf[..len]) else { - trace!("invalid utf8"); - break; - }; - // debug!("stdout:{str}"); - o_tx.send(Some(str.to_string())); + + let str = str::from_utf8(&buf[..len]).ok()?.to_string(); + + to.send(Some(str)).ok()?; } + } + + fn forward(from: &dyn FileDescriptor, to: mpmc::Sender<Option<String>>) { + _ = try_forward(from, &to); + _ = to.send(None); + } - o_tx.send(None); - }); + spawn(move || _ = forward(&stderr_rx, o_tx_2)); + spawn(move || _ = forward(&*stdin, o_tx)); // start sending keyboard events to the process and read stdout into the terminal let mut events = select(KeyboardEvents.map(Ok), o_rx.race_stream().map(Err)); @@ -342,9 +325,6 @@ impl Shell { self.term.flush(); } Some(Err(None)) => { - // _ = write!(self.term, "got EOI"); - // self.term.flush(); - trace!("EOI"); break; } None => {