sysmaster/backport-feature-get-credential-from-the-connection-and-use-i.patch
licunlong 0c7d548f94 sync patches from upstream
(cherry picked from commit e23ebb83bd7672e4dc8da68a9a8c73fe6e016341)
2023-06-19 10:39:49 +08:00

448 lines
16 KiB
Diff

From ca86515f0d1cfcaae7d5552ef9f7df89cf6398d9 Mon Sep 17 00:00:00 2001
From: licunlong <licunlong1@huawei.com>
Date: Thu, 1 Jun 2023 15:38:02 +0800
Subject: [PATCH] feature: get credential from the connection, and use it to
check if the request is send from root
---
core/bin/manager/commands.rs | 33 ++++++-----
core/bin/unit/entry/config.rs | 6 +-
core/bin/unit/util/unit_file.rs | 35 ++++++++++--
libs/cmdproto/src/error.rs | 3 +
libs/cmdproto/src/proto/execute.rs | 68 ++++++++++++++++++++--
libs/cmdproto/src/proto/frame.rs | 91 +++++++++++++++++++++++++-----
6 files changed, 194 insertions(+), 42 deletions(-)
diff --git a/core/bin/manager/commands.rs b/core/bin/manager/commands.rs
index 1279206..a8320e2 100644
--- a/core/bin/manager/commands.rs
+++ b/core/bin/manager/commands.rs
@@ -73,21 +73,28 @@ where
}
fn dispatch(&self, _e: &Events) -> i32 {
- log::trace!("Dispatching Command!");
-
self.reli.set_last_frame1(ReliLastFrame::CmdOp as u32);
- match self.fd.incoming().next() {
- None => log::info!("None CommandRequest!"),
- Some(stream) => {
- log::trace!("{stream:?}");
- if let Ok(s) = stream {
- let dispatch = ProstServerStream::new(s, self.command_action.clone());
- match dispatch.process() {
- Ok(_) => (),
- Err(e) => log::error!("Commands failed: {:?}", e),
- }
- }
+ let client = match socket::accept(self.socket_fd) {
+ Err(e) => {
+ log::error!("Failed to accept connection: {}, ignoring.", e);
+ return 0;
+ }
+ Ok(v) => v,
+ };
+ let cred = match socket::getsockopt(client, socket::sockopt::PeerCredentials) {
+ Err(e) => {
+ log::error!(
+ "Failed to get the credentials of the connection: {}, refuse any request.",
+ e
+ );
+ None
}
+ Ok(v) => Some(v),
+ };
+ let dispatch = ProstServerStream::new(client, self.command_action.clone(), cred);
+ match dispatch.process() {
+ Ok(_) => (),
+ Err(e) => log::error!("Commands failed: {:?}", e),
}
self.reli.clear_last_frame();
diff --git a/core/bin/unit/entry/config.rs b/core/bin/unit/entry/config.rs
index e629c77..705512d 100644
--- a/core/bin/unit/entry/config.rs
+++ b/core/bin/unit/entry/config.rs
@@ -136,14 +136,12 @@ impl UeConfig {
let unit_conf_frag = files.get_unit_id_fragment_pathbuf(name);
if unit_conf_frag.is_empty() {
- log::error!("config file for {} is not exist", name);
- return Err(format!("config file for {name} is not exist").into());
+ return Err(format!("{name} doesn't have corresponding config file").into());
}
// fragment
for v in unit_conf_frag {
if !v.exists() {
- log::error!("config file is not exist");
- return Err(format!("config file is not exist {name}").into());
+ return Err(format!("Config file {:?} of {name} doesn't exist", v).into());
}
builder = builder.file(&v);
}
diff --git a/core/bin/unit/util/unit_file.rs b/core/bin/unit/util/unit_file.rs
index 31f5838..0871703 100644
--- a/core/bin/unit/util/unit_file.rs
+++ b/core/bin/unit/util/unit_file.rs
@@ -133,12 +133,39 @@ impl UnitFileData {
format!("{v}/{name}")
};
let tmp = Path::new(&path);
- if tmp.exists() && !tmp.is_symlink() {
- let path = format!("{}.toml", tmp.to_string_lossy());
- if let Err(e) = std::fs::copy(tmp, &path) {
+ if !tmp.exists() {
+ continue;
+ }
+ /* Add .toml to the original path name */
+ if !tmp.is_symlink() {
+ let path_toml = format!("{}.toml", tmp.to_string_lossy());
+ let to = Path::new(&path_toml);
+ if let Err(e) = std::fs::copy(tmp, to) {
log::warn!("copy file content to toml file error: {}", e);
}
- let to = Path::new(&path);
+ pathbuf_fragment.push(to.to_path_buf());
+ } else {
+ let real_path = match std::fs::read_link(tmp) {
+ Err(e) => {
+ log::error!("Failed to chase the symlink {:?}: {e}", tmp);
+ continue;
+ }
+ Ok(v) => v,
+ };
+ /* Only support one-level symlink. */
+ if real_path.is_symlink() {
+ continue;
+ }
+ let real_path = tmp.parent().unwrap().join(real_path);
+ let real_path = fs::canonicalize(&real_path).unwrap();
+ let path_toml = format!("{}.toml", real_path.to_string_lossy());
+ let to = Path::new(&path_toml);
+ if let Err(e) = std::fs::copy(&real_path, to) {
+ log::warn!(
+ "copy file content {:?} to toml file {path_toml} error: {e}",
+ real_path
+ );
+ }
pathbuf_fragment.push(to.to_path_buf());
}
}
diff --git a/libs/cmdproto/src/error.rs b/libs/cmdproto/src/error.rs
index 5d64109..618e6ad 100644
--- a/libs/cmdproto/src/error.rs
+++ b/libs/cmdproto/src/error.rs
@@ -33,6 +33,9 @@ pub enum Error {
#[snafu(display("ReadStreamFailed"))]
ReadStream { msg: String },
+ #[snafu(display("SendStreamFailed"))]
+ SendStream { msg: String },
+
#[snafu(display("ManagerStartFailed"))]
ManagerStart { msg: String },
}
diff --git a/libs/cmdproto/src/proto/execute.rs b/libs/cmdproto/src/proto/execute.rs
index f9ca73a..280b878 100644
--- a/libs/cmdproto/src/proto/execute.rs
+++ b/libs/cmdproto/src/proto/execute.rs
@@ -18,7 +18,7 @@ use super::{
use crate::error::*;
use http::StatusCode;
-use nix;
+use nix::{self, sys::socket::UnixCredentials};
use std::{fmt::Display, rc::Rc};
pub(crate) trait Executer {
@@ -27,6 +27,7 @@ pub(crate) trait Executer {
self,
manager: Rc<impl ExecuterAction>,
call_back: Option<fn(&str) -> String>,
+ cred: Option<UnixCredentials>,
) -> CommandResponse;
}
@@ -73,7 +74,11 @@ pub trait ExecuterAction {
}
/// Depending on the type of request
-pub(crate) fn dispatch<T>(cmd: CommandRequest, manager: Rc<T>) -> CommandResponse
+pub(crate) fn dispatch<T>(
+ cmd: CommandRequest,
+ manager: Rc<T>,
+ cred: Option<UnixCredentials>,
+) -> CommandResponse
where
T: ExecuterAction,
{
@@ -89,10 +94,10 @@ where
};
match cmd.request_data {
- Some(RequestData::Ucomm(param)) => param.execute(manager, Some(call_back)),
- Some(RequestData::Mcomm(param)) => param.execute(manager, None),
- Some(RequestData::Syscomm(param)) => param.execute(manager, Some(call_back)),
- Some(RequestData::Ufile(param)) => param.execute(manager, Some(call_back)),
+ Some(RequestData::Ucomm(param)) => param.execute(manager, Some(call_back), cred),
+ Some(RequestData::Mcomm(param)) => param.execute(manager, None, cred),
+ Some(RequestData::Syscomm(param)) => param.execute(manager, Some(call_back), cred),
+ Some(RequestData::Ufile(param)) => param.execute(manager, Some(call_back), cred),
_ => CommandResponse::default(),
}
}
@@ -103,12 +108,45 @@ fn new_line_break(s: &mut String) {
}
}
+fn response_if_credential_dissatisfied(
+ cred: Option<UnixCredentials>,
+ command_is_allowed_for_nonroot: bool,
+) -> Option<CommandResponse> {
+ let sender = match cred {
+ None => {
+ return Some(CommandResponse {
+ status: StatusCode::OK.as_u16() as _,
+ error_code: 1,
+ message: "Failed to execute your command: cannot determine user credentials."
+ .to_string(),
+ })
+ }
+ Some(v) => v.uid(),
+ };
+ if sender != 0 && !command_is_allowed_for_nonroot {
+ return Some(CommandResponse {
+ status: StatusCode::OK.as_u16() as _,
+ error_code: 1,
+ message: "Failed to execute your command: Operation not permitted.".to_string(),
+ });
+ }
+ None
+}
+
impl Executer for UnitComm {
fn execute(
self,
manager: Rc<impl ExecuterAction>,
call_back: Option<fn(&str) -> String>,
+ cred: Option<UnixCredentials>,
) -> CommandResponse {
+ if let Some(v) = response_if_credential_dissatisfied(
+ cred,
+ [unit_comm::Action::Status].contains(&self.action()),
+ ) {
+ return v;
+ }
+
let mut reply = String::new();
let mut units: Vec<String> = Vec::new();
let mut error_code: u32 = 0;
@@ -197,7 +235,15 @@ impl Executer for MngrComm {
self,
manager: Rc<impl ExecuterAction>,
_call_back: Option<fn(&str) -> String>,
+ cred: Option<UnixCredentials>,
) -> CommandResponse {
+ if let Some(v) = response_if_credential_dissatisfied(
+ cred,
+ [mngr_comm::Action::Listunits].contains(&self.action()),
+ ) {
+ return v;
+ }
+
match self.action() {
mngr_comm::Action::Reexec => {
manager.daemon_reexec();
@@ -241,7 +287,12 @@ impl Executer for SysComm {
self,
manager: Rc<impl ExecuterAction>,
_call_back: Option<fn(&str) -> String>,
+ cred: Option<UnixCredentials>,
) -> CommandResponse {
+ if let Some(v) = response_if_credential_dissatisfied(cred, false) {
+ return v;
+ }
+
let ret = if self.force {
let unit_name = self.action().to_string() + ".target";
match manager.start(&unit_name) {
@@ -279,7 +330,12 @@ impl Executer for UnitFile {
self,
manager: Rc<impl ExecuterAction>,
call_back: Option<fn(&str) -> String>,
+ cred: Option<UnixCredentials>,
) -> CommandResponse {
+ if let Some(v) = response_if_credential_dissatisfied(cred, false) {
+ return v;
+ }
+
let mut reply = String::new();
let mut units: Vec<String> = Vec::new();
let mut error_code: u32 = 0;
diff --git a/libs/cmdproto/src/proto/frame.rs b/libs/cmdproto/src/proto/frame.rs
index 9b942a8..6b6912b 100644
--- a/libs/cmdproto/src/proto/frame.rs
+++ b/libs/cmdproto/src/proto/frame.rs
@@ -12,6 +12,7 @@
//! Encapsulate the command request into a frame
use crate::error::*;
+use nix::sys::socket::{self, UnixCredentials};
use prost::bytes::{BufMut, BytesMut};
use prost::Message;
use std::{
@@ -48,8 +49,49 @@ where
impl FrameCoder for CommandRequest {}
impl FrameCoder for CommandResponse {}
+/// Read frame from accept fd.
+pub fn read_frame_from_fd(fd: i32, buf: &mut BytesMut) -> Result<()> {
+ // 1. Got the message length
+ let mut msg_len = [0_u8; USIZE_TO_U8_LENGTH];
+ match socket::recv(fd, &mut msg_len, socket::MsgFlags::empty()) {
+ Ok(len) => {
+ if len != USIZE_TO_U8_LENGTH {
+ return Err(Error::ReadStream {
+ msg: "Invalid message length".to_string(),
+ });
+ }
+ }
+ Err(e) => {
+ return Err(Error::ReadStream { msg: e.to_string() });
+ }
+ }
+ let msg_len = get_msg_len(msg_len);
+
+ // 2. Got the message
+ let mut tmp = vec![0; MAX_FRAME];
+ let mut cur_len: usize = 0;
+ loop {
+ match socket::recv(fd, &mut tmp, socket::MsgFlags::empty()) {
+ Ok(len) => {
+ cur_len += len;
+ buf.put_slice(&tmp[..len]);
+ /* If there is no more message (len < MAX_FRAME), or
+ * we have got enough message (cur_len >= msg_len),
+ * then we finish reading. */
+ if len < MAX_FRAME || cur_len >= msg_len {
+ break;
+ }
+ }
+ Err(e) => {
+ return Err(Error::ReadStream { msg: e.to_string() });
+ }
+ }
+ }
+ Ok(())
+}
+
/// read frame from stream
-pub fn read_frame<S>(stream: &mut S, buf: &mut BytesMut) -> Result<()>
+pub fn read_frame_from_stream<S>(stream: &mut S, buf: &mut BytesMut) -> Result<()>
where
S: Read + Unpin + Send,
{
@@ -66,6 +108,9 @@ where
Ok(len) => {
cur_len += len;
buf.put_slice(&tmp[..len]);
+ /* If there is no more message (len < MAX_FRAME), or
+ * we have got enough message (cur_len >= msg_len),
+ * then we finish reading. */
if len < MAX_FRAME || cur_len >= msg_len {
break;
}
@@ -89,9 +134,10 @@ fn get_msg_len(message: [u8; USIZE_TO_U8_LENGTH]) -> usize {
}
/// Handle read and write of server-side socket
-pub struct ProstServerStream<S, T> {
- inner: S,
+pub struct ProstServerStream<T> {
+ accept_fd: i32,
manager: Rc<T>,
+ cred: Option<UnixCredentials>,
}
/// Handle read and write of client-side socket
@@ -99,23 +145,23 @@ pub struct ProstClientStream<S> {
inner: S,
}
-impl<S, T> ProstServerStream<S, T>
+impl<T> ProstServerStream<T>
where
- S: Read + Write + Unpin + Send,
T: ExecuterAction,
{
/// new ProstServerStream
- pub fn new(stream: S, manager: Rc<T>) -> Self {
+ pub fn new(accept_fd: i32, manager: Rc<T>, cred: Option<UnixCredentials>) -> Self {
Self {
- inner: stream,
+ accept_fd,
manager,
+ cred,
}
}
/// process frame in server-side
pub fn process(mut self) -> Result<()> {
if let Ok(cmd) = self.recv() {
- let res = execute::dispatch(cmd, Rc::clone(&self.manager));
+ let res = execute::dispatch(cmd, Rc::clone(&self.manager), self.cred);
self.send(res)?;
};
Ok(())
@@ -126,16 +172,32 @@ where
msg.encode_frame(&mut buf)?;
let encoded = buf.freeze();
let msg_len = msg_len_vec(encoded.len());
- self.inner.write_all(&msg_len).context(IoSnafu)?;
- self.inner.write_all(&encoded).context(IoSnafu)?;
- self.inner.flush().context(IoSnafu)?;
+ match socket::send(self.accept_fd, &msg_len, socket::MsgFlags::empty()) {
+ Ok(len) => {
+ if len != msg_len.len() {
+ return Err(Error::SendStream {
+ msg: "Invalid message length".to_string(),
+ });
+ }
+ }
+ Err(e) => return Err(Error::SendStream { msg: e.to_string() }),
+ }
+ match socket::send(self.accept_fd, &encoded, socket::MsgFlags::empty()) {
+ Ok(len) => {
+ if len != encoded.len() {
+ return Err(Error::SendStream {
+ msg: "Invalid message length".to_string(),
+ });
+ }
+ }
+ Err(e) => return Err(Error::SendStream { msg: e.to_string() }),
+ }
Ok(())
}
fn recv(&mut self) -> Result<CommandRequest> {
let mut buf = BytesMut::new();
- let stream = &mut self.inner;
- read_frame(stream, &mut buf)?;
+ read_frame_from_fd(self.accept_fd, &mut buf)?;
CommandRequest::decode_frame(&mut buf)
}
}
@@ -169,8 +231,7 @@ where
fn recv(&mut self) -> Result<CommandResponse> {
let mut buf = BytesMut::new();
- let stream = &mut self.inner;
- read_frame(stream, &mut buf)?;
+ read_frame_from_stream(&mut self.inner, &mut buf)?;
CommandResponse::decode_frame(&mut buf)
}
}
--
2.30.2