mirror of
https://github.com/AppFlowy-IO/AppFlowy.git
synced 2024-08-30 18:12:39 +00:00
chore: save chat config
This commit is contained in:
3
frontend/rust-lib/Cargo.lock
generated
3
frontend/rust-lib/Cargo.lock
generated
@ -1669,6 +1669,7 @@ dependencies = [
|
||||
"flowy-derive",
|
||||
"flowy-error",
|
||||
"flowy-notification",
|
||||
"flowy-sidecar",
|
||||
"flowy-sqlite",
|
||||
"futures",
|
||||
"lib-dispatch",
|
||||
@ -2147,6 +2148,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"crossbeam-utils",
|
||||
"dotenv",
|
||||
"lib-infra",
|
||||
"log",
|
||||
"once_cell",
|
||||
"parking_lot 0.12.1",
|
||||
@ -3115,6 +3117,7 @@ dependencies = [
|
||||
"atomic_refcell",
|
||||
"brotli",
|
||||
"bytes",
|
||||
"cfg-if",
|
||||
"chrono",
|
||||
"futures",
|
||||
"futures-core",
|
||||
|
@ -68,6 +68,7 @@ collab-integrate = { workspace = true, path = "collab-integrate" }
|
||||
flowy-date = { workspace = true, path = "flowy-date" }
|
||||
flowy-chat = { workspace = true, path = "flowy-chat" }
|
||||
flowy-chat-pub = { workspace = true, path = "flowy-chat-pub" }
|
||||
flowy-sidecar = { workspace = true, path = "flowy-sidecar" }
|
||||
anyhow = "1.0"
|
||||
tracing = "0.1.40"
|
||||
bytes = "1.5.0"
|
||||
|
@ -27,6 +27,7 @@ tokio.workspace = true
|
||||
futures.workspace = true
|
||||
allo-isolate = { version = "^0.1", features = ["catch-unwind"] }
|
||||
log = "0.4.21"
|
||||
flowy-sidecar = { workspace = true }
|
||||
|
||||
[build-dependencies]
|
||||
flowy-codegen.workspace = true
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::entities::{
|
||||
ChatMessageErrorPB, ChatMessageListPB, ChatMessagePB, RepeatedRelatedQuestionPB,
|
||||
};
|
||||
use crate::manager::ChatUserService;
|
||||
use crate::manager::{ChatService, ChatUserService};
|
||||
use crate::notification::{send_notification, ChatNotification};
|
||||
use crate::persistence::{insert_chat_messages, select_chat_messages, ChatMessageTable};
|
||||
use allo_isolate::Isolate;
|
||||
@ -25,7 +25,7 @@ pub struct Chat {
|
||||
chat_id: String,
|
||||
uid: i64,
|
||||
user_service: Arc<dyn ChatUserService>,
|
||||
cloud_service: Arc<dyn ChatCloudService>,
|
||||
chat_service: Arc<ChatService>,
|
||||
prev_message_state: Arc<RwLock<PrevMessageState>>,
|
||||
latest_message_id: Arc<AtomicI64>,
|
||||
stop_stream: Arc<AtomicBool>,
|
||||
@ -37,12 +37,12 @@ impl Chat {
|
||||
uid: i64,
|
||||
chat_id: String,
|
||||
user_service: Arc<dyn ChatUserService>,
|
||||
cloud_service: Arc<dyn ChatCloudService>,
|
||||
chat_service: Arc<ChatService>,
|
||||
) -> Chat {
|
||||
Chat {
|
||||
uid,
|
||||
chat_id,
|
||||
cloud_service,
|
||||
chat_service,
|
||||
user_service,
|
||||
prev_message_state: Arc::new(RwLock::new(PrevMessageState::HasMore)),
|
||||
latest_message_id: Default::default(),
|
||||
@ -92,7 +92,7 @@ impl Chat {
|
||||
let workspace_id = self.user_service.workspace_id()?;
|
||||
|
||||
let question = self
|
||||
.cloud_service
|
||||
.chat_service
|
||||
.send_question(&workspace_id, &self.chat_id, message, message_type)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
@ -109,7 +109,7 @@ impl Chat {
|
||||
let stop_stream = self.stop_stream.clone();
|
||||
let chat_id = self.chat_id.clone();
|
||||
let question_id = question.message_id;
|
||||
let cloud_service = self.cloud_service.clone();
|
||||
let cloud_service = self.chat_service.clone();
|
||||
let user_service = self.user_service.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut text_sink = IsolateSink::new(Isolate::new(text_stream_port));
|
||||
@ -300,7 +300,7 @@ impl Chat {
|
||||
);
|
||||
let workspace_id = self.user_service.workspace_id()?;
|
||||
let chat_id = self.chat_id.clone();
|
||||
let cloud_service = self.cloud_service.clone();
|
||||
let cloud_service = self.chat_service.clone();
|
||||
let user_service = self.user_service.clone();
|
||||
let uid = self.uid;
|
||||
let prev_message_state = self.prev_message_state.clone();
|
||||
@ -369,7 +369,7 @@ impl Chat {
|
||||
) -> Result<RepeatedRelatedQuestionPB, FlowyError> {
|
||||
let workspace_id = self.user_service.workspace_id()?;
|
||||
let resp = self
|
||||
.cloud_service
|
||||
.chat_service
|
||||
.get_related_message(&workspace_id, &self.chat_id, message_id)
|
||||
.await?;
|
||||
|
||||
@ -391,7 +391,7 @@ impl Chat {
|
||||
);
|
||||
let workspace_id = self.user_service.workspace_id()?;
|
||||
let answer = self
|
||||
.cloud_service
|
||||
.chat_service
|
||||
.generate_answer(&workspace_id, &self.chat_id, question_message_id)
|
||||
.await?;
|
||||
|
||||
|
@ -2,10 +2,16 @@ use crate::chat::Chat;
|
||||
use crate::entities::{ChatMessageListPB, ChatMessagePB, RepeatedRelatedQuestionPB};
|
||||
use crate::persistence::{insert_chat, ChatTable};
|
||||
use dashmap::DashMap;
|
||||
use flowy_chat_pub::cloud::{ChatCloudService, ChatMessageType};
|
||||
use flowy_chat_pub::cloud::{
|
||||
ChatCloudService, ChatMessage, ChatMessageStream, ChatMessageType, MessageCursor,
|
||||
RepeatedChatMessage, RepeatedRelatedQuestion, StreamAnswer,
|
||||
};
|
||||
use flowy_error::{FlowyError, FlowyResult};
|
||||
use flowy_sidecar::manager::SidecarManager;
|
||||
use flowy_sqlite::DBConnection;
|
||||
use lib_infra::future::FutureResult;
|
||||
use lib_infra::util::timestamp;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use tracing::trace;
|
||||
|
||||
@ -17,7 +23,7 @@ pub trait ChatUserService: Send + Sync + 'static {
|
||||
}
|
||||
|
||||
pub struct ChatManager {
|
||||
cloud_service: Arc<dyn ChatCloudService>,
|
||||
chat_service: Arc<ChatService>,
|
||||
user_service: Arc<dyn ChatUserService>,
|
||||
chats: Arc<DashMap<String, Arc<Chat>>>,
|
||||
}
|
||||
@ -27,10 +33,12 @@ impl ChatManager {
|
||||
cloud_service: Arc<dyn ChatCloudService>,
|
||||
user_service: impl ChatUserService,
|
||||
) -> ChatManager {
|
||||
let sidecar_manager = Arc::new(SidecarManager::new());
|
||||
let chat_service = Arc::new(ChatService::new(cloud_service, sidecar_manager));
|
||||
let user_service = Arc::new(user_service);
|
||||
|
||||
Self {
|
||||
cloud_service,
|
||||
chat_service,
|
||||
user_service,
|
||||
chats: Arc::new(DashMap::new()),
|
||||
}
|
||||
@ -43,7 +51,7 @@ impl ChatManager {
|
||||
self.user_service.user_id().unwrap(),
|
||||
chat_id.to_string(),
|
||||
self.user_service.clone(),
|
||||
self.cloud_service.clone(),
|
||||
self.chat_service.clone(),
|
||||
))
|
||||
});
|
||||
|
||||
@ -64,7 +72,7 @@ impl ChatManager {
|
||||
pub async fn create_chat(&self, uid: &i64, chat_id: &str) -> Result<Arc<Chat>, FlowyError> {
|
||||
let workspace_id = self.user_service.workspace_id()?;
|
||||
self
|
||||
.cloud_service
|
||||
.chat_service
|
||||
.create_chat(uid, &workspace_id, chat_id)
|
||||
.await?;
|
||||
save_chat(self.user_service.sqlite_connection(*uid)?, chat_id)?;
|
||||
@ -73,7 +81,7 @@ impl ChatManager {
|
||||
self.user_service.user_id().unwrap(),
|
||||
chat_id.to_string(),
|
||||
self.user_service.clone(),
|
||||
self.cloud_service.clone(),
|
||||
self.chat_service.clone(),
|
||||
));
|
||||
self.chats.insert(chat_id.to_string(), chat.clone());
|
||||
Ok(chat)
|
||||
@ -101,7 +109,7 @@ impl ChatManager {
|
||||
self.user_service.user_id().unwrap(),
|
||||
chat_id.to_string(),
|
||||
self.user_service.clone(),
|
||||
self.cloud_service.clone(),
|
||||
self.chat_service.clone(),
|
||||
));
|
||||
self.chats.insert(chat_id.to_string(), chat.clone());
|
||||
Ok(chat)
|
||||
@ -183,8 +191,107 @@ fn save_chat(conn: DBConnection, chat_id: &str) -> FlowyResult<()> {
|
||||
chat_id: chat_id.to_string(),
|
||||
created_at: timestamp(),
|
||||
name: "".to_string(),
|
||||
local_model_path: "".to_string(),
|
||||
local_model_name: "".to_string(),
|
||||
local_enabled: false,
|
||||
sync_to_cloud: true,
|
||||
};
|
||||
|
||||
insert_chat(conn, &row)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct ChatService {
|
||||
cloud_service: Arc<dyn ChatCloudService>,
|
||||
sidecar_manager: Arc<SidecarManager>,
|
||||
}
|
||||
|
||||
impl ChatService {
|
||||
pub fn new(
|
||||
cloud_service: Arc<dyn ChatCloudService>,
|
||||
sidecar_manager: Arc<SidecarManager>,
|
||||
) -> Self {
|
||||
Self {
|
||||
cloud_service,
|
||||
sidecar_manager,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ChatCloudService for ChatService {
|
||||
fn create_chat(
|
||||
&self,
|
||||
uid: &i64,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
) -> FutureResult<(), FlowyError> {
|
||||
self.cloud_service.create_chat(uid, workspace_id, chat_id)
|
||||
}
|
||||
|
||||
async fn send_chat_message(
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
message: &str,
|
||||
message_type: ChatMessageType,
|
||||
) -> Result<ChatMessageStream, FlowyError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn send_question(
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
message: &str,
|
||||
message_type: ChatMessageType,
|
||||
) -> FutureResult<ChatMessage, FlowyError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn save_answer(
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
message: &str,
|
||||
question_id: i64,
|
||||
) -> FutureResult<ChatMessage, FlowyError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn stream_answer(
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
message_id: i64,
|
||||
) -> Result<StreamAnswer, FlowyError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn get_chat_messages(
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
offset: MessageCursor,
|
||||
limit: u64,
|
||||
) -> FutureResult<RepeatedChatMessage, FlowyError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn get_related_message(
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
message_id: i64,
|
||||
) -> FutureResult<RepeatedRelatedQuestion, FlowyError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn generate_answer(
|
||||
&self,
|
||||
workspace_id: &str,
|
||||
chat_id: &str,
|
||||
question_message_id: i64,
|
||||
) -> FutureResult<ChatMessage, FlowyError> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,10 @@
|
||||
use diesel::sqlite::SqliteConnection;
|
||||
use flowy_sqlite::upsert::excluded;
|
||||
use flowy_sqlite::{
|
||||
diesel,
|
||||
query_dsl::*,
|
||||
schema::{chat_table, chat_table::dsl},
|
||||
DBConnection, ExpressionMethods, Identifiable, Insertable, QueryResult, Queryable,
|
||||
AsChangeset, DBConnection, ExpressionMethods, Identifiable, Insertable, QueryResult, Queryable,
|
||||
};
|
||||
|
||||
#[derive(Clone, Default, Queryable, Insertable, Identifiable)]
|
||||
@ -13,6 +14,22 @@ pub struct ChatTable {
|
||||
pub chat_id: String,
|
||||
pub created_at: i64,
|
||||
pub name: String,
|
||||
pub local_model_path: String,
|
||||
pub local_model_name: String,
|
||||
pub local_enabled: bool,
|
||||
pub sync_to_cloud: bool,
|
||||
}
|
||||
|
||||
#[derive(AsChangeset, Identifiable, Default, Debug)]
|
||||
#[diesel(table_name = chat_table)]
|
||||
#[diesel(primary_key(chat_id))]
|
||||
pub struct ChatTableChangeset {
|
||||
pub chat_id: String,
|
||||
pub name: Option<String>,
|
||||
pub local_model_path: Option<String>,
|
||||
pub local_model_name: Option<String>,
|
||||
pub local_enabled: Option<bool>,
|
||||
pub sync_to_cloud: Option<bool>,
|
||||
}
|
||||
|
||||
pub fn insert_chat(mut conn: DBConnection, new_chat: &ChatTable) -> QueryResult<usize> {
|
||||
@ -27,6 +44,15 @@ pub fn insert_chat(mut conn: DBConnection, new_chat: &ChatTable) -> QueryResult<
|
||||
.execute(&mut *conn)
|
||||
}
|
||||
|
||||
pub fn update_chat_local_model(
|
||||
conn: &mut SqliteConnection,
|
||||
changeset: ChatTableChangeset,
|
||||
) -> QueryResult<usize> {
|
||||
let filter = dsl::chat_table.filter(chat_table::chat_id.eq(changeset.chat_id.clone()));
|
||||
let affected_row = diesel::update(filter).set(changeset).execute(conn)?;
|
||||
Ok(affected_row)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn read_chat(mut conn: DBConnection, chat_id_val: &str) -> QueryResult<ChatTable> {
|
||||
let row = dsl::chat_table
|
||||
|
@ -9,7 +9,7 @@ use flowy_server_pub::af_cloud_config::AFCloudConfiguration;
|
||||
use flowy_server_pub::supabase_config::SupabaseConfiguration;
|
||||
use flowy_user::services::entities::URL_SAFE_ENGINE;
|
||||
use lib_infra::file_util::copy_dir_recursive;
|
||||
use lib_infra::util::Platform;
|
||||
use lib_infra::util::OperatingSystem;
|
||||
|
||||
use crate::integrate::log::create_log_filter;
|
||||
|
||||
@ -94,7 +94,7 @@ impl AppFlowyCoreConfig {
|
||||
},
|
||||
Some(config) => make_user_data_folder(&custom_application_path, &config.base_url),
|
||||
};
|
||||
let log_filter = create_log_filter("info".to_owned(), vec![], Platform::from(&platform));
|
||||
let log_filter = create_log_filter("info".to_owned(), vec![], OperatingSystem::from(&platform));
|
||||
|
||||
AppFlowyCoreConfig {
|
||||
app_version,
|
||||
@ -112,7 +112,7 @@ impl AppFlowyCoreConfig {
|
||||
self.log_filter = create_log_filter(
|
||||
level.to_owned(),
|
||||
with_crates,
|
||||
Platform::from(&self.platform),
|
||||
OperatingSystem::from(&self.platform),
|
||||
);
|
||||
self
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use lib_infra::util::Platform;
|
||||
use lib_infra::util::OperatingSystem;
|
||||
use lib_log::stream_log::StreamLogSender;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
@ -8,7 +8,7 @@ use crate::AppFlowyCoreConfig;
|
||||
static INIT_LOG: AtomicBool = AtomicBool::new(false);
|
||||
pub(crate) fn init_log(
|
||||
config: &AppFlowyCoreConfig,
|
||||
platform: &Platform,
|
||||
platform: &OperatingSystem,
|
||||
stream_log_sender: Option<Arc<dyn StreamLogSender>>,
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
@ -25,11 +25,15 @@ pub(crate) fn init_log(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_log_filter(level: String, with_crates: Vec<String>, platform: Platform) -> String {
|
||||
pub fn create_log_filter(
|
||||
level: String,
|
||||
with_crates: Vec<String>,
|
||||
platform: OperatingSystem,
|
||||
) -> String {
|
||||
let mut level = std::env::var("RUST_LOG").unwrap_or(level);
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
if matches!(platform, Platform::IOS) {
|
||||
if matches!(platform, OperatingSystem::IOS) {
|
||||
level = "trace".to_string();
|
||||
}
|
||||
|
||||
|
@ -25,7 +25,7 @@ use flowy_user::user_manager::UserManager;
|
||||
use lib_dispatch::prelude::*;
|
||||
use lib_dispatch::runtime::AFPluginRuntime;
|
||||
use lib_infra::priority_task::{TaskDispatcher, TaskRunner};
|
||||
use lib_infra::util::Platform;
|
||||
use lib_infra::util::OperatingSystem;
|
||||
use lib_log::stream_log::StreamLogSender;
|
||||
use module::make_plugins;
|
||||
|
||||
@ -69,7 +69,7 @@ impl AppFlowyCore {
|
||||
runtime: Arc<AFPluginRuntime>,
|
||||
stream_log_sender: Option<Arc<dyn StreamLogSender>>,
|
||||
) -> Self {
|
||||
let platform = Platform::from(&config.platform);
|
||||
let platform = OperatingSystem::from(&config.platform);
|
||||
|
||||
#[allow(clippy::if_same_then_else)]
|
||||
if cfg!(debug_assertions) {
|
||||
|
@ -16,6 +16,7 @@ tracing.workspace = true
|
||||
crossbeam-utils = "0.8.20"
|
||||
log = "0.4.21"
|
||||
parking_lot.workspace = true
|
||||
lib-infra.workspace = true
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
|
@ -2,3 +2,4 @@
|
||||
CHAT_BIN_PATH=
|
||||
LOCAL_AI_ROOT_PATH=
|
||||
LOCAL_AI_CHAT_MODEL_NAME=
|
||||
LOCAL_AI_EMBEDDING_MODEL_NAME=
|
||||
|
@ -17,6 +17,9 @@ pub enum Error {
|
||||
/// The peer sent a response containing the id, but was malformed.
|
||||
#[error("Invalid response.")]
|
||||
InvalidResponse,
|
||||
|
||||
#[error(transparent)]
|
||||
Internal(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -1,9 +1,10 @@
|
||||
use crate::error::{ReadError, RemoteError};
|
||||
use crate::error::{Error, ReadError, RemoteError};
|
||||
use crate::parser::ResponseParser;
|
||||
use crate::plugin::{start_plugin_process, Plugin, PluginId, PluginInfo, RpcCtx};
|
||||
use crate::rpc_loop::Handler;
|
||||
use crate::rpc_peer::PluginCommand;
|
||||
use anyhow::{anyhow, Result};
|
||||
use anyhow::anyhow;
|
||||
use lib_infra::util::{get_operating_system, OperatingSystem};
|
||||
use parking_lot::Mutex;
|
||||
use serde_json::{json, Value};
|
||||
use std::io;
|
||||
@ -14,6 +15,7 @@ use tracing::{trace, warn};
|
||||
pub struct SidecarManager {
|
||||
state: Arc<Mutex<SidecarState>>,
|
||||
plugin_id_counter: Arc<AtomicI64>,
|
||||
operating_system: OperatingSystem,
|
||||
}
|
||||
|
||||
impl SidecarManager {
|
||||
@ -23,23 +25,46 @@ impl SidecarManager {
|
||||
plugins: Vec::new(),
|
||||
})),
|
||||
plugin_id_counter: Arc::new(Default::default()),
|
||||
operating_system: get_operating_system(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result<PluginId> {
|
||||
pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result<PluginId, Error> {
|
||||
if self.operating_system.is_not_desktop() {
|
||||
return Err(Error::Internal(anyhow!(
|
||||
"plugin not supported on this platform"
|
||||
)));
|
||||
}
|
||||
let plugin_id = PluginId::from(self.plugin_id_counter.fetch_add(1, Ordering::SeqCst));
|
||||
let weak_state = WeakSidecarState(Arc::downgrade(&self.state));
|
||||
start_plugin_process(plugin_info, plugin_id, weak_state).await?;
|
||||
Ok(plugin_id)
|
||||
}
|
||||
|
||||
pub async fn remove_plugin(&self, id: PluginId) -> Result<()> {
|
||||
let mut state = self.state.lock();
|
||||
state.plugin_disconnect(id, Ok(()));
|
||||
pub async fn remove_plugin(&self, id: PluginId) -> Result<(), Error> {
|
||||
if self.operating_system.is_not_desktop() {
|
||||
return Err(Error::Internal(anyhow!(
|
||||
"plugin not supported on this platform"
|
||||
)));
|
||||
}
|
||||
|
||||
let state = self.state.lock();
|
||||
let plugin = state
|
||||
.plugins
|
||||
.iter()
|
||||
.find(|p| p.id == id)
|
||||
.ok_or(anyhow!("plugin not found"))?;
|
||||
plugin.shutdown();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<()> {
|
||||
pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<(), Error> {
|
||||
if self.operating_system.is_not_desktop() {
|
||||
return Err(Error::Internal(anyhow!(
|
||||
"plugin not supported on this platform"
|
||||
)));
|
||||
}
|
||||
|
||||
let state = self.state.lock();
|
||||
let plugin = state
|
||||
.plugins
|
||||
@ -56,7 +81,7 @@ impl SidecarManager {
|
||||
id: PluginId,
|
||||
method: &str,
|
||||
request: Value,
|
||||
) -> Result<P::ValueType> {
|
||||
) -> Result<P::ValueType, Error> {
|
||||
let state = self.state.lock();
|
||||
let plugin = state
|
||||
.plugins
|
||||
@ -67,6 +92,23 @@ impl SidecarManager {
|
||||
let value = P::parse_response(resp)?;
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
pub async fn async_send_request<P: ResponseParser>(
|
||||
&self,
|
||||
id: PluginId,
|
||||
method: &str,
|
||||
request: Value,
|
||||
) -> Result<P::ValueType, Error> {
|
||||
let state = self.state.lock();
|
||||
let plugin = state
|
||||
.plugins
|
||||
.iter()
|
||||
.find(|p| p.id == id)
|
||||
.ok_or(anyhow!("plugin not found"))?;
|
||||
let resp = plugin.async_send_request(method, &request).await?;
|
||||
let value = P::parse_response(resp)?;
|
||||
Ok(value)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SidecarState {
|
||||
|
@ -75,3 +75,19 @@ impl ResponseParser for ChatResponseParser {
|
||||
return Err(RemoteError::InvalidResponse(json));
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SimilarityResponseParser;
|
||||
impl ResponseParser for SimilarityResponseParser {
|
||||
type ValueType = f64;
|
||||
|
||||
fn parse_response(json: Value) -> Result<Self::ValueType, RemoteError> {
|
||||
if json.is_object() {
|
||||
if let Some(data) = json.get("data") {
|
||||
if let Some(score) = data.get("score").and_then(|v| v.as_f64()) {
|
||||
return Ok(score);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Err(RemoteError::InvalidResponse(json));
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ use std::io::BufReader;
|
||||
|
||||
use std::process::{Child, Stdio};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use std::thread;
|
||||
use std::time::Instant;
|
||||
use tracing::{error, info};
|
||||
@ -26,6 +27,13 @@ impl From<i64> for PluginId {
|
||||
pub trait Callback: Send {
|
||||
fn call(self: Box<Self>, result: Result<Value, Error>);
|
||||
}
|
||||
|
||||
impl<F: Send + FnOnce(Result<Value, Error>)> Callback for F {
|
||||
fn call(self: Box<F>, result: Result<Value, Error>) {
|
||||
(*self)(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// The `Peer` trait defines the interface for the opposite side of the RPC channel,
|
||||
/// designed to be used behind a pointer or as a trait object.
|
||||
pub trait Peer: Send + 'static {
|
||||
@ -74,6 +82,21 @@ impl Plugin {
|
||||
self.peer.send_rpc_request(method, params)
|
||||
}
|
||||
|
||||
pub async fn async_send_request(&self, method: &str, params: &Value) -> Result<Value, Error> {
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
self.peer.send_rpc_request_async(
|
||||
method,
|
||||
params,
|
||||
Box::new(move |result| {
|
||||
let _ = tx.send(result);
|
||||
}),
|
||||
);
|
||||
let value = rx
|
||||
.await
|
||||
.map_err(|err| Error::Internal(anyhow!("error waiting for async response: {:?}", err)))??;
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
pub fn shutdown(&self) {
|
||||
if let Err(err) = self.peer.send_rpc_request("shutdown", &json!({})) {
|
||||
error!("error sending shutdown to plugin {}: {:?}", self.name, err);
|
||||
|
@ -38,7 +38,6 @@ impl RpcObject {
|
||||
return Err("RPC response must contain exactly one of 'error' or 'result' fields.".into());
|
||||
}
|
||||
let result = self.0.as_object_mut().and_then(|obj| obj.remove("result"));
|
||||
|
||||
match result {
|
||||
Some(r) => Ok(Ok(r)),
|
||||
None => {
|
||||
|
@ -1,94 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use flowy_sidecar::manager::SidecarManager;
|
||||
use flowy_sidecar::parser::ChatResponseParser;
|
||||
use flowy_sidecar::plugin::PluginInfo;
|
||||
use serde_json::json;
|
||||
use std::sync::Once;
|
||||
|
||||
use tracing_subscriber::fmt::Subscriber;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
#[tokio::test]
|
||||
async fn load_chat_model_test() {
|
||||
if let Ok(config) = LocalAIConfiguration::new() {
|
||||
let manager = SidecarManager::new();
|
||||
let info = PluginInfo {
|
||||
name: "chat".to_string(),
|
||||
exec_path: config.chat_bin_path.clone(),
|
||||
};
|
||||
let plugin_id = manager.create_plugin(info).await.unwrap();
|
||||
manager
|
||||
.init_plugin(
|
||||
plugin_id,
|
||||
json!({
|
||||
"absolute_chat_model_path":config.chat_model_absolute_path(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let _json = json!({
|
||||
"plugin_id": "example_plugin_id",
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"absolute_chat_model_path":config.chat_model_absolute_path(),
|
||||
}
|
||||
});
|
||||
|
||||
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
let resp = manager
|
||||
.send_request::<ChatResponseParser>(
|
||||
plugin_id,
|
||||
"handle",
|
||||
json!({"chat_id": chat_id, "method": "answer", "params": {"content": "hello world"}}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
eprintln!("chat response: {:?}", resp);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LocalAIConfiguration {
|
||||
root: String,
|
||||
chat_bin_path: String,
|
||||
chat_model_name: String,
|
||||
}
|
||||
|
||||
impl LocalAIConfiguration {
|
||||
pub fn new() -> Result<Self> {
|
||||
dotenv::dotenv().ok();
|
||||
setup_log();
|
||||
|
||||
// load from .env
|
||||
let root = dotenv::var("LOCAL_AI_ROOT_PATH")?;
|
||||
let chat_bin_path = dotenv::var("CHAT_BIN_PATH")?;
|
||||
let chat_model = dotenv::var("LOCAL_AI_CHAT_MODEL_NAME")?;
|
||||
|
||||
Ok(Self {
|
||||
root,
|
||||
chat_bin_path,
|
||||
chat_model_name: chat_model,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn chat_model_absolute_path(&self) -> String {
|
||||
format!("{}/{}", self.root, self.chat_model_name)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn setup_log() {
|
||||
static START: Once = Once::new();
|
||||
START.call_once(|| {
|
||||
let level = "trace";
|
||||
let mut filters = vec![];
|
||||
filters.push(format!("flowy_sidecar={}", level));
|
||||
std::env::set_var("RUST_LOG", filters.join(","));
|
||||
|
||||
let subscriber = Subscriber::builder()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.with_line_number(true)
|
||||
.with_ansi(true)
|
||||
.finish();
|
||||
subscriber.try_init().unwrap();
|
||||
});
|
||||
}
|
15
frontend/rust-lib/flowy-sidecar/tests/chat_test/mod.rs
Normal file
15
frontend/rust-lib/flowy-sidecar/tests/chat_test/mod.rs
Normal file
@ -0,0 +1,15 @@
|
||||
use crate::util::LocalAITest;
|
||||
|
||||
#[tokio::test]
|
||||
async fn load_chat_model_test() {
|
||||
if let Ok(test) = LocalAITest::new() {
|
||||
let plugin_id = test.init_chat_plugin().await;
|
||||
let chat_id = uuid::Uuid::new_v4().to_string();
|
||||
let resp = test.send_message(&chat_id, plugin_id, "hello world").await;
|
||||
eprintln!("chat response: {:?}", resp);
|
||||
|
||||
let embedding_plugin_id = test.init_embedding_plugin().await;
|
||||
let score = test.calculate_similarity(embedding_plugin_id, &resp, "Hello! How can I help you today? Is there something specific you would like to know or discuss").await;
|
||||
assert!(score > 0.8);
|
||||
}
|
||||
}
|
2
frontend/rust-lib/flowy-sidecar/tests/main.rs
Normal file
2
frontend/rust-lib/flowy-sidecar/tests/main.rs
Normal file
@ -0,0 +1,2 @@
|
||||
pub mod chat_test;
|
||||
pub mod util;
|
148
frontend/rust-lib/flowy-sidecar/tests/util.rs
Normal file
148
frontend/rust-lib/flowy-sidecar/tests/util.rs
Normal file
@ -0,0 +1,148 @@
|
||||
use anyhow::Result;
|
||||
use flowy_sidecar::manager::SidecarManager;
|
||||
use flowy_sidecar::parser::{ChatResponseParser, SimilarityResponseParser};
|
||||
use flowy_sidecar::plugin::{PluginId, PluginInfo};
|
||||
use serde_json::json;
|
||||
use std::sync::Once;
|
||||
|
||||
use tracing_subscriber::fmt::Subscriber;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
pub struct LocalAITest {
|
||||
config: LocalAIConfiguration,
|
||||
manager: SidecarManager,
|
||||
}
|
||||
|
||||
impl LocalAITest {
|
||||
pub fn new() -> Result<Self> {
|
||||
let config = LocalAIConfiguration::new()?;
|
||||
let manager = SidecarManager::new();
|
||||
|
||||
Ok(Self { config, manager })
|
||||
}
|
||||
pub async fn init_chat_plugin(&self) -> PluginId {
|
||||
let info = PluginInfo {
|
||||
name: "chat".to_string(),
|
||||
exec_path: self.config.chat_bin_path.clone(),
|
||||
};
|
||||
let plugin_id = self.manager.create_plugin(info).await.unwrap();
|
||||
self
|
||||
.manager
|
||||
.init_plugin(
|
||||
plugin_id,
|
||||
json!({
|
||||
"absolute_chat_model_path":self.config.chat_model_absolute_path(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
plugin_id
|
||||
}
|
||||
|
||||
pub async fn init_embedding_plugin(&self) -> PluginId {
|
||||
let info = PluginInfo {
|
||||
name: "embedding".to_string(),
|
||||
exec_path: self.config.embedding_bin_path.clone(),
|
||||
};
|
||||
let plugin_id = self.manager.create_plugin(info).await.unwrap();
|
||||
let embedding_model_path = self.config.embedding_model_absolute_path();
|
||||
self
|
||||
.manager
|
||||
.init_plugin(
|
||||
plugin_id,
|
||||
json!({
|
||||
"absolute_model_path":embedding_model_path,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
plugin_id
|
||||
}
|
||||
|
||||
pub async fn send_message(&self, chat_id: &str, plugin_id: PluginId, message: &str) -> String {
|
||||
let resp = self
|
||||
.manager
|
||||
.async_send_request::<ChatResponseParser>(
|
||||
plugin_id,
|
||||
"handle",
|
||||
json!({"chat_id": chat_id, "method": "answer", "params": {"content": message}}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
resp
|
||||
}
|
||||
|
||||
pub async fn calculate_similarity(
|
||||
&self,
|
||||
plugin_id: PluginId,
|
||||
message1: &str,
|
||||
message2: &str,
|
||||
) -> f64 {
|
||||
self
|
||||
.manager
|
||||
.async_send_request::<SimilarityResponseParser>(
|
||||
plugin_id,
|
||||
"handle",
|
||||
json!({"method": "calculate_similarity", "params": {"src": message1, "dest": message2}}),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LocalAIConfiguration {
|
||||
root: String,
|
||||
chat_bin_path: String,
|
||||
chat_model_name: String,
|
||||
embedding_bin_path: String,
|
||||
embedding_model_name: String,
|
||||
}
|
||||
|
||||
impl LocalAIConfiguration {
|
||||
pub fn new() -> Result<Self> {
|
||||
dotenv::dotenv().ok();
|
||||
setup_log();
|
||||
|
||||
// load from .env
|
||||
let root = dotenv::var("LOCAL_AI_ROOT_PATH")?;
|
||||
let chat_bin_path = dotenv::var("CHAT_BIN_PATH")?;
|
||||
let chat_model_name = dotenv::var("LOCAL_AI_CHAT_MODEL_NAME")?;
|
||||
|
||||
let embedding_bin_path = dotenv::var("EMBEDDING_BIN_PATH")?;
|
||||
let embedding_model_name = dotenv::var("LOCAL_AI_EMBEDDING_MODEL_NAME")?;
|
||||
|
||||
Ok(Self {
|
||||
root,
|
||||
chat_bin_path,
|
||||
chat_model_name,
|
||||
embedding_bin_path,
|
||||
embedding_model_name,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn chat_model_absolute_path(&self) -> String {
|
||||
format!("{}/{}", self.root, self.chat_model_name)
|
||||
}
|
||||
|
||||
pub fn embedding_model_absolute_path(&self) -> String {
|
||||
format!("{}/{}", self.root, self.embedding_model_name)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn setup_log() {
|
||||
static START: Once = Once::new();
|
||||
START.call_once(|| {
|
||||
let level = "trace";
|
||||
let mut filters = vec![];
|
||||
filters.push(format!("flowy_sidecar={}", level));
|
||||
std::env::set_var("RUST_LOG", filters.join(","));
|
||||
|
||||
let subscriber = Subscriber::builder()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.with_line_number(true)
|
||||
.with_ansi(true)
|
||||
.finish();
|
||||
subscriber.try_init().unwrap();
|
||||
});
|
||||
}
|
@ -0,0 +1 @@
|
||||
-- This file should undo anything in `up.sql`
|
@ -0,0 +1,13 @@
|
||||
-- Your SQL goes here
|
||||
ALTER TABLE chat_table ADD COLUMN local_model_path TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE chat_table ADD COLUMN local_model_name TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE chat_table ADD COLUMN local_enabled BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
ALTER TABLE chat_table ADD COLUMN sync_to_cloud BOOLEAN NOT NULL DEFAULT TRUE;
|
||||
|
||||
|
||||
CREATE TABLE chat_local_setting_table
|
||||
(
|
||||
chat_id TEXT PRIMARY KEY NOT NULL,
|
||||
local_model_path TEXT NOT NULL,
|
||||
local_model_name TEXT NOT NULL DEFAULT ''
|
||||
);
|
@ -1,5 +1,13 @@
|
||||
// @generated automatically by Diesel CLI.
|
||||
|
||||
diesel::table! {
|
||||
chat_local_setting_table (chat_id) {
|
||||
chat_id -> Text,
|
||||
local_model_path -> Text,
|
||||
local_model_name -> Text,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
chat_message_table (message_id) {
|
||||
message_id -> BigInt,
|
||||
@ -17,6 +25,10 @@ diesel::table! {
|
||||
chat_id -> Text,
|
||||
created_at -> BigInt,
|
||||
name -> Text,
|
||||
local_model_path -> Text,
|
||||
local_model_name -> Text,
|
||||
local_enabled -> Bool,
|
||||
sync_to_cloud -> Bool,
|
||||
}
|
||||
}
|
||||
|
||||
@ -102,6 +114,7 @@ diesel::table! {
|
||||
}
|
||||
|
||||
diesel::allow_tables_to_appear_in_same_query!(
|
||||
chat_local_setting_table,
|
||||
chat_message_table,
|
||||
chat_table,
|
||||
collab_snapshot,
|
||||
|
@ -23,6 +23,7 @@ tracing.workspace = true
|
||||
atomic_refcell = "0.1"
|
||||
allo-isolate = { version = "^0.1", features = ["catch-unwind"], optional = true }
|
||||
futures = "0.3.30"
|
||||
cfg-if = "1.0.0"
|
||||
|
||||
[dev-dependencies]
|
||||
rand = "0.8.5"
|
||||
|
@ -67,9 +67,8 @@ pub fn md5<T: AsRef<[u8]>>(data: T) -> String {
|
||||
let md5 = format!("{:x}", md5::compute(data));
|
||||
md5
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum Platform {
|
||||
pub enum OperatingSystem {
|
||||
Unknown,
|
||||
Windows,
|
||||
Linux,
|
||||
@ -78,33 +77,62 @@ pub enum Platform {
|
||||
Android,
|
||||
}
|
||||
|
||||
impl Platform {
|
||||
impl OperatingSystem {
|
||||
pub fn is_not_ios(&self) -> bool {
|
||||
!matches!(self, Platform::IOS)
|
||||
!matches!(self, OperatingSystem::IOS)
|
||||
}
|
||||
|
||||
pub fn is_desktop(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
OperatingSystem::Windows | OperatingSystem::Linux | OperatingSystem::MacOS
|
||||
)
|
||||
}
|
||||
|
||||
pub fn is_not_desktop(&self) -> bool {
|
||||
!self.is_desktop()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for Platform {
|
||||
impl From<String> for OperatingSystem {
|
||||
fn from(s: String) -> Self {
|
||||
Platform::from(s.as_str())
|
||||
OperatingSystem::from(s.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&String> for Platform {
|
||||
impl From<&String> for OperatingSystem {
|
||||
fn from(s: &String) -> Self {
|
||||
Platform::from(s.as_str())
|
||||
OperatingSystem::from(s.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for Platform {
|
||||
impl From<&str> for OperatingSystem {
|
||||
fn from(s: &str) -> Self {
|
||||
match s {
|
||||
"windows" => Platform::Windows,
|
||||
"linux" => Platform::Linux,
|
||||
"macos" => Platform::MacOS,
|
||||
"ios" => Platform::IOS,
|
||||
"android" => Platform::Android,
|
||||
_ => Platform::Unknown,
|
||||
"windows" => OperatingSystem::Windows,
|
||||
"linux" => OperatingSystem::Linux,
|
||||
"macos" => OperatingSystem::MacOS,
|
||||
"ios" => OperatingSystem::IOS,
|
||||
"android" => OperatingSystem::Android,
|
||||
_ => OperatingSystem::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_operating_system() -> OperatingSystem {
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(target_os = "android")] {
|
||||
OperatingSystem::Android
|
||||
} else if #[cfg(target_os = "ios")] {
|
||||
OperatingSystem::IOS
|
||||
} else if #[cfg(target_os = "macos")] {
|
||||
OperatingSystem::MacOS
|
||||
} else if #[cfg(target_os = "windows")] {
|
||||
OperatingSystem::Windows
|
||||
} else if #[cfg(target_os = "linux")] {
|
||||
OperatingSystem::Linux
|
||||
} else {
|
||||
OperatingSystem::Unknown
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ use std::sync::{Arc, RwLock};
|
||||
|
||||
use chrono::Local;
|
||||
use lazy_static::lazy_static;
|
||||
use lib_infra::util::Platform;
|
||||
use lib_infra::util::OperatingSystem;
|
||||
use tracing::subscriber::set_global_default;
|
||||
use tracing_appender::rolling::Rotation;
|
||||
use tracing_appender::{non_blocking::WorkerGuard, rolling::RollingFileAppender};
|
||||
@ -29,7 +29,7 @@ pub struct Builder {
|
||||
env_filter: String,
|
||||
file_appender: RollingFileAppender,
|
||||
#[allow(dead_code)]
|
||||
platform: Platform,
|
||||
platform: OperatingSystem,
|
||||
stream_log_sender: Option<Arc<dyn StreamLogSender>>,
|
||||
}
|
||||
|
||||
@ -37,7 +37,7 @@ impl Builder {
|
||||
pub fn new(
|
||||
name: &str,
|
||||
directory: &str,
|
||||
platform: &Platform,
|
||||
platform: &OperatingSystem,
|
||||
stream_log_sender: Option<Arc<dyn StreamLogSender>>,
|
||||
) -> Self {
|
||||
let file_appender = RollingFileAppender::builder()
|
||||
|
Reference in New Issue
Block a user