chore: fix test

This commit is contained in:
nathan 2024-06-25 22:04:05 +08:00
parent 6ae00b8aef
commit 94f7add54d
15 changed files with 129 additions and 877 deletions

View File

@ -747,7 +747,7 @@ dependencies = [
"faccess",
"lazy_static",
"log",
"os_pipe 0.9.2",
"os_pipe",
]
[[package]]
@ -1208,16 +1208,6 @@ dependencies = [
"memchr",
]
[[package]]
name = "ctor"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edb49164822f3ee45b17acd4a208cfc1251410cf0cad9a833234c9890774dd9f"
dependencies = [
"quote",
"syn 2.0.47",
]
[[package]]
name = "ctr"
version = "0.9.2"
@ -2156,17 +2146,12 @@ version = "0.1.0"
dependencies = [
"anyhow",
"crossbeam-utils",
"ctor",
"dotenv",
"encoding_rs",
"log",
"memchr",
"once_cell",
"os_pipe 1.2.0",
"parking_lot 0.12.1",
"serde",
"serde_json",
"shared_child",
"thiserror",
"tokio",
"tracing",
@ -3671,16 +3656,6 @@ dependencies = [
"winapi",
]
[[package]]
name = "os_pipe"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29d73ba8daf8fac13b0501d1abeddcfe21ba7401ada61a819144b6c2a4f32209"
dependencies = [
"libc",
"windows-sys 0.52.0",
]
[[package]]
name = "overload"
version = "0.1.1"
@ -5111,16 +5086,6 @@ dependencies = [
"uuid",
]
[[package]]
name = "shared_child"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0d94659ad3c2137fef23ae75b03d5241d633f8acded53d672decfa0e6e0caef"
dependencies = [
"libc",
"winapi",
]
[[package]]
name = "shlex"
version = "1.2.0"

View File

@ -41,6 +41,9 @@ pub enum RemoteError {
/// clients.
#[error("Invalid request: {0:?}")]
InvalidRequest(Option<Value>),
#[error("Invalid response: {0}")]
InvalidResponse(Value),
/// A custom error, defined by the client.
#[error("Custom error: {message}")]
Custom {
@ -134,20 +137,18 @@ impl Serialize for RemoteError {
where
S: Serializer,
{
let (code, message, data) = match *self {
RemoteError::InvalidRequest(ref d) => (-32600, "Invalid request", d),
let (code, message, data) = match self {
RemoteError::InvalidRequest(ref d) => (-32600, "Invalid request".to_string(), d.clone()),
RemoteError::Custom {
code,
ref message,
ref data,
} => (code, message.as_ref(), data),
RemoteError::Unknown(_) => panic!(
"The 'Unknown' error variant is \
not intended for client use."
),
} => (*code, message.clone(), data.clone()),
RemoteError::Unknown(_) => {
panic!("The 'Unknown' error variant is not intended for client use.")
},
RemoteError::InvalidResponse(s) => (-1, "Invalid response".to_string(), Some(s.clone())),
};
let message = message.to_owned();
let data = data.to_owned();
let err = ErrorHelper {
code,
message,

View File

@ -1,12 +1,13 @@
use crate::error::{ReadError, RemoteError};
use crate::parser::ResponseParser;
use crate::plugin::{start_plugin_process, Plugin, PluginId, PluginInfo, RpcCtx};
use crate::rpc_loop::Handler;
use crate::rpc_peer::PluginCommand;
use anyhow::{anyhow, Result};
use parking_lot::{Mutex, RwLock};
use parking_lot::Mutex;
use serde_json::{json, Value};
use std::io;
use std::sync::atomic::{AtomicI64, AtomicU8, Ordering};
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::{Arc, Weak};
use tracing::{trace, warn};
@ -32,14 +33,9 @@ impl SidecarManager {
Ok(plugin_id)
}
pub async fn kill_plugin(&self, id: PluginId) -> Result<()> {
let state = self.state.lock();
let plugin = state
.plugins
.iter()
.find(|p| p.id == id)
.ok_or(anyhow!("plugin not found"))?;
plugin.shutdown()
pub async fn remove_plugin(&self, id: PluginId) -> Result<()> {
let mut state = self.state.lock();
state.plugin_disconnect(id, Ok(()));
Ok(())
}
@ -55,15 +51,21 @@ impl SidecarManager {
Ok(())
}
pub fn send_request(&self, id: PluginId, method: &str, request: Value) -> Result<()> {
pub fn send_request<P: ResponseParser>(
&self,
id: PluginId,
method: &str,
request: Value,
) -> Result<P::ValueType> {
let state = self.state.lock();
let plugin = state
.plugins
.iter()
.find(|p| p.id == id)
.ok_or(anyhow!("plugin not found"))?;
plugin.send_request(method, &request)?;
Ok(())
let resp = plugin.send_request(method, &request)?;
let value = P::parse_response(resp)?;
Ok(value)
}
}
@ -75,7 +77,7 @@ impl SidecarState {
pub fn plugin_connect(&mut self, plugin: Result<Plugin, io::Error>) {
match plugin {
Ok(plugin) => {
warn!("plugin connected: {:?}", plugin.id);
trace!("plugin connected: {:?}", plugin.id);
self.plugins.push(plugin);
},
Err(err) => {
@ -84,8 +86,10 @@ impl SidecarState {
}
}
pub fn plugin_exit(&mut self, id: PluginId, error: Result<(), ReadError>) {
warn!("plugin {:?} exited with result {:?}", id, error);
pub fn plugin_disconnect(&mut self, id: PluginId, error: Result<(), ReadError>) {
if let Err(err) = error {
warn!("[RPC] plugin {:?} exited with result {:?}", id, err);
}
let running_idx = self.plugins.iter().position(|p| p.id == id);
if let Some(idx) = running_idx {
let plugin = self.plugins.remove(idx);
@ -110,7 +114,7 @@ impl WeakSidecarState {
pub fn plugin_exit(&self, plugin: PluginId, error: Result<(), ReadError>) {
if let Some(core) = self.upgrade() {
core.lock().plugin_exit(plugin, error)
core.lock().plugin_disconnect(plugin, error)
}
}
}
@ -118,7 +122,7 @@ impl WeakSidecarState {
impl Handler for WeakSidecarState {
type Request = PluginCommand<String>;
fn handle_request(&mut self, ctx: &RpcCtx, rpc: Self::Request) -> Result<Value, RemoteError> {
fn handle_request(&mut self, _ctx: &RpcCtx, rpc: Self::Request) -> Result<Value, RemoteError> {
trace!("handling request: {:?}", rpc.cmd);
Ok(json!({}))
}

View File

@ -1,8 +1,7 @@
use crate::error::{ReadError, RemoteError};
use crate::rpc_loop::RpcObject;
use serde_json::Value;
use serde_json::{json, Value};
use std::io::BufRead;
use tracing::trace;
#[derive(Debug, Default)]
pub struct MessageReader(String);
@ -31,12 +30,15 @@ impl MessageReader {
/// This should not be called directly unless you are writing tests.
#[doc(hidden)]
pub fn parse(&self, s: &str) -> Result<RpcObject, ReadError> {
trace!("parsing message: {}", s);
let val = serde_json::from_str::<Value>(s)?;
if !val.is_object() {
Err(ReadError::NotObject(s.to_string()))
} else {
Ok(val.into())
match serde_json::from_str::<Value>(s) {
Ok(val) => {
if !val.is_object() {
Err(ReadError::NotObject(s.to_string()))
} else {
Ok(val.into())
}
},
Err(_) => Ok(RpcObject(json!({"message": s.to_string()}))),
}
}
}
@ -52,3 +54,24 @@ pub enum Call<R> {
/// not be parsed. The client will receive an error.
InvalidRequest(RequestId, RemoteError),
}
pub trait ResponseParser {
type ValueType;
fn parse_response(json: serde_json::Value) -> Result<Self::ValueType, RemoteError>;
}
pub struct ChatResponseParser;
impl ResponseParser for ChatResponseParser {
type ValueType = String;
fn parse_response(json: Value) -> Result<Self::ValueType, RemoteError> {
if json.is_object() {
if let Some(message) = json.get("data") {
if let Some(message) = message.as_str() {
return Ok(message.to_string());
}
}
}
return Err(RemoteError::InvalidResponse(json));
}
}

View File

@ -2,13 +2,12 @@ use crate::error::Error;
use crate::manager::WeakSidecarState;
use crate::rpc_loop::RpcLoop;
use anyhow::anyhow;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::io::BufReader;
use std::path::PathBuf;
use std::process::{Child, Stdio};
use std::sync::Arc;
use std::thread;
use std::time::Instant;
use tracing::{error, info};
@ -32,6 +31,8 @@ pub trait Callback: Send {
/// channel. It is intended to be used behind a pointer, a trait object.
pub trait Peer: Send + 'static {
fn box_clone(&self) -> Box<dyn Peer>;
fn send_rpc_notification(&self, method: &str, params: &Value);
fn send_rpc_request_async(&self, method: &str, params: &Value, f: Box<dyn Callback>);
/// Sends a request (synchronous RPC) to the peer, and waits for the result.
fn send_rpc_request(&self, method: &str, params: &Value) -> Result<Value, Error>;
@ -110,9 +111,8 @@ pub(crate) async fn start_plugin_process(
let mut looper = RpcLoop::new(child_stdin);
let peer: RpcPeer = Box::new(looper.get_raw_peer());
let name = plugin_info.name.clone();
if let Err(err) = peer.send_rpc_request("ping", &Value::Array(Vec::new())) {
error!("plugin {} failed to respond to ping: {:?}", name, err);
}
peer.send_rpc_notification("ping", &Value::Array(Vec::new()));
let plugin = Plugin {
peer,
process: child,
@ -123,7 +123,11 @@ pub(crate) async fn start_plugin_process(
state.plugin_connect(Ok(plugin));
let _ = tx.send(());
let mut state = state;
let err = looper.mainloop(|| BufReader::new(child_stdout), &mut state);
let err = looper.mainloop(
&plugin_info.name,
|| BufReader::new(child_stdout),
&mut state,
);
state.plugin_exit(id, err);
},
Err(err) => {

View File

@ -1,383 +0,0 @@
use std::{
collections::HashMap,
io::{BufReader, Write},
path::PathBuf,
process::{Command as StdCommand, Stdio},
sync::{Arc, Mutex, RwLock},
};
#[cfg(unix)]
use std::os::unix::process::ExitStatusExt;
#[cfg(windows)]
use std::os::windows::process::CommandExt;
#[cfg(windows)]
const CREATE_NO_WINDOW: u32 = 0x0800_0000;
use crate::process::runtime::block_on;
use crate::utils;
use crate::utils::platform;
use anyhow::{anyhow, Result};
pub use encoding_rs::Encoding;
use os_pipe::{pipe, PipeReader, PipeWriter};
use serde::Serialize;
use shared_child::SharedChild;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tracing::{error, trace};
type ChildStore = Arc<Mutex<HashMap<u32, Arc<SharedChild>>>>;
fn commands() -> &'static ChildStore {
use once_cell::sync::Lazy;
static STORE: Lazy<ChildStore> = Lazy::new(Default::default);
&STORE
}
/// Kills all child processes created with [`SidecarCommand`].
pub fn kill_children() {
let commands = commands().lock().unwrap();
let children = commands.values();
for child in children {
let _ = child.kill();
}
}
/// Payload for the [`CommandEvent::Terminated`] command event.
#[derive(Debug, Clone, Serialize)]
pub struct TerminatedPayload {
/// Exit code of the process.
pub code: Option<i32>,
/// If the process was terminated by a signal, represents that signal.
pub signal: Option<i32>,
}
/// A event sent to the command callback.
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "event", content = "payload")]
#[non_exhaustive]
pub enum CommandEvent {
/// Stderr bytes until a newline (\n) or carriage return (\r) is found.
Stderr(String),
/// Stdout bytes until a newline (\n) or carriage return (\r) is found.
Stdout(String),
/// An error happened waiting for the command to finish or converting the stdout/stderr bytes to an UTF-8 string.
Error(String),
/// Command process terminated.
Terminated(TerminatedPayload),
}
/// The type to spawn commands.
#[derive(Debug)]
pub struct SidecarCommand {
program: String,
args: Vec<String>,
env_clear: bool,
env: HashMap<String, String>,
current_dir: Option<PathBuf>,
encoding: Option<&'static Encoding>,
}
/// Spawned child process.
#[derive(Debug)]
pub struct CommandChild {
inner: Arc<SharedChild>,
stdin_writer: PipeWriter,
}
impl CommandChild {
/// Writes to process stdin.
pub fn write(&mut self, buf: &[u8]) -> Result<()> {
self.stdin_writer.write_all(buf)?;
Ok(())
}
pub fn write_json(&mut self, value: serde_json::Value) -> Result<()> {
let s = value.to_string();
self.write(s.as_bytes())?;
Ok(())
}
/// Sends a kill signal to the child.
pub fn kill(self) -> Result<()> {
self.inner.kill()?;
Ok(())
}
/// Returns the process pid.
pub fn pid(&self) -> u32 {
self.inner.id()
}
}
/// Describes the result of a process after it has terminated.
#[derive(Debug)]
pub struct ExitStatus {
code: Option<i32>,
}
impl ExitStatus {
/// Returns the exit code of the process, if any.
pub fn code(&self) -> Option<i32> {
self.code
}
/// Returns true if exit status is zero. Signal termination is not considered a success, and success is defined as a zero exit status.
pub fn success(&self) -> bool {
self.code == Some(0)
}
}
/// The output of a finished process.
#[derive(Debug)]
pub struct Output {
/// The status (exit code) of the process.
pub status: ExitStatus,
/// The data that the process wrote to stdout.
pub stdout: String,
/// The data that the process wrote to stderr.
pub stderr: String,
}
#[allow(dead_code)]
fn relative_command_path(command: String) -> Result<String> {
match platform::current_exe()?.parent() {
#[cfg(windows)]
Some(exe_dir) => Ok(format!("{}\\{command}.exe", exe_dir.display())),
#[cfg(not(windows))]
Some(exe_dir) => Ok(format!("{}/{command}", exe_dir.display())),
None => Err(anyhow!("Could not evaluate executable dir".to_string())),
}
}
impl From<SidecarCommand> for StdCommand {
fn from(cmd: SidecarCommand) -> StdCommand {
let mut command = StdCommand::new(cmd.program);
command.args(cmd.args);
command.stdout(Stdio::piped());
command.stdin(Stdio::piped());
command.stderr(Stdio::piped());
if cmd.env_clear {
command.env_clear();
}
command.envs(cmd.env);
if let Some(current_dir) = cmd.current_dir {
command.current_dir(current_dir);
}
#[cfg(windows)]
command.creation_flags(CREATE_NO_WINDOW);
command
}
}
impl SidecarCommand {
/// Creates a new Command for launching the given program.
pub fn new<S: Into<String>>(program: S) -> Self {
let program = program.into();
Self {
program,
args: Default::default(),
env_clear: false,
env: Default::default(),
current_dir: None,
encoding: None,
}
}
/// Creates a new Command for launching the given sidecar program.
///
/// A sidecar program is a embedded external binary in order to make your application work
/// or to prevent users having to install additional dependencies (e.g. Node.js, Python, etc).
pub fn new_sidecar<S: Into<String>>(program: S) -> Result<Self> {
let program = program.into();
Ok(Self::new(program))
}
/// Appends arguments to the command.
#[must_use]
pub fn args<I, S>(mut self, args: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
for arg in args {
self.args.push(arg.as_ref().to_string());
}
self
}
/// Clears the entire environment map for the child process.
#[must_use]
pub fn env_clear(mut self) -> Self {
self.env_clear = true;
self
}
/// Adds or updates multiple environment variable mappings.
#[must_use]
pub fn envs(mut self, env: HashMap<String, String>) -> Self {
self.env = env;
self
}
/// Sets the working directory for the child process.
#[must_use]
pub fn current_dir(mut self, current_dir: PathBuf) -> Self {
self.current_dir.replace(current_dir);
self
}
/// Sets the character encoding for stdout/stderr.
#[must_use]
pub fn encoding(mut self, encoding: &'static Encoding) -> Self {
self.encoding.replace(encoding);
self
}
pub fn spawn(self) -> Result<(Receiver<CommandEvent>, CommandChild)> {
let encoding = self.encoding;
let mut command: StdCommand = self.into();
let (stdout_reader, stdout_writer) = pipe()?;
let (stderr_reader, stderr_writer) = pipe()?;
let (stdin_reader, stdin_writer) = pipe()?;
command.stdout(stdout_writer);
command.stderr(stderr_writer);
command.stdin(stdin_reader);
let shared_child = SharedChild::spawn(&mut command)?;
let child = Arc::new(shared_child);
let child_ = child.clone();
let guard = Arc::new(RwLock::new(()));
commands().lock().unwrap().insert(child.id(), child.clone());
let (tx, rx) = channel(1);
spawn_pipe_reader(
tx.clone(),
guard.clone(),
stdout_reader,
CommandEvent::Stdout,
encoding,
);
spawn_pipe_reader(
tx.clone(),
guard.clone(),
stderr_reader,
CommandEvent::Stderr,
encoding,
);
std::thread::spawn(move || {
let _ = match child_.wait() {
Ok(status) => {
let _l = guard.write().unwrap();
commands().lock().unwrap().remove(&child_.id());
block_on(async move {
tx.send(CommandEvent::Terminated(TerminatedPayload {
code: status.code(),
#[cfg(windows)]
signal: None,
#[cfg(unix)]
signal: status.signal(),
}))
.await
})
},
Err(e) => {
let _l = guard.write().unwrap();
block_on(async move { tx.send(CommandEvent::Error(e.to_string())).await })
},
};
});
Ok((
rx,
CommandChild {
inner: child,
stdin_writer,
},
))
}
pub async fn status(self) -> Result<ExitStatus> {
let (mut rx, _child) = self.spawn()?;
let mut code = None;
#[allow(clippy::collapsible_match)]
while let Some(event) = rx.recv().await {
if let CommandEvent::Terminated(payload) = event {
code = payload.code;
}
}
Ok(ExitStatus { code })
}
pub async fn output(self) -> Result<Output> {
let (mut rx, _child) = self.spawn()?;
let mut code = None;
let mut stdout = String::new();
let mut stderr = String::new();
while let Some(event) = rx.recv().await {
match event {
CommandEvent::Terminated(payload) => {
code = payload.code;
},
CommandEvent::Stdout(line) => {
stdout.push_str(line.as_str());
stdout.push('\n');
},
CommandEvent::Stderr(line) => {
stderr.push_str(line.as_str());
stderr.push('\n');
},
CommandEvent::Error(_) => {},
}
}
Ok(Output {
status: ExitStatus { code },
stdout,
stderr,
})
}
}
fn spawn_pipe_reader<F: Fn(String) -> CommandEvent + Send + Copy + 'static>(
tx: Sender<CommandEvent>,
guard: Arc<RwLock<()>>,
pipe_reader: PipeReader,
wrapper: F,
character_encoding: Option<&'static Encoding>,
) {
std::thread::spawn(move || {
let _lock = guard.read().unwrap();
let mut reader = BufReader::new(pipe_reader);
let mut buf = Vec::new();
loop {
buf.clear();
match utils::io::read_line(&mut reader, &mut buf) {
Ok(n) => {
if n == 0 {
break;
}
let tx_ = tx.clone();
let line = match character_encoding {
Some(encoding) => Ok(encoding.decode_with_bom_removal(&buf).0.into()),
None => String::from_utf8(buf.clone()),
};
block_on(async move {
let _ = match line {
Ok(line) => {
trace!("{}", line);
tx_.send(wrapper(line)).await
},
Err(e) => tx_.send(CommandEvent::Error(e.to_string())).await,
};
});
},
Err(e) => {
let tx_ = tx.clone();
let _ = block_on(async move { tx_.send(CommandEvent::Error(e.to_string())).await });
},
}
}
});
}

View File

@ -1,4 +0,0 @@
mod command;
mod runtime;
pub use command::*;

View File

@ -1,86 +0,0 @@
use once_cell::sync::OnceCell;
use std::future::Future;
use tokio::runtime::{Handle, Runtime};
use tokio::task::JoinHandle;
static RUNTIME: OnceCell<GlobalRuntime> = OnceCell::new();
struct GlobalRuntime {
runtime: Option<Runtime>,
handle: Handle,
}
impl GlobalRuntime {
fn handle(&self) -> &Handle {
if let Some(r) = &self.runtime {
r.handle()
} else {
&self.handle
}
}
fn spawn<F: Future>(&self, task: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
if let Some(r) = &self.runtime {
r.spawn(task)
} else {
self.handle.spawn(task)
}
}
pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
if let Some(r) = &self.runtime {
r.spawn_blocking(func)
} else {
self.handle.spawn_blocking(func)
}
}
fn block_on<F: Future>(&self, task: F) -> F::Output {
if let Some(r) = &self.runtime {
r.block_on(task)
} else {
self.handle.block_on(task)
}
}
}
pub fn block_on<F: Future>(task: F) -> F::Output {
let runtime = RUNTIME.get_or_init(default_runtime);
runtime.block_on(task)
}
pub fn spawn<F>(task: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let runtime = RUNTIME.get_or_init(default_runtime);
runtime.spawn(task)
}
pub fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let runtime = RUNTIME.get_or_init(default_runtime);
runtime.spawn_blocking(func)
}
fn default_runtime() -> GlobalRuntime {
let runtime = Runtime::new().unwrap();
let handle = runtime.handle().clone();
GlobalRuntime {
runtime: Some(runtime),
handle,
}
}

View File

@ -2,11 +2,11 @@ use crate::error::{Error, ReadError, RemoteError};
use crate::parser::{Call, MessageReader, RequestId};
use crate::plugin::RpcCtx;
use crate::rpc_peer::{RawPeer, Response, RpcState};
use serde::de::DeserializeOwned;
use serde::de::{DeserializeOwned, Error as SerdeError};
use serde_json::Value;
use std::io::{BufRead, Write};
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
@ -34,17 +34,11 @@ impl RpcObject {
self.0.get("id").is_some() && self.0.get("method").is_none()
}
/// Attempts to convert the underlying `Value` into an RPC response
/// object, and returns the result.
///
/// The caller is expected to verify that the object is a response
/// before calling this method.
///
/// Converts the underlying `Value` into an RPC response object.
/// The caller should verify that the object is a response before calling this method.
/// # Errors
///
/// If the `Value` is not a well formed response object, this will
/// return a `String` containing an error message. The caller should
/// print this message and exit.
/// If the `Value` is not a well-formed response object, this returns a `String` containing an
/// error message. The caller should print this message and exit.
pub fn into_response(mut self) -> Result<Response, String> {
let _ = self
.get_id()
@ -68,13 +62,7 @@ impl RpcObject {
}
}
/// Attempts to convert the underlying `Value` into either an RPC
/// notification or request.
///
/// # Errors
///
/// Returns a `serde_json::Error` if the `Value` cannot be converted
/// to one of the expected types.
/// Converts the underlying `Value` into either an RPC notification or request.
pub fn into_rpc<R>(self) -> Result<Call<R>, serde_json::Error>
where
R: DeserializeOwned,
@ -85,7 +73,10 @@ impl RpcObject {
Ok(resp) => Ok(Call::Request(id, resp)),
Err(err) => Ok(Call::InvalidRequest(id, err.into())),
},
None => Ok(Call::Message(self.0)),
None => match self.0.get("message").and_then(|value| value.as_str()) {
None => Err(serde_json::Error::missing_field("message")),
Some(s) => Ok(Call::Message(s.to_string().into())),
},
}
}
}
@ -138,25 +129,16 @@ impl<W: Write + Send> RpcLoop<W> {
self.peer.clone()
}
/// Starts the event loop, reading lines from the reader until EOF,
/// or an error occurs.
/// Starts the event loop, reading lines from the reader until EOF or an error occurs.
///
/// Returns `Ok()` in the EOF case, otherwise returns the
/// underlying `ReadError`.
/// Returns `Ok()` if EOF is reached, otherwise returns the underlying `ReadError`.
///
/// # Note:
/// The reader is supplied via a closure, as basically a workaround
/// so that the reader doesn't have to be `Send`. Internally, the
/// main loop starts a separate thread for I/O, and at startup that
/// thread calls the given closure.
///
/// Calls to the handler happen on the caller's thread.
///
/// Calls to the handler are guaranteed to preserve the order as
/// they appear on on the channel. At the moment, there is no way
/// for there to be more than one incoming request to be outstanding.
/// The reader is provided via a closure to avoid needing `Send`. The main loop runs on a separate I/O thread that calls this closure at startup.
/// Calls to the handler occur on the caller's thread and maintain the order from the channel. Currently, there can only be one outstanding incoming request.
pub fn mainloop<R, BufferReadFn, H>(
&mut self,
plugin_name: &str,
buffer_read_fn: BufferReadFn,
handler: &mut H,
) -> Result<(), ReadError>
@ -189,13 +171,9 @@ impl<W: Write + Send> RpcLoop<W> {
let json = match self.reader.next(&mut stream) {
Ok(json) => json,
Err(err) => {
// When the data can't be parsed into JSON. It means the data is not in the correct format.
// Probably the data comes from other stdout.
if self.peer.0.is_blocking() {
self.peer.disconnect();
}
error!("[RPC] failed to parse JSON: {:?}", err);
self.peer.put_rpc_object(Err(err));
break;
},
@ -237,16 +215,21 @@ impl<W: Write + Send> RpcLoop<W> {
match json.into_rpc::<H::Request>() {
Ok(Call::Request(id, cmd)) => {
// Handle request sent from the client. For example from python executable.
trace!("[RPC] received request: {}", id);
let result = handler.handle_request(&ctx, cmd);
peer.respond(result, id);
},
Ok(Call::InvalidRequest(id, err)) => peer.respond(Err(err), id),
Ok(Call::InvalidRequest(id, err)) => {
trace!("[RPC] received invalid request: {}", id);
peer.respond(Err(err), id)
},
Err(err) => {
error!("[RPC] error parsing message: {:?}", err);
peer.disconnect();
return ReadError::UnknownRequest(err);
},
Ok(Call::Message(msg)) => {
trace!("[RPC] received message: {}", msg);
trace!("[RPC {}]: {}", plugin_name, msg);
},
}
}
@ -299,4 +282,4 @@ where
}
}
fn do_idle<H: Handler>(handler: &mut H, ctx: &RpcCtx, token: usize) {}
fn do_idle<H: Handler>(_handler: &mut H, _ctx: &RpcCtx, _token: usize) {}

View File

@ -84,6 +84,17 @@ impl<W: Write + Send + 'static> Peer for RawPeer<W> {
fn box_clone(&self) -> Box<dyn Peer> {
Box::new((*self).clone())
}
fn send_rpc_notification(&self, method: &str, params: &Value) {
if let Err(e) = self.send(&json!({
"method": method,
"params": params,
})) {
error!(
"send error on send_rpc_notification method {}: {}",
method, e
);
}
}
fn send_rpc_request_async(&self, method: &str, params: &Value, f: Box<dyn Callback>) {
self.send_rpc(method, params, ResponseHandler::Callback(f));
@ -132,7 +143,6 @@ impl<W: Write> RawPeer<W> {
}
fn send_rpc(&self, method: &str, params: &Value, rh: ResponseHandler) {
trace!("[RPC] method:{} params: {:?}", method, params);
let id = self.0.id.fetch_add(1, Ordering::Relaxed);
{
let mut pending = self.0.pending.lock();

View File

@ -1,41 +0,0 @@
use std::io::BufRead;
/// Read all bytes until a newline (the `0xA` byte) or a carriage return (`\r`) is reached, and append them to the provided buffer.
///
/// Adapted from <https://doc.rust-lang.org/std/io/trait.BufRead.html#method.read_line>.
pub fn read_line<R: BufRead + ?Sized>(r: &mut R, buf: &mut Vec<u8>) -> std::io::Result<usize> {
let mut read = 0;
loop {
let (done, used) = {
let available = match r.fill_buf() {
Ok(n) => n,
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
};
match memchr::memchr(b'\n', available) {
Some(i) => {
let end = i + 1;
buf.extend_from_slice(&available[..end]);
(true, end)
},
None => match memchr::memchr(b'\r', available) {
Some(i) => {
let end = i + 1;
buf.extend_from_slice(&available[..end]);
(true, end)
},
None => {
buf.extend_from_slice(available);
(false, available.len())
},
},
}
};
r.consume(used);
read += used;
if done || used == 0 {
return Ok(read);
}
}
}

View File

@ -1,3 +0,0 @@
pub(crate) mod io;
pub(crate) mod platform;
mod starting_binary;

View File

@ -1,135 +0,0 @@
use crate::error::Error;
use crate::utils::starting_binary;
use anyhow::Result;
use std::path::PathBuf;
/// Retrieves the currently running binary's path, taking into account security considerations.
///
/// The path is cached as soon as possible (before even `main` runs) and that value is returned
/// repeatedly instead of fetching the path every time. It is possible for the path to not be found,
/// or explicitly disabled (see following macOS specific behavior).
///
/// # Platform-specific behavior
///
/// On `macOS`, this function will return an error if the original path contained any symlinks
/// due to less protection on macOS regarding symlinks. This behavior can be disabled by setting the
/// `process-relaunch-dangerous-allow-symlink-macos` feature, although it is *highly discouraged*.
///
/// # Security
///
/// If the above platform-specific behavior does **not** take place, this function uses the
/// following resolution.
///
/// We canonicalize the path we received from [`std::env::current_exe`] to resolve any soft links.
/// This avoids the usual issue of needing the file to exist at the passed path because a valid
/// current executable result for our purpose should always exist. Notably,
/// [`std::env::current_exe`] also has a security section that goes over a theoretical attack using
/// hard links. Let's cover some specific topics that relate to different ways an attacker might
/// try to trick this function into returning the wrong binary path.
///
/// ## Symlinks ("Soft Links")
///
/// [`std::path::Path::canonicalize`] is used to resolve symbolic links to the original path,
/// including nested symbolic links (`link2 -> link1 -> bin`). On macOS, any results that include
/// a symlink are rejected by default due to lesser symlink protections. This can be disabled,
/// **although discouraged**, with the `process-relaunch-dangerous-allow-symlink-macos` feature.
///
/// ## Hard Links
///
/// A [Hard Link] is a named entry that points to a file in the file system.
/// On most systems, this is what you would think of as a "file". The term is
/// used on filesystems that allow multiple entries to point to the same file.
/// The linked [Hard Link] Wikipedia page provides a decent overview.
///
/// In short, unless the attacker was able to create the link with elevated
/// permissions, it should generally not be possible for them to hard link
/// to a file they do not have permissions to - with exception to possible
/// operating system exploits.
///
/// There are also some platform-specific information about this below.
///
/// ### Windows
///
/// Windows requires a permission to be set for the user to create a symlink
/// or a hard link, regardless of ownership status of the target. Elevated
/// permissions users have the ability to create them.
///
/// ### macOS
///
/// macOS allows for the creation of symlinks and hard links to any file.
/// Accessing through those links will fail if the user who owns the links
/// does not have the proper permissions on the original file.
///
/// ### Linux
///
/// Linux allows for the creation of symlinks to any file. Accessing the
/// symlink will fail if the user who owns the symlink does not have the
/// proper permissions on the original file.
///
/// Linux additionally provides a kernel hardening feature since version
/// 3.6 (30 September 2012). Most distributions since then have enabled
/// the protection (setting `fs.protected_hardlinks = 1`) by default, which
/// means that a vast majority of desktop Linux users should have it enabled.
/// **The feature prevents the creation of hardlinks that the user does not own
/// or have read/write access to.** [See the patch that enabled this].
///
/// [Hard Link]: https://en.wikipedia.org/wiki/Hard_link
/// [See the patch that enabled this]: https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/commit/?id=800179c9b8a1e796e441674776d11cd4c05d61d7
#[allow(dead_code)]
pub fn current_exe() -> std::io::Result<PathBuf> {
starting_binary::STARTING_BINARY.cloned()
}
/// Try to determine the current target triple.
///
/// Returns a target triple (e.g. `x86_64-unknown-linux-gnu` or `i686-pc-windows-msvc`) or an
/// `Error::Config` if the current config cannot be determined or is not some combination of the
/// following values:
/// `linux, mac, windows` -- `i686, x86, armv7` -- `gnu, musl, msvc`
///
/// * Errors:
/// * Unexpected system config
#[allow(dead_code)]
pub fn target_triple() -> Result<String, Error> {
let arch = if cfg!(target_arch = "x86") {
"i686"
} else if cfg!(target_arch = "x86_64") {
"x86_64"
} else if cfg!(target_arch = "arm") {
"armv7"
} else if cfg!(target_arch = "aarch64") {
"aarch64"
} else {
return Err(Error::Architecture);
};
let os = if cfg!(target_os = "linux") {
"unknown-linux"
} else if cfg!(target_os = "macos") {
"apple-darwin"
} else if cfg!(target_os = "windows") {
"pc-windows"
} else if cfg!(target_os = "freebsd") {
"unknown-freebsd"
} else {
return Err(Error::Os);
};
let os = if cfg!(target_os = "macos") || cfg!(target_os = "freebsd") {
String::from(os)
} else {
let env = if cfg!(target_env = "gnu") {
"gnu"
} else if cfg!(target_env = "musl") {
"musl"
} else if cfg!(target_env = "msvc") {
"msvc"
} else {
return Err(Error::Environment);
};
format!("{os}-{env}")
};
Ok(format!("{arch}-{os}"))
}

View File

@ -1,73 +0,0 @@
use ctor::ctor;
use std::{
io::{Error, ErrorKind, Result},
path::{Path, PathBuf},
};
/// A cached version of the current binary using [`ctor`] to cache it before even `main` runs.
#[ctor]
#[used]
pub(super) static STARTING_BINARY: StartingBinary = StartingBinary::new();
/// Represents a binary path that was cached when the program was loaded.
pub(super) struct StartingBinary(std::io::Result<PathBuf>);
impl StartingBinary {
/// Find the starting executable as safely as possible.
fn new() -> Self {
// see notes on current_exe() for security implications
let dangerous_path = match std::env::current_exe() {
Ok(dangerous_path) => dangerous_path,
error @ Err(_) => return Self(error),
};
// note: this only checks symlinks on problematic platforms, see implementation below
if let Some(symlink) = Self::has_symlink(&dangerous_path) {
let msg = format!(
"StartingBinary found current_exe() that contains a symlink on a non-allowed platform: {}",
symlink.display()
);
return Self(Err(Error::new(ErrorKind::InvalidData, msg)));
}
// we canonicalize the path to resolve any symlinks to the real exe path
Self(dangerous_path.canonicalize())
}
#[allow(dead_code)]
pub(super) fn cloned(&self) -> Result<PathBuf> {
self
.0
.as_ref()
.map(Clone::clone)
.map_err(|e| Error::new(e.kind(), e.to_string()))
}
/// We only care about checking this on macOS currently, as it has the least symlink protections.
#[cfg(any(
not(target_os = "macos"),
feature = "process-relaunch-dangerous-allow-symlink-macos"
))]
fn has_symlink(_: &Path) -> Option<&Path> {
None
}
/// We only care about checking this on macOS currently, as it has the least symlink protections.
#[cfg(all(
target_os = "macos",
not(feature = "process-relaunch-dangerous-allow-symlink-macos")
))]
fn has_symlink(path: &Path) -> Option<&Path> {
path.ancestors().find(|ancestor| {
matches!(
ancestor
.symlink_metadata()
.as_ref()
.map(std::fs::Metadata::file_type)
.as_ref()
.map(std::fs::FileType::is_symlink),
Ok(true)
)
})
}
}

View File

@ -1,9 +1,10 @@
use anyhow::Result;
use flowy_sidecar::manager::SidecarManager;
use flowy_sidecar::parser::ChatResponseParser;
use flowy_sidecar::plugin::PluginInfo;
use serde_json::json;
use std::sync::Once;
use tracing::info;
use tracing_subscriber::fmt::Subscriber;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::EnvFilter;
@ -26,39 +27,24 @@ async fn load_chat_model_test() {
)
.unwrap();
let _json = json!({
"plugin_id": "example_plugin_id",
"method": "initialize",
"params": {
"absolute_chat_model_path":config.chat_model_absolute_path(),
}
});
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_secs(15)).await;
manager.kill_plugin(plugin_id).await.unwrap();
})
let chat_id = uuid::Uuid::new_v4().to_string();
let resp = manager
.send_request::<ChatResponseParser>(
plugin_id,
"handle",
json!({"chat_id": chat_id, "method": "answer", "params": {"content": "hello world"}}),
)
.unwrap();
// let (mut rx, mut child) = SidecarCommand::new_sidecar(&config.chat_bin_path)
// .unwrap()
// .spawn()
// .unwrap();
//
// tokio::spawn(async move {
// while let Some(event) = rx.recv().await {
// info!("event: {:?}", event);
// }
// });
//
// let json = json!({
// "plugin_id": "example_plugin_id",
// "method": "initialize",
// "params": {
// "absolute_chat_model_path":config.chat_model_absolute_path(),
// }
// });
// child.write_json(json).unwrap();
// let chat_id = uuid::Uuid::new_v4().to_string();
// let json =
// json!({"chat_id": chat_id, "method": "answer", "params": {"content": "hello world"}});
// child.write_json(json).unwrap();
//
// tokio::time::sleep(tokio::time::Duration::from_secs(15)).await;
// child.kill().unwrap();
eprintln!("chat response: {:?}", resp);
}
}
@ -100,6 +86,7 @@ pub fn setup_log() {
let subscriber = Subscriber::builder()
.with_env_filter(EnvFilter::from_default_env())
.with_line_number(true)
.with_ansi(true)
.finish();
subscriber.try_init().unwrap();