Skip to content

Commit

Permalink
client: make Conneccion optionaly transport-agonistic
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxVerevkin committed Jan 12, 2025
1 parent 9425699 commit d4f9542
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 3 deletions.
7 changes: 7 additions & 0 deletions wayrs-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ edition.workspace = true
rust-version.workspace = true
license.workspace = true

[features]
any_transport = []

[[example]]
name = "custom_transport"
required-features = ["any_transport"]

[dependencies]
wayrs-core = { version = "1.0", path = "../wayrs-core" }
wayrs-scanner = { version = "0.15.2", path = "../wayrs-scanner" }
Expand Down
86 changes: 86 additions & 0 deletions wayrs-client/examples/custom_transport.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//! An example of how to use your own custom transport implementation.
//! Here, the transport keeps track of how much bytes were sent/received.
use std::collections::VecDeque;
use std::env;
use std::io;
use std::os::fd::{OwnedFd, RawFd};
use std::os::unix::net::UnixStream;
use std::path::PathBuf;

use wayrs_client::core::transport::Transport;
use wayrs_client::protocol::wl_registry;
use wayrs_client::{Connection, IoMode};

fn main() {
let mut conn = Connection::with_transport(MyTransport::connect());

conn.add_registry_cb(|_conn, _state, event| match event {
wl_registry::Event::Global(g) => println!(
"global ({}) {} added",
g.name,
g.interface.to_string_lossy(),
),
wl_registry::Event::GlobalRemove(name) => println!("global ({name}) removed"),
});

loop {
conn.flush(IoMode::Blocking).unwrap();
conn.recv_events(IoMode::Blocking).unwrap();
conn.dispatch_events(&mut ());

let t = conn.transport::<MyTransport>().unwrap();
eprintln!("up: {}b down: {}b", t.bytes_sent, t.bytes_read);
}
}

struct MyTransport {
socket: UnixStream,
bytes_read: usize,
bytes_sent: usize,
}

impl Transport for MyTransport {
fn pollable_fd(&self) -> RawFd {
self.socket.pollable_fd()
}

fn send(
&mut self,
bytes: &[std::io::IoSlice],
fds: &[OwnedFd],
mode: IoMode,
) -> io::Result<usize> {
let n = self.socket.send(bytes, fds, mode)?;
self.bytes_sent += n;
Ok(n)
}

fn recv(
&mut self,
bytes: &mut [std::io::IoSliceMut],
fds: &mut VecDeque<OwnedFd>,
mode: IoMode,
) -> io::Result<usize> {
let n = self.socket.recv(bytes, fds, mode)?;
self.bytes_read += n;
Ok(n)
}
}

impl MyTransport {
fn connect() -> Self {
let runtime_dir = env::var_os("XDG_RUNTIME_DIR").unwrap();
let wayland_disp = env::var_os("WAYLAND_DISPLAY").unwrap();

let mut path = PathBuf::new();
path.push(runtime_dir);
path.push(wayland_disp);

Self {
socket: UnixStream::connect(path).unwrap(),
bytes_read: 0,
bytes_sent: 0,
}
}
}
59 changes: 59 additions & 0 deletions wayrs-client/src/any_transport.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use std::any::Any;
use std::collections::VecDeque;
use std::io;
use std::os::fd::{OwnedFd, RawFd};

use wayrs_core::transport::Transport;
use wayrs_core::IoMode;

pub struct AnyTranpsort(Box<dyn AnyTransportImp>);

impl AnyTranpsort {
pub fn new<T>(transport: T) -> Self
where
T: Transport + Send + 'static,
{
Self(Box::new(transport))
}

pub fn as_any(&self) -> &dyn Any {
self.0.as_any()
}

pub fn as_any_mut(&mut self) -> &mut dyn Any {
self.0.as_any_mut()
}
}

trait AnyTransportImp: Transport + Send {
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
}

impl<T: Transport + Send + 'static> AnyTransportImp for T {
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}

