mirror of
https://github.com/AppFlowy-IO/AppFlowy.git
synced 2024-08-30 18:12:39 +00:00
chore: test streaming
This commit is contained in:
parent
9472361664
commit
e4b1108ff0
6
frontend/rust-lib/Cargo.lock
generated
6
frontend/rust-lib/Cargo.lock
generated
@ -2148,7 +2148,6 @@ dependencies = [
|
|||||||
"anyhow",
|
"anyhow",
|
||||||
"crossbeam-utils",
|
"crossbeam-utils",
|
||||||
"dotenv",
|
"dotenv",
|
||||||
"futures",
|
|
||||||
"lib-infra",
|
"lib-infra",
|
||||||
"log",
|
"log",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
@ -2157,6 +2156,7 @@ dependencies = [
|
|||||||
"serde_json",
|
"serde_json",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-stream",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
"uuid",
|
"uuid",
|
||||||
@ -5767,9 +5767,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-stream"
|
name = "tokio-stream"
|
||||||
version = "0.1.14"
|
version = "0.1.15"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842"
|
checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
|
@ -11,7 +11,7 @@ use flowy_sidecar::manager::SidecarManager;
|
|||||||
use flowy_sqlite::DBConnection;
|
use flowy_sqlite::DBConnection;
|
||||||
use lib_infra::future::FutureResult;
|
use lib_infra::future::FutureResult;
|
||||||
use lib_infra::util::timestamp;
|
use lib_infra::util::timestamp;
|
||||||
use std::sync::atomic::AtomicBool;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tracing::trace;
|
use tracing::trace;
|
||||||
|
|
||||||
@ -263,9 +263,9 @@ impl ChatCloudService for ChatService {
|
|||||||
|
|
||||||
async fn stream_answer(
|
async fn stream_answer(
|
||||||
&self,
|
&self,
|
||||||
workspace_id: &str,
|
_workspace_id: &str,
|
||||||
chat_id: &str,
|
_chat_id: &str,
|
||||||
message_id: i64,
|
_message_id: i64,
|
||||||
) -> Result<StreamAnswer, FlowyError> {
|
) -> Result<StreamAnswer, FlowyError> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
@ -16,8 +16,8 @@ tracing.workspace = true
|
|||||||
crossbeam-utils = "0.8.20"
|
crossbeam-utils = "0.8.20"
|
||||||
log = "0.4.21"
|
log = "0.4.21"
|
||||||
parking_lot.workspace = true
|
parking_lot.workspace = true
|
||||||
|
tokio-stream = "0.1.15"
|
||||||
lib-infra.workspace = true
|
lib-infra.workspace = true
|
||||||
futures.workspace = true
|
|
||||||
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
use crate::core::rpc_object::RpcObject;
|
use crate::core::rpc_object::RpcObject;
|
||||||
use crate::core::rpc_peer::ResponsePayload;
|
|
||||||
use crate::error::{ReadError, RemoteError};
|
use crate::error::{ReadError, RemoteError};
|
||||||
use serde_json::{json, Value as JsonValue};
|
use serde_json::{json, Value as JsonValue};
|
||||||
use std::io::BufRead;
|
use std::io::BufRead;
|
||||||
|
use tracing::error;
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
pub struct MessageReader(String);
|
pub struct MessageReader(String);
|
||||||
@ -57,15 +58,15 @@ pub enum Call<R> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub trait ResponseParser {
|
pub trait ResponseParser {
|
||||||
type ValueType;
|
type ValueType: Send + Sync + 'static;
|
||||||
fn parse_response(payload: JsonValue) -> Result<Self::ValueType, RemoteError>;
|
fn parse_json(payload: JsonValue) -> Result<Self::ValueType, RemoteError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ChatResponseParser;
|
pub struct ChatResponseParser;
|
||||||
impl ResponseParser for ChatResponseParser {
|
impl ResponseParser for ChatResponseParser {
|
||||||
type ValueType = String;
|
type ValueType = String;
|
||||||
|
|
||||||
fn parse_response(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
||||||
if json.is_object() {
|
if json.is_object() {
|
||||||
if let Some(data) = json.get("data") {
|
if let Some(data) = json.get("data") {
|
||||||
if let Some(message) = data.as_str() {
|
if let Some(message) = data.as_str() {
|
||||||
@ -73,7 +74,19 @@ impl ResponseParser for ChatResponseParser {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Err(RemoteError::InvalidResponse(json));
|
return Err(RemoteError::ParseResponse(json));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ChatStreamResponseParser;
|
||||||
|
impl ResponseParser for ChatStreamResponseParser {
|
||||||
|
type ValueType = String;
|
||||||
|
|
||||||
|
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
||||||
|
if let Some(message) = json.as_str() {
|
||||||
|
return Ok(message.to_string());
|
||||||
|
}
|
||||||
|
return Err(RemoteError::ParseResponse(json));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,7 +94,7 @@ pub struct ChatRelatedQuestionsResponseParser;
|
|||||||
impl ResponseParser for ChatRelatedQuestionsResponseParser {
|
impl ResponseParser for ChatRelatedQuestionsResponseParser {
|
||||||
type ValueType = Vec<JsonValue>;
|
type ValueType = Vec<JsonValue>;
|
||||||
|
|
||||||
fn parse_response(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
||||||
if json.is_object() {
|
if json.is_object() {
|
||||||
if let Some(data) = json.get("data") {
|
if let Some(data) = json.get("data") {
|
||||||
if let Some(values) = data.as_array() {
|
if let Some(values) = data.as_array() {
|
||||||
@ -89,7 +102,7 @@ impl ResponseParser for ChatRelatedQuestionsResponseParser {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Err(RemoteError::InvalidResponse(json));
|
return Err(RemoteError::ParseResponse(json));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,7 +110,7 @@ pub struct SimilarityResponseParser;
|
|||||||
impl ResponseParser for SimilarityResponseParser {
|
impl ResponseParser for SimilarityResponseParser {
|
||||||
type ValueType = f64;
|
type ValueType = f64;
|
||||||
|
|
||||||
fn parse_response(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
||||||
if json.is_object() {
|
if json.is_object() {
|
||||||
if let Some(data) = json.get("data") {
|
if let Some(data) = json.get("data") {
|
||||||
if let Some(score) = data.get("score").and_then(|v| v.as_f64()) {
|
if let Some(score) = data.get("score").and_then(|v| v.as_f64()) {
|
||||||
@ -106,6 +119,6 @@ impl ResponseParser for SimilarityResponseParser {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Err(RemoteError::InvalidResponse(json));
|
return Err(RemoteError::ParseResponse(json));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
use crate::error::Error;
|
use crate::error::SidecarError;
|
||||||
use crate::manager::WeakSidecarState;
|
use crate::manager::WeakSidecarState;
|
||||||
|
|
||||||
use crate::core::parser::ResponseParser;
|
use crate::core::parser::ResponseParser;
|
||||||
use crate::core::rpc_loop::RpcLoop;
|
use crate::core::rpc_loop::RpcLoop;
|
||||||
use crate::core::rpc_peer::{Callback, ResponsePayload};
|
use crate::core::rpc_peer::{CloneableCallback, OneShotCallback};
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{json, Value as JsonValue};
|
use serde_json::{json, Value as JsonValue};
|
||||||
@ -12,6 +12,9 @@ use std::process::{Child, Stdio};
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::thread;
|
use std::thread;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
|
use tokio_stream::Stream;
|
||||||
|
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
|
|
||||||
#[derive(
|
#[derive(
|
||||||
@ -34,12 +37,12 @@ pub trait Peer: Send + Sync + 'static {
|
|||||||
/// Sends an RPC notification to the peer with the specified method and parameters.
|
/// Sends an RPC notification to the peer with the specified method and parameters.
|
||||||
fn send_rpc_notification(&self, method: &str, params: &JsonValue);
|
fn send_rpc_notification(&self, method: &str, params: &JsonValue);
|
||||||
|
|
||||||
/// Sends an asynchronous RPC request to the peer and executes the provided callback upon completion.
|
fn stream_rpc_request(&self, method: &str, params: &JsonValue, f: CloneableCallback);
|
||||||
fn send_rpc_request_async(&self, method: &str, params: &JsonValue, f: Box<dyn Callback>);
|
|
||||||
|
|
||||||
|
fn async_send_rpc_request(&self, method: &str, params: &JsonValue, f: Box<dyn OneShotCallback>);
|
||||||
/// Sends a synchronous RPC request to the peer and waits for the result.
|
/// Sends a synchronous RPC request to the peer and waits for the result.
|
||||||
/// Returns the result of the request or an error.
|
/// Returns the result of the request or an error.
|
||||||
fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, Error>;
|
fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, SidecarError>;
|
||||||
|
|
||||||
/// Checks if there is an incoming request pending, intended to reduce latency for bulk operations done in the background.
|
/// Checks if there is an incoming request pending, intended to reduce latency for bulk operations done in the background.
|
||||||
fn request_is_pending(&self) -> bool;
|
fn request_is_pending(&self) -> bool;
|
||||||
@ -66,34 +69,53 @@ pub struct Plugin {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Plugin {
|
impl Plugin {
|
||||||
pub fn initialize(&self, value: JsonValue) -> Result<(), Error> {
|
pub fn initialize(&self, value: JsonValue) -> Result<(), SidecarError> {
|
||||||
self.peer.send_rpc_request("initialize", &value)?;
|
self.peer.send_rpc_request("initialize", &value)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn send_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, Error> {
|
pub fn request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, SidecarError> {
|
||||||
self.peer.send_rpc_request(method, params)
|
self.peer.send_rpc_request(method, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn async_send_request<P: ResponseParser>(
|
pub async fn async_request<P: ResponseParser>(
|
||||||
&self,
|
&self,
|
||||||
method: &str,
|
method: &str,
|
||||||
params: &JsonValue,
|
params: &JsonValue,
|
||||||
) -> Result<P::ValueType, Error> {
|
) -> Result<P::ValueType, SidecarError> {
|
||||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||||
self.peer.send_rpc_request_async(
|
self.peer.async_send_rpc_request(
|
||||||
method,
|
method,
|
||||||
params,
|
params,
|
||||||
Box::new(move |result| {
|
Box::new(move |result| {
|
||||||
let _ = tx.send(result);
|
let _ = tx.send(result);
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
let value = rx
|
let value = rx.await.map_err(|err| {
|
||||||
.await
|
SidecarError::Internal(anyhow!("error waiting for async response: {:?}", err))
|
||||||
.map_err(|err| Error::Internal(anyhow!("error waiting for async response: {:?}", err)))??;
|
})??;
|
||||||
let value = P::parse_response(value)?;
|
let value = P::parse_json(value)?;
|
||||||
Ok(value)
|
Ok(value)
|
||||||
}
|
}
|
||||||
|
pub fn stream_request<P: ResponseParser>(
|
||||||
|
&self,
|
||||||
|
method: &str,
|
||||||
|
params: &JsonValue,
|
||||||
|
) -> Result<ReceiverStream<Result<P::ValueType, SidecarError>>, SidecarError> {
|
||||||
|
let (tx, stream) = tokio::sync::mpsc::channel(100);
|
||||||
|
let stream = ReceiverStream::new(stream);
|
||||||
|
let callback = CloneableCallback::new(move |result| match result {
|
||||||
|
Ok(json) => {
|
||||||
|
let result = P::parse_json(json).map_err(SidecarError::from);
|
||||||
|
let _ = tx.blocking_send(result);
|
||||||
|
},
|
||||||
|
Err(err) => {
|
||||||
|
let _ = tx.blocking_send(Err(err));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
self.peer.stream_rpc_request(method, params, callback);
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn shutdown(&self) {
|
pub fn shutdown(&self) {
|
||||||
if let Err(err) = self.peer.send_rpc_request("shutdown", &json!({})) {
|
if let Err(err) = self.peer.send_rpc_request("shutdown", &json!({})) {
|
||||||
|
@ -2,14 +2,14 @@ use crate::core::parser::{Call, MessageReader};
|
|||||||
use crate::core::plugin::RpcCtx;
|
use crate::core::plugin::RpcCtx;
|
||||||
use crate::core::rpc_object::RpcObject;
|
use crate::core::rpc_object::RpcObject;
|
||||||
use crate::core::rpc_peer::{RawPeer, ResponsePayload, RpcState};
|
use crate::core::rpc_peer::{RawPeer, ResponsePayload, RpcState};
|
||||||
use crate::error::{Error, ReadError, RemoteError};
|
use crate::error::{ReadError, RemoteError, SidecarError};
|
||||||
use serde::de::DeserializeOwned;
|
use serde::de::DeserializeOwned;
|
||||||
use serde_json::Value;
|
|
||||||
use std::io::{BufRead, Write};
|
use std::io::{BufRead, Write};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::thread;
|
use std::thread;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tracing::{error, trace};
|
use tracing::{error, info, trace};
|
||||||
|
|
||||||
const MAX_IDLE_WAIT: Duration = Duration::from_millis(5);
|
const MAX_IDLE_WAIT: Duration = Duration::from_millis(5);
|
||||||
|
|
||||||
@ -97,7 +97,6 @@ impl<W: Write + Send> RpcLoop<W> {
|
|||||||
trace!("read loop exit");
|
trace!("read loop exit");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
let json = match self.reader.next(&mut stream) {
|
let json = match self.reader.next(&mut stream) {
|
||||||
Ok(json) => json,
|
Ok(json) => json,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
@ -109,15 +108,17 @@ impl<W: Write + Send> RpcLoop<W> {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
if json.is_response() {
|
if json.is_response() {
|
||||||
let id = json.get_id().unwrap();
|
let request_id = json.get_id().unwrap();
|
||||||
match json.into_response() {
|
match json.into_response() {
|
||||||
Ok(resp) => {
|
Ok(resp) => {
|
||||||
let resp = resp.map_err(Error::from);
|
let resp = resp.map_err(SidecarError::from);
|
||||||
self.peer.handle_response(id, resp);
|
self.peer.handle_response(request_id, resp);
|
||||||
},
|
},
|
||||||
Err(msg) => {
|
Err(msg) => {
|
||||||
error!("[RPC] failed to parse response: {}", msg);
|
error!("[RPC] failed to parse response: {}", msg);
|
||||||
self.peer.handle_response(id, Err(Error::InvalidResponse));
|
self
|
||||||
|
.peer
|
||||||
|
.handle_response(request_id, Err(SidecarError::InvalidResponse));
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
use crate::core::parser::{Call, RequestId};
|
use crate::core::parser::{Call, RequestId};
|
||||||
use crate::core::rpc_peer::{Response, ResponsePayload};
|
use crate::core::rpc_peer::{Response, ResponsePayload};
|
||||||
|
use crate::error::RemoteError;
|
||||||
use serde::de::{DeserializeOwned, Error};
|
use serde::de::{DeserializeOwned, Error};
|
||||||
use serde_json::Value;
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct RpcObject(pub Value);
|
pub struct RpcObject(pub Value);
|
||||||
@ -24,30 +25,63 @@ impl RpcObject {
|
|||||||
self.0.get("id").is_some() && self.0.get("method").is_none()
|
self.0.get("id").is_some() && self.0.get("method").is_none()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Converts the underlying `Value` into an RPC response object.
|
/// Converts a JSON-RPC response into a structured `Response` object.
|
||||||
/// The caller should verify that the object is a response before calling this method.
|
///
|
||||||
|
/// This function validates and parses a JSON-RPC response, ensuring it contains the necessary fields,
|
||||||
|
/// and then transforms it into a structured `Response` object. The response must contain either a
|
||||||
|
/// "result" or an "error" field, but not both. If the response contains a "result" field, it may also
|
||||||
|
/// include streaming data, indicated by a nested "stream" field.
|
||||||
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
/// If the `Value` is not a well-formed response object, this returns a `String` containing an
|
///
|
||||||
/// error message. The caller should print this message and exit.
|
/// This function will return an error if:
|
||||||
|
/// - The "id" field is missing.
|
||||||
|
/// - The response contains both "result" and "error" fields, or neither.
|
||||||
|
/// - The "stream" field within the "result" is missing "type" or "data" fields.
|
||||||
|
/// - The "stream" type is invalid (i.e., not "streaming" or "end").
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// - `Ok(Ok(ResponsePayload::Json(result)))`: If the response contains a valid "result".
|
||||||
|
/// - `Ok(Ok(ResponsePayload::Streaming(data)))`: If the response contains streaming data of type "streaming".
|
||||||
|
/// - `Ok(Ok(ResponsePayload::StreamEnd(json!({}))))`: If the response contains streaming data of type "end".
|
||||||
|
/// - `Err(String)`: If any validation or parsing errors occur.
|
||||||
|
///.
|
||||||
pub fn into_response(mut self) -> Result<Response, String> {
|
pub fn into_response(mut self) -> Result<Response, String> {
|
||||||
let _ = self
|
// Ensure 'id' field is present
|
||||||
|
self
|
||||||
.get_id()
|
.get_id()
|
||||||
.ok_or("Response requires 'id' field.".to_string())?;
|
.ok_or_else(|| "Response requires 'id' field.".to_string())?;
|
||||||
|
|
||||||
if self.0.get("result").is_some() == self.0.get("error").is_some() {
|
// Ensure the response contains exactly one of 'result' or 'error'
|
||||||
|
let has_result = self.0.get("result").is_some();
|
||||||
|
let has_error = self.0.get("error").is_some();
|
||||||
|
if has_result == has_error {
|
||||||
return Err("RPC response must contain exactly one of 'error' or 'result' fields.".into());
|
return Err("RPC response must contain exactly one of 'error' or 'result' fields.".into());
|
||||||
}
|
}
|
||||||
let result = self.0.as_object_mut().and_then(|obj| obj.remove("result"));
|
|
||||||
match result {
|
// Handle the 'result' field if present
|
||||||
Some(r) => Ok(Ok(ResponsePayload::Json(r))),
|
if let Some(mut result) = self.0.as_object_mut().and_then(|obj| obj.remove("result")) {
|
||||||
None => {
|
if let Some(mut stream) = result.as_object_mut().and_then(|obj| obj.remove("stream")) {
|
||||||
let error = self
|
if let Some((has_more, data)) = stream.as_object_mut().and_then(|obj| {
|
||||||
.0
|
let has_more = obj.remove("has_more")?.as_bool().unwrap_or(false);
|
||||||
.as_object_mut()
|
let data = obj.remove("data")?;
|
||||||
.and_then(|obj| obj.remove("error"))
|
Some((has_more, data))
|
||||||
.unwrap();
|
}) {
|
||||||
|
return match has_more {
|
||||||
|
true => Ok(Ok(ResponsePayload::Streaming(data))),
|
||||||
|
false => Ok(Ok(ResponsePayload::StreamEnd(data))),
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
return Err("Stream response must contain 'type' and 'data' fields.".into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Ok(ResponsePayload::Json(result)))
|
||||||
|
} else {
|
||||||
|
// Handle the 'error' field
|
||||||
|
let error = self.0.as_object_mut().unwrap().remove("error").unwrap();
|
||||||
Err(format!("Error handling response: {:?}", error))
|
Err(format!("Error handling response: {:?}", error))
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,18 +1,18 @@
|
|||||||
use crate::core::plugin::{Peer, PluginId};
|
use crate::core::plugin::{Peer, PluginId};
|
||||||
use crate::core::rpc_object::RpcObject;
|
use crate::core::rpc_object::RpcObject;
|
||||||
use crate::error::{Error, ReadError, RemoteError};
|
use crate::error::{ReadError, RemoteError, SidecarError};
|
||||||
use futures::Stream;
|
|
||||||
use parking_lot::{Condvar, Mutex};
|
use parking_lot::{Condvar, Mutex};
|
||||||
use serde::{de, ser, Deserialize, Deserializer, Serialize, Serializer};
|
use serde::{de, ser, Deserialize, Deserializer, Serialize, Serializer};
|
||||||
use serde_json::{json, Value as JsonValue};
|
use serde_json::{json, Value as JsonValue};
|
||||||
use std::collections::{BTreeMap, BinaryHeap, VecDeque};
|
use std::collections::{BTreeMap, BinaryHeap, VecDeque};
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::pin::Pin;
|
|
||||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||||
use std::sync::{mpsc, Arc};
|
use std::sync::{mpsc, Arc};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use std::{cmp, io};
|
use std::{cmp, io};
|
||||||
|
use tokio_stream::Stream;
|
||||||
use tracing::{error, trace, warn};
|
use tracing::{error, trace, warn};
|
||||||
|
|
||||||
pub struct PluginCommand<T> {
|
pub struct PluginCommand<T> {
|
||||||
@ -97,15 +97,19 @@ impl<W: Write + Send + 'static> Peer for RawPeer<W> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send_rpc_request_async(&self, method: &str, params: &JsonValue, f: Box<dyn Callback>) {
|
fn stream_rpc_request(&self, method: &str, params: &JsonValue, f: CloneableCallback) {
|
||||||
|
self.send_rpc(method, params, ResponseHandler::StreamCallback(Arc::new(f)));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn async_send_rpc_request(&self, method: &str, params: &JsonValue, f: Box<dyn OneShotCallback>) {
|
||||||
self.send_rpc(method, params, ResponseHandler::Callback(f));
|
self.send_rpc(method, params, ResponseHandler::Callback(f));
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, Error> {
|
fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, SidecarError> {
|
||||||
let (tx, rx) = mpsc::channel();
|
let (tx, rx) = mpsc::channel();
|
||||||
self.0.is_blocking.store(true, Ordering::Release);
|
self.0.is_blocking.store(true, Ordering::Release);
|
||||||
self.send_rpc(method, params, ResponseHandler::Chan(tx));
|
self.send_rpc(method, params, ResponseHandler::Chan(tx));
|
||||||
rx.recv().unwrap_or(Err(Error::PeerDisconnect))
|
rx.recv().unwrap_or(Err(SidecarError::PeerDisconnect))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn request_is_pending(&self) -> bool {
|
fn request_is_pending(&self) -> bool {
|
||||||
@ -133,7 +137,7 @@ impl<W: Write> RawPeer<W> {
|
|||||||
match result {
|
match result {
|
||||||
Ok(result) => match result {
|
Ok(result) => match result {
|
||||||
ResponsePayload::Json(value) => response["result"] = value,
|
ResponsePayload::Json(value) => response["result"] = value,
|
||||||
ResponsePayload::Stream(_) => {
|
ResponsePayload::Streaming(_) | ResponsePayload::StreamEnd(_) => {
|
||||||
error!("stream response not supported")
|
error!("stream response not supported")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -161,29 +165,44 @@ impl<W: Write> RawPeer<W> {
|
|||||||
})) {
|
})) {
|
||||||
let mut pending = self.0.pending.lock();
|
let mut pending = self.0.pending.lock();
|
||||||
if let Some(rh) = pending.remove(&id) {
|
if let Some(rh) = pending.remove(&id) {
|
||||||
rh.invoke(Err(Error::Io(e)));
|
rh.invoke(Err(SidecarError::Io(e)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn handle_response(&self, id: u64, resp: Result<ResponsePayload, Error>) {
|
pub(crate) fn handle_response(
|
||||||
let id = id as usize;
|
&self,
|
||||||
|
request_id: u64,
|
||||||
|
resp: Result<ResponsePayload, SidecarError>,
|
||||||
|
) {
|
||||||
|
let request_id = request_id as usize;
|
||||||
let handler = {
|
let handler = {
|
||||||
let mut pending = self.0.pending.lock();
|
let mut pending = self.0.pending.lock();
|
||||||
pending.remove(&id)
|
pending.remove(&request_id)
|
||||||
};
|
};
|
||||||
let is_stream = resp.as_ref().map(|resp| resp.is_stream()).unwrap_or(false);
|
let is_stream = resp.as_ref().map(|resp| resp.is_stream()).unwrap_or(false);
|
||||||
match handler {
|
match handler {
|
||||||
Some(response_handler) => {
|
Some(response_handler) => {
|
||||||
|
if is_stream {
|
||||||
|
let is_stream_end = resp
|
||||||
|
.as_ref()
|
||||||
|
.map(|resp| resp.is_stream_end())
|
||||||
|
.unwrap_or(false);
|
||||||
|
if !is_stream_end {
|
||||||
|
// when steam is not end, we need to put the stream callback back to pending in order to
|
||||||
|
// receive the next stream message.
|
||||||
|
if let Some(callback) = response_handler.get_stream_callback() {
|
||||||
|
let mut pending = self.0.pending.lock();
|
||||||
|
pending.insert(request_id, ResponseHandler::StreamCallback(callback));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
trace!("[RPC] {} stream end", request_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
let json = resp.map(|resp| resp.into_json());
|
let json = resp.map(|resp| resp.into_json());
|
||||||
response_handler.invoke(json);
|
response_handler.invoke(json);
|
||||||
|
|
||||||
// if is_stream {
|
|
||||||
// let mut pending = self.0.pending.lock();
|
|
||||||
// pending.insert(id, response_handler);
|
|
||||||
// }
|
|
||||||
},
|
},
|
||||||
None => warn!("[RPC] id {} not found in pending", id),
|
None => error!("[RPC] id {}'s handle not found", request_id),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -243,7 +262,7 @@ impl<W: Write> RawPeer<W> {
|
|||||||
let ids = pending.keys().cloned().collect::<Vec<_>>();
|
let ids = pending.keys().cloned().collect::<Vec<_>>();
|
||||||
for id in &ids {
|
for id in &ids {
|
||||||
let callback = pending.remove(id).unwrap();
|
let callback = pending.remove(id).unwrap();
|
||||||
callback.invoke(Err(Error::PeerDisconnect));
|
callback.invoke(Err(SidecarError::PeerDisconnect));
|
||||||
}
|
}
|
||||||
self.0.needs_exit.store(true, Ordering::Relaxed);
|
self.0.needs_exit.store(true, Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
@ -267,7 +286,8 @@ impl<W: Write> Clone for RawPeer<W> {
|
|||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub enum ResponsePayload {
|
pub enum ResponsePayload {
|
||||||
Json(JsonValue),
|
Json(JsonValue),
|
||||||
Stream(JsonValue),
|
Streaming(JsonValue),
|
||||||
|
StreamEnd(JsonValue),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ResponsePayload {
|
impl ResponsePayload {
|
||||||
@ -276,23 +296,29 @@ impl ResponsePayload {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_stream(&self) -> bool {
|
pub fn is_stream(&self) -> bool {
|
||||||
match self {
|
matches!(
|
||||||
ResponsePayload::Json(_) => false,
|
self,
|
||||||
ResponsePayload::Stream(_) => true,
|
ResponsePayload::Streaming(_) | ResponsePayload::StreamEnd(_)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_stream_end(&self) -> bool {
|
||||||
|
matches!(self, ResponsePayload::StreamEnd(_))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn json(&self) -> &JsonValue {
|
pub fn json(&self) -> &JsonValue {
|
||||||
match self {
|
match self {
|
||||||
ResponsePayload::Json(v) => v,
|
ResponsePayload::Json(v) => v,
|
||||||
ResponsePayload::Stream(v) => v,
|
ResponsePayload::Streaming(v) => v,
|
||||||
|
ResponsePayload::StreamEnd(v) => v,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn into_json(self) -> JsonValue {
|
pub fn into_json(self) -> JsonValue {
|
||||||
match self {
|
match self {
|
||||||
ResponsePayload::Json(v) => v,
|
ResponsePayload::Json(v) => v,
|
||||||
ResponsePayload::Stream(v) => v,
|
ResponsePayload::Streaming(v) => v,
|
||||||
|
ResponsePayload::StreamEnd(v) => v,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -301,37 +327,78 @@ impl Display for ResponsePayload {
|
|||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
ResponsePayload::Json(v) => write!(f, "{}", v),
|
ResponsePayload::Json(v) => write!(f, "{}", v),
|
||||||
ResponsePayload::Stream(_) => write!(f, "stream"),
|
ResponsePayload::Streaming(_) => write!(f, "stream start"),
|
||||||
|
ResponsePayload::StreamEnd(_) => write!(f, "stream end"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Response = Result<ResponsePayload, RemoteError>;
|
pub type Response = Result<ResponsePayload, RemoteError>;
|
||||||
|
|
||||||
pub trait ResponseStream: Stream<Item = Result<JsonValue, Error>> + Unpin + Send {}
|
pub trait ResponseStream: Stream<Item = Result<JsonValue, SidecarError>> + Unpin + Send {}
|
||||||
|
|
||||||
impl<T> ResponseStream for T where T: Stream<Item = Result<JsonValue, Error>> + Unpin + Send {}
|
impl<T> ResponseStream for T where T: Stream<Item = Result<JsonValue, SidecarError>> + Unpin + Send {}
|
||||||
|
|
||||||
enum ResponseHandler {
|
enum ResponseHandler {
|
||||||
Chan(mpsc::Sender<Result<JsonValue, Error>>),
|
Chan(mpsc::Sender<Result<JsonValue, SidecarError>>),
|
||||||
Callback(Box<dyn Callback>),
|
Callback(Box<dyn OneShotCallback>),
|
||||||
}
|
StreamCallback(Arc<CloneableCallback>),
|
||||||
pub trait Callback: Send {
|
|
||||||
fn call(self: Box<Self>, result: Result<JsonValue, Error>);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F: Send + FnOnce(Result<JsonValue, Error>)> Callback for F {
|
impl ResponseHandler {
|
||||||
fn call(self: Box<F>, result: Result<JsonValue, Error>) {
|
pub fn get_stream_callback(&self) -> Option<Arc<CloneableCallback>> {
|
||||||
|
match self {
|
||||||
|
ResponseHandler::StreamCallback(cb) => Some(cb.clone()),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait OneShotCallback: Send {
|
||||||
|
fn call(self: Box<Self>, result: Result<JsonValue, SidecarError>);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Send + FnOnce(Result<JsonValue, SidecarError>)> OneShotCallback for F {
|
||||||
|
fn call(self: Box<Self>, result: Result<JsonValue, SidecarError>) {
|
||||||
|
(self)(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Callback: Send + Sync {
|
||||||
|
fn call(&self, result: Result<JsonValue, SidecarError>);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Send + Sync + Fn(Result<JsonValue, SidecarError>)> Callback for F {
|
||||||
|
fn call(&self, result: Result<JsonValue, SidecarError>) {
|
||||||
(*self)(result)
|
(*self)(result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct CloneableCallback {
|
||||||
|
callback: Arc<dyn Callback>,
|
||||||
|
}
|
||||||
|
impl CloneableCallback {
|
||||||
|
pub fn new<C: Callback + 'static>(callback: C) -> Self {
|
||||||
|
CloneableCallback {
|
||||||
|
callback: Arc::new(callback),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call(&self, result: Result<JsonValue, SidecarError>) {
|
||||||
|
self.callback.call(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ResponseHandler {
|
impl ResponseHandler {
|
||||||
fn invoke(self, result: Result<JsonValue, Error>) {
|
fn invoke(self, result: Result<JsonValue, SidecarError>) {
|
||||||
match self {
|
match self {
|
||||||
ResponseHandler::Chan(tx) => {
|
ResponseHandler::Chan(tx) => {
|
||||||
let _ = tx.send(result);
|
let _ = tx.send(result);
|
||||||
},
|
},
|
||||||
|
ResponseHandler::StreamCallback(cb) => {
|
||||||
|
cb.call(result);
|
||||||
|
},
|
||||||
ResponseHandler::Callback(f) => f.call(result),
|
ResponseHandler::Callback(f) => f.call(result),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
use crate::core::rpc_peer::ResponsePayload;
|
|
||||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||||
use serde_json::{json, Value as JsonValue};
|
use serde_json::{json, Value as JsonValue};
|
||||||
use std::{fmt, io};
|
use std::{fmt, io};
|
||||||
|
|
||||||
/// The error type of `tauri-utils`.
|
/// The error type of `tauri-utils`.
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
pub enum Error {
|
pub enum SidecarError {
|
||||||
/// An IO error occurred on the underlying communication channel.
|
/// An IO error occurred on the underlying communication channel.
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Io(#[from] io::Error),
|
Io(#[from] io::Error),
|
||||||
@ -48,6 +47,9 @@ pub enum RemoteError {
|
|||||||
|
|
||||||
#[error("Invalid response: {0}")]
|
#[error("Invalid response: {0}")]
|
||||||
InvalidResponse(JsonValue),
|
InvalidResponse(JsonValue),
|
||||||
|
|
||||||
|
#[error("Parse response: {0}")]
|
||||||
|
ParseResponse(JsonValue),
|
||||||
/// A custom error, defined by the client.
|
/// A custom error, defined by the client.
|
||||||
#[error("Custom error: {message}")]
|
#[error("Custom error: {message}")]
|
||||||
Custom {
|
Custom {
|
||||||
@ -100,9 +102,9 @@ impl From<serde_json::Error> for RemoteError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<RemoteError> for Error {
|
impl From<RemoteError> for SidecarError {
|
||||||
fn from(err: RemoteError) -> Error {
|
fn from(err: RemoteError) -> SidecarError {
|
||||||
Error::RemoteError(err)
|
SidecarError::RemoteError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -156,6 +158,11 @@ impl Serialize for RemoteError {
|
|||||||
"Invalid response".to_string(),
|
"Invalid response".to_string(),
|
||||||
Some(json!(resp.to_string())),
|
Some(json!(resp.to_string())),
|
||||||
),
|
),
|
||||||
|
RemoteError::ParseResponse(resp) => (
|
||||||
|
-1,
|
||||||
|
"Invalid response".to_string(),
|
||||||
|
Some(json!(resp.to_string())),
|
||||||
|
),
|
||||||
};
|
};
|
||||||
let err = ErrorHelper {
|
let err = ErrorHelper {
|
||||||
code,
|
code,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
pub mod core;
|
pub mod core;
|
||||||
mod error;
|
pub mod error;
|
||||||
pub mod manager;
|
pub mod manager;
|
||||||
pub mod plugins;
|
pub mod plugins;
|
||||||
|
@ -2,11 +2,11 @@ use crate::core::parser::ResponseParser;
|
|||||||
use crate::core::plugin::{start_plugin_process, Plugin, PluginId, PluginInfo, RpcCtx};
|
use crate::core::plugin::{start_plugin_process, Plugin, PluginId, PluginInfo, RpcCtx};
|
||||||
use crate::core::rpc_loop::Handler;
|
use crate::core::rpc_loop::Handler;
|
||||||
use crate::core::rpc_peer::{PluginCommand, ResponsePayload};
|
use crate::core::rpc_peer::{PluginCommand, ResponsePayload};
|
||||||
use crate::error::{Error, ReadError, RemoteError};
|
use crate::error::{ReadError, RemoteError, SidecarError};
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use lib_infra::util::{get_operating_system, OperatingSystem};
|
use lib_infra::util::{get_operating_system, OperatingSystem};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use serde_json::{json, Value};
|
use serde_json::Value;
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::sync::atomic::{AtomicI64, Ordering};
|
use std::sync::atomic::{AtomicI64, Ordering};
|
||||||
use std::sync::{Arc, Weak};
|
use std::sync::{Arc, Weak};
|
||||||
@ -29,9 +29,9 @@ impl SidecarManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result<PluginId, Error> {
|
pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result<PluginId, SidecarError> {
|
||||||
if self.operating_system.is_not_desktop() {
|
if self.operating_system.is_not_desktop() {
|
||||||
return Err(Error::Internal(anyhow!(
|
return Err(SidecarError::Internal(anyhow!(
|
||||||
"plugin not supported on this platform"
|
"plugin not supported on this platform"
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
@ -41,7 +41,7 @@ impl SidecarManager {
|
|||||||
Ok(plugin_id)
|
Ok(plugin_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_plugin(&self, plugin_id: PluginId) -> Result<Weak<Plugin>, Error> {
|
pub async fn get_plugin(&self, plugin_id: PluginId) -> Result<Weak<Plugin>, SidecarError> {
|
||||||
let state = self.state.lock();
|
let state = self.state.lock();
|
||||||
let plugin = state
|
let plugin = state
|
||||||
.plugins
|
.plugins
|
||||||
@ -51,9 +51,9 @@ impl SidecarManager {
|
|||||||
Ok(Arc::downgrade(plugin))
|
Ok(Arc::downgrade(plugin))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn remove_plugin(&self, id: PluginId) -> Result<(), Error> {
|
pub async fn remove_plugin(&self, id: PluginId) -> Result<(), SidecarError> {
|
||||||
if self.operating_system.is_not_desktop() {
|
if self.operating_system.is_not_desktop() {
|
||||||
return Err(Error::Internal(anyhow!(
|
return Err(SidecarError::Internal(anyhow!(
|
||||||
"plugin not supported on this platform"
|
"plugin not supported on this platform"
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
@ -71,9 +71,9 @@ impl SidecarManager {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<(), Error> {
|
pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<(), SidecarError> {
|
||||||
if self.operating_system.is_not_desktop() {
|
if self.operating_system.is_not_desktop() {
|
||||||
return Err(Error::Internal(anyhow!(
|
return Err(SidecarError::Internal(anyhow!(
|
||||||
"plugin not supported on this platform"
|
"plugin not supported on this platform"
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
@ -94,15 +94,15 @@ impl SidecarManager {
|
|||||||
id: PluginId,
|
id: PluginId,
|
||||||
method: &str,
|
method: &str,
|
||||||
request: Value,
|
request: Value,
|
||||||
) -> Result<P::ValueType, Error> {
|
) -> Result<P::ValueType, SidecarError> {
|
||||||
let state = self.state.lock();
|
let state = self.state.lock();
|
||||||
let plugin = state
|
let plugin = state
|
||||||
.plugins
|
.plugins
|
||||||
.iter()
|
.iter()
|
||||||
.find(|p| p.id == id)
|
.find(|p| p.id == id)
|
||||||
.ok_or(anyhow!("plugin not found"))?;
|
.ok_or(anyhow!("plugin not found"))?;
|
||||||
let resp = plugin.send_request(method, &request)?;
|
let resp = plugin.request(method, &request)?;
|
||||||
let value = P::parse_response(resp)?;
|
let value = P::parse_json(resp)?;
|
||||||
Ok(value)
|
Ok(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,14 +111,14 @@ impl SidecarManager {
|
|||||||
id: PluginId,
|
id: PluginId,
|
||||||
method: &str,
|
method: &str,
|
||||||
request: Value,
|
request: Value,
|
||||||
) -> Result<P::ValueType, Error> {
|
) -> Result<P::ValueType, SidecarError> {
|
||||||
let state = self.state.lock();
|
let state = self.state.lock();
|
||||||
let plugin = state
|
let plugin = state
|
||||||
.plugins
|
.plugins
|
||||||
.iter()
|
.iter()
|
||||||
.find(|p| p.id == id)
|
.find(|p| p.id == id)
|
||||||
.ok_or(anyhow!("plugin not found"))?;
|
.ok_or(anyhow!("plugin not found"))?;
|
||||||
let value = plugin.async_send_request::<P>(method, &request).await?;
|
let value = plugin.async_request::<P>(method, &request).await?;
|
||||||
Ok(value)
|
Ok(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,13 @@
|
|||||||
use crate::core::parser::{ChatRelatedQuestionsResponseParser, ChatResponseParser};
|
use crate::core::parser::{
|
||||||
|
ChatRelatedQuestionsResponseParser, ChatResponseParser, ChatStreamResponseParser,
|
||||||
|
};
|
||||||
use crate::core::plugin::{Plugin, PluginId};
|
use crate::core::plugin::{Plugin, PluginId};
|
||||||
use crate::error::Error;
|
use crate::error::SidecarError;
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::sync::Weak;
|
use std::sync::Weak;
|
||||||
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
|
use tokio_stream::Stream;
|
||||||
|
|
||||||
pub struct ChatPluginOperation {
|
pub struct ChatPluginOperation {
|
||||||
plugin: Weak<Plugin>,
|
plugin: Weak<Plugin>,
|
||||||
@ -17,17 +21,17 @@ impl ChatPluginOperation {
|
|||||||
pub async fn send_message(
|
pub async fn send_message(
|
||||||
&self,
|
&self,
|
||||||
chat_id: &str,
|
chat_id: &str,
|
||||||
plugin_id: PluginId,
|
_plugin_id: PluginId,
|
||||||
message: &str,
|
message: &str,
|
||||||
) -> Result<String, Error> {
|
) -> Result<String, SidecarError> {
|
||||||
let plugin = self
|
let plugin = self
|
||||||
.plugin
|
.plugin
|
||||||
.upgrade()
|
.upgrade()
|
||||||
.ok_or(Error::Internal(anyhow!("Plugin is dropped")))?;
|
.ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
|
||||||
|
|
||||||
let params = json!({"chat_id": chat_id, "method": "answer", "params": {"content": message}});
|
let params = json!({"chat_id": chat_id, "method": "answer", "params": {"content": message}});
|
||||||
let resp = plugin
|
let resp = plugin
|
||||||
.async_send_request::<ChatResponseParser>("handle", ¶ms)
|
.async_request::<ChatResponseParser>("handle", ¶ms)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(resp)
|
Ok(resp)
|
||||||
}
|
}
|
||||||
@ -35,34 +39,31 @@ impl ChatPluginOperation {
|
|||||||
pub async fn stream_message(
|
pub async fn stream_message(
|
||||||
&self,
|
&self,
|
||||||
chat_id: &str,
|
chat_id: &str,
|
||||||
plugin_id: PluginId,
|
_plugin_id: PluginId,
|
||||||
message: &str,
|
message: &str,
|
||||||
) -> Result<String, Error> {
|
) -> Result<ReceiverStream<Result<String, SidecarError>>, SidecarError> {
|
||||||
let plugin = self
|
let plugin = self
|
||||||
.plugin
|
.plugin
|
||||||
.upgrade()
|
.upgrade()
|
||||||
.ok_or(Error::Internal(anyhow!("Plugin is dropped")))?;
|
.ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
|
||||||
|
|
||||||
let params =
|
let params =
|
||||||
json!({"chat_id": chat_id, "method": "stream_answer", "params": {"content": message}});
|
json!({"chat_id": chat_id, "method": "stream_answer", "params": {"content": message}});
|
||||||
let resp = plugin
|
plugin.stream_request::<ChatStreamResponseParser>("handle", ¶ms)
|
||||||
.async_send_request::<ChatResponseParser>("handle", ¶ms)
|
|
||||||
.await?;
|
|
||||||
Ok(resp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_related_questions(
|
pub async fn get_related_questions(
|
||||||
&self,
|
&self,
|
||||||
chat_id: &str,
|
chat_id: &str,
|
||||||
) -> Result<Vec<serde_json::Value>, Error> {
|
) -> Result<Vec<serde_json::Value>, SidecarError> {
|
||||||
let plugin = self
|
let plugin = self
|
||||||
.plugin
|
.plugin
|
||||||
.upgrade()
|
.upgrade()
|
||||||
.ok_or(Error::Internal(anyhow!("Plugin is dropped")))?;
|
.ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
|
||||||
|
|
||||||
let params = json!({"chat_id": chat_id, "method": "related_question"});
|
let params = json!({"chat_id": chat_id, "method": "related_question"});
|
||||||
let resp = plugin
|
let resp = plugin
|
||||||
.async_send_request::<ChatRelatedQuestionsResponseParser>("handle", ¶ms)
|
.async_request::<ChatRelatedQuestionsResponseParser>("handle", ¶ms)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(resp)
|
Ok(resp)
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
use crate::core::parser::{
|
use crate::core::parser::SimilarityResponseParser;
|
||||||
ChatRelatedQuestionsResponseParser, ChatResponseParser, SimilarityResponseParser,
|
use crate::core::plugin::Plugin;
|
||||||
};
|
use crate::error::SidecarError;
|
||||||
use crate::core::plugin::{Plugin, PluginId};
|
|
||||||
use crate::error::Error;
|
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::sync::Weak;
|
use std::sync::Weak;
|
||||||
@ -16,15 +14,19 @@ impl EmbeddingPluginOperation {
|
|||||||
EmbeddingPluginOperation { plugin }
|
EmbeddingPluginOperation { plugin }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn calculate_similarity(&self, message1: &str, message2: &str) -> Result<f64, Error> {
|
pub async fn calculate_similarity(
|
||||||
|
&self,
|
||||||
|
message1: &str,
|
||||||
|
message2: &str,
|
||||||
|
) -> Result<f64, SidecarError> {
|
||||||
let plugin = self
|
let plugin = self
|
||||||
.plugin
|
.plugin
|
||||||
.upgrade()
|
.upgrade()
|
||||||
.ok_or(Error::Internal(anyhow!("Plugin is dropped")))?;
|
.ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
|
||||||
let params =
|
let params =
|
||||||
json!({"method": "calculate_similarity", "params": {"src": message1, "dest": message2}});
|
json!({"method": "calculate_similarity", "params": {"src": message1, "dest": message2}});
|
||||||
plugin
|
plugin
|
||||||
.async_send_request::<SimilarityResponseParser>("handle", ¶ms)
|
.async_request::<SimilarityResponseParser>("handle", ¶ms)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
use crate::util::LocalAITest;
|
use crate::util::LocalAITest;
|
||||||
|
use tokio_stream::StreamExt;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn load_chat_model_test() {
|
async fn load_chat_model_test() {
|
||||||
if let Ok(test) = LocalAITest::new() {
|
if let Ok(test) = LocalAITest::new() {
|
||||||
let plugin_id = test.init_chat_plugin().await;
|
let plugin_id = test.init_chat_plugin().await;
|
||||||
let chat_id = uuid::Uuid::new_v4().to_string();
|
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||||
let resp = test.send_message(&chat_id, plugin_id, "hello world").await;
|
let resp = test
|
||||||
|
.send_chat_message(&chat_id, plugin_id, "hello world")
|
||||||
|
.await;
|
||||||
eprintln!("chat response: {:?}", resp);
|
eprintln!("chat response: {:?}", resp);
|
||||||
|
|
||||||
let embedding_plugin_id = test.init_embedding_plugin().await;
|
let embedding_plugin_id = test.init_embedding_plugin().await;
|
||||||
@ -17,3 +20,27 @@ async fn load_chat_model_test() {
|
|||||||
// eprintln!("related questions: {:?}", questions);
|
// eprintln!("related questions: {:?}", questions);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#[tokio::test]
|
||||||
|
async fn stream_local_model_test() {
|
||||||
|
if let Ok(test) = LocalAITest::new() {
|
||||||
|
let plugin_id = test.init_chat_plugin().await;
|
||||||
|
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
|
||||||
|
let mut resp = test
|
||||||
|
.stream_chat_message(&chat_id, plugin_id, "hello world")
|
||||||
|
.await;
|
||||||
|
let a = resp.next().await.unwrap().unwrap();
|
||||||
|
eprintln!("chat response: {:?}", a);
|
||||||
|
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||||
|
|
||||||
|
// let mut resp = test
|
||||||
|
// .stream_chat_message(&chat_id, plugin_id, "How are you")
|
||||||
|
// .await;
|
||||||
|
// let a = resp.next().await.unwrap().unwrap();
|
||||||
|
// eprintln!("chat response: {:?}", a);
|
||||||
|
// let questions = test.related_question(&chat_id, plugin_id).await;
|
||||||
|
// assert_eq!(questions.len(), 3);
|
||||||
|
// eprintln!("related questions: {:?}", questions);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use flowy_sidecar::manager::SidecarManager;
|
use flowy_sidecar::manager::SidecarManager;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::sync::{Arc, Once};
|
use std::sync::Once;
|
||||||
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
|
use tokio_stream::Stream;
|
||||||
|
|
||||||
use flowy_sidecar::core::parser::{ChatResponseParser, SimilarityResponseParser};
|
|
||||||
use flowy_sidecar::core::plugin::{PluginId, PluginInfo};
|
use flowy_sidecar::core::plugin::{PluginId, PluginInfo};
|
||||||
|
use flowy_sidecar::error::SidecarError;
|
||||||
use flowy_sidecar::plugins::chat_plugin::ChatPluginOperation;
|
use flowy_sidecar::plugins::chat_plugin::ChatPluginOperation;
|
||||||
use flowy_sidecar::plugins::embedding_plugin::EmbeddingPluginOperation;
|
use flowy_sidecar::plugins::embedding_plugin::EmbeddingPluginOperation;
|
||||||
use tracing_subscriber::fmt::Subscriber;
|
use tracing_subscriber::fmt::Subscriber;
|
||||||
@ -61,7 +63,12 @@ impl LocalAITest {
|
|||||||
plugin_id
|
plugin_id
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send_message(&self, chat_id: &str, plugin_id: PluginId, message: &str) -> String {
|
pub async fn send_chat_message(
|
||||||
|
&self,
|
||||||
|
chat_id: &str,
|
||||||
|
plugin_id: PluginId,
|
||||||
|
message: &str,
|
||||||
|
) -> String {
|
||||||
let plugin = self.manager.get_plugin(plugin_id).await.unwrap();
|
let plugin = self.manager.get_plugin(plugin_id).await.unwrap();
|
||||||
let operation = ChatPluginOperation::new(plugin);
|
let operation = ChatPluginOperation::new(plugin);
|
||||||
let resp = operation
|
let resp = operation
|
||||||
@ -72,6 +79,20 @@ impl LocalAITest {
|
|||||||
resp
|
resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn stream_chat_message(
|
||||||
|
&self,
|
||||||
|
chat_id: &str,
|
||||||
|
plugin_id: PluginId,
|
||||||
|
message: &str,
|
||||||
|
) -> ReceiverStream<Result<String, SidecarError>> {
|
||||||
|
let plugin = self.manager.get_plugin(plugin_id).await.unwrap();
|
||||||
|
let operation = ChatPluginOperation::new(plugin);
|
||||||
|
operation
|
||||||
|
.stream_message(chat_id, plugin_id, message)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn related_question(
|
pub async fn related_question(
|
||||||
&self,
|
&self,
|
||||||
chat_id: &str,
|
chat_id: &str,
|
||||||
|
@ -23,5 +23,6 @@ if_wasm! {
|
|||||||
pub mod isolate_stream;
|
pub mod isolate_stream;
|
||||||
pub mod priority_task;
|
pub mod priority_task;
|
||||||
pub mod ref_map;
|
pub mod ref_map;
|
||||||
|
pub mod stream_util;
|
||||||
pub mod util;
|
pub mod util;
|
||||||
pub mod validator_fn;
|
pub mod validator_fn;
|
||||||
|
21
frontend/rust-lib/lib-infra/src/stream_util.rs
Normal file
21
frontend/rust-lib/lib-infra/src/stream_util.rs
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
use futures_core::Stream;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio::sync::mpsc::{Receiver, Sender};
|
||||||
|
|
||||||
|
struct BoundedStream<T> {
|
||||||
|
recv: Receiver<T>,
|
||||||
|
}
|
||||||
|
impl<T> Stream for BoundedStream<T> {
|
||||||
|
type Item = T;
|
||||||
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
|
||||||
|
Pin::into_inner(self).recv.poll_recv(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mpsc_channel_stream<T: Unpin>(size: usize) -> (Sender<T>, impl Stream<Item = T>) {
|
||||||
|
let (tx, rx) = mpsc::channel(size);
|
||||||
|
let stream = BoundedStream { recv: rx };
|
||||||
|
(tx, stream)
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user