Skip to content

Commit

Permalink
client: use VecDeque<OwnedFd> to store in/out fd queues
Browse files Browse the repository at this point in the history
This fixes a potential fd leak, when the buffers are not empty but
the Connection is dropped.
  • Loading branch information
MaxVerevkin committed Feb 28, 2024
1 parent 640409d commit 0ef0043
Showing 1 changed file with 27 additions and 100 deletions.
127 changes: 27 additions & 100 deletions wayrs-client/src/socket.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::collections::VecDeque;
use std::env;
use std::ffi::CString;
use std::io::{self, IoSlice, IoSliceMut};
use std::num::NonZeroU32;
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
use std::os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd};
use std::os::unix::net::UnixStream;
use std::path::PathBuf;

Expand All @@ -13,7 +14,7 @@ use crate::object::{Object, ObjectId};
use crate::wire::{ArgType, ArgValue, Fixed, Message, MessageHeader};
use crate::{ConnectError, IoMode};

use buf::{ArrayBuffer, RingBuffer};
use buf::RingBuffer;

pub const BYTES_OUT_LEN: usize = 4096;
pub const BYTES_IN_LEN: usize = BYTES_OUT_LEN * 2;
Expand All @@ -24,8 +25,8 @@ pub struct BufferedSocket {
socket: UnixStream,
bytes_in: RingBuffer<BYTES_IN_LEN>,
bytes_out: RingBuffer<BYTES_OUT_LEN>,
fds_in: ArrayBuffer<RawFd, FDS_IN_LEN>,
fds_out: ArrayBuffer<RawFd, FDS_OUT_LEN>,
fds_in: VecDeque<OwnedFd>,
fds_out: VecDeque<OwnedFd>,

cmsg: Vec<u8>,
pub free_msg_args: Vec<Vec<ArgValue>>,
Expand Down Expand Up @@ -55,8 +56,8 @@ impl BufferedSocket {
socket: UnixStream::connect(path)?,
bytes_in: RingBuffer::new(),
bytes_out: RingBuffer::new(),
fds_in: ArrayBuffer::new(),
fds_out: ArrayBuffer::new(),
fds_in: VecDeque::new(),
fds_out: VecDeque::new(),

cmsg: nix::cmsg_space!([RawFd; FDS_OUT_LEN]),
free_msg_args: Vec::new(),
Expand All @@ -83,8 +84,8 @@ impl BufferedSocket {
// Check size and flush if neccessary
assert!(size as usize <= BYTES_OUT_LEN);
assert!(fds_cnt <= FDS_OUT_LEN);
if (size as usize) > self.bytes_out.writable_len()
|| fds_cnt > self.fds_out.get_writable().len()
while (size as usize) > self.bytes_out.writable_len()
|| fds_cnt + self.fds_out.len() > FDS_OUT_LEN
{
if let Err(err) = self.flush(mode) {
return Err(SendMessageError { msg, err });
Expand Down Expand Up @@ -117,7 +118,7 @@ impl BufferedSocket {
self.send_array(string.to_bytes_with_nul())
}
ArgValue::Array(array) => self.send_array(&array),
ArgValue::Fd(fd) => self.fds_out.write_one(fd.into_raw_fd()),
ArgValue::Fd(fd) => self.fds_out.push_back(fd),
ArgValue::NewIdEvent(_) => panic!("NewIdEvent in request"),
}
}
Expand Down Expand Up @@ -163,9 +164,7 @@ impl BufferedSocket {
.count();
assert!(header.size as usize <= BYTES_IN_LEN);
assert!(fds_cnt <= FDS_IN_LEN);
while header.size as usize > self.bytes_in.readable_len()
|| fds_cnt > self.fds_in.get_readable().len()
{
while header.size as usize > self.bytes_in.readable_len() || fds_cnt > self.fds_in.len() {
self.fill_incoming_buf(mode)?;
}

Expand Down Expand Up @@ -193,18 +192,14 @@ impl BufferedSocket {
len => Some(self.recv_string_with_len(len)),
}),
ArgType::Array => ArgValue::Array(self.recv_array()),
ArgType::Fd => {
let fd = self.fds_in.read_one();
assert_ne!(fd, -1);
ArgValue::Fd(unsafe { OwnedFd::from_raw_fd(fd) })
}
ArgType::Fd => ArgValue::Fd(self.fds_in.pop_front().unwrap()),
}));

Ok(Message { header, args })
}

