From ca86515f0d1cfcaae7d5552ef9f7df89cf6398d9 Mon Sep 17 00:00:00 2001 From: licunlong Date: Thu, 1 Jun 2023 15:38:02 +0800 Subject: [PATCH] feature: get credential from the connection, and use it to check if the request is send from root --- core/bin/manager/commands.rs | 33 ++++++----- core/bin/unit/entry/config.rs | 6 +- core/bin/unit/util/unit_file.rs | 35 ++++++++++-- libs/cmdproto/src/error.rs | 3 + libs/cmdproto/src/proto/execute.rs | 68 ++++++++++++++++++++-- libs/cmdproto/src/proto/frame.rs | 91 +++++++++++++++++++++++++----- 6 files changed, 194 insertions(+), 42 deletions(-) diff --git a/core/bin/manager/commands.rs b/core/bin/manager/commands.rs index 1279206..a8320e2 100644 --- a/core/bin/manager/commands.rs +++ b/core/bin/manager/commands.rs @@ -73,21 +73,28 @@ where } fn dispatch(&self, _e: &Events) -> i32 { - log::trace!("Dispatching Command!"); - self.reli.set_last_frame1(ReliLastFrame::CmdOp as u32); - match self.fd.incoming().next() { - None => log::info!("None CommandRequest!"), - Some(stream) => { - log::trace!("{stream:?}"); - if let Ok(s) = stream { - let dispatch = ProstServerStream::new(s, self.command_action.clone()); - match dispatch.process() { - Ok(_) => (), - Err(e) => log::error!("Commands failed: {:?}", e), - } - } + let client = match socket::accept(self.socket_fd) { + Err(e) => { + log::error!("Failed to accept connection: {}, ignoring.", e); + return 0; + } + Ok(v) => v, + }; + let cred = match socket::getsockopt(client, socket::sockopt::PeerCredentials) { + Err(e) => { + log::error!( + "Failed to get the credentials of the connection: {}, refuse any request.", + e + ); + None } + Ok(v) => Some(v), + }; + let dispatch = ProstServerStream::new(client, self.command_action.clone(), cred); + match dispatch.process() { + Ok(_) => (), + Err(e) => log::error!("Commands failed: {:?}", e), } self.reli.clear_last_frame(); diff --git a/core/bin/unit/entry/config.rs b/core/bin/unit/entry/config.rs index e629c77..705512d 100644 --- a/core/bin/unit/entry/config.rs +++ b/core/bin/unit/entry/config.rs @@ -136,14 +136,12 @@ impl UeConfig { let unit_conf_frag = files.get_unit_id_fragment_pathbuf(name); if unit_conf_frag.is_empty() { - log::error!("config file for {} is not exist", name); - return Err(format!("config file for {name} is not exist").into()); + return Err(format!("{name} doesn't have corresponding config file").into()); } // fragment for v in unit_conf_frag { if !v.exists() { - log::error!("config file is not exist"); - return Err(format!("config file is not exist {name}").into()); + return Err(format!("Config file {:?} of {name} doesn't exist", v).into()); } builder = builder.file(&v); } diff --git a/core/bin/unit/util/unit_file.rs b/core/bin/unit/util/unit_file.rs index 31f5838..0871703 100644 --- a/core/bin/unit/util/unit_file.rs +++ b/core/bin/unit/util/unit_file.rs @@ -133,12 +133,39 @@ impl UnitFileData { format!("{v}/{name}") }; let tmp = Path::new(&path); - if tmp.exists() && !tmp.is_symlink() { - let path = format!("{}.toml", tmp.to_string_lossy()); - if let Err(e) = std::fs::copy(tmp, &path) { + if !tmp.exists() { + continue; + } + /* Add .toml to the original path name */ + if !tmp.is_symlink() { + let path_toml = format!("{}.toml", tmp.to_string_lossy()); + let to = Path::new(&path_toml); + if let Err(e) = std::fs::copy(tmp, to) { log::warn!("copy file content to toml file error: {}", e); } - let to = Path::new(&path); + pathbuf_fragment.push(to.to_path_buf()); + } else { + let real_path = match std::fs::read_link(tmp) { + Err(e) => { + log::error!("Failed to chase the symlink {:?}: {e}", tmp); + continue; + } + Ok(v) => v, + }; + /* Only support one-level symlink. */ + if real_path.is_symlink() { + continue; + } + let real_path = tmp.parent().unwrap().join(real_path); + let real_path = fs::canonicalize(&real_path).unwrap(); + let path_toml = format!("{}.toml", real_path.to_string_lossy()); + let to = Path::new(&path_toml); + if let Err(e) = std::fs::copy(&real_path, to) { + log::warn!( + "copy file content {:?} to toml file {path_toml} error: {e}", + real_path + ); + } pathbuf_fragment.push(to.to_path_buf()); } } diff --git a/libs/cmdproto/src/error.rs b/libs/cmdproto/src/error.rs index 5d64109..618e6ad 100644 --- a/libs/cmdproto/src/error.rs +++ b/libs/cmdproto/src/error.rs @@ -33,6 +33,9 @@ pub enum Error { #[snafu(display("ReadStreamFailed"))] ReadStream { msg: String }, + #[snafu(display("SendStreamFailed"))] + SendStream { msg: String }, + #[snafu(display("ManagerStartFailed"))] ManagerStart { msg: String }, } diff --git a/libs/cmdproto/src/proto/execute.rs b/libs/cmdproto/src/proto/execute.rs index f9ca73a..280b878 100644 --- a/libs/cmdproto/src/proto/execute.rs +++ b/libs/cmdproto/src/proto/execute.rs @@ -18,7 +18,7 @@ use super::{ use crate::error::*; use http::StatusCode; -use nix; +use nix::{self, sys::socket::UnixCredentials}; use std::{fmt::Display, rc::Rc}; pub(crate) trait Executer { @@ -27,6 +27,7 @@ pub(crate) trait Executer { self, manager: Rc, call_back: Option String>, + cred: Option, ) -> CommandResponse; } @@ -73,7 +74,11 @@ pub trait ExecuterAction { } /// Depending on the type of request -pub(crate) fn dispatch(cmd: CommandRequest, manager: Rc) -> CommandResponse +pub(crate) fn dispatch( + cmd: CommandRequest, + manager: Rc, + cred: Option, +) -> CommandResponse where T: ExecuterAction, { @@ -89,10 +94,10 @@ where }; match cmd.request_data { - Some(RequestData::Ucomm(param)) => param.execute(manager, Some(call_back)), - Some(RequestData::Mcomm(param)) => param.execute(manager, None), - Some(RequestData::Syscomm(param)) => param.execute(manager, Some(call_back)), - Some(RequestData::Ufile(param)) => param.execute(manager, Some(call_back)), + Some(RequestData::Ucomm(param)) => param.execute(manager, Some(call_back), cred), + Some(RequestData::Mcomm(param)) => param.execute(manager, None, cred), + Some(RequestData::Syscomm(param)) => param.execute(manager, Some(call_back), cred), + Some(RequestData::Ufile(param)) => param.execute(manager, Some(call_back), cred), _ => CommandResponse::default(), } } @@ -103,12 +108,45 @@ fn new_line_break(s: &mut String) { } } +fn response_if_credential_dissatisfied( + cred: Option, + command_is_allowed_for_nonroot: bool, +) -> Option { + let sender = match cred { + None => { + return Some(CommandResponse { + status: StatusCode::OK.as_u16() as _, + error_code: 1, + message: "Failed to execute your command: cannot determine user credentials." + .to_string(), + }) + } + Some(v) => v.uid(), + }; + if sender != 0 && !command_is_allowed_for_nonroot { + return Some(CommandResponse { + status: StatusCode::OK.as_u16() as _, + error_code: 1, + message: "Failed to execute your command: Operation not permitted.".to_string(), + }); + } + None +} + impl Executer for UnitComm { fn execute( self, manager: Rc, call_back: Option String>, + cred: Option, ) -> CommandResponse { + if let Some(v) = response_if_credential_dissatisfied( + cred, + [unit_comm::Action::Status].contains(&self.action()), + ) { + return v; + } + let mut reply = String::new(); let mut units: Vec = Vec::new(); let mut error_code: u32 = 0; @@ -197,7 +235,15 @@ impl Executer for MngrComm { self, manager: Rc, _call_back: Option String>, + cred: Option, ) -> CommandResponse { + if let Some(v) = response_if_credential_dissatisfied( + cred, + [mngr_comm::Action::Listunits].contains(&self.action()), + ) { + return v; + } + match self.action() { mngr_comm::Action::Reexec => { manager.daemon_reexec(); @@ -241,7 +287,12 @@ impl Executer for SysComm { self, manager: Rc, _call_back: Option String>, + cred: Option, ) -> CommandResponse { + if let Some(v) = response_if_credential_dissatisfied(cred, false) { + return v; + } + let ret = if self.force { let unit_name = self.action().to_string() + ".target"; match manager.start(&unit_name) { @@ -279,7 +330,12 @@ impl Executer for UnitFile { self, manager: Rc, call_back: Option String>, + cred: Option, ) -> CommandResponse { + if let Some(v) = response_if_credential_dissatisfied(cred, false) { + return v; + } + let mut reply = String::new(); let mut units: Vec = Vec::new(); let mut error_code: u32 = 0; diff --git a/libs/cmdproto/src/proto/frame.rs b/libs/cmdproto/src/proto/frame.rs index 9b942a8..6b6912b 100644 --- a/libs/cmdproto/src/proto/frame.rs +++ b/libs/cmdproto/src/proto/frame.rs @@ -12,6 +12,7 @@ //! Encapsulate the command request into a frame use crate::error::*; +use nix::sys::socket::{self, UnixCredentials}; use prost::bytes::{BufMut, BytesMut}; use prost::Message; use std::{ @@ -48,8 +49,49 @@ where impl FrameCoder for CommandRequest {} impl FrameCoder for CommandResponse {} +/// Read frame from accept fd. +pub fn read_frame_from_fd(fd: i32, buf: &mut BytesMut) -> Result<()> { + // 1. Got the message length + let mut msg_len = [0_u8; USIZE_TO_U8_LENGTH]; + match socket::recv(fd, &mut msg_len, socket::MsgFlags::empty()) { + Ok(len) => { + if len != USIZE_TO_U8_LENGTH { + return Err(Error::ReadStream { + msg: "Invalid message length".to_string(), + }); + } + } + Err(e) => { + return Err(Error::ReadStream { msg: e.to_string() }); + } + } + let msg_len = get_msg_len(msg_len); + + // 2. Got the message + let mut tmp = vec![0; MAX_FRAME]; + let mut cur_len: usize = 0; + loop { + match socket::recv(fd, &mut tmp, socket::MsgFlags::empty()) { + Ok(len) => { + cur_len += len; + buf.put_slice(&tmp[..len]); + /* If there is no more message (len < MAX_FRAME), or + * we have got enough message (cur_len >= msg_len), + * then we finish reading. */ + if len < MAX_FRAME || cur_len >= msg_len { + break; + } + } + Err(e) => { + return Err(Error::ReadStream { msg: e.to_string() }); + } + } + } + Ok(()) +} + /// read frame from stream -pub fn read_frame(stream: &mut S, buf: &mut BytesMut) -> Result<()> +pub fn read_frame_from_stream(stream: &mut S, buf: &mut BytesMut) -> Result<()> where S: Read + Unpin + Send, { @@ -66,6 +108,9 @@ where Ok(len) => { cur_len += len; buf.put_slice(&tmp[..len]); + /* If there is no more message (len < MAX_FRAME), or + * we have got enough message (cur_len >= msg_len), + * then we finish reading. */ if len < MAX_FRAME || cur_len >= msg_len { break; } @@ -89,9 +134,10 @@ fn get_msg_len(message: [u8; USIZE_TO_U8_LENGTH]) -> usize { } /// Handle read and write of server-side socket -pub struct ProstServerStream { - inner: S, +pub struct ProstServerStream { + accept_fd: i32, manager: Rc, + cred: Option, } /// Handle read and write of client-side socket @@ -99,23 +145,23 @@ pub struct ProstClientStream { inner: S, } -impl ProstServerStream +impl ProstServerStream where - S: Read + Write + Unpin + Send, T: ExecuterAction, { /// new ProstServerStream - pub fn new(stream: S, manager: Rc) -> Self { + pub fn new(accept_fd: i32, manager: Rc, cred: Option) -> Self { Self { - inner: stream, + accept_fd, manager, + cred, } } /// process frame in server-side pub fn process(mut self) -> Result<()> { if let Ok(cmd) = self.recv() { - let res = execute::dispatch(cmd, Rc::clone(&self.manager)); + let res = execute::dispatch(cmd, Rc::clone(&self.manager), self.cred); self.send(res)?; }; Ok(()) @@ -126,16 +172,32 @@ where msg.encode_frame(&mut buf)?; let encoded = buf.freeze(); let msg_len = msg_len_vec(encoded.len()); - self.inner.write_all(&msg_len).context(IoSnafu)?; - self.inner.write_all(&encoded).context(IoSnafu)?; - self.inner.flush().context(IoSnafu)?; + match socket::send(self.accept_fd, &msg_len, socket::MsgFlags::empty()) { + Ok(len) => { + if len != msg_len.len() { + return Err(Error::SendStream { + msg: "Invalid message length".to_string(), + }); + } + } + Err(e) => return Err(Error::SendStream { msg: e.to_string() }), + } + match socket::send(self.accept_fd, &encoded, socket::MsgFlags::empty()) { + Ok(len) => { + if len != encoded.len() { + return Err(Error::SendStream { + msg: "Invalid message length".to_string(), + }); + } + } + Err(e) => return Err(Error::SendStream { msg: e.to_string() }), + } Ok(()) } fn recv(&mut self) -> Result { let mut buf = BytesMut::new(); - let stream = &mut self.inner; - read_frame(stream, &mut buf)?; + read_frame_from_fd(self.accept_fd, &mut buf)?; CommandRequest::decode_frame(&mut buf) } } @@ -169,8 +231,7 @@ where fn recv(&mut self) -> Result { let mut buf = BytesMut::new(); - let stream = &mut self.inner; - read_frame(stream, &mut buf)?; + read_frame_from_stream(&mut self.inner, &mut buf)?; CommandResponse::decode_frame(&mut buf) } } -- 2.30.2