Skip to content

Commit

Permalink
[test] Fix gvisor and essential tests bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
ClawSeven committed Apr 18, 2024
1 parent 2066ddc commit 5ce1526
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 80 deletions.
12 changes: 2 additions & 10 deletions src/libos/src/net/socket/host/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ impl HostSocket {
domain: Domain,
socket_type: Type,
socket_flags: SocketFlags,
protocol: SocketProtocol,
protocol: i32,
) -> Result<Self> {
let raw_host_fd = try_libc!(libc::ocall::socket(
domain as i32,
socket_type as i32 | socket_flags.bits(),
protocol.into()
protocol
)) as FileDesc;
let host_fd = HostFd::new(raw_host_fd);
Ok(HostSocket::from_host_fd(host_fd)?)
Expand Down Expand Up @@ -88,14 +88,6 @@ impl HostSocket {
Ok((HostSocket::from_host_fd(host_fd)?, addr_option))
}

pub fn setsockopt(&self, level: i32, optname: i32, optval: &[u8]) -> Result<()> {
sockopt::setsockopt_by_host(self.raw_host_fd(), level, optname, optval)
}

pub fn getsockopt(&self, level: i32, optname: i32, mut optval: &mut [u8]) -> Result<u32> {
sockopt::getsockopt_by_host(self.raw_host_fd(), level, optname, optval)
}

pub fn addr(&self) -> Result<RawAddr> {
let mut sockaddr = RawAddr::default();
let mut addr_len = sockaddr.len();
Expand Down
10 changes: 2 additions & 8 deletions src/libos/src/net/socket/unix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@ mod stream;
pub use self::stream::Stream;

//TODO: rewrite this file when a new kind of uds is added
pub fn unix_socket(
socket_type: Type,
flags: SocketFlags,
protocol: SocketProtocol,
) -> Result<Stream> {
let protocol: i32 = protocol.into();
pub fn unix_socket(socket_type: Type, flags: SocketFlags, protocol: i32) -> Result<Stream> {
if protocol != 0 && protocol != Domain::LOCAL as i32 {
return_errno!(EPROTONOSUPPORT, "protocol is not supported");
}
Expand All @@ -25,9 +20,8 @@ pub fn unix_socket(
pub fn socketpair(
socket_type: Type,
flags: SocketFlags,
protocol: SocketProtocol,
protocol: i32,
) -> Result<(Stream, Stream)> {
let protocol: i32 = protocol.into();
if protocol != 0 && protocol != Domain::LOCAL as i32 {
return_errno!(EPROTONOSUPPORT, "protocol is not supported");
}
Expand Down
6 changes: 5 additions & 1 deletion src/libos/src/net/socket/unix/stream/file.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::fs::{GetReadBufLen, IoctlCmd};
use crate::fs::{GetReadBufLen, IoctlCmd, SetNonBlocking};

use super::address_space::ADDRESS_SPACE;
use super::stream::Status;
Expand Down Expand Up @@ -68,6 +68,10 @@ impl File for Stream {
_ => return_errno!(ENOTCONN, "unconnected socket"),
};
},
cmd : SetNonBlocking => {
let nonblocking = cmd.input();
self.set_nonblocking(*nonblocking != 0);
}
});
Ok(())
}
Expand Down
160 changes: 99 additions & 61 deletions src/libos/src/net/syscalls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ pub fn do_socket(domain: c_int, socket_type: c_int, protocol: c_int) -> Result<i
let type_bits = socket_type & !flags.bits();
let socket_type =
Type::try_from(type_bits).map_err(|_| errno!(EINVAL, "invalid socket type"))?;
let protocol = SocketProtocol::try_from(protocol)
.map_err(|_| errno!(Errno::EINVAL, "Invalid or unsupported network protocol"))?;

debug!(
"socket domain: {:?}, type: {:?}, protocol: {:?}",
Expand All @@ -60,21 +58,26 @@ pub fn do_socket(domain: c_int, socket_type: c_int, protocol: c_int) -> Result<i

let mut file_ref: Option<Arc<dyn File>> = None;

// Determine if type and domain match uring supported socket
let match_uring = move || {
let is_uring_type = (socket_type == Type::DGRAM || socket_type == Type::STREAM);
let is_uring_protocol = (protocol == SocketProtocol::IPPROTO_IP
|| protocol == SocketProtocol::IPPROTO_TCP
|| protocol == SocketProtocol::IPPROTO_UDP);
let is_uring_domain = (domain == Domain::INET || domain == Domain::INET6);
if ENABLE_URING.load(Ordering::Relaxed) {
let protocol = SocketProtocol::try_from(protocol)
.map_err(|_| errno!(Errno::EINVAL, "Invalid or unsupported network protocol"))?;

is_uring_type && is_uring_protocol && is_uring_domain
};
// Determine if type and domain match uring supported socket
let match_uring = move || {
let is_uring_type = (socket_type == Type::DGRAM || socket_type == Type::STREAM);
let is_uring_protocol = (protocol == SocketProtocol::IPPROTO_IP
|| protocol == SocketProtocol::IPPROTO_TCP
|| protocol == SocketProtocol::IPPROTO_UDP);
let is_uring_domain = (domain == Domain::INET || domain == Domain::INET6);

is_uring_type && is_uring_protocol && is_uring_domain
};

if ENABLE_URING.load(Ordering::Relaxed) && match_uring() {
let nonblocking = flags.contains(SocketFlags::SOCK_NONBLOCK);
let socket_file = SocketFile::new(domain, protocol, socket_type, nonblocking)?;
file_ref = Some(Arc::new(socket_file));
if match_uring() {
let nonblocking = flags.contains(SocketFlags::SOCK_NONBLOCK);
let socket_file = SocketFile::new(domain, protocol, socket_type, nonblocking)?;
file_ref = Some(Arc::new(socket_file));
}
};

// Dispatch unsupported uring domain and flags to ocall
Expand Down Expand Up @@ -269,19 +272,21 @@ pub fn do_setsockopt(
"setsockopt: fd: {}, level: {}, optname: {}, optval: {:?}, optlen: {:?}",
fd, level, optname, optval, optlen
);
if optval as usize != 0 && optlen == 0 {
let file_ref = current!().file(fd as FileDesc)?;

if optval as usize != 0 && optlen == 0 && ENABLE_URING.load(Ordering::Relaxed) {
return_errno!(EINVAL, "the optlen size is 0");
}

let file_ref = current!().file(fd as FileDesc)?;
let optval = from_user::make_slice(optval as *const u8, optlen as usize)?;

if let Ok(host_socket) = file_ref.as_host_socket() {
host_socket.setsockopt(level, optname, optval)?;
let mut cmd = new_host_setsockopt_cmd(level, optname, optval)?;
host_socket.ioctl(cmd.as_mut())?;
} else if let Ok(unix_socket) = file_ref.as_unix_socket() {
warn!("setsockopt for unix socket is unimplemented");
} else if let Ok(uring_socket) = file_ref.as_uring_socket() {
let mut cmd = new_setsockopt_cmd(level, optname, optval, uring_socket.get_type())?;
let mut cmd = new_uring_setsockopt_cmd(level, optname, optval, uring_socket.get_type())?;
uring_socket.ioctl(cmd.as_mut())?;
} else {
return_errno!(ENOTSOCK, "not a socket")
Expand Down Expand Up @@ -314,12 +319,14 @@ pub fn do_getsockopt(
let file_ref = current!().file(fd as FileDesc)?;

if let Ok(host_socket) = file_ref.as_host_socket() {
// Some problem
host_socket.getsockopt(level, optname, optval_mut)?;
let mut cmd = new_host_getsockopt_cmd(level, optname, optlen)?;
host_socket.ioctl(cmd.as_mut())?;
let src_optval = get_optval(cmd.as_ref())?;
copy_bytes_to_user(src_optval, optval_mut, optlen_mut);
} else if let Ok(unix_socket) = file_ref.as_unix_socket() {
warn!("getsockopt for unix socket is unimplemented");
} else if let Ok(uring_socket) = file_ref.as_uring_socket() {
let mut cmd = new_getsockopt_cmd(level, optname, optlen, uring_socket.get_type())?;
let mut cmd = new_uring_getsockopt_cmd(level, optname, optlen, uring_socket.get_type())?;
uring_socket.ioctl(cmd.as_mut())?;
let src_optval = get_optval(cmd.as_ref())?;
copy_bytes_to_user(src_optval, optval_mut, optlen_mut);
Expand Down Expand Up @@ -491,9 +498,6 @@ pub fn do_socketpair(
std::slice::from_raw_parts_mut(sv as *mut u32, 2)
};

let protocol = SocketProtocol::try_from(protocol)
.map_err(|_| errno!(Errno::EINVAL, "Invalid or unsupported network protocol"))?;

let file_flags = SocketFlags::from_bits_truncate(socket_type);
let close_on_spawn = file_flags.contains(SocketFlags::SOCK_CLOEXEC);
let sock_type = Type::try_from(socket_type & (!file_flags.bits()))
Expand Down Expand Up @@ -1068,8 +1072,23 @@ fn copy_sock_addr_to_user(
*dst_addr_len = src_addr_len as u32;
}

/// Create a new ioctl command for getsockopt syscall
fn new_getsockopt_cmd(
/// Create a new ioctl command for host socket getsockopt syscall
fn new_host_getsockopt_cmd(level: i32, optname: i32, optlen: u32) -> Result<Box<dyn IoctlCmd>> {
if level != libc::SOL_SOCKET {
return Ok(Box::new(GetSockOptRawCmd::new(level, optname, optlen)));
}

let opt =
SockOptName::try_from(optname).map_err(|_| errno!(ENOPROTOOPT, "Not a valid optname"))?;

Ok(match opt {
SockOptName::SO_CNX_ADVICE => return_errno!(ENOPROTOOPT, "it's a write-only option"),
_ => Box::new(GetSockOptRawCmd::new(level, optname, optlen)),
})
}

/// Create a new ioctl command for uring socket getsockopt syscall
fn new_uring_getsockopt_cmd(
level: i32,
optname: i32,
optlen: u32,
Expand All @@ -1082,46 +1101,65 @@ fn new_getsockopt_cmd(
let opt =
SockOptName::try_from(optname).map_err(|_| errno!(ENOPROTOOPT, "Not a valid optname"))?;

let enable_uring = ENABLE_URING.load(Ordering::Relaxed);
if !enable_uring {
Ok(match opt {
SockOptName::SO_CNX_ADVICE => return_errno!(ENOPROTOOPT, "it's a write-only option"),
_ => Box::new(GetSockOptRawCmd::new(level, optname, optlen)),
})
} else {
Ok(match opt {
SockOptName::SO_ACCEPTCONN => Box::new(GetAcceptConnCmd::new(())),
SockOptName::SO_DOMAIN => Box::new(GetDomainCmd::new(())),
SockOptName::SO_ERROR => Box::new(GetErrorCmd::new(())),
SockOptName::SO_PEERNAME => Box::new(GetPeerNameCmd::new(())),
SockOptName::SO_TYPE => Box::new(GetTypeCmd::new(())),
SockOptName::SO_RCVTIMEO_OLD => Box::new(GetRecvTimeoutCmd::new(())),
SockOptName::SO_SNDTIMEO_OLD => Box::new(GetSendTimeoutCmd::new(())),
SockOptName::SO_SNDBUF => {
if socket_type == Type::STREAM {
// Implement dynamic buf size for stream socket only.
Box::new(GetSndBufSizeCmd::new(()))
} else {
Box::new(GetSockOptRawCmd::new(level, optname, optlen))
}
Ok(match opt {
SockOptName::SO_ACCEPTCONN => Box::new(GetAcceptConnCmd::new(())),
SockOptName::SO_DOMAIN => Box::new(GetDomainCmd::new(())),
SockOptName::SO_ERROR => Box::new(GetErrorCmd::new(())),
SockOptName::SO_PEERNAME => Box::new(GetPeerNameCmd::new(())),
SockOptName::SO_TYPE => Box::new(GetTypeCmd::new(())),
SockOptName::SO_RCVTIMEO_OLD => Box::new(GetRecvTimeoutCmd::new(())),
SockOptName::SO_SNDTIMEO_OLD => Box::new(GetSendTimeoutCmd::new(())),
SockOptName::SO_SNDBUF => {
if socket_type == Type::STREAM {
// Implement dynamic buf size for stream socket only.
Box::new(GetSndBufSizeCmd::new(()))
} else {
Box::new(GetSockOptRawCmd::new(level, optname, optlen))
}
SockOptName::SO_RCVBUF => {
if socket_type == Type::STREAM {
// Implement dynamic buf size for stream socket only.
Box::new(GetRcvBufSizeCmd::new(()))
} else {
Box::new(GetSockOptRawCmd::new(level, optname, optlen))
}
}
SockOptName::SO_RCVBUF => {
if socket_type == Type::STREAM {
// Implement dynamic buf size for stream socket only.
Box::new(GetRcvBufSizeCmd::new(()))
} else {
Box::new(GetSockOptRawCmd::new(level, optname, optlen))
}
}

SockOptName::SO_CNX_ADVICE => return_errno!(ENOPROTOOPT, "it's a write-only option"),
_ => Box::new(GetSockOptRawCmd::new(level, optname, optlen)),
})
SockOptName::SO_CNX_ADVICE => return_errno!(ENOPROTOOPT, "it's a write-only option"),
_ => Box::new(GetSockOptRawCmd::new(level, optname, optlen)),
})
}

/// Create a new ioctl command for host socket setsockopt syscall
fn new_host_setsockopt_cmd(level: i32, optname: i32, optval: &[u8]) -> Result<Box<dyn IoctlCmd>> {
if level != libc::SOL_SOCKET {
return Ok(Box::new(SetSockOptRawCmd::new(level, optname, optval)));
}

let opt =
SockOptName::try_from(optname).map_err(|_| errno!(ENOPROTOOPT, "Not a valid optname"))?;

Ok(match opt {
SockOptName::SO_ACCEPTCONN
| SockOptName::SO_DOMAIN
| SockOptName::SO_PEERNAME
| SockOptName::SO_TYPE
| SockOptName::SO_ERROR
| SockOptName::SO_PEERCRED
| SockOptName::SO_SNDLOWAT
| SockOptName::SO_PEERSEC
| SockOptName::SO_PROTOCOL
| SockOptName::SO_MEMINFO
| SockOptName::SO_INCOMING_NAPI_ID
| SockOptName::SO_COOKIE
| SockOptName::SO_PEERGROUPS => return_errno!(ENOPROTOOPT, "it's a read-only option"),
_ => Box::new(SetSockOptRawCmd::new(level, optname, optval)),
})
}

/// Create a new ioctl command for setsockopt syscall
fn new_setsockopt_cmd(
/// Create a new ioctl command for uring socket setsockopt syscall
fn new_uring_setsockopt_cmd(
level: i32,
optname: i32,
optval: &[u8],
Expand Down

0 comments on commit 5ce1526

Please sign in to comment.