feat: AI chat (#5383)

* chore: ai type

* chore: use patch to fix version issue

* chore: update

* chore: update

* chore: integrate client api

* chore: add schema

* chore: setup event

* chore: add event test

* chore: add test

* chore: update test

* chore: load chat message

* chore: load chat message

* chore: chat ui

* chore: disable create chat

* chore: update client api

* chore: disable chat

* chore: ui theme

* chore: ui theme

* chore: copy message

* chore: fix test

* chore: show error

* chore: update bloc

* chore: update test

* chore: lint

* chore: icon

* chore: hover

* chore: show unsupported page

* chore: adjust mobile ui

* chore: adjust view title bar

* chore: return related question

* chore: error page

* chore: error page

* chore: code format

* chore: prompt

* chore: fix test

* chore: ui adjust

* chore: disable create chat

* chore: add loading page

* chore: fix test

* chore: disable chat action

* chore: add maximum text limit
This commit is contained in:
Nathan.fooo
2024-06-03 14:27:28 +08:00
committed by GitHub
parent 4d42c9ea68
commit aec7bc847e
114 changed files with 5473 additions and 282 deletions

View File

@ -0,0 +1,35 @@
[package]
name = "flowy-chat"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
flowy-derive.workspace = true
flowy-notification = { workspace = true }
flowy-error = { path = "../flowy-error", features = [
"impl_from_dispatch_error",
"impl_from_collab_folder",
] }
lib-dispatch = { workspace = true }
tracing.workspace = true
uuid.workspace = true
strum_macros = "0.21"
protobuf.workspace = true
bytes.workspace = true
validator = { version = "0.16.0", features = ["derive"] }
lib-infra = { workspace = true }
flowy-chat-pub.workspace = true
dashmap = "5.5"
flowy-sqlite = { workspace = true }
tokio.workspace = true
futures.workspace = true
[build-dependencies]
flowy-codegen.workspace = true
[features]
dart = ["flowy-codegen/dart", "flowy-notification/dart"]
tauri_ts = ["flowy-codegen/ts", "flowy-notification/tauri_ts"]
web_ts = ["flowy-codegen/ts", "flowy-notification/web_ts"]

View File

@ -0,0 +1,3 @@
# Check out the FlowyConfig (located in flowy_toml.rs) for more details.
proto_input = ["src/entities.rs", "src/event_map.rs", "src/notification.rs"]
event_files = ["src/event_map.rs"]

View File

@ -0,0 +1,40 @@
fn main() {
#[cfg(feature = "dart")]
{
flowy_codegen::protobuf_file::dart_gen(env!("CARGO_PKG_NAME"));
flowy_codegen::dart_event::gen(env!("CARGO_PKG_NAME"));
}
#[cfg(feature = "tauri_ts")]
{
flowy_codegen::ts_event::gen(env!("CARGO_PKG_NAME"), flowy_codegen::Project::Tauri);
flowy_codegen::protobuf_file::ts_gen(
env!("CARGO_PKG_NAME"),
env!("CARGO_PKG_NAME"),
flowy_codegen::Project::Tauri,
);
flowy_codegen::ts_event::gen(env!("CARGO_PKG_NAME"), flowy_codegen::Project::Tauri);
flowy_codegen::protobuf_file::ts_gen(
env!("CARGO_PKG_NAME"),
env!("CARGO_PKG_NAME"),
flowy_codegen::Project::TauriApp,
);
}
#[cfg(feature = "web_ts")]
{
flowy_codegen::ts_event::gen(
"folder",
flowy_codegen::Project::Web {
relative_path: "../../".to_string(),
},
);
flowy_codegen::protobuf_file::ts_gen(
env!("CARGO_PKG_NAME"),
"folder",
flowy_codegen::Project::Web {
relative_path: "../../".to_string(),
},
);
}
}

View File

@ -0,0 +1,471 @@
use crate::entities::{
ChatMessageErrorPB, ChatMessageListPB, ChatMessagePB, RepeatedRelatedQuestionPB,
};
use crate::manager::ChatUserService;
use crate::notification::{send_notification, ChatNotification};
use crate::persistence::{
insert_answer_message, insert_chat_messages, select_chat_messages, ChatMessageTable,
};
use flowy_chat_pub::cloud::{
ChatAuthorType, ChatCloudService, ChatMessage, ChatMessageType, MessageCursor,
};
use flowy_error::{FlowyError, FlowyResult};
use flowy_sqlite::DBConnection;
use futures::StreamExt;
use std::sync::atomic::AtomicI64;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{error, instrument, trace};
enum PrevMessageState {
HasMore,
NoMore,
Loading,
}
pub struct Chat {
chat_id: String,
uid: i64,
user_service: Arc<dyn ChatUserService>,
cloud_service: Arc<dyn ChatCloudService>,
prev_message_state: Arc<RwLock<PrevMessageState>>,
latest_message_id: Arc<AtomicI64>,
}
impl Chat {
pub fn new(
uid: i64,
chat_id: String,
user_service: Arc<dyn ChatUserService>,
cloud_service: Arc<dyn ChatCloudService>,
) -> Chat {
Chat {
uid,
chat_id,
cloud_service,
user_service,
prev_message_state: Arc::new(RwLock::new(PrevMessageState::HasMore)),
latest_message_id: Default::default(),
}
}
pub fn close(&self) {}
#[allow(dead_code)]
pub async fn pull_latest_message(&self, limit: i64) {
let latest_message_id = self
.latest_message_id
.load(std::sync::atomic::Ordering::Relaxed);
if latest_message_id > 0 {
let _ = self
.load_remote_chat_messages(limit, None, Some(latest_message_id))
.await;
}
}
#[instrument(level = "info", skip_all, err)]
pub async fn send_chat_message(
&self,
message: &str,
message_type: ChatMessageType,
) -> Result<(), FlowyError> {
if message.len() > 2000 {
return Err(FlowyError::text_too_long().with_context("Exceeds maximum message 2000 length"));
}
let uid = self.user_service.user_id()?;
let workspace_id = self.user_service.workspace_id()?;
stream_send_chat_messages(
uid,
workspace_id,
self.chat_id.clone(),
message.to_string(),
message_type,
self.cloud_service.clone(),
self.user_service.clone(),
);
Ok(())
}
/// Load chat messages for a given `chat_id`.
///
/// 1. When opening a chat:
/// - Loads local chat messages.
/// - `after_message_id` and `before_message_id` are `None`.
/// - Spawns a task to load messages from the remote server, notifying the user when the remote messages are loaded.
///
/// 2. Loading more messages in an existing chat with `after_message_id`:
/// - `after_message_id` is the last message ID in the current chat messages.
///
/// 3. Loading more messages in an existing chat with `before_message_id`:
/// - `before_message_id` is the first message ID in the current chat messages.
pub async fn load_prev_chat_messages(
&self,
limit: i64,
before_message_id: Option<i64>,
) -> Result<ChatMessageListPB, FlowyError> {
trace!(
"Loading old messages: chat_id={}, limit={}, before_message_id={:?}",
self.chat_id,
limit,
before_message_id
);
let messages = self
.load_local_chat_messages(limit, None, before_message_id)
.await?;
// If the number of messages equals the limit, then no need to load more messages from remote
let has_more = !messages.is_empty();
if messages.len() == limit as usize {
return Ok(ChatMessageListPB {
messages,
has_more,
total: 0,
});
}
if matches!(
*self.prev_message_state.read().await,
PrevMessageState::HasMore
) {
*self.prev_message_state.write().await = PrevMessageState::Loading;
if let Err(err) = self
.load_remote_chat_messages(limit, before_message_id, None)
.await
{
error!("Failed to load previous chat messages: {}", err);
}
}
Ok(ChatMessageListPB {
messages,
has_more,
total: 0,
})
}
pub async fn load_latest_chat_messages(
&self,
limit: i64,
after_message_id: Option<i64>,
) -> Result<ChatMessageListPB, FlowyError> {
trace!(
"Loading new messages: chat_id={}, limit={}, after_message_id={:?}",
self.chat_id,
limit,
after_message_id,
);
let messages = self
.load_local_chat_messages(limit, after_message_id, None)
.await?;
trace!(
"Loaded local chat messages: chat_id={}, messages={}",
self.chat_id,
messages.len()
);
// If the number of messages equals the limit, then no need to load more messages from remote
let has_more = !messages.is_empty();
let _ = self
.load_remote_chat_messages(limit, None, after_message_id)
.await;
Ok(ChatMessageListPB {
messages,
has_more,
total: 0,
})
}
async fn load_remote_chat_messages(
&self,
limit: i64,
before_message_id: Option<i64>,
after_message_id: Option<i64>,
) -> FlowyResult<()> {
trace!(
"Loading chat messages from remote: chat_id={}, limit={}, before_message_id={:?}, after_message_id={:?}",
self.chat_id,
limit,
before_message_id,
after_message_id
);
let workspace_id = self.user_service.workspace_id()?;
let chat_id = self.chat_id.clone();
let cloud_service = self.cloud_service.clone();
let user_service = self.user_service.clone();
let uid = self.uid;
let prev_message_state = self.prev_message_state.clone();
let latest_message_id = self.latest_message_id.clone();
tokio::spawn(async move {
let cursor = match (before_message_id, after_message_id) {
(Some(bid), _) => MessageCursor::BeforeMessageId(bid),
(_, Some(aid)) => MessageCursor::AfterMessageId(aid),
_ => MessageCursor::NextBack,
};
match cloud_service
.get_chat_messages(&workspace_id, &chat_id, cursor.clone(), limit as u64)
.await
{
Ok(resp) => {
// Save chat messages to local disk
if let Err(err) = save_chat_message(
user_service.sqlite_connection(uid)?,
&chat_id,
resp.messages.clone(),
) {
error!("Failed to save chat:{} messages: {}", chat_id, err);
}
// Update latest message ID
if !resp.messages.is_empty() {
latest_message_id.store(
resp.messages[0].message_id,
std::sync::atomic::Ordering::Relaxed,
);
}
let pb = ChatMessageListPB::from(resp);
trace!(
"Loaded chat messages from remote: chat_id={}, messages={}",
chat_id,
pb.messages.len()
);
if matches!(cursor, MessageCursor::BeforeMessageId(_)) {
if pb.has_more {
*prev_message_state.write().await = PrevMessageState::HasMore;
} else {
*prev_message_state.write().await = PrevMessageState::NoMore;
}
send_notification(&chat_id, ChatNotification::DidLoadPrevChatMessage)
.payload(pb)
.send();
} else {
send_notification(&chat_id, ChatNotification::DidLoadLatestChatMessage)
.payload(pb)
.send();
}
},
Err(err) => error!("Failed to load chat messages: {}", err),
}
Ok::<(), FlowyError>(())
});
Ok(())
}
pub async fn get_related_question(
&self,
message_id: i64,
) -> Result<RepeatedRelatedQuestionPB, FlowyError> {
let workspace_id = self.user_service.workspace_id()?;
let resp = self
.cloud_service
.get_related_message(&workspace_id, &self.chat_id, message_id)
.await?;
trace!(
"Related messages: chat_id={}, message_id={}, messages:{:?}",
self.chat_id,
message_id,
resp.items
);
Ok(RepeatedRelatedQuestionPB::from(resp))
}
#[instrument(level = "debug", skip_all, err)]
pub async fn generate_answer(&self, question_message_id: i64) -> FlowyResult<ChatMessagePB> {
let workspace_id = self.user_service.workspace_id()?;
let resp = self
.cloud_service
.generate_answer(&workspace_id, &self.chat_id, question_message_id)
.await?;
save_answer(
self.user_service.sqlite_connection(self.uid)?,
&self.chat_id,
resp.clone(),
question_message_id,
)?;
let pb = ChatMessagePB::from(resp);
Ok(pb)
}
async fn load_local_chat_messages(
&self,
limit: i64,
after_message_id: Option<i64>,
before_message_id: Option<i64>,
) -> Result<Vec<ChatMessagePB>, FlowyError> {
let conn = self.user_service.sqlite_connection(self.uid)?;
let records = select_chat_messages(
conn,
&self.chat_id,
limit,
after_message_id,
before_message_id,
)?;
let messages = records
.into_iter()
.map(|record| ChatMessagePB {
message_id: record.message_id,
content: record.content,
created_at: record.created_at,
author_type: record.author_type,
author_id: record.author_id,
has_following: false,
reply_message_id: record.reply_message_id,
})
.collect::<Vec<_>>();
Ok(messages)
}
}
fn stream_send_chat_messages(
uid: i64,
workspace_id: String,
chat_id: String,
message_content: String,
message_type: ChatMessageType,
cloud_service: Arc<dyn ChatCloudService>,
user_service: Arc<dyn ChatUserService>,
) {
tokio::spawn(async move {
trace!(
"Sending chat message: chat_id={}, message={}, type={:?}",
chat_id,
message_content,
message_type
);
let mut messages = Vec::with_capacity(2);
let stream_result = cloud_service
.send_chat_message(&workspace_id, &chat_id, &message_content, message_type)
.await;
// By default, stream only returns two messages:
// 1. user message
// 2. ai response message
match stream_result {
Ok(mut stream) => {
while let Some(result) = stream.next().await {
match result {
Ok(message) => {
let mut pb = ChatMessagePB::from(message.clone());
if matches!(message.author.author_type, ChatAuthorType::Human) {
pb.has_following = true;
send_notification(&chat_id, ChatNotification::LastUserSentMessage)
.payload(pb.clone())
.send();
}
//
send_notification(&chat_id, ChatNotification::DidReceiveChatMessage)
.payload(pb)
.send();
messages.push(message);
},
Err(err) => {
error!("Failed to send chat message: {}", err);
let pb = ChatMessageErrorPB {
chat_id: chat_id.clone(),
content: message_content.clone(),
error_message: "Service Temporarily Unavailable".to_string(),
};
send_notification(&chat_id, ChatNotification::StreamChatMessageError)
.payload(pb)
.send();
break;
},
}
}
},
Err(err) => {
error!("Failed to send chat message: {}", err);
let pb = ChatMessageErrorPB {
chat_id: chat_id.clone(),
content: message_content.clone(),
error_message: err.to_string(),
};
send_notification(&chat_id, ChatNotification::StreamChatMessageError)
.payload(pb)
.send();
return;
},
}
if messages.is_empty() {
return;
}
trace!(
"Saving chat messages to local disk: chat_id={}, messages:{:?}",
chat_id,
messages
);
// Insert chat messages to local disk
if let Err(err) = user_service.sqlite_connection(uid).and_then(|conn| {
let records = messages
.into_iter()
.map(|message| ChatMessageTable {
message_id: message.message_id,
chat_id: chat_id.clone(),
content: message.content,
created_at: message.created_at.timestamp(),
author_type: message.author.author_type as i64,
author_id: message.author.author_id.to_string(),
reply_message_id: message.reply_message_id,
})
.collect::<Vec<_>>();
insert_chat_messages(conn, &records)?;
// Mark chat as finished
send_notification(&chat_id, ChatNotification::FinishAnswerQuestion).send();
Ok(())
}) {
error!("Failed to save chat messages: {}", err);
}
});
}
fn save_chat_message(
conn: DBConnection,
chat_id: &str,
messages: Vec<ChatMessage>,
) -> FlowyResult<()> {
let records = messages
.into_iter()
.map(|message| ChatMessageTable {
message_id: message.message_id,
chat_id: chat_id.to_string(),
content: message.content,
created_at: message.created_at.timestamp(),
author_type: message.author.author_type as i64,
author_id: message.author.author_id.to_string(),
reply_message_id: message.reply_message_id,
})
.collect::<Vec<_>>();
insert_chat_messages(conn, &records)?;
Ok(())
}
fn save_answer(
conn: DBConnection,
chat_id: &str,
message: ChatMessage,
question_message_id: i64,
) -> FlowyResult<()> {
let record = ChatMessageTable {
message_id: message.message_id,
chat_id: chat_id.to_string(),
content: message.content,
created_at: message.created_at.timestamp(),
author_type: message.author.author_type as i64,
author_id: message.author.author_id.to_string(),
reply_message_id: message.reply_message_id,
};
insert_answer_message(conn, question_message_id, record)?;
Ok(())
}

View File

@ -0,0 +1,190 @@
use flowy_chat_pub::cloud::{
ChatMessage, RelatedQuestion, RepeatedChatMessage, RepeatedRelatedQuestion,
};
use flowy_derive::{ProtoBuf, ProtoBuf_Enum};
use lib_infra::validator_fn::required_not_empty_str;
use validator::Validate;
#[derive(Default, ProtoBuf, Validate, Clone, Debug)]
pub struct SendChatPayloadPB {
#[pb(index = 1)]
#[validate(custom = "required_not_empty_str")]
pub chat_id: String,
#[pb(index = 2)]
#[validate(custom = "required_not_empty_str")]
pub message: String,
#[pb(index = 3)]
pub message_type: ChatMessageTypePB,
}
#[derive(Debug, Default, Clone, ProtoBuf_Enum, PartialEq, Eq, Copy)]
pub enum ChatMessageTypePB {
#[default]
System = 0,
User = 1,
}
#[derive(Default, ProtoBuf, Validate, Clone, Debug)]
pub struct LoadPrevChatMessagePB {
#[pb(index = 1)]
#[validate(custom = "required_not_empty_str")]
pub chat_id: String,
#[pb(index = 2)]
pub limit: i64,
#[pb(index = 4, one_of)]
pub before_message_id: Option<i64>,
}
#[derive(Default, ProtoBuf, Validate, Clone, Debug)]
pub struct LoadNextChatMessagePB {
#[pb(index = 1)]
#[validate(custom = "required_not_empty_str")]
pub chat_id: String,
#[pb(index = 2)]
pub limit: i64,
#[pb(index = 4, one_of)]
pub after_message_id: Option<i64>,
}
#[derive(Default, ProtoBuf, Validate, Clone, Debug)]
pub struct ChatMessageListPB {
#[pb(index = 1)]
pub has_more: bool,
#[pb(index = 2)]
pub messages: Vec<ChatMessagePB>,
/// If the total number of messages is 0, then the total number of messages is unknown.
#[pb(index = 3)]
pub total: i64,
}
impl From<RepeatedChatMessage> for ChatMessageListPB {
fn from(repeated_chat_message: RepeatedChatMessage) -> Self {
let messages = repeated_chat_message
.messages
.into_iter()
.map(ChatMessagePB::from)
.collect();
ChatMessageListPB {
has_more: repeated_chat_message.has_more,
messages,
total: repeated_chat_message.total,
}
}
}
#[derive(Debug, Clone, Default, ProtoBuf)]
pub struct ChatMessagePB {
#[pb(index = 1)]
pub message_id: i64,
#[pb(index = 2)]
pub content: String,
#[pb(index = 3)]
pub created_at: i64,
#[pb(index = 4)]
pub author_type: i64,
#[pb(index = 5)]
pub author_id: String,
#[pb(index = 6)]
pub has_following: bool,
#[pb(index = 7, one_of)]
pub reply_message_id: Option<i64>,
}
#[derive(Debug, Clone, Default, ProtoBuf)]
pub struct ChatMessageErrorPB {
#[pb(index = 1)]
pub chat_id: String,
#[pb(index = 2)]
pub content: String,
#[pb(index = 3)]
pub error_message: String,
}
impl From<ChatMessage> for ChatMessagePB {
fn from(chat_message: ChatMessage) -> Self {
ChatMessagePB {
message_id: chat_message.message_id,
content: chat_message.content,
created_at: chat_message.created_at.timestamp(),
author_type: chat_message.author.author_type as i64,
author_id: chat_message.author.author_id.to_string(),
has_following: false,
reply_message_id: None,
}
}
}
#[derive(Debug, Clone, Default, ProtoBuf)]
pub struct RepeatedChatMessagePB {
#[pb(index = 1)]
items: Vec<ChatMessagePB>,
}
impl From<Vec<ChatMessage>> for RepeatedChatMessagePB {
fn from(messages: Vec<ChatMessage>) -> Self {
RepeatedChatMessagePB {
items: messages.into_iter().map(ChatMessagePB::from).collect(),
}
}
}
#[derive(Debug, Clone, Default, ProtoBuf)]
pub struct ChatMessageIdPB {
#[pb(index = 1)]
pub chat_id: String,
#[pb(index = 2)]
pub message_id: i64,
}
#[derive(Debug, Clone, Default, ProtoBuf)]
pub struct RelatedQuestionPB {
#[pb(index = 1)]
pub content: String,
}
impl From<RelatedQuestion> for RelatedQuestionPB {
fn from(value: RelatedQuestion) -> Self {
RelatedQuestionPB {
content: value.content,
}
}
}
#[derive(Debug, Clone, Default, ProtoBuf)]
pub struct RepeatedRelatedQuestionPB {
#[pb(index = 1)]
pub message_id: i64,
#[pb(index = 2)]
pub items: Vec<RelatedQuestionPB>,
}
impl From<RepeatedRelatedQuestion> for RepeatedRelatedQuestionPB {
fn from(value: RepeatedRelatedQuestion) -> Self {
RepeatedRelatedQuestionPB {
message_id: value.message_id,
items: value
.items
.into_iter()
.map(RelatedQuestionPB::from)
.collect(),
}
}
}

View File

@ -0,0 +1,93 @@
use flowy_chat_pub::cloud::ChatMessageType;
use std::sync::{Arc, Weak};
use validator::Validate;
use flowy_error::{FlowyError, FlowyResult};
use lib_dispatch::prelude::{data_result_ok, AFPluginData, AFPluginState, DataResult};
use crate::entities::*;
use crate::manager::ChatManager;
fn upgrade_chat_manager(
chat_manager: AFPluginState<Weak<ChatManager>>,
) -> FlowyResult<Arc<ChatManager>> {
let chat_manager = chat_manager
.upgrade()
.ok_or(FlowyError::internal().with_context("The chat manager is already dropped"))?;
Ok(chat_manager)
}
#[tracing::instrument(level = "debug", skip_all, err)]
pub(crate) async fn send_chat_message_handler(
data: AFPluginData<SendChatPayloadPB>,
chat_manager: AFPluginState<Weak<ChatManager>>,
) -> Result<(), FlowyError> {
let chat_manager = upgrade_chat_manager(chat_manager)?;
let data = data.into_inner();
data.validate()?;
let message_type = match data.message_type {
ChatMessageTypePB::System => ChatMessageType::System,
ChatMessageTypePB::User => ChatMessageType::User,
};
chat_manager
.send_chat_message(&data.chat_id, &data.message, message_type)
.await?;
Ok(())
}
#[tracing::instrument(level = "debug", skip_all, err)]
pub(crate) async fn load_prev_message_handler(
data: AFPluginData<LoadPrevChatMessagePB>,
chat_manager: AFPluginState<Weak<ChatManager>>,
) -> DataResult<ChatMessageListPB, FlowyError> {
let chat_manager = upgrade_chat_manager(chat_manager)?;
let data = data.into_inner();
data.validate()?;
let messages = chat_manager
.load_prev_chat_messages(&data.chat_id, data.limit, data.before_message_id)
.await?;
data_result_ok(messages)
}
#[tracing::instrument(level = "debug", skip_all, err)]
pub(crate) async fn load_next_message_handler(
data: AFPluginData<LoadNextChatMessagePB>,
chat_manager: AFPluginState<Weak<ChatManager>>,
) -> DataResult<ChatMessageListPB, FlowyError> {
let chat_manager = upgrade_chat_manager(chat_manager)?;
let data = data.into_inner();
data.validate()?;
let messages = chat_manager
.load_latest_chat_messages(&data.chat_id, data.limit, data.after_message_id)
.await?;
data_result_ok(messages)
}
#[tracing::instrument(level = "debug", skip_all, err)]
pub(crate) async fn get_related_question_handler(
data: AFPluginData<ChatMessageIdPB>,
chat_manager: AFPluginState<Weak<ChatManager>>,
) -> DataResult<RepeatedRelatedQuestionPB, FlowyError> {
let chat_manager = upgrade_chat_manager(chat_manager)?;
let data = data.into_inner();
let messages = chat_manager
.get_related_questions(&data.chat_id, data.message_id)
.await?;
data_result_ok(messages)
}
#[tracing::instrument(level = "debug", skip_all, err)]
pub(crate) async fn get_answer_handler(
data: AFPluginData<ChatMessageIdPB>,
chat_manager: AFPluginState<Weak<ChatManager>>,
) -> DataResult<ChatMessagePB, FlowyError> {
let chat_manager = upgrade_chat_manager(chat_manager)?;
let data = data.into_inner();
let message = chat_manager
.generate_answer(&data.chat_id, data.message_id)
.await?;
data_result_ok(message)
}

View File

@ -0,0 +1,40 @@
use std::sync::Weak;
use strum_macros::Display;
use flowy_derive::{Flowy_Event, ProtoBuf_Enum};
use lib_dispatch::prelude::*;
use crate::event_handler::*;
use crate::manager::ChatManager;
pub fn init(chat_manager: Weak<ChatManager>) -> AFPlugin {
AFPlugin::new()
.name("Flowy-Chat")
.state(chat_manager)
.event(ChatEvent::SendMessage, send_chat_message_handler)
.event(ChatEvent::LoadPrevMessage, load_prev_message_handler)
.event(ChatEvent::LoadNextMessage, load_next_message_handler)
.event(ChatEvent::GetRelatedQuestion, get_related_question_handler)
.event(ChatEvent::GetAnswerForQuestion, get_answer_handler)
}
#[derive(Clone, Copy, PartialEq, Eq, Debug, Display, Hash, ProtoBuf_Enum, Flowy_Event)]
#[event_err = "FlowyError"]
pub enum ChatEvent {
/// Create a new workspace
#[event(input = "LoadPrevChatMessagePB", output = "ChatMessageListPB")]
LoadPrevMessage = 0,
#[event(input = "LoadNextChatMessagePB", output = "ChatMessageListPB")]
LoadNextMessage = 1,
#[event(input = "SendChatPayloadPB")]
SendMessage = 2,
#[event(input = "ChatMessageIdPB", output = "RepeatedRelatedQuestionPB")]
GetRelatedQuestion = 3,
#[event(input = "ChatMessageIdPB", output = "ChatMessagePB")]
GetAnswerForQuestion = 4,
}

View File

@ -0,0 +1,9 @@
mod event_handler;
pub mod event_map;
mod chat;
pub mod entities;
pub mod manager;
pub mod notification;
mod persistence;
mod protobuf;

View File

@ -0,0 +1,182 @@
use crate::chat::Chat;
use crate::entities::{ChatMessageListPB, ChatMessagePB, RepeatedRelatedQuestionPB};
use crate::persistence::{insert_chat, ChatTable};
use dashmap::DashMap;
use flowy_chat_pub::cloud::{ChatCloudService, ChatMessageType};
use flowy_error::{FlowyError, FlowyResult};
use flowy_sqlite::DBConnection;
use lib_infra::util::timestamp;
use std::sync::Arc;
use tracing::{instrument, trace};
pub trait ChatUserService: Send + Sync + 'static {
fn user_id(&self) -> Result<i64, FlowyError>;
fn device_id(&self) -> Result<String, FlowyError>;
fn workspace_id(&self) -> Result<String, FlowyError>;
fn sqlite_connection(&self, uid: i64) -> Result<DBConnection, FlowyError>;
}
pub struct ChatManager {
cloud_service: Arc<dyn ChatCloudService>,
user_service: Arc<dyn ChatUserService>,
chats: Arc<DashMap<String, Arc<Chat>>>,
}
impl ChatManager {
pub fn new(
cloud_service: Arc<dyn ChatCloudService>,
user_service: impl ChatUserService,
) -> ChatManager {
let user_service = Arc::new(user_service);
Self {
cloud_service,
user_service,
chats: Arc::new(DashMap::new()),
}
}
pub async fn open_chat(&self, chat_id: &str) -> Result<(), FlowyError> {
trace!("open chat: {}", chat_id);
self.chats.entry(chat_id.to_string()).or_insert_with(|| {
Arc::new(Chat::new(
self.user_service.user_id().unwrap(),
chat_id.to_string(),
self.user_service.clone(),
self.cloud_service.clone(),
))
});
Ok(())
}
pub async fn close_chat(&self, _chat_id: &str) -> Result<(), FlowyError> {
Ok(())
}
pub async fn delete_chat(&self, chat_id: &str) -> Result<(), FlowyError> {
if let Some((_, chat)) = self.chats.remove(chat_id) {
chat.close();
}
Ok(())
}
pub async fn create_chat(&self, uid: &i64, chat_id: &str) -> Result<Arc<Chat>, FlowyError> {
let workspace_id = self.user_service.workspace_id()?;
self
.cloud_service
.create_chat(uid, &workspace_id, chat_id)
.await?;
save_chat(self.user_service.sqlite_connection(*uid)?, chat_id)?;
let chat = Arc::new(Chat::new(
self.user_service.user_id().unwrap(),
chat_id.to_string(),
self.user_service.clone(),
self.cloud_service.clone(),
));
self.chats.insert(chat_id.to_string(), chat.clone());
Ok(chat)
}
#[instrument(level = "info", skip_all, err)]
pub async fn send_chat_message(
&self,
chat_id: &str,
message: &str,
message_type: ChatMessageType,
) -> Result<(), FlowyError> {
let chat = self.get_or_create_chat_instance(chat_id).await?;
chat.send_chat_message(message, message_type).await?;
Ok(())
}
pub async fn get_or_create_chat_instance(&self, chat_id: &str) -> Result<Arc<Chat>, FlowyError> {
let chat = self.chats.get(chat_id).as_deref().cloned();
match chat {
None => {
let chat = Arc::new(Chat::new(
self.user_service.user_id().unwrap(),
chat_id.to_string(),
self.user_service.clone(),
self.cloud_service.clone(),
));
self.chats.insert(chat_id.to_string(), chat.clone());
Ok(chat)
},
Some(chat) => Ok(chat),
}
}
/// Load chat messages for a given `chat_id`.
///
/// 1. When opening a chat:
/// - Loads local chat messages.
/// - `after_message_id` and `before_message_id` are `None`.
/// - Spawns a task to load messages from the remote server, notifying the user when the remote messages are loaded.
///
/// 2. Loading more messages in an existing chat with `after_message_id`:
/// - `after_message_id` is the last message ID in the current chat messages.
///
/// 3. Loading more messages in an existing chat with `before_message_id`:
/// - `before_message_id` is the first message ID in the current chat messages.
///
/// 4. `after_message_id` and `before_message_id` cannot be specified at the same time.
pub async fn load_prev_chat_messages(
&self,
chat_id: &str,
limit: i64,
before_message_id: Option<i64>,
) -> Result<ChatMessageListPB, FlowyError> {
let chat = self.get_or_create_chat_instance(chat_id).await?;
let list = chat
.load_prev_chat_messages(limit, before_message_id)
.await?;
Ok(list)
}
pub async fn load_latest_chat_messages(
&self,
chat_id: &str,
limit: i64,
after_message_id: Option<i64>,
) -> Result<ChatMessageListPB, FlowyError> {
let chat = self.get_or_create_chat_instance(chat_id).await?;
let list = chat
.load_latest_chat_messages(limit, after_message_id)
.await?;
Ok(list)
}
pub async fn get_related_questions(
&self,
chat_id: &str,
message_id: i64,
) -> Result<RepeatedRelatedQuestionPB, FlowyError> {
let chat = self.get_or_create_chat_instance(chat_id).await?;
let resp = chat.get_related_question(message_id).await?;
Ok(resp)
}
pub async fn generate_answer(
&self,
chat_id: &str,
question_message_id: i64,
) -> Result<ChatMessagePB, FlowyError> {
let chat = self.get_or_create_chat_instance(chat_id).await?;
let resp = chat.generate_answer(question_message_id).await?;
Ok(resp)
}
}
fn save_chat(conn: DBConnection, chat_id: &str) -> FlowyResult<()> {
let row = ChatTable {
chat_id: chat_id.to_string(),
created_at: timestamp(),
name: "".to_string(),
};
insert_chat(conn, &row)?;
Ok(())
}

