chore: test streaming

This commit is contained in:
nathan 2024-06-27 22:00:36 +08:00
parent 9472361664
commit e4b1108ff0
17 changed files with 359 additions and 142 deletions

View File

@ -2148,7 +2148,6 @@ dependencies = [
"anyhow", "anyhow",
"crossbeam-utils", "crossbeam-utils",
"dotenv", "dotenv",
"futures",
"lib-infra", "lib-infra",
"log", "log",
"once_cell", "once_cell",
@ -2157,6 +2156,7 @@ dependencies = [
"serde_json", "serde_json",
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-stream",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"uuid", "uuid",
@ -5767,9 +5767,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio-stream" name = "tokio-stream"
version = "0.1.14" version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"pin-project-lite", "pin-project-lite",

View File

@ -11,7 +11,7 @@ use flowy_sidecar::manager::SidecarManager;
use flowy_sqlite::DBConnection; use flowy_sqlite::DBConnection;
use lib_infra::future::FutureResult; use lib_infra::future::FutureResult;
use lib_infra::util::timestamp; use lib_infra::util::timestamp;
use std::sync::atomic::AtomicBool;
use std::sync::Arc; use std::sync::Arc;
use tracing::trace; use tracing::trace;
@ -263,9 +263,9 @@ impl ChatCloudService for ChatService {
async fn stream_answer( async fn stream_answer(
&self, &self,
workspace_id: &str, _workspace_id: &str,
chat_id: &str, _chat_id: &str,
message_id: i64, _message_id: i64,
) -> Result<StreamAnswer, FlowyError> { ) -> Result<StreamAnswer, FlowyError> {
todo!() todo!()
} }

View File

@ -16,8 +16,8 @@ tracing.workspace = true
crossbeam-utils = "0.8.20" crossbeam-utils = "0.8.20"
log = "0.4.21" log = "0.4.21"
parking_lot.workspace = true parking_lot.workspace = true
tokio-stream = "0.1.15"
lib-infra.workspace = true lib-infra.workspace = true
futures.workspace = true
[dev-dependencies] [dev-dependencies]

View File

@ -1,8 +1,9 @@
use crate::core::rpc_object::RpcObject; use crate::core::rpc_object::RpcObject;
use crate::core::rpc_peer::ResponsePayload;
use crate::error::{ReadError, RemoteError}; use crate::error::{ReadError, RemoteError};
use serde_json::{json, Value as JsonValue}; use serde_json::{json, Value as JsonValue};
use std::io::BufRead; use std::io::BufRead;
use tracing::error;
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct MessageReader(String); pub struct MessageReader(String);
@ -57,15 +58,15 @@ pub enum Call<R> {
} }
pub trait ResponseParser { pub trait ResponseParser {
type ValueType; type ValueType: Send + Sync + 'static;
fn parse_response(payload: JsonValue) -> Result<Self::ValueType, RemoteError>; fn parse_json(payload: JsonValue) -> Result<Self::ValueType, RemoteError>;
} }
pub struct ChatResponseParser; pub struct ChatResponseParser;
impl ResponseParser for ChatResponseParser { impl ResponseParser for ChatResponseParser {
type ValueType = String; type ValueType = String;
fn parse_response(json: JsonValue) -> Result<Self::ValueType, RemoteError> { fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
if json.is_object() { if json.is_object() {
if let Some(data) = json.get("data") { if let Some(data) = json.get("data") {
if let Some(message) = data.as_str() { if let Some(message) = data.as_str() {
@ -73,7 +74,19 @@ impl ResponseParser for ChatResponseParser {
} }
} }
} }
return Err(RemoteError::InvalidResponse(json)); return Err(RemoteError::ParseResponse(json));
}
}
pub struct ChatStreamResponseParser;
impl ResponseParser for ChatStreamResponseParser {
type ValueType = String;
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
if let Some(message) = json.as_str() {
return Ok(message.to_string());
}
return Err(RemoteError::ParseResponse(json));
} }
} }
@ -81,7 +94,7 @@ pub struct ChatRelatedQuestionsResponseParser;
impl ResponseParser for ChatRelatedQuestionsResponseParser { impl ResponseParser for ChatRelatedQuestionsResponseParser {
type ValueType = Vec<JsonValue>; type ValueType = Vec<JsonValue>;
fn parse_response(json: JsonValue) -> Result<Self::ValueType, RemoteError> { fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
if json.is_object() { if json.is_object() {
if let Some(data) = json.get("data") { if let Some(data) = json.get("data") {
if let Some(values) = data.as_array() { if let Some(values) = data.as_array() {
@ -89,7 +102,7 @@ impl ResponseParser for ChatRelatedQuestionsResponseParser {
} }
} }
} }
return Err(RemoteError::InvalidResponse(json)); return Err(RemoteError::ParseResponse(json));
} }
} }
@ -97,7 +110,7 @@ pub struct SimilarityResponseParser;
impl ResponseParser for SimilarityResponseParser { impl ResponseParser for SimilarityResponseParser {
type ValueType = f64; type ValueType = f64;
fn parse_response(json: JsonValue) -> Result<Self::ValueType, RemoteError> { fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
if json.is_object() { if json.is_object() {
if let Some(data) = json.get("data") { if let Some(data) = json.get("data") {
if let Some(score) = data.get("score").and_then(|v| v.as_f64()) { if let Some(score) = data.get("score").and_then(|v| v.as_f64()) {
@ -106,6 +119,6 @@ impl ResponseParser for SimilarityResponseParser {
} }
} }
return Err(RemoteError::InvalidResponse(json)); return Err(RemoteError::ParseResponse(json));
} }
} }

View File

@ -1,9 +1,9 @@
use crate::error::Error; use crate::error::SidecarError;
use crate::manager::WeakSidecarState; use crate::manager::WeakSidecarState;
use crate::core::parser::ResponseParser; use crate::core::parser::ResponseParser;
use crate::core::rpc_loop::RpcLoop; use crate::core::rpc_loop::RpcLoop;
use crate::core::rpc_peer::{Callback, ResponsePayload}; use crate::core::rpc_peer::{CloneableCallback, OneShotCallback};
use anyhow::anyhow; use anyhow::anyhow;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value as JsonValue}; use serde_json::{json, Value as JsonValue};
@ -12,6 +12,9 @@ use std::process::{Child, Stdio};
use std::sync::Arc; use std::sync::Arc;
use std::thread; use std::thread;
use std::time::Instant; use std::time::Instant;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::Stream;
use tracing::{error, info}; use tracing::{error, info};
#[derive( #[derive(
@ -34,12 +37,12 @@ pub trait Peer: Send + Sync + 'static {
/// Sends an RPC notification to the peer with the specified method and parameters. /// Sends an RPC notification to the peer with the specified method and parameters.
fn send_rpc_notification(&self, method: &str, params: &JsonValue); fn send_rpc_notification(&self, method: &str, params: &JsonValue);
/// Sends an asynchronous RPC request to the peer and executes the provided callback upon completion. fn stream_rpc_request(&self, method: &str, params: &JsonValue, f: CloneableCallback);
fn send_rpc_request_async(&self, method: &str, params: &JsonValue, f: Box<dyn Callback>);
fn async_send_rpc_request(&self, method: &str, params: &JsonValue, f: Box<dyn OneShotCallback>);
/// Sends a synchronous RPC request to the peer and waits for the result. /// Sends a synchronous RPC request to the peer and waits for the result.
/// Returns the result of the request or an error. /// Returns the result of the request or an error.
fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, Error>; fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, SidecarError>;
/// Checks if there is an incoming request pending, intended to reduce latency for bulk operations done in the background. /// Checks if there is an incoming request pending, intended to reduce latency for bulk operations done in the background.
fn request_is_pending(&self) -> bool; fn request_is_pending(&self) -> bool;
@ -66,34 +69,53 @@ pub struct Plugin {
} }
impl Plugin { impl Plugin {
pub fn initialize(&self, value: JsonValue) -> Result<(), Error> { pub fn initialize(&self, value: JsonValue) -> Result<(), SidecarError> {
self.peer.send_rpc_request("initialize", &value)?; self.peer.send_rpc_request("initialize", &value)?;
Ok(()) Ok(())
} }
pub fn send_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, Error> { pub fn request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, SidecarError> {
self.peer.send_rpc_request(method, params) self.peer.send_rpc_request(method, params)
} }
pub async fn async_send_request<P: ResponseParser>( pub async fn async_request<P: ResponseParser>(
&self, &self,
method: &str, method: &str,
params: &JsonValue, params: &JsonValue,
) -> Result<P::ValueType, Error> { ) -> Result<P::ValueType, SidecarError> {
let (tx, rx) = tokio::sync::oneshot::channel(); let (tx, rx) = tokio::sync::oneshot::channel();
self.peer.send_rpc_request_async( self.peer.async_send_rpc_request(
method, method,
params, params,
Box::new(move |result| { Box::new(move |result| {
let _ = tx.send(result); let _ = tx.send(result);
}), }),
); );
let value = rx let value = rx.await.map_err(|err| {
.await SidecarError::Internal(anyhow!("error waiting for async response: {:?}", err))
.map_err(|err| Error::Internal(anyhow!("error waiting for async response: {:?}", err)))??; })??;
let value = P::parse_response(value)?; let value = P::parse_json(value)?;
Ok(value) Ok(value)
} }
pub fn stream_request<P: ResponseParser>(
&self,
method: &str,
params: &JsonValue,
) -> Result<ReceiverStream<Result<P::ValueType, SidecarError>>, SidecarError> {
let (tx, stream) = tokio::sync::mpsc::channel(100);
let stream = ReceiverStream::new(stream);
let callback = CloneableCallback::new(move |result| match result {
Ok(json) => {
let result = P::parse_json(json).map_err(SidecarError::from);
let _ = tx.blocking_send(result);
},
Err(err) => {
let _ = tx.blocking_send(Err(err));
},
});
self.peer.stream_rpc_request(method, params, callback);
Ok(stream)
}
pub fn shutdown(&self) { pub fn shutdown(&self) {
if let Err(err) = self.peer.send_rpc_request("shutdown", &json!({})) { if let Err(err) = self.peer.send_rpc_request("shutdown", &json!({})) {

View File

@ -2,14 +2,14 @@ use crate::core::parser::{Call, MessageReader};
use crate::core::plugin::RpcCtx; use crate::core::plugin::RpcCtx;
use crate::core::rpc_object::RpcObject; use crate::core::rpc_object::RpcObject;
use crate::core::rpc_peer::{RawPeer, ResponsePayload, RpcState}; use crate::core::rpc_peer::{RawPeer, ResponsePayload, RpcState};
use crate::error::{Error, ReadError, RemoteError}; use crate::error::{ReadError, RemoteError, SidecarError};
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde_json::Value;
use std::io::{BufRead, Write}; use std::io::{BufRead, Write};
use std::sync::Arc; use std::sync::Arc;
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
use tracing::{error, trace}; use tracing::{error, info, trace};
const MAX_IDLE_WAIT: Duration = Duration::from_millis(5); const MAX_IDLE_WAIT: Duration = Duration::from_millis(5);
@ -97,7 +97,6 @@ impl<W: Write + Send> RpcLoop<W> {
trace!("read loop exit"); trace!("read loop exit");
break; break;
} }
let json = match self.reader.next(&mut stream) { let json = match self.reader.next(&mut stream) {
Ok(json) => json, Ok(json) => json,
Err(err) => { Err(err) => {
@ -109,15 +108,17 @@ impl<W: Write + Send> RpcLoop<W> {
}, },
}; };
if json.is_response() { if json.is_response() {
let id = json.get_id().unwrap(); let request_id = json.get_id().unwrap();
match json.into_response() { match json.into_response() {
Ok(resp) => { Ok(resp) => {
let resp = resp.map_err(Error::from); let resp = resp.map_err(SidecarError::from);
self.peer.handle_response(id, resp); self.peer.handle_response(request_id, resp);
}, },
Err(msg) => { Err(msg) => {
error!("[RPC] failed to parse response: {}", msg); error!("[RPC] failed to parse response: {}", msg);
self.peer.handle_response(id, Err(Error::InvalidResponse)); self
.peer
.handle_response(request_id, Err(SidecarError::InvalidResponse));
}, },
} }
} else { } else {

View File

@ -1,7 +1,8 @@
use crate::core::parser::{Call, RequestId}; use crate::core::parser::{Call, RequestId};
use crate::core::rpc_peer::{Response, ResponsePayload}; use crate::core::rpc_peer::{Response, ResponsePayload};
use crate::error::RemoteError;
use serde::de::{DeserializeOwned, Error}; use serde::de::{DeserializeOwned, Error};
use serde_json::Value; use serde_json::{json, Value};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct RpcObject(pub Value); pub struct RpcObject(pub Value);
@ -24,30 +25,63 @@ impl RpcObject {
self.0.get("id").is_some() && self.0.get("method").is_none() self.0.get("id").is_some() && self.0.get("method").is_none()
} }
/// Converts the underlying `Value` into an RPC response object. /// Converts a JSON-RPC response into a structured `Response` object.
/// The caller should verify that the object is a response before calling this method. ///
/// This function validates and parses a JSON-RPC response, ensuring it contains the necessary fields,
/// and then transforms it into a structured `Response` object. The response must contain either a
/// "result" or an "error" field, but not both. If the response contains a "result" field, it may also
/// include streaming data, indicated by a nested "stream" field.
///
/// # Errors /// # Errors
/// If the `Value` is not a well-formed response object, this returns a `String` containing an ///
/// error message. The caller should print this message and exit. /// This function will return an error if:
/// - The "id" field is missing.
/// - The response contains both "result" and "error" fields, or neither.
/// - The "stream" field within the "result" is missing "type" or "data" fields.
/// - The "stream" type is invalid (i.e., not "streaming" or "end").
///
/// # Returns
///
/// - `Ok(Ok(ResponsePayload::Json(result)))`: If the response contains a valid "result".
/// - `Ok(Ok(ResponsePayload::Streaming(data)))`: If the response contains streaming data of type "streaming".
/// - `Ok(Ok(ResponsePayload::StreamEnd(json!({}))))`: If the response contains streaming data of type "end".
/// - `Err(String)`: If any validation or parsing errors occur.
///.
pub fn into_response(mut self) -> Result<Response, String> { pub fn into_response(mut self) -> Result<Response, String> {
let _ = self // Ensure 'id' field is present
self
.get_id() .get_id()
.ok_or("Response requires 'id' field.".to_string())?; .ok_or_else(|| "Response requires 'id' field.".to_string())?;
if self.0.get("result").is_some() == self.0.get("error").is_some() { // Ensure the response contains exactly one of 'result' or 'error'
let has_result = self.0.get("result").is_some();
let has_error = self.0.get("error").is_some();
if has_result == has_error {
return Err("RPC response must contain exactly one of 'error' or 'result' fields.".into()); return Err("RPC response must contain exactly one of 'error' or 'result' fields.".into());
} }
let result = self.0.as_object_mut().and_then(|obj| obj.remove("result"));
match result { // Handle the 'result' field if present
Some(r) => Ok(Ok(ResponsePayload::Json(r))), if let Some(mut result) = self.0.as_object_mut().and_then(|obj| obj.remove("result")) {
None => { if let Some(mut stream) = result.as_object_mut().and_then(|obj| obj.remove("stream")) {
let error = self if let Some((has_more, data)) = stream.as_object_mut().and_then(|obj| {
.0 let has_more = obj.remove("has_more")?.as_bool().unwrap_or(false);
.as_object_mut() let data = obj.remove("data")?;
.and_then(|obj| obj.remove("error")) Some((has_more, data))
.unwrap(); }) {
return match has_more {
true => Ok(Ok(ResponsePayload::Streaming(data))),
false => Ok(Ok(ResponsePayload::StreamEnd(data))),
};
} else {
return Err("Stream response must contain 'type' and 'data' fields.".into());
}
}
Ok(Ok(ResponsePayload::Json(result)))
} else {
// Handle the 'error' field
let error = self.0.as_object_mut().unwrap().remove("error").unwrap();
Err(format!("Error handling response: {:?}", error)) Err(format!("Error handling response: {:?}", error))
},
} }
} }

View File

@ -1,18 +1,18 @@
use crate::core::plugin::{Peer, PluginId}; use crate::core::plugin::{Peer, PluginId};
use crate::core::rpc_object::RpcObject; use crate::core::rpc_object::RpcObject;
use crate::error::{Error, ReadError, RemoteError}; use crate::error::{ReadError, RemoteError, SidecarError};
use futures::Stream;
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use serde::{de, ser, Deserialize, Deserializer, Serialize, Serializer}; use serde::{de, ser, Deserialize, Deserializer, Serialize, Serializer};
use serde_json::{json, Value as JsonValue}; use serde_json::{json, Value as JsonValue};
use std::collections::{BTreeMap, BinaryHeap, VecDeque}; use std::collections::{BTreeMap, BinaryHeap, VecDeque};
use std::fmt::Display; use std::fmt::Display;
use std::io::Write; use std::io::Write;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{mpsc, Arc}; use std::sync::{mpsc, Arc};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::{cmp, io}; use std::{cmp, io};
use tokio_stream::Stream;
use tracing::{error, trace, warn}; use tracing::{error, trace, warn};
pub struct PluginCommand<T> { pub struct PluginCommand<T> {
@ -97,15 +97,19 @@ impl<W: Write + Send + 'static> Peer for RawPeer<W> {
} }
} }
fn send_rpc_request_async(&self, method: &str, params: &JsonValue, f: Box<dyn Callback>) { fn stream_rpc_request(&self, method: &str, params: &JsonValue, f: CloneableCallback) {
self.send_rpc(method, params, ResponseHandler::StreamCallback(Arc::new(f)));
}
fn async_send_rpc_request(&self, method: &str, params: &JsonValue, f: Box<dyn OneShotCallback>) {
self.send_rpc(method, params, ResponseHandler::Callback(f)); self.send_rpc(method, params, ResponseHandler::Callback(f));
} }
fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, Error> { fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, SidecarError> {
let (tx, rx) = mpsc::channel(); let (tx, rx) = mpsc::channel();
self.0.is_blocking.store(true, Ordering::Release); self.0.is_blocking.store(true, Ordering::Release);
self.send_rpc(method, params, ResponseHandler::Chan(tx)); self.send_rpc(method, params, ResponseHandler::Chan(tx));
rx.recv().unwrap_or(Err(Error::PeerDisconnect)) rx.recv().unwrap_or(Err(SidecarError::PeerDisconnect))
} }
fn request_is_pending(&self) -> bool { fn request_is_pending(&self) -> bool {
@ -133,7 +137,7 @@ impl<W: Write> RawPeer<W> {
match result { match result {
Ok(result) => match result { Ok(result) => match result {
ResponsePayload::Json(value) => response["result"] = value, ResponsePayload::Json(value) => response["result"] = value,
ResponsePayload::Stream(_) => { ResponsePayload::Streaming(_) | ResponsePayload::StreamEnd(_) => {
error!("stream response not supported") error!("stream response not supported")
}, },
}, },
@ -161,29 +165,44 @@ impl<W: Write> RawPeer<W> {
})) { })) {
let mut pending = self.0.pending.lock(); let mut pending = self.0.pending.lock();
if let Some(rh) = pending.remove(&id) { if let Some(rh) = pending.remove(&id) {
rh.invoke(Err(Error::Io(e))); rh.invoke(Err(SidecarError::Io(e)));
} }
} }
} }
pub(crate) fn handle_response(&self, id: u64, resp: Result<ResponsePayload, Error>) { pub(crate) fn handle_response(
let id = id as usize; &self,
request_id: u64,
resp: Result<ResponsePayload, SidecarError>,
) {
let request_id = request_id as usize;
let handler = { let handler = {
let mut pending = self.0.pending.lock(); let mut pending = self.0.pending.lock();
pending.remove(&id) pending.remove(&request_id)
}; };
let is_stream = resp.as_ref().map(|resp| resp.is_stream()).unwrap_or(false); let is_stream = resp.as_ref().map(|resp| resp.is_stream()).unwrap_or(false);
match handler { match handler {
Some(response_handler) => { Some(response_handler) => {
if is_stream {
let is_stream_end = resp
.as_ref()
.map(|resp| resp.is_stream_end())
.unwrap_or(false);
if !is_stream_end {
// when steam is not end, we need to put the stream callback back to pending in order to
// receive the next stream message.
if let Some(callback) = response_handler.get_stream_callback() {
let mut pending = self.0.pending.lock();
pending.insert(request_id, ResponseHandler::StreamCallback(callback));
}
} else {
trace!("[RPC] {} stream end", request_id);
}
}
let json = resp.map(|resp| resp.into_json()); let json = resp.map(|resp| resp.into_json());
response_handler.invoke(json); response_handler.invoke(json);
// if is_stream {
// let mut pending = self.0.pending.lock();
// pending.insert(id, response_handler);
// }
}, },
None => warn!("[RPC] id {} not found in pending", id), None => error!("[RPC] id {}'s handle not found", request_id),
} }
} }
@ -243,7 +262,7 @@ impl<W: Write> RawPeer<W> {
let ids = pending.keys().cloned().collect::<Vec<_>>(); let ids = pending.keys().cloned().collect::<Vec<_>>();
for id in &ids { for id in &ids {
let callback = pending.remove(id).unwrap(); let callback = pending.remove(id).unwrap();
callback.invoke(Err(Error::PeerDisconnect)); callback.invoke(Err(SidecarError::PeerDisconnect));
} }
self.0.needs_exit.store(true, Ordering::Relaxed); self.0.needs_exit.store(true, Ordering::Relaxed);
} }
@ -267,7 +286,8 @@ impl<W: Write> Clone for RawPeer<W> {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum ResponsePayload { pub enum ResponsePayload {
Json(JsonValue), Json(JsonValue),
Stream(JsonValue), Streaming(JsonValue),
StreamEnd(JsonValue),
} }
impl ResponsePayload { impl ResponsePayload {
@ -276,23 +296,29 @@ impl ResponsePayload {
} }
pub fn is_stream(&self) -> bool { pub fn is_stream(&self) -> bool {
match self { matches!(
ResponsePayload::Json(_) => false, self,
ResponsePayload::Stream(_) => true, ResponsePayload::Streaming(_) | ResponsePayload::StreamEnd(_)
)
} }
pub fn is_stream_end(&self) -> bool {
matches!(self, ResponsePayload::StreamEnd(_))
} }
pub fn json(&self) -> &JsonValue { pub fn json(&self) -> &JsonValue {
match self { match self {
ResponsePayload::Json(v) => v, ResponsePayload::Json(v) => v,
ResponsePayload::Stream(v) => v, ResponsePayload::Streaming(v) => v,
ResponsePayload::StreamEnd(v) => v,
} }
} }
pub fn into_json(self) -> JsonValue { pub fn into_json(self) -> JsonValue {
match self { match self {
ResponsePayload::Json(v) => v, ResponsePayload::Json(v) => v,
ResponsePayload::Stream(v) => v, ResponsePayload::Streaming(v) => v,
ResponsePayload::StreamEnd(v) => v,
} }
} }
} }
@ -301,37 +327,78 @@ impl Display for ResponsePayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
ResponsePayload::Json(v) => write!(f, "{}", v), ResponsePayload::Json(v) => write!(f, "{}", v),
ResponsePayload::Stream(_) => write!(f, "stream"), ResponsePayload::Streaming(_) => write!(f, "stream start"),
ResponsePayload::StreamEnd(_) => write!(f, "stream end"),
} }
} }
} }
pub type Response = Result<ResponsePayload, RemoteError>; pub type Response = Result<ResponsePayload, RemoteError>;
pub trait ResponseStream: Stream<Item = Result<JsonValue, Error>> + Unpin + Send {} pub trait ResponseStream: Stream<Item = Result<JsonValue, SidecarError>> + Unpin + Send {}
impl<T> ResponseStream for T where T: Stream<Item = Result<JsonValue, Error>> + Unpin + Send {} impl<T> ResponseStream for T where T: Stream<Item = Result<JsonValue, SidecarError>> + Unpin + Send {}
enum ResponseHandler { enum ResponseHandler {
Chan(mpsc::Sender<Result<JsonValue, Error>>), Chan(mpsc::Sender<Result<JsonValue, SidecarError>>),
Callback(Box<dyn Callback>), Callback(Box<dyn OneShotCallback>),
} StreamCallback(Arc<CloneableCallback>),
pub trait Callback: Send {
fn call(self: Box<Self>, result: Result<JsonValue, Error>);
} }
impl<F: Send + FnOnce(Result<JsonValue, Error>)> Callback for F { impl ResponseHandler {
fn call(self: Box<F>, result: Result<JsonValue, Error>) { pub fn get_stream_callback(&self) -> Option<Arc<CloneableCallback>> {
match self {
ResponseHandler::StreamCallback(cb) => Some(cb.clone()),
_ => None,
}
}
}
pub trait OneShotCallback: Send {
fn call(self: Box<Self>, result: Result<JsonValue, SidecarError>);
}
impl<F: Send + FnOnce(Result<JsonValue, SidecarError>)> OneShotCallback for F {
fn call(self: Box<Self>, result: Result<JsonValue, SidecarError>) {
(self)(result)
}
}
pub trait Callback: Send + Sync {
fn call(&self, result: Result<JsonValue, SidecarError>);
}
impl<F: Send + Sync + Fn(Result<JsonValue, SidecarError>)> Callback for F {
fn call(&self, result: Result<JsonValue, SidecarError>) {
(*self)(result) (*self)(result)
} }
} }
#[derive(Clone)]
pub struct CloneableCallback {
callback: Arc<dyn Callback>,
}
impl CloneableCallback {
pub fn new<C: Callback + 'static>(callback: C) -> Self {
CloneableCallback {
callback: Arc::new(callback),
}
}
pub fn call(&self, result: Result<JsonValue, SidecarError>) {
self.callback.call(result)
}
}
impl ResponseHandler { impl ResponseHandler {
fn invoke(self, result: Result<JsonValue, Error>) { fn invoke(self, result: Result<JsonValue, SidecarError>) {
match self { match self {
ResponseHandler::Chan(tx) => { ResponseHandler::Chan(tx) => {
let _ = tx.send(result); let _ = tx.send(result);
}, },
ResponseHandler::StreamCallback(cb) => {
cb.call(result);
},
ResponseHandler::Callback(f) => f.call(result), ResponseHandler::Callback(f) => f.call(result),
} }
} }

View File

@ -1,11 +1,10 @@
use crate::core::rpc_peer::ResponsePayload;
use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::{json, Value as JsonValue}; use serde_json::{json, Value as JsonValue};
use std::{fmt, io}; use std::{fmt, io};
/// The error type of `tauri-utils`. /// The error type of `tauri-utils`.
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum Error { pub enum SidecarError {
/// An IO error occurred on the underlying communication channel. /// An IO error occurred on the underlying communication channel.
#[error(transparent)] #[error(transparent)]
Io(#[from] io::Error), Io(#[from] io::Error),
@ -48,6 +47,9 @@ pub enum RemoteError {
#[error("Invalid response: {0}")] #[error("Invalid response: {0}")]
InvalidResponse(JsonValue), InvalidResponse(JsonValue),
#[error("Parse response: {0}")]
ParseResponse(JsonValue),
/// A custom error, defined by the client. /// A custom error, defined by the client.
#[error("Custom error: {message}")] #[error("Custom error: {message}")]
Custom { Custom {
@ -100,9 +102,9 @@ impl From<serde_json::Error> for RemoteError {
} }
} }
impl From<RemoteError> for Error { impl From<RemoteError> for SidecarError {
fn from(err: RemoteError) -> Error { fn from(err: RemoteError) -> SidecarError {
Error::RemoteError(err) SidecarError::RemoteError(err)
} }
} }
@ -156,6 +158,11 @@ impl Serialize for RemoteError {
"Invalid response".to_string(), "Invalid response".to_string(),
Some(json!(resp.to_string())), Some(json!(resp.to_string())),
), ),
RemoteError::ParseResponse(resp) => (
-1,
"Invalid response".to_string(),
Some(json!(resp.to_string())),
),
}; };
let err = ErrorHelper { let err = ErrorHelper {
code, code,

View File

@ -1,4 +1,4 @@
pub mod core; pub mod core;
mod error; pub mod error;
pub mod manager; pub mod manager;
pub mod plugins; pub mod plugins;

View File

@ -2,11 +2,11 @@ use crate::core::parser::ResponseParser;
use crate::core::plugin::{start_plugin_process, Plugin, PluginId, PluginInfo, RpcCtx}; use crate::core::plugin::{start_plugin_process, Plugin, PluginId, PluginInfo, RpcCtx};
use crate::core::rpc_loop::Handler; use crate::core::rpc_loop::Handler;
use crate::core::rpc_peer::{PluginCommand, ResponsePayload}; use crate::core::rpc_peer::{PluginCommand, ResponsePayload};
use crate::error::{Error, ReadError, RemoteError}; use crate::error::{ReadError, RemoteError, SidecarError};
use anyhow::anyhow; use anyhow::anyhow;
use lib_infra::util::{get_operating_system, OperatingSystem}; use lib_infra::util::{get_operating_system, OperatingSystem};
use parking_lot::Mutex; use parking_lot::Mutex;
use serde_json::{json, Value}; use serde_json::Value;
use std::io; use std::io;
use std::sync::atomic::{AtomicI64, Ordering}; use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
@ -29,9 +29,9 @@ impl SidecarManager {
} }
} }
pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result<PluginId, Error> { pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result<PluginId, SidecarError> {
if self.operating_system.is_not_desktop() { if self.operating_system.is_not_desktop() {
return Err(Error::Internal(anyhow!( return Err(SidecarError::Internal(anyhow!(
"plugin not supported on this platform" "plugin not supported on this platform"
))); )));
} }
@ -41,7 +41,7 @@ impl SidecarManager {
Ok(plugin_id) Ok(plugin_id)
} }
pub async fn get_plugin(&self, plugin_id: PluginId) -> Result<Weak<Plugin>, Error> { pub async fn get_plugin(&self, plugin_id: PluginId) -> Result<Weak<Plugin>, SidecarError> {
let state = self.state.lock(); let state = self.state.lock();
let plugin = state let plugin = state
.plugins .plugins
@ -51,9 +51,9 @@ impl SidecarManager {
Ok(Arc::downgrade(plugin)) Ok(Arc::downgrade(plugin))
} }
pub async fn remove_plugin(&self, id: PluginId) -> Result<(), Error> { pub async fn remove_plugin(&self, id: PluginId) -> Result<(), SidecarError> {
if self.operating_system.is_not_desktop() { if self.operating_system.is_not_desktop() {
return Err(Error::Internal(anyhow!( return Err(SidecarError::Internal(anyhow!(
"plugin not supported on this platform" "plugin not supported on this platform"
))); )));
} }
@ -71,9 +71,9 @@ impl SidecarManager {
Ok(()) Ok(())
} }
pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<(), Error> { pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<(), SidecarError> {
if self.operating_system.is_not_desktop() { if self.operating_system.is_not_desktop() {
return Err(Error::Internal(anyhow!( return Err(SidecarError::Internal(anyhow!(
"plugin not supported on this platform" "plugin not supported on this platform"
))); )));
} }
@ -94,15 +94,15 @@ impl SidecarManager {
id: PluginId, id: PluginId,
method: &str, method: &str,
request: Value, request: Value,
) -> Result<P::ValueType, Error> { ) -> Result<P::ValueType, SidecarError> {
let state = self.state.lock(); let state = self.state.lock();
let plugin = state let plugin = state
.plugins .plugins
.iter() .iter()
.find(|p| p.id == id) .find(|p| p.id == id)
.ok_or(anyhow!("plugin not found"))?; .ok_or(anyhow!("plugin not found"))?;
let resp = plugin.send_request(method, &request)?; let resp = plugin.request(method, &request)?;
let value = P::parse_response(resp)?; let value = P::parse_json(resp)?;
Ok(value) Ok(value)
} }
@ -111,14 +111,14 @@ impl SidecarManager {
id: PluginId, id: PluginId,
method: &str, method: &str,
request: Value, request: Value,
) -> Result<P::ValueType, Error> { ) -> Result<P::ValueType, SidecarError> {
let state = self.state.lock(); let state = self.state.lock();
let plugin = state let plugin = state
.plugins .plugins
.iter() .iter()
.find(|p| p.id == id) .find(|p| p.id == id)
.ok_or(anyhow!("plugin not found"))?; .ok_or(anyhow!("plugin not found"))?;
let value = plugin.async_send_request::<P>(method, &request).await?; let value = plugin.async_request::<P>(method, &request).await?;
Ok(value) Ok(value)
} }
} }

View File

@ -1,9 +1,13 @@
use crate::core::parser::{ChatRelatedQuestionsResponseParser, ChatResponseParser}; use crate::core::parser::{
ChatRelatedQuestionsResponseParser, ChatResponseParser, ChatStreamResponseParser,
};
use crate::core::plugin::{Plugin, PluginId}; use crate::core::plugin::{Plugin, PluginId};
use crate::error::Error; use crate::error::SidecarError;
use anyhow::anyhow; use anyhow::anyhow;
use serde_json::json; use serde_json::json;
use std::sync::Weak; use std::sync::Weak;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::Stream;
pub struct ChatPluginOperation { pub struct ChatPluginOperation {
plugin: Weak<Plugin>, plugin: Weak<Plugin>,
@ -17,17 +21,17 @@ impl ChatPluginOperation {
pub async fn send_message( pub async fn send_message(
&self, &self,
chat_id: &str, chat_id: &str,
plugin_id: PluginId, _plugin_id: PluginId,
message: &str, message: &str,
) -> Result<String, Error> { ) -> Result<String, SidecarError> {
let plugin = self let plugin = self
.plugin .plugin
.upgrade() .upgrade()
.ok_or(Error::Internal(anyhow!("Plugin is dropped")))?; .ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
let params = json!({"chat_id": chat_id, "method": "answer", "params": {"content": message}}); let params = json!({"chat_id": chat_id, "method": "answer", "params": {"content": message}});
let resp = plugin let resp = plugin
.async_send_request::<ChatResponseParser>("handle", &params) .async_request::<ChatResponseParser>("handle", &params)
.await?; .await?;
Ok(resp) Ok(resp)
} }
@ -35,34 +39,31 @@ impl ChatPluginOperation {
pub async fn stream_message( pub async fn stream_message(
&self, &self,
chat_id: &str, chat_id: &str,
plugin_id: PluginId, _plugin_id: PluginId,
message: &str, message: &str,
) -> Result<String, Error> { ) -> Result<ReceiverStream<Result<String, SidecarError>>, SidecarError> {
let plugin = self let plugin = self
.plugin .plugin
.upgrade() .upgrade()
.ok_or(Error::Internal(anyhow!("Plugin is dropped")))?; .ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
let params = let params =
json!({"chat_id": chat_id, "method": "stream_answer", "params": {"content": message}}); json!({"chat_id": chat_id, "method": "stream_answer", "params": {"content": message}});
let resp = plugin plugin.stream_request::<ChatStreamResponseParser>("handle", &params)
.async_send_request::<ChatResponseParser>("handle", &params)
.await?;
Ok(resp)
} }
pub async fn get_related_questions( pub async fn get_related_questions(
&self, &self,
chat_id: &str, chat_id: &str,
) -> Result<Vec<serde_json::Value>, Error> { ) -> Result<Vec<serde_json::Value>, SidecarError> {
let plugin = self let plugin = self
.plugin .plugin
.upgrade() .upgrade()
.ok_or(Error::Internal(anyhow!("Plugin is dropped")))?; .ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
let params = json!({"chat_id": chat_id, "method": "related_question"}); let params = json!({"chat_id": chat_id, "method": "related_question"});
let resp = plugin let resp = plugin
.async_send_request::<ChatRelatedQuestionsResponseParser>("handle", &params) .async_request::<ChatRelatedQuestionsResponseParser>("handle", &params)
.await?; .await?;
Ok(resp) Ok(resp)
} }

View File

@ -1,8 +1,6 @@
use crate::core::parser::{ use crate::core::parser::SimilarityResponseParser;
ChatRelatedQuestionsResponseParser, ChatResponseParser, SimilarityResponseParser, use crate::core::plugin::Plugin;
}; use crate::error::SidecarError;
use crate::core::plugin::{Plugin, PluginId};
use crate::error::Error;
use anyhow::anyhow; use anyhow::anyhow;
use serde_json::json; use serde_json::json;
use std::sync::Weak; use std::sync::Weak;
@ -16,15 +14,19 @@ impl EmbeddingPluginOperation {
EmbeddingPluginOperation { plugin } EmbeddingPluginOperation { plugin }
} }
pub async fn calculate_similarity(&self, message1: &str, message2: &str) -> Result<f64, Error> { pub async fn calculate_similarity(
&self,
message1: &str,
message2: &str,
) -> Result<f64, SidecarError> {
let plugin = self let plugin = self
.plugin .plugin
.upgrade() .upgrade()
.ok_or(Error::Internal(anyhow!("Plugin is dropped")))?; .ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
let params = let params =
json!({"method": "calculate_similarity", "params": {"src": message1, "dest": message2}}); json!({"method": "calculate_similarity", "params": {"src": message1, "dest": message2}});
plugin plugin
.async_send_request::<SimilarityResponseParser>("handle", &params) .async_request::<SimilarityResponseParser>("handle", &params)
.await .await
} }
} }

View File

@ -1,11 +1,14 @@
use crate::util::LocalAITest; use crate::util::LocalAITest;
use tokio_stream::StreamExt;
#[tokio::test] #[tokio::test]
async fn load_chat_model_test() { async fn load_chat_model_test() {
if let Ok(test) = LocalAITest::new() { if let Ok(test) = LocalAITest::new() {
let plugin_id = test.init_chat_plugin().await; let plugin_id = test.init_chat_plugin().await;
let chat_id = uuid::Uuid::new_v4().to_string(); let chat_id = uuid::Uuid::new_v4().to_string();
let resp = test.send_message(&chat_id, plugin_id, "hello world").await; let resp = test
.send_chat_message(&chat_id, plugin_id, "hello world")
.await;
eprintln!("chat response: {:?}", resp); eprintln!("chat response: {:?}", resp);
let embedding_plugin_id = test.init_embedding_plugin().await; let embedding_plugin_id = test.init_embedding_plugin().await;
@ -17,3 +20,27 @@ async fn load_chat_model_test() {
// eprintln!("related questions: {:?}", questions); // eprintln!("related questions: {:?}", questions);
} }
} }
#[tokio::test]
async fn stream_local_model_test() {
if let Ok(test) = LocalAITest::new() {
let plugin_id = test.init_chat_plugin().await;
let chat_id = uuid::Uuid::new_v4().to_string();
let mut resp = test
.stream_chat_message(&chat_id, plugin_id, "hello world")
.await;
let a = resp.next().await.unwrap().unwrap();
eprintln!("chat response: {:?}", a);
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
// let mut resp = test
// .stream_chat_message(&chat_id, plugin_id, "How are you")
// .await;
// let a = resp.next().await.unwrap().unwrap();
// eprintln!("chat response: {:?}", a);
// let questions = test.related_question(&chat_id, plugin_id).await;
// assert_eq!(questions.len(), 3);
// eprintln!("related questions: {:?}", questions);
}
}

View File

@ -1,10 +1,12 @@
use anyhow::Result; use anyhow::Result;
use flowy_sidecar::manager::SidecarManager; use flowy_sidecar::manager::SidecarManager;
use serde_json::json; use serde_json::json;
use std::sync::{Arc, Once}; use std::sync::Once;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::Stream;
use flowy_sidecar::core::parser::{ChatResponseParser, SimilarityResponseParser};
use flowy_sidecar::core::plugin::{PluginId, PluginInfo}; use flowy_sidecar::core::plugin::{PluginId, PluginInfo};
use flowy_sidecar::error::SidecarError;
use flowy_sidecar::plugins::chat_plugin::ChatPluginOperation; use flowy_sidecar::plugins::chat_plugin::ChatPluginOperation;
use flowy_sidecar::plugins::embedding_plugin::EmbeddingPluginOperation; use flowy_sidecar::plugins::embedding_plugin::EmbeddingPluginOperation;
use tracing_subscriber::fmt::Subscriber; use tracing_subscriber::fmt::Subscriber;
@ -61,7 +63,12 @@ impl LocalAITest {
plugin_id plugin_id
} }
pub async fn send_message(&self, chat_id: &str, plugin_id: PluginId, message: &str) -> String { pub async fn send_chat_message(
&self,
chat_id: &str,
plugin_id: PluginId,
message: &str,
) -> String {
let plugin = self.manager.get_plugin(plugin_id).await.unwrap(); let plugin = self.manager.get_plugin(plugin_id).await.unwrap();
let operation = ChatPluginOperation::new(plugin); let operation = ChatPluginOperation::new(plugin);
let resp = operation let resp = operation
@ -72,6 +79,20 @@ impl LocalAITest {
resp resp
} }
pub async fn stream_chat_message(
&self,
chat_id: &str,
plugin_id: PluginId,
message: &str,
) -> ReceiverStream<Result<String, SidecarError>> {
let plugin = self.manager.get_plugin(plugin_id).await.unwrap();
let operation = ChatPluginOperation::new(plugin);
operation
.stream_message(chat_id, plugin_id, message)
.await
.unwrap()
}
pub async fn related_question( pub async fn related_question(
&self, &self,
chat_id: &str, chat_id: &str,

View File

@ -23,5 +23,6 @@ if_wasm! {
pub mod isolate_stream; pub mod isolate_stream;
pub mod priority_task; pub mod priority_task;
pub mod ref_map; pub mod ref_map;
pub mod stream_util;
pub mod util; pub mod util;
pub mod validator_fn; pub mod validator_fn;

View File

@ -0,0 +1,21 @@
use futures_core::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::mpsc;
use tokio::sync::mpsc::{Receiver, Sender};
struct BoundedStream<T> {
recv: Receiver<T>,
}
impl<T> Stream for BoundedStream<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
Pin::into_inner(self).recv.poll_recv(cx)
}
}
pub fn mpsc_channel_stream<T: Unpin>(size: usize) -> (Sender<T>, impl Stream<Item = T>) {
let (tx, rx) = mpsc::channel(size);
let stream = BoundedStream { recv: rx };
(tx, stream)
}