Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
akaladarshi committed Jan 31, 2025
1 parent 6e031d6 commit f7ad5e9
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 91 deletions.
33 changes: 15 additions & 18 deletions yamux/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
}

fn on_drop_stream(&mut self, stream_id: StreamId) -> Option<Frame<()>> {
let s = self.streams.remove(&stream_id).expect("stream not found");
let mut s = self.streams.remove(&stream_id).expect("stream not found");

log::trace!("{}: removing dropped stream {}", self.id, stream_id);
let frame = s.with_mut(|inner| {
Expand Down Expand Up @@ -565,7 +565,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
&& matches!(frame.header().tag(), Tag::Data | Tag::WindowUpdate)
{
let id = frame.header().stream_id();
if let Some(shared) = self.streams.get(&id) {
if let Some(shared) = self.streams.get_mut(&id) {
shared.update_state(self.id, id, State::Open { acknowledged: true });
}
if let Some(waker) = self.new_outbound_stream_waker.take() {
Expand Down Expand Up @@ -625,16 +625,14 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
log::error!("{}: maximum number of streams reached", self.id);
return Action::Terminate(Frame::internal_error());
}
let stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT);
{
stream.shared().with_mut(|inner| {
if is_finish {
inner.update_state(self.id, stream_id, State::RecvClosed);
}
inner.consume_receive_window(frame.body_len());
inner.buffer.push(frame.into_body());
})
}
let mut stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT);
stream.shared_mut().with_mut(|inner| {
if is_finish {
inner.update_state(self.id, stream_id, State::RecvClosed);
}
inner.consume_receive_window(frame.body_len());
inner.buffer.push(frame.into_body());
});
self.streams.insert(stream_id, stream.clone_shared());
return Action::New(stream);
}
Expand All @@ -660,7 +658,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Action::None
}
});
return action;
action
} else {
log::trace!(
"{}/{}: data frame for unknown stream, possibly dropped earlier: {:?}",
Expand All @@ -675,9 +673,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
// termination for the remote.
//
// See https://github.com/paritytech/yamux/issues/110 for details.
Action::None
}

Action::None
}

fn on_window_update(&mut self, frame: &Frame<WindowUpdate>) -> Action {
Expand Down Expand Up @@ -717,11 +714,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
}

let credit = frame.header().credit() + DEFAULT_CREDIT;
let stream = self.make_new_inbound_stream(stream_id, credit);
let mut stream = self.make_new_inbound_stream(stream_id, credit);