impl Transport for AnyTranpsort {
fn pollable_fd(&self) -> RawFd {
self.0.as_ref().pollable_fd()
}

fn send(&mut self, bytes: &[io::IoSlice], fds: &[OwnedFd], mode: IoMode) -> io::Result<usize> {
self.0.as_mut().send(bytes, fds, mode)
}

fn recv(
&mut self,
bytes: &mut [io::IoSliceMut],
fds: &mut VecDeque<OwnedFd>,
mode: IoMode,
) -> io::Result<usize> {
self.0.as_mut().recv(bytes, fds, mode)
}
}
82 changes: 79 additions & 3 deletions wayrs-client/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ use wayrs_core::{ArgType, ArgValue, Interface, IoMode, Message, MessageBuffersPo
#[cfg(feature = "tokio")]
use tokio::io::unix::AsyncFd;

#[cfg(feature = "any_transport")]
use wayrs_core::transport::Transport;
#[cfg(feature = "any_transport")]
type TransportImp = crate::any_transport::AnyTranpsort;
#[cfg(not(feature = "any_transport"))]
type TransportImp = UnixStream;

/// An error that can occur while connecting to a Wayland socket.
#[derive(Debug)]
pub enum ConnectError {
Expand Down Expand Up @@ -62,7 +69,7 @@ pub struct Connection<D> {
#[cfg(feature = "tokio")]
async_fd: Option<AsyncFd<RawFd>>,

socket: BufferedSocket<UnixStream>,
socket: BufferedSocket<TransportImp>,
msg_buffers_pool: MessageBuffersPool,

object_mgr: ObjectManager<D>,
Expand Down Expand Up @@ -110,11 +117,27 @@ impl<D> Connection<D> {
path.push(runtime_dir);
path.push(wayland_disp);

#[cfg(feature = "any_transport")]
let transport = TransportImp::new(UnixStream::connect(path)?);
#[cfg(not(feature = "any_transport"))]
let transport = UnixStream::connect(path)?;

Ok(Self::with_transport_imp(transport))
}

/// Use a custom transport
#[cfg(feature = "any_transport")]
#[cfg_attr(docsrs, doc(cfg(feature = "any_transport")))]
pub fn with_transport<T: Transport + Send + 'static>(transport: T) -> Self {
Self::with_transport_imp(TransportImp::new(transport))
}

fn with_transport_imp(transport: TransportImp) -> Self {
let mut this = Self {
#[cfg(feature = "tokio")]
async_fd: None,

socket: BufferedSocket::from(UnixStream::connect(path)?),
socket: BufferedSocket::from(transport),
msg_buffers_pool: MessageBuffersPool::default(),

object_mgr: ObjectManager::new(),
Expand All @@ -132,7 +155,25 @@ impl<D> Connection<D> {

this.registry = WlDisplay::INSTANCE.get_registry(&mut this);

Ok(this)
this
}

/// Try to get a reference to the underlying transport.
///
/// Returns `None` if the type of the transport is not `T`.
#[cfg(feature = "any_transport")]
#[cfg_attr(docsrs, doc(cfg(feature = "any_transport")))]
pub fn transport<T: 'static>(&self) -> Option<&T> {
self.socket.transport().as_any().downcast_ref()
}

/// Try to get a mutable reference to the underlying transport.
///
/// Returns `None` if the type of the transport is not `T`.
#[cfg(feature = "any_transport")]
#[cfg_attr(docsrs, doc(cfg(feature = "any_transport")))]
pub fn transport_mut<T: 'static>(&mut self) -> Option<&mut T> {
self.socket.transport_mut().as_any_mut().downcast_mut()
}

/// [`connect`](Self::connect) and collect the initial set of advertised globals.
Expand Down Expand Up @@ -706,4 +747,39 @@ mod tests {
fn send() {
assert_send::<Connection<()>>();
}

#[test]
#[cfg(feature = "any_transport")]
fn transport_downcast() {
use std::os::fd::OwnedFd;

struct T;
impl Transport for T {
fn pollable_fd(&self) -> RawFd {
todo!()
}
fn send(
&mut self,
_bytes: &[io::IoSlice],
_fds: &[OwnedFd],
_mode: IoMode,
) -> io::Result<usize> {
todo!()
}
fn recv(
&mut self,
_bytes: &mut [io::IoSliceMut],
_fds: &mut VecDeque<OwnedFd>,
_mode: IoMode,
) -> io::Result<usize> {
todo!()
}
}

let mut conn = Connection::<()>::with_transport(T);
assert!(conn.transport::<T>().is_some());
assert!(conn.transport_mut::<T>().is_some());
assert!(conn.transport::<()>().is_none());
assert!(conn.transport_mut::<()>().is_none());
}
}
3 changes: 3 additions & 0 deletions wayrs-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ pub mod protocol;
mod connection;
mod debug_message;

#[cfg(feature = "any_transport")]
mod any_transport;

pub use connection::{ConnectError, Connection};

#[doc(hidden)]
Expand Down

0 comments on commit d4f9542

Please sign in to comment.