View File

@ -0,0 +1,40 @@
use flowy_derive::ProtoBuf_Enum;
use flowy_notification::NotificationBuilder;
const CHAT_OBSERVABLE_SOURCE: &str = "Chat";
#[derive(ProtoBuf_Enum, Debug, Default)]
pub enum ChatNotification {
#[default]
Unknown = 0,
DidLoadLatestChatMessage = 1,
DidLoadPrevChatMessage = 2,
DidReceiveChatMessage = 3,
StreamChatMessageError = 4,
FinishAnswerQuestion = 5,
LastUserSentMessage = 6,
}
impl std::convert::From<ChatNotification> for i32 {
fn from(notification: ChatNotification) -> Self {
notification as i32
}
}
impl std::convert::From<i32> for ChatNotification {
fn from(notification: i32) -> Self {
match notification {
1 => ChatNotification::DidLoadLatestChatMessage,
2 => ChatNotification::DidLoadPrevChatMessage,
3 => ChatNotification::DidReceiveChatMessage,
4 => ChatNotification::StreamChatMessageError,
5 => ChatNotification::FinishAnswerQuestion,
6 => ChatNotification::LastUserSentMessage,
_ => ChatNotification::Unknown,
}
}
}
#[tracing::instrument(level = "trace")]
pub(crate) fn send_notification(id: &str, ty: ChatNotification) -> NotificationBuilder {
NotificationBuilder::new(id, ty, CHAT_OBSERVABLE_SOURCE)
}