if is_finish {
stream
.shared()
.shared_mut()
.update_state(self.id, stream_id, State::RecvClosed);
}
self.streams.insert(stream_id, stream.clone_shared());
Expand Down Expand Up @@ -874,7 +871,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
impl<T> Active<T> {
/// Close and drop all `Stream`s and wake any pending `Waker`s.
fn drop_all_streams(&mut self) {
for (id, shared) in self.streams.drain() {
for (id, mut shared) in self.streams.drain() {
shared.with_mut(|inner| {
inner.update_state(self.id, id, State::Closed);
if let Some(w) = inner.reader.take() {
Expand Down
116 changes: 43 additions & 73 deletions yamux/src/connection/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
// as LICENSE-MIT. You may also obtain a copy of the Apache License, Version 2.0
// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license
// at https://opensource.org/licenses/MIT.

use crate::chunks::Chunk;
use crate::connection::rtt::Rtt;
use crate::frame::header::ACK;
use crate::{
Expand All @@ -28,7 +26,7 @@ use futures::{
ready, SinkExt,
};

use parking_lot::Mutex;
use parking_lot::{Mutex, MutexGuard};
use std::{
fmt, io,
pin::Pin,
Expand Down Expand Up @@ -179,14 +177,12 @@ impl Stream {
matches!(self.shared.state(), State::Closed)
}

/// Whether we are still waiting for the remote to acknowledge this stream.
pub fn is_pending_ack(&self) -> bool {
self.shared.is_pending_ack()
}

/// Returns a reference to the `Shared` concurrency wrapper.
pub(crate) fn shared(&self) -> &Shared {
&self.shared
pub(crate) fn shared_mut(&mut self) -> &mut Shared {
&mut self.shared
}

pub(crate) fn clone_shared(&self) -> Shared {
Expand Down Expand Up @@ -265,7 +261,8 @@ impl futures::stream::Stream for Stream {
Poll::Pending => {}
}

if let Some(bytes) = self.shared.pop_buffer() {
let mut shared = self.shared.lock();
if let Some(bytes) = shared.buffer.pop() {
let off = bytes.offset();
let mut vec = bytes.into_vec();
if off != 0 {
Expand All @@ -276,23 +273,21 @@ impl futures::stream::Stream for Stream {
log::debug!(
"{}/{}: chunk has been partially consumed",
self.conn,
self.id
self.id,
);
vec = vec.split_off(off)
}
return Poll::Ready(Some(Ok(Packet(vec))));
}

// Buffer is empty, let's check if we can expect to read more data.
if !self.shared.state().can_read() {
if !shared.state.can_read() {
log::debug!("{}/{}: eof", self.conn, self.id);
return Poll::Ready(None); // stream has been reset
}

// Since we have no more data at this point, we want to be woken up
// by the connection when more becomes available for us.
self.shared.set_reader_waker(Some(cx.waker().clone()));

shared.reader = Some(cx.waker().clone());
Poll::Pending
}
}
Expand All @@ -317,47 +312,36 @@ impl AsyncRead for Stream {
}

// Copy data from stream buffer.
let mut shared = self.shared.lock();
let mut n = 0;
let can_read = self.shared.with_mut(|inner| {
while let Some(chunk) = inner.buffer.front_mut() {
if chunk.is_empty() {
inner.buffer.pop();
continue;
}
let k = std::cmp::min(chunk.len(), buf.len() - n);
buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]);
n += k;
chunk.advance(k);
if n == buf.len() {
break;
}
while let Some(chunk) = shared.buffer.front_mut() {
if chunk.is_empty() {
shared.buffer.pop();
continue;
}

if n > 0 {
return true;
}

// Buffer is empty, let's check if we can expect to read more data.
if !inner.state.can_read() {
return false; // No more data available
let k = std::cmp::min(chunk.len(), buf.len() - n);
buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]);
n += k;
chunk.advance(k);
if n == buf.len() {
break;
}

// Since we have no more data at this point, we want to be woken up
// by the connection when more becomes available for us.
inner.reader = Some(cx.waker().clone());
true
});
}

if n > 0 {
log::trace!("{}/{}: read {} bytes", self.conn, self.id, n);
return Poll::Ready(Ok(n));
}

if !can_read {
// Buffer is empty, let's check if we can expect to read more data.
if !shared.state.can_read() {
log::debug!("{}/{}: eof", self.conn, self.id);
return Poll::Ready(Ok(0)); // stream has been reset
}

// Since we have no more data at this point, we want to be woken up
// by the connection when more becomes available for us.
shared.reader = Some(cx.waker().clone());
Poll::Pending
}
}
Expand All @@ -373,18 +357,19 @@ impl AsyncWrite for Stream {
.poll_ready(cx)
.map_err(|_| self.write_zero_err())?);

let result = self.shared.with_mut(|inner| {
if !inner.state.can_write() {
let body = {
let mut shared = self.shared.lock();
if !shared.state.can_write() {
log::debug!("{}/{}: can no longer write", self.conn, self.id);
// Return an error
return Err(self.write_zero_err());
return Poll::Ready(Err(self.write_zero_err()));
}

let window = inner.send_window();
let window = shared.send_window();
if window == 0 {
log::trace!("{}/{}: no more credit left", self.conn, self.id);
inner.writer = Some(cx.waker().clone());
return Ok(None); // means we are Pending
shared.writer = Some(cx.waker().clone());
return Poll::Pending;
}

let k = std::cmp::min(window, buf.len().try_into().unwrap_or(u32::MAX));
Expand All @@ -394,15 +379,8 @@ impl AsyncWrite for Stream {
self.config.split_send_size.try_into().unwrap_or(u32::MAX),
);

inner.consume_send_window(k);
let body = Some(Vec::from(&buf[..k as usize]));
Ok(body)
});

let body = match result {
Err(e) => return Poll::Ready(Err(e)), // can't write
Ok(None) => return Poll::Pending, // no credit => Pending
Ok(Some(b)) => b, // we have a body
shared.consume_send_window(k);
Vec::from(&buf[..k as usize])
};

let n = body.len();
Expand All @@ -415,9 +393,8 @@ impl AsyncWrite for Stream {
// a) to be consistent with outbound streams
// b) to correctly test our behaviour around timing of when ACKs are sent. See `ack_timing.rs` test.
if frame.header().flags().contains(ACK) {
self.shared.with_mut(|inner| {
inner.update_state(self.conn, self.id, State::Open { acknowledged: true });
});
self.shared
.update_state(self.conn, self.id, State::Open { acknowledged: true });
}

let cmd = StreamCommand::SendFrame(frame);
Expand Down Expand Up @@ -452,9 +429,8 @@ impl AsyncWrite for Stream {
self.sender
.start_send(cmd)
.map_err(|_| self.write_zero_err())?;
self.shared.with_mut(|inner| {
inner.update_state(self.conn, self.id, State::SendClosed);
});
self.shared
.update_state(self.conn, self.id, State::SendClosed);
Poll::Ready(Ok(()))
}
}
Expand Down Expand Up @@ -487,10 +463,6 @@ impl Shared {
self.inner.lock().state
}

pub fn pop_buffer(&self) -> Option<Chunk> {
self.with_mut(|inner| inner.buffer.pop())
}

pub fn is_pending_ack(&self) -> bool {
self.inner.lock().is_pending_ack()
}
Expand All @@ -499,17 +471,15 @@ impl Shared {
self.inner.lock().next_window_update()
}

pub fn set_reader_waker(&self, waker: Option<Waker>) {
self.with_mut(|inner| {
inner.reader = waker;
});
pub fn update_state(&self, cid: connection::Id, sid: StreamId, next: State) -> State {
self.inner.lock().update_state(cid, sid, next)
}

pub fn update_state(&self, cid: connection::Id, sid: StreamId, next: State) -> State {
self.with_mut(|inner| inner.update_state(cid, sid, next))
pub fn lock(&self) -> MutexGuard<'_, SharedInner> {
self.inner.lock()
}

pub fn with_mut<F, R>(&self, f: F) -> R
pub fn with_mut<F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut SharedInner) -> R,
{
Expand Down

0 comments on commit f7ad5e9

Please sign in to comment.