diff --git a/frontend/rust-lib/Cargo.lock b/frontend/rust-lib/Cargo.lock index 215fbdd43b..d787ea79c3 100644 --- a/frontend/rust-lib/Cargo.lock +++ b/frontend/rust-lib/Cargo.lock @@ -1685,6 +1685,7 @@ dependencies = [ "protobuf", "serde", "serde_json", + "simsimd", "strum_macros 0.21.1", "tokio", "tokio-stream", @@ -5145,6 +5146,15 @@ dependencies = [ "time", ] +[[package]] +name = "simsimd" +version = "4.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efc843bc8f12d9c8e6b734a0fe8918fc497b42f6ae0f347dbfdad5b5138ab9b4" +dependencies = [ + "cc", +] + [[package]] name = "siphasher" version = "0.3.11" diff --git a/frontend/rust-lib/flowy-chat/Cargo.toml b/frontend/rust-lib/flowy-chat/Cargo.toml index 879b707c1a..4ac2c5dad5 100644 --- a/frontend/rust-lib/flowy-chat/Cargo.toml +++ b/frontend/rust-lib/flowy-chat/Cargo.toml @@ -40,6 +40,7 @@ parking_lot.workspace = true dotenv = "0.15.0" uuid.workspace = true tracing-subscriber = { version = "0.3.17", features = ["registry", "env-filter", "ansi", "json"] } +simsimd = "4.4.0" [build-dependencies] flowy-codegen.workspace = true diff --git a/frontend/rust-lib/flowy-chat/src/local_ai/embedding_plugin.rs b/frontend/rust-lib/flowy-chat/src/local_ai/embedding_plugin.rs index beb0bf75a9..93675172a8 100644 --- a/frontend/rust-lib/flowy-chat/src/local_ai/embedding_plugin.rs +++ b/frontend/rust-lib/flowy-chat/src/local_ai/embedding_plugin.rs @@ -15,36 +15,48 @@ impl EmbeddingPluginOperation { EmbeddingPluginOperation { plugin } } - pub async fn calculate_similarity( - &self, - message1: &str, - message2: &str, - ) -> Result { + pub async fn get_embeddings(&self, message: &str) -> Result>, SidecarError> { let plugin = self .plugin .upgrade() .ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?; - let params = - json!({"method": "calculate_similarity", "params": {"src": message1, "dest": message2}}); + let params = json!({"method": "get_embeddings", "params": {"input": message }}); plugin - .async_request::("handle", ¶ms) + .async_request::("handle", ¶ms) .await } } -pub struct SimilarityResponseParser; -impl ResponseParser for SimilarityResponseParser { - type ValueType = f64; +pub struct EmbeddingResponseParse; +impl ResponseParser for EmbeddingResponseParse { + type ValueType = Vec>; fn parse_json(json: JsonValue) -> Result { if json.is_object() { if let Some(data) = json.get("data") { - if let Some(score) = data.get("score").and_then(|v| v.as_f64()) { - return Ok(score); + if let Some(embeddings) = data.get("embeddings") { + if let Some(array) = embeddings.as_array() { + let mut result = Vec::new(); + for item in array { + if let Some(inner_array) = item.as_array() { + let mut inner_result = Vec::new(); + for num in inner_array { + if let Some(value) = num.as_f64() { + inner_result.push(value); + } else { + return Err(RemoteError::ParseResponse(json)); + } + } + result.push(inner_result); + } else { + return Err(RemoteError::ParseResponse(json)); + } + } + return Ok(result); + } } } } - Err(RemoteError::ParseResponse(json)) } } diff --git a/frontend/rust-lib/flowy-chat/tests/chat_test/mod.rs b/frontend/rust-lib/flowy-chat/tests/chat_test/mod.rs index 2730bf96d5..62d279d12f 100644 --- a/frontend/rust-lib/flowy-chat/tests/chat_test/mod.rs +++ b/frontend/rust-lib/flowy-chat/tests/chat_test/mod.rs @@ -13,7 +13,7 @@ async fn load_chat_model_test() { let embedding_plugin_id = test.init_embedding_plugin().await; let score = test.calculate_similarity(embedding_plugin_id, &resp, "Hello! How can I help you today? Is there something specific you would like to know or discuss").await; - assert!(score > 0.8); + assert!(score > 0.9, "score: {}", score); // let questions = test.related_question(&chat_id, plugin_id).await; // assert_eq!(questions.len(), 3); diff --git a/frontend/rust-lib/flowy-chat/tests/util.rs b/frontend/rust-lib/flowy-chat/tests/util.rs index 49757529f0..a38979c828 100644 --- a/frontend/rust-lib/flowy-chat/tests/util.rs +++ b/frontend/rust-lib/flowy-chat/tests/util.rs @@ -10,6 +10,8 @@ use flowy_chat::local_ai::chat_plugin::ChatPluginOperation; use flowy_chat::local_ai::embedding_plugin::EmbeddingPluginOperation; use flowy_sidecar::core::plugin::{PluginId, PluginInfo}; use flowy_sidecar::error::SidecarError; +use simsimd::SpatialSimilarity; +use std::f64; use tracing_subscriber::fmt::Subscriber; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::EnvFilter; @@ -104,13 +106,23 @@ impl LocalAITest { ) -> f64 { let plugin = self.manager.get_plugin(plugin_id).await.unwrap(); let operation = EmbeddingPluginOperation::new(plugin); - operation - .calculate_similarity(message1, message2) - .await - .unwrap() + let left = operation.get_embeddings(message1).await.unwrap(); + let right = operation.get_embeddings(message2).await.unwrap(); + + let actual_embedding_flat = flatten_vec(left); + let expected_embedding_flat = flatten_vec(right); + let distance = f64::cosine(&actual_embedding_flat, &expected_embedding_flat) + .expect("Vectors must be of the same length"); + + distance.cos() } } +// Function to flatten Vec> into Vec +fn flatten_vec(vec: Vec>) -> Vec { + vec.into_iter().flatten().collect() +} + pub struct LocalAIConfiguration { model_dir: String, chat_bin_path: PathBuf,