View File

@ -0,0 +1,106 @@
use flowy_error::{FlowyError, FlowyResult};
use flowy_sqlite::upsert::excluded;
use flowy_sqlite::{
diesel, insert_into,
query_dsl::*,
schema::{chat_message_table, chat_message_table::dsl},
DBConnection, ExpressionMethods, Identifiable, Insertable, QueryResult, Queryable,
};
#[derive(Queryable, Insertable, Identifiable)]
#[diesel(table_name = chat_message_table)]
#[diesel(primary_key(message_id))]
pub struct ChatMessageTable {
pub message_id: i64,
pub chat_id: String,
pub content: String,
pub created_at: i64,
pub author_type: i64,
pub author_id: String,
pub reply_message_id: Option<i64>,
}
pub fn insert_chat_messages(
mut conn: DBConnection,
new_messages: &[ChatMessageTable],
) -> FlowyResult<()> {
conn.immediate_transaction(|conn| {
for message in new_messages {
let _ = insert_into(chat_message_table::table)
.values(message)
.on_conflict(chat_message_table::message_id)
.do_update()
.set((
chat_message_table::content.eq(excluded(chat_message_table::content)),
chat_message_table::created_at.eq(excluded(chat_message_table::created_at)),
chat_message_table::author_type.eq(excluded(chat_message_table::author_type)),
chat_message_table::author_id.eq(excluded(chat_message_table::author_id)),
chat_message_table::reply_message_id.eq(excluded(chat_message_table::reply_message_id)),
))
.execute(conn)?;
}
Ok::<(), FlowyError>(())
})?;
Ok(())
}
pub fn insert_answer_message(
mut conn: DBConnection,
question_message_id: i64,
message: ChatMessageTable,
) -> FlowyResult<()> {
conn.immediate_transaction(|conn| {
// Step 1: Get the message with the given question_message_id
let question_message = dsl::chat_message_table
.filter(chat_message_table::message_id.eq(question_message_id))
.first::<ChatMessageTable>(conn)?;
// Step 2: Use reply_message_id from the retrieved message to delete the existing message
if let Some(reply_id) = question_message.reply_message_id {
diesel::delete(dsl::chat_message_table.filter(chat_message_table::message_id.eq(reply_id)))
.execute(conn)?;
}
// Step 3: Insert the new message
let _ = insert_into(chat_message_table::table)
.values(message)
.on_conflict(chat_message_table::message_id)
.do_update()
.set((
chat_message_table::content.eq(excluded(chat_message_table::content)),
chat_message_table::created_at.eq(excluded(chat_message_table::created_at)),
chat_message_table::author_type.eq(excluded(chat_message_table::author_type)),
chat_message_table::author_id.eq(excluded(chat_message_table::author_id)),
chat_message_table::reply_message_id.eq(excluded(chat_message_table::reply_message_id)),
))
.execute(conn)?;
Ok::<(), FlowyError>(())
})?;
Ok(())
}
pub fn select_chat_messages(
mut conn: DBConnection,
chat_id_val: &str,
limit_val: i64,
after_message_id: Option<i64>,
before_message_id: Option<i64>,
) -> QueryResult<Vec<ChatMessageTable>> {
let mut query = dsl::chat_message_table
.filter(chat_message_table::chat_id.eq(chat_id_val))
.into_boxed();
if let Some(after_message_id) = after_message_id {
query = query.filter(chat_message_table::message_id.gt(after_message_id));
}
if let Some(before_message_id) = before_message_id {
query = query.filter(chat_message_table::message_id.lt(before_message_id));
}
query = query
.order((chat_message_table::message_id.desc(),))
.limit(limit_val);
let messages: Vec<ChatMessageTable> = query.load::<ChatMessageTable>(&mut *conn)?;
Ok(messages)
}

