From e4646077cac8c196f676e2711d99ef46a40bda25 Mon Sep 17 00:00:00 2001 From: licunlong Date: Fri, 30 Jun 2023 16:07:29 +0800 Subject: [PATCH] fix: use UnixStream to test send_and_recv --- libs/cmdproto/src/proto/frame.rs | 82 +++++++++++++++++++------------- 1 file changed, 49 insertions(+), 33 deletions(-) diff --git a/libs/cmdproto/src/proto/frame.rs b/libs/cmdproto/src/proto/frame.rs index 6b6912b..9c48cf0 100644 --- a/libs/cmdproto/src/proto/frame.rs +++ b/libs/cmdproto/src/proto/frame.rs @@ -172,25 +172,25 @@ where msg.encode_frame(&mut buf)?; let encoded = buf.freeze(); let msg_len = msg_len_vec(encoded.len()); - match socket::send(self.accept_fd, &msg_len, socket::MsgFlags::empty()) { - Ok(len) => { - if len != msg_len.len() { - return Err(Error::SendStream { - msg: "Invalid message length".to_string(), - }); - } - } + + let len = match socket::send(self.accept_fd, &msg_len, socket::MsgFlags::empty()) { Err(e) => return Err(Error::SendStream { msg: e.to_string() }), + Ok(v) => v, + }; + if len != msg_len.len() { + return Err(Error::SendStream { + msg: "Invalid message length".to_string(), + }); } - match socket::send(self.accept_fd, &encoded, socket::MsgFlags::empty()) { - Ok(len) => { - if len != encoded.len() { - return Err(Error::SendStream { - msg: "Invalid message length".to_string(), - }); - } - } + + let len = match socket::send(self.accept_fd, &encoded, socket::MsgFlags::empty()) { Err(e) => return Err(Error::SendStream { msg: e.to_string() }), + Ok(v) => v, + }; + if len != encoded.len() { + return Err(Error::SendStream { + msg: "Invalid message length".to_string(), + }); } Ok(()) } @@ -240,37 +240,53 @@ where mod tests { use super::super::abi::unit_comm::Action as UnitAction; use super::*; - use std::net::{SocketAddr, TcpStream}; + use core::time; + use std::os::unix::net::UnixStream; + use std::path::Path; use std::thread; use std::time::Duration; #[test] fn test_send_and_recv() { + let socket_name = "./test-sctl.sock"; + let socket_path = Path::new(socket_name); + if socket_path.exists() { + std::fs::remove_file(&socket_path).unwrap(); + } thread::spawn(move || { thread::sleep(Duration::from_secs(1)); - let addrs = [ - SocketAddr::from(([127, 0, 0, 1], 9528)), - SocketAddr::from(([127, 0, 0, 1], 9529)), - ]; - let stream = TcpStream::connect(&addrs[..]).unwrap(); + let stream = match UnixStream::connect(socket_name) { + Err(e) => { + println!("Failed to connect to sysmaster: {e}"); + return; + } + Ok(v) => v, + }; let mut client = ProstClientStream::new(stream); let cmd = CommandRequest::new_unitcomm(UnitAction::Start, vec!["test.service".to_string()]); let _ = client.execute(cmd).unwrap(); }); - let addrs = [ - SocketAddr::from(([127, 0, 0, 1], 9528)), - SocketAddr::from(([127, 0, 0, 1], 9529)), - ]; - let fd = std::net::TcpListener::bind(&addrs[..]).unwrap(); - loop { - for stream in fd.incoming() { - match stream { - Err(e) => panic!("failed: {e}"), - Ok(_stream) => return, - } + let sctl_socket_addr = socket::UnixAddr::new(Path::new(socket_path)).unwrap(); + let socket_fd = socket::socket( + socket::AddressFamily::Unix, + socket::SockType::Stream, + socket::SockFlag::SOCK_CLOEXEC | socket::SockFlag::SOCK_NONBLOCK, + None, + ) + .unwrap(); + + let _ = socket::bind(socket_fd, &sctl_socket_addr); + let _ = socket::listen(socket_fd, 10); + + for _ in 0..5 { + match socket::accept(socket_fd) { + Err(nix::Error::EAGAIN) => thread::sleep(time::Duration::from_millis(500)), + Ok(_) => break, + Err(_) => panic!("Unexpected error when accepting connection."), } } + std::fs::remove_file(&socket_path).unwrap(); } } -- 2.33.0