mirror of
https://github.com/AppFlowy-IO/AppFlowy.git
synced 2025-07-26 03:23:01 +00:00
chore: support multiple source search in chat (#7906)
* chore: support multiple source search in chat * chore: clippy
This commit is contained in:
@ -217,7 +217,6 @@ impl AIManager {
|
||||
let summary = select_chat_summary(&mut conn, chat_id).unwrap_or_default();
|
||||
|
||||
let model = self.get_active_model(&chat_id.to_string()).await;
|
||||
trace!("[AI Plugin] notify open chat: {}", chat_id);
|
||||
self
|
||||
.local_ai
|
||||
.open_chat(&workspace_id, chat_id, &model.name, rag_ids, summary)
|
||||
@ -240,7 +239,7 @@ impl AIManager {
|
||||
.await
|
||||
{
|
||||
Ok(settings) => {
|
||||
local_ai.set_rag_ids(&chat_id, &settings.rag_ids);
|
||||
local_ai.set_rag_ids(&chat_id, &settings.rag_ids).await;
|
||||
let rag_ids = settings
|
||||
.rag_ids
|
||||
.into_iter()
|
||||
@ -712,7 +711,7 @@ impl AIManager {
|
||||
|
||||
let user_service = self.user_service.clone();
|
||||
let external_service = self.external_service.clone();
|
||||
self.local_ai.set_rag_ids(chat_id, &rag_ids);
|
||||
self.local_ai.set_rag_ids(chat_id, &rag_ids).await;
|
||||
|
||||
let rag_ids = rag_ids
|
||||
.into_iter()
|
||||
|
@ -1,6 +1,7 @@
|
||||
use crate::embeddings::document_indexer::split_text_into_chunks;
|
||||
use crate::embeddings::embedder::{Embedder, OllamaEmbedder};
|
||||
use crate::embeddings::indexer::{EmbeddingModel, IndexerProvider};
|
||||
use crate::local_ai::chat::retriever::MultipleSourceRetrieverStore;
|
||||
use async_trait::async_trait;
|
||||
use flowy_ai_pub::cloud::CollabType;
|
||||
use flowy_ai_pub::entities::{RAG_IDS, SOURCE_ID};
|
||||
@ -9,10 +10,8 @@ use flowy_sqlite_vec::db::VectorSqliteDB;
|
||||
use flowy_sqlite_vec::entities::{EmbeddedContent, SqliteEmbeddedDocument};
|
||||
use futures::stream::{self, StreamExt};
|
||||
use langchain_rust::llm::client::OllamaClient;
|
||||
use langchain_rust::{
|
||||
schemas::Document,
|
||||
vectorstore::{VecStoreOptions, VectorStore},
|
||||
};
|
||||
use langchain_rust::schemas::Document;
|
||||
use langchain_rust::vectorstore::{VecStoreOptions, VectorStore};
|
||||
use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
@ -86,6 +85,80 @@ impl SqliteVectorStore {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MultipleSourceRetrieverStore for SqliteVectorStore {
|
||||
fn retriever_name(&self) -> &'static str {
|
||||
"Sqlite Multiple Source Retriever"
|
||||
}
|
||||
|
||||
async fn read_documents(
|
||||
&self,
|
||||
workspace_id: &Uuid,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
rag_ids: &[String],
|
||||
score_threshold: f32,
|
||||
_full_search: bool,
|
||||
) -> FlowyResult<Vec<Document>> {
|
||||
let vector_db = match self.vector_db.upgrade() {
|
||||
Some(db) => db,
|
||||
None => return Err(FlowyError::internal().with_context("Vector database not initialized")),
|
||||
};
|
||||
|
||||
// Create embedder and generate embedding for query
|
||||
let embedder = self.create_embedder()?;
|
||||
let request = GenerateEmbeddingsRequest::new(
|
||||
embedder.model().name().to_string(),
|
||||
EmbeddingsInput::Single(query.to_string()),
|
||||
);
|
||||
|
||||
let embedding = embedder.embed(request).await?.embeddings;
|
||||
if embedding.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
debug_assert!(embedding.len() == 1);
|
||||
let query_embedding = embedding.first().unwrap();
|
||||
|
||||
// Perform similarity search in the database
|
||||
let results = vector_db
|
||||
.search_with_score(
|
||||
&workspace_id.to_string(),
|
||||
rag_ids,
|
||||
query_embedding,
|
||||
limit as i32,
|
||||
score_threshold,
|
||||
)
|
||||
.await?;
|
||||
|
||||
trace!(
|
||||
"[VectorStore] Found {} results for query:{}, rag_ids: {:?}, score_threshold: {}",
|
||||
results.len(),
|
||||
query,
|
||||
rag_ids,
|
||||
score_threshold
|
||||
);
|
||||
|
||||
// Convert results to Documents
|
||||
let documents = results
|
||||
.into_iter()
|
||||
.map(|result| {
|
||||
let mut metadata = HashMap::new();
|
||||
|
||||
if let Some(map) = result.metadata.as_ref().and_then(|v| v.as_object()) {
|
||||
for (key, value) in map {
|
||||
metadata.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
Document::new(result.content).with_metadata(metadata)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(documents)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl VectorStore for SqliteVectorStore {
|
||||
type Options = VecStoreOptions<Value>;
|
||||
@ -215,74 +288,23 @@ impl VectorStore for SqliteVectorStore {
|
||||
|
||||
// Return empty result if workspace_id is missing
|
||||
let workspace_id = match workspace_id {
|
||||
Some(id) => id.to_string(),
|
||||
Some(id) => id,
|
||||
None => {
|
||||
warn!("[VectorStore] Missing workspace_id in filters. Returning empty result.");
|
||||
return Ok(Vec::new());
|
||||
},
|
||||
};
|
||||
|
||||
// Get the vector database
|
||||
let vector_db = match self.vector_db.upgrade() {
|
||||
Some(db) => db,
|
||||
None => return Err("Vector database not initialized".into()),
|
||||
};
|
||||
|
||||
// Create embedder and generate embedding for query
|
||||
let embedder = self.create_embedder()?;
|
||||
let request = GenerateEmbeddingsRequest::new(
|
||||
embedder.model().name().to_string(),
|
||||
EmbeddingsInput::Single(query.to_string()),
|
||||
);
|
||||
|
||||
let embedding = match embedder.embed(request).await {
|
||||
Ok(result) => result.embeddings,
|
||||
Err(e) => return Err(Box::new(e)),
|
||||
};
|
||||
|
||||
if embedding.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let score_threshold = opt.score_threshold.unwrap_or(0.4);
|
||||
debug_assert!(embedding.len() == 1);
|
||||
let query_embedding = embedding.first().unwrap();
|
||||
|
||||
// Perform similarity search in the database
|
||||
let results = vector_db
|
||||
.search_with_score(
|
||||
self
|
||||
.read_documents(
|
||||
&workspace_id,
|
||||
query,
|
||||
limit,
|
||||
&rag_ids,
|
||||
query_embedding,
|
||||
limit as i32,
|
||||
score_threshold,
|
||||
opt.score_threshold.unwrap_or(0.4),
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
|
||||
trace!(
|
||||
"[VectorStore] Found {} results for query:{}, rag_ids: {:?}, score_threshold: {}",
|
||||
results.len(),
|
||||
query,
|
||||
rag_ids,
|
||||
score_threshold
|
||||
);
|
||||
|
||||
// Convert results to Documents
|
||||
let documents = results
|
||||
.into_iter()
|
||||
.map(|result| {
|
||||
let mut metadata = HashMap::new();
|
||||
|
||||
if let Some(map) = result.metadata.as_ref().and_then(|v| v.as_object()) {
|
||||
for (key, value) in map {
|
||||
metadata.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
Document::new(result.content).with_metadata(metadata)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(documents)
|
||||
.await
|
||||
.map_err(|err| Box::new(err) as Box<dyn Error>)
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
use crate::local_ai::chat::llm::LLMOllama;
|
||||
use crate::SqliteVectorStore;
|
||||
use flowy_error::{FlowyError, FlowyResult};
|
||||
use flowy_sqlite_vec::entities::EmbeddedContent;
|
||||
use langchain_rust::language_models::llm::LLM;
|
||||
use langchain_rust::prompt::TemplateFormat;
|
||||
use langchain_rust::prompt::{PromptFromatter, PromptTemplate};
|
||||
@ -10,6 +11,7 @@ use ollama_rs::generation::parameters::{FormatType, JsonStructure};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::fmt::Debug;
|
||||
use tracing::trace;
|
||||
use uuid::Uuid;
|
||||
|
||||
@ -60,13 +62,13 @@ pub struct ContextQuestion {
|
||||
pub object_id: String,
|
||||
}
|
||||
|
||||
pub struct RelatedQuestionChain {
|
||||
pub struct ContextRelatedQuestionChain {
|
||||
workspace_id: Uuid,
|
||||
llm: LLMOllama,
|
||||
store: SqliteVectorStore,
|
||||
}
|
||||
|
||||
impl RelatedQuestionChain {
|
||||
impl ContextRelatedQuestionChain {
|
||||
pub fn new(workspace_id: Uuid, ollama: LLMOllama, store: SqliteVectorStore) -> Self {
|
||||
let format = FormatType::StructuredJson(JsonStructure::new::<ContextQuestionsResponse>());
|
||||
Self {
|
||||
@ -76,25 +78,16 @@ impl RelatedQuestionChain {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn generate_questions(&self, rag_ids: &[String]) -> FlowyResult<Vec<ContextQuestion>> {
|
||||
trace!(
|
||||
"[embedding] Generating context related questions for RAG IDs: {:?}",
|
||||
rag_ids
|
||||
);
|
||||
|
||||
let context = self
|
||||
.store
|
||||
.select_all_embedded_content(&self.workspace_id.to_string(), rag_ids, 3)
|
||||
.await?;
|
||||
|
||||
trace!(
|
||||
"[embedding] Generating related questions base on: {:?}",
|
||||
context,
|
||||
);
|
||||
|
||||
let context_str = json!(context).to_string();
|
||||
pub async fn generate_questions_from_context<T>(
|
||||
&self,
|
||||
rag_ids: &[T],
|
||||
context: &str,
|
||||
) -> FlowyResult<Vec<ContextQuestion>>
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
let input_variables = prompt_args! {
|
||||
"context" => context_str,
|
||||
"context" => context,
|
||||
};
|
||||
|
||||
let template = PromptTemplate::new(
|
||||
@ -116,8 +109,42 @@ impl RelatedQuestionChain {
|
||||
// filter out questions that are not in the rag_ids
|
||||
parsed_result
|
||||
.questions
|
||||
.retain(|v| rag_ids.contains(&v.object_id));
|
||||
.retain(|v| rag_ids.iter().any(|id| id.as_ref() == v.object_id));
|
||||
|
||||
Ok(parsed_result.questions)
|
||||
}
|
||||
|
||||
pub async fn generate_questions<T>(
|
||||
&self,
|
||||
rag_ids: &[T],
|
||||
) -> FlowyResult<(String, Vec<ContextQuestion>)>
|
||||
where
|
||||
T: AsRef<str> + Debug,
|
||||
{
|
||||
trace!(
|
||||
"[embedding] Generating context related questions for RAG IDs: {:?}",
|
||||
rag_ids
|
||||
);
|
||||
|
||||
let rag_ids_str: Vec<String> = rag_ids.iter().map(|id| id.as_ref().to_string()).collect();
|
||||
let context = self
|
||||
.store
|
||||
.select_all_embedded_content(&self.workspace_id.to_string(), &rag_ids_str, 3)
|
||||
.await?;
|
||||
|
||||
trace!(
|
||||
"[embedding] Generating related questions base on: {:?}",
|
||||
context,
|
||||
);
|
||||
|
||||
let context_str = embedded_documents_to_context_str(context);
|
||||
self
|
||||
.generate_questions_from_context(rag_ids, &context_str)
|
||||
.await
|
||||
.map(|questions| (context_str, questions))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn embedded_documents_to_context_str(documents: Vec<EmbeddedContent>) -> String {
|
||||
json!(documents).to_string()
|
||||
}
|
||||
|
@ -1,11 +1,17 @@
|
||||
use crate::local_ai::chat::chains::context_question_chain::RelatedQuestionChain;
|
||||
use crate::local_ai::chat::chains::context_question_chain::{
|
||||
embedded_documents_to_context_str, ContextRelatedQuestionChain,
|
||||
};
|
||||
use crate::local_ai::chat::chains::related_question_chain::RelatedQuestionChain;
|
||||
use crate::local_ai::chat::llm::LLMOllama;
|
||||
use crate::local_ai::chat::retriever::AFRetriever;
|
||||
use crate::SqliteVectorStore;
|
||||
use arc_swap::ArcSwap;
|
||||
use async_stream::stream;
|
||||
use async_trait::async_trait;
|
||||
use flowy_ai_pub::cloud::{ContextSuggestedQuestion, QuestionStreamValue};
|
||||
use flowy_ai_pub::entities::{RAG_IDS, SOURCE_ID};
|
||||
use flowy_ai_pub::entities::SOURCE_ID;
|
||||
use flowy_error::{FlowyError, FlowyResult};
|
||||
use flowy_sqlite_vec::entities::EmbeddedContent;
|
||||
use futures::Stream;
|
||||
use futures_util::{pin_mut, StreamExt};
|
||||
use langchain_rust::chain::{
|
||||
@ -15,11 +21,9 @@ use langchain_rust::chain::{
|
||||
use langchain_rust::language_models::{GenerateResult, TokenUsage};
|
||||
use langchain_rust::memory::SimpleMemory;
|
||||
use langchain_rust::prompt::{FormatPrompter, PromptArgs};
|
||||
use langchain_rust::schemas::{BaseMemory, Document, Message, Retriever, StreamData};
|
||||
use langchain_rust::vectorstore::{VecStoreOptions, VectorStore};
|
||||
use langchain_rust::schemas::{BaseMemory, Document, Message, StreamData};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::error::Error;
|
||||
use std::{collections::HashMap, pin::Pin, sync::Arc};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_util::either::Either;
|
||||
@ -37,15 +41,16 @@ const CONVERSATIONAL_RETRIEVAL_QA_DEFAULT_INPUT_KEY: &str = "question";
|
||||
|
||||
pub struct ConversationalRetrieverChain {
|
||||
pub(crate) ollama: LLMOllama,
|
||||
pub(crate) retriever: AFRetriever,
|
||||
pub(crate) retriever: Box<dyn AFRetriever>,
|
||||
pub memory: Arc<Mutex<dyn BaseMemory>>,
|
||||
pub(crate) combine_documents_chain: Box<dyn Chain>,
|
||||
pub(crate) condense_question_chain: Box<dyn Chain>,
|
||||
pub(crate) context_question_chain: Option<RelatedQuestionChain>,
|
||||
pub(crate) context_question_chain: Option<ContextRelatedQuestionChain>,
|
||||
pub(crate) rephrase_question: bool,
|
||||
pub(crate) return_source_documents: bool,
|
||||
pub(crate) input_key: String,
|
||||
pub(crate) output_key: String,
|
||||
latest_context: ArcSwap<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@ -91,6 +96,27 @@ impl ConversationalRetrieverChain {
|
||||
Ok((question, token_usage))
|
||||
}
|
||||
|
||||
pub async fn get_related_questions(&self, question: &str) -> Result<Vec<String>, FlowyError> {
|
||||
let context = self.latest_context.load();
|
||||
let rag_ids = self.retriever.get_rag_ids();
|
||||
|
||||
if context.is_empty() {
|
||||
trace!("[Chat] No context available. Generating related questions");
|
||||
let chain = RelatedQuestionChain::new(self.ollama.clone());
|
||||
chain.generate_related_question(question).await
|
||||
} else if let Some(c) = self.context_question_chain.as_ref() {
|
||||
trace!(
|
||||
"[Chat] found context:{}. Generating context related questions",
|
||||
context
|
||||
);
|
||||
c.generate_questions_from_context(&rag_ids, &context)
|
||||
.await
|
||||
.map(|questions| questions.into_iter().map(|q| q.content).collect())
|
||||
} else {
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_documents_or_result(
|
||||
&self,
|
||||
question: &str,
|
||||
@ -101,7 +127,7 @@ impl ConversationalRetrieverChain {
|
||||
} else {
|
||||
let documents = self
|
||||
.retriever
|
||||
.get_relevant_documents(question)
|
||||
.retrieve_documents(question)
|
||||
.await
|
||||
.map_err(|e| ChainError::RetrieverError(e.to_string()))?;
|
||||
|
||||
@ -115,7 +141,8 @@ impl ConversationalRetrieverChain {
|
||||
if let Some(c) = self.context_question_chain.as_ref() {
|
||||
let rag_ids = rag_ids.iter().map(|v| v.to_string()).collect::<Vec<_>>();
|
||||
match c.generate_questions(&rag_ids).await {
|
||||
Ok(questions) => {
|
||||
Ok((context, questions)) => {
|
||||
self.latest_context.store(Arc::new(context));
|
||||
trace!("[embedding]: context related questions: {:?}", questions);
|
||||
suggested_questions = questions
|
||||
.into_iter()
|
||||
@ -134,7 +161,7 @@ impl ConversationalRetrieverChain {
|
||||
}
|
||||
}
|
||||
|
||||
return if suggested_questions.is_empty() {
|
||||
if suggested_questions.is_empty() {
|
||||
Ok(Either::Right(StreamValue::ContextSuggested {
|
||||
value: CAN_NOT_ANSWER_WITH_CONTEXT.to_string(),
|
||||
suggested_questions,
|
||||
@ -144,10 +171,27 @@ impl ConversationalRetrieverChain {
|
||||
value: ANSWER_WITH_SUGGESTED_QUESTION.to_string(),
|
||||
suggested_questions,
|
||||
}))
|
||||
};
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let embedded_docs = documents
|
||||
.iter()
|
||||
.flat_map(|d| {
|
||||
let object_id = d
|
||||
.metadata
|
||||
.get(SOURCE_ID)
|
||||
.and_then(|v| v.as_str().map(|v| v.to_string()))?;
|
||||
Some(EmbeddedContent {
|
||||
content: d.page_content.clone(),
|
||||
object_id,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Ok(Either::Left(documents))
|
||||
let context = embedded_documents_to_context_str(embedded_docs);
|
||||
self.latest_context.store(Arc::new(context));
|
||||
|
||||
Ok(Either::Left(documents))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -426,7 +470,7 @@ impl Chain for ConversationalRetrieverChain {
|
||||
pub struct ConversationalRetrieverChainBuilder {
|
||||
workspace_id: Uuid,
|
||||
llm: LLMOllama,
|
||||
retriever: AFRetriever,
|
||||
retriever: Box<dyn AFRetriever>,
|
||||
memory: Option<Arc<Mutex<dyn BaseMemory>>>,
|
||||
prompt: Option<Box<dyn FormatPrompter>>,
|
||||
rephrase_question: bool,
|
||||
@ -439,7 +483,7 @@ impl ConversationalRetrieverChainBuilder {
|
||||
pub fn new(
|
||||
workspace_id: Uuid,
|
||||
llm: LLMOllama,
|
||||
retriever: AFRetriever,
|
||||
retriever: Box<dyn AFRetriever>,
|
||||
store: Option<SqliteVectorStore>,
|
||||
) -> Self {
|
||||
ConversationalRetrieverChainBuilder {
|
||||
@ -496,7 +540,7 @@ impl ConversationalRetrieverChainBuilder {
|
||||
|
||||
let context_question_chain = self
|
||||
.store
|
||||
.map(|store| RelatedQuestionChain::new(self.workspace_id, self.llm.clone(), store));
|
||||
.map(|store| ContextRelatedQuestionChain::new(self.workspace_id, self.llm.clone(), store));
|
||||
|
||||
Ok(ConversationalRetrieverChain {
|
||||
ollama: self.llm,
|
||||
@ -509,67 +553,11 @@ impl ConversationalRetrieverChainBuilder {
|
||||
return_source_documents: self.return_source_documents,
|
||||
input_key: self.input_key,
|
||||
output_key: self.output_key,
|
||||
latest_context: Default::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Retriever is a retriever for vector stores.
|
||||
pub type RetrieverOption = VecStoreOptions<Value>;
|
||||
pub struct AFRetriever {
|
||||
vector_store: Option<Box<dyn VectorStore<Options = RetrieverOption>>>,
|
||||
num_docs: usize,
|
||||
options: RetrieverOption,
|
||||
}
|
||||
impl AFRetriever {
|
||||
pub fn new<V: Into<Box<dyn VectorStore<Options = RetrieverOption>>>>(
|
||||
vector_store: Option<V>,
|
||||
num_docs: usize,
|
||||
options: RetrieverOption,
|
||||
) -> Self {
|
||||
AFRetriever {
|
||||
vector_store: vector_store.map(Into::into),
|
||||
num_docs,
|
||||
options,
|
||||
}
|
||||
}
|
||||
pub fn set_rag_ids(&mut self, new_rag_ids: Vec<String>) {
|
||||
trace!("[VectorStore] retriever {:p}", self);
|
||||
let filters = self.options.filters.get_or_insert_with(|| json!({}));
|
||||
filters[RAG_IDS] = json!(new_rag_ids);
|
||||
}
|
||||
|
||||
pub fn get_rag_ids(&self) -> Vec<&str> {
|
||||
trace!("[VectorStore] retriever {:p}", self);
|
||||
self
|
||||
.options
|
||||
.filters
|
||||
.as_ref()
|
||||
.and_then(|filters| filters.get(RAG_IDS).and_then(|rag_ids| rag_ids.as_array()))
|
||||
.map(|rag_ids| rag_ids.iter().filter_map(|id| id.as_str()).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Retriever for AFRetriever {
|
||||
async fn get_relevant_documents(&self, query: &str) -> Result<Vec<Document>, Box<dyn Error>> {
|
||||
trace!(
|
||||
"[VectorStore] filters: {:?}, retrieving documents for query: {}",
|
||||
self.options.filters,
|
||||
query,
|
||||
);
|
||||
|
||||
match self.vector_store.as_ref() {
|
||||
None => Ok(vec![]),
|
||||
Some(vector_store) => {
|
||||
vector_store
|
||||
.similarity_search(query, self.num_docs, &self.options)
|
||||
.await
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Deduplicates metadata from a list of documents by merging metadata entries with the same keys
|
||||
fn deduplicate_metadata(documents: &[Document]) -> Vec<QuestionStreamValue> {
|
||||
let mut merged_metadata: HashMap<String, QuestionStreamValue> = HashMap::new();
|
||||
|
@ -6,17 +6,9 @@ use ollama_rs::generation::parameters::{FormatType, JsonStructure};
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
|
||||
const SUMMARIZE_SYSTEM_PROMPT: &str = r#"
|
||||
As an AppFlowy AI assistant, your task is to generate three medium-length, relevant, and informative questions based on the provided conversation history.
|
||||
The output should only return a JSON instance that conforms to the JSON schema below.
|
||||
|
||||
{
|
||||
"questions": [
|
||||
"What are the key skills needed to tackle a black diamond slope in snowboarding?",
|
||||
"How does the difficulty of black diamond trails compare across different ski resorts?",
|
||||
"Can you provide tips for snowboarders preparing to try a black diamond trail for the first time?"
|
||||
]
|
||||
}
|
||||
const SYSTEM_PROMPT: &str = r#"
|
||||
You are the AppFlowy AI assistant. Given the conversation history, generate exactly three medium-length, relevant, and informative questions.
|
||||
Respond with a single JSON object matching the schema below—and nothing else. If you can’t generate questions, return {}.
|
||||
"#;
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
@ -36,9 +28,9 @@ impl RelatedQuestionChain {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn related_question(&self, question: &str) -> FlowyResult<Vec<String>> {
|
||||
pub async fn generate_related_question(&self, question: &str) -> FlowyResult<Vec<String>> {
|
||||
let messages = vec![
|
||||
Message::new_system_message(SUMMARIZE_SYSTEM_PROMPT),
|
||||
Message::new_system_message(SYSTEM_PROMPT),
|
||||
Message::new_human_message(question),
|
||||
];
|
||||
|
||||
|
@ -9,7 +9,7 @@ use langchain_rust::template_jinja2;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
const QA_CONTEXT_TEMPLATE: &str = r#"
|
||||
Only Use the context provided below to formulate your answer. Do not use any other information. If the context doesn't contain sufficient information to answer the question, respond with "I don't know".
|
||||
Only Use the context provided below to formulate your answer. Do not use any other information.
|
||||
Do not reference external knowledge or information outside the context.
|
||||
|
||||
##Context##
|
||||
|
@ -1,9 +1,11 @@
|
||||
use crate::local_ai::chat::chains::conversation_chain::{
|
||||
AFRetriever, ConversationalRetrieverChain, ConversationalRetrieverChainBuilder, RetrieverOption,
|
||||
ConversationalRetrieverChain, ConversationalRetrieverChainBuilder,
|
||||
};
|
||||
use crate::local_ai::chat::chains::related_question_chain::RelatedQuestionChain;
|
||||
use crate::local_ai::chat::format_prompt::AFContextPrompt;
|
||||
use crate::local_ai::chat::llm::LLMOllama;
|
||||
use crate::local_ai::chat::retriever::multi_source_retriever::MultipleSourceRetriever;
|
||||
use crate::local_ai::chat::retriever::sqlite_retriever::RetrieverOption;
|
||||
use crate::local_ai::chat::retriever::{AFRetriever, MultipleSourceRetrieverStore};
|
||||
use crate::local_ai::chat::summary_memory::SummaryMemory;
|
||||
use crate::local_ai::chat::LLMChatInfo;
|
||||
use crate::SqliteVectorStore;
|
||||
@ -27,6 +29,7 @@ use uuid::Uuid;
|
||||
pub struct LLMChat {
|
||||
store: Option<SqliteVectorStore>,
|
||||
chain: ConversationalRetrieverChain,
|
||||
#[allow(dead_code)]
|
||||
client: Arc<Ollama>,
|
||||
prompt: AFContextPrompt,
|
||||
info: LLMChatInfo,
|
||||
@ -38,6 +41,7 @@ impl LLMChat {
|
||||
client: Arc<Ollama>,
|
||||
store: Option<SqliteVectorStore>,
|
||||
user_service: Option<Weak<dyn AIUserService>>,
|
||||
retriever_sources: Vec<Weak<dyn MultipleSourceRetrieverStore>>,
|
||||
) -> FlowyResult<Self> {
|
||||
let response_format = ResponseFormat::default();
|
||||
let formatter = create_formatter_prompt_with_format(&response_format, &info.rag_ids);
|
||||
@ -47,7 +51,12 @@ impl LLMChat {
|
||||
.map(|v| v.into())
|
||||
.unwrap_or(SimpleMemory::new().into());
|
||||
|
||||
let retriever = create_retriever(&info.workspace_id, info.rag_ids.clone(), store.clone());
|
||||
let retriever = create_retriever(
|
||||
&info.workspace_id,
|
||||
info.rag_ids.clone(),
|
||||
store.clone(),
|
||||
retriever_sources,
|
||||
);
|
||||
let builder =
|
||||
ConversationalRetrieverChainBuilder::new(info.workspace_id, llm, retriever, store.clone())
|
||||
.rephrase_question(false)
|
||||
@ -64,19 +73,7 @@ impl LLMChat {
|
||||
}
|
||||
|
||||
pub async fn get_related_question(&self, user_message: String) -> FlowyResult<Vec<String>> {
|
||||
let chain = RelatedQuestionChain::new(LLMOllama::new(
|
||||
&self.info.model,
|
||||
self.client.clone(),
|
||||
None,
|
||||
None,
|
||||
));
|
||||
let questions = chain.related_question(&user_message).await?;
|
||||
trace!(
|
||||
"related questions: {:?} for message: {}",
|
||||
questions,
|
||||
user_message
|
||||
);
|
||||
Ok(questions)
|
||||
self.chain.get_related_questions(&user_message).await
|
||||
}
|
||||
|
||||
pub fn set_chat_model(&mut self, model: &str) {
|
||||
@ -198,17 +195,40 @@ fn create_retriever(
|
||||
workspace_id: &Uuid,
|
||||
rag_ids: Vec<String>,
|
||||
store: Option<SqliteVectorStore>,
|
||||
) -> AFRetriever {
|
||||
retrievers_sources: Vec<Weak<dyn MultipleSourceRetrieverStore>>,
|
||||
) -> Box<dyn AFRetriever> {
|
||||
trace!(
|
||||
"[VectorStore]: {} create retriever with rag_ids: {:?}",
|
||||
workspace_id,
|
||||
rag_ids,
|
||||
);
|
||||
let options = VecStoreOptions::default()
|
||||
.with_score_threshold(0.2)
|
||||
.with_filters(json!({RAG_IDS: rag_ids, "workspace_id": workspace_id}));
|
||||
|
||||
AFRetriever::new(store, 5, options)
|
||||
let mut stores: Vec<Arc<dyn MultipleSourceRetrieverStore>> = vec![];
|
||||
if let Some(store) = store {
|
||||
stores.push(Arc::new(store));
|
||||
}
|
||||
|
||||
for source in retrievers_sources {
|
||||
if let Some(source) = source.upgrade() {
|
||||
stores.push(source);
|
||||
}
|
||||
}
|
||||
|
||||
trace!(
|
||||
"[VectorStore]: use retrievers sources: {:?}",
|
||||
stores
|
||||
.iter()
|
||||
.map(|s| s.retriever_name())
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
Box::new(MultipleSourceRetriever::new(
|
||||
*workspace_id,
|
||||
stores,
|
||||
rag_ids.clone(),
|
||||
5,
|
||||
0.2,
|
||||
))
|
||||
}
|
||||
|
||||
fn map_chain_error(err: ChainError) -> FlowyError {
|
||||
|
@ -2,11 +2,12 @@ pub mod chains;
|
||||
mod format_prompt;
|
||||
pub mod llm;
|
||||
pub mod llm_chat;
|
||||
pub mod retriever;
|
||||
mod summary_memory;
|
||||
|
||||
use crate::local_ai::chat::chains::related_question_chain::RelatedQuestionChain;
|
||||
use crate::local_ai::chat::llm::LLMOllama;
|
||||
use crate::local_ai::chat::llm_chat::LLMChat;
|
||||
use crate::local_ai::chat::retriever::MultipleSourceRetrieverStore;
|
||||
use crate::local_ai::completion::chain::CompletionChain;
|
||||
use crate::local_ai::database::summary::DatabaseSummaryChain;
|
||||
use crate::local_ai::database::translate::DatabaseTranslateChain;
|
||||
@ -28,7 +29,7 @@ use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Weak};
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
type OllamaClientRef = Arc<RwLock<Option<Weak<Ollama>>>>;
|
||||
@ -41,11 +42,14 @@ pub struct LLMChatInfo {
|
||||
pub summary: String,
|
||||
}
|
||||
|
||||
pub type RetrieversSources = RwLock<Vec<Arc<dyn MultipleSourceRetrieverStore>>>;
|
||||
|
||||
pub struct LLMChatController {
|
||||
chat_by_id: DashMap<Uuid, LLMChat>,
|
||||
chat_by_id: DashMap<Uuid, Arc<RwLock<LLMChat>>>,
|
||||
store: RwLock<Option<SqliteVectorStore>>,
|
||||
client: OllamaClientRef,
|
||||
user_service: Weak<dyn AIUserService>,
|
||||
retriever_sources: RetrieversSources,
|
||||
}
|
||||
impl LLMChatController {
|
||||
pub fn new(user_service: Weak<dyn AIUserService>) -> Self {
|
||||
@ -54,9 +58,14 @@ impl LLMChatController {
|
||||
chat_by_id: DashMap::new(),
|
||||
client: Default::default(),
|
||||
user_service,
|
||||
retriever_sources: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn set_retriever_sources(&self, sources: Vec<Arc<dyn MultipleSourceRetrieverStore>>) {
|
||||
*self.retriever_sources.write().await = sources;
|
||||
}
|
||||
|
||||
pub async fn is_ready(&self) -> bool {
|
||||
self.client.read().await.is_some()
|
||||
}
|
||||
@ -72,9 +81,9 @@ impl LLMChatController {
|
||||
*self.store.write().await = Some(store);
|
||||
}
|
||||
|
||||
pub fn set_rag_ids(&self, chat_id: &Uuid, rag_ids: &[String]) {
|
||||
if let Some(mut chat) = self.chat_by_id.get_mut(chat_id) {
|
||||
chat.set_rag_ids(rag_ids.to_vec());
|
||||
pub async fn set_rag_ids(&self, chat_id: &Uuid, rag_ids: &[String]) {
|
||||
if let Some(chat) = self.get_chat(chat_id) {
|
||||
chat.write().await.set_rag_ids(rag_ids.to_vec());
|
||||
}
|
||||
}
|
||||
|
||||
@ -90,10 +99,22 @@ impl LLMChatController {
|
||||
.ok_or_else(|| FlowyError::local_ai().with_context("Ollama client has been dropped"))?
|
||||
.clone();
|
||||
let entry = self.chat_by_id.entry(info.chat_id);
|
||||
|
||||
let retriever_sources = self
|
||||
.retriever_sources
|
||||
.read()
|
||||
.await
|
||||
.iter()
|
||||
.map(Arc::downgrade)
|
||||
.collect();
|
||||
if let Entry::Vacant(e) = entry {
|
||||
let chat = LLMChat::new(info, client, store, Some(self.user_service.clone()))?;
|
||||
e.insert(chat);
|
||||
let chat = LLMChat::new(
|
||||
info,
|
||||
client,
|
||||
store,
|
||||
Some(self.user_service.clone()),
|
||||
retriever_sources,
|
||||
)?;
|
||||
e.insert(Arc::new(RwLock::new(chat)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@ -107,6 +128,10 @@ impl LLMChatController {
|
||||
self.chat_by_id.remove(chat_id);
|
||||
}
|
||||
|
||||
pub fn get_chat(&self, chat_id: &Uuid) -> Option<Arc<RwLock<LLMChat>>> {
|
||||
self.chat_by_id.get(chat_id).map(|c| c.value().clone())
|
||||
}
|
||||
|
||||
pub async fn summarize_database_row(
|
||||
&self,
|
||||
model_name: &str,
|
||||
@ -185,38 +210,36 @@ impl LLMChatController {
|
||||
|
||||
pub async fn get_related_question(
|
||||
&self,
|
||||
model_name: &str,
|
||||
_model_name: &str,
|
||||
chat_id: &Uuid,
|
||||
_message_id: i64,
|
||||
) -> FlowyResult<Vec<String>> {
|
||||
let client = self
|
||||
.client
|
||||
.read()
|
||||
.await
|
||||
.clone()
|
||||
.ok_or(FlowyError::local_ai())?
|
||||
.upgrade()
|
||||
.ok_or(FlowyError::local_ai())?;
|
||||
|
||||
let user_service = self.user_service.upgrade().ok_or(FlowyError::local_ai())?;
|
||||
let uid = user_service.user_id()?;
|
||||
let conn = user_service.sqlite_connection(uid)?;
|
||||
let message = select_latest_user_message(conn, &chat_id.to_string(), ChatAuthorType::Human)?;
|
||||
|
||||
let chain = RelatedQuestionChain::new(LLMOllama::new(model_name, client, None, None));
|
||||
let questions = chain.related_question(&message.content).await?;
|
||||
trace!(
|
||||
"related questions: {:?} for message: {}",
|
||||
questions,
|
||||
message.content
|
||||
);
|
||||
Ok(questions)
|
||||
match self.get_chat(chat_id) {
|
||||
None => {
|
||||
warn!(
|
||||
"[Chat] Chat with id {} not found, unable to get related question",
|
||||
chat_id
|
||||
);
|
||||
Ok(vec![])
|
||||
},
|
||||
Some(chat) => {
|
||||
let user_service = self.user_service.upgrade().ok_or(FlowyError::local_ai())?;
|
||||
let uid = user_service.user_id()?;
|
||||
let conn = user_service.sqlite_connection(uid)?;
|
||||
let message =
|
||||
select_latest_user_message(conn, &chat_id.to_string(), ChatAuthorType::Human)?;
|
||||
chat
|
||||
.read()
|
||||
.await
|
||||
.get_related_question(message.content)
|
||||
.await
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ask_question(&self, chat_id: &Uuid, question: &str) -> FlowyResult<String> {
|
||||
if let Some(chat) = self.chat_by_id.get(chat_id) {
|
||||
let chat = chat.value();
|
||||
let response = chat.ask_question(question).await;
|
||||
if let Some(chat) = self.get_chat(chat_id) {
|
||||
let response = chat.read().await.ask_question(question).await;
|
||||
return response;
|
||||
}
|
||||
|
||||
@ -230,10 +253,9 @@ impl LLMChatController {
|
||||
format: ResponseFormat,
|
||||
model_name: &str,
|
||||
) -> FlowyResult<StreamAnswer> {
|
||||
if let Some(mut chat) = self.chat_by_id.get_mut(chat_id) {
|
||||
chat.set_chat_model(model_name);
|
||||
|
||||
let response = chat.stream_question(question, format).await;
|
||||
if let Some(chat) = self.get_chat(chat_id) {
|
||||
chat.write().await.set_chat_model(model_name);
|
||||
let response = chat.write().await.stream_question(question, format).await;
|
||||
return response;
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,31 @@
|
||||
use async_trait::async_trait;
|
||||
use flowy_error::FlowyResult;
|
||||
pub use langchain_rust::schemas::Document as LangchainDocument;
|
||||
use std::error::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub mod multi_source_retriever;
|
||||
pub mod sqlite_retriever;
|
||||
|
||||
#[async_trait]
|
||||
pub trait AFRetriever: Send + Sync + 'static {
|
||||
fn get_rag_ids(&self) -> Vec<&str>;
|
||||
fn set_rag_ids(&mut self, new_rag_ids: Vec<String>);
|
||||
|
||||
async fn retrieve_documents(&self, query: &str)
|
||||
-> Result<Vec<LangchainDocument>, Box<dyn Error>>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait MultipleSourceRetrieverStore: Send + Sync {
|
||||
fn retriever_name(&self) -> &'static str;
|
||||
async fn read_documents(
|
||||
&self,
|
||||
workspace_id: &Uuid,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
rag_ids: &[String],
|
||||
score_threshold: f32,
|
||||
full_search: bool,
|
||||
) -> FlowyResult<Vec<LangchainDocument>>;
|
||||
}
|
@ -0,0 +1,120 @@
|
||||
use crate::local_ai::chat::retriever::{AFRetriever, MultipleSourceRetrieverStore};
|
||||
use async_trait::async_trait;
|
||||
use futures::future::join_all;
|
||||
use langchain_rust::schemas::Document;
|
||||
use std::error::Error;
|
||||
use std::sync::Arc;
|
||||
use tracing::{error, trace};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub struct MultipleSourceRetriever {
|
||||
workspace_id: Uuid,
|
||||
vector_stores: Vec<Arc<dyn MultipleSourceRetrieverStore>>,
|
||||
num_docs: usize,
|
||||
rag_ids: Vec<String>,
|
||||
full_search: bool,
|
||||
score_threshold: f32,
|
||||
}
|
||||
|
||||
impl MultipleSourceRetriever {
|
||||
pub fn new<V: Into<Arc<dyn MultipleSourceRetrieverStore>>>(
|
||||
workspace_id: Uuid,
|
||||
vector_stores: Vec<V>,
|
||||
rag_ids: Vec<String>,
|
||||
num_docs: usize,
|
||||
score_threshold: f32,
|
||||
) -> Self {
|
||||
MultipleSourceRetriever {
|
||||
workspace_id,
|
||||
vector_stores: vector_stores.into_iter().map(|v| v.into()).collect(),
|
||||
num_docs,
|
||||
rag_ids,
|
||||
full_search: false,
|
||||
score_threshold,
|
||||
}
|
||||
}
|
||||
pub fn set_rag_ids(&mut self, new_rag_ids: Vec<String>) {
|
||||
self.rag_ids = new_rag_ids;
|
||||
}
|
||||
|
||||
pub fn get_rag_ids(&self) -> Vec<&str> {
|
||||
self
|
||||
.rag_ids
|
||||
.iter()
|
||||
.map(|id| id.as_str())
|
||||
.collect::<Vec<&str>>()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AFRetriever for MultipleSourceRetriever {
|
||||
fn get_rag_ids(&self) -> Vec<&str> {
|
||||
self
|
||||
.rag_ids
|
||||
.iter()
|
||||
.map(|id| id.as_str())
|
||||
.collect::<Vec<&str>>()
|
||||
}
|
||||
|
||||
fn set_rag_ids(&mut self, new_rag_ids: Vec<String>) {
|
||||
self.rag_ids = new_rag_ids;
|
||||
}
|
||||
|
||||
async fn retrieve_documents(&self, query: &str) -> Result<Vec<Document>, Box<dyn Error>> {
|
||||
trace!(
|
||||
"[VectorStore] filters: {:?}, retrieving documents for query: {}",
|
||||
self.rag_ids,
|
||||
query,
|
||||
);
|
||||
|
||||
// Create futures for each vector store search
|
||||
let search_futures = self
|
||||
.vector_stores
|
||||
.iter()
|
||||
.map(|vector_store| {
|
||||
let vector_store = vector_store.clone();
|
||||
let query = query.to_string();
|
||||
let num_docs = self.num_docs;
|
||||
let full_search = self.full_search;
|
||||
let rag_ids = self.rag_ids.clone();
|
||||
let workspace_id = self.workspace_id;
|
||||
let score_threshold = self.score_threshold;
|
||||
|
||||
async move {
|
||||
vector_store
|
||||
.read_documents(
|
||||
&workspace_id,
|
||||
&query,
|
||||
num_docs,
|
||||
&rag_ids,
|
||||
score_threshold,
|
||||
full_search,
|
||||
)
|
||||
.await
|
||||
.map(|docs| (vector_store.retriever_name(), docs))
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let search_results = join_all(search_futures).await;
|
||||
let mut results = Vec::new();
|
||||
for result in search_results {
|
||||
if let Ok((retriever_name, docs)) = result {
|
||||
trace!(
|
||||
"[VectorStore] {} found {} results, scores: {:?}",
|
||||
retriever_name,
|
||||
docs.len(),
|
||||
docs.iter().map(|doc| doc.score).collect::<Vec<_>>()
|
||||
);
|
||||
results.extend(docs);
|
||||
} else {
|
||||
error!(
|
||||
"[VectorStore] Failed to retrieve documents: {}",
|
||||
result.unwrap_err()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
@ -0,0 +1,71 @@
|
||||
use crate::local_ai::chat::retriever::AFRetriever;
|
||||
use async_trait::async_trait;
|
||||
use flowy_ai_pub::entities::RAG_IDS;
|
||||
use langchain_rust::schemas::{Document, Retriever};
|
||||
use langchain_rust::vectorstore::{VecStoreOptions, VectorStore};
|
||||
use serde_json::{json, Value};
|
||||
use std::error::Error;
|
||||
use tracing::trace;
|
||||
|
||||
// Retriever is a retriever for vector stores.
|
||||
pub type RetrieverOption = VecStoreOptions<Value>;
|
||||
pub struct SqliteVecRetriever {
|
||||
vector_store: Option<Box<dyn VectorStore<Options = RetrieverOption>>>,
|
||||
num_docs: usize,
|
||||
options: RetrieverOption,
|
||||
}
|
||||
impl SqliteVecRetriever {
|
||||
pub fn new<V: Into<Box<dyn VectorStore<Options = RetrieverOption>>>>(
|
||||
vector_store: Option<V>,
|
||||
num_docs: usize,
|
||||
options: RetrieverOption,
|
||||
) -> Self {
|
||||
SqliteVecRetriever {
|
||||
vector_store: vector_store.map(Into::into),
|
||||
num_docs,
|
||||
options,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AFRetriever for SqliteVecRetriever {
|
||||
fn get_rag_ids(&self) -> Vec<&str> {
|
||||
self
|
||||
.options
|
||||
.filters
|
||||
.as_ref()
|
||||
.and_then(|filters| filters.get(RAG_IDS).and_then(|rag_ids| rag_ids.as_array()))
|
||||
.map(|rag_ids| rag_ids.iter().filter_map(|id| id.as_str()).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn set_rag_ids(&mut self, new_rag_ids: Vec<String>) {
|
||||
let filters = self.options.filters.get_or_insert_with(|| json!({}));
|
||||
filters[RAG_IDS] = json!(new_rag_ids);
|
||||
}
|
||||
|
||||
async fn retrieve_documents(&self, query: &str) -> Result<Vec<Document>, Box<dyn Error>> {
|
||||
trace!(
|
||||
"[VectorStore] filters: {:?}, retrieving documents for query: {}",
|
||||
self.options.filters,
|
||||
query,
|
||||
);
|
||||
|
||||
match self.vector_store.as_ref() {
|
||||
None => Ok(vec![]),
|
||||
Some(vector_store) => {
|
||||
vector_store
|
||||
.similarity_search(query, self.num_docs, &self.options)
|
||||
.await
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Retriever for SqliteVecRetriever {
|
||||
async fn get_relevant_documents(&self, query: &str) -> Result<Vec<Document>, Box<dyn Error>> {
|
||||
self.retrieve_documents(query).await
|
||||
}
|
||||
}
|
@ -26,7 +26,7 @@ use serde::{Deserialize, Serialize};
|
||||
use std::ops::Deref;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Weak};
|
||||
use tracing::{debug, error, info, instrument, trace};
|
||||
use tracing::{debug, error, info, instrument, trace, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
@ -189,7 +189,7 @@ impl LocalAIController {
|
||||
return;
|
||||
}
|
||||
|
||||
self.llm_controller.set_rag_ids(chat_id, rag_ids);
|
||||
self.llm_controller.set_rag_ids(chat_id, rag_ids).await;
|
||||
}
|
||||
|
||||
pub async fn open_chat(
|
||||
@ -201,14 +201,17 @@ impl LocalAIController {
|
||||
summary: String,
|
||||
) -> FlowyResult<()> {
|
||||
if !self.is_enabled() {
|
||||
warn!("[chat] local ai is disabled, skip open chat");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Only keep one chat open at a time. Since loading multiple models at the same time will cause
|
||||
// memory issues.
|
||||
if let Some(current_chat_id) = self.current_chat_id.load().as_ref() {
|
||||
debug!("[Local AI] close previous chat: {}", current_chat_id);
|
||||
self.close_chat(current_chat_id);
|
||||
if current_chat_id.as_ref() != chat_id {
|
||||
debug!("[Chat] close previous chat: {}", current_chat_id);
|
||||
self.close_chat(current_chat_id);
|
||||
}
|
||||
}
|
||||
|
||||
let info = LLMChatInfo {
|
||||
@ -219,6 +222,7 @@ impl LocalAIController {
|
||||
summary,
|
||||
};
|
||||
self.current_chat_id.store(Some(Arc::new(*chat_id)));
|
||||
trace!("[Chat] open chat: {}", chat_id);
|
||||
self.llm_controller.open_chat(info).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ async fn local_ollama_test_chat_related_question() {
|
||||
let ollama = LLMOllama::default().with_model("llama3.1");
|
||||
let chain = RelatedQuestionChain::new(ollama);
|
||||
let resp = chain
|
||||
.related_question("Compare rust with JS")
|
||||
.generate_related_question("Compare rust with JS")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
@ -76,7 +76,14 @@ impl TestContext {
|
||||
summary: "".to_string(),
|
||||
};
|
||||
|
||||
LLMChat::new(info, self.ollama.clone(), Some(self.store.clone()), None).unwrap()
|
||||
LLMChat::new(
|
||||
info,
|
||||
self.ollama.clone(),
|
||||
Some(self.store.clone()),
|
||||
None,
|
||||
vec![],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5,14 +5,18 @@ use collab::preclude::{Collab, StateVector};
|
||||
use collab::util::is_change_since_sv;
|
||||
use collab_entity::CollabType;
|
||||
use flowy_ai::ai_manager::{AIExternalService, AIManager};
|
||||
use flowy_ai::local_ai::chat::retriever::{LangchainDocument, MultipleSourceRetrieverStore};
|
||||
use flowy_ai::local_ai::controller::LocalAIController;
|
||||
use flowy_ai_pub::cloud::ChatCloudService;
|
||||
use flowy_ai_pub::entities::{SOURCE, SOURCE_ID, SOURCE_NAME};
|
||||
use flowy_ai_pub::persistence::AFCollabMetadata;
|
||||
use flowy_ai_pub::user_service::AIUserService;
|
||||
use flowy_error::{FlowyError, FlowyResult};
|
||||
use flowy_folder::ViewLayout;
|
||||
use flowy_folder_pub::cloud::{FolderCloudService, FullSyncCollabParams};
|
||||
use flowy_folder_pub::query::FolderService;
|
||||
use flowy_search_pub::tantivy_state::DocumentTantivyState;
|
||||
use flowy_server::util::tanvity_local_search;
|
||||
use flowy_sqlite::kv::KVStorePreferences;
|
||||
use flowy_sqlite::DBConnection;
|
||||
use flowy_storage_pub::storage::StorageService;
|
||||
@ -20,9 +24,11 @@ use flowy_user::services::authenticate_user::AuthenticateUser;
|
||||
use flowy_user_pub::entities::WorkspaceType;
|
||||
use lib_infra::async_trait::async_trait;
|
||||
use lib_infra::util::timestamp;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Weak};
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, error, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
@ -205,3 +211,64 @@ impl AIUserService for ChatUserServiceImpl {
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MultiSourceVSTanvityImpl {
|
||||
state: Option<Weak<RwLock<DocumentTantivyState>>>,
|
||||
}
|
||||
|
||||
impl MultiSourceVSTanvityImpl {
|
||||
pub fn new(state: Option<Weak<RwLock<DocumentTantivyState>>>) -> Self {
|
||||
Self { state }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MultipleSourceRetrieverStore for MultiSourceVSTanvityImpl {
|
||||
fn retriever_name(&self) -> &'static str {
|
||||
"Tanvity Multiple Source Retriever"
|
||||
}
|
||||
|
||||
async fn read_documents(
|
||||
&self,
|
||||
workspace_id: &Uuid,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
rag_ids: &[String],
|
||||
score_threshold: f32,
|
||||
_full_search: bool,
|
||||
) -> FlowyResult<Vec<LangchainDocument>> {
|
||||
let docs = tanvity_local_search(
|
||||
&self.state,
|
||||
workspace_id,
|
||||
query,
|
||||
Some(rag_ids.to_vec()),
|
||||
limit,
|
||||
score_threshold,
|
||||
)
|
||||
.await;
|
||||
|
||||
match docs {
|
||||
None => Ok(vec![]),
|
||||
Some(docs) => Ok(
|
||||
docs
|
||||
.into_iter()
|
||||
.map(|v| LangchainDocument {
|
||||
page_content: v.content,
|
||||
metadata: json!({
|
||||
SOURCE_ID: v.object_id,
|
||||
SOURCE: "appflowy",
|
||||
SOURCE_NAME: "document",
|
||||
})
|
||||
.as_object()
|
||||
.unwrap()
|
||||
.clone()
|
||||
.into_iter()
|
||||
.collect(),
|
||||
score: v.score,
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,3 +1,4 @@
|
||||
use crate::deps_resolve::MultiSourceVSTanvityImpl;
|
||||
use crate::AppFlowyCoreConfig;
|
||||
use arc_swap::{ArcSwap, ArcSwapOption};
|
||||
use collab::entity::EncodedCollab;
|
||||
@ -73,6 +74,13 @@ impl ServerProvider {
|
||||
}
|
||||
|
||||
async fn set_tanvity_state(&self, tanvity_state: Option<Weak<RwLock<DocumentTantivyState>>>) {
|
||||
let tanvity_store = Arc::new(MultiSourceVSTanvityImpl::new(tanvity_state.clone()));
|
||||
|
||||
self
|
||||
.local_ai
|
||||
.set_retriever_sources(vec![tanvity_store])
|
||||
.await;
|
||||
|
||||
match self.providers.try_get(self.auth_type.load().as_ref()) {
|
||||
TryResult::Present(r) => {
|
||||
r.set_tanvity_state(tanvity_state).await;
|
||||
|
@ -37,6 +37,7 @@ pub struct TanvitySearchResponseItem {
|
||||
pub icon: Option<ResultIcon>,
|
||||
pub workspace_id: String,
|
||||
pub content: String,
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
||||
|
@ -305,6 +305,8 @@ impl DocumentTantivyState {
|
||||
workspace_id: &Uuid,
|
||||
query: &str,
|
||||
object_ids: Option<Vec<String>>,
|
||||
limit: usize,
|
||||
score_threshold: f32,
|
||||
) -> FlowyResult<Vec<TanvitySearchResponseItem>> {
|
||||
let workspace_id = workspace_id.to_string();
|
||||
let reader = self.reader.clone();
|
||||
@ -319,7 +321,7 @@ impl DocumentTantivyState {
|
||||
qp.set_field_fuzzy(self.field_name, true, 2, true);
|
||||
|
||||
let query = qp.parse_query(query)?;
|
||||
let top_docs = searcher.search(&query, &tantivy::collector::TopDocs::with_limit(10))?;
|
||||
let top_docs = searcher.search(&query, &tantivy::collector::TopDocs::with_limit(limit))?;
|
||||
|
||||
let mut results = Vec::with_capacity(top_docs.len());
|
||||
let mut seen_ids = std::collections::HashSet::new();
|
||||
@ -333,7 +335,12 @@ impl DocumentTantivyState {
|
||||
}
|
||||
});
|
||||
|
||||
for (_score, doc_address) in top_docs {
|
||||
for (score, doc_address) in top_docs {
|
||||
// Skip results that don't meet the score threshold
|
||||
if score < score_threshold {
|
||||
continue;
|
||||
}
|
||||
|
||||
let retrieved: TantivyDocument = searcher.doc(doc_address)?;
|
||||
// Pull out each stored field using cached field references
|
||||
let workspace_id_str = retrieved
|
||||
@ -416,6 +423,7 @@ impl DocumentTantivyState {
|
||||
icon,
|
||||
workspace_id: workspace_id_str,
|
||||
content,
|
||||
score,
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -50,7 +50,7 @@ impl SearchHandler for DocumentLocalSearchHandler {
|
||||
);
|
||||
},
|
||||
Some(state) => {
|
||||
match state.read().await.search(&workspace_id, &query, None) {
|
||||
match state.read().await.search(&workspace_id, &query, None, 10, 0.4) {
|
||||
Ok(items) => {
|
||||
trace!("[Tanvity] local document search result: {:?}", items);
|
||||
if items.is_empty() {
|
||||
|
@ -39,7 +39,7 @@ where
|
||||
}
|
||||
|
||||
trace!("[Search] Local AI search returned no results, falling back to local search");
|
||||
let items = tanvity_local_search(&self.state, workspace_id, &query)
|
||||
let items = tanvity_local_search(&self.state, workspace_id, &query, None, 10, 0.4)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
Ok(items)
|
||||
|
@ -44,7 +44,7 @@ impl SearchCloudService for LocalSearchServiceImpl {
|
||||
}
|
||||
|
||||
trace!("[Search] Local AI search returned no results, falling back to local search");
|
||||
let items = tanvity_local_search(&self.state, workspace_id, &query)
|
||||
let items = tanvity_local_search(&self.state, workspace_id, &query, None, 10, 0.4)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
Ok(items)
|
||||
|
@ -23,6 +23,9 @@ pub async fn tanvity_local_search(
|
||||
state: &Option<Weak<RwLock<DocumentTantivyState>>>,
|
||||
workspace_id: &Uuid,
|
||||
query: &str,
|
||||
object_ids: Option<Vec<String>>,
|
||||
limit: usize,
|
||||
score_threshold: f32,
|
||||
) -> Option<Vec<SearchDocumentResponseItem>> {
|
||||
match state.as_ref().and_then(|v| v.upgrade()) {
|
||||
None => {
|
||||
@ -30,7 +33,11 @@ pub async fn tanvity_local_search(
|
||||
None
|
||||
},
|
||||
Some(state) => {
|
||||
let results = state.read().await.search(workspace_id, query, None).ok()?;
|
||||
let results = state
|
||||
.read()
|
||||
.await
|
||||
.search(workspace_id, query, object_ids, limit, score_threshold)
|
||||
.ok()?;
|
||||
let items = results
|
||||
.into_iter()
|
||||
.flat_map(|v| tanvity_document_to_search_document(*workspace_id, v))
|
||||
@ -49,7 +56,7 @@ pub(crate) fn tanvity_document_to_search_document(
|
||||
Some(SearchDocumentResponseItem {
|
||||
object_id,
|
||||
workspace_id,
|
||||
score: 1.0,
|
||||
score: doc.score as f64,
|
||||
content_type: Some(SearchContentType::PlainText),
|
||||
content: doc.content,
|
||||
preview: None,
|
||||
|
Reference in New Issue
Block a user