Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions devolutions-agent/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ features = [
aws-lc-rs = "1.15"
time = { version = "0.3", features = ["local-offset", "macros", "parsing"] }
devolutions-pedm = { path = "../crates/devolutions-pedm" }
devolutions-session = { path = "../devolutions-session", default-features = false, features = ["dvc"] }
notify-debouncer-mini = "0.6"
now-proto-pdu = { version = "0.4.3", features = ["std"] }
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls-native-roots", "http2", "socks"] }
thiserror = "2"
uuid = { version = "1.17", features = ["v4"] }
Expand Down
4 changes: 2 additions & 2 deletions devolutions-agent/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use url as _;
use uuid as _;
#[cfg(windows)]
use {
aws_lc_rs as _, devolutions_pedm as _, hex as _, notify_debouncer_mini as _, sha2 as _, thiserror as _,
win_api_wrappers as _, windows as _,
aws_lc_rs as _, devolutions_pedm as _, devolutions_session as _, hex as _, notify_debouncer_mini as _,
now_proto_pdu as _, sha2 as _, thiserror as _, win_api_wrappers as _, windows as _,
};

#[macro_use]
Expand Down
102 changes: 99 additions & 3 deletions devolutions-agent/src/psu_grpc_agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ use protocol::agent_control_client::AgentControlClient;
use protocol::agent_message::Payload as AgentPayload;
use protocol::server_message::Payload as ServerPayload;
use protocol::{
AgentCapability, AgentDiagnostic, AgentMessage, PowerShellRuntime, RegisterAgent, StreamClosed, StreamData,
AgentCapability, AgentDiagnostic, AgentMessage, PowerShellRuntime, ProcessCompleted, RegisterAgent, StreamClosed,
StreamData,
};

const PROTOCOL_VERSION: &str = "poc.v1";
Expand Down Expand Up @@ -217,15 +218,68 @@ impl PsuGrpcAgent {
info!(connection_id = %accepted.connection_id, "PSU gRPC agent registration accepted");
}
Some(ServerPayload::StartProcess(start_process)) => {
let incoming_rx = registry.register_stream(&start_process.stream_id).await;
if let Some(error_message) = start_process_validation_error(&start_process) {
outgoing_tx
.send(agent_message(
&self.agent_id,
connection_id,
AgentPayload::ProcessCompleted(ProcessCompleted {
correlation_id: start_process.correlation_id,
exit_code: -1,
canceled: false,
error_message: error_message.to_owned(),
}),
))
.await
.context("failed to send PSU gRPC invalid StartProcess response")?;
return Ok(());
}

let (control_tx, control_rx) = mpsc::channel(8);
registry
let registered = registry
.register_process(
start_process.correlation_id.clone(),
ProcessControl { stop: control_tx },
)
.await;

if !registered {
let error_message = "process correlation id is already in use".to_owned();
outgoing_tx
.send(agent_message(
&self.agent_id,
connection_id,
AgentPayload::ProcessCompleted(ProcessCompleted {
correlation_id: start_process.correlation_id,
exit_code: -1,
canceled: false,
error_message,
}),
))
.await
.context("failed to send PSU gRPC duplicate ProcessCompleted message")?;
return Ok(());
}

let Some(incoming_rx) = registry.register_stream(&start_process.stream_id).await else {
registry.remove_process(&start_process.correlation_id).await;
let error_message = "process stream id is already in use".to_owned();
outgoing_tx
.send(agent_message(
&self.agent_id,
connection_id,
AgentPayload::ProcessCompleted(ProcessCompleted {
correlation_id: start_process.correlation_id,
exit_code: -1,
canceled: false,
error_message,
}),
))
.await
.context("failed to send PSU gRPC duplicate stream ProcessCompleted message")?;
return Ok(());
};

let agent_id = self.agent_id.clone();
let connection_id = connection_id.clone();
let default_executable = self.powershell_executable.clone();
Expand Down Expand Up @@ -355,6 +409,18 @@ fn timestamp_now() -> prost_types::Timestamp {
}
}

fn start_process_validation_error(start_process: &protocol::StartProcess) -> Option<&'static str> {
if start_process.correlation_id.trim().is_empty() {
return Some("process correlation id is required");
}

if start_process.stream_id.trim().is_empty() {
return Some("process stream id is required");
}

None
}

fn machine_name() -> String {
hostname::get()
.ok()
Expand Down Expand Up @@ -437,4 +503,34 @@ mod tests {
"Bearer token"
);
}

#[test]
fn start_process_validation_requires_ids() {
let mut start_process = protocol::StartProcess {
correlation_id: String::new(),
stream_id: "stream-id".to_owned(),
executable: String::new(),
arguments: Vec::new(),
working_directory: String::new(),
environment: HashMap::new(),
metadata: HashMap::new(),
};

assert_eq!(
start_process_validation_error(&start_process),
Some("process correlation id is required")
);

start_process.correlation_id = "correlation-id".to_owned();
start_process.stream_id.clear();

assert_eq!(
start_process_validation_error(&start_process),
Some("process stream id is required")
);

start_process.stream_id = "stream-id".to_owned();

assert_eq!(start_process_validation_error(&start_process), None);
}
}
Loading
Loading