chore: save chat config

This commit is contained in:
nathan
2024-06-26 16:06:43 +08:00
parent ae20547f8b
commit c9d61e543b
26 changed files with 501 additions and 147 deletions

View File

@ -1669,6 +1669,7 @@ dependencies = [
"flowy-derive",
"flowy-error",
"flowy-notification",
"flowy-sidecar",
"flowy-sqlite",
"futures",
"lib-dispatch",
@ -2147,6 +2148,7 @@ dependencies = [
"anyhow",
"crossbeam-utils",
"dotenv",
"lib-infra",
"log",
"once_cell",
"parking_lot 0.12.1",
@ -3115,6 +3117,7 @@ dependencies = [
"atomic_refcell",
"brotli",
"bytes",
"cfg-if",
"chrono",
"futures",
"futures-core",

View File

@ -68,6 +68,7 @@ collab-integrate = { workspace = true, path = "collab-integrate" }
flowy-date = { workspace = true, path = "flowy-date" }
flowy-chat = { workspace = true, path = "flowy-chat" }
flowy-chat-pub = { workspace = true, path = "flowy-chat-pub" }
flowy-sidecar = { workspace = true, path = "flowy-sidecar" }
anyhow = "1.0"
tracing = "0.1.40"
bytes = "1.5.0"

View File

@ -27,6 +27,7 @@ tokio.workspace = true
futures.workspace = true
allo-isolate = { version = "^0.1", features = ["catch-unwind"] }
log = "0.4.21"
flowy-sidecar = { workspace = true }
[build-dependencies]
flowy-codegen.workspace = true

View File

@ -1,7 +1,7 @@
use crate::entities::{
ChatMessageErrorPB, ChatMessageListPB, ChatMessagePB, RepeatedRelatedQuestionPB,
};
use crate::manager::ChatUserService;
use crate::manager::{ChatService, ChatUserService};
use crate::notification::{send_notification, ChatNotification};
use crate::persistence::{insert_chat_messages, select_chat_messages, ChatMessageTable};
use allo_isolate::Isolate;
@ -25,7 +25,7 @@ pub struct Chat {
chat_id: String,
uid: i64,
user_service: Arc<dyn ChatUserService>,
cloud_service: Arc<dyn ChatCloudService>,
chat_service: Arc<ChatService>,
prev_message_state: Arc<RwLock<PrevMessageState>>,
latest_message_id: Arc<AtomicI64>,
stop_stream: Arc<AtomicBool>,
@ -37,12 +37,12 @@ impl Chat {
uid: i64,
chat_id: String,
user_service: Arc<dyn ChatUserService>,
cloud_service: Arc<dyn ChatCloudService>,
chat_service: Arc<ChatService>,
) -> Chat {
Chat {
uid,
chat_id,
cloud_service,
chat_service,
user_service,
prev_message_state: Arc::new(RwLock::new(PrevMessageState::HasMore)),
latest_message_id: Default::default(),
@ -92,7 +92,7 @@ impl Chat {
let workspace_id = self.user_service.workspace_id()?;
let question = self
.cloud_service
.chat_service
.send_question(&workspace_id, &self.chat_id, message, message_type)
.await
.map_err(|err| {
@ -109,7 +109,7 @@ impl Chat {
let stop_stream = self.stop_stream.clone();
let chat_id = self.chat_id.clone();
let question_id = question.message_id;
let cloud_service = self.cloud_service.clone();
let cloud_service = self.chat_service.clone();
let user_service = self.user_service.clone();
tokio::spawn(async move {
let mut text_sink = IsolateSink::new(Isolate::new(text_stream_port));
@ -300,7 +300,7 @@ impl Chat {
);
let workspace_id = self.user_service.workspace_id()?;
let chat_id = self.chat_id.clone();
let cloud_service = self.cloud_service.clone();
let cloud_service = self.chat_service.clone();
let user_service = self.user_service.clone();
let uid = self.uid;
let prev_message_state = self.prev_message_state.clone();
@ -369,7 +369,7 @@ impl Chat {
) -> Result<RepeatedRelatedQuestionPB, FlowyError> {
let workspace_id = self.user_service.workspace_id()?;
let resp = self
.cloud_service
.chat_service
.get_related_message(&workspace_id, &self.chat_id, message_id)
.await?;
@ -391,7 +391,7 @@ impl Chat {
);
let workspace_id = self.user_service.workspace_id()?;
let answer = self
.cloud_service
.chat_service
.generate_answer(&workspace_id, &self.chat_id, question_message_id)
.await?;

View File

@ -2,10 +2,16 @@ use crate::chat::Chat;
use crate::entities::{ChatMessageListPB, ChatMessagePB, RepeatedRelatedQuestionPB};
use crate::persistence::{insert_chat, ChatTable};
use dashmap::DashMap;
use flowy_chat_pub::cloud::{ChatCloudService, ChatMessageType};
use flowy_chat_pub::cloud::{
ChatCloudService, ChatMessage, ChatMessageStream, ChatMessageType, MessageCursor,
RepeatedChatMessage, RepeatedRelatedQuestion, StreamAnswer,
};
use flowy_error::{FlowyError, FlowyResult};
use flowy_sidecar::manager::SidecarManager;
use flowy_sqlite::DBConnection;
use lib_infra::future::FutureResult;
use lib_infra::util::timestamp;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use tracing::trace;
@ -17,7 +23,7 @@ pub trait ChatUserService: Send + Sync + 'static {
}
pub struct ChatManager {
cloud_service: Arc<dyn ChatCloudService>,
chat_service: Arc<ChatService>,
user_service: Arc<dyn ChatUserService>,
chats: Arc<DashMap<String, Arc<Chat>>>,
}
@ -27,10 +33,12 @@ impl ChatManager {
cloud_service: Arc<dyn ChatCloudService>,
user_service: impl ChatUserService,
) -> ChatManager {
let sidecar_manager = Arc::new(SidecarManager::new());
let chat_service = Arc::new(ChatService::new(cloud_service, sidecar_manager));
let user_service = Arc::new(user_service);
Self {
cloud_service,
chat_service,
user_service,
chats: Arc::new(DashMap::new()),
}
@ -43,7 +51,7 @@ impl ChatManager {
self.user_service.user_id().unwrap(),
chat_id.to_string(),
self.user_service.clone(),
self.cloud_service.clone(),
self.chat_service.clone(),
))
});
@ -64,7 +72,7 @@ impl ChatManager {
pub async fn create_chat(&self, uid: &i64, chat_id: &str) -> Result<Arc<Chat>, FlowyError> {
let workspace_id = self.user_service.workspace_id()?;
self
.cloud_service
.chat_service
.create_chat(uid, &workspace_id, chat_id)
.await?;
save_chat(self.user_service.sqlite_connection(*uid)?, chat_id)?;
@ -73,7 +81,7 @@ impl ChatManager {
self.user_service.user_id().unwrap(),
chat_id.to_string(),
self.user_service.clone(),
self.cloud_service.clone(),
self.chat_service.clone(),
));
self.chats.insert(chat_id.to_string(), chat.clone());
Ok(chat)
@ -101,7 +109,7 @@ impl ChatManager {
self.user_service.user_id().unwrap(),
chat_id.to_string(),
self.user_service.clone(),
self.cloud_service.clone(),
self.chat_service.clone(),
));
self.chats.insert(chat_id.to_string(), chat.clone());
Ok(chat)
@ -183,8 +191,107 @@ fn save_chat(conn: DBConnection, chat_id: &str) -> FlowyResult<()> {
chat_id: chat_id.to_string(),
created_at: timestamp(),
name: "".to_string(),
local_model_path: "".to_string(),
local_model_name: "".to_string(),
local_enabled: false,
sync_to_cloud: true,
};
insert_chat(conn, &row)?;
Ok(())
}
pub struct ChatService {
cloud_service: Arc<dyn ChatCloudService>,
sidecar_manager: Arc<SidecarManager>,
}
impl ChatService {
pub fn new(
cloud_service: Arc<dyn ChatCloudService>,
sidecar_manager: Arc<SidecarManager>,
) -> Self {
Self {
cloud_service,
sidecar_manager,
}
}
}
impl ChatCloudService for ChatService {
fn create_chat(
&self,
uid: &i64,
workspace_id: &str,
chat_id: &str,
) -> FutureResult<(), FlowyError> {
self.cloud_service.create_chat(uid, workspace_id, chat_id)
}
async fn send_chat_message(
&self,
workspace_id: &str,
chat_id: &str,
message: &str,
message_type: ChatMessageType,
) -> Result<ChatMessageStream, FlowyError> {
todo!()
}
fn send_question(
&self,
workspace_id: &str,
chat_id: &str,
message: &str,
message_type: ChatMessageType,
) -> FutureResult<ChatMessage, FlowyError> {
todo!()
}
fn save_answer(
&self,
workspace_id: &str,
chat_id: &str,
message: &str,
question_id: i64,
) -> FutureResult<ChatMessage, FlowyError> {
todo!()
}
async fn stream_answer(
&self,
workspace_id: &str,
chat_id: &str,
message_id: i64,
) -> Result<StreamAnswer, FlowyError> {
todo!()
}
fn get_chat_messages(
&self,
workspace_id: &str,
chat_id: &str,
offset: MessageCursor,
limit: u64,
) -> FutureResult<RepeatedChatMessage, FlowyError> {
todo!()
}
fn get_related_message(
&self,
workspace_id: &str,
chat_id: &str,
message_id: i64,
) -> FutureResult<RepeatedRelatedQuestion, FlowyError> {
todo!()
}
fn generate_answer(
&self,
workspace_id: &str,
chat_id: &str,
question_message_id: i64,
) -> FutureResult<ChatMessage, FlowyError> {
todo!()
}
}

View File

@ -1,9 +1,10 @@
use diesel::sqlite::SqliteConnection;
use flowy_sqlite::upsert::excluded;
use flowy_sqlite::{
diesel,
query_dsl::*,
schema::{chat_table, chat_table::dsl},
DBConnection, ExpressionMethods, Identifiable, Insertable, QueryResult, Queryable,
AsChangeset, DBConnection, ExpressionMethods, Identifiable, Insertable, QueryResult, Queryable,
};
#[derive(Clone, Default, Queryable, Insertable, Identifiable)]
@ -13,6 +14,22 @@ pub struct ChatTable {
pub chat_id: String,
pub created_at: i64,
pub name: String,
pub local_model_path: String,
pub local_model_name: String,
pub local_enabled: bool,
pub sync_to_cloud: bool,
}
#[derive(AsChangeset, Identifiable, Default, Debug)]
#[diesel(table_name = chat_table)]
#[diesel(primary_key(chat_id))]
pub struct ChatTableChangeset {
pub chat_id: String,
pub name: Option<String>,
pub local_model_path: Option<String>,
pub local_model_name: Option<String>,
pub local_enabled: Option<bool>,
pub sync_to_cloud: Option<bool>,
}
pub fn insert_chat(mut conn: DBConnection, new_chat: &ChatTable) -> QueryResult<usize> {
@ -27,6 +44,15 @@ pub fn insert_chat(mut conn: DBConnection, new_chat: &ChatTable) -> QueryResult<
.execute(&mut *conn)
}
pub fn update_chat_local_model(
conn: &mut SqliteConnection,
changeset: ChatTableChangeset,
) -> QueryResult<usize> {
let filter = dsl::chat_table.filter(chat_table::chat_id.eq(changeset.chat_id.clone()));
let affected_row = diesel::update(filter).set(changeset).execute(conn)?;
Ok(affected_row)
}
#[allow(dead_code)]
pub fn read_chat(mut conn: DBConnection, chat_id_val: &str) -> QueryResult<ChatTable> {
let row = dsl::chat_table

View File

@ -9,7 +9,7 @@ use flowy_server_pub::af_cloud_config::AFCloudConfiguration;
use flowy_server_pub::supabase_config::SupabaseConfiguration;
use flowy_user::services::entities::URL_SAFE_ENGINE;
use lib_infra::file_util::copy_dir_recursive;
use lib_infra::util::Platform;
use lib_infra::util::OperatingSystem;
use crate::integrate::log::create_log_filter;
@ -94,7 +94,7 @@ impl AppFlowyCoreConfig {
},
Some(config) => make_user_data_folder(&custom_application_path, &config.base_url),
};
let log_filter = create_log_filter("info".to_owned(), vec![], Platform::from(&platform));
let log_filter = create_log_filter("info".to_owned(), vec![], OperatingSystem::from(&platform));
AppFlowyCoreConfig {
app_version,
@ -112,7 +112,7 @@ impl AppFlowyCoreConfig {
self.log_filter = create_log_filter(
level.to_owned(),
with_crates,
Platform::from(&self.platform),
OperatingSystem::from(&self.platform),
);
self
}

View File

@ -1,4 +1,4 @@
use lib_infra::util::Platform;
use lib_infra::util::OperatingSystem;
use lib_log::stream_log::StreamLogSender;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
@ -8,7 +8,7 @@ use crate::AppFlowyCoreConfig;
static INIT_LOG: AtomicBool = AtomicBool::new(false);
pub(crate) fn init_log(
config: &AppFlowyCoreConfig,
platform: &Platform,
platform: &OperatingSystem,
stream_log_sender: Option<Arc<dyn StreamLogSender>>,
) {
#[cfg(debug_assertions)]
@ -25,11 +25,15 @@ pub(crate) fn init_log(
}
}
pub fn create_log_filter(level: String, with_crates: Vec<String>, platform: Platform) -> String {
pub fn create_log_filter(
level: String,
with_crates: Vec<String>,
platform: OperatingSystem,
) -> String {
let mut level = std::env::var("RUST_LOG").unwrap_or(level);
#[cfg(debug_assertions)]
if matches!(platform, Platform::IOS) {
if matches!(platform, OperatingSystem::IOS) {
level = "trace".to_string();
}

View File

@ -25,7 +25,7 @@ use flowy_user::user_manager::UserManager;
use lib_dispatch::prelude::*;
use lib_dispatch::runtime::AFPluginRuntime;
use lib_infra::priority_task::{TaskDispatcher, TaskRunner};
use lib_infra::util::Platform;
use lib_infra::util::OperatingSystem;
use lib_log::stream_log::StreamLogSender;
use module::make_plugins;
@ -69,7 +69,7 @@ impl AppFlowyCore {
runtime: Arc<AFPluginRuntime>,
stream_log_sender: Option<Arc<dyn StreamLogSender>>,
) -> Self {
let platform = Platform::from(&config.platform);
let platform = OperatingSystem::from(&config.platform);
#[allow(clippy::if_same_then_else)]
if cfg!(debug_assertions) {

View File

@ -16,6 +16,7 @@ tracing.workspace = true
crossbeam-utils = "0.8.20"
log = "0.4.21"
parking_lot.workspace = true
lib-infra.workspace = true
[dev-dependencies]

View File

@ -2,3 +2,4 @@
CHAT_BIN_PATH=
LOCAL_AI_ROOT_PATH=
LOCAL_AI_CHAT_MODEL_NAME=
LOCAL_AI_EMBEDDING_MODEL_NAME=

View File

@ -17,6 +17,9 @@ pub enum Error {
/// The peer sent a response containing the id, but was malformed.
#[error("Invalid response.")]
InvalidResponse,
#[error(transparent)]
Internal(#[from] anyhow::Error),
}
#[derive(Debug)]

View File

@ -1,9 +1,10 @@
use crate::error::{ReadError, RemoteError};
use crate::error::{Error, ReadError, RemoteError};
use crate::parser::ResponseParser;
use crate::plugin::{start_plugin_process, Plugin, PluginId, PluginInfo, RpcCtx};
use crate::rpc_loop::Handler;
use crate::rpc_peer::PluginCommand;
use anyhow::{anyhow, Result};
use anyhow::anyhow;
use lib_infra::util::{get_operating_system, OperatingSystem};
use parking_lot::Mutex;
use serde_json::{json, Value};
use std::io;
@ -14,6 +15,7 @@ use tracing::{trace, warn};
pub struct SidecarManager {
state: Arc<Mutex<SidecarState>>,
plugin_id_counter: Arc<AtomicI64>,
operating_system: OperatingSystem,
}
impl SidecarManager {
@ -23,23 +25,46 @@ impl SidecarManager {
plugins: Vec::new(),
})),
plugin_id_counter: Arc::new(Default::default()),
operating_system: get_operating_system(),
}
}
pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result<PluginId> {
pub async fn create_plugin(&self, plugin_info: PluginInfo) -> Result<PluginId, Error> {
if self.operating_system.is_not_desktop() {
return Err(Error::Internal(anyhow!(
"plugin not supported on this platform"
)));
}
let plugin_id = PluginId::from(self.plugin_id_counter.fetch_add(1, Ordering::SeqCst));
let weak_state = WeakSidecarState(Arc::downgrade(&self.state));
start_plugin_process(plugin_info, plugin_id, weak_state).await?;
Ok(plugin_id)
}
pub async fn remove_plugin(&self, id: PluginId) -> Result<()> {
let mut state = self.state.lock();
state.plugin_disconnect(id, Ok(()));
pub async fn remove_plugin(&self, id: PluginId) -> Result<(), Error> {
if self.operating_system.is_not_desktop() {
return Err(Error::Internal(anyhow!(
"plugin not supported on this platform"
)));
}
let state = self.state.lock();
let plugin = state
.plugins
.iter()
.find(|p| p.id == id)
.ok_or(anyhow!("plugin not found"))?;
plugin.shutdown();
Ok(())
}
pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<()> {
pub fn init_plugin(&self, id: PluginId, init_params: Value) -> Result<(), Error> {
if self.operating_system.is_not_desktop() {
return Err(Error::Internal(anyhow!(
"plugin not supported on this platform"
)));
}
let state = self.state.lock();
let plugin = state
.plugins
@ -56,7 +81,7 @@ impl SidecarManager {
id: PluginId,
method: &str,
request: Value,
) -> Result<P::ValueType> {
) -> Result<P::ValueType, Error> {
let state = self.state.lock();
let plugin = state
.plugins
@ -67,6 +92,23 @@ impl SidecarManager {
let value = P::parse_response(resp)?;
Ok(value)
}
pub async fn async_send_request<P: ResponseParser>(
&self,
id: PluginId,
method: &str,
request: Value,
) -> Result<P::ValueType, Error> {
let state = self.state.lock();
let plugin = state
.plugins
.iter()
.find(|p| p.id == id)
.ok_or(anyhow!("plugin not found"))?;
let resp = plugin.async_send_request(method, &request).await?;
let value = P::parse_response(resp)?;
Ok(value)
}
}
pub struct SidecarState {

View File

@ -75,3 +75,19 @@ impl ResponseParser for ChatResponseParser {
return Err(RemoteError::InvalidResponse(json));
}
}
pub struct SimilarityResponseParser;
impl ResponseParser for SimilarityResponseParser {
type ValueType = f64;
fn parse_response(json: Value) -> Result<Self::ValueType, RemoteError> {
if json.is_object() {
if let Some(data) = json.get("data") {
if let Some(score) = data.get("score").and_then(|v| v.as_f64()) {
return Ok(score);
}
}
}
return Err(RemoteError::InvalidResponse(json));
}
}

View File

@ -8,6 +8,7 @@ use std::io::BufReader;
use std::process::{Child, Stdio};
use anyhow::anyhow;
use std::thread;
use std::time::Instant;
use tracing::{error, info};
@ -26,6 +27,13 @@ impl From<i64> for PluginId {
pub trait Callback: Send {
fn call(self: Box<Self>, result: Result<Value, Error>);
}
impl<F: Send + FnOnce(Result<Value, Error>)> Callback for F {
fn call(self: Box<F>, result: Result<Value, Error>) {
(*self)(result)
}
}
/// The `Peer` trait defines the interface for the opposite side of the RPC channel,
/// designed to be used behind a pointer or as a trait object.
pub trait Peer: Send + 'static {
@ -74,6 +82,21 @@ impl Plugin {
self.peer.send_rpc_request(method, params)
}
pub async fn async_send_request(&self, method: &str, params: &Value) -> Result<Value, Error> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.peer.send_rpc_request_async(
method,
params,
Box::new(move |result| {
let _ = tx.send(result);
}),
);
let value = rx
.await
.map_err(|err| Error::Internal(anyhow!("error waiting for async response: {:?}", err)))??;
Ok(value)
}
pub fn shutdown(&self) {
if let Err(err) = self.peer.send_rpc_request("shutdown", &json!({})) {
error!("error sending shutdown to plugin {}: {:?}", self.name, err);

View File

@ -38,7 +38,6 @@ impl RpcObject {
return Err("RPC response must contain exactly one of 'error' or 'result' fields.".into());
}
let result = self.0.as_object_mut().and_then(|obj| obj.remove("result"));
match result {
Some(r) => Ok(Ok(r)),
None => {

View File

@ -1,94 +0,0 @@
use anyhow::Result;
use flowy_sidecar::manager::SidecarManager;
use flowy_sidecar::parser::ChatResponseParser;
use flowy_sidecar::plugin::PluginInfo;
use serde_json::json;
use std::sync::Once;
use tracing_subscriber::fmt::Subscriber;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::EnvFilter;
#[tokio::test]
async fn load_chat_model_test() {
if let Ok(config) = LocalAIConfiguration::new() {
let manager = SidecarManager::new();
let info = PluginInfo {
name: "chat".to_string(),
exec_path: config.chat_bin_path.clone(),
};
let plugin_id = manager.create_plugin(info).await.unwrap();
manager
.init_plugin(
plugin_id,
json!({
"absolute_chat_model_path":config.chat_model_absolute_path(),
}),
)
.unwrap();
let _json = json!({
"plugin_id": "example_plugin_id",
"method": "initialize",
"params": {
"absolute_chat_model_path":config.chat_model_absolute_path(),
}
});
let chat_id = uuid::Uuid::new_v4().to_string();
let resp = manager
.send_request::<ChatResponseParser>(
plugin_id,
"handle",
json!({"chat_id": chat_id, "method": "answer", "params": {"content": "hello world"}}),
)
.unwrap();
eprintln!("chat response: {:?}", resp);
}
}
pub struct LocalAIConfiguration {
root: String,
chat_bin_path: String,
chat_model_name: String,
}
impl LocalAIConfiguration {
pub fn new() -> Result<Self> {
dotenv::dotenv().ok();
setup_log();
// load from .env
let root = dotenv::var("LOCAL_AI_ROOT_PATH")?;
let chat_bin_path = dotenv::var("CHAT_BIN_PATH")?;
let chat_model = dotenv::var("LOCAL_AI_CHAT_MODEL_NAME")?;
Ok(Self {
root,
chat_bin_path,
chat_model_name: chat_model,
})
}
pub fn chat_model_absolute_path(&self) -> String {
format!("{}/{}", self.root, self.chat_model_name)
}
}
pub fn setup_log() {
static START: Once = Once::new();
START.call_once(|| {
let level = "trace";
let mut filters = vec![];
filters.push(format!("flowy_sidecar={}", level));
std::env::set_var("RUST_LOG", filters.join(","));
let subscriber = Subscriber::builder()
.with_env_filter(EnvFilter::from_default_env())
.with_line_number(true)
.with_ansi(true)
.finish();
subscriber.try_init().unwrap();
});
}

View File

@ -0,0 +1,15 @@
use crate::util::LocalAITest;
#[tokio::test]
async fn load_chat_model_test() {
if let Ok(test) = LocalAITest::new() {
let plugin_id = test.init_chat_plugin().await;
let chat_id = uuid::Uuid::new_v4().to_string();
let resp = test.send_message(&chat_id, plugin_id, "hello world").await;
eprintln!("chat response: {:?}", resp);
let embedding_plugin_id = test.init_embedding_plugin().await;
let score = test.calculate_similarity(embedding_plugin_id, &resp, "Hello! How can I help you today? Is there something specific you would like to know or discuss").await;
assert!(score > 0.8);
}
}

View File

@ -0,0 +1,2 @@
pub mod chat_test;
pub mod util;

View File

@ -0,0 +1,148 @@
use anyhow::Result;
use flowy_sidecar::manager::SidecarManager;
use flowy_sidecar::parser::{ChatResponseParser, SimilarityResponseParser};
use flowy_sidecar::plugin::{PluginId, PluginInfo};
use serde_json::json;
use std::sync::Once;
use tracing_subscriber::fmt::Subscriber;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::EnvFilter;
pub struct LocalAITest {
config: LocalAIConfiguration,
manager: SidecarManager,
}
impl LocalAITest {
pub fn new() -> Result<Self> {
let config = LocalAIConfiguration::new()?;
let manager = SidecarManager::new();
Ok(Self { config, manager })
}
pub async fn init_chat_plugin(&self) -> PluginId {
let info = PluginInfo {
name: "chat".to_string(),
exec_path: self.config.chat_bin_path.clone(),
};
let plugin_id = self.manager.create_plugin(info).await.unwrap();
self
.manager
.init_plugin(
plugin_id,
json!({
"absolute_chat_model_path":self.config.chat_model_absolute_path(),
}),
)
.unwrap();
plugin_id
}
pub async fn init_embedding_plugin(&self) -> PluginId {
let info = PluginInfo {
name: "embedding".to_string(),
exec_path: self.config.embedding_bin_path.clone(),
};
let plugin_id = self.manager.create_plugin(info).await.unwrap();
let embedding_model_path = self.config.embedding_model_absolute_path();
self
.manager
.init_plugin(
plugin_id,
json!({
"absolute_model_path":embedding_model_path,
}),
)
.unwrap();
plugin_id
}
pub async fn send_message(&self, chat_id: &str, plugin_id: PluginId, message: &str) -> String {
let resp = self
.manager
.async_send_request::<ChatResponseParser>(
plugin_id,
"handle",
json!({"chat_id": chat_id, "method": "answer", "params": {"content": message}}),
)
.await
.unwrap();
resp
}
pub async fn calculate_similarity(
&self,
plugin_id: PluginId,
message1: &str,
message2: &str,
) -> f64 {
self
.manager
.async_send_request::<SimilarityResponseParser>(
plugin_id,
"handle",
json!({"method": "calculate_similarity", "params": {"src": message1, "dest": message2}}),
)
.await
.unwrap()
}
}
pub struct LocalAIConfiguration {
root: String,
chat_bin_path: String,
chat_model_name: String,
embedding_bin_path: String,
embedding_model_name: String,
}
impl LocalAIConfiguration {
pub fn new() -> Result<Self> {
dotenv::dotenv().ok();
setup_log();
// load from .env
let root = dotenv::var("LOCAL_AI_ROOT_PATH")?;
let chat_bin_path = dotenv::var("CHAT_BIN_PATH")?;
let chat_model_name = dotenv::var("LOCAL_AI_CHAT_MODEL_NAME")?;
let embedding_bin_path = dotenv::var("EMBEDDING_BIN_PATH")?;
let embedding_model_name = dotenv::var("LOCAL_AI_EMBEDDING_MODEL_NAME")?;
Ok(Self {
root,
chat_bin_path,
chat_model_name,
embedding_bin_path,
embedding_model_name,
})
}
pub fn chat_model_absolute_path(&self) -> String {
format!("{}/{}", self.root, self.chat_model_name)
}
pub fn embedding_model_absolute_path(&self) -> String {
format!("{}/{}", self.root, self.embedding_model_name)
}
}
pub fn setup_log() {
static START: Once = Once::new();
START.call_once(|| {
let level = "trace";
let mut filters = vec![];
filters.push(format!("flowy_sidecar={}", level));
std::env::set_var("RUST_LOG", filters.join(","));
let subscriber = Subscriber::builder()
.with_env_filter(EnvFilter::from_default_env())
.with_line_number(true)
.with_ansi(true)
.finish();
subscriber.try_init().unwrap();
});
}

View File

@ -0,0 +1 @@
-- This file should undo anything in `up.sql`

View File

@ -0,0 +1,13 @@
-- Your SQL goes here
ALTER TABLE chat_table ADD COLUMN local_model_path TEXT NOT NULL DEFAULT '';
ALTER TABLE chat_table ADD COLUMN local_model_name TEXT NOT NULL DEFAULT '';
ALTER TABLE chat_table ADD COLUMN local_enabled BOOLEAN NOT NULL DEFAULT FALSE;
ALTER TABLE chat_table ADD COLUMN sync_to_cloud BOOLEAN NOT NULL DEFAULT TRUE;
CREATE TABLE chat_local_setting_table
(
chat_id TEXT PRIMARY KEY NOT NULL,
local_model_path TEXT NOT NULL,
local_model_name TEXT NOT NULL DEFAULT ''
);

View File

@ -1,5 +1,13 @@
// @generated automatically by Diesel CLI.
diesel::table! {
chat_local_setting_table (chat_id) {
chat_id -> Text,
local_model_path -> Text,
local_model_name -> Text,
}
}
diesel::table! {
chat_message_table (message_id) {
message_id -> BigInt,
@ -17,6 +25,10 @@ diesel::table! {
chat_id -> Text,
created_at -> BigInt,
name -> Text,
local_model_path -> Text,
local_model_name -> Text,
local_enabled -> Bool,
sync_to_cloud -> Bool,
}
}
@ -102,6 +114,7 @@ diesel::table! {
}
diesel::allow_tables_to_appear_in_same_query!(
chat_local_setting_table,
chat_message_table,
chat_table,
collab_snapshot,

View File

@ -23,6 +23,7 @@ tracing.workspace = true
atomic_refcell = "0.1"
allo-isolate = { version = "^0.1", features = ["catch-unwind"], optional = true }
futures = "0.3.30"
cfg-if = "1.0.0"
[dev-dependencies]
rand = "0.8.5"

View File

@ -67,9 +67,8 @@ pub fn md5<T: AsRef<[u8]>>(data: T) -> String {
let md5 = format!("{:x}", md5::compute(data));
md5
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Platform {
pub enum OperatingSystem {
Unknown,
Windows,
Linux,
@ -78,33 +77,62 @@ pub enum Platform {
Android,
}
impl Platform {
impl OperatingSystem {
pub fn is_not_ios(&self) -> bool {
!matches!(self, Platform::IOS)
!matches!(self, OperatingSystem::IOS)
}
pub fn is_desktop(&self) -> bool {
matches!(
self,
OperatingSystem::Windows | OperatingSystem::Linux | OperatingSystem::MacOS
)
}
pub fn is_not_desktop(&self) -> bool {
!self.is_desktop()
}
}
impl From<String> for Platform {
impl From<String> for OperatingSystem {
fn from(s: String) -> Self {
Platform::from(s.as_str())
OperatingSystem::from(s.as_str())
}
}
impl From<&String> for Platform {
impl From<&String> for OperatingSystem {
fn from(s: &String) -> Self {
Platform::from(s.as_str())
OperatingSystem::from(s.as_str())
}
}
impl From<&str> for Platform {
impl From<&str> for OperatingSystem {
fn from(s: &str) -> Self {
match s {
"windows" => Platform::Windows,
"linux" => Platform::Linux,
"macos" => Platform::MacOS,
"ios" => Platform::IOS,
"android" => Platform::Android,
_ => Platform::Unknown,
"windows" => OperatingSystem::Windows,
"linux" => OperatingSystem::Linux,
"macos" => OperatingSystem::MacOS,
"ios" => OperatingSystem::IOS,
"android" => OperatingSystem::Android,
_ => OperatingSystem::Unknown,
}
}
}
pub fn get_operating_system() -> OperatingSystem {
cfg_if::cfg_if! {
if #[cfg(target_os = "android")] {
OperatingSystem::Android
} else if #[cfg(target_os = "ios")] {
OperatingSystem::IOS
} else if #[cfg(target_os = "macos")] {
OperatingSystem::MacOS
} else if #[cfg(target_os = "windows")] {
OperatingSystem::Windows
} else if #[cfg(target_os = "linux")] {
OperatingSystem::Linux
} else {
OperatingSystem::Unknown
}
}
}

View File

@ -4,7 +4,7 @@ use std::sync::{Arc, RwLock};
use chrono::Local;
use lazy_static::lazy_static;
use lib_infra::util::Platform;
use lib_infra::util::OperatingSystem;
use tracing::subscriber::set_global_default;
use tracing_appender::rolling::Rotation;
use tracing_appender::{non_blocking::WorkerGuard, rolling::RollingFileAppender};
@ -29,7 +29,7 @@ pub struct Builder {
env_filter: String,
file_appender: RollingFileAppender,
#[allow(dead_code)]
platform: Platform,
platform: OperatingSystem,
stream_log_sender: Option<Arc<dyn StreamLogSender>>,
}
@ -37,7 +37,7 @@ impl Builder {
pub fn new(
name: &str,
directory: &str,
platform: &Platform,
platform: &OperatingSystem,
stream_log_sender: Option<Arc<dyn StreamLogSender>>,
) -> Self {
let file_appender = RollingFileAppender::builder()