From e4b1108ff044a81311ef6607eeabe23dc228c9cd Mon Sep 17 00:00:00 2001 From: nathan Date: Thu, 27 Jun 2024 22:00:36 +0800 Subject: [PATCH] chore: test streaming --- frontend/rust-lib/Cargo.lock | 6 +- frontend/rust-lib/flowy-chat/src/manager.rs | 8 +- frontend/rust-lib/flowy-sidecar/Cargo.toml | 2 +- .../rust-lib/flowy-sidecar/src/core/parser.rs | 31 ++-- .../rust-lib/flowy-sidecar/src/core/plugin.rs | 50 +++++-- .../flowy-sidecar/src/core/rpc_loop.rs | 17 ++- .../flowy-sidecar/src/core/rpc_object.rs | 72 ++++++--- .../flowy-sidecar/src/core/rpc_peer.rs | 139 +++++++++++++----- frontend/rust-lib/flowy-sidecar/src/error.rs | 17 ++- frontend/rust-lib/flowy-sidecar/src/lib.rs | 2 +- .../rust-lib/flowy-sidecar/src/manager.rs | 28 ++-- .../flowy-sidecar/src/plugins/chat_plugin.rs | 33 +++-- .../src/plugins/embedding_plugin.rs | 18 ++- .../flowy-sidecar/tests/chat_test/mod.rs | 29 +++- frontend/rust-lib/flowy-sidecar/tests/util.rs | 27 +++- frontend/rust-lib/lib-infra/src/lib.rs | 1 + .../rust-lib/lib-infra/src/stream_util.rs | 21 +++ 17 files changed, 359 insertions(+), 142 deletions(-) create mode 100644 frontend/rust-lib/lib-infra/src/stream_util.rs diff --git a/frontend/rust-lib/Cargo.lock b/frontend/rust-lib/Cargo.lock index f16761ee37..dc53e80e38 100644 --- a/frontend/rust-lib/Cargo.lock +++ b/frontend/rust-lib/Cargo.lock @@ -2148,7 +2148,6 @@ dependencies = [ "anyhow", "crossbeam-utils", "dotenv", - "futures", "lib-infra", "log", "once_cell", @@ -2157,6 +2156,7 @@ dependencies = [ "serde_json", "thiserror", "tokio", + "tokio-stream", "tracing", "tracing-subscriber", "uuid", @@ -5767,9 +5767,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.14" +version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" dependencies = [ "futures-core", "pin-project-lite", diff --git a/frontend/rust-lib/flowy-chat/src/manager.rs b/frontend/rust-lib/flowy-chat/src/manager.rs index d71c1ca6f4..4a4f1b4131 100644 --- a/frontend/rust-lib/flowy-chat/src/manager.rs +++ b/frontend/rust-lib/flowy-chat/src/manager.rs @@ -11,7 +11,7 @@ use flowy_sidecar::manager::SidecarManager; use flowy_sqlite::DBConnection; use lib_infra::future::FutureResult; use lib_infra::util::timestamp; -use std::sync::atomic::AtomicBool; + use std::sync::Arc; use tracing::trace; @@ -263,9 +263,9 @@ impl ChatCloudService for ChatService { async fn stream_answer( &self, - workspace_id: &str, - chat_id: &str, - message_id: i64, + _workspace_id: &str, + _chat_id: &str, + _message_id: i64, ) -> Result { todo!() } diff --git a/frontend/rust-lib/flowy-sidecar/Cargo.toml b/frontend/rust-lib/flowy-sidecar/Cargo.toml index 31d0fda4d6..b7467dce69 100644 --- a/frontend/rust-lib/flowy-sidecar/Cargo.toml +++ b/frontend/rust-lib/flowy-sidecar/Cargo.toml @@ -16,8 +16,8 @@ tracing.workspace = true crossbeam-utils = "0.8.20" log = "0.4.21" parking_lot.workspace = true +tokio-stream = "0.1.15" lib-infra.workspace = true -futures.workspace = true [dev-dependencies] diff --git a/frontend/rust-lib/flowy-sidecar/src/core/parser.rs b/frontend/rust-lib/flowy-sidecar/src/core/parser.rs index 160fd1eb89..9e8178d44d 100644 --- a/frontend/rust-lib/flowy-sidecar/src/core/parser.rs +++ b/frontend/rust-lib/flowy-sidecar/src/core/parser.rs @@ -1,8 +1,9 @@ use crate::core::rpc_object::RpcObject; -use crate::core::rpc_peer::ResponsePayload; + use crate::error::{ReadError, RemoteError}; use serde_json::{json, Value as JsonValue}; use std::io::BufRead; +use tracing::error; #[derive(Debug, Default)] pub struct MessageReader(String); @@ -57,15 +58,15 @@ pub enum Call { } pub trait ResponseParser { - type ValueType; - fn parse_response(payload: JsonValue) -> Result; + type ValueType: Send + Sync + 'static; + fn parse_json(payload: JsonValue) -> Result; } pub struct ChatResponseParser; impl ResponseParser for ChatResponseParser { type ValueType = String; - fn parse_response(json: JsonValue) -> Result { + fn parse_json(json: JsonValue) -> Result { if json.is_object() { if let Some(data) = json.get("data") { if let Some(message) = data.as_str() { @@ -73,7 +74,19 @@ impl ResponseParser for ChatResponseParser { } } } - return Err(RemoteError::InvalidResponse(json)); + return Err(RemoteError::ParseResponse(json)); + } +} + +pub struct ChatStreamResponseParser; +impl ResponseParser for ChatStreamResponseParser { + type ValueType = String; + + fn parse_json(json: JsonValue) -> Result { + if let Some(message) = json.as_str() { + return Ok(message.to_string()); + } + return Err(RemoteError::ParseResponse(json)); } } @@ -81,7 +94,7 @@ pub struct ChatRelatedQuestionsResponseParser; impl ResponseParser for ChatRelatedQuestionsResponseParser { type ValueType = Vec; - fn parse_response(json: JsonValue) -> Result { + fn parse_json(json: JsonValue) -> Result { if json.is_object() { if let Some(data) = json.get("data") { if let Some(values) = data.as_array() { @@ -89,7 +102,7 @@ impl ResponseParser for ChatRelatedQuestionsResponseParser { } } } - return Err(RemoteError::InvalidResponse(json)); + return Err(RemoteError::ParseResponse(json)); } } @@ -97,7 +110,7 @@ pub struct SimilarityResponseParser; impl ResponseParser for SimilarityResponseParser { type ValueType = f64; - fn parse_response(json: JsonValue) -> Result { + fn parse_json(json: JsonValue) -> Result { if json.is_object() { if let Some(data) = json.get("data") { if let Some(score) = data.get("score").and_then(|v| v.as_f64()) { @@ -106,6 +119,6 @@ impl ResponseParser for SimilarityResponseParser { } } - return Err(RemoteError::InvalidResponse(json)); + return Err(RemoteError::ParseResponse(json)); } } diff --git a/frontend/rust-lib/flowy-sidecar/src/core/plugin.rs b/frontend/rust-lib/flowy-sidecar/src/core/plugin.rs index fe82d441e5..db18e68661 100644 --- a/frontend/rust-lib/flowy-sidecar/src/core/plugin.rs +++ b/frontend/rust-lib/flowy-sidecar/src/core/plugin.rs @@ -1,9 +1,9 @@ -use crate::error::Error; +use crate::error::SidecarError; use crate::manager::WeakSidecarState; use crate::core::parser::ResponseParser; use crate::core::rpc_loop::RpcLoop; -use crate::core::rpc_peer::{Callback, ResponsePayload}; +use crate::core::rpc_peer::{CloneableCallback, OneShotCallback}; use anyhow::anyhow; use serde::{Deserialize, Serialize}; use serde_json::{json, Value as JsonValue}; @@ -12,6 +12,9 @@ use std::process::{Child, Stdio}; use std::sync::Arc; use std::thread; use std::time::Instant; +use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::Stream; + use tracing::{error, info}; #[derive( @@ -34,12 +37,12 @@ pub trait Peer: Send + Sync + 'static { /// Sends an RPC notification to the peer with the specified method and parameters. fn send_rpc_notification(&self, method: &str, params: &JsonValue); - /// Sends an asynchronous RPC request to the peer and executes the provided callback upon completion. - fn send_rpc_request_async(&self, method: &str, params: &JsonValue, f: Box); + fn stream_rpc_request(&self, method: &str, params: &JsonValue, f: CloneableCallback); + fn async_send_rpc_request(&self, method: &str, params: &JsonValue, f: Box); /// Sends a synchronous RPC request to the peer and waits for the result. /// Returns the result of the request or an error. - fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result; + fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result; /// Checks if there is an incoming request pending, intended to reduce latency for bulk operations done in the background. fn request_is_pending(&self) -> bool; @@ -66,34 +69,53 @@ pub struct Plugin { } impl Plugin { - pub fn initialize(&self, value: JsonValue) -> Result<(), Error> { + pub fn initialize(&self, value: JsonValue) -> Result<(), SidecarError> { self.peer.send_rpc_request("initialize", &value)?; Ok(()) } - pub fn send_request(&self, method: &str, params: &JsonValue) -> Result { + pub fn request(&self, method: &str, params: &JsonValue) -> Result { self.peer.send_rpc_request(method, params) } - pub async fn async_send_request( + pub async fn async_request( &self, method: &str, params: &JsonValue, - ) -> Result { + ) -> Result { let (tx, rx) = tokio::sync::oneshot::channel(); - self.peer.send_rpc_request_async( + self.peer.async_send_rpc_request( method, params, Box::new(move |result| { let _ = tx.send(result); }), ); - let value = rx - .await - .map_err(|err| Error::Internal(anyhow!("error waiting for async response: {:?}", err)))??; - let value = P::parse_response(value)?; + let value = rx.await.map_err(|err| { + SidecarError::Internal(anyhow!("error waiting for async response: {:?}", err)) + })??; + let value = P::parse_json(value)?; Ok(value) } + pub fn stream_request( + &self, + method: &str, + params: &JsonValue, + ) -> Result>, SidecarError> { + let (tx, stream) = tokio::sync::mpsc::channel(100); + let stream = ReceiverStream::new(stream); + let callback = CloneableCallback::new(move |result| match result { + Ok(json) => { + let result = P::parse_json(json).map_err(SidecarError::from); + let _ = tx.blocking_send(result); + }, + Err(err) => { + let _ = tx.blocking_send(Err(err)); + }, + }); + self.peer.stream_rpc_request(method, params, callback); + Ok(stream) + } pub fn shutdown(&self) { if let Err(err) = self.peer.send_rpc_request("shutdown", &json!({})) { diff --git a/frontend/rust-lib/flowy-sidecar/src/core/rpc_loop.rs b/frontend/rust-lib/flowy-sidecar/src/core/rpc_loop.rs index c7411ec73b..8712349cc2 100644 --- a/frontend/rust-lib/flowy-sidecar/src/core/rpc_loop.rs +++ b/frontend/rust-lib/flowy-sidecar/src/core/rpc_loop.rs @@ -2,14 +2,14 @@ use crate::core::parser::{Call, MessageReader}; use crate::core::plugin::RpcCtx; use crate::core::rpc_object::RpcObject; use crate::core::rpc_peer::{RawPeer, ResponsePayload, RpcState}; -use crate::error::{Error, ReadError, RemoteError}; +use crate::error::{ReadError, RemoteError, SidecarError}; use serde::de::DeserializeOwned; -use serde_json::Value; + use std::io::{BufRead, Write}; use std::sync::Arc; use std::thread; use std::time::Duration; -use tracing::{error, trace}; +use tracing::{error, info, trace}; const MAX_IDLE_WAIT: Duration = Duration::from_millis(5); @@ -97,7 +97,6 @@ impl RpcLoop { trace!("read loop exit"); break; } - let json = match self.reader.next(&mut stream) { Ok(json) => json, Err(err) => { @@ -109,15 +108,17 @@ impl RpcLoop { }, }; if json.is_response() { - let id = json.get_id().unwrap(); + let request_id = json.get_id().unwrap(); match json.into_response() { Ok(resp) => { - let resp = resp.map_err(Error::from); - self.peer.handle_response(id, resp); + let resp = resp.map_err(SidecarError::from); + self.peer.handle_response(request_id, resp); }, Err(msg) => { error!("[RPC] failed to parse response: {}", msg); - self.peer.handle_response(id, Err(Error::InvalidResponse)); + self + .peer + .handle_response(request_id, Err(SidecarError::InvalidResponse)); }, } } else { diff --git a/frontend/rust-lib/flowy-sidecar/src/core/rpc_object.rs b/frontend/rust-lib/flowy-sidecar/src/core/rpc_object.rs index 388b66fd31..253e400e45 100644 --- a/frontend/rust-lib/flowy-sidecar/src/core/rpc_object.rs +++ b/frontend/rust-lib/flowy-sidecar/src/core/rpc_object.rs @@ -1,7 +1,8 @@ use crate::core::parser::{Call, RequestId}; use crate::core::rpc_peer::{Response, ResponsePayload}; +use crate::error::RemoteError; use serde::de::{DeserializeOwned, Error}; -use serde_json::Value; +use serde_json::{json, Value}; #[derive(Debug, Clone)] pub struct RpcObject(pub Value); @@ -24,30 +25,63 @@ impl RpcObject { self.0.get("id").is_some() && self.0.get("method").is_none() } - /// Converts the underlying `Value` into an RPC response object. - /// The caller should verify that the object is a response before calling this method. + /// Converts a JSON-RPC response into a structured `Response` object. + /// + /// This function validates and parses a JSON-RPC response, ensuring it contains the necessary fields, + /// and then transforms it into a structured `Response` object. The response must contain either a + /// "result" or an "error" field, but not both. If the response contains a "result" field, it may also + /// include streaming data, indicated by a nested "stream" field. + /// /// # Errors - /// If the `Value` is not a well-formed response object, this returns a `String` containing an - /// error message. The caller should print this message and exit. + /// + /// This function will return an error if: + /// - The "id" field is missing. + /// - The response contains both "result" and "error" fields, or neither. + /// - The "stream" field within the "result" is missing "type" or "data" fields. + /// - The "stream" type is invalid (i.e., not "streaming" or "end"). + /// + /// # Returns + /// + /// - `Ok(Ok(ResponsePayload::Json(result)))`: If the response contains a valid "result". + /// - `Ok(Ok(ResponsePayload::Streaming(data)))`: If the response contains streaming data of type "streaming". + /// - `Ok(Ok(ResponsePayload::StreamEnd(json!({}))))`: If the response contains streaming data of type "end". + /// - `Err(String)`: If any validation or parsing errors occur. + ///. pub fn into_response(mut self) -> Result { - let _ = self + // Ensure 'id' field is present + self .get_id() - .ok_or("Response requires 'id' field.".to_string())?; + .ok_or_else(|| "Response requires 'id' field.".to_string())?; - if self.0.get("result").is_some() == self.0.get("error").is_some() { + // Ensure the response contains exactly one of 'result' or 'error' + let has_result = self.0.get("result").is_some(); + let has_error = self.0.get("error").is_some(); + if has_result == has_error { return Err("RPC response must contain exactly one of 'error' or 'result' fields.".into()); } - let result = self.0.as_object_mut().and_then(|obj| obj.remove("result")); - match result { - Some(r) => Ok(Ok(ResponsePayload::Json(r))), - None => { - let error = self - .0 - .as_object_mut() - .and_then(|obj| obj.remove("error")) - .unwrap(); - Err(format!("Error handling response: {:?}", error)) - }, + + // Handle the 'result' field if present + if let Some(mut result) = self.0.as_object_mut().and_then(|obj| obj.remove("result")) { + if let Some(mut stream) = result.as_object_mut().and_then(|obj| obj.remove("stream")) { + if let Some((has_more, data)) = stream.as_object_mut().and_then(|obj| { + let has_more = obj.remove("has_more")?.as_bool().unwrap_or(false); + let data = obj.remove("data")?; + Some((has_more, data)) + }) { + return match has_more { + true => Ok(Ok(ResponsePayload::Streaming(data))), + false => Ok(Ok(ResponsePayload::StreamEnd(data))), + }; + } else { + return Err("Stream response must contain 'type' and 'data' fields.".into()); + } + } + + Ok(Ok(ResponsePayload::Json(result))) + } else { + // Handle the 'error' field + let error = self.0.as_object_mut().unwrap().remove("error").unwrap(); + Err(format!("Error handling response: {:?}", error)) } } diff --git a/frontend/rust-lib/flowy-sidecar/src/core/rpc_peer.rs b/frontend/rust-lib/flowy-sidecar/src/core/rpc_peer.rs index 346b94e22b..5d8d277f47 100644 --- a/frontend/rust-lib/flowy-sidecar/src/core/rpc_peer.rs +++ b/frontend/rust-lib/flowy-sidecar/src/core/rpc_peer.rs @@ -1,18 +1,18 @@ use crate::core::plugin::{Peer, PluginId}; use crate::core::rpc_object::RpcObject; -use crate::error::{Error, ReadError, RemoteError}; -use futures::Stream; +use crate::error::{ReadError, RemoteError, SidecarError}; use parking_lot::{Condvar, Mutex}; use serde::{de, ser, Deserialize, Deserializer, Serialize, Serializer}; use serde_json::{json, Value as JsonValue}; use std::collections::{BTreeMap, BinaryHeap, VecDeque}; use std::fmt::Display; use std::io::Write; -use std::pin::Pin; + use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{mpsc, Arc}; use std::time::{Duration, Instant}; use std::{cmp, io}; +use tokio_stream::Stream; use tracing::{error, trace, warn}; pub struct PluginCommand { @@ -97,15 +97,19 @@ impl Peer for RawPeer { } } - fn send_rpc_request_async(&self, method: &str, params: &JsonValue, f: Box) { + fn stream_rpc_request(&self, method: &str, params: &JsonValue, f: CloneableCallback) { + self.send_rpc(method, params, ResponseHandler::StreamCallback(Arc::new(f))); + } + + fn async_send_rpc_request(&self, method: &str, params: &JsonValue, f: Box) { self.send_rpc(method, params, ResponseHandler::Callback(f)); } - fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result { + fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result { let (tx, rx) = mpsc::channel(); self.0.is_blocking.store(true, Ordering::Release); self.send_rpc(method, params, ResponseHandler::Chan(tx)); - rx.recv().unwrap_or(Err(Error::PeerDisconnect)) + rx.recv().unwrap_or(Err(SidecarError::PeerDisconnect)) } fn request_is_pending(&self) -> bool { @@ -133,7 +137,7 @@ impl RawPeer { match result { Ok(result) => match result { ResponsePayload::Json(value) => response["result"] = value, - ResponsePayload::Stream(_) => { + ResponsePayload::Streaming(_) | ResponsePayload::StreamEnd(_) => { error!("stream response not supported") }, }, @@ -161,29 +165,44 @@ impl RawPeer { })) { let mut pending = self.0.pending.lock(); if let Some(rh) = pending.remove(&id) { - rh.invoke(Err(Error::Io(e))); + rh.invoke(Err(SidecarError::Io(e))); } } } - pub(crate) fn handle_response(&self, id: u64, resp: Result) { - let id = id as usize; + pub(crate) fn handle_response( + &self, + request_id: u64, + resp: Result, + ) { + let request_id = request_id as usize; let handler = { let mut pending = self.0.pending.lock(); - pending.remove(&id) + pending.remove(&request_id) }; let is_stream = resp.as_ref().map(|resp| resp.is_stream()).unwrap_or(false); match handler { Some(response_handler) => { + if is_stream { + let is_stream_end = resp + .as_ref() + .map(|resp| resp.is_stream_end()) + .unwrap_or(false); + if !is_stream_end { + // when steam is not end, we need to put the stream callback back to pending in order to + // receive the next stream message. + if let Some(callback) = response_handler.get_stream_callback() { + let mut pending = self.0.pending.lock(); + pending.insert(request_id, ResponseHandler::StreamCallback(callback)); + } + } else { + trace!("[RPC] {} stream end", request_id); + } + } let json = resp.map(|resp| resp.into_json()); response_handler.invoke(json); - - // if is_stream { - // let mut pending = self.0.pending.lock(); - // pending.insert(id, response_handler); - // } }, - None => warn!("[RPC] id {} not found in pending", id), + None => error!("[RPC] id {}'s handle not found", request_id), } } @@ -243,7 +262,7 @@ impl RawPeer { let ids = pending.keys().cloned().collect::>(); for id in &ids { let callback = pending.remove(id).unwrap(); - callback.invoke(Err(Error::PeerDisconnect)); + callback.invoke(Err(SidecarError::PeerDisconnect)); } self.0.needs_exit.store(true, Ordering::Relaxed); } @@ -267,7 +286,8 @@ impl Clone for RawPeer { #[derive(Clone, Debug)] pub enum ResponsePayload { Json(JsonValue), - Stream(JsonValue), + Streaming(JsonValue), + StreamEnd(JsonValue), } impl ResponsePayload { @@ -276,23 +296,29 @@ impl ResponsePayload { } pub fn is_stream(&self) -> bool { - match self { - ResponsePayload::Json(_) => false, - ResponsePayload::Stream(_) => true, - } + matches!( + self, + ResponsePayload::Streaming(_) | ResponsePayload::StreamEnd(_) + ) + } + + pub fn is_stream_end(&self) -> bool { + matches!(self, ResponsePayload::StreamEnd(_)) } pub fn json(&self) -> &JsonValue { match self { ResponsePayload::Json(v) => v, - ResponsePayload::Stream(v) => v, + ResponsePayload::Streaming(v) => v, + ResponsePayload::StreamEnd(v) => v, } } pub fn into_json(self) -> JsonValue { match self { ResponsePayload::Json(v) => v, - ResponsePayload::Stream(v) => v, + ResponsePayload::Streaming(v) => v, + ResponsePayload::StreamEnd(v) => v, } } } @@ -301,37 +327,78 @@ impl Display for ResponsePayload { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ResponsePayload::Json(v) => write!(f, "{}", v), - ResponsePayload::Stream(_) => write!(f, "stream"), + ResponsePayload::Streaming(_) => write!(f, "stream start"), + ResponsePayload::StreamEnd(_) => write!(f, "stream end"), } } } pub type Response = Result; -pub trait ResponseStream: Stream> + Unpin + Send {} +pub trait ResponseStream: Stream> + Unpin + Send {} -impl ResponseStream for T where T: Stream> + Unpin + Send {} +impl ResponseStream for T where T: Stream> + Unpin + Send {} enum ResponseHandler { - Chan(mpsc::Sender>), - Callback(Box), -} -pub trait Callback: Send { - fn call(self: Box, result: Result); + Chan(mpsc::Sender>), + Callback(Box), + StreamCallback(Arc), } -impl)> Callback for F { - fn call(self: Box, result: Result) { +impl ResponseHandler { + pub fn get_stream_callback(&self) -> Option> { + match self { + ResponseHandler::StreamCallback(cb) => Some(cb.clone()), + _ => None, + } + } +} + +pub trait OneShotCallback: Send { + fn call(self: Box, result: Result); +} + +impl)> OneShotCallback for F { + fn call(self: Box, result: Result) { + (self)(result) + } +} + +pub trait Callback: Send + Sync { + fn call(&self, result: Result); +} + +impl)> Callback for F { + fn call(&self, result: Result) { (*self)(result) } } +#[derive(Clone)] +pub struct CloneableCallback { + callback: Arc, +} +impl CloneableCallback { + pub fn new(callback: C) -> Self { + CloneableCallback { + callback: Arc::new(callback), + } + } + + pub fn call(&self, result: Result) { + self.callback.call(result) + } +} + impl ResponseHandler { - fn invoke(self, result: Result) { + fn invoke(self, result: Result) { match self { ResponseHandler::Chan(tx) => { let _ = tx.send(result); }, + ResponseHandler::StreamCallback(cb) => { + cb.call(result); + }, ResponseHandler::Callback(f) => f.call(result), } } diff --git a/frontend/rust-lib/flowy-sidecar/src/error.rs b/frontend/rust-lib/flowy-sidecar/src/error.rs index e7ff1c3339..fb7f7b6a52 100644 --- a/frontend/rust-lib/flowy-sidecar/src/error.rs +++ b/frontend/rust-lib/flowy-sidecar/src/error.rs @@ -1,11 +1,10 @@ -use crate::core::rpc_peer::ResponsePayload; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_json::{json, Value as JsonValue}; use std::{fmt, io}; /// The error type of `tauri-utils`. #[derive(Debug, thiserror::Error)] -pub enum Error { +pub enum SidecarError { /// An IO error occurred on the underlying communication channel. #[error(transparent)] Io(#[from] io::Error), @@ -48,6 +47,9 @@ pub enum RemoteError { #[error("Invalid response: {0}")] InvalidResponse(JsonValue), + + #[error("Parse response: {0}")] + ParseResponse(JsonValue), /// A custom error, defined by the client. #[error("Custom error: {message}")] Custom { @@ -100,9 +102,9 @@ impl From for RemoteError { } } -impl From for Error { - fn from(err: RemoteError) -> Error { - Error::RemoteError(err) +impl From for SidecarError { + fn from(err: RemoteError) -> SidecarError { + SidecarError::RemoteError(err) } } @@ -156,6 +158,11 @@ impl Serialize for RemoteError { "Invalid response".to_string(), Some(json!(resp.to_string())), ), + RemoteError::ParseResponse(resp) => ( + -1, + "Invalid response".to_string(), + Some(json!(resp.to_string())), + ), }; let err = ErrorHelper { code, diff --git a/frontend/rust-lib/flowy-sidecar/src/lib.rs b/frontend/rust-lib/flowy-sidecar/src/lib.rs index 2e508ba35f..7d77790307 100644 --- a/frontend/rust-lib/flowy-sidecar/src/lib.rs +++ b/frontend/rust-lib/flowy-sidecar/src/lib.rs @@ -1,4 +1,4 @@ pub mod core; -mod error; +pub mod error; pub mod manager; pub mod plugins; diff --git a/frontend/rust-lib/flowy-sidecar/src/manager.rs b/frontend/rust-lib/flowy-sidecar/src/manager.rs index 7d23bbe6c2..9541961119 100644 --- a/frontend/rust-lib/flowy-sidecar/src/manager.rs +++ b/frontend/rust-lib/flowy-sidecar/src/manager.rs @@ -2,11 +2,11 @@ use crate::core::parser::ResponseParser; use crate::core::plugin::{start_plugin_process, Plugin, PluginId, PluginInfo, RpcCtx}; use crate::core::rpc_loop::Handler; use crate::core::rpc_peer::{PluginCommand, ResponsePayload}; -use crate::error::{Error, ReadError, RemoteError}; +use crate::error::{ReadError, RemoteError, SidecarError}; use anyhow::anyhow; use lib_infra::util::{get_operating_system, OperatingSystem}; use parking_lot::Mutex; -use serde_json::{json, Value}; +use serde_json::Value; use std::io; use std::sync::atomic::{AtomicI64, Ordering}; use std::sync::{Arc, Weak}; @@ -29,9 +29,9 @@ impl SidecarManager { } } - pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result { + pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result { if self.operating_system.is_not_desktop() { - return Err(Error::Internal(anyhow!( + return Err(SidecarError::Internal(anyhow!( "plugin not supported on this platform" ))); } @@ -41,7 +41,7 @@ impl SidecarManager { Ok(plugin_id) } - pub async fn get_plugin(&self, plugin_id: PluginId) -> Result, Error> { + pub async fn get_plugin(&self, plugin_id: PluginId) -> Result, SidecarError> { let state = self.state.lock(); let plugin = state .plugins @@ -51,9 +51,9 @@ impl SidecarManager { Ok(Arc::downgrade(plugin)) } - pub async fn remove_plugin(&self, id: PluginId) -> Result<(), Error> { + pub async fn remove_plugin(&self, id: PluginId) -> Result<(), SidecarError> { if self.operating_system.is_not_desktop() { - return Err(Error::Internal(anyhow!( + return Err(SidecarError::Internal(anyhow!( "plugin not supported on this platform" ))); } @@ -71,9 +71,9 @@ impl SidecarManager { Ok(()) } - pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<(), Error> { + pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<(), SidecarError> { if self.operating_system.is_not_desktop() { - return Err(Error::Internal(anyhow!( + return Err(SidecarError::Internal(anyhow!( "plugin not supported on this platform" ))); } @@ -94,15 +94,15 @@ impl SidecarManager { id: PluginId, method: &str, request: Value, - ) -> Result { + ) -> Result { let state = self.state.lock(); let plugin = state .plugins .iter() .find(|p| p.id == id) .ok_or(anyhow!("plugin not found"))?; - let resp = plugin.send_request(method, &request)?; - let value = P::parse_response(resp)?; + let resp = plugin.request(method, &request)?; + let value = P::parse_json(resp)?; Ok(value) } @@ -111,14 +111,14 @@ impl SidecarManager { id: PluginId, method: &str, request: Value, - ) -> Result { + ) -> Result { let state = self.state.lock(); let plugin = state .plugins .iter() .find(|p| p.id == id) .ok_or(anyhow!("plugin not found"))?; - let value = plugin.async_send_request::

(method, &request).await?; + let value = plugin.async_request::

(method, &request).await?; Ok(value) } } diff --git a/frontend/rust-lib/flowy-sidecar/src/plugins/chat_plugin.rs b/frontend/rust-lib/flowy-sidecar/src/plugins/chat_plugin.rs index 69b5441065..f346b62933 100644 --- a/frontend/rust-lib/flowy-sidecar/src/plugins/chat_plugin.rs +++ b/frontend/rust-lib/flowy-sidecar/src/plugins/chat_plugin.rs @@ -1,9 +1,13 @@ -use crate::core::parser::{ChatRelatedQuestionsResponseParser, ChatResponseParser}; +use crate::core::parser::{ + ChatRelatedQuestionsResponseParser, ChatResponseParser, ChatStreamResponseParser, +}; use crate::core::plugin::{Plugin, PluginId}; -use crate::error::Error; +use crate::error::SidecarError; use anyhow::anyhow; use serde_json::json; use std::sync::Weak; +use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::Stream; pub struct ChatPluginOperation { plugin: Weak, @@ -17,17 +21,17 @@ impl ChatPluginOperation { pub async fn send_message( &self, chat_id: &str, - plugin_id: PluginId, + _plugin_id: PluginId, message: &str, - ) -> Result { + ) -> Result { let plugin = self .plugin .upgrade() - .ok_or(Error::Internal(anyhow!("Plugin is dropped")))?; + .ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?; let params = json!({"chat_id": chat_id, "method": "answer", "params": {"content": message}}); let resp = plugin - .async_send_request::("handle", ¶ms) + .async_request::("handle", ¶ms) .await?; Ok(resp) } @@ -35,34 +39,31 @@ impl ChatPluginOperation { pub async fn stream_message( &self, chat_id: &str, - plugin_id: PluginId, + _plugin_id: PluginId, message: &str, - ) -> Result { + ) -> Result>, SidecarError> { let plugin = self .plugin .upgrade() - .ok_or(Error::Internal(anyhow!("Plugin is dropped")))?; + .ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?; let params = json!({"chat_id": chat_id, "method": "stream_answer", "params": {"content": message}}); - let resp = plugin - .async_send_request::("handle", ¶ms) - .await?; - Ok(resp) + plugin.stream_request::("handle", ¶ms) } pub async fn get_related_questions( &self, chat_id: &str, - ) -> Result, Error> { + ) -> Result, SidecarError> { let plugin = self .plugin .upgrade() - .ok_or(Error::Internal(anyhow!("Plugin is dropped")))?; + .ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?; let params = json!({"chat_id": chat_id, "method": "related_question"}); let resp = plugin - .async_send_request::("handle", ¶ms) + .async_request::("handle", ¶ms) .await?; Ok(resp) } diff --git a/frontend/rust-lib/flowy-sidecar/src/plugins/embedding_plugin.rs b/frontend/rust-lib/flowy-sidecar/src/plugins/embedding_plugin.rs index 5f1e065f37..93b2a8717c 100644 --- a/frontend/rust-lib/flowy-sidecar/src/plugins/embedding_plugin.rs +++ b/frontend/rust-lib/flowy-sidecar/src/plugins/embedding_plugin.rs @@ -1,8 +1,6 @@ -use crate::core::parser::{ - ChatRelatedQuestionsResponseParser, ChatResponseParser, SimilarityResponseParser, -}; -use crate::core::plugin::{Plugin, PluginId}; -use crate::error::Error; +use crate::core::parser::SimilarityResponseParser; +use crate::core::plugin::Plugin; +use crate::error::SidecarError; use anyhow::anyhow; use serde_json::json; use std::sync::Weak; @@ -16,15 +14,19 @@ impl EmbeddingPluginOperation { EmbeddingPluginOperation { plugin } } - pub async fn calculate_similarity(&self, message1: &str, message2: &str) -> Result { + pub async fn calculate_similarity( + &self, + message1: &str, + message2: &str, + ) -> Result { let plugin = self .plugin .upgrade() - .ok_or(Error::Internal(anyhow!("Plugin is dropped")))?; + .ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?; let params = json!({"method": "calculate_similarity", "params": {"src": message1, "dest": message2}}); plugin - .async_send_request::("handle", ¶ms) + .async_request::("handle", ¶ms) .await } } diff --git a/frontend/rust-lib/flowy-sidecar/tests/chat_test/mod.rs b/frontend/rust-lib/flowy-sidecar/tests/chat_test/mod.rs index b5b324cba5..e6de70abcf 100644 --- a/frontend/rust-lib/flowy-sidecar/tests/chat_test/mod.rs +++ b/frontend/rust-lib/flowy-sidecar/tests/chat_test/mod.rs @@ -1,11 +1,14 @@ use crate::util::LocalAITest; +use tokio_stream::StreamExt; #[tokio::test] async fn load_chat_model_test() { if let Ok(test) = LocalAITest::new() { let plugin_id = test.init_chat_plugin().await; let chat_id = uuid::Uuid::new_v4().to_string(); - let resp = test.send_message(&chat_id, plugin_id, "hello world").await; + let resp = test + .send_chat_message(&chat_id, plugin_id, "hello world") + .await; eprintln!("chat response: {:?}", resp); let embedding_plugin_id = test.init_embedding_plugin().await; @@ -17,3 +20,27 @@ async fn load_chat_model_test() { // eprintln!("related questions: {:?}", questions); } } +#[tokio::test] +async fn stream_local_model_test() { + if let Ok(test) = LocalAITest::new() { + let plugin_id = test.init_chat_plugin().await; + let chat_id = uuid::Uuid::new_v4().to_string(); + + let mut resp = test + .stream_chat_message(&chat_id, plugin_id, "hello world") + .await; + let a = resp.next().await.unwrap().unwrap(); + eprintln!("chat response: {:?}", a); + + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + + // let mut resp = test + // .stream_chat_message(&chat_id, plugin_id, "How are you") + // .await; + // let a = resp.next().await.unwrap().unwrap(); + // eprintln!("chat response: {:?}", a); + // let questions = test.related_question(&chat_id, plugin_id).await; + // assert_eq!(questions.len(), 3); + // eprintln!("related questions: {:?}", questions); + } +} diff --git a/frontend/rust-lib/flowy-sidecar/tests/util.rs b/frontend/rust-lib/flowy-sidecar/tests/util.rs index 2d7818225a..17d90e0514 100644 --- a/frontend/rust-lib/flowy-sidecar/tests/util.rs +++ b/frontend/rust-lib/flowy-sidecar/tests/util.rs @@ -1,10 +1,12 @@ use anyhow::Result; use flowy_sidecar::manager::SidecarManager; use serde_json::json; -use std::sync::{Arc, Once}; +use std::sync::Once; +use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::Stream; -use flowy_sidecar::core::parser::{ChatResponseParser, SimilarityResponseParser}; use flowy_sidecar::core::plugin::{PluginId, PluginInfo}; +use flowy_sidecar::error::SidecarError; use flowy_sidecar::plugins::chat_plugin::ChatPluginOperation; use flowy_sidecar::plugins::embedding_plugin::EmbeddingPluginOperation; use tracing_subscriber::fmt::Subscriber; @@ -61,7 +63,12 @@ impl LocalAITest { plugin_id } - pub async fn send_message(&self, chat_id: &str, plugin_id: PluginId, message: &str) -> String { + pub async fn send_chat_message( + &self, + chat_id: &str, + plugin_id: PluginId, + message: &str, + ) -> String { let plugin = self.manager.get_plugin(plugin_id).await.unwrap(); let operation = ChatPluginOperation::new(plugin); let resp = operation @@ -72,6 +79,20 @@ impl LocalAITest { resp } + pub async fn stream_chat_message( + &self, + chat_id: &str, + plugin_id: PluginId, + message: &str, + ) -> ReceiverStream> { + let plugin = self.manager.get_plugin(plugin_id).await.unwrap(); + let operation = ChatPluginOperation::new(plugin); + operation + .stream_message(chat_id, plugin_id, message) + .await + .unwrap() + } + pub async fn related_question( &self, chat_id: &str, diff --git a/frontend/rust-lib/lib-infra/src/lib.rs b/frontend/rust-lib/lib-infra/src/lib.rs index 18539e49aa..3b46e162fb 100644 --- a/frontend/rust-lib/lib-infra/src/lib.rs +++ b/frontend/rust-lib/lib-infra/src/lib.rs @@ -23,5 +23,6 @@ if_wasm! { pub mod isolate_stream; pub mod priority_task; pub mod ref_map; +pub mod stream_util; pub mod util; pub mod validator_fn; diff --git a/frontend/rust-lib/lib-infra/src/stream_util.rs b/frontend/rust-lib/lib-infra/src/stream_util.rs new file mode 100644 index 0000000000..41c747d26a --- /dev/null +++ b/frontend/rust-lib/lib-infra/src/stream_util.rs @@ -0,0 +1,21 @@ +use futures_core::Stream; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::sync::mpsc; +use tokio::sync::mpsc::{Receiver, Sender}; + +struct BoundedStream { + recv: Receiver, +} +impl Stream for BoundedStream { + type Item = T; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::into_inner(self).recv.poll_recv(cx) + } +} + +pub fn mpsc_channel_stream(size: usize) -> (Sender, impl Stream) { + let (tx, rx) = mpsc::channel(size); + let stream = BoundedStream { recv: rx }; + (tx, stream) +}