diff --git a/frontend/rust-lib/Cargo.lock b/frontend/rust-lib/Cargo.lock index 9becfd0ae3..6d3e80a402 100644 --- a/frontend/rust-lib/Cargo.lock +++ b/frontend/rust-lib/Cargo.lock @@ -1669,6 +1669,7 @@ dependencies = [ "flowy-derive", "flowy-error", "flowy-notification", + "flowy-sidecar", "flowy-sqlite", "futures", "lib-dispatch", @@ -2147,6 +2148,7 @@ dependencies = [ "anyhow", "crossbeam-utils", "dotenv", + "lib-infra", "log", "once_cell", "parking_lot 0.12.1", @@ -3115,6 +3117,7 @@ dependencies = [ "atomic_refcell", "brotli", "bytes", + "cfg-if", "chrono", "futures", "futures-core", diff --git a/frontend/rust-lib/Cargo.toml b/frontend/rust-lib/Cargo.toml index 5ac6119944..5f721c120b 100644 --- a/frontend/rust-lib/Cargo.toml +++ b/frontend/rust-lib/Cargo.toml @@ -68,6 +68,7 @@ collab-integrate = { workspace = true, path = "collab-integrate" } flowy-date = { workspace = true, path = "flowy-date" } flowy-chat = { workspace = true, path = "flowy-chat" } flowy-chat-pub = { workspace = true, path = "flowy-chat-pub" } +flowy-sidecar = { workspace = true, path = "flowy-sidecar" } anyhow = "1.0" tracing = "0.1.40" bytes = "1.5.0" diff --git a/frontend/rust-lib/flowy-chat/Cargo.toml b/frontend/rust-lib/flowy-chat/Cargo.toml index 1abbb62a17..05608a284e 100644 --- a/frontend/rust-lib/flowy-chat/Cargo.toml +++ b/frontend/rust-lib/flowy-chat/Cargo.toml @@ -27,6 +27,7 @@ tokio.workspace = true futures.workspace = true allo-isolate = { version = "^0.1", features = ["catch-unwind"] } log = "0.4.21" +flowy-sidecar = { workspace = true } [build-dependencies] flowy-codegen.workspace = true diff --git a/frontend/rust-lib/flowy-chat/src/chat.rs b/frontend/rust-lib/flowy-chat/src/chat.rs index cb32c342c8..ad461d50da 100644 --- a/frontend/rust-lib/flowy-chat/src/chat.rs +++ b/frontend/rust-lib/flowy-chat/src/chat.rs @@ -1,7 +1,7 @@ use crate::entities::{ ChatMessageErrorPB, ChatMessageListPB, ChatMessagePB, RepeatedRelatedQuestionPB, }; -use crate::manager::ChatUserService; +use crate::manager::{ChatService, ChatUserService}; use crate::notification::{send_notification, ChatNotification}; use crate::persistence::{insert_chat_messages, select_chat_messages, ChatMessageTable}; use allo_isolate::Isolate; @@ -25,7 +25,7 @@ pub struct Chat { chat_id: String, uid: i64, user_service: Arc, - cloud_service: Arc, + chat_service: Arc, prev_message_state: Arc>, latest_message_id: Arc, stop_stream: Arc, @@ -37,12 +37,12 @@ impl Chat { uid: i64, chat_id: String, user_service: Arc, - cloud_service: Arc, + chat_service: Arc, ) -> Chat { Chat { uid, chat_id, - cloud_service, + chat_service, user_service, prev_message_state: Arc::new(RwLock::new(PrevMessageState::HasMore)), latest_message_id: Default::default(), @@ -92,7 +92,7 @@ impl Chat { let workspace_id = self.user_service.workspace_id()?; let question = self - .cloud_service + .chat_service .send_question(&workspace_id, &self.chat_id, message, message_type) .await .map_err(|err| { @@ -109,7 +109,7 @@ impl Chat { let stop_stream = self.stop_stream.clone(); let chat_id = self.chat_id.clone(); let question_id = question.message_id; - let cloud_service = self.cloud_service.clone(); + let cloud_service = self.chat_service.clone(); let user_service = self.user_service.clone(); tokio::spawn(async move { let mut text_sink = IsolateSink::new(Isolate::new(text_stream_port)); @@ -300,7 +300,7 @@ impl Chat { ); let workspace_id = self.user_service.workspace_id()?; let chat_id = self.chat_id.clone(); - let cloud_service = self.cloud_service.clone(); + let cloud_service = self.chat_service.clone(); let user_service = self.user_service.clone(); let uid = self.uid; let prev_message_state = self.prev_message_state.clone(); @@ -369,7 +369,7 @@ impl Chat { ) -> Result { let workspace_id = self.user_service.workspace_id()?; let resp = self - .cloud_service + .chat_service .get_related_message(&workspace_id, &self.chat_id, message_id) .await?; @@ -391,7 +391,7 @@ impl Chat { ); let workspace_id = self.user_service.workspace_id()?; let answer = self - .cloud_service + .chat_service .generate_answer(&workspace_id, &self.chat_id, question_message_id) .await?; diff --git a/frontend/rust-lib/flowy-chat/src/manager.rs b/frontend/rust-lib/flowy-chat/src/manager.rs index b72cdfb87d..8f7d60f9be 100644 --- a/frontend/rust-lib/flowy-chat/src/manager.rs +++ b/frontend/rust-lib/flowy-chat/src/manager.rs @@ -2,10 +2,16 @@ use crate::chat::Chat; use crate::entities::{ChatMessageListPB, ChatMessagePB, RepeatedRelatedQuestionPB}; use crate::persistence::{insert_chat, ChatTable}; use dashmap::DashMap; -use flowy_chat_pub::cloud::{ChatCloudService, ChatMessageType}; +use flowy_chat_pub::cloud::{ + ChatCloudService, ChatMessage, ChatMessageStream, ChatMessageType, MessageCursor, + RepeatedChatMessage, RepeatedRelatedQuestion, StreamAnswer, +}; use flowy_error::{FlowyError, FlowyResult}; +use flowy_sidecar::manager::SidecarManager; use flowy_sqlite::DBConnection; +use lib_infra::future::FutureResult; use lib_infra::util::timestamp; +use std::sync::atomic::AtomicBool; use std::sync::Arc; use tracing::trace; @@ -17,7 +23,7 @@ pub trait ChatUserService: Send + Sync + 'static { } pub struct ChatManager { - cloud_service: Arc, + chat_service: Arc, user_service: Arc, chats: Arc>>, } @@ -27,10 +33,12 @@ impl ChatManager { cloud_service: Arc, user_service: impl ChatUserService, ) -> ChatManager { + let sidecar_manager = Arc::new(SidecarManager::new()); + let chat_service = Arc::new(ChatService::new(cloud_service, sidecar_manager)); let user_service = Arc::new(user_service); Self { - cloud_service, + chat_service, user_service, chats: Arc::new(DashMap::new()), } @@ -43,7 +51,7 @@ impl ChatManager { self.user_service.user_id().unwrap(), chat_id.to_string(), self.user_service.clone(), - self.cloud_service.clone(), + self.chat_service.clone(), )) }); @@ -64,7 +72,7 @@ impl ChatManager { pub async fn create_chat(&self, uid: &i64, chat_id: &str) -> Result, FlowyError> { let workspace_id = self.user_service.workspace_id()?; self - .cloud_service + .chat_service .create_chat(uid, &workspace_id, chat_id) .await?; save_chat(self.user_service.sqlite_connection(*uid)?, chat_id)?; @@ -73,7 +81,7 @@ impl ChatManager { self.user_service.user_id().unwrap(), chat_id.to_string(), self.user_service.clone(), - self.cloud_service.clone(), + self.chat_service.clone(), )); self.chats.insert(chat_id.to_string(), chat.clone()); Ok(chat) @@ -101,7 +109,7 @@ impl ChatManager { self.user_service.user_id().unwrap(), chat_id.to_string(), self.user_service.clone(), - self.cloud_service.clone(), + self.chat_service.clone(), )); self.chats.insert(chat_id.to_string(), chat.clone()); Ok(chat) @@ -183,8 +191,107 @@ fn save_chat(conn: DBConnection, chat_id: &str) -> FlowyResult<()> { chat_id: chat_id.to_string(), created_at: timestamp(), name: "".to_string(), + local_model_path: "".to_string(), + local_model_name: "".to_string(), + local_enabled: false, + sync_to_cloud: true, }; insert_chat(conn, &row)?; Ok(()) } + +pub struct ChatService { + cloud_service: Arc, + sidecar_manager: Arc, +} + +impl ChatService { + pub fn new( + cloud_service: Arc, + sidecar_manager: Arc, + ) -> Self { + Self { + cloud_service, + sidecar_manager, + } + } +} + +impl ChatCloudService for ChatService { + fn create_chat( + &self, + uid: &i64, + workspace_id: &str, + chat_id: &str, + ) -> FutureResult<(), FlowyError> { + self.cloud_service.create_chat(uid, workspace_id, chat_id) + } + + async fn send_chat_message( + &self, + workspace_id: &str, + chat_id: &str, + message: &str, + message_type: ChatMessageType, + ) -> Result { + todo!() + } + + fn send_question( + &self, + workspace_id: &str, + chat_id: &str, + message: &str, + message_type: ChatMessageType, + ) -> FutureResult { + todo!() + } + + fn save_answer( + &self, + workspace_id: &str, + chat_id: &str, + message: &str, + question_id: i64, + ) -> FutureResult { + todo!() + } + + async fn stream_answer( + &self, + workspace_id: &str, + chat_id: &str, + message_id: i64, + ) -> Result { + todo!() + } + + fn get_chat_messages( + &self, + workspace_id: &str, + chat_id: &str, + offset: MessageCursor, + limit: u64, + ) -> FutureResult { + todo!() + } + + fn get_related_message( + &self, + workspace_id: &str, + chat_id: &str, + message_id: i64, + ) -> FutureResult { + todo!() + } + + fn generate_answer( + &self, + workspace_id: &str, + chat_id: &str, + question_message_id: i64, + ) -> FutureResult { + todo!() + } +} diff --git a/frontend/rust-lib/flowy-chat/src/persistence/chat_sql.rs b/frontend/rust-lib/flowy-chat/src/persistence/chat_sql.rs index 1fd0480c54..487a5245e0 100644 --- a/frontend/rust-lib/flowy-chat/src/persistence/chat_sql.rs +++ b/frontend/rust-lib/flowy-chat/src/persistence/chat_sql.rs @@ -1,9 +1,10 @@ +use diesel::sqlite::SqliteConnection; use flowy_sqlite::upsert::excluded; use flowy_sqlite::{ diesel, query_dsl::*, schema::{chat_table, chat_table::dsl}, - DBConnection, ExpressionMethods, Identifiable, Insertable, QueryResult, Queryable, + AsChangeset, DBConnection, ExpressionMethods, Identifiable, Insertable, QueryResult, Queryable, }; #[derive(Clone, Default, Queryable, Insertable, Identifiable)] @@ -13,6 +14,22 @@ pub struct ChatTable { pub chat_id: String, pub created_at: i64, pub name: String, + pub local_model_path: String, + pub local_model_name: String, + pub local_enabled: bool, + pub sync_to_cloud: bool, +} + +#[derive(AsChangeset, Identifiable, Default, Debug)] +#[diesel(table_name = chat_table)] +#[diesel(primary_key(chat_id))] +pub struct ChatTableChangeset { + pub chat_id: String, + pub name: Option, + pub local_model_path: Option, + pub local_model_name: Option, + pub local_enabled: Option, + pub sync_to_cloud: Option, } pub fn insert_chat(mut conn: DBConnection, new_chat: &ChatTable) -> QueryResult { @@ -27,6 +44,15 @@ pub fn insert_chat(mut conn: DBConnection, new_chat: &ChatTable) -> QueryResult< .execute(&mut *conn) } +pub fn update_chat_local_model( + conn: &mut SqliteConnection, + changeset: ChatTableChangeset, +) -> QueryResult { + let filter = dsl::chat_table.filter(chat_table::chat_id.eq(changeset.chat_id.clone())); + let affected_row = diesel::update(filter).set(changeset).execute(conn)?; + Ok(affected_row) +} + #[allow(dead_code)] pub fn read_chat(mut conn: DBConnection, chat_id_val: &str) -> QueryResult { let row = dsl::chat_table diff --git a/frontend/rust-lib/flowy-core/src/config.rs b/frontend/rust-lib/flowy-core/src/config.rs index fd8fbe335f..c910064a0a 100644 --- a/frontend/rust-lib/flowy-core/src/config.rs +++ b/frontend/rust-lib/flowy-core/src/config.rs @@ -9,7 +9,7 @@ use flowy_server_pub::af_cloud_config::AFCloudConfiguration; use flowy_server_pub::supabase_config::SupabaseConfiguration; use flowy_user::services::entities::URL_SAFE_ENGINE; use lib_infra::file_util::copy_dir_recursive; -use lib_infra::util::Platform; +use lib_infra::util::OperatingSystem; use crate::integrate::log::create_log_filter; @@ -94,7 +94,7 @@ impl AppFlowyCoreConfig { }, Some(config) => make_user_data_folder(&custom_application_path, &config.base_url), }; - let log_filter = create_log_filter("info".to_owned(), vec![], Platform::from(&platform)); + let log_filter = create_log_filter("info".to_owned(), vec![], OperatingSystem::from(&platform)); AppFlowyCoreConfig { app_version, @@ -112,7 +112,7 @@ impl AppFlowyCoreConfig { self.log_filter = create_log_filter( level.to_owned(), with_crates, - Platform::from(&self.platform), + OperatingSystem::from(&self.platform), ); self } diff --git a/frontend/rust-lib/flowy-core/src/integrate/log.rs b/frontend/rust-lib/flowy-core/src/integrate/log.rs index ddf8d32277..932c52a783 100644 --- a/frontend/rust-lib/flowy-core/src/integrate/log.rs +++ b/frontend/rust-lib/flowy-core/src/integrate/log.rs @@ -1,4 +1,4 @@ -use lib_infra::util::Platform; +use lib_infra::util::OperatingSystem; use lib_log::stream_log::StreamLogSender; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -8,7 +8,7 @@ use crate::AppFlowyCoreConfig; static INIT_LOG: AtomicBool = AtomicBool::new(false); pub(crate) fn init_log( config: &AppFlowyCoreConfig, - platform: &Platform, + platform: &OperatingSystem, stream_log_sender: Option>, ) { #[cfg(debug_assertions)] @@ -25,11 +25,15 @@ pub(crate) fn init_log( } } -pub fn create_log_filter(level: String, with_crates: Vec, platform: Platform) -> String { +pub fn create_log_filter( + level: String, + with_crates: Vec, + platform: OperatingSystem, +) -> String { let mut level = std::env::var("RUST_LOG").unwrap_or(level); #[cfg(debug_assertions)] - if matches!(platform, Platform::IOS) { + if matches!(platform, OperatingSystem::IOS) { level = "trace".to_string(); } diff --git a/frontend/rust-lib/flowy-core/src/lib.rs b/frontend/rust-lib/flowy-core/src/lib.rs index 11bba9de8c..6511724914 100644 --- a/frontend/rust-lib/flowy-core/src/lib.rs +++ b/frontend/rust-lib/flowy-core/src/lib.rs @@ -25,7 +25,7 @@ use flowy_user::user_manager::UserManager; use lib_dispatch::prelude::*; use lib_dispatch::runtime::AFPluginRuntime; use lib_infra::priority_task::{TaskDispatcher, TaskRunner}; -use lib_infra::util::Platform; +use lib_infra::util::OperatingSystem; use lib_log::stream_log::StreamLogSender; use module::make_plugins; @@ -69,7 +69,7 @@ impl AppFlowyCore { runtime: Arc, stream_log_sender: Option>, ) -> Self { - let platform = Platform::from(&config.platform); + let platform = OperatingSystem::from(&config.platform); #[allow(clippy::if_same_then_else)] if cfg!(debug_assertions) { diff --git a/frontend/rust-lib/flowy-sidecar/Cargo.toml b/frontend/rust-lib/flowy-sidecar/Cargo.toml index 8b4d14ac4c..e05c482816 100644 --- a/frontend/rust-lib/flowy-sidecar/Cargo.toml +++ b/frontend/rust-lib/flowy-sidecar/Cargo.toml @@ -16,6 +16,7 @@ tracing.workspace = true crossbeam-utils = "0.8.20" log = "0.4.21" parking_lot.workspace = true +lib-infra.workspace = true [dev-dependencies] diff --git a/frontend/rust-lib/flowy-sidecar/dev.env b/frontend/rust-lib/flowy-sidecar/dev.env index 79cb1fb482..60562f76d7 100644 --- a/frontend/rust-lib/flowy-sidecar/dev.env +++ b/frontend/rust-lib/flowy-sidecar/dev.env @@ -2,3 +2,4 @@ CHAT_BIN_PATH= LOCAL_AI_ROOT_PATH= LOCAL_AI_CHAT_MODEL_NAME= +LOCAL_AI_EMBEDDING_MODEL_NAME= diff --git a/frontend/rust-lib/flowy-sidecar/src/error.rs b/frontend/rust-lib/flowy-sidecar/src/error.rs index f1ed1b5a6a..93afef0919 100644 --- a/frontend/rust-lib/flowy-sidecar/src/error.rs +++ b/frontend/rust-lib/flowy-sidecar/src/error.rs @@ -17,6 +17,9 @@ pub enum Error { /// The peer sent a response containing the id, but was malformed. #[error("Invalid response.")] InvalidResponse, + + #[error(transparent)] + Internal(#[from] anyhow::Error), } #[derive(Debug)] diff --git a/frontend/rust-lib/flowy-sidecar/src/manager.rs b/frontend/rust-lib/flowy-sidecar/src/manager.rs index f088c8cbfd..1c016e6b89 100644 --- a/frontend/rust-lib/flowy-sidecar/src/manager.rs +++ b/frontend/rust-lib/flowy-sidecar/src/manager.rs @@ -1,9 +1,10 @@ -use crate::error::{ReadError, RemoteError}; +use crate::error::{Error, ReadError, RemoteError}; use crate::parser::ResponseParser; use crate::plugin::{start_plugin_process, Plugin, PluginId, PluginInfo, RpcCtx}; use crate::rpc_loop::Handler; use crate::rpc_peer::PluginCommand; -use anyhow::{anyhow, Result}; +use anyhow::anyhow; +use lib_infra::util::{get_operating_system, OperatingSystem}; use parking_lot::Mutex; use serde_json::{json, Value}; use std::io; @@ -14,6 +15,7 @@ use tracing::{trace, warn}; pub struct SidecarManager { state: Arc>, plugin_id_counter: Arc, + operating_system: OperatingSystem, } impl SidecarManager { @@ -23,23 +25,46 @@ impl SidecarManager { plugins: Vec::new(), })), plugin_id_counter: Arc::new(Default::default()), + operating_system: get_operating_system(), } } - pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result { + pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result { + if self.operating_system.is_not_desktop() { + return Err(Error::Internal(anyhow!( + "plugin not supported on this platform" + ))); + } let plugin_id = PluginId::from(self.plugin_id_counter.fetch_add(1, Ordering::SeqCst)); let weak_state = WeakSidecarState(Arc::downgrade(&self.state)); start_plugin_process(plugin_info, plugin_id, weak_state).await?; Ok(plugin_id) } - pub async fn remove_plugin(&self, id: PluginId) -> Result<()> { - let mut state = self.state.lock(); - state.plugin_disconnect(id, Ok(())); + pub async fn remove_plugin(&self, id: PluginId) -> Result<(), Error> { + if self.operating_system.is_not_desktop() { + return Err(Error::Internal(anyhow!( + "plugin not supported on this platform" + ))); + } + + let state = self.state.lock(); + let plugin = state + .plugins + .iter() + .find(|p| p.id == id) + .ok_or(anyhow!("plugin not found"))?; + plugin.shutdown(); Ok(()) } - pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<()> { + pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<(), Error> { + if self.operating_system.is_not_desktop() { + return Err(Error::Internal(anyhow!( + "plugin not supported on this platform" + ))); + } + let state = self.state.lock(); let plugin = state .plugins @@ -56,7 +81,7 @@ impl SidecarManager { id: PluginId, method: &str, request: Value, - ) -> Result { + ) -> Result { let state = self.state.lock(); let plugin = state .plugins @@ -67,6 +92,23 @@ impl SidecarManager { let value = P::parse_response(resp)?; Ok(value) } + + pub async fn async_send_request( + &self, + id: PluginId, + method: &str, + request: Value, + ) -> Result { + let state = self.state.lock(); + let plugin = state + .plugins + .iter() + .find(|p| p.id == id) + .ok_or(anyhow!("plugin not found"))?; + let resp = plugin.async_send_request(method, &request).await?; + let value = P::parse_response(resp)?; + Ok(value) + } } pub struct SidecarState { diff --git a/frontend/rust-lib/flowy-sidecar/src/parser.rs b/frontend/rust-lib/flowy-sidecar/src/parser.rs index 64ec178e5c..e65285e075 100644 --- a/frontend/rust-lib/flowy-sidecar/src/parser.rs +++ b/frontend/rust-lib/flowy-sidecar/src/parser.rs @@ -75,3 +75,19 @@ impl ResponseParser for ChatResponseParser { return Err(RemoteError::InvalidResponse(json)); } } + +pub struct SimilarityResponseParser; +impl ResponseParser for SimilarityResponseParser { + type ValueType = f64; + + fn parse_response(json: Value) -> Result { + if json.is_object() { + if let Some(data) = json.get("data") { + if let Some(score) = data.get("score").and_then(|v| v.as_f64()) { + return Ok(score); + } + } + } + return Err(RemoteError::InvalidResponse(json)); + } +} diff --git a/frontend/rust-lib/flowy-sidecar/src/plugin.rs b/frontend/rust-lib/flowy-sidecar/src/plugin.rs index e0fcc82474..6df3ec248f 100644 --- a/frontend/rust-lib/flowy-sidecar/src/plugin.rs +++ b/frontend/rust-lib/flowy-sidecar/src/plugin.rs @@ -8,6 +8,7 @@ use std::io::BufReader; use std::process::{Child, Stdio}; +use anyhow::anyhow; use std::thread; use std::time::Instant; use tracing::{error, info}; @@ -26,6 +27,13 @@ impl From for PluginId { pub trait Callback: Send { fn call(self: Box, result: Result); } + +impl)> Callback for F { + fn call(self: Box, result: Result) { + (*self)(result) + } +} + /// The `Peer` trait defines the interface for the opposite side of the RPC channel, /// designed to be used behind a pointer or as a trait object. pub trait Peer: Send + 'static { @@ -74,6 +82,21 @@ impl Plugin { self.peer.send_rpc_request(method, params) } + pub async fn async_send_request(&self, method: &str, params: &Value) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.peer.send_rpc_request_async( + method, + params, + Box::new(move |result| { + let _ = tx.send(result); + }), + ); + let value = rx + .await + .map_err(|err| Error::Internal(anyhow!("error waiting for async response: {:?}", err)))??; + Ok(value) + } + pub fn shutdown(&self) { if let Err(err) = self.peer.send_rpc_request("shutdown", &json!({})) { error!("error sending shutdown to plugin {}: {:?}", self.name, err); diff --git a/frontend/rust-lib/flowy-sidecar/src/rpc_object.rs b/frontend/rust-lib/flowy-sidecar/src/rpc_object.rs index 8a193e0fe3..2059df69e3 100644 --- a/frontend/rust-lib/flowy-sidecar/src/rpc_object.rs +++ b/frontend/rust-lib/flowy-sidecar/src/rpc_object.rs @@ -38,7 +38,6 @@ impl RpcObject { return Err("RPC response must contain exactly one of 'error' or 'result' fields.".into()); } let result = self.0.as_object_mut().and_then(|obj| obj.remove("result")); - match result { Some(r) => Ok(Ok(r)), None => { diff --git a/frontend/rust-lib/flowy-sidecar/tests/chat_bin_test.rs b/frontend/rust-lib/flowy-sidecar/tests/chat_bin_test.rs deleted file mode 100644 index 3764e912de..0000000000 --- a/frontend/rust-lib/flowy-sidecar/tests/chat_bin_test.rs +++ /dev/null @@ -1,94 +0,0 @@ -use anyhow::Result; -use flowy_sidecar::manager::SidecarManager; -use flowy_sidecar::parser::ChatResponseParser; -use flowy_sidecar::plugin::PluginInfo; -use serde_json::json; -use std::sync::Once; - -use tracing_subscriber::fmt::Subscriber; -use tracing_subscriber::util::SubscriberInitExt; -use tracing_subscriber::EnvFilter; - -#[tokio::test] -async fn load_chat_model_test() { - if let Ok(config) = LocalAIConfiguration::new() { - let manager = SidecarManager::new(); - let info = PluginInfo { - name: "chat".to_string(), - exec_path: config.chat_bin_path.clone(), - }; - let plugin_id = manager.create_plugin(info).await.unwrap(); - manager - .init_plugin( - plugin_id, - json!({ - "absolute_chat_model_path":config.chat_model_absolute_path(), - }), - ) - .unwrap(); - - let _json = json!({ - "plugin_id": "example_plugin_id", - "method": "initialize", - "params": { - "absolute_chat_model_path":config.chat_model_absolute_path(), - } - }); - - let chat_id = uuid::Uuid::new_v4().to_string(); - let resp = manager - .send_request::( - plugin_id, - "handle", - json!({"chat_id": chat_id, "method": "answer", "params": {"content": "hello world"}}), - ) - .unwrap(); - - eprintln!("chat response: {:?}", resp); - } -} - -pub struct LocalAIConfiguration { - root: String, - chat_bin_path: String, - chat_model_name: String, -} - -impl LocalAIConfiguration { - pub fn new() -> Result { - dotenv::dotenv().ok(); - setup_log(); - - // load from .env - let root = dotenv::var("LOCAL_AI_ROOT_PATH")?; - let chat_bin_path = dotenv::var("CHAT_BIN_PATH")?; - let chat_model = dotenv::var("LOCAL_AI_CHAT_MODEL_NAME")?; - - Ok(Self { - root, - chat_bin_path, - chat_model_name: chat_model, - }) - } - - pub fn chat_model_absolute_path(&self) -> String { - format!("{}/{}", self.root, self.chat_model_name) - } -} - -pub fn setup_log() { - static START: Once = Once::new(); - START.call_once(|| { - let level = "trace"; - let mut filters = vec![]; - filters.push(format!("flowy_sidecar={}", level)); - std::env::set_var("RUST_LOG", filters.join(",")); - - let subscriber = Subscriber::builder() - .with_env_filter(EnvFilter::from_default_env()) - .with_line_number(true) - .with_ansi(true) - .finish(); - subscriber.try_init().unwrap(); - }); -} diff --git a/frontend/rust-lib/flowy-sidecar/tests/chat_test/mod.rs b/frontend/rust-lib/flowy-sidecar/tests/chat_test/mod.rs new file mode 100644 index 0000000000..030d15ab6e --- /dev/null +++ b/frontend/rust-lib/flowy-sidecar/tests/chat_test/mod.rs @@ -0,0 +1,15 @@ +use crate::util::LocalAITest; + +#[tokio::test] +async fn load_chat_model_test() { + if let Ok(test) = LocalAITest::new() { + let plugin_id = test.init_chat_plugin().await; + let chat_id = uuid::Uuid::new_v4().to_string(); + let resp = test.send_message(&chat_id, plugin_id, "hello world").await; + eprintln!("chat response: {:?}", resp); + + let embedding_plugin_id = test.init_embedding_plugin().await; + let score = test.calculate_similarity(embedding_plugin_id, &resp, "Hello! How can I help you today? Is there something specific you would like to know or discuss").await; + assert!(score > 0.8); + } +} diff --git a/frontend/rust-lib/flowy-sidecar/tests/main.rs b/frontend/rust-lib/flowy-sidecar/tests/main.rs new file mode 100644 index 0000000000..8a6b230211 --- /dev/null +++ b/frontend/rust-lib/flowy-sidecar/tests/main.rs @@ -0,0 +1,2 @@ +pub mod chat_test; +pub mod util; diff --git a/frontend/rust-lib/flowy-sidecar/tests/util.rs b/frontend/rust-lib/flowy-sidecar/tests/util.rs new file mode 100644 index 0000000000..b6886d0ad3 --- /dev/null +++ b/frontend/rust-lib/flowy-sidecar/tests/util.rs @@ -0,0 +1,148 @@ +use anyhow::Result; +use flowy_sidecar::manager::SidecarManager; +use flowy_sidecar::parser::{ChatResponseParser, SimilarityResponseParser}; +use flowy_sidecar::plugin::{PluginId, PluginInfo}; +use serde_json::json; +use std::sync::Once; + +use tracing_subscriber::fmt::Subscriber; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::EnvFilter; + +pub struct LocalAITest { + config: LocalAIConfiguration, + manager: SidecarManager, +} + +impl LocalAITest { + pub fn new() -> Result { + let config = LocalAIConfiguration::new()?; + let manager = SidecarManager::new(); + + Ok(Self { config, manager }) + } + pub async fn init_chat_plugin(&self) -> PluginId { + let info = PluginInfo { + name: "chat".to_string(), + exec_path: self.config.chat_bin_path.clone(), + }; + let plugin_id = self.manager.create_plugin(info).await.unwrap(); + self + .manager + .init_plugin( + plugin_id, + json!({ + "absolute_chat_model_path":self.config.chat_model_absolute_path(), + }), + ) + .unwrap(); + + plugin_id + } + + pub async fn init_embedding_plugin(&self) -> PluginId { + let info = PluginInfo { + name: "embedding".to_string(), + exec_path: self.config.embedding_bin_path.clone(), + }; + let plugin_id = self.manager.create_plugin(info).await.unwrap(); + let embedding_model_path = self.config.embedding_model_absolute_path(); + self + .manager + .init_plugin( + plugin_id, + json!({ + "absolute_model_path":embedding_model_path, + }), + ) + .unwrap(); + plugin_id + } + + pub async fn send_message(&self, chat_id: &str, plugin_id: PluginId, message: &str) -> String { + let resp = self + .manager + .async_send_request::( + plugin_id, + "handle", + json!({"chat_id": chat_id, "method": "answer", "params": {"content": message}}), + ) + .await + .unwrap(); + + resp + } + + pub async fn calculate_similarity( + &self, + plugin_id: PluginId, + message1: &str, + message2: &str, + ) -> f64 { + self + .manager + .async_send_request::( + plugin_id, + "handle", + json!({"method": "calculate_similarity", "params": {"src": message1, "dest": message2}}), + ) + .await + .unwrap() + } +} + +pub struct LocalAIConfiguration { + root: String, + chat_bin_path: String, + chat_model_name: String, + embedding_bin_path: String, + embedding_model_name: String, +} + +impl LocalAIConfiguration { + pub fn new() -> Result { + dotenv::dotenv().ok(); + setup_log(); + + // load from .env + let root = dotenv::var("LOCAL_AI_ROOT_PATH")?; + let chat_bin_path = dotenv::var("CHAT_BIN_PATH")?; + let chat_model_name = dotenv::var("LOCAL_AI_CHAT_MODEL_NAME")?; + + let embedding_bin_path = dotenv::var("EMBEDDING_BIN_PATH")?; + let embedding_model_name = dotenv::var("LOCAL_AI_EMBEDDING_MODEL_NAME")?; + + Ok(Self { + root, + chat_bin_path, + chat_model_name, + embedding_bin_path, + embedding_model_name, + }) + } + + pub fn chat_model_absolute_path(&self) -> String { + format!("{}/{}", self.root, self.chat_model_name) + } + + pub fn embedding_model_absolute_path(&self) -> String { + format!("{}/{}", self.root, self.embedding_model_name) + } +} + +pub fn setup_log() { + static START: Once = Once::new(); + START.call_once(|| { + let level = "trace"; + let mut filters = vec![]; + filters.push(format!("flowy_sidecar={}", level)); + std::env::set_var("RUST_LOG", filters.join(",")); + + let subscriber = Subscriber::builder() + .with_env_filter(EnvFilter::from_default_env()) + .with_line_number(true) + .with_ansi(true) + .finish(); + subscriber.try_init().unwrap(); + }); +} diff --git a/frontend/rust-lib/flowy-sqlite/migrations/2024-06-26-015936_chat_setting/down.sql b/frontend/rust-lib/flowy-sqlite/migrations/2024-06-26-015936_chat_setting/down.sql new file mode 100644 index 0000000000..d9a93fe9a1 --- /dev/null +++ b/frontend/rust-lib/flowy-sqlite/migrations/2024-06-26-015936_chat_setting/down.sql @@ -0,0 +1 @@ +-- This file should undo anything in `up.sql` diff --git a/frontend/rust-lib/flowy-sqlite/migrations/2024-06-26-015936_chat_setting/up.sql b/frontend/rust-lib/flowy-sqlite/migrations/2024-06-26-015936_chat_setting/up.sql new file mode 100644 index 0000000000..2361adeb34 --- /dev/null +++ b/frontend/rust-lib/flowy-sqlite/migrations/2024-06-26-015936_chat_setting/up.sql @@ -0,0 +1,13 @@ +-- Your SQL goes here +ALTER TABLE chat_table ADD COLUMN local_model_path TEXT NOT NULL DEFAULT ''; +ALTER TABLE chat_table ADD COLUMN local_model_name TEXT NOT NULL DEFAULT ''; +ALTER TABLE chat_table ADD COLUMN local_enabled BOOLEAN NOT NULL DEFAULT FALSE; +ALTER TABLE chat_table ADD COLUMN sync_to_cloud BOOLEAN NOT NULL DEFAULT TRUE; + + +CREATE TABLE chat_local_setting_table +( + chat_id TEXT PRIMARY KEY NOT NULL, + local_model_path TEXT NOT NULL, + local_model_name TEXT NOT NULL DEFAULT '' +); \ No newline at end of file diff --git a/frontend/rust-lib/flowy-sqlite/src/schema.rs b/frontend/rust-lib/flowy-sqlite/src/schema.rs index 49fcc254d1..6033f40b5a 100644 --- a/frontend/rust-lib/flowy-sqlite/src/schema.rs +++ b/frontend/rust-lib/flowy-sqlite/src/schema.rs @@ -1,5 +1,13 @@ // @generated automatically by Diesel CLI. +diesel::table! { + chat_local_setting_table (chat_id) { + chat_id -> Text, + local_model_path -> Text, + local_model_name -> Text, + } +} + diesel::table! { chat_message_table (message_id) { message_id -> BigInt, @@ -17,6 +25,10 @@ diesel::table! { chat_id -> Text, created_at -> BigInt, name -> Text, + local_model_path -> Text, + local_model_name -> Text, + local_enabled -> Bool, + sync_to_cloud -> Bool, } } @@ -102,6 +114,7 @@ diesel::table! { } diesel::allow_tables_to_appear_in_same_query!( + chat_local_setting_table, chat_message_table, chat_table, collab_snapshot, diff --git a/frontend/rust-lib/lib-infra/Cargo.toml b/frontend/rust-lib/lib-infra/Cargo.toml index 6b902537ee..416df9b04d 100644 --- a/frontend/rust-lib/lib-infra/Cargo.toml +++ b/frontend/rust-lib/lib-infra/Cargo.toml @@ -23,6 +23,7 @@ tracing.workspace = true atomic_refcell = "0.1" allo-isolate = { version = "^0.1", features = ["catch-unwind"], optional = true } futures = "0.3.30" +cfg-if = "1.0.0" [dev-dependencies] rand = "0.8.5" diff --git a/frontend/rust-lib/lib-infra/src/util.rs b/frontend/rust-lib/lib-infra/src/util.rs index 823095a26f..b67646e4be 100644 --- a/frontend/rust-lib/lib-infra/src/util.rs +++ b/frontend/rust-lib/lib-infra/src/util.rs @@ -67,9 +67,8 @@ pub fn md5>(data: T) -> String { let md5 = format!("{:x}", md5::compute(data)); md5 } - #[derive(Debug, Clone, PartialEq, Eq)] -pub enum Platform { +pub enum OperatingSystem { Unknown, Windows, Linux, @@ -78,33 +77,62 @@ pub enum Platform { Android, } -impl Platform { +impl OperatingSystem { pub fn is_not_ios(&self) -> bool { - !matches!(self, Platform::IOS) + !matches!(self, OperatingSystem::IOS) + } + + pub fn is_desktop(&self) -> bool { + matches!( + self, + OperatingSystem::Windows | OperatingSystem::Linux | OperatingSystem::MacOS + ) + } + + pub fn is_not_desktop(&self) -> bool { + !self.is_desktop() } } -impl From for Platform { +impl From for OperatingSystem { fn from(s: String) -> Self { - Platform::from(s.as_str()) + OperatingSystem::from(s.as_str()) } } -impl From<&String> for Platform { +impl From<&String> for OperatingSystem { fn from(s: &String) -> Self { - Platform::from(s.as_str()) + OperatingSystem::from(s.as_str()) } } -impl From<&str> for Platform { +impl From<&str> for OperatingSystem { fn from(s: &str) -> Self { match s { - "windows" => Platform::Windows, - "linux" => Platform::Linux, - "macos" => Platform::MacOS, - "ios" => Platform::IOS, - "android" => Platform::Android, - _ => Platform::Unknown, + "windows" => OperatingSystem::Windows, + "linux" => OperatingSystem::Linux, + "macos" => OperatingSystem::MacOS, + "ios" => OperatingSystem::IOS, + "android" => OperatingSystem::Android, + _ => OperatingSystem::Unknown, } } } + +pub fn get_operating_system() -> OperatingSystem { + cfg_if::cfg_if! { + if #[cfg(target_os = "android")] { + OperatingSystem::Android + } else if #[cfg(target_os = "ios")] { + OperatingSystem::IOS + } else if #[cfg(target_os = "macos")] { + OperatingSystem::MacOS + } else if #[cfg(target_os = "windows")] { + OperatingSystem::Windows + } else if #[cfg(target_os = "linux")] { + OperatingSystem::Linux + } else { + OperatingSystem::Unknown + } + } +} diff --git a/frontend/rust-lib/lib-log/src/lib.rs b/frontend/rust-lib/lib-log/src/lib.rs index a328be254c..a956368883 100644 --- a/frontend/rust-lib/lib-log/src/lib.rs +++ b/frontend/rust-lib/lib-log/src/lib.rs @@ -4,7 +4,7 @@ use std::sync::{Arc, RwLock}; use chrono::Local; use lazy_static::lazy_static; -use lib_infra::util::Platform; +use lib_infra::util::OperatingSystem; use tracing::subscriber::set_global_default; use tracing_appender::rolling::Rotation; use tracing_appender::{non_blocking::WorkerGuard, rolling::RollingFileAppender}; @@ -29,7 +29,7 @@ pub struct Builder { env_filter: String, file_appender: RollingFileAppender, #[allow(dead_code)] - platform: Platform, + platform: OperatingSystem, stream_log_sender: Option>, } @@ -37,7 +37,7 @@ impl Builder { pub fn new( name: &str, directory: &str, - platform: &Platform, + platform: &OperatingSystem, stream_log_sender: Option>, ) -> Self { let file_appender = RollingFileAppender::builder()