diff --git a/src/libos/src/net/socket/host/mod.rs b/src/libos/src/net/socket/host/mod.rs index 92a6a3da3f..d7a4fc36fc 100644 --- a/src/libos/src/net/socket/host/mod.rs +++ b/src/libos/src/net/socket/host/mod.rs @@ -30,12 +30,12 @@ impl HostSocket { domain: Domain, socket_type: Type, socket_flags: SocketFlags, - protocol: SocketProtocol, + protocol: i32, ) -> Result { let raw_host_fd = try_libc!(libc::ocall::socket( domain as i32, socket_type as i32 | socket_flags.bits(), - protocol.into() + protocol )) as FileDesc; let host_fd = HostFd::new(raw_host_fd); Ok(HostSocket::from_host_fd(host_fd)?) @@ -88,14 +88,6 @@ impl HostSocket { Ok((HostSocket::from_host_fd(host_fd)?, addr_option)) } - pub fn setsockopt(&self, level: i32, optname: i32, optval: &[u8]) -> Result<()> { - sockopt::setsockopt_by_host(self.raw_host_fd(), level, optname, optval) - } - - pub fn getsockopt(&self, level: i32, optname: i32, mut optval: &mut [u8]) -> Result { - sockopt::getsockopt_by_host(self.raw_host_fd(), level, optname, optval) - } - pub fn addr(&self) -> Result { let mut sockaddr = RawAddr::default(); let mut addr_len = sockaddr.len(); diff --git a/src/libos/src/net/socket/unix/mod.rs b/src/libos/src/net/socket/unix/mod.rs index baf815d0f4..f82602323a 100644 --- a/src/libos/src/net/socket/unix/mod.rs +++ b/src/libos/src/net/socket/unix/mod.rs @@ -5,12 +5,7 @@ mod stream; pub use self::stream::Stream; //TODO: rewrite this file when a new kind of uds is added -pub fn unix_socket( - socket_type: Type, - flags: SocketFlags, - protocol: SocketProtocol, -) -> Result { - let protocol: i32 = protocol.into(); +pub fn unix_socket(socket_type: Type, flags: SocketFlags, protocol: i32) -> Result { if protocol != 0 && protocol != Domain::LOCAL as i32 { return_errno!(EPROTONOSUPPORT, "protocol is not supported"); } @@ -25,9 +20,8 @@ pub fn unix_socket( pub fn socketpair( socket_type: Type, flags: SocketFlags, - protocol: SocketProtocol, + protocol: i32, ) -> Result<(Stream, Stream)> { - let protocol: i32 = protocol.into(); if protocol != 0 && protocol != Domain::LOCAL as i32 { return_errno!(EPROTONOSUPPORT, "protocol is not supported"); } diff --git a/src/libos/src/net/socket/unix/stream/file.rs b/src/libos/src/net/socket/unix/stream/file.rs index cee0a42c93..773a9b1dd1 100644 --- a/src/libos/src/net/socket/unix/stream/file.rs +++ b/src/libos/src/net/socket/unix/stream/file.rs @@ -1,4 +1,4 @@ -use crate::fs::{GetReadBufLen, IoctlCmd}; +use crate::fs::{GetReadBufLen, IoctlCmd, SetNonBlocking}; use super::address_space::ADDRESS_SPACE; use super::stream::Status; @@ -68,6 +68,10 @@ impl File for Stream { _ => return_errno!(ENOTCONN, "unconnected socket"), }; }, + cmd : SetNonBlocking => { + let nonblocking = cmd.input(); + self.set_nonblocking(*nonblocking != 0); + } }); Ok(()) } diff --git a/src/libos/src/net/syscalls.rs b/src/libos/src/net/syscalls.rs index b5dc1f2c6a..54ab4956c2 100644 --- a/src/libos/src/net/syscalls.rs +++ b/src/libos/src/net/syscalls.rs @@ -50,8 +50,6 @@ pub fn do_socket(domain: c_int, socket_type: c_int, protocol: c_int) -> Result Result> = None; - // Determine if type and domain match uring supported socket - let match_uring = move || { - let is_uring_type = (socket_type == Type::DGRAM || socket_type == Type::STREAM); - let is_uring_protocol = (protocol == SocketProtocol::IPPROTO_IP - || protocol == SocketProtocol::IPPROTO_TCP - || protocol == SocketProtocol::IPPROTO_UDP); - let is_uring_domain = (domain == Domain::INET || domain == Domain::INET6); + if ENABLE_URING.load(Ordering::Relaxed) { + let protocol = SocketProtocol::try_from(protocol) + .map_err(|_| errno!(Errno::EINVAL, "Invalid or unsupported network protocol"))?; - is_uring_type && is_uring_protocol && is_uring_domain - }; + // Determine if type and domain match uring supported socket + let match_uring = move || { + let is_uring_type = (socket_type == Type::DGRAM || socket_type == Type::STREAM); + let is_uring_protocol = (protocol == SocketProtocol::IPPROTO_IP + || protocol == SocketProtocol::IPPROTO_TCP + || protocol == SocketProtocol::IPPROTO_UDP); + let is_uring_domain = (domain == Domain::INET || domain == Domain::INET6); + + is_uring_type && is_uring_protocol && is_uring_domain + }; - if ENABLE_URING.load(Ordering::Relaxed) && match_uring() { - let nonblocking = flags.contains(SocketFlags::SOCK_NONBLOCK); - let socket_file = SocketFile::new(domain, protocol, socket_type, nonblocking)?; - file_ref = Some(Arc::new(socket_file)); + if match_uring() { + let nonblocking = flags.contains(SocketFlags::SOCK_NONBLOCK); + let socket_file = SocketFile::new(domain, protocol, socket_type, nonblocking)?; + file_ref = Some(Arc::new(socket_file)); + } }; // Dispatch unsupported uring domain and flags to ocall @@ -269,19 +272,21 @@ pub fn do_setsockopt( "setsockopt: fd: {}, level: {}, optname: {}, optval: {:?}, optlen: {:?}", fd, level, optname, optval, optlen ); - if optval as usize != 0 && optlen == 0 { + let file_ref = current!().file(fd as FileDesc)?; + + if optval as usize != 0 && optlen == 0 && ENABLE_URING.load(Ordering::Relaxed) { return_errno!(EINVAL, "the optlen size is 0"); } - let file_ref = current!().file(fd as FileDesc)?; let optval = from_user::make_slice(optval as *const u8, optlen as usize)?; if let Ok(host_socket) = file_ref.as_host_socket() { - host_socket.setsockopt(level, optname, optval)?; + let mut cmd = new_host_setsockopt_cmd(level, optname, optval)?; + host_socket.ioctl(cmd.as_mut())?; } else if let Ok(unix_socket) = file_ref.as_unix_socket() { warn!("setsockopt for unix socket is unimplemented"); } else if let Ok(uring_socket) = file_ref.as_uring_socket() { - let mut cmd = new_setsockopt_cmd(level, optname, optval, uring_socket.get_type())?; + let mut cmd = new_uring_setsockopt_cmd(level, optname, optval, uring_socket.get_type())?; uring_socket.ioctl(cmd.as_mut())?; } else { return_errno!(ENOTSOCK, "not a socket") @@ -314,12 +319,14 @@ pub fn do_getsockopt( let file_ref = current!().file(fd as FileDesc)?; if let Ok(host_socket) = file_ref.as_host_socket() { - // Some problem - host_socket.getsockopt(level, optname, optval_mut)?; + let mut cmd = new_host_getsockopt_cmd(level, optname, optlen)?; + host_socket.ioctl(cmd.as_mut())?; + let src_optval = get_optval(cmd.as_ref())?; + copy_bytes_to_user(src_optval, optval_mut, optlen_mut); } else if let Ok(unix_socket) = file_ref.as_unix_socket() { warn!("getsockopt for unix socket is unimplemented"); } else if let Ok(uring_socket) = file_ref.as_uring_socket() { - let mut cmd = new_getsockopt_cmd(level, optname, optlen, uring_socket.get_type())?; + let mut cmd = new_uring_getsockopt_cmd(level, optname, optlen, uring_socket.get_type())?; uring_socket.ioctl(cmd.as_mut())?; let src_optval = get_optval(cmd.as_ref())?; copy_bytes_to_user(src_optval, optval_mut, optlen_mut); @@ -491,9 +498,6 @@ pub fn do_socketpair( std::slice::from_raw_parts_mut(sv as *mut u32, 2) }; - let protocol = SocketProtocol::try_from(protocol) - .map_err(|_| errno!(Errno::EINVAL, "Invalid or unsupported network protocol"))?; - let file_flags = SocketFlags::from_bits_truncate(socket_type); let close_on_spawn = file_flags.contains(SocketFlags::SOCK_CLOEXEC); let sock_type = Type::try_from(socket_type & (!file_flags.bits())) @@ -1068,8 +1072,23 @@ fn copy_sock_addr_to_user( *dst_addr_len = src_addr_len as u32; } -/// Create a new ioctl command for getsockopt syscall -fn new_getsockopt_cmd( +/// Create a new ioctl command for host socket getsockopt syscall +fn new_host_getsockopt_cmd(level: i32, optname: i32, optlen: u32) -> Result> { + if level != libc::SOL_SOCKET { + return Ok(Box::new(GetSockOptRawCmd::new(level, optname, optlen))); + } + + let opt = + SockOptName::try_from(optname).map_err(|_| errno!(ENOPROTOOPT, "Not a valid optname"))?; + + Ok(match opt { + SockOptName::SO_CNX_ADVICE => return_errno!(ENOPROTOOPT, "it's a write-only option"), + _ => Box::new(GetSockOptRawCmd::new(level, optname, optlen)), + }) +} + +/// Create a new ioctl command for uring socket getsockopt syscall +fn new_uring_getsockopt_cmd( level: i32, optname: i32, optlen: u32, @@ -1082,46 +1101,65 @@ fn new_getsockopt_cmd( let opt = SockOptName::try_from(optname).map_err(|_| errno!(ENOPROTOOPT, "Not a valid optname"))?; - let enable_uring = ENABLE_URING.load(Ordering::Relaxed); - if !enable_uring { - Ok(match opt { - SockOptName::SO_CNX_ADVICE => return_errno!(ENOPROTOOPT, "it's a write-only option"), - _ => Box::new(GetSockOptRawCmd::new(level, optname, optlen)), - }) - } else { - Ok(match opt { - SockOptName::SO_ACCEPTCONN => Box::new(GetAcceptConnCmd::new(())), - SockOptName::SO_DOMAIN => Box::new(GetDomainCmd::new(())), - SockOptName::SO_ERROR => Box::new(GetErrorCmd::new(())), - SockOptName::SO_PEERNAME => Box::new(GetPeerNameCmd::new(())), - SockOptName::SO_TYPE => Box::new(GetTypeCmd::new(())), - SockOptName::SO_RCVTIMEO_OLD => Box::new(GetRecvTimeoutCmd::new(())), - SockOptName::SO_SNDTIMEO_OLD => Box::new(GetSendTimeoutCmd::new(())), - SockOptName::SO_SNDBUF => { - if socket_type == Type::STREAM { - // Implement dynamic buf size for stream socket only. - Box::new(GetSndBufSizeCmd::new(())) - } else { - Box::new(GetSockOptRawCmd::new(level, optname, optlen)) - } + Ok(match opt { + SockOptName::SO_ACCEPTCONN => Box::new(GetAcceptConnCmd::new(())), + SockOptName::SO_DOMAIN => Box::new(GetDomainCmd::new(())), + SockOptName::SO_ERROR => Box::new(GetErrorCmd::new(())), + SockOptName::SO_PEERNAME => Box::new(GetPeerNameCmd::new(())), + SockOptName::SO_TYPE => Box::new(GetTypeCmd::new(())), + SockOptName::SO_RCVTIMEO_OLD => Box::new(GetRecvTimeoutCmd::new(())), + SockOptName::SO_SNDTIMEO_OLD => Box::new(GetSendTimeoutCmd::new(())), + SockOptName::SO_SNDBUF => { + if socket_type == Type::STREAM { + // Implement dynamic buf size for stream socket only. + Box::new(GetSndBufSizeCmd::new(())) + } else { + Box::new(GetSockOptRawCmd::new(level, optname, optlen)) } - SockOptName::SO_RCVBUF => { - if socket_type == Type::STREAM { - // Implement dynamic buf size for stream socket only. - Box::new(GetRcvBufSizeCmd::new(())) - } else { - Box::new(GetSockOptRawCmd::new(level, optname, optlen)) - } + } + SockOptName::SO_RCVBUF => { + if socket_type == Type::STREAM { + // Implement dynamic buf size for stream socket only. + Box::new(GetRcvBufSizeCmd::new(())) + } else { + Box::new(GetSockOptRawCmd::new(level, optname, optlen)) } + } - SockOptName::SO_CNX_ADVICE => return_errno!(ENOPROTOOPT, "it's a write-only option"), - _ => Box::new(GetSockOptRawCmd::new(level, optname, optlen)), - }) + SockOptName::SO_CNX_ADVICE => return_errno!(ENOPROTOOPT, "it's a write-only option"), + _ => Box::new(GetSockOptRawCmd::new(level, optname, optlen)), + }) +} + +/// Create a new ioctl command for host socket setsockopt syscall +fn new_host_setsockopt_cmd(level: i32, optname: i32, optval: &[u8]) -> Result> { + if level != libc::SOL_SOCKET { + return Ok(Box::new(SetSockOptRawCmd::new(level, optname, optval))); } + + let opt = + SockOptName::try_from(optname).map_err(|_| errno!(ENOPROTOOPT, "Not a valid optname"))?; + + Ok(match opt { + SockOptName::SO_ACCEPTCONN + | SockOptName::SO_DOMAIN + | SockOptName::SO_PEERNAME + | SockOptName::SO_TYPE + | SockOptName::SO_ERROR + | SockOptName::SO_PEERCRED + | SockOptName::SO_SNDLOWAT + | SockOptName::SO_PEERSEC + | SockOptName::SO_PROTOCOL + | SockOptName::SO_MEMINFO + | SockOptName::SO_INCOMING_NAPI_ID + | SockOptName::SO_COOKIE + | SockOptName::SO_PEERGROUPS => return_errno!(ENOPROTOOPT, "it's a read-only option"), + _ => Box::new(SetSockOptRawCmd::new(level, optname, optval)), + }) } -/// Create a new ioctl command for setsockopt syscall -fn new_setsockopt_cmd( +/// Create a new ioctl command for uring socket setsockopt syscall +fn new_uring_setsockopt_cmd( level: i32, optname: i32, optval: &[u8],