From a2e310d9a7e35a2d238e516ccc94886fd0017e90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Andr=C3=A9=20Moreau?= Date: Tue, 30 Jun 2026 18:34:36 -0400 Subject: [PATCH] refactor(agent): reuse NOW execution for PSU gRPC Route the PSU gRPC remote execution path through the existing NOW/DVC Windows process backend while preserving the current gRPC wire protocol. Keep non-Windows execution on the existing tokio fallback and tighten process/stream cleanup and validation around the gRPC adapter. Co-authored-by: Copilot App <223556219+Copilot@users.noreply.github.com> --- Cargo.lock | 2 + devolutions-agent/Cargo.toml | 2 + devolutions-agent/src/main.rs | 4 +- devolutions-agent/src/psu_grpc_agent/mod.rs | 102 +++- .../src/psu_grpc_agent/process.rs | 521 ++++++++++++++++-- devolutions-session/src/dvc/process.rs | 35 +- devolutions-session/src/dvc/task.rs | 2 +- 7 files changed, 612 insertions(+), 56 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 560224325..0aa8fa22b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1603,6 +1603,7 @@ dependencies = [ "devolutions-gateway-task", "devolutions-log", "devolutions-pedm", + "devolutions-session", "embed-resource", "expect-test", "futures", @@ -1612,6 +1613,7 @@ dependencies = [ "ipnetwork", "ironrdp", "notify-debouncer-mini", + "now-proto-pdu", "parking_lot", "prost 0.13.5", "prost-types", diff --git a/devolutions-agent/Cargo.toml b/devolutions-agent/Cargo.toml index 3e19a8017..c1b69a0e8 100644 --- a/devolutions-agent/Cargo.toml +++ b/devolutions-agent/Cargo.toml @@ -80,7 +80,9 @@ features = [ aws-lc-rs = "1.15" time = { version = "0.3", features = ["local-offset", "macros", "parsing"] } devolutions-pedm = { path = "../crates/devolutions-pedm" } +devolutions-session = { path = "../devolutions-session", default-features = false, features = ["dvc"] } notify-debouncer-mini = "0.6" +now-proto-pdu = { version = "0.4.3", features = ["std"] } reqwest = { version = "0.12", default-features = false, features = ["rustls-tls-native-roots", "http2", "socks"] } thiserror = "2" uuid = { version = "1.17", features = ["v4"] } diff --git a/devolutions-agent/src/main.rs b/devolutions-agent/src/main.rs index 02941753e..7d5d64cf1 100644 --- a/devolutions-agent/src/main.rs +++ b/devolutions-agent/src/main.rs @@ -30,8 +30,8 @@ use url as _; use uuid as _; #[cfg(windows)] use { - aws_lc_rs as _, devolutions_pedm as _, hex as _, notify_debouncer_mini as _, sha2 as _, thiserror as _, - win_api_wrappers as _, windows as _, + aws_lc_rs as _, devolutions_pedm as _, devolutions_session as _, hex as _, notify_debouncer_mini as _, + now_proto_pdu as _, sha2 as _, thiserror as _, win_api_wrappers as _, windows as _, }; #[macro_use] diff --git a/devolutions-agent/src/psu_grpc_agent/mod.rs b/devolutions-agent/src/psu_grpc_agent/mod.rs index aa1819355..9059d7df9 100644 --- a/devolutions-agent/src/psu_grpc_agent/mod.rs +++ b/devolutions-agent/src/psu_grpc_agent/mod.rs @@ -27,7 +27,8 @@ use protocol::agent_control_client::AgentControlClient; use protocol::agent_message::Payload as AgentPayload; use protocol::server_message::Payload as ServerPayload; use protocol::{ - AgentCapability, AgentDiagnostic, AgentMessage, PowerShellRuntime, RegisterAgent, StreamClosed, StreamData, + AgentCapability, AgentDiagnostic, AgentMessage, PowerShellRuntime, ProcessCompleted, RegisterAgent, StreamClosed, + StreamData, }; const PROTOCOL_VERSION: &str = "poc.v1"; @@ -217,15 +218,68 @@ impl PsuGrpcAgent { info!(connection_id = %accepted.connection_id, "PSU gRPC agent registration accepted"); } Some(ServerPayload::StartProcess(start_process)) => { - let incoming_rx = registry.register_stream(&start_process.stream_id).await; + if let Some(error_message) = start_process_validation_error(&start_process) { + outgoing_tx + .send(agent_message( + &self.agent_id, + connection_id, + AgentPayload::ProcessCompleted(ProcessCompleted { + correlation_id: start_process.correlation_id, + exit_code: -1, + canceled: false, + error_message: error_message.to_owned(), + }), + )) + .await + .context("failed to send PSU gRPC invalid StartProcess response")?; + return Ok(()); + } + let (control_tx, control_rx) = mpsc::channel(8); - registry + let registered = registry .register_process( start_process.correlation_id.clone(), ProcessControl { stop: control_tx }, ) .await; + if !registered { + let error_message = "process correlation id is already in use".to_owned(); + outgoing_tx + .send(agent_message( + &self.agent_id, + connection_id, + AgentPayload::ProcessCompleted(ProcessCompleted { + correlation_id: start_process.correlation_id, + exit_code: -1, + canceled: false, + error_message, + }), + )) + .await + .context("failed to send PSU gRPC duplicate ProcessCompleted message")?; + return Ok(()); + } + + let Some(incoming_rx) = registry.register_stream(&start_process.stream_id).await else { + registry.remove_process(&start_process.correlation_id).await; + let error_message = "process stream id is already in use".to_owned(); + outgoing_tx + .send(agent_message( + &self.agent_id, + connection_id, + AgentPayload::ProcessCompleted(ProcessCompleted { + correlation_id: start_process.correlation_id, + exit_code: -1, + canceled: false, + error_message, + }), + )) + .await + .context("failed to send PSU gRPC duplicate stream ProcessCompleted message")?; + return Ok(()); + }; + let agent_id = self.agent_id.clone(); let connection_id = connection_id.clone(); let default_executable = self.powershell_executable.clone(); @@ -355,6 +409,18 @@ fn timestamp_now() -> prost_types::Timestamp { } } +fn start_process_validation_error(start_process: &protocol::StartProcess) -> Option<&'static str> { + if start_process.correlation_id.trim().is_empty() { + return Some("process correlation id is required"); + } + + if start_process.stream_id.trim().is_empty() { + return Some("process stream id is required"); + } + + None +} + fn machine_name() -> String { hostname::get() .ok() @@ -437,4 +503,34 @@ mod tests { "Bearer token" ); } + + #[test] + fn start_process_validation_requires_ids() { + let mut start_process = protocol::StartProcess { + correlation_id: String::new(), + stream_id: "stream-id".to_owned(), + executable: String::new(), + arguments: Vec::new(), + working_directory: String::new(), + environment: HashMap::new(), + metadata: HashMap::new(), + }; + + assert_eq!( + start_process_validation_error(&start_process), + Some("process correlation id is required") + ); + + start_process.correlation_id = "correlation-id".to_owned(); + start_process.stream_id.clear(); + + assert_eq!( + start_process_validation_error(&start_process), + Some("process stream id is required") + ); + + start_process.stream_id = "stream-id".to_owned(); + + assert_eq!(start_process_validation_error(&start_process), None); + } } diff --git a/devolutions-agent/src/psu_grpc_agent/process.rs b/devolutions-agent/src/psu_grpc_agent/process.rs index cd61433cb..f3e4eb114 100644 --- a/devolutions-agent/src/psu_grpc_agent/process.rs +++ b/devolutions-agent/src/psu_grpc_agent/process.rs @@ -1,18 +1,34 @@ use std::collections::HashMap; +#[cfg(not(windows))] use std::process::{ExitStatus, Stdio}; use std::sync::Arc; +#[cfg(not(windows))] use std::time::Duration; use anyhow::Context as _; +#[cfg(windows)] +use devolutions_session::dvc::encoding::DataEncoding; +#[cfg(windows)] +use devolutions_session::dvc::process::{ExecError, ServerChannelEvent, WinApiProcessBuilder}; +#[cfg(windows)] +use now_proto_pdu::NowExecDataStreamKind; +#[cfg(windows)] +use sha2::Digest as _; +#[cfg(not(windows))] use tokio::io::{AsyncBufReadExt as _, AsyncRead, AsyncReadExt as _, AsyncWriteExt as _, BufReader}; +#[cfg(not(windows))] use tokio::process::{Child, Command}; use tokio::sync::{Mutex, mpsc}; +#[cfg(not(windows))] use tokio::task::JoinHandle; +#[cfg(windows)] +use win_api_wrappers::utils::CommandLine; use crate::psu_grpc_agent::protocol::agent_message::Payload as AgentPayload; use crate::psu_grpc_agent::protocol::{AgentMessage, ProcessCompleted, ProcessStarted, StartProcess, StreamData}; use crate::psu_grpc_agent::{agent_message, diagnostic, stream_closed, stream_data}; +#[cfg(not(windows))] const PWSH_STDIN_CLOSED_EXIT_CODE: i32 = 160; #[derive(Debug)] @@ -32,10 +48,15 @@ struct ProcessRegistryInner { } impl ProcessRegistry { - pub(super) async fn register_stream(&self, stream_id: &str) -> mpsc::Receiver { + pub(super) async fn register_stream(&self, stream_id: &str) -> Option> { let (tx, rx) = mpsc::channel(256); - self.inner.lock().await.streams.insert(stream_id.to_owned(), tx); - rx + let mut inner = self.inner.lock().await; + if inner.streams.contains_key(stream_id) { + return None; + } + + inner.streams.insert(stream_id.to_owned(), tx); + Some(rx) } pub(super) async fn dispatch_stream_data(&self, stream_data: StreamData) { @@ -56,8 +77,14 @@ impl ProcessRegistry { self.inner.lock().await.streams.remove(stream_id); } - pub(super) async fn register_process(&self, correlation_id: String, control: ProcessControl) { - self.inner.lock().await.processes.insert(correlation_id, control); + pub(super) async fn register_process(&self, correlation_id: String, control: ProcessControl) -> bool { + let mut inner = self.inner.lock().await; + if inner.processes.contains_key(&correlation_id) { + return false; + } + + inner.processes.insert(correlation_id, control); + true } pub(super) async fn stop_process(&self, correlation_id: &str, kill_process: bool) { @@ -75,7 +102,7 @@ impl ProcessRegistry { } } - async fn remove_process(&self, correlation_id: &str) { + pub(super) async fn remove_process(&self, correlation_id: &str) { self.inner.lock().await.processes.remove(correlation_id); } } @@ -111,6 +138,30 @@ pub(super) async fn run_process( result } +#[cfg(windows)] +#[allow(clippy::too_many_arguments)] +async fn run_process_inner( + request: StartProcess, + incoming_rx: mpsc::Receiver, + control_rx: mpsc::Receiver, + outgoing_tx: mpsc::Sender, + agent_id: String, + connection_id: String, + default_executable: String, +) -> anyhow::Result<()> { + run_process_inner_windows( + request, + incoming_rx, + control_rx, + outgoing_tx, + agent_id, + connection_id, + default_executable, + ) + .await +} + +#[cfg(not(windows))] #[allow(clippy::too_many_arguments)] async fn run_process_inner( request: StartProcess, @@ -137,7 +188,29 @@ async fn run_process_inner( .stderr(Stdio::piped()) .kill_on_drop(true); - if !request.working_directory.trim().is_empty() && std::path::Path::new(&request.working_directory).is_dir() { + if !request.working_directory.trim().is_empty() { + if !std::path::Path::new(&request.working_directory).is_dir() { + let error_message = format!("working directory does not exist: {}", request.working_directory); + let _ = outgoing_tx + .send(agent_message( + &agent_id, + &connection_id, + AgentPayload::StreamClosed(stream_closed(request.stream_id.clone(), error_message.clone(), true)), + )) + .await; + send_process_completed( + &outgoing_tx, + &agent_id, + &connection_id, + &request.correlation_id, + -1, + false, + error_message.clone(), + ) + .await?; + return Err(anyhow::anyhow!(error_message)); + } + command.current_dir(&request.working_directory); } @@ -300,6 +373,320 @@ async fn run_process_inner( Ok(()) } +#[cfg(windows)] +#[allow(clippy::too_many_arguments)] +async fn run_process_inner_windows( + request: StartProcess, + mut incoming_rx: mpsc::Receiver, + mut control_rx: mpsc::Receiver, + outgoing_tx: mpsc::Sender, + agent_id: String, + connection_id: String, + default_executable: String, +) -> anyhow::Result<()> { + let session_id = session_id_from_correlation_id(&request.correlation_id); + let executable = if request.executable.trim().is_empty() { + default_executable + } else { + request.executable.clone() + }; + + info!( + correlation_id = %request.correlation_id, + session_id, + executable = %executable, + arguments = ?request.arguments, + "Starting PSU gRPC child process through NOW_EXEC backend" + ); + + let command_line = CommandLine::new(request.arguments.clone()).to_command_line(); + + let mut process_builder = WinApiProcessBuilder::new(&executable) + .with_command_line(&command_line) + .with_io_redirection(true) + .with_encoding(DataEncoding::Raw) + .with_kill_on_drop(true); + + if !request.working_directory.trim().is_empty() { + let working_directory = std::path::Path::new(&request.working_directory); + if !working_directory.is_dir() { + let error_message = format!("working directory does not exist: {}", request.working_directory); + let _ = outgoing_tx + .send(agent_message( + &agent_id, + &connection_id, + AgentPayload::StreamClosed(stream_closed(request.stream_id.clone(), error_message.clone(), true)), + )) + .await; + send_process_completed( + &outgoing_tx, + &agent_id, + &connection_id, + &request.correlation_id, + -1, + false, + error_message.clone(), + ) + .await?; + return Err(anyhow::anyhow!(error_message)); + } + + process_builder = process_builder.with_current_directory(&request.working_directory); + } + + for (key, value) in &request.environment { + process_builder = process_builder.with_env(key, value); + } + + let (io_notification_tx, mut io_notification_rx) = mpsc::channel(100); + let mut process = match process_builder.run(session_id, io_notification_tx) { + Ok(process) => process, + Err(error) => { + let error_message = format!( + "failed to start PSU gRPC child process using {executable}: {}", + format_exec_error(error) + ); + let _ = outgoing_tx + .send(agent_message( + &agent_id, + &connection_id, + AgentPayload::StreamClosed(stream_closed(request.stream_id.clone(), error_message.clone(), true)), + )) + .await; + let _ = send_process_completed( + &outgoing_tx, + &agent_id, + &connection_id, + &request.correlation_id, + -1, + false, + error_message.clone(), + ) + .await; + return Err(anyhow::anyhow!(error_message)); + } + }; + + let mut canceled = false; + let mut stdout_closed = false; + let mut stderr_closed = false; + let mut stdin_closed = false; + let mut control_closed = false; + let mut stdout_sequence = 0; + let mut stderr_sequence = 0; + + loop { + tokio::select! { + event = io_notification_rx.recv() => { + match event { + Some(ServerChannelEvent::SessionStarted { process_id, .. }) => { + let process_id = i32::try_from(process_id).unwrap_or(i32::MAX); + outgoing_tx + .send(agent_message( + &agent_id, + &connection_id, + AgentPayload::ProcessStarted(ProcessStarted { + correlation_id: request.correlation_id.clone(), + process_id, + }), + )) + .await + .context("failed to send PSU gRPC ProcessStarted message")?; + } + Some(ServerChannelEvent::SessionDataOut { stream, last, data, .. }) => { + match stream { + NowExecDataStreamKind::Stdout => { + if !data.is_empty() || last { + send_stream_frame( + &outgoing_tx, + &agent_id, + &connection_id, + &request.stream_id, + stdout_sequence, + data, + last, + ) + .await?; + stdout_sequence += 1; + } + stdout_closed |= last; + } + NowExecDataStreamKind::Stderr => { + if !data.is_empty() { + send_stderr_diagnostic( + &outgoing_tx, + &agent_id, + &connection_id, + &request.correlation_id, + stderr_sequence, + data, + ) + .await?; + stderr_sequence += 1; + } + stderr_closed |= last; + } + NowExecDataStreamKind::Stdin => {} + } + } + Some(ServerChannelEvent::SessionCancelSuccess { .. }) => { + canceled = true; + } + Some(ServerChannelEvent::SessionCancelFailed { error, .. }) => { + warn!(error = %error, correlation_id = %request.correlation_id, "PSU gRPC NOW_EXEC cancel failed"); + } + Some(ServerChannelEvent::SessionExited { exit_code, .. }) => { + process.disable_kill_on_drop(); + + if !stdout_closed { + send_stream_frame( + &outgoing_tx, + &agent_id, + &connection_id, + &request.stream_id, + stdout_sequence, + Vec::new(), + true, + ) + .await?; + } + if !stderr_closed { + send_stderr_diagnostic( + &outgoing_tx, + &agent_id, + &connection_id, + &request.correlation_id, + stderr_sequence, + Vec::new(), + ) + .await?; + } + + let _ = outgoing_tx + .send(agent_message( + &agent_id, + &connection_id, + AgentPayload::StreamClosed(stream_closed( + request.stream_id.clone(), + "child process completed".to_owned(), + false, + )), + )) + .await; + + let exit_code = i32::try_from(exit_code).unwrap_or(i32::MAX); + send_process_completed( + &outgoing_tx, + &agent_id, + &connection_id, + &request.correlation_id, + exit_code, + canceled, + String::new(), + ) + .await + .context("failed to send PSU gRPC ProcessCompleted message")?; + return Ok(()); + } + Some(ServerChannelEvent::SessionFailed { error, .. }) => { + let error_message = format_exec_error(error); + let _ = outgoing_tx + .send(agent_message( + &agent_id, + &connection_id, + AgentPayload::StreamClosed(stream_closed( + request.stream_id.clone(), + error_message.clone(), + true, + )), + )) + .await; + send_process_completed( + &outgoing_tx, + &agent_id, + &connection_id, + &request.correlation_id, + -1, + canceled, + error_message, + ) + .await?; + return Ok(()); + } + Some(ServerChannelEvent::CloseChannel | ServerChannelEvent::WindowRecordingEvent { .. }) => {} + None => { + let error_message = "NOW_EXEC process event channel closed before completion".to_owned(); + send_process_completed( + &outgoing_tx, + &agent_id, + &connection_id, + &request.correlation_id, + -1, + canceled, + error_message.clone(), + ) + .await?; + return Err(anyhow::anyhow!(error_message)); + } + } + } + frame = incoming_rx.recv(), if !stdin_closed => { + match frame { + Some(frame) => { + stdin_closed = frame.end_of_stream; + if let Err(error) = process.send_stdin(frame.data, frame.end_of_stream).await { + warn!( + error = format!("{error:#}"), + correlation_id = %request.correlation_id, + "Failed to send PSU gRPC stdin frame through NOW_EXEC backend" + ); + stdin_closed = true; + } + } + None => { + stdin_closed = true; + if let Err(error) = process.send_stdin(Vec::new(), true).await { + warn!( + error = format!("{error:#}"), + correlation_id = %request.correlation_id, + "Failed to close PSU gRPC stdin through NOW_EXEC backend" + ); + } + } + } + } + kill_process = control_rx.recv(), if !control_closed => { + match kill_process { + Some(true) => { + canceled = true; + control_closed = true; + if let Err(error) = process.abort_execution(1).await { + warn!( + error = format!("{error:#}"), + correlation_id = %request.correlation_id, + "Failed to abort PSU gRPC NOW_EXEC process" + ); + } + } + Some(false) => { + if let Err(error) = process.cancel_execution().await { + warn!( + error = format!("{error:#}"), + correlation_id = %request.correlation_id, + "Failed to cancel PSU gRPC NOW_EXEC process" + ); + } + } + None => { + control_closed = true; + } + } + } + } + } +} + +#[cfg(not(windows))] async fn wait_for_graceful_child_exit(child: &mut Child, process_id: i32) -> anyhow::Result<(ExitStatus, bool)> { match tokio::time::timeout(Duration::from_secs(5), child.wait()).await { Ok(status) => Ok((status.context("failed to wait for PSU gRPC child process")?, false)), @@ -318,6 +705,7 @@ async fn wait_for_graceful_child_exit(child: &mut Child, process_id: i32) -> any } } +#[cfg(not(windows))] async fn await_pump_task(mut task: JoinHandle>, process_id: i32, stream_name: &'static str) { tokio::select! { result = &mut task => match result { @@ -357,6 +745,7 @@ async fn send_process_completed( .context("failed to send PSU gRPC ProcessCompleted message") } +#[cfg(not(windows))] async fn pump_stdout_to_server( mut stdout: R, stream_id: String, @@ -369,7 +758,6 @@ where R: AsyncRead + Unpin, { let mut buffer = [0u8; 4096]; - let mut line = Vec::new(); let mut sequence = 0; loop { @@ -378,35 +766,13 @@ where break; } - for byte in &buffer[..read] { - match *byte { - b'\r' => {} - b'\n' => { - send_stream_frame( - &outgoing_tx, - &agent_id, - &connection_id, - &stream_id, - sequence, - std::mem::take(&mut line), - false, - ) - .await?; - sequence += 1; - } - byte => line.push(byte), - } - } - } - - if !line.is_empty() { send_stream_frame( &outgoing_tx, &agent_id, &connection_id, &stream_id, sequence, - line, + buffer[..read].to_vec(), false, ) .await?; @@ -446,6 +812,48 @@ async fn send_stream_frame( .context("failed to send PSU gRPC stdout frame") } +#[cfg(windows)] +async fn send_stderr_diagnostic( + outgoing_tx: &mpsc::Sender, + agent_id: &str, + connection_id: &str, + correlation_id: &str, + sequence: u64, + data: Vec, +) -> anyhow::Result<()> { + if data.is_empty() { + return Ok(()); + } + + let message = String::from_utf8_lossy(&data); + outgoing_tx + .send(agent_message( + agent_id, + connection_id, + AgentPayload::Diagnostic(diagnostic( + "warning", + format!("stderr[{correlation_id}:{sequence}] {message}"), + )), + )) + .await + .context("failed to send PSU gRPC stderr diagnostic") +} + +#[cfg(windows)] +fn session_id_from_correlation_id(correlation_id: &str) -> u32 { + let digest = sha2::Sha256::digest(correlation_id.as_bytes()); + u32::from_le_bytes(digest[..4].try_into().expect("BUG: SHA-256 digest is at least 4 bytes")) +} + +#[cfg(windows)] +fn format_exec_error(error: ExecError) -> String { + match error { + ExecError::Other(error) => format!("{error:#}"), + error => error.to_string(), + } +} + +#[cfg(not(windows))] async fn pump_server_to_stdin( mut incoming_rx: mpsc::Receiver, mut stdin: tokio::process::ChildStdin, @@ -460,12 +868,7 @@ async fn pump_server_to_stdin( break; } - let mut data = frame.data; - if !ends_with_line_ending(&data) { - data.push(b'\n'); - } - - if let Err(error) = stdin.write_all(&data).await { + if let Err(error) = stdin.write_all(&frame.data).await { warn!(process_id, %error, "Failed to write PSU gRPC frame to child stdin"); break; } @@ -480,6 +883,7 @@ async fn pump_server_to_stdin( closed_from_end_of_stream } +#[cfg(not(windows))] async fn pump_stderr_diagnostics( stderr: R, outgoing_tx: mpsc::Sender, @@ -509,10 +913,6 @@ where Ok(()) } -fn ends_with_line_ending(data: &[u8]) -> bool { - data.ends_with(b"\n") || data.ends_with(b"\r") -} - #[cfg(test)] mod tests { use super::*; @@ -522,9 +922,11 @@ mod tests { let registry = ProcessRegistry::default(); let (control_tx, mut control_rx) = mpsc::channel(8); - registry - .register_process("correlation-id".to_owned(), ProcessControl { stop: control_tx }) - .await; + assert!( + registry + .register_process("correlation-id".to_owned(), ProcessControl { stop: control_tx }) + .await + ); registry.stop_process("correlation-id", false).await; assert_eq!(control_rx.recv().await, Some(false)); @@ -538,13 +940,15 @@ mod tests { #[tokio::test] async fn run_process_cleans_registry_and_reports_spawn_failure() { let registry = ProcessRegistry::default(); - let incoming_rx = registry.register_stream("stream-id").await; + let incoming_rx = registry.register_stream("stream-id").await.expect("register stream"); let (control_tx, control_rx) = mpsc::channel(8); let (outgoing_tx, mut outgoing_rx) = mpsc::channel(8); - registry - .register_process("correlation-id".to_owned(), ProcessControl { stop: control_tx }) - .await; + assert!( + registry + .register_process("correlation-id".to_owned(), ProcessControl { stop: control_tx }) + .await + ); let result = run_process( StartProcess { @@ -598,4 +1002,25 @@ mod tests { payload => panic!("unexpected payload: {payload:?}"), } } + + #[tokio::test] + async fn registry_rejects_duplicate_processes_and_streams() { + let registry = ProcessRegistry::default(); + let (control_tx, _control_rx) = mpsc::channel(8); + let (duplicate_tx, _duplicate_rx) = mpsc::channel(8); + + assert!( + registry + .register_process("correlation-id".to_owned(), ProcessControl { stop: control_tx }) + .await + ); + assert!( + !registry + .register_process("correlation-id".to_owned(), ProcessControl { stop: duplicate_tx }) + .await + ); + + assert!(registry.register_stream("stream-id").await.is_some()); + assert!(registry.register_stream("stream-id").await.is_none()); + } } diff --git a/devolutions-session/src/dvc/process.rs b/devolutions-session/src/dvc/process.rs index c7c56f3fa..4149e70a0 100644 --- a/devolutions-session/src/dvc/process.rs +++ b/devolutions-session/src/dvc/process.rs @@ -71,6 +71,7 @@ pub enum ServerChannelEvent { CloseChannel, SessionStarted { session_id: u32, + process_id: u32, }, SessionDataOut { session_id: u32, @@ -159,7 +160,10 @@ impl WinApiProcessCtx { const WAIT_OBJECT_INPUT_MESSAGE: WAIT_EVENT = WAIT_OBJECT_0; const WAIT_OBJECT_PROCESS_EXIT: WAIT_EVENT = WAIT_EVENT(WAIT_OBJECT_0.0 + 1); - io_notification_tx.blocking_send(ServerChannelEvent::SessionStarted { session_id })?; + io_notification_tx.blocking_send(ServerChannelEvent::SessionStarted { + session_id, + process_id: self.pid, + })?; loop { // SAFETY: No preconditions. @@ -288,7 +292,10 @@ impl WinApiProcessCtx { // Signal client side about started execution - io_notification_tx.blocking_send(ServerChannelEvent::SessionStarted { session_id })?; + io_notification_tx.blocking_send(ServerChannelEvent::SessionStarted { + session_id, + process_id: self.pid, + })?; info!(session_id, "Process IO is ready for async loop execution"); loop { @@ -557,6 +564,7 @@ pub struct WinApiProcessBuilder { encoding: DataEncoding, env: HashMap, temp_files: Vec, + kill_on_drop: bool, } impl WinApiProcessBuilder { @@ -569,6 +577,7 @@ impl WinApiProcessBuilder { encoding: DataEncoding::from_oem_codepage(), env: HashMap::new(), temp_files: Vec::new(), + kill_on_drop: false, } } @@ -608,6 +617,12 @@ impl WinApiProcessBuilder { self } + #[must_use] + pub fn with_kill_on_drop(mut self, enable: bool) -> Self { + self.kill_on_drop = enable; + self + } + fn run_impl( mut self, session_id: u32, @@ -637,6 +652,7 @@ impl WinApiProcessBuilder { let io_redirection = self.enable_io_redirection; let encoding = self.encoding; + let kill_on_drop = self.kill_on_drop; let process_ctx = if io_redirection { prepare_process_with_io_redirection(session_id, command_line, current_directory, self.env, encoding)? @@ -666,6 +682,9 @@ impl WinApiProcessBuilder { // Create channel for `task` -> `Process IO thread` communication let (input_event_tx, input_event_rx) = winapi_signaled_mpsc_channel()?; + let kill_on_drop_process = kill_on_drop + .then(|| process_ctx.process.handle.try_clone().map(Process::from)) + .transpose()?; let io_notification_tx = io_notification_tx.expect("BUG: io_notification_tx must be Some for non-detached mode"); @@ -690,6 +709,7 @@ impl WinApiProcessBuilder { Ok(Some(WinApiProcess { input_event_tx, join_handle, + kill_on_drop_process, _temp_files: temp_files, })) } @@ -870,11 +890,18 @@ fn prepare_process_with_io_redirection( pub struct WinApiProcess { input_event_tx: WinapiSignaledSender, join_handle: std::thread::JoinHandle<()>, + kill_on_drop_process: Option, _temp_files: Vec, } impl Drop for WinApiProcess { fn drop(&mut self) { + if let Some(process) = self.kill_on_drop_process.take() + && let Err(error) = process.terminate(1) + { + trace!(%error, "Failed to terminate process on drop"); + } + // Ensure that the input event channel is closed when the process is dropped. // This will signal the IO thread to terminate if it is still running. let _ = self.input_event_tx.try_send(ProcessIoInputEvent::TerminateIo); @@ -903,6 +930,10 @@ impl WinApiProcess { Ok(()) } + pub fn disable_kill_on_drop(&mut self) { + self.kill_on_drop_process = None; + } + pub fn is_session_terminated(&self) -> bool { self.join_handle.is_finished() } diff --git a/devolutions-session/src/dvc/task.rs b/devolutions-session/src/dvc/task.rs index 21e47cba8..234f7c94a 100644 --- a/devolutions-session/src/dvc/task.rs +++ b/devolutions-session/src/dvc/task.rs @@ -201,7 +201,7 @@ async fn process_messages( match task_rx { Some(notification) => { match notification { - ServerChannelEvent::SessionStarted { session_id } => { + ServerChannelEvent::SessionStarted { session_id, .. } => { info!(session_id, "Session started"); let message = NowExecStartedMsg::new(session_id); dvc_tx.send(message.into()).await?;