View File

@ -0,0 +1,52 @@
use flowy_sqlite::upsert::excluded;
use flowy_sqlite::{
diesel,
query_dsl::*,
schema::{chat_table, chat_table::dsl},
DBConnection, ExpressionMethods, Identifiable, Insertable, QueryResult, Queryable,
};
#[derive(Clone, Default, Queryable, Insertable, Identifiable)]
#[diesel(table_name = chat_table)]
#[diesel(primary_key(chat_id))]
pub struct ChatTable {
pub chat_id: String,
pub created_at: i64,
pub name: String,
}
pub fn insert_chat(mut conn: DBConnection, new_chat: &ChatTable) -> QueryResult<usize> {
diesel::insert_into(chat_table::table)
.values(new_chat)
.on_conflict(chat_table::chat_id)
.do_update()
.set((
chat_table::created_at.eq(excluded(chat_table::created_at)),
chat_table::name.eq(excluded(chat_table::name)),
))
.execute(&mut *conn)
}
#[allow(dead_code)]
pub fn read_chat(mut conn: DBConnection, chat_id_val: &str) -> QueryResult<ChatTable> {
let row = dsl::chat_table
.filter(chat_table::chat_id.eq(chat_id_val))
.first::<ChatTable>(&mut *conn)?;
Ok(row)
}
#[allow(dead_code)]
pub fn update_chat_name(
mut conn: DBConnection,
chat_id_val: &str,
new_name: &str,
) -> QueryResult<usize> {
diesel::update(dsl::chat_table.filter(chat_table::chat_id.eq(chat_id_val)))
.set(chat_table::name.eq(new_name))
.execute(&mut *conn)
}
#[allow(dead_code)]
pub fn delete_chat(mut conn: DBConnection, chat_id_val: &str) -> QueryResult<usize> {
diesel::delete(dsl::chat_table.filter(chat_table::chat_id.eq(chat_id_val))).execute(&mut *conn)
}

View File

@ -0,0 +1,5 @@
mod chat_message_sql;
mod chat_sql;
pub use chat_message_sql::*;
pub use chat_sql::*;