pub fn flush(&mut self, mode: IoMode) -> io::Result<()> {
if self.bytes_out.is_empty() && self.fds_out.get_readable().is_empty() {
if self.bytes_out.is_empty() && self.fds_out.is_empty() {
return Ok(());
}

Expand All @@ -214,22 +209,21 @@ impl BufferedSocket {
}

let b;
let cmsgs: &[ControlMessage] = match self.fds_out.get_readable() {
[] => &[],
fds => {
b = [ControlMessage::ScmRights(fds)];
&b
}
let mut fds = [0; FDS_OUT_LEN];
for (i, fd) in self.fds_out.iter().enumerate() {
fds[i] = fd.as_raw_fd();
}
let cmsgs: &[ControlMessage] = if fds.is_empty() {
&[]
} else {
b = [ControlMessage::ScmRights(&fds[..self.fds_out.len()])];
&b
};

let mut iov_buf = [IoSlice::new(&[]), IoSlice::new(&[])];
let iov = self.bytes_out.get_readable_iov(&mut iov_buf);
let sent = socket::sendmsg::<()>(self.socket.as_raw_fd(), iov, cmsgs, flags, None)?;

for fd in self.fds_out.get_readable() {
let _ = nix::unistd::close(*fd);
}

// Does this have to be true?
assert_eq!(sent, self.bytes_out.readable_len());

Expand All @@ -242,8 +236,7 @@ impl BufferedSocket {

impl BufferedSocket {
fn fill_incoming_buf(&mut self, mode: IoMode) -> io::Result<()> {
self.fds_in.relocate();
if self.bytes_in.is_full() && self.fds_in.get_writable().is_empty() {
if self.bytes_in.is_full() {
return Ok(());
}

Expand All @@ -260,7 +253,10 @@ impl BufferedSocket {

for cmsg in msg.cmsgs() {
if let ControlMessageOwned::ScmRights(fds) = cmsg {
self.fds_in.extend(&fds);
for fd in fds {
assert_ne!(fd, -1);
self.fds_in.push_back(unsafe { OwnedFd::from_raw_fd(fd) });
}
}
}

Expand Down Expand Up @@ -319,75 +315,6 @@ impl BufferedSocket {
mod buf {
use super::*;

pub struct ArrayBuffer<T, const N: usize> {
bytes: Box<[T; N]>,
offset: usize,
len: usize,
}

impl<T: Default + Copy, const N: usize> ArrayBuffer<T, N> {
pub fn new() -> Self {
Self {
bytes: Box::new([T::default(); N]),
offset: 0,
len: 0,
}
}

pub fn clear(&mut self) {
self.offset = 0;
self.len = 0;
}

pub fn get_writable(&mut self) -> &mut [T] {
&mut self.bytes[(self.offset + self.len)..]
}

pub fn get_readable(&self) -> &[T] {
&self.bytes[self.offset..][..self.len]
}

pub fn consume(&mut self, cnt: usize) {
assert!(cnt <= self.len);
self.offset += cnt;
self.len -= cnt;
}

pub fn advance(&mut self, cnt: usize) {
assert!(self.offset + self.len + cnt <= N);
self.len += cnt;
}

pub fn relocate(&mut self) {
if self.len > 0 && self.offset > 0 {
self.bytes
.copy_within(self.offset..(self.offset + self.len), 0);
}
self.offset = 0;
}

pub fn write_one(&mut self, elem: T) {
let writable = self.get_writable();
assert!(!writable.is_empty());
writable[0] = elem;
self.advance(1);
}

pub fn read_one(&mut self) -> T {
let readable = self.get_readable();
assert!(!readable.is_empty());
let elem = readable[0];
self.consume(1);
elem
}

pub fn extend(&mut self, src: &[T]) {
let writable = &mut self.get_writable()[..src.len()];
writable.copy_from_slice(src);
self.advance(src.len());
}
}

pub struct RingBuffer<const N: usize> {
bytes: Box<[u8; N]>,
offset: usize,
Expand Down

0 comments on commit 0ef0043

Please sign in to comment.