Skip to content

Commit

Permalink
client: make Connection generic over transport
Browse files Browse the repository at this point in the history
  • Loading branch information
madushan1000 authored and MaxVerevkin committed Apr 16, 2024
1 parent 4319194 commit c08e893
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 106 deletions.
132 changes: 82 additions & 50 deletions wayrs-client/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,50 +1,37 @@
//! Wayland connection
use std::collections::VecDeque;
use std::env;
use std::io;
use std::num::NonZeroU32;
use std::os::fd::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::path::PathBuf;

use crate::debug_message::DebugMessage;
use crate::object::{Object, ObjectManager, Proxy};
use crate::protocol::wl_registry::GlobalArgs;
use crate::protocol::*;
use crate::EventCtx;
use crate::{ClientTransport, EventCtx, Transport};

use wayrs_core::transport::{BufferedSocket, PeekHeaderError, RecvMessageError, SendMessageError};
use wayrs_core::{ArgType, ArgValue, Interface, IoMode, Message, MessageBuffersPool, ObjectId};

#[cfg(feature = "tokio")]
use tokio::io::unix::AsyncFd;

/// An error that can occur while connecting to a Wayland socket.
#[derive(Debug, thiserror::Error)]
pub enum ConnectError {
/// Either `$XDG_RUNTIME_DIR` or `$WAYLAND_DISPLAY` was not available.
#[error("both $XDG_RUNTIME_DIR and $WAYLAND_DISPLAY must be set")]
NotEnoughEnvVars,
/// Some IO error.
#[error(transparent)]
Io(#[from] io::Error),
}

/// Wayland connection state.
///
/// This struct manages a buffered Wayland socket, keeps track of objects and request/event queues
/// and dispatches object events.
///
/// Set `WAYLAND_DEBUG=1` environment variable to get debug messages.
pub struct Connection<D> {
pub struct Connection<D, T = UnixStream> {
#[cfg(feature = "tokio")]
async_fd: Option<AsyncFd<RawFd>>,

socket: BufferedSocket<UnixStream>,
socket: BufferedSocket<T>,
msg_buffers_pool: MessageBuffersPool,

object_mgr: ObjectManager<D>,
object_mgr: ObjectManager<D, T>,

event_queue: VecDeque<QueuedEvent>,
requests_queue: VecDeque<Message>,
Expand All @@ -53,7 +40,7 @@ pub struct Connection<D> {
registry: WlRegistry,

// This is `None` while dispatching registry events, to prevent mutation from registry callbacks.
registry_cbs: Option<Vec<RegistryCb<D>>>,
registry_cbs: Option<Vec<RegistryCb<D, T>>>,

debug: bool,
}
Expand All @@ -64,35 +51,33 @@ enum QueuedEvent {
Message(Message),
}

pub(crate) type GenericCallback<D> =
Box<dyn FnMut(&mut Connection<D>, &mut D, Object, Message) + Send>;
pub(crate) type GenericCallback<D, T> =
Box<dyn FnMut(&mut Connection<D, T>, &mut D, Object, Message) + Send>;

type RegistryCb<D> = Box<dyn FnMut(&mut Connection<D>, &mut D, &wl_registry::Event) + Send>;
type RegistryCb<D, T> = Box<dyn FnMut(&mut Connection<D, T>, &mut D, &wl_registry::Event) + Send>;

impl<D> AsRawFd for Connection<D> {
impl<D, T: Transport> AsRawFd for Connection<D, T> {
fn as_raw_fd(&self) -> RawFd {
self.socket.as_raw_fd()
}
}

impl<D> Connection<D> {
/// Connect to a Wayland socket at `$XDG_RUNTIME_DIR/$WAYLAND_DISPLAY` and create a registry.
impl<D, T> Connection<D, T> {
/// Connect to a Wayland socket and create a registry.
///
/// At the moment, only a single registry can be created. This might or might not change in the
/// future, considering registries cannot be destroyed.
pub fn connect() -> Result<Self, ConnectError> {
let runtime_dir = env::var_os("XDG_RUNTIME_DIR").ok_or(ConnectError::NotEnoughEnvVars)?;
let wayland_disp = env::var_os("WAYLAND_DISPLAY").ok_or(ConnectError::NotEnoughEnvVars)?;

let mut path = PathBuf::new();
path.push(runtime_dir);
path.push(wayland_disp);

///
/// With the default `T = UnixStream`, `$XDG_RUNTIME_DIR/$WAYLAND_DISPLAY` path is used for the socket.
pub fn connect() -> Result<Self, T::ConnectError>
where
T: ClientTransport,
{
let mut this = Self {
#[cfg(feature = "tokio")]
async_fd: None,

socket: BufferedSocket::from(UnixStream::connect(path)?),
socket: BufferedSocket::from(T::connect()?),
msg_buffers_pool: MessageBuffersPool::default(),

object_mgr: ObjectManager::new(),
Expand All @@ -113,7 +98,10 @@ impl<D> Connection<D> {
}

/// [`connect`](Self::connect) and collect the initial set of advertised globals.
pub fn connect_and_collect_globals() -> Result<(Self, Vec<GlobalArgs>), ConnectError> {
pub fn connect_and_collect_globals() -> Result<(Self, Vec<GlobalArgs>), T::ConnectError>
where
T: ClientTransport,
{
let mut this = Self::connect()?;
this.blocking_roundtrip()?;
let globals = this
Expand All @@ -130,7 +118,10 @@ impl<D> Connection<D> {
/// Async version of [`connect_and_collect_globals`](Self::connect_and_collect_globals).
#[cfg(feature = "tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
pub async fn async_connect_and_collect_globals() -> Result<(Self, Vec<GlobalArgs>), ConnectError>
pub async fn async_connect_and_collect_globals(
) -> Result<(Self, Vec<GlobalArgs>), T::ConnectError>
where
T: ClientTransport,
{
let mut this = Self::connect()?;
this.async_roundtrip().await?;
Expand Down Expand Up @@ -162,7 +153,7 @@ impl<D> Connection<D> {
///
/// This method panics if called from the context of a registry callback.
pub fn add_registry_cb<
F: FnMut(&mut Connection<D>, &mut D, &wl_registry::Event) + Send + 'static,
F: FnMut(&mut Connection<D, T>, &mut D, &wl_registry::Event) + Send + 'static,
>(
&mut self,
cb: F,
Expand All @@ -185,7 +176,7 @@ impl<D> Connection<D> {
///
/// Calling this function on a destroyed object will most likely panic, but this is not
/// guarantied due to id-reuse.
pub fn set_callback_for<P: Proxy, F: FnMut(EventCtx<D, P>) + Send + 'static>(
pub fn set_callback_for<P: Proxy, F: FnMut(EventCtx<D, P, T>) + Send + 'static>(
&mut self,
proxy: P,
cb: F,
Expand All @@ -210,7 +201,7 @@ impl<D> Connection<D> {
/// Remove all callbacks.
///
/// You can use this function to change the "state type" of a connection.
pub fn clear_callbacks<D2>(self) -> Connection<D2> {
pub fn clear_callbacks<D2>(self) -> Connection<D2, T> {
Connection {
#[cfg(feature = "tokio")]
async_fd: self.async_fd,
Expand All @@ -230,7 +221,10 @@ impl<D> Connection<D> {
///
/// This function flushes the buffer of pending requests. All received events during the
/// roundtrip are queued.
pub fn blocking_roundtrip(&mut self) -> io::Result<()> {
pub fn blocking_roundtrip(&mut self) -> io::Result<()>
where
T: Transport,
{
let sync_cb = WlDisplay::INSTANCE.sync(self);
self.flush(IoMode::Blocking)?;

Expand All @@ -247,7 +241,10 @@ impl<D> Connection<D> {
/// Async version of [`blocking_roundtrip`](Self::blocking_roundtrip).
#[cfg(feature = "tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
pub async fn async_roundtrip(&mut self) -> io::Result<()> {
pub async fn async_roundtrip(&mut self) -> io::Result<()>
where
T: Transport,
{
let sync_cb = WlDisplay::INSTANCE.sync(self);
self.async_flush().await?;

Expand Down Expand Up @@ -290,7 +287,10 @@ impl<D> Connection<D> {
self.requests_queue.push_back(request);
}

fn recv_event(&mut self, mode: IoMode) -> io::Result<QueuedEvent> {
fn recv_event(&mut self, mode: IoMode) -> io::Result<QueuedEvent>
where
T: Transport,
{
let header = self
.socket
.peek_message_header(mode)
Expand Down Expand Up @@ -373,7 +373,10 @@ impl<D> Connection<D> {
}

#[cfg(feature = "tokio")]
async fn async_recv_event(&mut self) -> io::Result<QueuedEvent> {
async fn async_recv_event(&mut self) -> io::Result<QueuedEvent>
where
T: Transport,
{
let mut async_fd = match self.async_fd.take() {
Some(fd) => fd,
None => AsyncFd::new(self.as_raw_fd())?,
Expand Down Expand Up @@ -401,7 +404,10 @@ impl<D> Connection<D> {
/// Otherwise, [`WouldBlock`](io::ErrorKind::WouldBlock) will be propagated.
///
/// Regular IO errors are propagated as usual.
pub fn recv_events(&mut self, mut mode: IoMode) -> io::Result<()> {
pub fn recv_events(&mut self, mut mode: IoMode) -> io::Result<()>
where
T: Transport,
{
let mut at_least_one = false;

loop {
Expand All @@ -420,7 +426,10 @@ impl<D> Connection<D> {
/// Async version of [`recv_events`](Self::recv_events).
#[cfg(feature = "tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
pub async fn async_recv_events(&mut self) -> io::Result<()> {
pub async fn async_recv_events(&mut self) -> io::Result<()>
where
T: Transport,
{
let msg = self.async_recv_event().await?;
self.event_queue.push_back(msg);

Expand All @@ -434,7 +443,10 @@ impl<D> Connection<D> {
}

/// Send the queue of pending request to the server.
pub fn flush(&mut self, mode: IoMode) -> io::Result<()> {
pub fn flush(&mut self, mode: IoMode) -> io::Result<()>
where
T: Transport,
{
// Send pending messages
while let Some(msg) = self.requests_queue.pop_front() {
if let Err(SendMessageError { msg, err }) =
Expand All @@ -453,7 +465,10 @@ impl<D> Connection<D> {
/// Async version of [`flush`](Self::flush).
#[cfg(feature = "tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
pub async fn async_flush(&mut self) -> io::Result<()> {
pub async fn async_flush(&mut self) -> io::Result<()>
where
T: Transport,
{
// Try to just flush before even touching async fd. In many cases flushing does not block.
match self.flush(IoMode::NonBlocking) {
Err(e) if e.kind() == io::ErrorKind::WouldBlock => (),
Expand All @@ -467,7 +482,8 @@ impl<D> Connection<D> {

loop {
let mut fd_guard = async_fd.writable_mut().await?;
match self.flush(IoMode::NonBlocking) {
let t = self.flush(IoMode::NonBlocking);
match t {
Err(e) if e.kind() == io::ErrorKind::WouldBlock => fd_guard.clear_ready(),
result => {
self.async_fd = Some(async_fd);
Expand Down Expand Up @@ -562,7 +578,7 @@ impl<D> Connection<D> {
/// Allocate a new object and set callback. Returned object must be sent in a request as a
/// "new_id" argument.
#[doc(hidden)]
pub fn allocate_new_object_with_cb<P: Proxy, F: FnMut(EventCtx<D, P>) + Send + 'static>(
pub fn allocate_new_object_with_cb<P: Proxy, F: FnMut(EventCtx<D, P, T>) + Send + 'static>(
&mut self,
version: u32,
cb: F,
Expand All @@ -572,9 +588,9 @@ impl<D> Connection<D> {
P::new(state.object.id, version)
}

fn make_generic_cb<P: Proxy, F: FnMut(EventCtx<D, P>) + Send + 'static>(
fn make_generic_cb<P: Proxy, F: FnMut(EventCtx<D, P, T>) + Send + 'static>(
mut cb: F,
) -> GenericCallback<D> {
) -> GenericCallback<D, T> {
// Note: if `F` does not capture anything, this `Box::new` will not allocate.
Box::new(move |conn, state, object, event| {
let proxy: P = object.try_into().unwrap();
Expand All @@ -588,6 +604,22 @@ impl<D> Connection<D> {
cb(ctx);
})
}

/// Get a reference to the underlying transport.
pub fn transport(&self) -> &T
where
T: Transport,
{
self.socket.transport()
}

/// Get a mutable reference to the underlying transport.
pub fn transport_mut(&mut self) -> &mut T
where
T: Transport,
{
self.socket.transport_mut()
}
}

#[cfg(test)]
Expand Down
Loading

0 comments on commit c08e893

Please sign in to comment.