chore: test streaming

This commit is contained in:
nathan 2024-06-27 22:00:36 +08:00
parent 9472361664
commit e4b1108ff0
17 changed files with 359 additions and 142 deletions

View File

@ -2148,7 +2148,6 @@ dependencies = [
"anyhow",
"crossbeam-utils",
"dotenv",
"futures",
"lib-infra",
"log",
"once_cell",
@ -2157,6 +2156,7 @@ dependencies = [
"serde_json",
"thiserror",
"tokio",
"tokio-stream",
"tracing",
"tracing-subscriber",
"uuid",
@ -5767,9 +5767,9 @@ dependencies = [
[[package]]
name = "tokio-stream"
version = "0.1.14"
version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842"
checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af"
dependencies = [
"futures-core",
"pin-project-lite",

View File

@ -11,7 +11,7 @@ use flowy_sidecar::manager::SidecarManager;
use flowy_sqlite::DBConnection;
use lib_infra::future::FutureResult;
use lib_infra::util::timestamp;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use tracing::trace;
@ -263,9 +263,9 @@ impl ChatCloudService for ChatService {
async fn stream_answer(
&self,
workspace_id: &str,
chat_id: &str,
message_id: i64,
_workspace_id: &str,
_chat_id: &str,
_message_id: i64,
) -> Result<StreamAnswer, FlowyError> {
todo!()
}

View File

@ -16,8 +16,8 @@ tracing.workspace = true
crossbeam-utils = "0.8.20"
log = "0.4.21"
parking_lot.workspace = true
tokio-stream = "0.1.15"
lib-infra.workspace = true
futures.workspace = true
[dev-dependencies]

View File

@ -1,8 +1,9 @@
use crate::core::rpc_object::RpcObject;
use crate::core::rpc_peer::ResponsePayload;
use crate::error::{ReadError, RemoteError};
use serde_json::{json, Value as JsonValue};
use std::io::BufRead;
use tracing::error;
#[derive(Debug, Default)]
pub struct MessageReader(String);
@ -57,15 +58,15 @@ pub enum Call<R> {
}
pub trait ResponseParser {
type ValueType;
fn parse_response(payload: JsonValue) -> Result<Self::ValueType, RemoteError>;
type ValueType: Send + Sync + 'static;
fn parse_json(payload: JsonValue) -> Result<Self::ValueType, RemoteError>;
}
pub struct ChatResponseParser;
impl ResponseParser for ChatResponseParser {
type ValueType = String;
fn parse_response(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
if json.is_object() {
if let Some(data) = json.get("data") {
if let Some(message) = data.as_str() {
@ -73,7 +74,19 @@ impl ResponseParser for ChatResponseParser {
}
}
}
return Err(RemoteError::InvalidResponse(json));
return Err(RemoteError::ParseResponse(json));
}
}
pub struct ChatStreamResponseParser;
impl ResponseParser for ChatStreamResponseParser {
type ValueType = String;
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
if let Some(message) = json.as_str() {
return Ok(message.to_string());
}
return Err(RemoteError::ParseResponse(json));
}
}
@ -81,7 +94,7 @@ pub struct ChatRelatedQuestionsResponseParser;
impl ResponseParser for ChatRelatedQuestionsResponseParser {
type ValueType = Vec<JsonValue>;
fn parse_response(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
if json.is_object() {
if let Some(data) = json.get("data") {
if let Some(values) = data.as_array() {
@ -89,7 +102,7 @@ impl ResponseParser for ChatRelatedQuestionsResponseParser {
}
}
}
return Err(RemoteError::InvalidResponse(json));
return Err(RemoteError::ParseResponse(json));
}
}
@ -97,7 +110,7 @@ pub struct SimilarityResponseParser;
impl ResponseParser for SimilarityResponseParser {
type ValueType = f64;
fn parse_response(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
if json.is_object() {
if let Some(data) = json.get("data") {
if let Some(score) = data.get("score").and_then(|v| v.as_f64()) {
@ -106,6 +119,6 @@ impl ResponseParser for SimilarityResponseParser {
}
}
return Err(RemoteError::InvalidResponse(json));
return Err(RemoteError::ParseResponse(json));
}
}

View File

@ -1,9 +1,9 @@
use crate::error::Error;
use crate::error::SidecarError;
use crate::manager::WeakSidecarState;
use crate::core::parser::ResponseParser;
use crate::core::rpc_loop::RpcLoop;
use crate::core::rpc_peer::{Callback, ResponsePayload};
use crate::core::rpc_peer::{CloneableCallback, OneShotCallback};
use anyhow::anyhow;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value as JsonValue};
@ -12,6 +12,9 @@ use std::process::{Child, Stdio};
use std::sync::Arc;
use std::thread;
use std::time::Instant;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::Stream;
use tracing::{error, info};
#[derive(
@ -34,12 +37,12 @@ pub trait Peer: Send + Sync + 'static {
/// Sends an RPC notification to the peer with the specified method and parameters.
fn send_rpc_notification(&self, method: &str, params: &JsonValue);
/// Sends an asynchronous RPC request to the peer and executes the provided callback upon completion.
fn send_rpc_request_async(&self, method: &str, params: &JsonValue, f: Box<dyn Callback>);
fn stream_rpc_request(&self, method: &str, params: &JsonValue, f: CloneableCallback);
fn async_send_rpc_request(&self, method: &str, params: &JsonValue, f: Box<dyn OneShotCallback>);
/// Sends a synchronous RPC request to the peer and waits for the result.
/// Returns the result of the request or an error.
fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, Error>;
fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, SidecarError>;
/// Checks if there is an incoming request pending, intended to reduce latency for bulk operations done in the background.
fn request_is_pending(&self) -> bool;
@ -66,34 +69,53 @@ pub struct Plugin {
}
impl Plugin {
pub fn initialize(&self, value: JsonValue) -> Result<(), Error> {
pub fn initialize(&self, value: JsonValue) -> Result<(), SidecarError> {
self.peer.send_rpc_request("initialize", &value)?;
Ok(())
}
pub fn send_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, Error> {
pub fn request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, SidecarError> {
self.peer.send_rpc_request(method, params)
}
pub async fn async_send_request<P: ResponseParser>(
pub async fn async_request<P: ResponseParser>(
&self,
method: &str,
params: &JsonValue,
) -> Result<P::ValueType, Error> {
) -> Result<P::ValueType, SidecarError> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.peer.send_rpc_request_async(
self.peer.async_send_rpc_request(
method,
params,
Box::new(move |result| {
let _ = tx.send(result);
}),
);
let value = rx
.await
.map_err(|err| Error::Internal(anyhow!("error waiting for async response: {:?}", err)))??;
let value = P::parse_response(value)?;
let value = rx.await.map_err(|err| {
SidecarError::Internal(anyhow!("error waiting for async response: {:?}", err))
})??;
let value = P::parse_json(value)?;
Ok(value)
}
pub fn stream_request<P: ResponseParser>(
&self,
method: &str,
params: &JsonValue,
) -> Result<ReceiverStream<Result<P::ValueType, SidecarError>>, SidecarError> {
let (tx, stream) = tokio::sync::mpsc::channel(100);
let stream = ReceiverStream::new(stream);
let callback = CloneableCallback::new(move |result| match result {
Ok(json) => {
let result = P::parse_json(json).map_err(SidecarError::from);
let _ = tx.blocking_send(result);
},
Err(err) => {
let _ = tx.blocking_send(Err(err));
},
});
self.peer.stream_rpc_request(method, params, callback);
Ok(stream)
}
pub fn shutdown(&self) {
if let Err(err) = self.peer.send_rpc_request("shutdown", &json!({})) {

View File

@ -2,14 +2,14 @@ use crate::core::parser::{Call, MessageReader};
use crate::core::plugin::RpcCtx;
use crate::core::rpc_object::RpcObject;
use crate::core::rpc_peer::{RawPeer, ResponsePayload, RpcState};
use crate::error::{Error, ReadError, RemoteError};
use crate::error::{ReadError, RemoteError, SidecarError};
use serde::de::DeserializeOwned;
use serde_json::Value;
use std::io::{BufRead, Write};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use tracing::{error, trace};
use tracing::{error, info, trace};
const MAX_IDLE_WAIT: Duration = Duration::from_millis(5);
@ -97,7 +97,6 @@ impl<W: Write + Send> RpcLoop<W> {
trace!("read loop exit");
break;
}
let json = match self.reader.next(&mut stream) {
Ok(json) => json,
Err(err) => {
@ -109,15 +108,17 @@ impl<W: Write + Send> RpcLoop<W> {
},
};
if json.is_response() {
let id = json.get_id().unwrap();
let request_id = json.get_id().unwrap();
match json.into_response() {
Ok(resp) => {
let resp = resp.map_err(Error::from);
self.peer.handle_response(id, resp);
let resp = resp.map_err(SidecarError::from);
self.peer.handle_response(request_id, resp);
},
Err(msg) => {
error!("[RPC] failed to parse response: {}", msg);
self.peer.handle_response(id, Err(Error::InvalidResponse));
self
.peer
.handle_response(request_id, Err(SidecarError::InvalidResponse));
},
}
} else {

View File

@ -1,7 +1,8 @@
use crate::core::parser::{Call, RequestId};
use crate::core::rpc_peer::{Response, ResponsePayload};
use crate::error::RemoteError;
use serde::de::{DeserializeOwned, Error};
use serde_json::Value;
use serde_json::{json, Value};
#[derive(Debug, Clone)]
pub struct RpcObject(pub Value);
@ -24,30 +25,63 @@ impl RpcObject {
self.0.get("id").is_some() && self.0.get("method").is_none()
}
/// Converts the underlying `Value` into an RPC response object.
/// The caller should verify that the object is a response before calling this method.
/// Converts a JSON-RPC response into a structured `Response` object.
///
/// This function validates and parses a JSON-RPC response, ensuring it contains the necessary fields,
/// and then transforms it into a structured `Response` object. The response must contain either a
/// "result" or an "error" field, but not both. If the response contains a "result" field, it may also
/// include streaming data, indicated by a nested "stream" field.
///
/// # Errors
/// If the `Value` is not a well-formed response object, this returns a `String` containing an
/// error message. The caller should print this message and exit.
///
/// This function will return an error if:
/// - The "id" field is missing.
/// - The response contains both "result" and "error" fields, or neither.
/// - The "stream" field within the "result" is missing "type" or "data" fields.
/// - The "stream" type is invalid (i.e., not "streaming" or "end").
///
/// # Returns
///
/// - `Ok(Ok(ResponsePayload::Json(result)))`: If the response contains a valid "result".
/// - `Ok(Ok(ResponsePayload::Streaming(data)))`: If the response contains streaming data of type "streaming".
/// - `Ok(Ok(ResponsePayload::StreamEnd(json!({}))))`: If the response contains streaming data of type "end".
/// - `Err(String)`: If any validation or parsing errors occur.
///.
pub fn into_response(mut self) -> Result<Response, String> {
let _ = self
// Ensure 'id' field is present
self
.get_id()
.ok_or("Response requires 'id' field.".to_string())?;
.ok_or_else(|| "Response requires 'id' field.".to_string())?;
if self.0.get("result").is_some() == self.0.get("error").is_some() {
// Ensure the response contains exactly one of 'result' or 'error'
let has_result = self.0.get("result").is_some();
let has_error = self.0.get("error").is_some();
if has_result == has_error {
return Err("RPC response must contain exactly one of 'error' or 'result' fields.".into());
}
let result = self.0.as_object_mut().and_then(|obj| obj.remove("result"));
match result {
Some(r) => Ok(Ok(ResponsePayload::Json(r))),
None => {
let error = self
.0
.as_object_mut()
.and_then(|obj| obj.remove("error"))
.unwrap();
// Handle the 'result' field if present
if let Some(mut result) = self.0.as_object_mut().and_then(|obj| obj.remove("result")) {
if let Some(mut stream) = result.as_object_mut().and_then(|obj| obj.remove("stream")) {
if let Some((has_more, data)) = stream.as_object_mut().and_then(|obj| {
let has_more = obj.remove("has_more")?.as_bool().unwrap_or(false);
let data = obj.remove("data")?;
Some((has_more, data))
}) {
return match has_more {
true => Ok(Ok(ResponsePayload::Streaming(data))),
false => Ok(Ok(ResponsePayload::StreamEnd(data))),
};
} else {
return Err("Stream response must contain 'type' and 'data' fields.".into());
}
}
Ok(Ok(ResponsePayload::Json(result)))
} else {
// Handle the 'error' field
let error = self.0.as_object_mut().unwrap().remove("error").unwrap();
Err(format!("Error handling response: {:?}", error))
},
}
}

View File

@ -1,18 +1,18 @@
use crate::core::plugin::{Peer, PluginId};
use crate::core::rpc_object::RpcObject;
use crate::error::{Error, ReadError, RemoteError};
use futures::Stream;
use crate::error::{ReadError, RemoteError, SidecarError};
use parking_lot::{Condvar, Mutex};
use serde::{de, ser, Deserialize, Deserializer, Serialize, Serializer};
use serde_json::{json, Value as JsonValue};
use std::collections::{BTreeMap, BinaryHeap, VecDeque};
use std::fmt::Display;
use std::io::Write;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{mpsc, Arc};
use std::time::{Duration, Instant};
use std::{cmp, io};
use tokio_stream::Stream;
use tracing::{error, trace, warn};
pub struct PluginCommand<T> {
@ -97,15 +97,19 @@ impl<W: Write + Send + 'static> Peer for RawPeer<W> {
}
}
fn send_rpc_request_async(&self, method: &str, params: &JsonValue, f: Box<dyn Callback>) {
fn stream_rpc_request(&self, method: &str, params: &JsonValue, f: CloneableCallback) {
self.send_rpc(method, params, ResponseHandler::StreamCallback(Arc::new(f)));
}
fn async_send_rpc_request(&self, method: &str, params: &JsonValue, f: Box<dyn OneShotCallback>) {
self.send_rpc(method, params, ResponseHandler::Callback(f));
}
fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, Error> {
fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, SidecarError> {
let (tx, rx) = mpsc::channel();
self.0.is_blocking.store(true, Ordering::Release);
self.send_rpc(method, params, ResponseHandler::Chan(tx));
rx.recv().unwrap_or(Err(Error::PeerDisconnect))
rx.recv().unwrap_or(Err(SidecarError::PeerDisconnect))
}
fn request_is_pending(&self) -> bool {
@ -133,7 +137,7 @@ impl<W: Write> RawPeer<W> {
match result {
Ok(result) => match result {
ResponsePayload::Json(value) => response["result"] = value,
ResponsePayload::Stream(_) => {
ResponsePayload::Streaming(_) | ResponsePayload::StreamEnd(_) => {
error!("stream response not supported")
},
},
@ -161,29 +165,44 @@ impl<W: Write> RawPeer<W> {
})) {
let mut pending = self.0.pending.lock();
if let Some(rh) = pending.remove(&id) {
rh.invoke(Err(Error::Io(e)));
rh.invoke(Err(SidecarError::Io(e)));
}
}
}
pub(crate) fn handle_response(&self, id: u64, resp: Result<ResponsePayload, Error>) {
let id = id as usize;
pub(crate) fn handle_response(
&self,
request_id: u64,
resp: Result<ResponsePayload, SidecarError>,
) {
let request_id = request_id as usize;
let handler = {
let mut pending = self.0.pending.lock();
pending.remove(&id)
pending.remove(&request_id)
};
let is_stream = resp.as_ref().map(|resp| resp.is_stream()).unwrap_or(false);
match handler {
Some(response_handler) => {
if is_stream {
let is_stream_end = resp
.as_ref()
.map(|resp| resp.is_stream_end())
.unwrap_or(false);
if !is_stream_end {
// when steam is not end, we need to put the stream callback back to pending in order to
// receive the next stream message.
if let Some(callback) = response_handler.get_stream_callback() {
let mut pending = self.0.pending.lock();
pending.insert(request_id, ResponseHandler::StreamCallback(callback));
}
} else {
trace!("[RPC] {} stream end", request_id);
}
}
let json = resp.map(|resp| resp.into_json());
response_handler.invoke(json);
// if is_stream {
// let mut pending = self.0.pending.lock();
// pending.insert(id, response_handler);
// }
},
None => warn!("[RPC] id {} not found in pending", id),
None => error!("[RPC] id {}'s handle not found", request_id),
}
}
@ -243,7 +262,7 @@ impl<W: Write> RawPeer<W> {
let ids = pending.keys().cloned().collect::<Vec<_>>();
for id in &ids {
let callback = pending.remove(id).unwrap();
callback.invoke(Err(Error::PeerDisconnect));
callback.invoke(Err(SidecarError::PeerDisconnect));
}
self.0.needs_exit.store(true, Ordering::Relaxed);
}
@ -267,7 +286,8 @@ impl<W: Write> Clone for RawPeer<W> {
#[derive(Clone, Debug)]
pub enum ResponsePayload {
Json(JsonValue),
Stream(JsonValue),
Streaming(JsonValue),
StreamEnd(JsonValue),
}
impl ResponsePayload {
@ -276,23 +296,29 @@ impl ResponsePayload {
}
pub fn is_stream(&self) -> bool {
match self {
ResponsePayload::Json(_) => false,
ResponsePayload::Stream(_) => true,
matches!(
self,
ResponsePayload::Streaming(_) | ResponsePayload::StreamEnd(_)
)
}
pub fn is_stream_end(&self) -> bool {
matches!(self, ResponsePayload::StreamEnd(_))
}
pub fn json(&self) -> &JsonValue {
match self {
ResponsePayload::Json(v) => v,
ResponsePayload::Stream(v) => v,
ResponsePayload::Streaming(v) => v,
ResponsePayload::StreamEnd(v) => v,
}
}
pub fn into_json(self) -> JsonValue {
match self {
ResponsePayload::Json(v) => v,
ResponsePayload::Stream(v) => v,
ResponsePayload::Streaming(v) => v,
ResponsePayload::StreamEnd(v) => v,
}
}
}
@ -301,37 +327,78 @@ impl Display for ResponsePayload {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ResponsePayload::Json(v) => write!(f, "{}", v),
ResponsePayload::Stream(_) => write!(f, "stream"),
ResponsePayload::Streaming(_) => write!(f, "stream start"),
ResponsePayload::StreamEnd(_) => write!(f, "stream end"),
}
}
}
pub type Response = Result<ResponsePayload, RemoteError>;
pub trait ResponseStream: Stream<Item = Result<JsonValue, Error>> + Unpin + Send {}
pub trait ResponseStream: Stream<Item = Result<JsonValue, SidecarError>> + Unpin + Send {}
impl<T> ResponseStream for T where T: Stream<Item = Result<JsonValue, Error>> + Unpin + Send {}
impl<T> ResponseStream for T where T: Stream<Item = Result<JsonValue, SidecarError>> + Unpin + Send {}
enum ResponseHandler {
Chan(mpsc::Sender<Result<JsonValue, Error>>),
Callback(Box<dyn Callback>),
}
pub trait Callback: Send {
fn call(self: Box<Self>, result: Result<JsonValue, Error>);
Chan(mpsc::Sender<Result<JsonValue, SidecarError>>),
Callback(Box<dyn OneShotCallback>),
StreamCallback(Arc<CloneableCallback>),
}
impl<F: Send + FnOnce(Result<JsonValue, Error>)> Callback for F {
fn call(self: Box<F>, result: Result<JsonValue, Error>) {
impl ResponseHandler {
pub fn get_stream_callback(&self) -> Option<Arc<CloneableCallback>> {
match self {
ResponseHandler::StreamCallback(cb) => Some(cb.clone()),
_ => None,
}
}
}
pub trait OneShotCallback: Send {
fn call(self: Box<Self>, result: Result<JsonValue, SidecarError>);
}
impl<F: Send + FnOnce(Result<JsonValue, SidecarError>)> OneShotCallback for F {
fn call(self: Box<Self>, result: Result<JsonValue, SidecarError>) {
(self)(result)
}
}
pub trait Callback: Send + Sync {
fn call(&self, result: Result<JsonValue, SidecarError>);
}
impl<F: Send + Sync + Fn(Result<JsonValue, SidecarError>)> Callback for F {
fn call(&self, result: Result<JsonValue, SidecarError>) {
(*self)(result)
}
}
#[derive(Clone)]
pub struct CloneableCallback {
callback: Arc<dyn Callback>,
}
impl CloneableCallback {
pub fn new<C: Callback + 'static>(callback: C) -> Self {
CloneableCallback {
callback: Arc::new(callback),
}
}
pub fn call(&self, result: Result<JsonValue, SidecarError>) {
self.callback.call(result)
}
}
impl ResponseHandler {
fn invoke(self, result: Result<JsonValue, Error>) {
fn invoke(self, result: Result<JsonValue, SidecarError>) {
match self {
ResponseHandler::Chan(tx) => {
let _ = tx.send(result);
},
ResponseHandler::StreamCallback(cb) => {
cb.call(result);
},
ResponseHandler::Callback(f) => f.call(result),
}
}

View File

@ -1,11 +1,10 @@
use crate::core::rpc_peer::ResponsePayload;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::{json, Value as JsonValue};
use std::{fmt, io};
/// The error type of `tauri-utils`.
#[derive(Debug, thiserror::Error)]
pub enum Error {
pub enum SidecarError {
/// An IO error occurred on the underlying communication channel.
#[error(transparent)]
Io(#[from] io::Error),
@ -48,6 +47,9 @@ pub enum RemoteError {
#[error("Invalid response: {0}")]
InvalidResponse(JsonValue),
#[error("Parse response: {0}")]
ParseResponse(JsonValue),
/// A custom error, defined by the client.
#[error("Custom error: {message}")]
Custom {
@ -100,9 +102,9 @@ impl From<serde_json::Error> for RemoteError {
}
}
impl From<RemoteError> for Error {
fn from(err: RemoteError) -> Error {
Error::RemoteError(err)
impl From<RemoteError> for SidecarError {
fn from(err: RemoteError) -> SidecarError {
SidecarError::RemoteError(err)
}
}
@ -156,6 +158,11 @@ impl Serialize for RemoteError {
"Invalid response".to_string(),
Some(json!(resp.to_string())),
),
RemoteError::ParseResponse(resp) => (
-1,
"Invalid response".to_string(),
Some(json!(resp.to_string())),
),
};
let err = ErrorHelper {
code,

View File

@ -1,4 +1,4 @@
pub mod core;
mod error;
pub mod error;
pub mod manager;
pub mod plugins;

View File

@ -2,11 +2,11 @@ use crate::core::parser::ResponseParser;
use crate::core::plugin::{start_plugin_process, Plugin, PluginId, PluginInfo, RpcCtx};
use crate::core::rpc_loop::Handler;
use crate::core::rpc_peer::{PluginCommand, ResponsePayload};
use crate::error::{Error, ReadError, RemoteError};
use crate::error::{ReadError, RemoteError, SidecarError};
use anyhow::anyhow;
use lib_infra::util::{get_operating_system, OperatingSystem};
use parking_lot::Mutex;
use serde_json::{json, Value};
use serde_json::Value;
use std::io;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::{Arc, Weak};
@ -29,9 +29,9 @@ impl SidecarManager {
}
}
pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result<PluginId, Error> {
pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result<PluginId, SidecarError> {
if self.operating_system.is_not_desktop() {
return Err(Error::Internal(anyhow!(
return Err(SidecarError::Internal(anyhow!(
"plugin not supported on this platform"
)));
}
@ -41,7 +41,7 @@ impl SidecarManager {
Ok(plugin_id)
}
pub async fn get_plugin(&self, plugin_id: PluginId) -> Result<Weak<Plugin>, Error> {
pub async fn get_plugin(&self, plugin_id: PluginId) -> Result<Weak<Plugin>, SidecarError> {
let state = self.state.lock();
let plugin = state
.plugins
@ -51,9 +51,9 @@ impl SidecarManager {
Ok(Arc::downgrade(plugin))
}
pub async fn remove_plugin(&self, id: PluginId) -> Result<(), Error> {
pub async fn remove_plugin(&self, id: PluginId) -> Result<(), SidecarError> {
if self.operating_system.is_not_desktop() {
return Err(Error::Internal(anyhow!(
return Err(SidecarError::Internal(anyhow!(
"plugin not supported on this platform"
)));
}
@ -71,9 +71,9 @@ impl SidecarManager {
Ok(())
}
pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<(), Error> {
pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<(), SidecarError> {
if self.operating_system.is_not_desktop() {
return Err(Error::Internal(anyhow!(
return Err(SidecarError::Internal(anyhow!(
"plugin not supported on this platform"
)));
}
@ -94,15 +94,15 @@ impl SidecarManager {
id: PluginId,
method: &str,
request: Value,
) -> Result<P::ValueType, Error> {
) -> Result<P::ValueType, SidecarError> {
let state = self.state.lock();
let plugin = state
.plugins
.iter()
.find(|p| p.id == id)
.ok_or(anyhow!("plugin not found"))?;
let resp = plugin.send_request(method, &request)?;
let value = P::parse_response(resp)?;
let resp = plugin.request(method, &request)?;
let value = P::parse_json(resp)?;
Ok(value)
}
@ -111,14 +111,14 @@ impl SidecarManager {
id: PluginId,
method: &str,
request: Value,
) -> Result<P::ValueType, Error> {
) -> Result<P::ValueType, SidecarError> {
let state = self.state.lock();
let plugin = state
.plugins
.iter()
.find(|p| p.id == id)
.ok_or(anyhow!("plugin not found"))?;
let value = plugin.async_send_request::<P>(method, &request).await?;
let value = plugin.async_request::<P>(method, &request).await?;
Ok(value)
}
}

View File

@ -1,9 +1,13 @@
use crate::core::parser::{ChatRelatedQuestionsResponseParser, ChatResponseParser};
use crate::core::parser::{
ChatRelatedQuestionsResponseParser, ChatResponseParser, ChatStreamResponseParser,
};
use crate::core::plugin::{Plugin, PluginId};
use crate::error::Error;
use crate::error::SidecarError;
use anyhow::anyhow;
use serde_json::json;
use std::sync::Weak;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::Stream;
pub struct ChatPluginOperation {
plugin: Weak<Plugin>,
@ -17,17 +21,17 @@ impl ChatPluginOperation {
pub async fn send_message(
&self,
chat_id: &str,
plugin_id: PluginId,
_plugin_id: PluginId,
message: &str,
) -> Result<String, Error> {
) -> Result<String, SidecarError> {
let plugin = self
.plugin
.upgrade()
.ok_or(Error::Internal(anyhow!("Plugin is dropped")))?;
.ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
let params = json!({"chat_id": chat_id, "method": "answer", "params": {"content": message}});
let resp = plugin
.async_send_request::<ChatResponseParser>("handle", &params)
.async_request::<ChatResponseParser>("handle", &params)
.await?;
Ok(resp)
}
@ -35,34 +39,31 @@ impl ChatPluginOperation {
pub async fn stream_message(
&self,
chat_id: &str,
plugin_id: PluginId,
_plugin_id: PluginId,
message: &str,
) -> Result<String, Error> {
) -> Result<ReceiverStream<Result<String, SidecarError>>, SidecarError> {
let plugin = self
.plugin
.upgrade()
.ok_or(Error::Internal(anyhow!("Plugin is dropped")))?;
.ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
let params =
json!({"chat_id": chat_id, "method": "stream_answer", "params": {"content": message}});
let resp = plugin
.async_send_request::<ChatResponseParser>("handle", &params)
.await?;
Ok(resp)
plugin.stream_request::<ChatStreamResponseParser>("handle", &params)
}
pub async fn get_related_questions(
&self,
chat_id: &str,
) -> Result<Vec<serde_json::Value>, Error> {
) -> Result<Vec<serde_json::Value>, SidecarError> {
let plugin = self
.plugin
.upgrade()
.ok_or(Error::Internal(anyhow!("Plugin is dropped")))?;
.ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
let params = json!({"chat_id": chat_id, "method": "related_question"});
let resp = plugin
.async_send_request::<ChatRelatedQuestionsResponseParser>("handle", &params)
.async_request::<ChatRelatedQuestionsResponseParser>("handle", &params)
.await?;
Ok(resp)
}

View File

@ -1,8 +1,6 @@
use crate::core::parser::{
ChatRelatedQuestionsResponseParser, ChatResponseParser, SimilarityResponseParser,
};
use crate::core::plugin::{Plugin, PluginId};
use crate::error::Error;
use crate::core::parser::SimilarityResponseParser;
use crate::core::plugin::Plugin;
use crate::error::SidecarError;
use anyhow::anyhow;
use serde_json::json;
use std::sync::Weak;
@ -16,15 +14,19 @@ impl EmbeddingPluginOperation {
EmbeddingPluginOperation { plugin }
}
pub async fn calculate_similarity(&self, message1: &str, message2: &str) -> Result<f64, Error> {
pub async fn calculate_similarity(
&self,
message1: &str,
message2: &str,
) -> Result<f64, SidecarError> {
let plugin = self
.plugin
.upgrade()
.ok_or(Error::Internal(anyhow!("Plugin is dropped")))?;
.ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
let params =
json!({"method": "calculate_similarity", "params": {"src": message1, "dest": message2}});
plugin
.async_send_request::<SimilarityResponseParser>("handle", &params)
.async_request::<SimilarityResponseParser>("handle", &params)
.await
}
}

View File

@ -1,11 +1,14 @@
use crate::util::LocalAITest;
use tokio_stream::StreamExt;
#[tokio::test]
async fn load_chat_model_test() {
if let Ok(test) = LocalAITest::new() {
let plugin_id = test.init_chat_plugin().await;
let chat_id = uuid::Uuid::new_v4().to_string();
let resp = test.send_message(&chat_id, plugin_id, "hello world").await;
let resp = test
.send_chat_message(&chat_id, plugin_id, "hello world")
.await;
eprintln!("chat response: {:?}", resp);
let embedding_plugin_id = test.init_embedding_plugin().await;
@ -17,3 +20,27 @@ async fn load_chat_model_test() {
// eprintln!("related questions: {:?}", questions);
}
}
#[tokio::test]
async fn stream_local_model_test() {
if let Ok(test) = LocalAITest::new() {
let plugin_id = test.init_chat_plugin().await;
let chat_id = uuid::Uuid::new_v4().to_string();
let mut resp = test
.stream_chat_message(&chat_id, plugin_id, "hello world")
.await;
let a = resp.next().await.unwrap().unwrap();
eprintln!("chat response: {:?}", a);
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
// let mut resp = test
// .stream_chat_message(&chat_id, plugin_id, "How are you")
// .await;
// let a = resp.next().await.unwrap().unwrap();
// eprintln!("chat response: {:?}", a);
// let questions = test.related_question(&chat_id, plugin_id).await;
// assert_eq!(questions.len(), 3);
// eprintln!("related questions: {:?}", questions);
}
}

View File

@ -1,10 +1,12 @@
use anyhow::Result;
use flowy_sidecar::manager::SidecarManager;
use serde_json::json;
use std::sync::{Arc, Once};
use std::sync::Once;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::Stream;
use flowy_sidecar::core::parser::{ChatResponseParser, SimilarityResponseParser};
use flowy_sidecar::core::plugin::{PluginId, PluginInfo};
use flowy_sidecar::error::SidecarError;
use flowy_sidecar::plugins::chat_plugin::ChatPluginOperation;
use flowy_sidecar::plugins::embedding_plugin::EmbeddingPluginOperation;
use tracing_subscriber::fmt::Subscriber;
@ -61,7 +63,12 @@ impl LocalAITest {
plugin_id
}
pub async fn send_message(&self, chat_id: &str, plugin_id: PluginId, message: &str) -> String {
pub async fn send_chat_message(
&self,
chat_id: &str,
plugin_id: PluginId,
message: &str,
) -> String {
let plugin = self.manager.get_plugin(plugin_id).await.unwrap();
let operation = ChatPluginOperation::new(plugin);
let resp = operation
@ -72,6 +79,20 @@ impl LocalAITest {
resp
}
pub async fn stream_chat_message(
&self,
chat_id: &str,
plugin_id: PluginId,
message: &str,
) -> ReceiverStream<Result<String, SidecarError>> {
let plugin = self.manager.get_plugin(plugin_id).await.unwrap();
let operation = ChatPluginOperation::new(plugin);
operation
.stream_message(chat_id, plugin_id, message)
.await
.unwrap()
}
pub async fn related_question(
&self,
chat_id: &str,

View File

@ -23,5 +23,6 @@ if_wasm! {
pub mod isolate_stream;
pub mod priority_task;
pub mod ref_map;
pub mod stream_util;
pub mod util;
pub mod validator_fn;

View File

@ -0,0 +1,21 @@
use futures_core::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::mpsc;
use tokio::sync::mpsc::{Receiver, Sender};
struct BoundedStream<T> {
recv: Receiver<T>,
}
impl<T> Stream for BoundedStream<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
Pin::into_inner(self).recv.poll_recv(cx)
}
}
pub fn mpsc_channel_stream<T: Unpin>(size: usize) -> (Sender<T>, impl Stream<Item = T>) {
let (tx, rx) = mpsc::channel(size);
let stream = BoundedStream { recv: rx };
(tx, stream)
}