mirror of
https://github.com/AppFlowy-IO/AppFlowy.git
synced 2024-08-30 18:12:39 +00:00
chore: test streaming
This commit is contained in:
parent
9472361664
commit
e4b1108ff0
6
frontend/rust-lib/Cargo.lock
generated
6
frontend/rust-lib/Cargo.lock
generated
@ -2148,7 +2148,6 @@ dependencies = [
|
||||
"anyhow",
|
||||
"crossbeam-utils",
|
||||
"dotenv",
|
||||
"futures",
|
||||
"lib-infra",
|
||||
"log",
|
||||
"once_cell",
|
||||
@ -2157,6 +2156,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"uuid",
|
||||
@ -5767,9 +5767,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tokio-stream"
|
||||
version = "0.1.14"
|
||||
version = "0.1.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842"
|
||||
checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
|
@ -11,7 +11,7 @@ use flowy_sidecar::manager::SidecarManager;
|
||||
use flowy_sqlite::DBConnection;
|
||||
use lib_infra::future::FutureResult;
|
||||
use lib_infra::util::timestamp;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
use std::sync::Arc;
|
||||
use tracing::trace;
|
||||
|
||||
@ -263,9 +263,9 @@ impl ChatCloudService for ChatService {
|
||||
|
||||
async fn stream_answer(
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
message_id: i64,
|
||||
_workspace_id: &str,
|
||||
_chat_id: &str,
|
||||
_message_id: i64,
|
||||
) -> Result<StreamAnswer, FlowyError> {
|
||||
todo!()
|
||||
}
|
||||
|
@ -16,8 +16,8 @@ tracing.workspace = true
|
||||
crossbeam-utils = "0.8.20"
|
||||
log = "0.4.21"
|
||||
parking_lot.workspace = true
|
||||
tokio-stream = "0.1.15"
|
||||
lib-infra.workspace = true
|
||||
futures.workspace = true
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
|
@ -1,8 +1,9 @@
|
||||
use crate::core::rpc_object::RpcObject;
|
||||
use crate::core::rpc_peer::ResponsePayload;
|
||||
|
||||
use crate::error::{ReadError, RemoteError};
|
||||
use serde_json::{json, Value as JsonValue};
|
||||
use std::io::BufRead;
|
||||
use tracing::error;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct MessageReader(String);
|
||||
@ -57,15 +58,15 @@ pub enum Call<R> {
|
||||
}
|
||||
|
||||
pub trait ResponseParser {
|
||||
type ValueType;
|
||||
fn parse_response(payload: JsonValue) -> Result<Self::ValueType, RemoteError>;
|
||||
type ValueType: Send + Sync + 'static;
|
||||
fn parse_json(payload: JsonValue) -> Result<Self::ValueType, RemoteError>;
|
||||
}
|
||||
|
||||
pub struct ChatResponseParser;
|
||||
impl ResponseParser for ChatResponseParser {
|
||||
type ValueType = String;
|
||||
|
||||
fn parse_response(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
||||
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
||||
if json.is_object() {
|
||||
if let Some(data) = json.get("data") {
|
||||
if let Some(message) = data.as_str() {
|
||||
@ -73,7 +74,19 @@ impl ResponseParser for ChatResponseParser {
|
||||
}
|
||||
}
|
||||
}
|
||||
return Err(RemoteError::InvalidResponse(json));
|
||||
return Err(RemoteError::ParseResponse(json));
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ChatStreamResponseParser;
|
||||
impl ResponseParser for ChatStreamResponseParser {
|
||||
type ValueType = String;
|
||||
|
||||
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
||||
if let Some(message) = json.as_str() {
|
||||
return Ok(message.to_string());
|
||||
}
|
||||
return Err(RemoteError::ParseResponse(json));
|
||||
}
|
||||
}
|
||||
|
||||
@ -81,7 +94,7 @@ pub struct ChatRelatedQuestionsResponseParser;
|
||||
impl ResponseParser for ChatRelatedQuestionsResponseParser {
|
||||
type ValueType = Vec<JsonValue>;
|
||||
|
||||
fn parse_response(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
||||
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
||||
if json.is_object() {
|
||||
if let Some(data) = json.get("data") {
|
||||
if let Some(values) = data.as_array() {
|
||||
@ -89,7 +102,7 @@ impl ResponseParser for ChatRelatedQuestionsResponseParser {
|
||||
}
|
||||
}
|
||||
}
|
||||
return Err(RemoteError::InvalidResponse(json));
|
||||
return Err(RemoteError::ParseResponse(json));
|
||||
}
|
||||
}
|
||||
|
||||
@ -97,7 +110,7 @@ pub struct SimilarityResponseParser;
|
||||
impl ResponseParser for SimilarityResponseParser {
|
||||
type ValueType = f64;
|
||||
|
||||
fn parse_response(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
||||
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
||||
if json.is_object() {
|
||||
if let Some(data) = json.get("data") {
|
||||
if let Some(score) = data.get("score").and_then(|v| v.as_f64()) {
|
||||
@ -106,6 +119,6 @@ impl ResponseParser for SimilarityResponseParser {
|
||||
}
|
||||
}
|
||||
|
||||
return Err(RemoteError::InvalidResponse(json));
|
||||
return Err(RemoteError::ParseResponse(json));
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
use crate::error::Error;
|
||||
use crate::error::SidecarError;
|
||||
use crate::manager::WeakSidecarState;
|
||||
|
||||
use crate::core::parser::ResponseParser;
|
||||
use crate::core::rpc_loop::RpcLoop;
|
||||
use crate::core::rpc_peer::{Callback, ResponsePayload};
|
||||
use crate::core::rpc_peer::{CloneableCallback, OneShotCallback};
|
||||
use anyhow::anyhow;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value as JsonValue};
|
||||
@ -12,6 +12,9 @@ use std::process::{Child, Stdio};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::Instant;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::Stream;
|
||||
|
||||
use tracing::{error, info};
|
||||
|
||||
#[derive(
|
||||
@ -34,12 +37,12 @@ pub trait Peer: Send + Sync + 'static {
|
||||
/// Sends an RPC notification to the peer with the specified method and parameters.
|
||||
fn send_rpc_notification(&self, method: &str, params: &JsonValue);
|
||||
|
||||
/// Sends an asynchronous RPC request to the peer and executes the provided callback upon completion.
|
||||
fn send_rpc_request_async(&self, method: &str, params: &JsonValue, f: Box<dyn Callback>);
|
||||
fn stream_rpc_request(&self, method: &str, params: &JsonValue, f: CloneableCallback);
|
||||
|
||||
fn async_send_rpc_request(&self, method: &str, params: &JsonValue, f: Box<dyn OneShotCallback>);
|
||||
/// Sends a synchronous RPC request to the peer and waits for the result.
|
||||
/// Returns the result of the request or an error.
|
||||
fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, Error>;
|
||||
fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, SidecarError>;
|
||||
|
||||
/// Checks if there is an incoming request pending, intended to reduce latency for bulk operations done in the background.
|
||||
fn request_is_pending(&self) -> bool;
|
||||
@ -66,34 +69,53 @@ pub struct Plugin {
|
||||
}
|
||||
|
||||
impl Plugin {
|
||||
pub fn initialize(&self, value: JsonValue) -> Result<(), Error> {
|
||||
pub fn initialize(&self, value: JsonValue) -> Result<(), SidecarError> {
|
||||
self.peer.send_rpc_request("initialize", &value)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn send_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, Error> {
|
||||
pub fn request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, SidecarError> {
|
||||
self.peer.send_rpc_request(method, params)
|
||||
}
|
||||
|
||||
pub async fn async_send_request<P: ResponseParser>(
|
||||
pub async fn async_request<P: ResponseParser>(
|
||||
&self,
|
||||
method: &str,
|
||||
params: &JsonValue,
|
||||
) -> Result<P::ValueType, Error> {
|
||||
) -> Result<P::ValueType, SidecarError> {
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
self.peer.send_rpc_request_async(
|
||||
self.peer.async_send_rpc_request(
|
||||
method,
|
||||
params,
|
||||
Box::new(move |result| {
|
||||
let _ = tx.send(result);
|
||||
}),
|
||||
);
|
||||
let value = rx
|
||||
.await
|
||||
.map_err(|err| Error::Internal(anyhow!("error waiting for async response: {:?}", err)))??;
|
||||
let value = P::parse_response(value)?;
|
||||
let value = rx.await.map_err(|err| {
|
||||
SidecarError::Internal(anyhow!("error waiting for async response: {:?}", err))
|
||||
})??;
|
||||
let value = P::parse_json(value)?;
|
||||
Ok(value)
|
||||
}
|
||||
pub fn stream_request<P: ResponseParser>(
|
||||
&self,
|
||||
method: &str,
|
||||
params: &JsonValue,
|
||||
) -> Result<ReceiverStream<Result<P::ValueType, SidecarError>>, SidecarError> {
|
||||
let (tx, stream) = tokio::sync::mpsc::channel(100);
|
||||
let stream = ReceiverStream::new(stream);
|
||||
let callback = CloneableCallback::new(move |result| match result {
|
||||
Ok(json) => {
|
||||
let result = P::parse_json(json).map_err(SidecarError::from);
|
||||
let _ = tx.blocking_send(result);
|
||||
},
|
||||
Err(err) => {
|
||||
let _ = tx.blocking_send(Err(err));
|
||||
},
|
||||
});
|
||||
self.peer.stream_rpc_request(method, params, callback);
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
pub fn shutdown(&self) {
|
||||
if let Err(err) = self.peer.send_rpc_request("shutdown", &json!({})) {
|
||||
|
@ -2,14 +2,14 @@ use crate::core::parser::{Call, MessageReader};
|
||||
use crate::core::plugin::RpcCtx;
|
||||
use crate::core::rpc_object::RpcObject;
|
||||
use crate::core::rpc_peer::{RawPeer, ResponsePayload, RpcState};
|
||||
use crate::error::{Error, ReadError, RemoteError};
|
||||
use crate::error::{ReadError, RemoteError, SidecarError};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde_json::Value;
|
||||
|
||||
use std::io::{BufRead, Write};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
use tracing::{error, trace};
|
||||
use tracing::{error, info, trace};
|
||||
|
||||
const MAX_IDLE_WAIT: Duration = Duration::from_millis(5);
|
||||
|
||||
@ -97,7 +97,6 @@ impl<W: Write + Send> RpcLoop<W> {
|
||||
trace!("read loop exit");
|
||||
break;
|
||||
}
|
||||
|
||||
let json = match self.reader.next(&mut stream) {
|
||||
Ok(json) => json,
|
||||
Err(err) => {
|
||||
@ -109,15 +108,17 @@ impl<W: Write + Send> RpcLoop<W> {
|
||||
},
|
||||
};
|
||||
if json.is_response() {
|
||||
let id = json.get_id().unwrap();
|
||||
let request_id = json.get_id().unwrap();
|
||||
match json.into_response() {
|
||||
Ok(resp) => {
|
||||
let resp = resp.map_err(Error::from);
|
||||
self.peer.handle_response(id, resp);
|
||||
let resp = resp.map_err(SidecarError::from);
|
||||
self.peer.handle_response(request_id, resp);
|
||||
},
|
||||
Err(msg) => {
|
||||
error!("[RPC] failed to parse response: {}", msg);
|
||||
self.peer.handle_response(id, Err(Error::InvalidResponse));
|
||||
self
|
||||
.peer
|
||||
.handle_response(request_id, Err(SidecarError::InvalidResponse));
|
||||
},
|
||||
}
|
||||
} else {
|
||||
|
@ -1,7 +1,8 @@
|
||||
use crate::core::parser::{Call, RequestId};
|
||||
use crate::core::rpc_peer::{Response, ResponsePayload};
|
||||
use crate::error::RemoteError;
|
||||
use serde::de::{DeserializeOwned, Error};
|
||||
use serde_json::Value;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RpcObject(pub Value);
|
||||
@ -24,30 +25,63 @@ impl RpcObject {
|
||||
self.0.get("id").is_some() && self.0.get("method").is_none()
|
||||
}
|
||||
|
||||
/// Converts the underlying `Value` into an RPC response object.
|
||||
/// The caller should verify that the object is a response before calling this method.
|
||||
/// Converts a JSON-RPC response into a structured `Response` object.
|
||||
///
|
||||
/// This function validates and parses a JSON-RPC response, ensuring it contains the necessary fields,
|
||||
/// and then transforms it into a structured `Response` object. The response must contain either a
|
||||
/// "result" or an "error" field, but not both. If the response contains a "result" field, it may also
|
||||
/// include streaming data, indicated by a nested "stream" field.
|
||||
///
|
||||
/// # Errors
|
||||
/// If the `Value` is not a well-formed response object, this returns a `String` containing an
|
||||
/// error message. The caller should print this message and exit.
|
||||
///
|
||||
/// This function will return an error if:
|
||||
/// - The "id" field is missing.
|
||||
/// - The response contains both "result" and "error" fields, or neither.
|
||||
/// - The "stream" field within the "result" is missing "type" or "data" fields.
|
||||
/// - The "stream" type is invalid (i.e., not "streaming" or "end").
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// - `Ok(Ok(ResponsePayload::Json(result)))`: If the response contains a valid "result".
|
||||
/// - `Ok(Ok(ResponsePayload::Streaming(data)))`: If the response contains streaming data of type "streaming".
|
||||
/// - `Ok(Ok(ResponsePayload::StreamEnd(json!({}))))`: If the response contains streaming data of type "end".
|
||||
/// - `Err(String)`: If any validation or parsing errors occur.
|
||||
///.
|
||||
pub fn into_response(mut self) -> Result<Response, String> {
|
||||
let _ = self
|
||||
// Ensure 'id' field is present
|
||||
self
|
||||
.get_id()
|
||||
.ok_or("Response requires 'id' field.".to_string())?;
|
||||
.ok_or_else(|| "Response requires 'id' field.".to_string())?;
|
||||
|
||||
if self.0.get("result").is_some() == self.0.get("error").is_some() {
|
||||
// Ensure the response contains exactly one of 'result' or 'error'
|
||||
let has_result = self.0.get("result").is_some();
|
||||
let has_error = self.0.get("error").is_some();
|
||||
if has_result == has_error {
|
||||
return Err("RPC response must contain exactly one of 'error' or 'result' fields.".into());
|
||||
}
|
||||
let result = self.0.as_object_mut().and_then(|obj| obj.remove("result"));
|
||||
match result {
|
||||
Some(r) => Ok(Ok(ResponsePayload::Json(r))),
|
||||
None => {
|
||||
let error = self
|
||||
.0
|
||||
.as_object_mut()
|
||||
.and_then(|obj| obj.remove("error"))
|
||||
.unwrap();
|
||||
|
||||
// Handle the 'result' field if present
|
||||
if let Some(mut result) = self.0.as_object_mut().and_then(|obj| obj.remove("result")) {
|
||||
if let Some(mut stream) = result.as_object_mut().and_then(|obj| obj.remove("stream")) {
|
||||
if let Some((has_more, data)) = stream.as_object_mut().and_then(|obj| {
|
||||
let has_more = obj.remove("has_more")?.as_bool().unwrap_or(false);
|
||||
let data = obj.remove("data")?;
|
||||
Some((has_more, data))
|
||||
}) {
|
||||
return match has_more {
|
||||
true => Ok(Ok(ResponsePayload::Streaming(data))),
|
||||
false => Ok(Ok(ResponsePayload::StreamEnd(data))),
|
||||
};
|
||||
} else {
|
||||
return Err("Stream response must contain 'type' and 'data' fields.".into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Ok(ResponsePayload::Json(result)))
|
||||
} else {
|
||||
// Handle the 'error' field
|
||||
let error = self.0.as_object_mut().unwrap().remove("error").unwrap();
|
||||
Err(format!("Error handling response: {:?}", error))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,18 +1,18 @@
|
||||
use crate::core::plugin::{Peer, PluginId};
|
||||
use crate::core::rpc_object::RpcObject;
|
||||
use crate::error::{Error, ReadError, RemoteError};
|
||||
use futures::Stream;
|
||||
use crate::error::{ReadError, RemoteError, SidecarError};
|
||||
use parking_lot::{Condvar, Mutex};
|
||||
use serde::{de, ser, Deserialize, Deserializer, Serialize, Serializer};
|
||||
use serde_json::{json, Value as JsonValue};
|
||||
use std::collections::{BTreeMap, BinaryHeap, VecDeque};
|
||||
use std::fmt::Display;
|
||||
use std::io::Write;
|
||||
use std::pin::Pin;
|
||||
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::{mpsc, Arc};
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{cmp, io};
|
||||
use tokio_stream::Stream;
|
||||
use tracing::{error, trace, warn};
|
||||
|
||||
pub struct PluginCommand<T> {
|
||||
@ -97,15 +97,19 @@ impl<W: Write + Send + 'static> Peer for RawPeer<W> {
|
||||
}
|
||||
}
|
||||
|
||||
fn send_rpc_request_async(&self, method: &str, params: &JsonValue, f: Box<dyn Callback>) {
|
||||
fn stream_rpc_request(&self, method: &str, params: &JsonValue, f: CloneableCallback) {
|
||||
self.send_rpc(method, params, ResponseHandler::StreamCallback(Arc::new(f)));
|
||||
}
|
||||
|
||||
fn async_send_rpc_request(&self, method: &str, params: &JsonValue, f: Box<dyn OneShotCallback>) {
|
||||
self.send_rpc(method, params, ResponseHandler::Callback(f));
|
||||
}
|
||||
|
||||
fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, Error> {
|
||||
fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, SidecarError> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
self.0.is_blocking.store(true, Ordering::Release);
|
||||
self.send_rpc(method, params, ResponseHandler::Chan(tx));
|
||||
rx.recv().unwrap_or(Err(Error::PeerDisconnect))
|
||||
rx.recv().unwrap_or(Err(SidecarError::PeerDisconnect))
|
||||
}
|
||||
|
||||
fn request_is_pending(&self) -> bool {
|
||||
@ -133,7 +137,7 @@ impl<W: Write> RawPeer<W> {
|
||||
match result {
|
||||
Ok(result) => match result {
|
||||
ResponsePayload::Json(value) => response["result"] = value,
|
||||
ResponsePayload::Stream(_) => {
|
||||
ResponsePayload::Streaming(_) | ResponsePayload::StreamEnd(_) => {
|
||||
error!("stream response not supported")
|
||||
},
|
||||
},
|
||||
@ -161,29 +165,44 @@ impl<W: Write> RawPeer<W> {
|
||||
})) {
|
||||
let mut pending = self.0.pending.lock();
|
||||
if let Some(rh) = pending.remove(&id) {
|
||||
rh.invoke(Err(Error::Io(e)));
|
||||
rh.invoke(Err(SidecarError::Io(e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn handle_response(&self, id: u64, resp: Result<ResponsePayload, Error>) {
|
||||
let id = id as usize;
|
||||
pub(crate) fn handle_response(
|
||||
&self,
|
||||
request_id: u64,
|
||||
resp: Result<ResponsePayload, SidecarError>,
|
||||
) {
|
||||
let request_id = request_id as usize;
|
||||
let handler = {
|
||||
let mut pending = self.0.pending.lock();
|
||||
pending.remove(&id)
|
||||
pending.remove(&request_id)
|
||||
};
|
||||
let is_stream = resp.as_ref().map(|resp| resp.is_stream()).unwrap_or(false);
|
||||
match handler {
|
||||
Some(response_handler) => {
|
||||
if is_stream {
|
||||
let is_stream_end = resp
|
||||
.as_ref()
|
||||
.map(|resp| resp.is_stream_end())
|
||||
.unwrap_or(false);
|
||||
if !is_stream_end {
|
||||
// when steam is not end, we need to put the stream callback back to pending in order to
|
||||
// receive the next stream message.
|
||||
if let Some(callback) = response_handler.get_stream_callback() {
|
||||
let mut pending = self.0.pending.lock();
|
||||
pending.insert(request_id, ResponseHandler::StreamCallback(callback));
|
||||
}
|
||||
} else {
|
||||
trace!("[RPC] {} stream end", request_id);
|
||||
}
|
||||
}
|
||||
let json = resp.map(|resp| resp.into_json());
|
||||
response_handler.invoke(json);
|
||||
|
||||
// if is_stream {
|
||||
// let mut pending = self.0.pending.lock();
|
||||
// pending.insert(id, response_handler);
|
||||
// }
|
||||
},
|
||||
None => warn!("[RPC] id {} not found in pending", id),
|
||||
None => error!("[RPC] id {}'s handle not found", request_id),
|
||||
}
|
||||
}
|
||||
|
||||
@ -243,7 +262,7 @@ impl<W: Write> RawPeer<W> {
|
||||
let ids = pending.keys().cloned().collect::<Vec<_>>();
|
||||
for id in &ids {
|
||||
let callback = pending.remove(id).unwrap();
|
||||
callback.invoke(Err(Error::PeerDisconnect));
|
||||
callback.invoke(Err(SidecarError::PeerDisconnect));
|
||||
}
|
||||
self.0.needs_exit.store(true, Ordering::Relaxed);
|
||||
}
|
||||
@ -267,7 +286,8 @@ impl<W: Write> Clone for RawPeer<W> {
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ResponsePayload {
|
||||
Json(JsonValue),
|
||||
Stream(JsonValue),
|
||||
Streaming(JsonValue),
|
||||
StreamEnd(JsonValue),
|
||||
}
|
||||
|
||||
impl ResponsePayload {
|
||||
@ -276,23 +296,29 @@ impl ResponsePayload {
|
||||
}
|
||||
|
||||
pub fn is_stream(&self) -> bool {
|
||||
match self {
|
||||
ResponsePayload::Json(_) => false,
|
||||
ResponsePayload::Stream(_) => true,
|
||||
matches!(
|
||||
self,
|
||||
ResponsePayload::Streaming(_) | ResponsePayload::StreamEnd(_)
|
||||
)
|
||||
}
|
||||
|
||||
pub fn is_stream_end(&self) -> bool {
|
||||
matches!(self, ResponsePayload::StreamEnd(_))
|
||||
}
|
||||
|
||||
pub fn json(&self) -> &JsonValue {
|
||||
match self {
|
||||
ResponsePayload::Json(v) => v,
|
||||
ResponsePayload::Stream(v) => v,
|
||||
ResponsePayload::Streaming(v) => v,
|
||||
ResponsePayload::StreamEnd(v) => v,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_json(self) -> JsonValue {
|
||||
match self {
|
||||
ResponsePayload::Json(v) => v,
|
||||
ResponsePayload::Stream(v) => v,
|
||||
ResponsePayload::Streaming(v) => v,
|
||||
ResponsePayload::StreamEnd(v) => v,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -301,37 +327,78 @@ impl Display for ResponsePayload {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ResponsePayload::Json(v) => write!(f, "{}", v),
|
||||
ResponsePayload::Stream(_) => write!(f, "stream"),
|
||||
ResponsePayload::Streaming(_) => write!(f, "stream start"),
|
||||
ResponsePayload::StreamEnd(_) => write!(f, "stream end"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type Response = Result<ResponsePayload, RemoteError>;
|
||||
|
||||
pub trait ResponseStream: Stream<Item = Result<JsonValue, Error>> + Unpin + Send {}
|
||||
pub trait ResponseStream: Stream<Item = Result<JsonValue, SidecarError>> + Unpin + Send {}
|
||||
|
||||
impl<T> ResponseStream for T where T: Stream<Item = Result<JsonValue, Error>> + Unpin + Send {}
|
||||
impl<T> ResponseStream for T where T: Stream<Item = Result<JsonValue, SidecarError>> + Unpin + Send {}
|
||||
|
||||
enum ResponseHandler {
|
||||
Chan(mpsc::Sender<Result<JsonValue, Error>>),
|
||||
Callback(Box<dyn Callback>),
|
||||
}
|
||||
pub trait Callback: Send {
|
||||
fn call(self: Box<Self>, result: Result<JsonValue, Error>);
|
||||
Chan(mpsc::Sender<Result<JsonValue, SidecarError>>),
|
||||
Callback(Box<dyn OneShotCallback>),
|
||||
StreamCallback(Arc<CloneableCallback>),
|
||||
}
|
||||
|
||||
impl<F: Send + FnOnce(Result<JsonValue, Error>)> Callback for F {
|
||||
fn call(self: Box<F>, result: Result<JsonValue, Error>) {
|
||||
impl ResponseHandler {
|
||||
pub fn get_stream_callback(&self) -> Option<Arc<CloneableCallback>> {
|
||||
match self {
|
||||
ResponseHandler::StreamCallback(cb) => Some(cb.clone()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait OneShotCallback: Send {
|
||||
fn call(self: Box<Self>, result: Result<JsonValue, SidecarError>);
|
||||
}
|
||||
|
||||
impl<F: Send + FnOnce(Result<JsonValue, SidecarError>)> OneShotCallback for F {
|
||||
fn call(self: Box<Self>, result: Result<JsonValue, SidecarError>) {
|
||||
(self)(result)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Callback: Send + Sync {
|
||||
fn call(&self, result: Result<JsonValue, SidecarError>);
|
||||
}
|
||||
|
||||
impl<F: Send + Sync + Fn(Result<JsonValue, SidecarError>)> Callback for F {
|
||||
fn call(&self, result: Result<JsonValue, SidecarError>) {
|
||||
(*self)(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CloneableCallback {
|
||||
callback: Arc<dyn Callback>,
|
||||
}
|
||||
impl CloneableCallback {
|
||||
pub fn new<C: Callback + 'static>(callback: C) -> Self {
|
||||
CloneableCallback {
|
||||
callback: Arc::new(callback),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call(&self, result: Result<JsonValue, SidecarError>) {
|
||||
self.callback.call(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponseHandler {
|
||||
fn invoke(self, result: Result<JsonValue, Error>) {
|
||||
fn invoke(self, result: Result<JsonValue, SidecarError>) {
|
||||
match self {
|
||||
ResponseHandler::Chan(tx) => {
|
||||
let _ = tx.send(result);
|
||||
},
|
||||
ResponseHandler::StreamCallback(cb) => {
|
||||
cb.call(result);
|
||||
},
|
||||
ResponseHandler::Callback(f) => f.call(result),
|
||||
}
|
||||
}
|
||||
|
@ -1,11 +1,10 @@
|
||||
use crate::core::rpc_peer::ResponsePayload;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use serde_json::{json, Value as JsonValue};
|
||||
use std::{fmt, io};
|
||||
|
||||
/// The error type of `tauri-utils`.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Error {
|
||||
pub enum SidecarError {
|
||||
/// An IO error occurred on the underlying communication channel.
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
@ -48,6 +47,9 @@ pub enum RemoteError {
|
||||
|
||||
#[error("Invalid response: {0}")]
|
||||
InvalidResponse(JsonValue),
|
||||
|
||||
#[error("Parse response: {0}")]
|
||||
ParseResponse(JsonValue),
|
||||
/// A custom error, defined by the client.
|
||||
#[error("Custom error: {message}")]
|
||||
Custom {
|
||||
@ -100,9 +102,9 @@ impl From<serde_json::Error> for RemoteError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RemoteError> for Error {
|
||||
fn from(err: RemoteError) -> Error {
|
||||
Error::RemoteError(err)
|
||||
impl From<RemoteError> for SidecarError {
|
||||
fn from(err: RemoteError) -> SidecarError {
|
||||
SidecarError::RemoteError(err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -156,6 +158,11 @@ impl Serialize for RemoteError {
|
||||
"Invalid response".to_string(),
|
||||
Some(json!(resp.to_string())),
|
||||
),
|
||||
RemoteError::ParseResponse(resp) => (
|
||||
-1,
|
||||
"Invalid response".to_string(),
|
||||
Some(json!(resp.to_string())),
|
||||
),
|
||||
};
|
||||
let err = ErrorHelper {
|
||||
code,
|
||||
|
@ -1,4 +1,4 @@
|
||||
pub mod core;
|
||||
mod error;
|
||||
pub mod error;
|
||||
pub mod manager;
|
||||
pub mod plugins;
|
||||
|
@ -2,11 +2,11 @@ use crate::core::parser::ResponseParser;
|
||||
use crate::core::plugin::{start_plugin_process, Plugin, PluginId, PluginInfo, RpcCtx};
|
||||
use crate::core::rpc_loop::Handler;
|
||||
use crate::core::rpc_peer::{PluginCommand, ResponsePayload};
|
||||
use crate::error::{Error, ReadError, RemoteError};
|
||||
use crate::error::{ReadError, RemoteError, SidecarError};
|
||||
use anyhow::anyhow;
|
||||
use lib_infra::util::{get_operating_system, OperatingSystem};
|
||||
use parking_lot::Mutex;
|
||||
use serde_json::{json, Value};
|
||||
use serde_json::Value;
|
||||
use std::io;
|
||||
use std::sync::atomic::{AtomicI64, Ordering};
|
||||
use std::sync::{Arc, Weak};
|
||||
@ -29,9 +29,9 @@ impl SidecarManager {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result<PluginId, Error> {
|
||||
pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result<PluginId, SidecarError> {
|
||||
if self.operating_system.is_not_desktop() {
|
||||
return Err(Error::Internal(anyhow!(
|
||||
return Err(SidecarError::Internal(anyhow!(
|
||||
"plugin not supported on this platform"
|
||||
)));
|
||||
}
|
||||
@ -41,7 +41,7 @@ impl SidecarManager {
|
||||
Ok(plugin_id)
|
||||
}
|
||||
|
||||
pub async fn get_plugin(&self, plugin_id: PluginId) -> Result<Weak<Plugin>, Error> {
|
||||
pub async fn get_plugin(&self, plugin_id: PluginId) -> Result<Weak<Plugin>, SidecarError> {
|
||||
let state = self.state.lock();
|
||||
let plugin = state
|
||||
.plugins
|
||||
@ -51,9 +51,9 @@ impl SidecarManager {
|
||||
Ok(Arc::downgrade(plugin))
|
||||
}
|
||||
|
||||
pub async fn remove_plugin(&self, id: PluginId) -> Result<(), Error> {
|
||||
pub async fn remove_plugin(&self, id: PluginId) -> Result<(), SidecarError> {
|
||||
if self.operating_system.is_not_desktop() {
|
||||
return Err(Error::Internal(anyhow!(
|
||||
return Err(SidecarError::Internal(anyhow!(
|
||||
"plugin not supported on this platform"
|
||||
)));
|
||||
}
|
||||
@ -71,9 +71,9 @@ impl SidecarManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<(), Error> {
|
||||
pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<(), SidecarError> {
|
||||
if self.operating_system.is_not_desktop() {
|
||||
return Err(Error::Internal(anyhow!(
|
||||
return Err(SidecarError::Internal(anyhow!(
|
||||
"plugin not supported on this platform"
|
||||
)));
|
||||
}
|
||||
@ -94,15 +94,15 @@ impl SidecarManager {
|
||||
id: PluginId,
|
||||
method: &str,
|
||||
request: Value,
|
||||
) -> Result<P::ValueType, Error> {
|
||||
) -> Result<P::ValueType, SidecarError> {
|
||||
let state = self.state.lock();
|
||||
let plugin = state
|
||||
.plugins
|
||||
.iter()
|
||||
.find(|p| p.id == id)
|
||||
.ok_or(anyhow!("plugin not found"))?;
|
||||
let resp = plugin.send_request(method, &request)?;
|
||||
let value = P::parse_response(resp)?;
|
||||
let resp = plugin.request(method, &request)?;
|
||||
let value = P::parse_json(resp)?;
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
@ -111,14 +111,14 @@ impl SidecarManager {
|
||||
id: PluginId,
|
||||
method: &str,
|
||||
request: Value,
|
||||
) -> Result<P::ValueType, Error> {
|
||||
) -> Result<P::ValueType, SidecarError> {
|
||||
let state = self.state.lock();
|
||||
let plugin = state
|
||||
.plugins
|
||||
.iter()
|
||||
.find(|p| p.id == id)
|
||||
.ok_or(anyhow!("plugin not found"))?;
|
||||
let value = plugin.async_send_request::<P>(method, &request).await?;
|
||||
let value = plugin.async_request::<P>(method, &request).await?;
|
||||
Ok(value)
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,13 @@
|
||||
use crate::core::parser::{ChatRelatedQuestionsResponseParser, ChatResponseParser};
|
||||
use crate::core::parser::{
|
||||
ChatRelatedQuestionsResponseParser, ChatResponseParser, ChatStreamResponseParser,
|
||||
};
|
||||
use crate::core::plugin::{Plugin, PluginId};
|
||||
use crate::error::Error;
|
||||
use crate::error::SidecarError;
|
||||
use anyhow::anyhow;
|
||||
use serde_json::json;
|
||||
use std::sync::Weak;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::Stream;
|
||||
|
||||
pub struct ChatPluginOperation {
|
||||
plugin: Weak<Plugin>,
|
||||
@ -17,17 +21,17 @@ impl ChatPluginOperation {
|
||||
pub async fn send_message(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
plugin_id: PluginId,
|
||||
_plugin_id: PluginId,
|
||||
message: &str,
|
||||
) -> Result<String, Error> {
|
||||
) -> Result<String, SidecarError> {
|
||||
let plugin = self
|
||||
.plugin
|
||||
.upgrade()
|
||||
.ok_or(Error::Internal(anyhow!("Plugin is dropped")))?;
|
||||
.ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
|
||||
|
||||
let params = json!({"chat_id": chat_id, "method": "answer", "params": {"content": message}});
|
||||
let resp = plugin
|
||||
.async_send_request::<ChatResponseParser>("handle", ¶ms)
|
||||
.async_request::<ChatResponseParser>("handle", ¶ms)
|
||||
.await?;
|
||||
Ok(resp)
|
||||
}
|
||||
@ -35,34 +39,31 @@ impl ChatPluginOperation {
|
||||
pub async fn stream_message(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
plugin_id: PluginId,
|
||||
_plugin_id: PluginId,
|
||||
message: &str,
|
||||
) -> Result<String, Error> {
|
||||
) -> Result<ReceiverStream<Result<String, SidecarError>>, SidecarError> {
|
||||
let plugin = self
|
||||
.plugin
|
||||
.upgrade()
|
||||
.ok_or(Error::Internal(anyhow!("Plugin is dropped")))?;
|
||||
.ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
|
||||
|
||||
let params =
|
||||
json!({"chat_id": chat_id, "method": "stream_answer", "params": {"content": message}});
|
||||
let resp = plugin
|
||||
.async_send_request::<ChatResponseParser>("handle", ¶ms)
|
||||
.await?;
|
||||
Ok(resp)
|
||||
plugin.stream_request::<ChatStreamResponseParser>("handle", ¶ms)
|
||||
}
|
||||
|
||||
pub async fn get_related_questions(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
) -> Result<Vec<serde_json::Value>, Error> {
|
||||
) -> Result<Vec<serde_json::Value>, SidecarError> {
|
||||
let plugin = self
|
||||
.plugin
|
||||
.upgrade()
|
||||
.ok_or(Error::Internal(anyhow!("Plugin is dropped")))?;
|
||||
.ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
|
||||
|
||||
let params = json!({"chat_id": chat_id, "method": "related_question"});
|
||||
let resp = plugin
|
||||
.async_send_request::<ChatRelatedQuestionsResponseParser>("handle", ¶ms)
|
||||
.async_request::<ChatRelatedQuestionsResponseParser>("handle", ¶ms)
|
||||
.await?;
|
||||
Ok(resp)
|
||||
}
|
||||
|
@ -1,8 +1,6 @@
|
||||
use crate::core::parser::{
|
||||
ChatRelatedQuestionsResponseParser, ChatResponseParser, SimilarityResponseParser,
|
||||
};
|
||||
use crate::core::plugin::{Plugin, PluginId};
|
||||
use crate::error::Error;
|
||||
use crate::core::parser::SimilarityResponseParser;
|
||||
use crate::core::plugin::Plugin;
|
||||
use crate::error::SidecarError;
|
||||
use anyhow::anyhow;
|
||||
use serde_json::json;
|
||||
use std::sync::Weak;
|
||||
@ -16,15 +14,19 @@ impl EmbeddingPluginOperation {
|
||||
EmbeddingPluginOperation { plugin }
|
||||
}
|
||||
|
||||
pub async fn calculate_similarity(&self, message1: &str, message2: &str) -> Result<f64, Error> {
|
||||
pub async fn calculate_similarity(
|
||||
&self,
|
||||
message1: &str,
|
||||
message2: &str,
|
||||
) -> Result<f64, SidecarError> {
|
||||
let plugin = self
|
||||
.plugin
|
||||
.upgrade()
|
||||
.ok_or(Error::Internal(anyhow!("Plugin is dropped")))?;
|
||||
.ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
|
||||
let params =
|
||||
json!({"method": "calculate_similarity", "params": {"src": message1, "dest": message2}});
|
||||
plugin
|
||||
.async_send_request::<SimilarityResponseParser>("handle", ¶ms)
|
||||
.async_request::<SimilarityResponseParser>("handle", ¶ms)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
@ -1,11 +1,14 @@
|
||||
use crate::util::LocalAITest;
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn load_chat_model_test() {
|
||||
if let Ok(test) = LocalAITest::new() {
|
||||
let plugin_id = test.init_chat_plugin().await;
|
||||
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
let resp = test.send_message(&chat_id, plugin_id, "hello world").await;
|
||||
let resp = test
|
||||
.send_chat_message(&chat_id, plugin_id, "hello world")
|
||||
.await;
|
||||
eprintln!("chat response: {:?}", resp);
|
||||
|
||||
let embedding_plugin_id = test.init_embedding_plugin().await;
|
||||
@ -17,3 +20,27 @@ async fn load_chat_model_test() {
|
||||
// eprintln!("related questions: {:?}", questions);
|
||||
}
|
||||
}
|
||||
#[tokio::test]
|
||||
async fn stream_local_model_test() {
|
||||
if let Ok(test) = LocalAITest::new() {
|
||||
let plugin_id = test.init_chat_plugin().await;
|
||||
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
let mut resp = test
|
||||
.stream_chat_message(&chat_id, plugin_id, "hello world")
|
||||
.await;
|
||||
let a = resp.next().await.unwrap().unwrap();
|
||||
eprintln!("chat response: {:?}", a);
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
|
||||
// let mut resp = test
|
||||
// .stream_chat_message(&chat_id, plugin_id, "How are you")
|
||||
// .await;
|
||||
// let a = resp.next().await.unwrap().unwrap();
|
||||
// eprintln!("chat response: {:?}", a);
|
||||
// let questions = test.related_question(&chat_id, plugin_id).await;
|
||||
// assert_eq!(questions.len(), 3);
|
||||
// eprintln!("related questions: {:?}", questions);
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,12 @@
|
||||
use anyhow::Result;
|
||||
use flowy_sidecar::manager::SidecarManager;
|
||||
use serde_json::json;
|
||||
use std::sync::{Arc, Once};
|
||||
use std::sync::Once;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::Stream;
|
||||
|
||||
use flowy_sidecar::core::parser::{ChatResponseParser, SimilarityResponseParser};
|
||||
use flowy_sidecar::core::plugin::{PluginId, PluginInfo};
|
||||
use flowy_sidecar::error::SidecarError;
|
||||
use flowy_sidecar::plugins::chat_plugin::ChatPluginOperation;
|
||||
use flowy_sidecar::plugins::embedding_plugin::EmbeddingPluginOperation;
|
||||
use tracing_subscriber::fmt::Subscriber;
|
||||
@ -61,7 +63,12 @@ impl LocalAITest {
|
||||
plugin_id
|
||||
}
|
||||
|
||||
pub async fn send_message(&self, chat_id: &str, plugin_id: PluginId, message: &str) -> String {
|
||||
pub async fn send_chat_message(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
plugin_id: PluginId,
|
||||
message: &str,
|
||||
) -> String {
|
||||
let plugin = self.manager.get_plugin(plugin_id).await.unwrap();
|
||||
let operation = ChatPluginOperation::new(plugin);
|
||||
let resp = operation
|
||||
@ -72,6 +79,20 @@ impl LocalAITest {
|
||||
resp
|
||||
}
|
||||
|
||||
pub async fn stream_chat_message(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
plugin_id: PluginId,
|
||||
message: &str,
|
||||
) -> ReceiverStream<Result<String, SidecarError>> {
|
||||
let plugin = self.manager.get_plugin(plugin_id).await.unwrap();
|
||||
let operation = ChatPluginOperation::new(plugin);
|
||||
operation
|
||||
.stream_message(chat_id, plugin_id, message)
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub async fn related_question(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
|
@ -23,5 +23,6 @@ if_wasm! {
|
||||
pub mod isolate_stream;
|
||||
pub mod priority_task;
|
||||
pub mod ref_map;
|
||||
pub mod stream_util;
|
||||
pub mod util;
|
||||
pub mod validator_fn;
|
||||
|
21
frontend/rust-lib/lib-infra/src/stream_util.rs
Normal file
21
frontend/rust-lib/lib-infra/src/stream_util.rs
Normal file
@ -0,0 +1,21 @@
|
||||
use futures_core::Stream;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::{Receiver, Sender};
|
||||
|
||||
struct BoundedStream<T> {
|
||||
recv: Receiver<T>,
|
||||
}
|
||||
impl<T> Stream for BoundedStream<T> {
|
||||
type Item = T;
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
|
||||
Pin::into_inner(self).recv.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mpsc_channel_stream<T: Unpin>(size: usize) -> (Sender<T>, impl Stream<Item = T>) {
|
||||
let (tx, rx) = mpsc::channel(size);
|
||||
let stream = BoundedStream { recv: rx };
|
||||
(tx, stream)
|
||||
}
|
Loading…
Reference in New Issue
Block a user