mirror of
https://github.com/AppFlowy-IO/AppFlowy.git
synced 2024-08-30 18:12:39 +00:00
chore: separate to new crate
This commit is contained in:
parent
e4f4a128be
commit
ffc75106f3
56
frontend/rust-lib/Cargo.lock
generated
56
frontend/rust-lib/Cargo.lock
generated
@ -194,6 +194,40 @@ dependencies = [
|
|||||||
"thiserror",
|
"thiserror",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "appflowy-local-ai-chat"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "git+https://github.com/AppFlowy-IO/AppFlowy-LocalAI?rev=b7f51a3f#b7f51a3fe79142582d89c4e577ccd36957cc2c00"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"appflowy-plugin",
|
||||||
|
"bytes",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"tokio",
|
||||||
|
"tokio-stream",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "appflowy-plugin"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "git+https://github.com/AppFlowy-IO/AppFlowy-LocalAI?rev=b7f51a3f#b7f51a3fe79142582d89c4e577ccd36957cc2c00"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"cfg-if",
|
||||||
|
"crossbeam-utils",
|
||||||
|
"log",
|
||||||
|
"once_cell",
|
||||||
|
"parking_lot 0.12.1",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"thiserror",
|
||||||
|
"tokio",
|
||||||
|
"tokio-stream",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "arc-swap"
|
name = "arc-swap"
|
||||||
version = "1.7.1"
|
version = "1.7.1"
|
||||||
@ -1667,6 +1701,8 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"allo-isolate",
|
"allo-isolate",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"appflowy-local-ai-chat",
|
||||||
|
"appflowy-plugin",
|
||||||
"bytes",
|
"bytes",
|
||||||
"dashmap",
|
"dashmap",
|
||||||
"dotenv",
|
"dotenv",
|
||||||
@ -1675,7 +1711,6 @@ dependencies = [
|
|||||||
"flowy-derive",
|
"flowy-derive",
|
||||||
"flowy-error",
|
"flowy-error",
|
||||||
"flowy-notification",
|
"flowy-notification",
|
||||||
"flowy-sidecar",
|
|
||||||
"flowy-sqlite",
|
"flowy-sqlite",
|
||||||
"futures",
|
"futures",
|
||||||
"lib-dispatch",
|
"lib-dispatch",
|
||||||
@ -1963,7 +1998,6 @@ dependencies = [
|
|||||||
"fancy-regex 0.11.0",
|
"fancy-regex 0.11.0",
|
||||||
"flowy-codegen",
|
"flowy-codegen",
|
||||||
"flowy-derive",
|
"flowy-derive",
|
||||||
"flowy-sidecar",
|
|
||||||
"flowy-sqlite",
|
"flowy-sqlite",
|
||||||
"lib-dispatch",
|
"lib-dispatch",
|
||||||
"protobuf",
|
"protobuf",
|
||||||
@ -2154,24 +2188,6 @@ dependencies = [
|
|||||||
"serde_repr",
|
"serde_repr",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "flowy-sidecar"
|
|
||||||
version = "0.1.0"
|
|
||||||
dependencies = [
|
|
||||||
"anyhow",
|
|
||||||
"crossbeam-utils",
|
|
||||||
"lib-infra",
|
|
||||||
"log",
|
|
||||||
"once_cell",
|
|
||||||
"parking_lot 0.12.1",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"thiserror",
|
|
||||||
"tokio",
|
|
||||||
"tokio-stream",
|
|
||||||
"tracing",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "flowy-sqlite"
|
name = "flowy-sqlite"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
@ -32,7 +32,6 @@ members = [
|
|||||||
"flowy-chat",
|
"flowy-chat",
|
||||||
"flowy-chat-pub",
|
"flowy-chat-pub",
|
||||||
"flowy-storage-pub",
|
"flowy-storage-pub",
|
||||||
"flowy-sidecar",
|
|
||||||
]
|
]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
@ -68,7 +67,6 @@ collab-integrate = { workspace = true, path = "collab-integrate" }
|
|||||||
flowy-date = { workspace = true, path = "flowy-date" }
|
flowy-date = { workspace = true, path = "flowy-date" }
|
||||||
flowy-chat = { workspace = true, path = "flowy-chat" }
|
flowy-chat = { workspace = true, path = "flowy-chat" }
|
||||||
flowy-chat-pub = { workspace = true, path = "flowy-chat-pub" }
|
flowy-chat-pub = { workspace = true, path = "flowy-chat-pub" }
|
||||||
flowy-sidecar = { workspace = true, path = "flowy-sidecar" }
|
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
tracing = "0.1.40"
|
tracing = "0.1.40"
|
||||||
bytes = "1.5.0"
|
bytes = "1.5.0"
|
||||||
@ -146,3 +144,6 @@ collab-document = { version = "0.2", git = "https://github.com/AppFlowy-IO/AppFl
|
|||||||
collab-database = { version = "0.2", git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3a58d95" }
|
collab-database = { version = "0.2", git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3a58d95" }
|
||||||
collab-plugins = { version = "0.2", git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3a58d95" }
|
collab-plugins = { version = "0.2", git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3a58d95" }
|
||||||
collab-user = { version = "0.2", git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3a58d95" }
|
collab-user = { version = "0.2", git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev = "3a58d95" }
|
||||||
|
|
||||||
|
appflowy-local-ai-chat = { version = "0.1", git = "https://github.com/AppFlowy-IO/AppFlowy-LocalAI", rev = "b7f51a3f" }
|
||||||
|
appflowy-plugin = { version = "0.1", git = "https://github.com/AppFlowy-IO/AppFlowy-LocalAI", rev = "b7f51a3f" }
|
||||||
|
@ -12,7 +12,6 @@ flowy-error = { path = "../flowy-error", features = [
|
|||||||
"impl_from_dispatch_error",
|
"impl_from_dispatch_error",
|
||||||
"impl_from_collab_folder",
|
"impl_from_collab_folder",
|
||||||
"impl_from_sqlite",
|
"impl_from_sqlite",
|
||||||
"impl_from_sidecar"
|
|
||||||
] }
|
] }
|
||||||
lib-dispatch = { workspace = true }
|
lib-dispatch = { workspace = true }
|
||||||
tracing.workspace = true
|
tracing.workspace = true
|
||||||
@ -29,12 +28,13 @@ tokio.workspace = true
|
|||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
allo-isolate = { version = "^0.1", features = ["catch-unwind"] }
|
allo-isolate = { version = "^0.1", features = ["catch-unwind"] }
|
||||||
log = "0.4.21"
|
log = "0.4.21"
|
||||||
flowy-sidecar = { workspace = true, features = ["verbose"] }
|
|
||||||
serde = { workspace = true, features = ["derive"] }
|
serde = { workspace = true, features = ["derive"] }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
anyhow = "1.0.86"
|
anyhow = "1.0.86"
|
||||||
tokio-stream = "0.1.15"
|
tokio-stream = "0.1.15"
|
||||||
parking_lot.workspace = true
|
parking_lot.workspace = true
|
||||||
|
appflowy-local-ai-chat = { version = "0.1.0", features = ["verbose"] }
|
||||||
|
appflowy-plugin = { version = "0.1.0", features = ["verbose"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
dotenv = "0.15.0"
|
dotenv = "0.15.0"
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
use crate::chat::Chat;
|
use crate::chat::Chat;
|
||||||
use crate::chat_service_impl::ChatService;
|
use crate::chat_service_impl::ChatService;
|
||||||
use crate::entities::{ChatMessageListPB, ChatMessagePB, RepeatedRelatedQuestionPB};
|
use crate::entities::{ChatMessageListPB, ChatMessagePB, RepeatedRelatedQuestionPB};
|
||||||
use crate::local_ai::llm_chat::{LocalChatLLMChat, LocalLLMSetting};
|
|
||||||
use crate::persistence::{insert_chat, ChatTable};
|
use crate::persistence::{insert_chat, ChatTable};
|
||||||
|
use appflowy_local_ai_chat::llm_chat::{LocalChatLLMChat, LocalLLMSetting};
|
||||||
|
use appflowy_plugin::manager::SidecarManager;
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use flowy_chat_pub::cloud::{ChatCloudService, ChatMessageType};
|
use flowy_chat_pub::cloud::{ChatCloudService, ChatMessageType};
|
||||||
use flowy_error::{FlowyError, FlowyResult};
|
use flowy_error::{FlowyError, FlowyResult};
|
||||||
use flowy_sidecar::manager::SidecarManager;
|
|
||||||
use flowy_sqlite::kv::KVStorePreferences;
|
use flowy_sqlite::kv::KVStorePreferences;
|
||||||
use flowy_sqlite::DBConnection;
|
use flowy_sqlite::DBConnection;
|
||||||
use lib_infra::util::timestamp;
|
use lib_infra::util::timestamp;
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::chat_manager::ChatUserService;
|
use crate::chat_manager::ChatUserService;
|
||||||
use crate::local_ai::llm_chat::{LocalChatLLMChat, LocalLLMSetting};
|
|
||||||
use crate::persistence::select_single_message;
|
use crate::persistence::select_single_message;
|
||||||
|
use appflowy_local_ai_chat::llm_chat::{LocalChatLLMChat, LocalLLMSetting};
|
||||||
use flowy_chat_pub::cloud::{
|
use flowy_chat_pub::cloud::{
|
||||||
ChatCloudService, ChatMessage, ChatMessageType, CompletionType, MessageCursor,
|
ChatCloudService, ChatMessage, ChatMessageType, CompletionType, MessageCursor,
|
||||||
RepeatedChatMessage, RepeatedRelatedQuestion, StreamAnswer, StreamComplete,
|
RepeatedChatMessage, RepeatedRelatedQuestion, StreamAnswer, StreamComplete,
|
||||||
@ -141,7 +141,7 @@ impl ChatCloudService for ChatService {
|
|||||||
.local_llm_chat
|
.local_llm_chat
|
||||||
.ask_question(chat_id, &content)
|
.ask_question(chat_id, &content)
|
||||||
.await?
|
.await?
|
||||||
.map_err(FlowyError::from);
|
.map_err(|err| FlowyError::local_ai().with_context(err));
|
||||||
Ok(stream.boxed())
|
Ok(stream.boxed())
|
||||||
} else {
|
} else {
|
||||||
self
|
self
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::local_ai::llm_chat::LocalLLMSetting;
|
use appflowy_local_ai_chat::llm_chat::LocalLLMSetting;
|
||||||
use flowy_chat_pub::cloud::{
|
use flowy_chat_pub::cloud::{
|
||||||
ChatMessage, RelatedQuestion, RepeatedChatMessage, RepeatedRelatedQuestion,
|
ChatMessage, RelatedQuestion, RepeatedChatMessage, RepeatedRelatedQuestion,
|
||||||
};
|
};
|
||||||
|
@ -5,7 +5,6 @@ mod chat;
|
|||||||
pub mod chat_manager;
|
pub mod chat_manager;
|
||||||
mod chat_service_impl;
|
mod chat_service_impl;
|
||||||
pub mod entities;
|
pub mod entities;
|
||||||
pub mod local_ai;
|
|
||||||
pub mod notification;
|
pub mod notification;
|
||||||
mod persistence;
|
mod persistence;
|
||||||
mod protobuf;
|
mod protobuf;
|
||||||
|
@ -1,128 +0,0 @@
|
|||||||
use anyhow::anyhow;
|
|
||||||
use bytes::Bytes;
|
|
||||||
use flowy_error::FlowyError;
|
|
||||||
use flowy_sidecar::core::parser::{DefaultResponseParser, ResponseParser};
|
|
||||||
use flowy_sidecar::core::plugin::Plugin;
|
|
||||||
use flowy_sidecar::error::{RemoteError, SidecarError};
|
|
||||||
use serde_json::json;
|
|
||||||
use serde_json::Value as JsonValue;
|
|
||||||
use std::sync::Weak;
|
|
||||||
use tokio_stream::wrappers::ReceiverStream;
|
|
||||||
use tracing::instrument;
|
|
||||||
|
|
||||||
pub struct ChatPluginOperation {
|
|
||||||
plugin: Weak<Plugin>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChatPluginOperation {
|
|
||||||
pub fn new(plugin: Weak<Plugin>) -> Self {
|
|
||||||
ChatPluginOperation { plugin }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_plugin(&self) -> Result<std::sync::Arc<Plugin>, SidecarError> {
|
|
||||||
self
|
|
||||||
.plugin
|
|
||||||
.upgrade()
|
|
||||||
.ok_or_else(|| SidecarError::Internal(anyhow!("Plugin is dropped")))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_request<T: ResponseParser>(
|
|
||||||
&self,
|
|
||||||
method: &str,
|
|
||||||
params: JsonValue,
|
|
||||||
) -> Result<T::ValueType, SidecarError> {
|
|
||||||
let plugin = self.get_plugin()?;
|
|
||||||
let mut request = json!({ "method": method });
|
|
||||||
request
|
|
||||||
.as_object_mut()
|
|
||||||
.unwrap()
|
|
||||||
.extend(params.as_object().unwrap().clone());
|
|
||||||
plugin.async_request::<T>("handle", &request).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn create_chat(&self, chat_id: &str) -> Result<(), SidecarError> {
|
|
||||||
self
|
|
||||||
.send_request::<DefaultResponseParser>("create_chat", json!({ "chat_id": chat_id }))
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn close_chat(&self, chat_id: &str) -> Result<(), SidecarError> {
|
|
||||||
self
|
|
||||||
.send_request::<DefaultResponseParser>("close_chat", json!({ "chat_id": chat_id }))
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn send_message(&self, chat_id: &str, message: &str) -> Result<String, SidecarError> {
|
|
||||||
self
|
|
||||||
.send_request::<ChatResponseParser>(
|
|
||||||
"answer",
|
|
||||||
json!({ "chat_id": chat_id, "params": { "content": message } }),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(level = "debug", skip(self), err)]
|
|
||||||
pub async fn stream_message(
|
|
||||||
&self,
|
|
||||||
chat_id: &str,
|
|
||||||
message: &str,
|
|
||||||
) -> Result<ReceiverStream<Result<Bytes, SidecarError>>, FlowyError> {
|
|
||||||
let plugin = self
|
|
||||||
.get_plugin()
|
|
||||||
.map_err(|err| FlowyError::internal().with_context(err.to_string()))?;
|
|
||||||
let params = json!({
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"method": "stream_answer",
|
|
||||||
"params": { "content": message }
|
|
||||||
});
|
|
||||||
plugin
|
|
||||||
.stream_request::<ChatStreamResponseParser>("handle", ¶ms)
|
|
||||||
.map_err(|err| FlowyError::internal().with_context(err.to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_related_questions(&self, chat_id: &str) -> Result<Vec<JsonValue>, SidecarError> {
|
|
||||||
self
|
|
||||||
.send_request::<ChatRelatedQuestionsResponseParser>(
|
|
||||||
"related_question",
|
|
||||||
json!({ "chat_id": chat_id }),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ChatResponseParser;
|
|
||||||
impl ResponseParser for ChatResponseParser {
|
|
||||||
type ValueType = String;
|
|
||||||
|
|
||||||
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
|
||||||
json
|
|
||||||
.get("data")
|
|
||||||
.and_then(|data| data.as_str())
|
|
||||||
.map(String::from)
|
|
||||||
.ok_or(RemoteError::ParseResponse(json))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ChatStreamResponseParser;
|
|
||||||
impl ResponseParser for ChatStreamResponseParser {
|
|
||||||
type ValueType = Bytes;
|
|
||||||
|
|
||||||
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
|
||||||
json
|
|
||||||
.as_str()
|
|
||||||
.map(|message| Bytes::from(message.to_string()))
|
|
||||||
.ok_or(RemoteError::ParseResponse(json))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ChatRelatedQuestionsResponseParser;
|
|
||||||
impl ResponseParser for ChatRelatedQuestionsResponseParser {
|
|
||||||
type ValueType = Vec<JsonValue>;
|
|
||||||
|
|
||||||
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
|
||||||
json
|
|
||||||
.get("data")
|
|
||||||
.and_then(|data| data.as_array()).cloned()
|
|
||||||
.ok_or(RemoteError::ParseResponse(json))
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,62 +0,0 @@
|
|||||||
use anyhow::anyhow;
|
|
||||||
use flowy_sidecar::core::parser::ResponseParser;
|
|
||||||
use flowy_sidecar::core::plugin::Plugin;
|
|
||||||
use flowy_sidecar::error::{RemoteError, SidecarError};
|
|
||||||
use serde_json::json;
|
|
||||||
use serde_json::Value as JsonValue;
|
|
||||||
use std::sync::Weak;
|
|
||||||
|
|
||||||
pub struct EmbeddingPluginOperation {
|
|
||||||
plugin: Weak<Plugin>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EmbeddingPluginOperation {
|
|
||||||
pub fn new(plugin: Weak<Plugin>) -> Self {
|
|
||||||
EmbeddingPluginOperation { plugin }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_embeddings(&self, message: &str) -> Result<Vec<Vec<f64>>, SidecarError> {
|
|
||||||
let plugin = self
|
|
||||||
.plugin
|
|
||||||
.upgrade()
|
|
||||||
.ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
|
|
||||||
let params = json!({"method": "get_embeddings", "params": {"input": message }});
|
|
||||||
plugin
|
|
||||||
.async_request::<EmbeddingResponseParse>("handle", ¶ms)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct EmbeddingResponseParse;
|
|
||||||
impl ResponseParser for EmbeddingResponseParse {
|
|
||||||
type ValueType = Vec<Vec<f64>>;
|
|
||||||
|
|
||||||
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
|
||||||
if json.is_object() {
|
|
||||||
if let Some(data) = json.get("data") {
|
|
||||||
if let Some(embeddings) = data.get("embeddings") {
|
|
||||||
if let Some(array) = embeddings.as_array() {
|
|
||||||
let mut result = Vec::new();
|
|
||||||
for item in array {
|
|
||||||
if let Some(inner_array) = item.as_array() {
|
|
||||||
let mut inner_result = Vec::new();
|
|
||||||
for num in inner_array {
|
|
||||||
if let Some(value) = num.as_f64() {
|
|
||||||
inner_result.push(value);
|
|
||||||
} else {
|
|
||||||
return Err(RemoteError::ParseResponse(json));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result.push(inner_result);
|
|
||||||
} else {
|
|
||||||
return Err(RemoteError::ParseResponse(json));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Ok(result);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(RemoteError::ParseResponse(json))
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,344 +0,0 @@
|
|||||||
use crate::local_ai::chat_plugin::ChatPluginOperation;
|
|
||||||
use bytes::Bytes;
|
|
||||||
use flowy_error::{FlowyError, FlowyResult};
|
|
||||||
use flowy_sidecar::core::plugin::{Plugin, PluginId, PluginInfo};
|
|
||||||
use flowy_sidecar::error::SidecarError;
|
|
||||||
use flowy_sidecar::manager::SidecarManager;
|
|
||||||
use lib_infra::util::{get_operating_system, OperatingSystem};
|
|
||||||
use log::error;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::path::PathBuf;
|
|
||||||
use std::sync::{Arc, Weak};
|
|
||||||
use std::time::Duration;
|
|
||||||
use tokio::sync::RwLock;
|
|
||||||
use tokio::time::timeout;
|
|
||||||
use tokio_stream::wrappers::ReceiverStream;
|
|
||||||
use tracing::{info, instrument, trace};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
|
||||||
pub struct LocalLLMSetting {
|
|
||||||
pub chat_bin_path: String,
|
|
||||||
pub chat_model_path: String,
|
|
||||||
pub enabled: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LocalLLMSetting {
|
|
||||||
pub fn validate(&self) -> FlowyResult<()> {
|
|
||||||
ChatPluginConfig::new(&self.chat_bin_path, &self.chat_model_path)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
pub fn chat_config(&self) -> FlowyResult<ChatPluginConfig> {
|
|
||||||
let config = ChatPluginConfig::new(&self.chat_bin_path, &self.chat_model_path)?;
|
|
||||||
Ok(config)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct LocalChatLLMChat {
|
|
||||||
sidecar_manager: Arc<SidecarManager>,
|
|
||||||
state: RwLock<LLMState>,
|
|
||||||
state_notify: tokio::sync::broadcast::Sender<LLMState>,
|
|
||||||
plugin_config: RwLock<Option<ChatPluginConfig>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LocalChatLLMChat {
|
|
||||||
pub fn new(sidecar_manager: Arc<SidecarManager>) -> Self {
|
|
||||||
let (state_notify, _) = tokio::sync::broadcast::channel(10);
|
|
||||||
Self {
|
|
||||||
sidecar_manager,
|
|
||||||
state: RwLock::new(LLMState::Loading),
|
|
||||||
state_notify,
|
|
||||||
plugin_config: Default::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn update_state(&self, state: LLMState) {
|
|
||||||
*self.state.write().await = state.clone();
|
|
||||||
let _ = self.state_notify.send(state);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Waits for the plugin to be ready.
|
|
||||||
///
|
|
||||||
/// The wait_plugin_ready method is an asynchronous function designed to ensure that the chat
|
|
||||||
/// plugin is in a ready state before allowing further operations. This is crucial for maintaining
|
|
||||||
/// the correct sequence of operations and preventing errors that could occur if operations are
|
|
||||||
/// attempted on an unready plugin.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// A `FlowyResult<()>` indicating success or failure.
|
|
||||||
async fn wait_plugin_ready(&self) -> FlowyResult<()> {
|
|
||||||
let is_loading = self.state.read().await.is_loading();
|
|
||||||
if !is_loading {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
info!("[Chat Plugin] wait for chat plugin to be ready");
|
|
||||||
let mut rx = self.state_notify.subscribe();
|
|
||||||
let timeout_duration = Duration::from_secs(30);
|
|
||||||
let result = timeout(timeout_duration, async {
|
|
||||||
while let Ok(state) = rx.recv().await {
|
|
||||||
if state.is_ready() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(_) => {
|
|
||||||
trace!("[Chat Plugin] chat plugin is ready");
|
|
||||||
Ok(())
|
|
||||||
},
|
|
||||||
Err(_) => Err(
|
|
||||||
FlowyError::local_ai().with_context("Timeout while waiting for chat plugin to be ready"),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Retrieves the chat plugin.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// A `FlowyResult<Weak<Plugin>>` containing a weak reference to the plugin.
|
|
||||||
async fn get_chat_plugin(&self) -> FlowyResult<Weak<Plugin>> {
|
|
||||||
let plugin_id = self.state.read().await.plugin_id()?;
|
|
||||||
let plugin = self.sidecar_manager.get_plugin(plugin_id).await?;
|
|
||||||
Ok(plugin)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new chat session.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `chat_id` - A string slice containing the unique identifier for the chat session.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// A `FlowyResult<()>` indicating success or failure.
|
|
||||||
pub async fn create_chat(&self, chat_id: &str) -> FlowyResult<()> {
|
|
||||||
trace!("[Chat Plugin] create chat: {}", chat_id);
|
|
||||||
self.wait_plugin_ready().await?;
|
|
||||||
|
|
||||||
let plugin = self.get_chat_plugin().await?;
|
|
||||||
let operation = ChatPluginOperation::new(plugin);
|
|
||||||
operation.create_chat(chat_id).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Closes an existing chat session.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `chat_id` - A string slice containing the unique identifier for the chat session to close.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// A `FlowyResult<()>` indicating success or failure.
|
|
||||||
pub async fn close_chat(&self, chat_id: &str) -> FlowyResult<()> {
|
|
||||||
trace!("[Chat Plugin] close chat: {}", chat_id);
|
|
||||||
let plugin = self.get_chat_plugin().await?;
|
|
||||||
let operation = ChatPluginOperation::new(plugin);
|
|
||||||
operation.close_chat(chat_id).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Asks a question and returns a stream of responses.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `chat_id` - A string slice containing the unique identifier for the chat session.
|
|
||||||
/// * `message` - A string slice containing the question or message to send.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// A `FlowyResult<ReceiverStream<anyhow::Result<Bytes, SidecarError>>>` containing a stream of responses.
|
|
||||||
pub async fn ask_question(
|
|
||||||
&self,
|
|
||||||
chat_id: &str,
|
|
||||||
message: &str,
|
|
||||||
) -> FlowyResult<ReceiverStream<anyhow::Result<Bytes, SidecarError>>> {
|
|
||||||
trace!("[Chat Plugin] ask question: {}", message);
|
|
||||||
self.wait_plugin_ready().await?;
|
|
||||||
let plugin = self.get_chat_plugin().await?;
|
|
||||||
let operation = ChatPluginOperation::new(plugin);
|
|
||||||
let stream = operation.stream_message(chat_id, message).await?;
|
|
||||||
Ok(stream)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a complete answer for a given message.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `chat_id` - A string slice containing the unique identifier for the chat session.
|
|
||||||
/// * `message` - A string slice containing the message to generate an answer for.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// A `FlowyResult<String>` containing the generated answer.
|
|
||||||
pub async fn generate_answer(&self, chat_id: &str, message: &str) -> FlowyResult<String> {
|
|
||||||
let plugin = self.get_chat_plugin().await?;
|
|
||||||
let operation = ChatPluginOperation::new(plugin);
|
|
||||||
let answer = operation.send_message(chat_id, message).await?;
|
|
||||||
Ok(answer)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all, err)]
|
|
||||||
pub async fn destroy_chat_plugin(&self) -> FlowyResult<()> {
|
|
||||||
if let Ok(plugin_id) = self.state.read().await.plugin_id() {
|
|
||||||
if let Err(err) = self.sidecar_manager.remove_plugin(plugin_id).await {
|
|
||||||
error!("remove plugin failed: {:?}", err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
self.update_state(LLMState::Uninitialized).await;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip_all, err)]
|
|
||||||
pub async fn init_chat_plugin(&self, config: ChatPluginConfig) -> FlowyResult<()> {
|
|
||||||
if self.state.read().await.is_ready() {
|
|
||||||
if let Some(existing_config) = self.plugin_config.read().await.as_ref() {
|
|
||||||
if existing_config == &config {
|
|
||||||
trace!("[Chat Plugin] chat plugin already initialized with the same config");
|
|
||||||
return Ok(());
|
|
||||||
} else {
|
|
||||||
trace!(
|
|
||||||
"[Chat Plugin] existing config: {:?}, new config:{:?}",
|
|
||||||
existing_config,
|
|
||||||
config
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let system = get_operating_system();
|
|
||||||
// Initialize chat plugin if the config is different
|
|
||||||
// If the chat_bin_path is different, remove the old plugin
|
|
||||||
if let Err(err) = self.destroy_chat_plugin().await {
|
|
||||||
error!("[Chat Plugin] failed to destroy plugin: {:?}", err);
|
|
||||||
}
|
|
||||||
self.update_state(LLMState::Loading).await;
|
|
||||||
|
|
||||||
// create new plugin
|
|
||||||
trace!("[Chat Plugin] create chat plugin: {:?}", config);
|
|
||||||
let plugin_info = PluginInfo {
|
|
||||||
name: "chat_plugin".to_string(),
|
|
||||||
exec_path: config.chat_bin_path.clone(),
|
|
||||||
};
|
|
||||||
let plugin_id = self.sidecar_manager.create_plugin(plugin_info).await?;
|
|
||||||
|
|
||||||
// init plugin
|
|
||||||
trace!("[Chat Plugin] init chat plugin model: {:?}", plugin_id);
|
|
||||||
let model_path = config.chat_model_path.clone();
|
|
||||||
let params = match system {
|
|
||||||
OperatingSystem::Windows => {
|
|
||||||
serde_json::json!({
|
|
||||||
"absolute_chat_model_path": model_path,
|
|
||||||
"device": "cpu",
|
|
||||||
})
|
|
||||||
},
|
|
||||||
OperatingSystem::Linux => {
|
|
||||||
serde_json::json!({
|
|
||||||
"absolute_chat_model_path": model_path,
|
|
||||||
"device": "cpu",
|
|
||||||
})
|
|
||||||
},
|
|
||||||
OperatingSystem::MacOS => {
|
|
||||||
serde_json::json!({
|
|
||||||
"absolute_chat_model_path": model_path,
|
|
||||||
"device": "gpu",
|
|
||||||
})
|
|
||||||
},
|
|
||||||
_ => {
|
|
||||||
return Err(FlowyError::local_ai().with_context("Unsupported operating system"));
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
info!(
|
|
||||||
"[Chat Plugin] setup chat plugin: {:?}, params: {:?}",
|
|
||||||
plugin_id, params
|
|
||||||
);
|
|
||||||
let plugin = self.sidecar_manager.init_plugin(plugin_id, params)?;
|
|
||||||
info!("[Chat Plugin] {} setup success", plugin);
|
|
||||||
self.plugin_config.write().await.replace(config);
|
|
||||||
self.update_state(LLMState::Ready { plugin_id }).await;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Eq, PartialEq, Debug, Clone)]
|
|
||||||
pub struct ChatPluginConfig {
|
|
||||||
chat_bin_path: PathBuf,
|
|
||||||
chat_model_path: PathBuf,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChatPluginConfig {
|
|
||||||
pub fn new(chat_bin: &str, chat_model_path: &str) -> FlowyResult<Self> {
|
|
||||||
let chat_bin_path = PathBuf::from(chat_bin);
|
|
||||||
if !chat_bin_path.exists() {
|
|
||||||
return Err(FlowyError::invalid_data().with_context(format!(
|
|
||||||
"Chat binary path does not exist: {:?}",
|
|
||||||
chat_bin_path
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
if !chat_bin_path.is_file() {
|
|
||||||
return Err(FlowyError::invalid_data().with_context(format!(
|
|
||||||
"Chat binary path is not a file: {:?}",
|
|
||||||
chat_bin_path
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if local_model_dir exists and is a directory
|
|
||||||
let chat_model_path = PathBuf::from(&chat_model_path);
|
|
||||||
if !chat_model_path.exists() {
|
|
||||||
return Err(
|
|
||||||
FlowyError::invalid_data()
|
|
||||||
.with_context(format!("Local model does not exist: {:?}", chat_model_path)),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if !chat_model_path.is_file() {
|
|
||||||
return Err(
|
|
||||||
FlowyError::invalid_data()
|
|
||||||
.with_context(format!("Local model is not a file: {:?}", chat_model_path)),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
chat_bin_path,
|
|
||||||
chat_model_path,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum LLMState {
|
|
||||||
Uninitialized,
|
|
||||||
Loading,
|
|
||||||
Ready { plugin_id: PluginId },
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LLMState {
|
|
||||||
fn plugin_id(&self) -> FlowyResult<PluginId> {
|
|
||||||
match self {
|
|
||||||
LLMState::Ready { plugin_id } => Ok(*plugin_id),
|
|
||||||
_ => Err(FlowyError::local_ai().with_context("chat plugin is not ready")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_loading(&self) -> bool {
|
|
||||||
matches!(self, LLMState::Loading)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
fn is_uninitialized(&self) -> bool {
|
|
||||||
matches!(self, LLMState::Uninitialized)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_ready(&self) -> bool {
|
|
||||||
let system = get_operating_system();
|
|
||||||
if system.is_desktop() {
|
|
||||||
return matches!(self, LLMState::Ready { .. });
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,4 +0,0 @@
|
|||||||
pub mod chat_plugin;
|
|
||||||
pub mod llm_chat;
|
|
||||||
|
|
||||||
pub mod embedding_plugin;
|
|
@ -33,7 +33,6 @@ collab-plugins = { workspace = true, optional = true }
|
|||||||
collab-folder = { workspace = true, optional = true }
|
collab-folder = { workspace = true, optional = true }
|
||||||
client-api = { workspace = true, optional = true }
|
client-api = { workspace = true, optional = true }
|
||||||
tantivy = { version = "0.21.1", optional = true }
|
tantivy = { version = "0.21.1", optional = true }
|
||||||
flowy-sidecar = { workspace = true, optional = true }
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
impl_from_dispatch_error = ["lib-dispatch"]
|
impl_from_dispatch_error = ["lib-dispatch"]
|
||||||
@ -49,7 +48,6 @@ impl_from_collab_folder = ["collab-folder"]
|
|||||||
impl_from_collab_database = ["collab-database"]
|
impl_from_collab_database = ["collab-database"]
|
||||||
impl_from_url = ["url"]
|
impl_from_url = ["url"]
|
||||||
impl_from_tantivy = ["tantivy"]
|
impl_from_tantivy = ["tantivy"]
|
||||||
impl_from_sidecar = ["flowy-sidecar"]
|
|
||||||
|
|
||||||
impl_from_sqlite = ["flowy-sqlite", "r2d2"]
|
impl_from_sqlite = ["flowy-sqlite", "r2d2"]
|
||||||
impl_from_appflowy_cloud = ["client-api"]
|
impl_from_appflowy_cloud = ["client-api"]
|
||||||
|
@ -1,8 +0,0 @@
|
|||||||
use crate::{ErrorCode, FlowyError};
|
|
||||||
use flowy_sidecar::error::SidecarError;
|
|
||||||
|
|
||||||
impl std::convert::From<SidecarError> for FlowyError {
|
|
||||||
fn from(error: SidecarError) -> Self {
|
|
||||||
FlowyError::new(ErrorCode::LocalAIError, error)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,23 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "flowy-sidecar"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2021"
|
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
anyhow = { version = "1.0" }
|
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
|
||||||
tokio = { version = "1.0", features = ["full"] }
|
|
||||||
once_cell = "1.19.0"
|
|
||||||
thiserror = "1.0"
|
|
||||||
serde_json = "1.0.117"
|
|
||||||
tracing.workspace = true
|
|
||||||
crossbeam-utils = "0.8.20"
|
|
||||||
log = "0.4.21"
|
|
||||||
parking_lot.workspace = true
|
|
||||||
tokio-stream = "0.1.15"
|
|
||||||
lib-infra.workspace = true
|
|
||||||
|
|
||||||
[features]
|
|
||||||
verbose = []
|
|
@ -1,5 +0,0 @@
|
|||||||
pub mod parser;
|
|
||||||
pub mod plugin;
|
|
||||||
pub mod rpc_loop;
|
|
||||||
mod rpc_object;
|
|
||||||
pub mod rpc_peer;
|
|
@ -1,71 +0,0 @@
|
|||||||
use crate::core::rpc_object::RpcObject;
|
|
||||||
|
|
||||||
use crate::error::{ReadError, RemoteError};
|
|
||||||
use serde_json::{json, Value as JsonValue};
|
|
||||||
use std::io::BufRead;
|
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
|
||||||
pub struct MessageReader(String);
|
|
||||||
|
|
||||||
impl MessageReader {
|
|
||||||
/// Attempts to read the next line from the stream and parse it as
|
|
||||||
/// an RPC object.
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// This function will return an error if there is an underlying
|
|
||||||
/// I/O error, if the stream is closed, or if the message is not
|
|
||||||
/// a valid JSON object.
|
|
||||||
pub fn next<R: BufRead>(&mut self, reader: &mut R) -> Result<RpcObject, ReadError> {
|
|
||||||
self.0.clear();
|
|
||||||
let _ = reader.read_line(&mut self.0)?;
|
|
||||||
if self.0.is_empty() {
|
|
||||||
Err(ReadError::Disconnect("Empty line".to_string()))
|
|
||||||
} else {
|
|
||||||
self.parse(&self.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Attempts to parse a &str as an RPC Object.
|
|
||||||
///
|
|
||||||
/// This should not be called directly unless you are writing tests.
|
|
||||||
#[doc(hidden)]
|
|
||||||
pub fn parse(&self, s: &str) -> Result<RpcObject, ReadError> {
|
|
||||||
match serde_json::from_str::<JsonValue>(s) {
|
|
||||||
Ok(val) => {
|
|
||||||
if !val.is_object() {
|
|
||||||
Err(ReadError::NotObject(s.to_string()))
|
|
||||||
} else {
|
|
||||||
Ok(val.into())
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Err(_) => Ok(RpcObject(json!({"message": s.to_string()}))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type RequestId = u64;
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
/// An RPC call, which may be either a notification or a request.
|
|
||||||
pub enum Call<R> {
|
|
||||||
Message(JsonValue),
|
|
||||||
/// An id and an RPC Request
|
|
||||||
Request(RequestId, R),
|
|
||||||
/// A malformed request: the request contained an id, but could
|
|
||||||
/// not be parsed. The client will receive an error.
|
|
||||||
InvalidRequest(RequestId, RemoteError),
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait ResponseParser {
|
|
||||||
type ValueType: Send + Sync + 'static;
|
|
||||||
fn parse_json(payload: JsonValue) -> Result<Self::ValueType, RemoteError>;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct DefaultResponseParser;
|
|
||||||
impl ResponseParser for DefaultResponseParser {
|
|
||||||
type ValueType = ();
|
|
||||||
|
|
||||||
fn parse_json(_payload: JsonValue) -> Result<Self::ValueType, RemoteError> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,205 +0,0 @@
|
|||||||
use crate::error::SidecarError;
|
|
||||||
use crate::manager::WeakSidecarState;
|
|
||||||
use std::fmt::{Display, Formatter};
|
|
||||||
|
|
||||||
use crate::core::parser::ResponseParser;
|
|
||||||
use crate::core::rpc_loop::RpcLoop;
|
|
||||||
use crate::core::rpc_peer::{CloneableCallback, OneShotCallback};
|
|
||||||
use anyhow::anyhow;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::{json, Value as JsonValue};
|
|
||||||
use std::io::BufReader;
|
|
||||||
use std::path::PathBuf;
|
|
||||||
use std::process::{Child, Stdio};
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::thread;
|
|
||||||
use std::time::Instant;
|
|
||||||
use tokio_stream::wrappers::ReceiverStream;
|
|
||||||
|
|
||||||
use tracing::{error, info};
|
|
||||||
|
|
||||||
#[derive(
|
|
||||||
Default, Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize,
|
|
||||||
)]
|
|
||||||
pub struct PluginId(pub(crate) i64);
|
|
||||||
|
|
||||||
impl From<i64> for PluginId {
|
|
||||||
fn from(id: i64) -> Self {
|
|
||||||
PluginId(id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The `Peer` trait defines the interface for the opposite side of the RPC channel,
|
|
||||||
/// designed to be used behind a pointer or as a trait object.
|
|
||||||
pub trait Peer: Send + Sync + 'static {
|
|
||||||
/// Clones the peer into a boxed trait object.
|
|
||||||
fn box_clone(&self) -> Arc<dyn Peer>;
|
|
||||||
|
|
||||||
/// Sends an RPC notification to the peer with the specified method and parameters.
|
|
||||||
fn send_rpc_notification(&self, method: &str, params: &JsonValue);
|
|
||||||
|
|
||||||
fn stream_rpc_request(&self, method: &str, params: &JsonValue, f: CloneableCallback);
|
|
||||||
|
|
||||||
fn async_send_rpc_request(&self, method: &str, params: &JsonValue, f: Box<dyn OneShotCallback>);
|
|
||||||
/// Sends a synchronous RPC request to the peer and waits for the result.
|
|
||||||
/// Returns the result of the request or an error.
|
|
||||||
fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, SidecarError>;
|
|
||||||
|
|
||||||
/// Checks if there is an incoming request pending, intended to reduce latency for bulk operations done in the background.
|
|
||||||
fn request_is_pending(&self) -> bool;
|
|
||||||
|
|
||||||
/// Schedules a timer to execute the handler's `idle` function after the specified `Instant`.
|
|
||||||
/// Note: This is not a high-fidelity timer. Regular RPC messages will always take priority over idle tasks.
|
|
||||||
fn schedule_timer(&self, after: Instant, token: usize);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The `Peer` trait object.
|
|
||||||
pub type RpcPeer = Arc<dyn Peer>;
|
|
||||||
|
|
||||||
pub struct RpcCtx {
|
|
||||||
pub peer: RpcPeer,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct Plugin {
|
|
||||||
peer: RpcPeer,
|
|
||||||
pub(crate) id: PluginId,
|
|
||||||
pub(crate) name: String,
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub(crate) process: Arc<Child>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for Plugin {
|
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(
|
|
||||||
f,
|
|
||||||
"{}, plugin id: {:?}, process id: {}",
|
|
||||||
self.name,
|
|
||||||
self.id,
|
|
||||||
self.process.id()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Plugin {
|
|
||||||
pub fn initialize(&self, value: JsonValue) -> Result<(), SidecarError> {
|
|
||||||
self.peer.send_rpc_request("initialize", &value)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, SidecarError> {
|
|
||||||
self.peer.send_rpc_request(method, params)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn async_request<P: ResponseParser>(
|
|
||||||
&self,
|
|
||||||
method: &str,
|
|
||||||
params: &JsonValue,
|
|
||||||
) -> Result<P::ValueType, SidecarError> {
|
|
||||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
|
||||||
self.peer.async_send_rpc_request(
|
|
||||||
method,
|
|
||||||
params,
|
|
||||||
Box::new(move |result| {
|
|
||||||
let _ = tx.send(result);
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
let value = rx.await.map_err(|err| {
|
|
||||||
SidecarError::Internal(anyhow!("error waiting for async response: {:?}", err))
|
|
||||||
})??;
|
|
||||||
let value = P::parse_json(value)?;
|
|
||||||
Ok(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn stream_request<P: ResponseParser>(
|
|
||||||
&self,
|
|
||||||
method: &str,
|
|
||||||
params: &JsonValue,
|
|
||||||
) -> Result<ReceiverStream<Result<P::ValueType, SidecarError>>, SidecarError> {
|
|
||||||
let (tx, stream) = tokio::sync::mpsc::channel(100);
|
|
||||||
let stream = ReceiverStream::new(stream);
|
|
||||||
let callback = CloneableCallback::new(move |result| match result {
|
|
||||||
Ok(json) => {
|
|
||||||
let result = P::parse_json(json).map_err(SidecarError::from);
|
|
||||||
let _ = tx.blocking_send(result);
|
|
||||||
},
|
|
||||||
Err(err) => {
|
|
||||||
let _ = tx.blocking_send(Err(err));
|
|
||||||
},
|
|
||||||
});
|
|
||||||
self.peer.stream_rpc_request(method, params, callback);
|
|
||||||
Ok(stream)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn shutdown(&self) {
|
|
||||||
match self.peer.send_rpc_request("shutdown", &json!({})) {
|
|
||||||
Ok(_) => {
|
|
||||||
info!("shutting down plugin {}", self);
|
|
||||||
},
|
|
||||||
Err(err) => {
|
|
||||||
error!("error sending shutdown to plugin {}: {:?}", self, err);
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct PluginInfo {
|
|
||||||
pub name: String,
|
|
||||||
pub exec_path: PathBuf,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn start_plugin_process(
|
|
||||||
plugin_info: PluginInfo,
|
|
||||||
id: PluginId,
|
|
||||||
state: WeakSidecarState,
|
|
||||||
) -> Result<(), anyhow::Error> {
|
|
||||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
|
||||||
let spawn_result = thread::Builder::new()
|
|
||||||
.name(format!("<{}> core host thread", &plugin_info.name))
|
|
||||||
.spawn(move || {
|
|
||||||
info!("Load {} plugin", &plugin_info.name);
|
|
||||||
let child = std::process::Command::new(&plugin_info.exec_path)
|
|
||||||
.stdin(Stdio::piped())
|
|
||||||
.stdout(Stdio::piped())
|
|
||||||
.spawn();
|
|
||||||
|
|
||||||
match child {
|
|
||||||
Ok(mut child) => {
|
|
||||||
let child_stdin = child.stdin.take().unwrap();
|
|
||||||
let child_stdout = child.stdout.take().unwrap();
|
|
||||||
let mut looper = RpcLoop::new(child_stdin);
|
|
||||||
let peer: RpcPeer = Arc::new(looper.get_raw_peer());
|
|
||||||
let name = plugin_info.name.clone();
|
|
||||||
peer.send_rpc_notification("ping", &JsonValue::Array(Vec::new()));
|
|
||||||
|
|
||||||
let plugin = Plugin {
|
|
||||||
peer,
|
|
||||||
process: Arc::new(child),
|
|
||||||
name,
|
|
||||||
id,
|
|
||||||
};
|
|
||||||
|
|
||||||
state.plugin_connect(Ok(plugin));
|
|
||||||
let _ = tx.send(());
|
|
||||||
let mut state = state;
|
|
||||||
let err = looper.mainloop(
|
|
||||||
&plugin_info.name,
|
|
||||||
|| BufReader::new(child_stdout),
|
|
||||||
&mut state,
|
|
||||||
);
|
|
||||||
state.plugin_exit(id, err);
|
|
||||||
},
|
|
||||||
Err(err) => {
|
|
||||||
let _ = tx.send(());
|
|
||||||
state.plugin_connect(Err(err))
|
|
||||||
},
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Err(err) = spawn_result {
|
|
||||||
error!("[RPC] thread spawn failed for {:?}, {:?}", id, err);
|
|
||||||
return Err(err.into());
|
|
||||||
}
|
|
||||||
rx.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,271 +0,0 @@
|
|||||||
use crate::core::parser::{Call, MessageReader};
|
|
||||||
use crate::core::plugin::RpcCtx;
|
|
||||||
use crate::core::rpc_object::RpcObject;
|
|
||||||
use crate::core::rpc_peer::{RawPeer, ResponsePayload, RpcState};
|
|
||||||
use crate::error::{ReadError, RemoteError, SidecarError};
|
|
||||||
use serde::de::DeserializeOwned;
|
|
||||||
|
|
||||||
use std::io::{BufRead, Write};
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::thread;
|
|
||||||
use std::time::Duration;
|
|
||||||
use tracing::{error, trace};
|
|
||||||
|
|
||||||
const MAX_IDLE_WAIT: Duration = Duration::from_millis(5);
|
|
||||||
|
|
||||||
pub trait Handler {
|
|
||||||
type Request: DeserializeOwned;
|
|
||||||
fn handle_request(
|
|
||||||
&mut self,
|
|
||||||
ctx: &RpcCtx,
|
|
||||||
rpc: Self::Request,
|
|
||||||
) -> Result<ResponsePayload, RemoteError>;
|
|
||||||
#[allow(unused_variables)]
|
|
||||||
fn idle(&mut self, ctx: &RpcCtx, token: usize) {}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A helper type which shuts down the runloop if a panic occurs while
|
|
||||||
/// handling an RPC.
|
|
||||||
struct PanicGuard<'a, W: Write + 'static>(&'a RawPeer<W>);
|
|
||||||
|
|
||||||
impl<'a, W: Write + 'static> Drop for PanicGuard<'a, W> {
|
|
||||||
/// Implements the cleanup behavior when the guard is dropped.
|
|
||||||
///
|
|
||||||
/// This method is automatically called when the `PanicGuard` goes out of scope.
|
|
||||||
/// It checks if a panic is occurring and, if so, logs an error message and
|
|
||||||
/// disconnects the peer.
|
|
||||||
fn drop(&mut self) {
|
|
||||||
// - If no panic is occurring, this method does nothing.
|
|
||||||
// - If a panic is detected:
|
|
||||||
// 1. An error message is logged.
|
|
||||||
// 2. The `disconnect()` method is called on the peer.
|
|
||||||
if thread::panicking() {
|
|
||||||
error!("[RPC] panic guard hit, closing run loop");
|
|
||||||
self.0.disconnect();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A structure holding the state of a main loop for handling RPC's.
|
|
||||||
pub struct RpcLoop<W: Write + 'static> {
|
|
||||||
reader: MessageReader,
|
|
||||||
peer: RawPeer<W>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<W: Write + Send> RpcLoop<W> {
|
|
||||||
/// Creates a new `RpcLoop` with the given output stream (which is used for
|
|
||||||
/// sending requests and notifications, as well as responses).
|
|
||||||
pub fn new(writer: W) -> Self {
|
|
||||||
let rpc_peer = RawPeer(Arc::new(RpcState::new(writer)));
|
|
||||||
RpcLoop {
|
|
||||||
reader: MessageReader::default(),
|
|
||||||
peer: rpc_peer,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Gets a reference to the peer.
|
|
||||||
pub fn get_raw_peer(&self) -> RawPeer<W> {
|
|
||||||
self.peer.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Starts the event loop, reading lines from the reader until EOF or an error occurs.
|
|
||||||
///
|
|
||||||
/// Returns `Ok()` if EOF is reached, otherwise returns the underlying `ReadError`.
|
|
||||||
///
|
|
||||||
/// # Note:
|
|
||||||
/// The reader is provided via a closure to avoid needing `Send`. The main loop runs on a separate I/O thread that calls this closure at startup.
|
|
||||||
/// Calls to the handler occur on the caller's thread and maintain the order from the channel. Currently, there can only be one outstanding incoming request.
|
|
||||||
|
|
||||||
/// Starts and manages the main event loop for processing RPC messages.
|
|
||||||
///
|
|
||||||
/// This function is the core of the RPC system, handling incoming messages,
|
|
||||||
/// dispatching requests to the appropriate handler, and managing the overall
|
|
||||||
/// lifecycle of the RPC communication.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `&mut self` - A mutable reference to the `RpcLoop` instance.
|
|
||||||
/// * `_plugin_name: &str` - The name of the plugin (currently unused in the function body).
|
|
||||||
/// * `buffer_read_fn: BufferReadFn` - A closure that returns a `BufRead` instance for reading input.
|
|
||||||
/// * `handler: &mut H` - A mutable reference to the handler implementing the `Handler` trait.
|
|
||||||
///
|
|
||||||
/// # Type Parameters
|
|
||||||
///
|
|
||||||
/// * `R: BufRead` - The type returned by `buffer_read_fn`, must implement `BufRead`.
|
|
||||||
/// * `BufferReadFn: Send + FnOnce() -> R` - The type of the closure that provides the input reader.
|
|
||||||
/// * `H: Handler` - The type of the handler, must implement the `Handler` trait.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// * `Result<(), ReadError>` - Returns `Ok(())` if the loop exits normally (EOF),
|
|
||||||
/// or an error if an unrecoverable error occurs.
|
|
||||||
///
|
|
||||||
/// # Behavior
|
|
||||||
///
|
|
||||||
/// 1. Creates a new `RpcCtx` with a clone of the `RawPeer`.
|
|
||||||
/// 2. Spawns a separate thread for reading input using `crossbeam_utils::thread::scope`.
|
|
||||||
/// 3. In the reading thread:
|
|
||||||
/// - Continuously reads and parses JSON messages from the input.
|
|
||||||
/// - Handles responses by calling `handle_response` on the peer.
|
|
||||||
/// - Puts other messages into the peer's queue using `put_rpc_object`.
|
|
||||||
/// 4. In the main thread:
|
|
||||||
/// - Retrieves messages using `next_read`.
|
|
||||||
/// - Processes requests by calling the handler's `handle_request` method.
|
|
||||||
/// - Sends responses back using the peer's `respond` method.
|
|
||||||
/// 5. Continues looping until an error occurs or the peer is disconnected.
|
|
||||||
pub fn mainloop<R, BufferReadFn, H>(
|
|
||||||
&mut self,
|
|
||||||
_plugin_name: &str,
|
|
||||||
buffer_read_fn: BufferReadFn,
|
|
||||||
handler: &mut H,
|
|
||||||
) -> Result<(), ReadError>
|
|
||||||
where
|
|
||||||
R: BufRead,
|
|
||||||
BufferReadFn: Send + FnOnce() -> R,
|
|
||||||
H: Handler,
|
|
||||||
{
|
|
||||||
// uses `crossbeam_utils::thread::scope` for thread management,
|
|
||||||
// which offers several advantages over `std::thread`:
|
|
||||||
// 1. Scoped Threads: Guarantees thread termination when the scope ends,
|
|
||||||
// preventing resource leaks.
|
|
||||||
// 2. Simplified Lifetime Management: Allows threads to borrow data from
|
|
||||||
// their parent stack frame, enabling more ergonomic code.
|
|
||||||
// 3. Improved Safety: Prevents threads from outliving the data they operate on,
|
|
||||||
// reducing risks of data races and use-after-free errors.
|
|
||||||
// 4. Efficiency: Potentially more efficient due to known thread lifetimes,
|
|
||||||
// leading to better resource management.
|
|
||||||
// 5. Error Propagation: Simplifies propagating errors from spawned threads
|
|
||||||
// back to the parent thread.
|
|
||||||
// 6. Consistency with Rust's Ownership Model: Aligns well with Rust's
|
|
||||||
// ownership and borrowing rules.
|
|
||||||
// 7. Automatic Thread Joining: No need for manual thread joining, reducing
|
|
||||||
// the risk of thread management errors.
|
|
||||||
let exit = crossbeam_utils::thread::scope(|scope| {
|
|
||||||
let peer = self.get_raw_peer();
|
|
||||||
peer.reset_needs_exit();
|
|
||||||
|
|
||||||
let ctx = RpcCtx {
|
|
||||||
peer: Arc::new(peer.clone()),
|
|
||||||
};
|
|
||||||
|
|
||||||
// 1. Spawn a new thread for reading data from a stream.
|
|
||||||
// 2. Continuously read data from the stream.
|
|
||||||
// 3. Parse the data as JSON.
|
|
||||||
// 4. Handle the JSON data as either a response or another type of JSON object.
|
|
||||||
// 5. Manage errors and connection status.
|
|
||||||
scope.spawn(move |_| {
|
|
||||||
let mut stream = buffer_read_fn();
|
|
||||||
loop {
|
|
||||||
if self.peer.needs_exit() {
|
|
||||||
trace!("read loop exit");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let json = match self.reader.next(&mut stream) {
|
|
||||||
Ok(json) => json,
|
|
||||||
Err(err) => {
|
|
||||||
if self.peer.0.is_blocking() {
|
|
||||||
error!("[RPC] {:?}, disconnecting peer", err);
|
|
||||||
self.peer.disconnect();
|
|
||||||
}
|
|
||||||
self.peer.put_rpc_object(Err(err));
|
|
||||||
break;
|
|
||||||
},
|
|
||||||
};
|
|
||||||
if json.is_response() {
|
|
||||||
let request_id = json.get_id().unwrap();
|
|
||||||
match json.into_response() {
|
|
||||||
Ok(resp) => {
|
|
||||||
let resp = resp.map_err(SidecarError::from);
|
|
||||||
self.peer.handle_response(request_id, resp);
|
|
||||||
},
|
|
||||||
Err(msg) => {
|
|
||||||
error!("[RPC] failed to parse response: {}", msg);
|
|
||||||
self
|
|
||||||
.peer
|
|
||||||
.handle_response(request_id, Err(SidecarError::InvalidResponse));
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
self.peer.put_rpc_object(Ok(json));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Main processing loop
|
|
||||||
loop {
|
|
||||||
// `PanicGuard` is a critical safety mechanism in the RPC system. It's designed to detect
|
|
||||||
// panics that occur during RPC request handling and ensure that the system shuts down
|
|
||||||
// gracefully, preventing resource leaks and maintaining system integrity.
|
|
||||||
//
|
|
||||||
let _guard = PanicGuard(&peer);
|
|
||||||
let read_result = next_read(&peer, &ctx);
|
|
||||||
let json = match read_result {
|
|
||||||
Ok(json) => json,
|
|
||||||
Err(err) => {
|
|
||||||
error!("[RPC] error reading message: {:?}, disconnecting peer", err);
|
|
||||||
peer.disconnect();
|
|
||||||
return err;
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
match json.into_rpc::<H::Request>() {
|
|
||||||
Ok(Call::Request(id, cmd)) => {
|
|
||||||
// Handle request sent from the client. For example from python executable.
|
|
||||||
trace!("[RPC] received request: {}", id);
|
|
||||||
let result = handler.handle_request(&ctx, cmd);
|
|
||||||
peer.respond(result, id);
|
|
||||||
},
|
|
||||||
Ok(Call::InvalidRequest(id, err)) => {
|
|
||||||
trace!("[RPC] received invalid request: {}", id);
|
|
||||||
peer.respond(Err(err), id)
|
|
||||||
},
|
|
||||||
Err(err) => {
|
|
||||||
error!("[RPC] error parsing message: {:?}", err);
|
|
||||||
peer.disconnect();
|
|
||||||
return ReadError::UnknownRequest(err);
|
|
||||||
},
|
|
||||||
Ok(Call::Message(_msg)) => {
|
|
||||||
#[cfg(feature = "verbose")]
|
|
||||||
trace!("[RPC {}]: {}", _plugin_name, _msg);
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
if exit.is_disconnect() {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(exit)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// retrieves the next available read result from a peer, performing idle work if no result is
|
|
||||||
/// immediately available.
|
|
||||||
fn next_read<W>(peer: &RawPeer<W>, _ctx: &RpcCtx) -> Result<RpcObject, ReadError>
|
|
||||||
where
|
|
||||||
W: Write + Send,
|
|
||||||
{
|
|
||||||
loop {
|
|
||||||
// Continuously checks if there is a result available from the peer using
|
|
||||||
if let Some(result) = peer.try_get_rx() {
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
let time_to_next_timer = match peer.check_timers() {
|
|
||||||
Some(Ok(_token)) => continue,
|
|
||||||
Some(Err(duration)) => Some(duration),
|
|
||||||
None => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Ensures the function does not block indefinitely by setting a maximum wait time
|
|
||||||
let idle_timeout = time_to_next_timer
|
|
||||||
.unwrap_or(MAX_IDLE_WAIT)
|
|
||||||
.min(MAX_IDLE_WAIT);
|
|
||||||
|
|
||||||
if let Some(result) = peer.get_rx_timeout(idle_timeout) {
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,111 +0,0 @@
|
|||||||
use crate::core::parser::{Call, RequestId};
|
|
||||||
use crate::core::rpc_peer::{Response, ResponsePayload};
|
|
||||||
|
|
||||||
use serde::de::{DeserializeOwned, Error};
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct RpcObject(pub Value);
|
|
||||||
|
|
||||||
impl RpcObject {
|
|
||||||
/// Returns the 'id' of the underlying object, if present.
|
|
||||||
pub fn get_id(&self) -> Option<RequestId> {
|
|
||||||
self.0.get("id").and_then(Value::as_u64)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the 'method' field of the underlying object, if present.
|
|
||||||
pub fn get_method(&self) -> Option<&str> {
|
|
||||||
self.0.get("method").and_then(Value::as_str)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns `true` if this object looks like an RPC response;
|
|
||||||
/// that is, if it has an 'id' field and does _not_ have a 'method'
|
|
||||||
/// field.
|
|
||||||
pub fn is_response(&self) -> bool {
|
|
||||||
self.0.get("id").is_some() && self.0.get("method").is_none()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Converts a JSON-RPC response into a structured `Response` object.
|
|
||||||
///
|
|
||||||
/// This function validates and parses a JSON-RPC response, ensuring it contains the necessary fields,
|
|
||||||
/// and then transforms it into a structured `Response` object. The response must contain either a
|
|
||||||
/// "result" or an "error" field, but not both. If the response contains a "result" field, it may also
|
|
||||||
/// include streaming data, indicated by a nested "stream" field.
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// This function will return an error if:
|
|
||||||
/// - The "id" field is missing.
|
|
||||||
/// - The response contains both "result" and "error" fields, or neither.
|
|
||||||
/// - The "stream" field within the "result" is missing "type" or "data" fields.
|
|
||||||
/// - The "stream" type is invalid (i.e., not "streaming" or "end").
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// - `Ok(Ok(ResponsePayload::Json(result)))`: If the response contains a valid "result".
|
|
||||||
/// - `Ok(Ok(ResponsePayload::Streaming(data)))`: If the response contains streaming data of type "streaming".
|
|
||||||
/// - `Ok(Ok(ResponsePayload::StreamEnd(json!({}))))`: If the response contains streaming data of type "end".
|
|
||||||
/// - `Err(String)`: If any validation or parsing errors occur.
|
|
||||||
///.
|
|
||||||
pub fn into_response(mut self) -> Result<Response, String> {
|
|
||||||
// Ensure 'id' field is present
|
|
||||||
self
|
|
||||||
.get_id()
|
|
||||||
.ok_or_else(|| "Response requires 'id' field.".to_string())?;
|
|
||||||
|
|
||||||
// Ensure the response contains exactly one of 'result' or 'error'
|
|
||||||
let has_result = self.0.get("result").is_some();
|
|
||||||
let has_error = self.0.get("error").is_some();
|
|
||||||
if has_result == has_error {
|
|
||||||
return Err("RPC response must contain exactly one of 'error' or 'result' fields.".into());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle the 'result' field if present
|
|
||||||
if let Some(mut result) = self.0.as_object_mut().and_then(|obj| obj.remove("result")) {
|
|
||||||
if let Some(mut stream) = result.as_object_mut().and_then(|obj| obj.remove("stream")) {
|
|
||||||
if let Some((has_more, data)) = stream.as_object_mut().and_then(|obj| {
|
|
||||||
let has_more = obj.remove("has_more")?.as_bool().unwrap_or(false);
|
|
||||||
let data = obj.remove("data")?;
|
|
||||||
Some((has_more, data))
|
|
||||||
}) {
|
|
||||||
return match has_more {
|
|
||||||
true => Ok(Ok(ResponsePayload::Streaming(data))),
|
|
||||||
false => Ok(Ok(ResponsePayload::StreamEnd(data))),
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
return Err("Stream response must contain 'type' and 'data' fields.".into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Ok(ResponsePayload::Json(result)))
|
|
||||||
} else {
|
|
||||||
// Handle the 'error' field
|
|
||||||
let error = self.0.as_object_mut().unwrap().remove("error").unwrap();
|
|
||||||
Err(format!("Error handling response: {:?}", error))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Converts the underlying `Value` into either an RPC notification or request.
|
|
||||||
pub fn into_rpc<R>(self) -> Result<Call<R>, serde_json::Error>
|
|
||||||
where
|
|
||||||
R: DeserializeOwned,
|
|
||||||
{
|
|
||||||
let id = self.get_id();
|
|
||||||
match id {
|
|
||||||
Some(id) => match serde_json::from_value::<R>(self.0) {
|
|
||||||
Ok(resp) => Ok(Call::Request(id, resp)),
|
|
||||||
Err(err) => Ok(Call::InvalidRequest(id, err.into())),
|
|
||||||
},
|
|
||||||
None => match self.0.get("message").and_then(|value| value.as_str()) {
|
|
||||||
None => Err(serde_json::Error::missing_field("message")),
|
|
||||||
Some(s) => Ok(Call::Message(s.to_string().into())),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<Value> for RpcObject {
|
|
||||||
fn from(v: Value) -> RpcObject {
|
|
||||||
RpcObject(v)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,500 +0,0 @@
|
|||||||
use crate::core::plugin::{Peer, PluginId};
|
|
||||||
use crate::core::rpc_object::RpcObject;
|
|
||||||
use crate::error::{ReadError, RemoteError, SidecarError};
|
|
||||||
use parking_lot::{Condvar, Mutex};
|
|
||||||
use serde::{de, ser, Deserialize, Deserializer, Serialize, Serializer};
|
|
||||||
use serde_json::{json, Value as JsonValue};
|
|
||||||
use std::collections::{BTreeMap, BinaryHeap, VecDeque};
|
|
||||||
use std::fmt::Display;
|
|
||||||
use std::io::Write;
|
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
|
||||||
use std::sync::{mpsc, Arc};
|
|
||||||
use std::time::{Duration, Instant};
|
|
||||||
use std::{cmp, io};
|
|
||||||
use tokio_stream::Stream;
|
|
||||||
use tracing::{error, trace, warn};
|
|
||||||
|
|
||||||
pub struct PluginCommand<T> {
|
|
||||||
pub plugin_id: PluginId,
|
|
||||||
pub cmd: T,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Serialize> Serialize for PluginCommand<T> {
|
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
||||||
where
|
|
||||||
S: Serializer,
|
|
||||||
{
|
|
||||||
let mut v = serde_json::to_value(&self.cmd).map_err(ser::Error::custom)?;
|
|
||||||
v["params"]["plugin_id"] = json!(self.plugin_id);
|
|
||||||
v.serialize(serializer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'de, T: Deserialize<'de>> Deserialize<'de> for PluginCommand<T> {
|
|
||||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
||||||
where
|
|
||||||
D: Deserializer<'de>,
|
|
||||||
{
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct PluginIdHelper {
|
|
||||||
plugin_id: PluginId,
|
|
||||||
}
|
|
||||||
let v = JsonValue::deserialize(deserializer)?;
|
|
||||||
let plugin_id = PluginIdHelper::deserialize(&v)
|
|
||||||
.map_err(de::Error::custom)?
|
|
||||||
.plugin_id;
|
|
||||||
let cmd = T::deserialize(v).map_err(de::Error::custom)?;
|
|
||||||
Ok(PluginCommand { plugin_id, cmd })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct RpcState<W: Write> {
|
|
||||||
rx_queue: Mutex<VecDeque<Result<RpcObject, ReadError>>>,
|
|
||||||
rx_cvar: Condvar,
|
|
||||||
writer: Mutex<W>,
|
|
||||||
id: AtomicUsize,
|
|
||||||
pending: Mutex<BTreeMap<usize, ResponseHandler>>,
|
|
||||||
timers: Mutex<BinaryHeap<Timer>>,
|
|
||||||
needs_exit: AtomicBool,
|
|
||||||
is_blocking: AtomicBool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<W: Write> RpcState<W> {
|
|
||||||
/// Creates a new `RawPeer` instance.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `writer` - An object implementing the `Write` trait, used for sending messages.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// A new `RawPeer` instance wrapped in an `Arc`.
|
|
||||||
pub fn new(writer: W) -> Self {
|
|
||||||
RpcState {
|
|
||||||
rx_queue: Mutex::new(VecDeque::new()),
|
|
||||||
rx_cvar: Condvar::new(),
|
|
||||||
writer: Mutex::new(writer),
|
|
||||||
id: AtomicUsize::new(0),
|
|
||||||
pending: Mutex::new(BTreeMap::new()),
|
|
||||||
timers: Mutex::new(BinaryHeap::new()),
|
|
||||||
needs_exit: AtomicBool::new(false),
|
|
||||||
is_blocking: Default::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_blocking(&self) -> bool {
|
|
||||||
self.is_blocking.load(Ordering::Acquire)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct RawPeer<W: Write + 'static>(pub(crate) Arc<RpcState<W>>);
|
|
||||||
|
|
||||||
impl<W: Write + Send + 'static> Peer for RawPeer<W> {
|
|
||||||
fn box_clone(&self) -> Arc<dyn Peer> {
|
|
||||||
Arc::new((*self).clone())
|
|
||||||
}
|
|
||||||
fn send_rpc_notification(&self, method: &str, params: &JsonValue) {
|
|
||||||
if let Err(e) = self.send(&json!({
|
|
||||||
"method": method,
|
|
||||||
"params": params,
|
|
||||||
})) {
|
|
||||||
error!(
|
|
||||||
"send error on send_rpc_notification method {}: {}",
|
|
||||||
method, e
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn stream_rpc_request(&self, method: &str, params: &JsonValue, f: CloneableCallback) {
|
|
||||||
self.send_rpc(method, params, ResponseHandler::StreamCallback(Arc::new(f)));
|
|
||||||
}
|
|
||||||
|
|
||||||
fn async_send_rpc_request(&self, method: &str, params: &JsonValue, f: Box<dyn OneShotCallback>) {
|
|
||||||
self.send_rpc(method, params, ResponseHandler::Callback(f));
|
|
||||||
}
|
|
||||||
|
|
||||||
fn send_rpc_request(&self, method: &str, params: &JsonValue) -> Result<JsonValue, SidecarError> {
|
|
||||||
let (tx, rx) = mpsc::channel();
|
|
||||||
self.0.is_blocking.store(true, Ordering::Release);
|
|
||||||
self.send_rpc(method, params, ResponseHandler::Chan(tx));
|
|
||||||
rx.recv().unwrap_or(Err(SidecarError::PeerDisconnect))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn request_is_pending(&self) -> bool {
|
|
||||||
let queue = self.0.rx_queue.lock();
|
|
||||||
!queue.is_empty()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn schedule_timer(&self, after: Instant, token: usize) {
|
|
||||||
self.0.timers.lock().push(Timer {
|
|
||||||
fire_after: after,
|
|
||||||
token,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<W: Write> RawPeer<W> {
|
|
||||||
/// Sends a JSON value to the peer.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `json` - A reference to a `JsonValue` to be sent.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// A `Result` indicating success or an `io::Error` if the write operation fails.
|
|
||||||
///
|
|
||||||
/// # Notes
|
|
||||||
///
|
|
||||||
/// This function serializes the JSON value, appends a newline, and writes it to the underlying writer.
|
|
||||||
fn send(&self, json: &JsonValue) -> Result<(), io::Error> {
|
|
||||||
let mut s = serde_json::to_string(json).unwrap();
|
|
||||||
s.push('\n');
|
|
||||||
self.0.writer.lock().write_all(s.as_bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sends a response to a previous RPC request.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `result` - The `Response` to be sent.
|
|
||||||
/// * `id` - The ID of the request being responded to.
|
|
||||||
///
|
|
||||||
/// # Notes
|
|
||||||
///
|
|
||||||
/// This function constructs a JSON response and sends it using the `send` method.
|
|
||||||
/// It handles both successful results and errors.
|
|
||||||
pub(crate) fn respond(&self, result: Response, id: u64) {
|
|
||||||
let mut response = json!({ "id": id });
|
|
||||||
match result {
|
|
||||||
Ok(result) => match result {
|
|
||||||
ResponsePayload::Json(value) => response["result"] = value,
|
|
||||||
ResponsePayload::Streaming(_) | ResponsePayload::StreamEnd(_) => {
|
|
||||||
error!("stream response not supported")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Err(error) => response["error"] = json!(error),
|
|
||||||
};
|
|
||||||
if let Err(e) = self.send(&response) {
|
|
||||||
error!("[RPC] error {} sending response to RPC {:?}", e, id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sends an RPC request.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `method` - The name of the RPC method to be called.
|
|
||||||
/// * `params` - The parameters for the RPC call.
|
|
||||||
/// * `response_handler` - A `ResponseHandler` to handle the response.
|
|
||||||
///
|
|
||||||
/// # Notes
|
|
||||||
///
|
|
||||||
/// This function generates a unique ID for the request, stores the response handler,
|
|
||||||
/// and sends the RPC request. If sending fails, it immediately invokes the response handler with an error.
|
|
||||||
fn send_rpc(&self, method: &str, params: &JsonValue, response_handler: ResponseHandler) {
|
|
||||||
trace!("[RPC] call method: {} params: {:?}", method, params);
|
|
||||||
let id = self.0.id.fetch_add(1, Ordering::Relaxed);
|
|
||||||
{
|
|
||||||
let mut pending = self.0.pending.lock();
|
|
||||||
pending.insert(id, response_handler);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call the ResponseHandler if the send fails. Otherwise, the response will be
|
|
||||||
// called in handle_response.
|
|
||||||
if let Err(e) = self.send(&json!({
|
|
||||||
"id": id,
|
|
||||||
"method": method,
|
|
||||||
"params": params,
|
|
||||||
})) {
|
|
||||||
let mut pending = self.0.pending.lock();
|
|
||||||
if let Some(rh) = pending.remove(&id) {
|
|
||||||
rh.invoke(Err(SidecarError::Io(e)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Processes an incoming response to an RPC request.
|
|
||||||
///
|
|
||||||
/// This function is responsible for handling responses received from the peer, matching them
|
|
||||||
/// to their corresponding requests, and invoking the appropriate callbacks. It supports both
|
|
||||||
/// one-time responses and streaming responses.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `&self` - A reference to the `RawPeer` instance.
|
|
||||||
/// * `request_id: u64` - The unique identifier of the request to which this is a response.
|
|
||||||
/// * `resp: Result<ResponsePayload, SidecarError>` - The response payload or an error.
|
|
||||||
///
|
|
||||||
/// # Behavior
|
|
||||||
///
|
|
||||||
/// 1. Retrieves and removes the response handler for the given `request_id` from the pending requests.
|
|
||||||
/// 2. Determines if the response is part of a stream.
|
|
||||||
/// 3. For streaming responses:
|
|
||||||
/// - If it's not the end of the stream, re-inserts the stream callback for future messages.
|
|
||||||
/// - If it's the end of the stream, logs this information.
|
|
||||||
/// 4. Converts the response payload to JSON.
|
|
||||||
/// 5. Invokes the response handler with the JSON data or error.
|
|
||||||
///
|
|
||||||
/// # Concurrency
|
|
||||||
///
|
|
||||||
/// This function uses mutex locks to ensure thread-safe access to shared data structures.
|
|
||||||
/// It's designed to be called from multiple threads safely.
|
|
||||||
///
|
|
||||||
/// # Error Handling
|
|
||||||
///
|
|
||||||
/// - If no handler is found for the `request_id`, an error is logged.
|
|
||||||
/// - If a non-stream response payload is `None`, a warning is logged.
|
|
||||||
/// - Errors in the response are propagated to the response handler.
|
|
||||||
pub(crate) fn handle_response(
|
|
||||||
&self,
|
|
||||||
request_id: u64,
|
|
||||||
resp: Result<ResponsePayload, SidecarError>,
|
|
||||||
) {
|
|
||||||
let request_id = request_id as usize;
|
|
||||||
let handler = {
|
|
||||||
let mut pending = self.0.pending.lock();
|
|
||||||
pending.remove(&request_id)
|
|
||||||
};
|
|
||||||
let is_stream = resp.as_ref().map(|resp| resp.is_stream()).unwrap_or(false);
|
|
||||||
match handler {
|
|
||||||
Some(response_handler) => {
|
|
||||||
if is_stream {
|
|
||||||
let is_stream_end = resp
|
|
||||||
.as_ref()
|
|
||||||
.map(|resp| resp.is_stream_end())
|
|
||||||
.unwrap_or(false);
|
|
||||||
if !is_stream_end {
|
|
||||||
// when steam is not end, we need to put the stream callback back to pending in order to
|
|
||||||
// receive the next stream message.
|
|
||||||
if let Some(callback) = response_handler.get_stream_callback() {
|
|
||||||
let mut pending = self.0.pending.lock();
|
|
||||||
pending.insert(request_id, ResponseHandler::StreamCallback(callback));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
trace!("[RPC] {} stream end", request_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let json = resp.map(|resp| resp.into_json());
|
|
||||||
match json {
|
|
||||||
Ok(Some(json)) => {
|
|
||||||
response_handler.invoke(Ok(json));
|
|
||||||
},
|
|
||||||
Ok(None) => {
|
|
||||||
if !is_stream {
|
|
||||||
warn!("[RPC] only stream response can be None");
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Err(err) => {
|
|
||||||
response_handler.invoke(Err(err));
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
None => error!("[RPC] id {}'s handle not found", request_id),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a message from the receive queue if available.
|
|
||||||
pub(crate) fn try_get_rx(&self) -> Option<Result<RpcObject, ReadError>> {
|
|
||||||
let mut queue = self.0.rx_queue.lock();
|
|
||||||
queue.pop_front()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a message from the receive queue, waiting for at most `Duration`
|
|
||||||
/// and returning `None` if no message is available.
|
|
||||||
pub(crate) fn get_rx_timeout(&self, dur: Duration) -> Option<Result<RpcObject, ReadError>> {
|
|
||||||
let mut queue = self.0.rx_queue.lock();
|
|
||||||
let result = self.0.rx_cvar.wait_for(&mut queue, dur);
|
|
||||||
if result.timed_out() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
queue.pop_front()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Adds a message to the receive queue. The message should only
|
|
||||||
/// be `None` if the read thread is exiting.
|
|
||||||
pub(crate) fn put_rpc_object(&self, json: Result<RpcObject, ReadError>) {
|
|
||||||
let mut queue = self.0.rx_queue.lock();
|
|
||||||
queue.push_back(json);
|
|
||||||
self.0.rx_cvar.notify_one();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Checks the status of the most imminent timer.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// - `Some(Ok(usize))`: If the most imminent timer has expired, returns its token.
|
|
||||||
/// - `Some(Err(Duration))`: If the most imminent timer has not yet expired, returns the time until it expires.
|
|
||||||
/// - `None`: If no timers are registered.
|
|
||||||
pub(crate) fn check_timers(&self) -> Option<Result<usize, Duration>> {
|
|
||||||
let mut timers = self.0.timers.lock();
|
|
||||||
match timers.peek() {
|
|
||||||
None => return None,
|
|
||||||
Some(t) => {
|
|
||||||
let now = Instant::now();
|
|
||||||
if t.fire_after > now {
|
|
||||||
return Some(Err(t.fire_after - now));
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
Some(Ok(timers.pop().unwrap().token))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// send disconnect error to pending requests.
|
|
||||||
pub(crate) fn disconnect(&self) {
|
|
||||||
trace!("[RPC] disconnecting peer");
|
|
||||||
let mut pending = self.0.pending.lock();
|
|
||||||
let ids = pending.keys().cloned().collect::<Vec<_>>();
|
|
||||||
for id in &ids {
|
|
||||||
let callback = pending.remove(id).unwrap();
|
|
||||||
callback.invoke(Err(SidecarError::PeerDisconnect));
|
|
||||||
}
|
|
||||||
self.0.needs_exit.store(true, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Checks if the RPC system needs to exit.
|
|
||||||
pub(crate) fn needs_exit(&self) -> bool {
|
|
||||||
self.0.needs_exit.load(Ordering::Relaxed)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn reset_needs_exit(&self) {
|
|
||||||
self.0.needs_exit.store(false, Ordering::SeqCst);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<W: Write> Clone for RawPeer<W> {
|
|
||||||
fn clone(&self) -> Self {
|
|
||||||
RawPeer(self.0.clone())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub enum ResponsePayload {
|
|
||||||
Json(JsonValue),
|
|
||||||
Streaming(JsonValue),
|
|
||||||
StreamEnd(JsonValue),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ResponsePayload {
|
|
||||||
pub fn empty_json() -> Self {
|
|
||||||
ResponsePayload::Json(json!({}))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_stream(&self) -> bool {
|
|
||||||
matches!(
|
|
||||||
self,
|
|
||||||
ResponsePayload::Streaming(_) | ResponsePayload::StreamEnd(_)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_stream_end(&self) -> bool {
|
|
||||||
matches!(self, ResponsePayload::StreamEnd(_))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn into_json(self) -> Option<JsonValue> {
|
|
||||||
match self {
|
|
||||||
ResponsePayload::Json(v) => Some(v),
|
|
||||||
ResponsePayload::Streaming(v) => Some(v),
|
|
||||||
ResponsePayload::StreamEnd(_) => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for ResponsePayload {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
ResponsePayload::Json(v) => write!(f, "{}", v),
|
|
||||||
ResponsePayload::Streaming(_) => write!(f, "stream start"),
|
|
||||||
ResponsePayload::StreamEnd(_) => write!(f, "stream end"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type Response = Result<ResponsePayload, RemoteError>;
|
|
||||||
|
|
||||||
pub trait ResponseStream: Stream<Item = Result<JsonValue, SidecarError>> + Unpin + Send {}
|
|
||||||
|
|
||||||
impl<T> ResponseStream for T where T: Stream<Item = Result<JsonValue, SidecarError>> + Unpin + Send {}
|
|
||||||
|
|
||||||
enum ResponseHandler {
|
|
||||||
Chan(mpsc::Sender<Result<JsonValue, SidecarError>>),
|
|
||||||
Callback(Box<dyn OneShotCallback>),
|
|
||||||
StreamCallback(Arc<CloneableCallback>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ResponseHandler {
|
|
||||||
pub fn get_stream_callback(&self) -> Option<Arc<CloneableCallback>> {
|
|
||||||
match self {
|
|
||||||
ResponseHandler::StreamCallback(cb) => Some(cb.clone()),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait OneShotCallback: Send {
|
|
||||||
fn call(self: Box<Self>, result: Result<JsonValue, SidecarError>);
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F: Send + FnOnce(Result<JsonValue, SidecarError>)> OneShotCallback for F {
|
|
||||||
fn call(self: Box<Self>, result: Result<JsonValue, SidecarError>) {
|
|
||||||
(self)(result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Callback: Send + Sync {
|
|
||||||
fn call(&self, result: Result<JsonValue, SidecarError>);
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F: Send + Sync + Fn(Result<JsonValue, SidecarError>)> Callback for F {
|
|
||||||
fn call(&self, result: Result<JsonValue, SidecarError>) {
|
|
||||||
(*self)(result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct CloneableCallback {
|
|
||||||
callback: Arc<dyn Callback>,
|
|
||||||
}
|
|
||||||
impl CloneableCallback {
|
|
||||||
pub fn new<C: Callback + 'static>(callback: C) -> Self {
|
|
||||||
CloneableCallback {
|
|
||||||
callback: Arc::new(callback),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn call(&self, result: Result<JsonValue, SidecarError>) {
|
|
||||||
self.callback.call(result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ResponseHandler {
|
|
||||||
fn invoke(self, result: Result<JsonValue, SidecarError>) {
|
|
||||||
match self {
|
|
||||||
ResponseHandler::Chan(tx) => {
|
|
||||||
let _ = tx.send(result);
|
|
||||||
},
|
|
||||||
ResponseHandler::StreamCallback(cb) => {
|
|
||||||
cb.call(result);
|
|
||||||
},
|
|
||||||
ResponseHandler::Callback(f) => f.call(result),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
|
||||||
struct Timer {
|
|
||||||
fire_after: Instant,
|
|
||||||
token: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Ord for Timer {
|
|
||||||
fn cmp(&self, other: &Timer) -> cmp::Ordering {
|
|
||||||
other.fire_after.cmp(&self.fire_after)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PartialOrd for Timer {
|
|
||||||
fn partial_cmp(&self, other: &Timer) -> Option<cmp::Ordering> {
|
|
||||||
Some(self.cmp(other))
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,174 +0,0 @@
|
|||||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
|
||||||
use serde_json::{json, Value as JsonValue};
|
|
||||||
use std::{fmt, io};
|
|
||||||
|
|
||||||
/// The error type of `tauri-utils`.
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
|
||||||
pub enum SidecarError {
|
|
||||||
/// An IO error occurred on the underlying communication channel.
|
|
||||||
#[error(transparent)]
|
|
||||||
Io(#[from] io::Error),
|
|
||||||
/// The peer returned an error.
|
|
||||||
#[error("Remote error: {0}")]
|
|
||||||
RemoteError(RemoteError),
|
|
||||||
/// The peer closed the connection.
|
|
||||||
#[error("Peer closed the connection.")]
|
|
||||||
PeerDisconnect,
|
|
||||||
/// The peer sent a response containing the id, but was malformed.
|
|
||||||
#[error("Invalid response.")]
|
|
||||||
InvalidResponse,
|
|
||||||
|
|
||||||
#[error(transparent)]
|
|
||||||
Internal(#[from] anyhow::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum ReadError {
|
|
||||||
/// An error occurred in the underlying stream
|
|
||||||
Io(io::Error),
|
|
||||||
/// The message was not valid JSON.
|
|
||||||
Json(serde_json::Error),
|
|
||||||
/// The message was not a JSON object.
|
|
||||||
NotObject(String),
|
|
||||||
/// The the method and params were not recognized by the handler.
|
|
||||||
UnknownRequest(serde_json::Error),
|
|
||||||
/// The peer closed the connection.
|
|
||||||
Disconnect(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, thiserror::Error)]
|
|
||||||
pub enum RemoteError {
|
|
||||||
/// The JSON was valid, but was not a correctly formed request.
|
|
||||||
///
|
|
||||||
/// This Error is used internally, and should not be returned by
|
|
||||||
/// clients.
|
|
||||||
#[error("Invalid request: {0:?}")]
|
|
||||||
InvalidRequest(Option<JsonValue>),
|
|
||||||
|
|
||||||
#[error("Invalid response: {0}")]
|
|
||||||
InvalidResponse(JsonValue),
|
|
||||||
|
|
||||||
#[error("Parse response: {0}")]
|
|
||||||
ParseResponse(JsonValue),
|
|
||||||
/// A custom error, defined by the client.
|
|
||||||
#[error("Custom error: {message}")]
|
|
||||||
Custom {
|
|
||||||
code: i64,
|
|
||||||
message: String,
|
|
||||||
data: Option<JsonValue>,
|
|
||||||
},
|
|
||||||
/// An error that cannot be represented by an error object.
|
|
||||||
///
|
|
||||||
/// This error is intended to accommodate clients that return arbitrary
|
|
||||||
/// error values. It should not be used for new errors.
|
|
||||||
#[error("Unknown error: {0}")]
|
|
||||||
Unknown(JsonValue),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ReadError {
|
|
||||||
/// Returns `true` iff this is the `ReadError::Disconnect` variant.
|
|
||||||
pub fn is_disconnect(&self) -> bool {
|
|
||||||
matches!(*self, ReadError::Disconnect(_))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for ReadError {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
||||||
match self {
|
|
||||||
ReadError::Io(ref err) => write!(f, "I/O Error: {:?}", err),
|
|
||||||
ReadError::Json(ref err) => write!(f, "JSON Error: {:?}", err),
|
|
||||||
ReadError::NotObject(s) => write!(f, "Expected JSON object, found: {}", s),
|
|
||||||
ReadError::UnknownRequest(ref err) => write!(f, "Unknown request: {:?}", err),
|
|
||||||
ReadError::Disconnect(reason) => write!(f, "Peer closed the connection, reason: {}", reason),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<serde_json::Error> for ReadError {
|
|
||||||
fn from(err: serde_json::Error) -> ReadError {
|
|
||||||
ReadError::Json(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<io::Error> for ReadError {
|
|
||||||
fn from(err: io::Error) -> ReadError {
|
|
||||||
ReadError::Io(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<serde_json::Error> for RemoteError {
|
|
||||||
fn from(err: serde_json::Error) -> RemoteError {
|
|
||||||
RemoteError::InvalidRequest(Some(json!(err.to_string())))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<RemoteError> for SidecarError {
|
|
||||||
fn from(err: RemoteError) -> SidecarError {
|
|
||||||
SidecarError::RemoteError(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize)]
|
|
||||||
struct ErrorHelper {
|
|
||||||
code: i64,
|
|
||||||
message: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
data: Option<JsonValue>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'de> Deserialize<'de> for RemoteError {
|
|
||||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
||||||
where
|
|
||||||
D: Deserializer<'de>,
|
|
||||||
{
|
|
||||||
let v = JsonValue::deserialize(deserializer)?;
|
|
||||||
let resp = match ErrorHelper::deserialize(&v) {
|
|
||||||
Ok(resp) => resp,
|
|
||||||
Err(_) => return Ok(RemoteError::Unknown(v)),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(match resp.code {
|
|
||||||
-32600 => RemoteError::InvalidRequest(resp.data),
|
|
||||||
_ => RemoteError::Custom {
|
|
||||||
code: resp.code,
|
|
||||||
message: resp.message,
|
|
||||||
data: resp.data,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Serialize for RemoteError {
|
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
||||||
where
|
|
||||||
S: Serializer,
|
|
||||||
{
|
|
||||||
let (code, message, data) = match self {
|
|
||||||
RemoteError::InvalidRequest(ref d) => (-32600, "Invalid request".to_string(), d.clone()),
|
|
||||||
RemoteError::Custom {
|
|
||||||
code,
|
|
||||||
ref message,
|
|
||||||
ref data,
|
|
||||||
} => (*code, message.clone(), data.clone()),
|
|
||||||
RemoteError::Unknown(_) => {
|
|
||||||
panic!("The 'Unknown' error variant is not intended for client use.")
|
|
||||||
},
|
|
||||||
RemoteError::InvalidResponse(resp) => (
|
|
||||||
-1,
|
|
||||||
"Invalid response".to_string(),
|
|
||||||
Some(json!(resp.to_string())),
|
|
||||||
),
|
|
||||||
RemoteError::ParseResponse(resp) => (
|
|
||||||
-1,
|
|
||||||
"Invalid response".to_string(),
|
|
||||||
Some(json!(resp.to_string())),
|
|
||||||
),
|
|
||||||
};
|
|
||||||
let err = ErrorHelper {
|
|
||||||
code,
|
|
||||||
message,
|
|
||||||
data,
|
|
||||||
};
|
|
||||||
err.serialize(serializer)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,3 +0,0 @@
|
|||||||
pub mod core;
|
|
||||||
pub mod error;
|
|
||||||
pub mod manager;
|
|
@ -1,201 +0,0 @@
|
|||||||
use crate::core::parser::ResponseParser;
|
|
||||||
use crate::core::plugin::{start_plugin_process, Plugin, PluginId, PluginInfo, RpcCtx};
|
|
||||||
use crate::core::rpc_loop::Handler;
|
|
||||||
use crate::core::rpc_peer::{PluginCommand, ResponsePayload};
|
|
||||||
use crate::error::{ReadError, RemoteError, SidecarError};
|
|
||||||
use anyhow::anyhow;
|
|
||||||
use lib_infra::util::{get_operating_system, OperatingSystem};
|
|
||||||
use parking_lot::Mutex;
|
|
||||||
use serde_json::Value;
|
|
||||||
use std::io;
|
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicI64, Ordering};
|
|
||||||
use std::sync::{Arc, Weak};
|
|
||||||
use tracing::{error, info, instrument, trace, warn};
|
|
||||||
|
|
||||||
pub struct SidecarManager {
|
|
||||||
state: Arc<Mutex<SidecarState>>,
|
|
||||||
plugin_id_counter: Arc<AtomicI64>,
|
|
||||||
operating_system: OperatingSystem,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for SidecarManager {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SidecarManager {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
SidecarManager {
|
|
||||||
state: Arc::new(Mutex::new(SidecarState {
|
|
||||||
plugins: Vec::new(),
|
|
||||||
})),
|
|
||||||
plugin_id_counter: Arc::new(Default::default()),
|
|
||||||
operating_system: get_operating_system(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result<PluginId, SidecarError> {
|
|
||||||
if self.operating_system.is_not_desktop() {
|
|
||||||
return Err(SidecarError::Internal(anyhow!(
|
|
||||||
"plugin not supported on this platform"
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
let plugin_id = PluginId::from(self.plugin_id_counter.fetch_add(1, Ordering::SeqCst));
|
|
||||||
let weak_state = WeakSidecarState(Arc::downgrade(&self.state));
|
|
||||||
start_plugin_process(plugin_info, plugin_id, weak_state).await?;
|
|
||||||
Ok(plugin_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_plugin(&self, plugin_id: PluginId) -> Result<Weak<Plugin>, SidecarError> {
|
|
||||||
let state = self.state.lock();
|
|
||||||
let plugin = state
|
|
||||||
.plugins
|
|
||||||
.iter()
|
|
||||||
.find(|p| p.id == plugin_id)
|
|
||||||
.ok_or(anyhow!("plugin not found"))?;
|
|
||||||
Ok(Arc::downgrade(plugin))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip(self), err)]
|
|
||||||
pub async fn remove_plugin(&self, id: PluginId) -> Result<(), SidecarError> {
|
|
||||||
if self.operating_system.is_not_desktop() {
|
|
||||||
return Err(SidecarError::Internal(anyhow!(
|
|
||||||
"plugin not supported on this platform"
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
info!("[RPC] removing plugin {:?}", id);
|
|
||||||
self.state.lock().plugin_disconnect(id, Ok(()));
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<Arc<Plugin>, SidecarError> {
|
|
||||||
if self.operating_system.is_not_desktop() {
|
|
||||||
return Err(SidecarError::Internal(anyhow!(
|
|
||||||
"plugin not supported on this platform"
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let state = self.state.lock();
|
|
||||||
let plugin = state
|
|
||||||
.plugins
|
|
||||||
.iter()
|
|
||||||
.find(|p| p.id == id)
|
|
||||||
.ok_or(anyhow!("plugin not found"))?;
|
|
||||||
plugin.initialize(init_params)?;
|
|
||||||
|
|
||||||
Ok(plugin.clone())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send_request<P: ResponseParser>(
|
|
||||||
&self,
|
|
||||||
id: PluginId,
|
|
||||||
method: &str,
|
|
||||||
request: Value,
|
|
||||||
) -> Result<P::ValueType, SidecarError> {
|
|
||||||
let state = self.state.lock();
|
|
||||||
let plugin = state
|
|
||||||
.plugins
|
|
||||||
.iter()
|
|
||||||
.find(|p| p.id == id)
|
|
||||||
.ok_or(anyhow!("plugin not found"))?;
|
|
||||||
let resp = plugin.request(method, &request)?;
|
|
||||||
let value = P::parse_json(resp)?;
|
|
||||||
Ok(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn async_send_request<P: ResponseParser>(
|
|
||||||
&self,
|
|
||||||
id: PluginId,
|
|
||||||
method: &str,
|
|
||||||
request: Value,
|
|
||||||
) -> Result<P::ValueType, SidecarError> {
|
|
||||||
let plugin = self
|
|
||||||
.state
|
|
||||||
.lock()
|
|
||||||
.plugins
|
|
||||||
.iter()
|
|
||||||
.find(|p| p.id == id)
|
|
||||||
.ok_or(anyhow!("plugin not found"))
|
|
||||||
.cloned()?;
|
|
||||||
let value = plugin.async_request::<P>(method, &request).await?;
|
|
||||||
Ok(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct SidecarState {
|
|
||||||
plugins: Vec<Arc<Plugin>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SidecarState {
|
|
||||||
pub fn plugin_connect(&mut self, plugin: Result<Plugin, io::Error>) {
|
|
||||||
match plugin {
|
|
||||||
Ok(plugin) => {
|
|
||||||
info!("[RPC] {} connected", plugin);
|
|
||||||
self.plugins.push(Arc::new(plugin));
|
|
||||||
},
|
|
||||||
Err(err) => {
|
|
||||||
warn!("plugin failed to connect: {:?}", err);
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn plugin_disconnect(
|
|
||||||
&mut self,
|
|
||||||
id: PluginId,
|
|
||||||
error: Result<(), ReadError>,
|
|
||||||
) -> Option<Arc<Plugin>> {
|
|
||||||
if let Err(err) = error {
|
|
||||||
error!("[RPC] plugin {:?} exited with result {:?}", id, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
let running_idx = self.plugins.iter().position(|p| p.id == id);
|
|
||||||
match running_idx {
|
|
||||||
Some(idx) => {
|
|
||||||
let plugin = self.plugins.remove(idx);
|
|
||||||
plugin.shutdown();
|
|
||||||
Some(plugin)
|
|
||||||
},
|
|
||||||
None => {
|
|
||||||
warn!("[RPC] plugin {:?} not found", id);
|
|
||||||
None
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct WeakSidecarState(Weak<Mutex<SidecarState>>);
|
|
||||||
|
|
||||||
impl WeakSidecarState {
|
|
||||||
pub fn upgrade(&self) -> Option<Arc<Mutex<SidecarState>>> {
|
|
||||||
self.0.upgrade()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn plugin_connect(&self, plugin: Result<Plugin, io::Error>) {
|
|
||||||
if let Some(state) = self.upgrade() {
|
|
||||||
state.lock().plugin_connect(plugin)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn plugin_exit(&self, plugin: PluginId, error: Result<(), ReadError>) {
|
|
||||||
if let Some(core) = self.upgrade() {
|
|
||||||
core.lock().plugin_disconnect(plugin, error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Handler for WeakSidecarState {
|
|
||||||
type Request = PluginCommand<String>;
|
|
||||||
|
|
||||||
fn handle_request(
|
|
||||||
&mut self,
|
|
||||||
_ctx: &RpcCtx,
|
|
||||||
rpc: Self::Request,
|
|
||||||
) -> Result<ResponsePayload, RemoteError> {
|
|
||||||
trace!("handling request: {:?}", rpc.cmd);
|
|
||||||
Ok(ResponsePayload::empty_json())
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user