Skip to content
Snippets Groups Projects
Commit 268889ca authored by Eemeli's avatar Eemeli
Browse files

fix futures mpmc channel

parent 8e8443d1
Branches
No related tags found
No related merge requests found
use alloc::sync::Arc;
use core::{
future::Future,
pin::Pin,
sync::atomic::{AtomicUsize, Ordering},
task::{Context, Poll},
};
use alloc::{boxed::Box, sync::Arc};
use core::sync::atomic::{AtomicUsize, Ordering};
use crossbeam_queue::SegQueue;
use futures_util::{task::AtomicWaker, FutureExt, Stream};
use event_listener::Event;
use futures::stream::{unfold, Stream};
use crate::block_on;
//
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let inner = Arc::new(SplitChannel::new());
(
Sender {
inner: inner.clone(),
},
Receiver { inner },
)
Channel::new().split()
}
//
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SendError<T>(pub T);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RecvError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TryRecvError {
Empty,
Closed,
}
//
#[derive(Clone)]
pub struct Sender<T> {
inner: Arc<SplitChannel<T>>,
}
impl<T> Sender<T> {
pub fn send(&self, data: T) -> Option<()> {
pub fn send(&self, data: T) -> Result<(), SendError<T>> {
self.inner.send(data)
}
pub fn receiver(&self) -> Option<Receiver<T>> {
loop {
let current = self.inner.readers.load(Ordering::Acquire);
if current == 0 {
// the read end was closed
return None;
}
// fetch_add but don't increment if it was 0
// once the count goes to 0, the channel is permanently closed
if self
.inner
.readers
.compare_exchange(current, current + 1, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
break;
}
}
Some(Receiver {
inner: self.inner.clone(),
})
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
self.inner.writers.fetch_add(1, Ordering::Acquire);
Self {
inner: self.inner.clone(),
}
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
if self.inner.writers.fetch_sub(1, Ordering::SeqCst) == 1 {
self.inner.channel.waker.wake();
}
self.inner.close_send();
}
}
//
#[derive(Clone)]
pub struct Receiver<T> {
inner: Arc<SplitChannel<T>>,
}
impl<T> Receiver<T> {
pub fn recv(&self) -> Recv<T> {
self.inner.recv()
pub async fn recv(&self) -> Result<T, RecvError> {
self.inner.recv().await
}
pub fn try_recv(&self) -> Option<T> {
pub fn blocking_recv(&self) -> Result<T, RecvError> {
self.inner.blocking_recv()
}
pub fn spin_recv(&self) -> Result<T, RecvError> {
self.inner.spin_recv()
}
pub fn try_recv(&self) -> Result<T, TryRecvError> {
self.inner.try_recv()
}
pub fn race_stream(&self) -> RecvStream<T> {
pub fn race_stream(&self) -> impl Stream<Item = T> + Unpin + '_ {
self.inner.race_stream()
}
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
self.inner.readers.fetch_add(1, Ordering::Acquire);
Self {
inner: self.inner.clone(),
}
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
if self.inner.readers.fetch_sub(1, Ordering::SeqCst) == 1 {
self.inner.channel.waker.wake();
}
self.inner.close_recv();
}
}
......@@ -75,142 +126,182 @@ impl<T> Drop for Receiver<T> {
pub struct Channel<T> {
queue: SegQueue<T>,
waker: AtomicWaker,
wakers: Event,
}
impl<T> Channel<T> {
pub const fn new() -> Self {
Self {
queue: SegQueue::new(),
waker: AtomicWaker::new(),
wakers: Event::new(),
}
}
pub fn send(&self, val: T) {
self.queue.push(val);
self.waker.wake();
pub fn split(self) -> (Sender<T>, Receiver<T>) {
let inner = Arc::new(SplitChannel::new(self));
let tx = Sender {
inner: inner.clone(),
};
let rx = Receiver { inner };
(tx, rx)
}
pub fn recv(&self) -> ChannelRecv<T> {
ChannelRecv { inner: self }
pub fn send(&self, val: T) {
self.queue.push(val);
self.wakers.notify(1);
}
pub fn try_recv(&self) -> Option<T> {
self.queue.pop()
}
pub fn race_stream(&self) -> ChannelRecvStream<T> {
ChannelRecvStream { inner: self }
}
pub async fn recv(&self) -> T {
if let Some(val) = self.try_recv() {
return val;
}
//
self.recv_slow().await
}
pub struct ChannelRecv<'a, T> {
inner: &'a Channel<T>,
pub fn race_stream(&self) -> impl Stream<Item = T> + Unpin + '_ {
unfold(self, |ch| {
// TODO: manual Future impl to get rid of the Box::pin
Box::pin(async move {
let item = ch.recv().await;
Some((item, ch))
})
})
}
impl<'a, T> Future for ChannelRecv<'a, T> {
type Output = Option<T>;
#[cold]
async fn recv_slow(&self) -> T {
loop {
let l = self.wakers.listen();
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
if let Some(v) = self.inner.try_recv() {
return Poll::Ready(Some(v));
if let Some(val) = self.try_recv() {
return val;
}
self.inner.waker.register(cx.waker());
l.await;
if let Some(v) = self.inner.try_recv() {
self.inner.waker.take();
return Poll::Ready(Some(v));
if let Some(val) = self.try_recv() {
return val;
}
Poll::Pending
}
}
pub struct ChannelRecvStream<'a, T> {
inner: &'a Channel<T>,
}
impl<'a, T> Stream for ChannelRecvStream<'a, T> {
type Item = T;
//
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.recv().poll_unpin(cx)
}
struct SplitChannel<T> {
readers: AtomicUsize,
writers: AtomicUsize,
channel: Channel<T>,
}
pub struct Recv<'a, T> {
inner: &'a SplitChannel<T>,
impl<T> SplitChannel<T> {
const fn new(channel: Channel<T>) -> Self {
Self {
readers: AtomicUsize::new(1),
writers: AtomicUsize::new(1),
channel,
}
}
impl<'a, T> Future for Recv<'a, T> {
type Output = Option<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let closed = self.inner.writers.load(Ordering::SeqCst) == 0;
if let Poll::Ready(val) = self.inner.channel.recv().poll_unpin(cx) {
return Poll::Ready(val);
fn close_send(&self) {
if self.writers.fetch_sub(1, Ordering::SeqCst) == 1 {
self.channel.wakers.notify(usize::MAX);
}
}
if closed {
self.inner.channel.waker.take();
Poll::Ready(None)
} else {
Poll::Pending
fn close_recv(&self) {
if self.readers.fetch_sub(1, Ordering::SeqCst) == 1 {
// self.channel.wakers.notify(usize::MAX);
}
}
fn is_send_closed(&self) -> bool {
self.writers.load(Ordering::SeqCst) == 0
}
pub struct RecvStream<'a, T> {
inner: &'a SplitChannel<T>,
fn is_recv_closed(&self) -> bool {
self.readers.load(Ordering::SeqCst) == 0
}
impl<'a, T> Stream for RecvStream<'a, T> {
type Item = T;
fn send(&self, val: T) -> Result<(), SendError<T>> {
if self.is_recv_closed() {
return Err(SendError(val));
}
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.recv().poll_unpin(cx)
self.channel.send(val);
Ok(())
}
async fn recv(&self) -> Result<T, RecvError> {
match self.try_recv() {
Ok(val) => return Ok(val),
Err(TryRecvError::Empty) => {}
Err(TryRecvError::Closed) => return Err(RecvError),
}
//
self.recv_slow().await
}
struct SplitChannel<T> {
readers: AtomicUsize,
writers: AtomicUsize,
channel: Channel<T>,
fn blocking_recv(&self) -> Result<T, RecvError> {
block_on(self.recv())
}
impl<T> SplitChannel<T> {
const fn new() -> Self {
Self {
readers: AtomicUsize::new(1),
writers: AtomicUsize::new(1),
channel: Channel::new(),
fn spin_recv(&self) -> Result<T, RecvError> {
loop {
match self.try_recv() {
Ok(val) => return Ok(val),
Err(TryRecvError::Empty) => {}
Err(TryRecvError::Closed) => return Err(RecvError),
}
}
}
fn send(&self, val: T) -> Option<()> {
if self.readers.load(Ordering::SeqCst) == 0 {
return None;
fn try_recv(&self) -> Result<T, TryRecvError> {
if let Some(val) = self.channel.try_recv() {
return Ok(val);
}
self.channel.send(val);
Some(())
if self.is_send_closed() {
return Err(TryRecvError::Closed);
}
pub fn recv(&self) -> Recv<T> {
Recv { inner: self }
Err(TryRecvError::Empty)
}
pub fn try_recv(&self) -> Option<T> {
self.channel.try_recv()
fn race_stream(&self) -> impl Stream<Item = T> + Unpin + '_ {
unfold(self, |ch| {
// TODO: same as [`Channel::race_stream`]
Box::pin(async move {
let item = Box::pin(ch.recv()).await.ok()?;
Some((item, ch))
})
})
}
#[cold]
async fn recv_slow(&self) -> Result<T, RecvError> {
loop {
let l = self.channel.wakers.listen();
match self.try_recv() {
Ok(val) => return Ok(val),
Err(TryRecvError::Empty) => {}
Err(TryRecvError::Closed) => return Err(RecvError),
}
pub fn race_stream(&self) -> RecvStream<T> {
RecvStream { inner: self }
l.await;
match self.try_recv() {
Ok(val) => return Ok(val),
Err(TryRecvError::Empty) => {}
Err(TryRecvError::Closed) => return Err(RecvError),
}
}
}
}
......@@ -3,12 +3,13 @@ use alloc::{
string::{String, ToString},
sync::Arc,
};
use core::{fmt::Write, sync::atomic::Ordering};
use core::{fmt::Write, str, sync::atomic::Ordering};
use anyhow::anyhow;
use futures_util::stream::select;
use hyperion_cpu_id::cpu_count;
use hyperion_driver_acpi::apic::ApicId;
use hyperion_futures::mpmc;
use hyperion_instant::Instant;
use hyperion_kernel_impl::{FileDescData, FileDescriptor};
use hyperion_keyboard::{
......@@ -19,7 +20,7 @@ use hyperion_mem::pmm;
use hyperion_num_postfix::NumberPostfix;
use hyperion_scheduler::{
idle,
ipc::pipe::pipe,
ipc::pipe::{self, pipe},
spawn,
task::{processes, Pid, TASKS_READY, TASKS_RUNNING, TASKS_SLEEPING},
};
......@@ -211,48 +212,30 @@ impl Shell {
let (o_tx, o_rx) = hyperion_futures::mpmc::channel();
// last program's stdout (stdin) -> terminal
let o_tx_2 = o_tx.clone();
spawn(move || {
fn try_forward(from: &dyn FileDescriptor, to: &mpmc::Sender<Option<String>>) -> Option<()> {
loop {
let mut buf = [0; 128];
let Ok(len) = stderr_rx.read(&mut buf) else {
trace!("end of stream");
break;
};
// TODO: if the buffer is full, the result might not be UTF-8
let mut buf = [0u8; 512];
let len = from.read(&mut buf).ok()?;
if len == 0 {
trace!("end of stream");
break;
}
let Ok(str) = core::str::from_utf8(&buf[..len]) else {
trace!("invalid utf8");
break;
};
// debug!("stderr:{str}");
o_tx_2.send(Some(str.to_string()));
return None;
}
o_tx_2.send(None);
});
spawn(move || {
loop {
let mut buf = [0; 128];
let Ok(len) = stdin.read(&mut buf) else {
trace!("end of stream");
break;
};
if len == 0 {
trace!("end of stream");
break;
let str = str::from_utf8(&buf[..len]).ok()?.to_string();
to.send(Some(str)).ok()?;
}
let Ok(str) = core::str::from_utf8(&buf[..len]) else {
trace!("invalid utf8");
break;
};
// debug!("stdout:{str}");
o_tx.send(Some(str.to_string()));
}
o_tx.send(None);
});
fn forward(from: &dyn FileDescriptor, to: mpmc::Sender<Option<String>>) {
_ = try_forward(from, &to);
_ = to.send(None);
}
spawn(move || _ = forward(&stderr_rx, o_tx_2));
spawn(move || _ = forward(&*stdin, o_tx));
// start sending keyboard events to the process and read stdout into the terminal
let mut events = select(KeyboardEvents.map(Ok), o_rx.race_stream().map(Err));
......@@ -342,9 +325,6 @@ impl Shell {
self.term.flush();
}
Some(Err(None)) => {
// _ = write!(self.term, "got EOI");
// self.term.flush();
trace!("EOI");
break;
}
None => {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment