From 6ae00b8aef54a75e5cf36a3c7560b4bee016e242 Mon Sep 17 00:00:00 2001 From: nathan Date: Tue, 25 Jun 2024 15:53:01 +0800 Subject: [PATCH] chore: sidecar --- frontend/rust-lib/Cargo.lock | 10 +- frontend/rust-lib/flowy-sidecar/src/error.rs | 182 +++++++++-- .../rust-lib/flowy-sidecar/src/manager.rs | 125 ++++++++ frontend/rust-lib/flowy-sidecar/src/parser.rs | 54 ++++ frontend/rust-lib/flowy-sidecar/src/plugin.rs | 142 ++++++++ .../rust-lib/flowy-sidecar/src/rpc_loop.rs | 302 ++++++++++++++++++ .../rust-lib/flowy-sidecar/src/rpc_peer.rs | 279 ++++++++++++++++ .../flowy-sidecar/tests/chat_bin_test.rs | 57 ++-- 8 files changed, 1100 insertions(+), 51 deletions(-) create mode 100644 frontend/rust-lib/flowy-sidecar/src/manager.rs create mode 100644 frontend/rust-lib/flowy-sidecar/src/parser.rs create mode 100644 frontend/rust-lib/flowy-sidecar/src/plugin.rs create mode 100644 frontend/rust-lib/flowy-sidecar/src/rpc_loop.rs create mode 100644 frontend/rust-lib/flowy-sidecar/src/rpc_peer.rs diff --git a/frontend/rust-lib/Cargo.lock b/frontend/rust-lib/Cargo.lock index 0177512763..e575382b85 100644 --- a/frontend/rust-lib/Cargo.lock +++ b/frontend/rust-lib/Cargo.lock @@ -1143,12 +1143,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.16" +version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" -dependencies = [ - "cfg-if", -] +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" [[package]] name = "crunchy" @@ -2158,12 +2155,15 @@ name = "flowy-sidecar" version = "0.1.0" dependencies = [ "anyhow", + "crossbeam-utils", "ctor", "dotenv", "encoding_rs", + "log", "memchr", "once_cell", "os_pipe 1.2.0", + "parking_lot 0.12.1", "serde", "serde_json", "shared_child", diff --git a/frontend/rust-lib/flowy-sidecar/src/error.rs b/frontend/rust-lib/flowy-sidecar/src/error.rs index 5f7cc7d853..c1e97c2891 100644 --- a/frontend/rust-lib/flowy-sidecar/src/error.rs +++ b/frontend/rust-lib/flowy-sidecar/src/error.rs @@ -1,32 +1,158 @@ +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde_json::{json, Value}; +use std::{fmt, io}; + /// The error type of `tauri-utils`. #[derive(Debug, thiserror::Error)] -#[non_exhaustive] pub enum Error { - /// Target triple architecture error - #[error("Unable to determine target-architecture")] - Architecture, - /// Target triple OS error - #[error("Unable to determine target-os")] - Os, - /// Target triple environment error - #[error("Unable to determine target-environment")] - Environment, - /// Tried to get resource on an unsupported platform - #[error("Unsupported platform for reading resources")] - UnsupportedPlatform, - /// Get parent process error - #[error("Could not get parent process")] - ParentProcess, - /// Get parent process PID error - #[error("Could not get parent PID")] - ParentPid, - /// Get child process error - #[error("Could not get child process")] - ChildProcess, - /// IO error - #[error("{0}")] - Io(#[from] std::io::Error), - /// Invalid pattern. - #[error("invalid pattern `{0}`. Expected either `brownfield` or `isolation`.")] - InvalidPattern(String), + /// An IO error occurred on the underlying communication channel. + #[error(transparent)] + Io(#[from] io::Error), + /// The peer returned an error. + #[error("Remote error: {0}")] + RemoteError(RemoteError), + /// The peer closed the connection. + #[error("Peer closed the connection.")] + PeerDisconnect, + /// The peer sent a response containing the id, but was malformed. + #[error("Invalid response.")] + InvalidResponse, +} + +#[derive(Debug)] +pub enum ReadError { + /// An error occurred in the underlying stream + Io(io::Error), + /// The message was not valid JSON. + Json(serde_json::Error), + /// The message was not a JSON object. + NotObject(String), + /// The the method and params were not recognized by the handler. + UnknownRequest(serde_json::Error), + /// The peer closed the connection. + Disconnect, +} + +#[derive(Debug, Clone, PartialEq, thiserror::Error)] +pub enum RemoteError { + /// The JSON was valid, but was not a correctly formed request. + /// + /// This Error is used internally, and should not be returned by + /// clients. + #[error("Invalid request: {0:?}")] + InvalidRequest(Option), + /// A custom error, defined by the client. + #[error("Custom error: {message}")] + Custom { + code: i64, + message: String, + data: Option, + }, + /// An error that cannot be represented by an error object. + /// + /// This error is intended to accommodate clients that return arbitrary + /// error values. It should not be used for new errors. + #[error("Unknown error: {0}")] + Unknown(Value), +} + +impl ReadError { + /// Returns `true` iff this is the `ReadError::Disconnect` variant. + pub fn is_disconnect(&self) -> bool { + matches!(*self, ReadError::Disconnect) + } +} + +impl fmt::Display for ReadError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ReadError::Io(ref err) => write!(f, "I/O Error: {:?}", err), + ReadError::Json(ref err) => write!(f, "JSON Error: {:?}", err), + ReadError::NotObject(s) => write!(f, "Expected JSON object, found: {}", s), + ReadError::UnknownRequest(ref err) => write!(f, "Unknown request: {:?}", err), + ReadError::Disconnect => write!(f, "Peer closed the connection."), + } + } +} + +impl From for ReadError { + fn from(err: serde_json::Error) -> ReadError { + ReadError::Json(err) + } +} + +impl From for ReadError { + fn from(err: io::Error) -> ReadError { + ReadError::Io(err) + } +} + +impl From for RemoteError { + fn from(err: serde_json::Error) -> RemoteError { + RemoteError::InvalidRequest(Some(json!(err.to_string()))) + } +} + +impl From for Error { + fn from(err: RemoteError) -> Error { + Error::RemoteError(err) + } +} + +#[derive(Deserialize, Serialize)] +struct ErrorHelper { + code: i64, + message: String, + #[serde(skip_serializing_if = "Option::is_none")] + data: Option, +} + +impl<'de> Deserialize<'de> for RemoteError { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let v = Value::deserialize(deserializer)?; + let resp = match ErrorHelper::deserialize(&v) { + Ok(resp) => resp, + Err(_) => return Ok(RemoteError::Unknown(v)), + }; + + Ok(match resp.code { + -32600 => RemoteError::InvalidRequest(resp.data), + _ => RemoteError::Custom { + code: resp.code, + message: resp.message, + data: resp.data, + }, + }) + } +} + +impl Serialize for RemoteError { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let (code, message, data) = match *self { + RemoteError::InvalidRequest(ref d) => (-32600, "Invalid request", d), + RemoteError::Custom { + code, + ref message, + ref data, + } => (code, message.as_ref(), data), + RemoteError::Unknown(_) => panic!( + "The 'Unknown' error variant is \ + not intended for client use." + ), + }; + let message = message.to_owned(); + let data = data.to_owned(); + let err = ErrorHelper { + code, + message, + data, + }; + err.serialize(serializer) + } } diff --git a/frontend/rust-lib/flowy-sidecar/src/manager.rs b/frontend/rust-lib/flowy-sidecar/src/manager.rs new file mode 100644 index 0000000000..0f10d6a88a --- /dev/null +++ b/frontend/rust-lib/flowy-sidecar/src/manager.rs @@ -0,0 +1,125 @@ +use crate::error::{ReadError, RemoteError}; +use crate::plugin::{start_plugin_process, Plugin, PluginId, PluginInfo, RpcCtx}; +use crate::rpc_loop::Handler; +use crate::rpc_peer::PluginCommand; +use anyhow::{anyhow, Result}; +use parking_lot::{Mutex, RwLock}; +use serde_json::{json, Value}; +use std::io; +use std::sync::atomic::{AtomicI64, AtomicU8, Ordering}; +use std::sync::{Arc, Weak}; +use tracing::{trace, warn}; + +pub struct SidecarManager { + state: Arc>, + plugin_id_counter: Arc, +} + +impl SidecarManager { + pub fn new() -> Self { + SidecarManager { + state: Arc::new(Mutex::new(SidecarState { + plugins: Vec::new(), + })), + plugin_id_counter: Arc::new(Default::default()), + } + } + + pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result { + let plugin_id = PluginId::from(self.plugin_id_counter.fetch_add(1, Ordering::SeqCst)); + let weak_state = WeakSidecarState(Arc::downgrade(&self.state)); + start_plugin_process(plugin_info, plugin_id, weak_state).await?; + Ok(plugin_id) + } + + pub async fn kill_plugin(&self, id: PluginId) -> Result<()> { + let state = self.state.lock(); + let plugin = state + .plugins + .iter() + .find(|p| p.id == id) + .ok_or(anyhow!("plugin not found"))?; + plugin.shutdown() + Ok(()) + } + + pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<()> { + let state = self.state.lock(); + let plugin = state + .plugins + .iter() + .find(|p| p.id == id) + .ok_or(anyhow!("plugin not found"))?; + plugin.initialize(init_params)?; + + Ok(()) + } + + pub fn send_request(&self, id: PluginId, method: &str, request: Value) -> Result<()> { + let state = self.state.lock(); + let plugin = state + .plugins + .iter() + .find(|p| p.id == id) + .ok_or(anyhow!("plugin not found"))?; + plugin.send_request(method, &request)?; + Ok(()) + } +} + +pub struct SidecarState { + plugins: Vec, +} + +impl SidecarState { + pub fn plugin_connect(&mut self, plugin: Result) { + match plugin { + Ok(plugin) => { + warn!("plugin connected: {:?}", plugin.id); + self.plugins.push(plugin); + }, + Err(err) => { + warn!("plugin failed to connect: {:?}", err); + }, + } + } + + pub fn plugin_exit(&mut self, id: PluginId, error: Result<(), ReadError>) { + warn!("plugin {:?} exited with result {:?}", id, error); + let running_idx = self.plugins.iter().position(|p| p.id == id); + if let Some(idx) = running_idx { + let plugin = self.plugins.remove(idx); + plugin.shutdown(); + } + } +} + +#[derive(Clone)] +pub struct WeakSidecarState(Weak>); + +impl WeakSidecarState { + pub fn upgrade(&self) -> Option>> { + self.0.upgrade() + } + + pub fn plugin_connect(&self, plugin: Result) { + if let Some(state) = self.upgrade() { + state.lock().plugin_connect(plugin) + } + } + + pub fn plugin_exit(&self, plugin: PluginId, error: Result<(), ReadError>) { + if let Some(core) = self.upgrade() { + core.lock().plugin_exit(plugin, error) + } + } +} + +impl Handler for WeakSidecarState { + type Request = PluginCommand; + + fn handle_request(&mut self, ctx: &RpcCtx, rpc: Self::Request) -> Result { + trace!("handling request: {:?}", rpc.cmd); + Ok(json!({})) + } +} diff --git a/frontend/rust-lib/flowy-sidecar/src/parser.rs b/frontend/rust-lib/flowy-sidecar/src/parser.rs new file mode 100644 index 0000000000..ea47c6e4b6 --- /dev/null +++ b/frontend/rust-lib/flowy-sidecar/src/parser.rs @@ -0,0 +1,54 @@ +use crate::error::{ReadError, RemoteError}; +use crate::rpc_loop::RpcObject; +use serde_json::Value; +use std::io::BufRead; +use tracing::trace; + +#[derive(Debug, Default)] +pub struct MessageReader(String); + +impl MessageReader { + /// Attempts to read the next line from the stream and parse it as + /// an RPC object. + /// + /// # Errors + /// + /// This function will return an error if there is an underlying + /// I/O error, if the stream is closed, or if the message is not + /// a valid JSON object. + pub fn next(&mut self, reader: &mut R) -> Result { + self.0.clear(); + let _ = reader.read_line(&mut self.0)?; + if self.0.is_empty() { + Err(ReadError::Disconnect) + } else { + self.parse(&self.0) + } + } + + /// Attempts to parse a &str as an RPC Object. + /// + /// This should not be called directly unless you are writing tests. + #[doc(hidden)] + pub fn parse(&self, s: &str) -> Result { + trace!("parsing message: {}", s); + let val = serde_json::from_str::(s)?; + if !val.is_object() { + Err(ReadError::NotObject(s.to_string())) + } else { + Ok(val.into()) + } + } +} + +pub type RequestId = u64; +#[derive(Debug, Clone, PartialEq)] +/// An RPC call, which may be either a notification or a request. +pub enum Call { + Message(Value), + /// An id and an RPC Request + Request(RequestId, R), + /// A malformed request: the request contained an id, but could + /// not be parsed. The client will receive an error. + InvalidRequest(RequestId, RemoteError), +} diff --git a/frontend/rust-lib/flowy-sidecar/src/plugin.rs b/frontend/rust-lib/flowy-sidecar/src/plugin.rs new file mode 100644 index 0000000000..fe592a7e02 --- /dev/null +++ b/frontend/rust-lib/flowy-sidecar/src/plugin.rs @@ -0,0 +1,142 @@ +use crate::error::Error; +use crate::manager::WeakSidecarState; +use crate::rpc_loop::RpcLoop; + +use anyhow::anyhow; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::io::BufReader; +use std::path::PathBuf; +use std::process::{Child, Stdio}; +use std::sync::Arc; +use std::thread; +use std::time::Instant; +use tracing::{error, info}; + +#[derive( + Default, Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, +)] +pub struct PluginId(pub(crate) i64); + +impl From for PluginId { + fn from(id: i64) -> Self { + PluginId(id) + } +} + +pub trait Callback: Send { + fn call(self: Box, result: Result); +} + +/// The `Peer` trait represents the interface for the other side of the RPC +/// channel. It is intended to be used behind a pointer, a trait object. +pub trait Peer: Send + 'static { + fn box_clone(&self) -> Box; + fn send_rpc_request_async(&self, method: &str, params: &Value, f: Box); + /// Sends a request (synchronous RPC) to the peer, and waits for the result. + fn send_rpc_request(&self, method: &str, params: &Value) -> Result; + /// Determines whether an incoming request (or notification) is + /// pending. This is intended to reduce latency for bulk operations + /// done in the background. + fn request_is_pending(&self) -> bool; + + fn schedule_idle(&self, token: usize); + /// Like `schedule_idle`, with the guarantee that the handler's `idle` + /// fn will not be called _before_ the provided `Instant`. + /// + /// # Note + /// + /// This is not intended as a high-fidelity timer. Regular RPC messages + /// will always take priority over an idle task. + fn schedule_timer(&self, after: Instant, token: usize); +} + +/// The `Peer` trait object. +pub type RpcPeer = Box; + +pub struct RpcCtx { + pub peer: RpcPeer, +} +pub struct Plugin { + peer: RpcPeer, + pub(crate) id: PluginId, + pub(crate) name: String, + #[allow(dead_code)] + process: Child, +} + +impl Plugin { + pub fn initialize(&self, value: Value) -> Result<(), Error> { + self.peer.send_rpc_request("initialize", &value)?; + Ok(()) + } + + pub fn send_request(&self, method: &str, params: &Value) -> Result { + self.peer.send_rpc_request(method, params) + } + + pub fn shutdown(&self) { + if let Err(err) = self.peer.send_rpc_request("shutdown", &json!({})) { + error!("error sending shutdown to plugin {}: {:?}", self.name, err); + } + } +} + +pub struct PluginInfo { + pub name: String, + // pub absolute_chat_model_path: String, + pub exec_path: String, +} + +pub(crate) async fn start_plugin_process( + plugin_info: PluginInfo, + id: PluginId, + state: WeakSidecarState, +) -> Result<(), anyhow::Error> { + let (tx, rx) = tokio::sync::oneshot::channel(); + let spawn_result = thread::Builder::new() + .name(format!("<{}> core host thread", &plugin_info.name)) + .spawn(move || { + info!("Load {} plugin", &plugin_info.name); + let child = std::process::Command::new(&plugin_info.exec_path) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn(); + + match child { + Ok(mut child) => { + let child_stdin = child.stdin.take().unwrap(); + let child_stdout = child.stdout.take().unwrap(); + let mut looper = RpcLoop::new(child_stdin); + let peer: RpcPeer = Box::new(looper.get_raw_peer()); + let name = plugin_info.name.clone(); + if let Err(err) = peer.send_rpc_request("ping", &Value::Array(Vec::new())) { + error!("plugin {} failed to respond to ping: {:?}", name, err); + } + let plugin = Plugin { + peer, + process: child, + name, + id, + }; + + state.plugin_connect(Ok(plugin)); + let _ = tx.send(()); + let mut state = state; + let err = looper.mainloop(|| BufReader::new(child_stdout), &mut state); + state.plugin_exit(id, err); + }, + Err(err) => { + let _ = tx.send(()); + state.plugin_connect(Err(err)) + }, + } + }); + + if let Err(err) = spawn_result { + error!("[RPC] thread spawn failed for {:?}, {:?}", id, err); + return Err(err.into()); + } + rx.await?; + Ok(()) +} diff --git a/frontend/rust-lib/flowy-sidecar/src/rpc_loop.rs b/frontend/rust-lib/flowy-sidecar/src/rpc_loop.rs new file mode 100644 index 0000000000..3de64b6ae8 --- /dev/null +++ b/frontend/rust-lib/flowy-sidecar/src/rpc_loop.rs @@ -0,0 +1,302 @@ +use crate::error::{Error, ReadError, RemoteError}; +use crate::parser::{Call, MessageReader, RequestId}; +use crate::plugin::RpcCtx; +use crate::rpc_peer::{RawPeer, Response, RpcState}; +use serde::de::DeserializeOwned; +use serde_json::Value; + +use std::io::{BufRead, Write}; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::thread; +use std::time::Duration; +use tracing::{error, trace}; + +const MAX_IDLE_WAIT: Duration = Duration::from_millis(5); +#[derive(Debug, Clone)] +pub struct RpcObject(pub Value); + +impl RpcObject { + /// Returns the 'id' of the underlying object, if present. + pub fn get_id(&self) -> Option { + self.0.get("id").and_then(Value::as_u64) + } + + /// Returns the 'method' field of the underlying object, if present. + pub fn get_method(&self) -> Option<&str> { + self.0.get("method").and_then(Value::as_str) + } + + /// Returns `true` if this object looks like an RPC response; + /// that is, if it has an 'id' field and does _not_ have a 'method' + /// field. + pub fn is_response(&self) -> bool { + self.0.get("id").is_some() && self.0.get("method").is_none() + } + + /// Attempts to convert the underlying `Value` into an RPC response + /// object, and returns the result. + /// + /// The caller is expected to verify that the object is a response + /// before calling this method. + /// + /// # Errors + /// + /// If the `Value` is not a well formed response object, this will + /// return a `String` containing an error message. The caller should + /// print this message and exit. + pub fn into_response(mut self) -> Result { + let _ = self + .get_id() + .ok_or("Response requires 'id' field.".to_string())?; + + if self.0.get("result").is_some() == self.0.get("error").is_some() { + return Err("RPC response must contain exactly one of 'error' or 'result' fields.".into()); + } + let result = self.0.as_object_mut().and_then(|obj| obj.remove("result")); + + match result { + Some(r) => Ok(Ok(r)), + None => { + let error = self + .0 + .as_object_mut() + .and_then(|obj| obj.remove("error")) + .unwrap(); + Err(format!("Error handling response: {:?}", error)) + }, + } + } + + /// Attempts to convert the underlying `Value` into either an RPC + /// notification or request. + /// + /// # Errors + /// + /// Returns a `serde_json::Error` if the `Value` cannot be converted + /// to one of the expected types. + pub fn into_rpc(self) -> Result, serde_json::Error> + where + R: DeserializeOwned, + { + let id = self.get_id(); + match id { + Some(id) => match serde_json::from_value::(self.0) { + Ok(resp) => Ok(Call::Request(id, resp)), + Err(err) => Ok(Call::InvalidRequest(id, err.into())), + }, + None => Ok(Call::Message(self.0)), + } + } +} + +impl From for RpcObject { + fn from(v: Value) -> RpcObject { + RpcObject(v) + } +} + +pub trait Handler { + type Request: DeserializeOwned; + fn handle_request(&mut self, ctx: &RpcCtx, rpc: Self::Request) -> Result; + #[allow(unused_variables)] + fn idle(&mut self, ctx: &RpcCtx, token: usize) {} +} + +/// A helper type which shuts down the runloop if a panic occurs while +/// handling an RPC. +struct PanicGuard<'a, W: Write + 'static>(&'a RawPeer); + +impl<'a, W: Write + 'static> Drop for PanicGuard<'a, W> { + fn drop(&mut self) { + if thread::panicking() { + error!("[RPC] panic guard hit, closing run loop"); + self.0.disconnect(); + } + } +} + +/// A structure holding the state of a main loop for handling RPC's. +pub struct RpcLoop { + reader: MessageReader, + peer: RawPeer, +} + +impl RpcLoop { + /// Creates a new `RpcLoop` with the given output stream (which is used for + /// sending requests and notifications, as well as responses). + pub fn new(writer: W) -> Self { + let rpc_peer = RawPeer(Arc::new(RpcState::new(writer))); + RpcLoop { + reader: MessageReader::default(), + peer: rpc_peer, + } + } + + /// Gets a reference to the peer. + pub fn get_raw_peer(&self) -> RawPeer { + self.peer.clone() + } + + /// Starts the event loop, reading lines from the reader until EOF, + /// or an error occurs. + /// + /// Returns `Ok()` in the EOF case, otherwise returns the + /// underlying `ReadError`. + /// + /// # Note: + /// The reader is supplied via a closure, as basically a workaround + /// so that the reader doesn't have to be `Send`. Internally, the + /// main loop starts a separate thread for I/O, and at startup that + /// thread calls the given closure. + /// + /// Calls to the handler happen on the caller's thread. + /// + /// Calls to the handler are guaranteed to preserve the order as + /// they appear on on the channel. At the moment, there is no way + /// for there to be more than one incoming request to be outstanding. + pub fn mainloop( + &mut self, + buffer_read_fn: BufferReadFn, + handler: &mut H, + ) -> Result<(), ReadError> + where + R: BufRead, + BufferReadFn: Send + FnOnce() -> R, + H: Handler, + { + let exit = crossbeam_utils::thread::scope(|scope| { + let peer = self.get_raw_peer(); + peer.reset_needs_exit(); + + let ctx = RpcCtx { + peer: Box::new(peer.clone()), + }; + + // 1. Spawn a new thread for reading data from a stream. + // 2. Continuously read data from the stream. + // 3. Parse the data as JSON. + // 4. Handle the JSON data as either a response or another type of JSON object. + // 5. Manage errors and connection status. + scope.spawn(move |_| { + let mut stream = buffer_read_fn(); + loop { + if self.peer.needs_exit() { + trace!("read loop exit"); + break; + } + + let json = match self.reader.next(&mut stream) { + Ok(json) => json, + Err(err) => { + // When the data can't be parsed into JSON. It means the data is not in the correct format. + // Probably the data comes from other stdout. + if self.peer.0.is_blocking() { + self.peer.disconnect(); + } + + error!("[RPC] failed to parse JSON: {:?}", err); + self.peer.put_rpc_object(Err(err)); + break; + }, + }; + if json.is_response() { + let id = json.get_id().unwrap(); + match json.into_response() { + Ok(resp) => { + let resp = resp.map_err(Error::from); + self.peer.handle_response(id, resp); + }, + Err(msg) => { + error!("[RPC] failed to parse response: {}", msg); + self.peer.handle_response(id, Err(Error::InvalidResponse)); + }, + } + } else { + self.peer.put_rpc_object(Ok(json)); + } + } + }); + + loop { + let _guard = PanicGuard(&peer); + let read_result = next_read(&peer, handler, &ctx); + let json = match read_result { + Ok(json) => json, + Err(err) => { + // finish idle work before disconnecting; + // this is mostly useful for integration tests. + if let Some(idle_token) = peer.try_get_idle() { + handler.idle(&ctx, idle_token); + } + peer.disconnect(); + return err; + }, + }; + + match json.into_rpc::() { + Ok(Call::Request(id, cmd)) => { + // Handle request sent from the client. For example from python executable. + let result = handler.handle_request(&ctx, cmd); + peer.respond(result, id); + }, + Ok(Call::InvalidRequest(id, err)) => peer.respond(Err(err), id), + Err(err) => { + peer.disconnect(); + return ReadError::UnknownRequest(err); + }, + Ok(Call::Message(msg)) => { + trace!("[RPC] received message: {}", msg); + }, + } + } + }) + .unwrap(); + + if exit.is_disconnect() { + Ok(()) + } else { + Err(exit) + } + } +} + +/// Returns the next read result, checking for idle work when no +/// result is available. +fn next_read(peer: &RawPeer, handler: &mut H, ctx: &RpcCtx) -> Result +where + W: Write + Send, + H: Handler, +{ + loop { + if let Some(result) = peer.try_get_rx() { + return result; + } + // handle timers before general idle work + let time_to_next_timer = match peer.check_timers() { + Some(Ok(token)) => { + do_idle(handler, ctx, token); + continue; + }, + Some(Err(duration)) => Some(duration), + None => None, + }; + + if let Some(idle_token) = peer.try_get_idle() { + do_idle(handler, ctx, idle_token); + continue; + } + + // we don't want to block indefinitely if there's no current idle work, + // because idle work could be scheduled from another thread. + let idle_timeout = time_to_next_timer + .unwrap_or(MAX_IDLE_WAIT) + .min(MAX_IDLE_WAIT); + + if let Some(result) = peer.get_rx_timeout(idle_timeout) { + return result; + } + } +} + +fn do_idle(handler: &mut H, ctx: &RpcCtx, token: usize) {} diff --git a/frontend/rust-lib/flowy-sidecar/src/rpc_peer.rs b/frontend/rust-lib/flowy-sidecar/src/rpc_peer.rs new file mode 100644 index 0000000000..c27992a40f --- /dev/null +++ b/frontend/rust-lib/flowy-sidecar/src/rpc_peer.rs @@ -0,0 +1,279 @@ +use crate::error::{Error, ReadError, RemoteError}; +use crate::plugin::{Callback, Peer, PluginId}; +use crate::rpc_loop::RpcObject; +use parking_lot::{Condvar, Mutex}; +use serde::{de, ser, Deserialize, Deserializer, Serialize, Serializer}; +use serde_json::{json, Value}; +use std::collections::{BTreeMap, BinaryHeap, VecDeque}; +use std::io::Write; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::{mpsc, Arc}; +use std::time::{Duration, Instant}; +use std::{cmp, io}; +use tracing::{error, trace, warn}; + +pub struct PluginCommand { + pub plugin_id: PluginId, + pub cmd: T, +} + +impl Serialize for PluginCommand { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut v = serde_json::to_value(&self.cmd).map_err(ser::Error::custom)?; + v["params"]["plugin_id"] = json!(self.plugin_id); + v.serialize(serializer) + } +} + +impl<'de, T: Deserialize<'de>> Deserialize<'de> for PluginCommand { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + struct PluginIdHelper { + plugin_id: PluginId, + } + let v = Value::deserialize(deserializer)?; + let plugin_id = PluginIdHelper::deserialize(&v) + .map_err(de::Error::custom)? + .plugin_id; + let cmd = T::deserialize(v).map_err(de::Error::custom)?; + Ok(PluginCommand { plugin_id, cmd }) + } +} + +pub struct RpcState { + rx_queue: Mutex>>, + rx_cvar: Condvar, + writer: Mutex, + id: AtomicUsize, + pending: Mutex>, + idle_queue: Mutex>, + timers: Mutex>, + needs_exit: AtomicBool, + is_blocking: AtomicBool, +} + +impl RpcState { + pub fn new(writer: W) -> Self { + RpcState { + rx_queue: Mutex::new(VecDeque::new()), + rx_cvar: Condvar::new(), + writer: Mutex::new(writer), + id: AtomicUsize::new(0), + pending: Mutex::new(BTreeMap::new()), + idle_queue: Mutex::new(VecDeque::new()), + timers: Mutex::new(BinaryHeap::new()), + needs_exit: AtomicBool::new(false), + is_blocking: Default::default(), + } + } + + pub fn is_blocking(&self) -> bool { + self.is_blocking.load(Ordering::Acquire) + } +} + +pub struct RawPeer(pub(crate) Arc>); + +impl Peer for RawPeer { + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } + + fn send_rpc_request_async(&self, method: &str, params: &Value, f: Box) { + self.send_rpc(method, params, ResponseHandler::Callback(f)); + } + + fn send_rpc_request(&self, method: &str, params: &Value) -> Result { + let (tx, rx) = mpsc::channel(); + self.0.is_blocking.store(true, Ordering::Release); + self.send_rpc(method, params, ResponseHandler::Chan(tx)); + rx.recv().unwrap_or(Err(Error::PeerDisconnect)) + } + + fn request_is_pending(&self) -> bool { + let queue = self.0.rx_queue.lock(); + !queue.is_empty() + } + + fn schedule_idle(&self, token: usize) { + self.0.idle_queue.lock().push_back(token); + } + + fn schedule_timer(&self, after: Instant, token: usize) { + self.0.timers.lock().push(Timer { + fire_after: after, + token, + }); + } +} + +impl RawPeer { + fn send(&self, v: &Value) -> Result<(), io::Error> { + let mut s = serde_json::to_string(v).unwrap(); + s.push('\n'); + self.0.writer.lock().write_all(s.as_bytes()) + } + + pub(crate) fn respond(&self, result: Response, id: u64) { + let mut response = json!({ "id": id }); + match result { + Ok(result) => response["result"] = result, + Err(error) => response["error"] = json!(error), + }; + if let Err(e) = self.send(&response) { + error!("[RPC] error {} sending response to RPC {:?}", e, id); + } + } + + fn send_rpc(&self, method: &str, params: &Value, rh: ResponseHandler) { + trace!("[RPC] method:{} params: {:?}", method, params); + let id = self.0.id.fetch_add(1, Ordering::Relaxed); + { + let mut pending = self.0.pending.lock(); + pending.insert(id, rh); + } + if let Err(e) = self.send(&json!({ + "id": id, + "method": method, + "params": params, + })) { + let mut pending = self.0.pending.lock(); + if let Some(rh) = pending.remove(&id) { + rh.invoke(Err(Error::Io(e))); + } + } + } + + pub(crate) fn handle_response(&self, id: u64, resp: Result) { + let id = id as usize; + let handler = { + let mut pending = self.0.pending.lock(); + pending.remove(&id) + }; + match handler { + Some(response_handler) => response_handler.invoke(resp), + None => warn!("[RPC] id {} not found in pending", id), + } + } + + /// Get a message from the receive queue if available. + pub(crate) fn try_get_rx(&self) -> Option> { + let mut queue = self.0.rx_queue.lock(); + queue.pop_front() + } + + /// Get a message from the receive queue, waiting for at most `Duration` + /// and returning `None` if no message is available. + pub(crate) fn get_rx_timeout(&self, dur: Duration) -> Option> { + let mut queue = self.0.rx_queue.lock(); + let result = self.0.rx_cvar.wait_for(&mut queue, dur); + if result.timed_out() { + return None; + } + queue.pop_front() + } + + /// Adds a message to the receive queue. The message should only + /// be `None` if the read thread is exiting. + pub(crate) fn put_rpc_object(&self, json: Result) { + let mut queue = self.0.rx_queue.lock(); + queue.push_back(json); + self.0.rx_cvar.notify_one(); + } + + pub(crate) fn try_get_idle(&self) -> Option { + self.0.idle_queue.lock().pop_front() + } + + /// Checks status of the most imminent timer. If that timer has expired, + /// returns `Some(Ok(_))`, with the corresponding token. + /// If a timer exists but has not expired, returns `Some(Err(_))`, + /// with the error value being the `Duration` until the timer is ready. + /// Returns `None` if no timers are registered. + pub(crate) fn check_timers(&self) -> Option> { + let mut timers = self.0.timers.lock(); + match timers.peek() { + None => return None, + Some(t) => { + let now = Instant::now(); + if t.fire_after > now { + return Some(Err(t.fire_after - now)); + } + }, + } + Some(Ok(timers.pop().unwrap().token)) + } + + /// send disconnect error to pending requests. + pub(crate) fn disconnect(&self) { + trace!("[RPC] disconnecting peer"); + let mut pending = self.0.pending.lock(); + let ids = pending.keys().cloned().collect::>(); + for id in &ids { + let callback = pending.remove(id).unwrap(); + callback.invoke(Err(Error::PeerDisconnect)); + } + self.0.needs_exit.store(true, Ordering::Relaxed); + } + + /// Returns `true` if an error has occured in the main thread. + pub(crate) fn needs_exit(&self) -> bool { + self.0.needs_exit.load(Ordering::Relaxed) + } + + pub(crate) fn reset_needs_exit(&self) { + self.0.needs_exit.store(false, Ordering::SeqCst); + } +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.box_clone() + } +} + +impl Clone for RawPeer { + fn clone(&self) -> Self { + RawPeer(self.0.clone()) + } +} + +pub type Response = Result; +enum ResponseHandler { + Chan(mpsc::Sender>), + Callback(Box), +} + +impl ResponseHandler { + fn invoke(self, result: Result) { + match self { + ResponseHandler::Chan(tx) => { + let _ = tx.send(result); + }, + ResponseHandler::Callback(f) => f.call(result), + } + } +} +#[derive(Debug, PartialEq, Eq)] +struct Timer { + fire_after: Instant, + token: usize, +} + +impl Ord for Timer { + fn cmp(&self, other: &Timer) -> cmp::Ordering { + other.fire_after.cmp(&self.fire_after) + } +} + +impl PartialOrd for Timer { + fn partial_cmp(&self, other: &Timer) -> Option { + Some(self.cmp(other)) + } +} diff --git a/frontend/rust-lib/flowy-sidecar/tests/chat_bin_test.rs b/frontend/rust-lib/flowy-sidecar/tests/chat_bin_test.rs index 084222b98a..89c43ae92e 100644 --- a/frontend/rust-lib/flowy-sidecar/tests/chat_bin_test.rs +++ b/frontend/rust-lib/flowy-sidecar/tests/chat_bin_test.rs @@ -1,5 +1,6 @@ use anyhow::Result; -use flowy_sidecar::process::SidecarCommand; +use flowy_sidecar::manager::SidecarManager; +use flowy_sidecar::plugin::PluginInfo; use serde_json::json; use std::sync::Once; use tracing::info; @@ -10,26 +11,46 @@ use tracing_subscriber::EnvFilter; #[tokio::test] async fn load_chat_model_test() { if let Ok(config) = LocalAIConfiguration::new() { - let (mut rx, mut child) = SidecarCommand::new_sidecar(&config.chat_bin_path) - .unwrap() - .spawn() + let manager = SidecarManager::new(); + let info = PluginInfo { + name: "chat".to_string(), + exec_path: config.chat_bin_path.clone(), + }; + let plugin_id = manager.create_plugin(info).await.unwrap(); + manager + .init_plugin( + plugin_id, + json!({ + "absolute_chat_model_path":config.chat_model_absolute_path(), + }), + ) .unwrap(); - tokio::spawn(async move { - while let Some(event) = rx.recv().await { - info!("event: {:?}", event); - } - }); - let json = json!({ - "plugin_id": "example_plugin_id", - "method": "initialize", - "params": { - "absolute_chat_model_path":config.chat_model_absolute_path(), - } - }); - child.write_json(json).unwrap(); - tokio::time::sleep(tokio::time::Duration::from_secs(15)).await; + tokio::spawn(async move { + tokio::time::sleep(tokio::time::Duration::from_secs(15)).await; + manager.kill_plugin(plugin_id).await.unwrap(); + }) + + // let (mut rx, mut child) = SidecarCommand::new_sidecar(&config.chat_bin_path) + // .unwrap() + // .spawn() + // .unwrap(); + // + // tokio::spawn(async move { + // while let Some(event) = rx.recv().await { + // info!("event: {:?}", event); + // } + // }); + // + // let json = json!({ + // "plugin_id": "example_plugin_id", + // "method": "initialize", + // "params": { + // "absolute_chat_model_path":config.chat_model_absolute_path(), + // } + // }); + // child.write_json(json).unwrap(); // let chat_id = uuid::Uuid::new_v4().to_string(); // let json =