chore: update

This commit is contained in:
nathan 2024-07-03 19:20:51 +08:00
parent 6cd3407d20
commit e4f4a128be
5 changed files with 54 additions and 19 deletions

View File

@ -1685,6 +1685,7 @@ dependencies = [
"protobuf", "protobuf",
"serde", "serde",
"serde_json", "serde_json",
"simsimd",
"strum_macros 0.21.1", "strum_macros 0.21.1",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
@ -5145,6 +5146,15 @@ dependencies = [
"time", "time",
] ]
[[package]]
name = "simsimd"
version = "4.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "efc843bc8f12d9c8e6b734a0fe8918fc497b42f6ae0f347dbfdad5b5138ab9b4"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "siphasher" name = "siphasher"
version = "0.3.11" version = "0.3.11"

View File

@ -40,6 +40,7 @@ parking_lot.workspace = true
dotenv = "0.15.0" dotenv = "0.15.0"
uuid.workspace = true uuid.workspace = true
tracing-subscriber = { version = "0.3.17", features = ["registry", "env-filter", "ansi", "json"] } tracing-subscriber = { version = "0.3.17", features = ["registry", "env-filter", "ansi", "json"] }
simsimd = "4.4.0"
[build-dependencies] [build-dependencies]
flowy-codegen.workspace = true flowy-codegen.workspace = true

View File

@ -15,36 +15,48 @@ impl EmbeddingPluginOperation {
EmbeddingPluginOperation { plugin } EmbeddingPluginOperation { plugin }
} }
pub async fn calculate_similarity( pub async fn get_embeddings(&self, message: &str) -> Result<Vec<Vec<f64>>, SidecarError> {
&self,
message1: &str,
message2: &str,
) -> Result<f64, SidecarError> {
let plugin = self let plugin = self
.plugin .plugin
.upgrade() .upgrade()
.ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?; .ok_or(SidecarError::Internal(anyhow!("Plugin is dropped")))?;
let params = let params = json!({"method": "get_embeddings", "params": {"input": message }});
json!({"method": "calculate_similarity", "params": {"src": message1, "dest": message2}});
plugin plugin
.async_request::<SimilarityResponseParser>("handle", &params) .async_request::<EmbeddingResponseParse>("handle", &params)
.await .await
} }
} }
pub struct SimilarityResponseParser; pub struct EmbeddingResponseParse;
impl ResponseParser for SimilarityResponseParser { impl ResponseParser for EmbeddingResponseParse {
type ValueType = f64; type ValueType = Vec<Vec<f64>>;
fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> { fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
if json.is_object() { if json.is_object() {
if let Some(data) = json.get("data") { if let Some(data) = json.get("data") {
if let Some(score) = data.get("score").and_then(|v| v.as_f64()) { if let Some(embeddings) = data.get("embeddings") {
return Ok(score); if let Some(array) = embeddings.as_array() {
let mut result = Vec::new();
for item in array {
if let Some(inner_array) = item.as_array() {
let mut inner_result = Vec::new();
for num in inner_array {
if let Some(value) = num.as_f64() {
inner_result.push(value);
} else {
return Err(RemoteError::ParseResponse(json));
}
}
result.push(inner_result);
} else {
return Err(RemoteError::ParseResponse(json));
}
}
return Ok(result);
}
} }
} }
} }
Err(RemoteError::ParseResponse(json)) Err(RemoteError::ParseResponse(json))
} }
} }

View File

@ -13,7 +13,7 @@ async fn load_chat_model_test() {
let embedding_plugin_id = test.init_embedding_plugin().await; let embedding_plugin_id = test.init_embedding_plugin().await;
let score = test.calculate_similarity(embedding_plugin_id, &resp, "Hello! How can I help you today? Is there something specific you would like to know or discuss").await; let score = test.calculate_similarity(embedding_plugin_id, &resp, "Hello! How can I help you today? Is there something specific you would like to know or discuss").await;
assert!(score > 0.8); assert!(score > 0.9, "score: {}", score);
// let questions = test.related_question(&chat_id, plugin_id).await; // let questions = test.related_question(&chat_id, plugin_id).await;
// assert_eq!(questions.len(), 3); // assert_eq!(questions.len(), 3);

View File

@ -10,6 +10,8 @@ use flowy_chat::local_ai::chat_plugin::ChatPluginOperation;
use flowy_chat::local_ai::embedding_plugin::EmbeddingPluginOperation; use flowy_chat::local_ai::embedding_plugin::EmbeddingPluginOperation;
use flowy_sidecar::core::plugin::{PluginId, PluginInfo}; use flowy_sidecar::core::plugin::{PluginId, PluginInfo};
use flowy_sidecar::error::SidecarError; use flowy_sidecar::error::SidecarError;
use simsimd::SpatialSimilarity;
use std::f64;
use tracing_subscriber::fmt::Subscriber; use tracing_subscriber::fmt::Subscriber;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
@ -104,13 +106,23 @@ impl LocalAITest {
) -> f64 { ) -> f64 {
let plugin = self.manager.get_plugin(plugin_id).await.unwrap(); let plugin = self.manager.get_plugin(plugin_id).await.unwrap();
let operation = EmbeddingPluginOperation::new(plugin); let operation = EmbeddingPluginOperation::new(plugin);
operation let left = operation.get_embeddings(message1).await.unwrap();
.calculate_similarity(message1, message2) let right = operation.get_embeddings(message2).await.unwrap();
.await
.unwrap() let actual_embedding_flat = flatten_vec(left);
let expected_embedding_flat = flatten_vec(right);
let distance = f64::cosine(&actual_embedding_flat, &expected_embedding_flat)
.expect("Vectors must be of the same length");
distance.cos()
} }
} }
// Function to flatten Vec<Vec<f64>> into Vec<f64>
fn flatten_vec(vec: Vec<Vec<f64>>) -> Vec<f64> {
vec.into_iter().flatten().collect()
}
pub struct LocalAIConfiguration { pub struct LocalAIConfiguration {
model_dir: String, model_dir: String,
chat_bin_path: PathBuf, chat_bin_path: PathBuf,