Mathias Mogensen 54c9d12171
feat: support switch model (#5575)
* feat: ai settings page

* chore: intergate client api

* chore: replace open ai calls

* chore: disable gen image from ai

* chore: clippy

* chore: remove learn about ai

* chore: fix wanrings

* chore: fix restart button title

* chore: remove await

* chore: remove loading indicator

---------

Co-authored-by: nathan <nathan@appflowy.io>
Co-authored-by: Lucas.Xu <lucas.xu@appflowy.io>
2024-06-25 07:59:38 +08:00

137 lines
4.2 KiB
Rust

use crate::chat_manager::ChatUserService;
use crate::entities::{CompleteTextPB, CompleteTextTaskPB, CompletionTypePB};
use allo_isolate::Isolate;
use dashmap::DashMap;
use flowy_chat_pub::cloud::{ChatCloudService, CompletionType};
use flowy_error::{FlowyError, FlowyResult};
use futures::{SinkExt, StreamExt};
use lib_infra::isolate_stream::IsolateSink;
use std::sync::{Arc, Weak};
use tokio::select;
use tracing::{error, trace};
pub struct AITools {
tasks: Arc<DashMap<String, tokio::sync::mpsc::Sender<()>>>,
cloud_service: Weak<dyn ChatCloudService>,
user_service: Weak<dyn ChatUserService>,
}
impl AITools {
pub fn new(
cloud_service: Weak<dyn ChatCloudService>,
user_service: Weak<dyn ChatUserService>,
) -> Self {
Self {
tasks: Arc::new(DashMap::new()),
cloud_service,
user_service,
}
}
pub async fn create_complete_task(
&self,
complete: CompleteTextPB,
) -> FlowyResult<CompleteTextTaskPB> {
let workspace_id = self
.user_service
.upgrade()
.ok_or_else(FlowyError::internal)?
.workspace_id()?;
let (tx, rx) = tokio::sync::mpsc::channel(1);
let task = ToolTask::new(workspace_id, complete, self.cloud_service.clone(), rx);
let task_id = task.task_id.clone();
self.tasks.insert(task_id.clone(), tx);
task.start().await;
Ok(CompleteTextTaskPB { task_id })
}
pub async fn cancel_complete_task(&self, task_id: &str) {
if let Some(entry) = self.tasks.remove(task_id) {
let _ = entry.1.send(()).await;
}
}
}
pub struct ToolTask {
workspace_id: String,
task_id: String,
stop_rx: tokio::sync::mpsc::Receiver<()>,
context: CompleteTextPB,
cloud_service: Weak<dyn ChatCloudService>,
}
impl ToolTask {
pub fn new(
workspace_id: String,
context: CompleteTextPB,
cloud_service: Weak<dyn ChatCloudService>,
stop_rx: tokio::sync::mpsc::Receiver<()>,
) -> Self {
Self {
workspace_id,
task_id: uuid::Uuid::new_v4().to_string(),
context,
cloud_service,
stop_rx,
}
}
pub async fn start(mut self) {
tokio::spawn(async move {
let mut sink = IsolateSink::new(Isolate::new(self.context.stream_port));
match self.cloud_service.upgrade() {
None => {},
Some(cloud_service) => {
let complete_type = match self.context.completion_type {
CompletionTypePB::UnknownCompletionType => CompletionType::ImproveWriting,
CompletionTypePB::ImproveWriting => CompletionType::ImproveWriting,
CompletionTypePB::SpellingAndGrammar => CompletionType::SpellingAndGrammar,
CompletionTypePB::MakeShorter => CompletionType::MakeShorter,
CompletionTypePB::MakeLonger => CompletionType::MakeLonger,
CompletionTypePB::ContinueWriting => CompletionType::ContinueWriting,
};
let _ = sink.send("start:".to_string()).await;
match cloud_service
.stream_complete(&self.workspace_id, &self.context.text, complete_type)
.await
{
Ok(mut stream) => loop {
select! {
_ = self.stop_rx.recv() => {
return;
},
result = stream.next() => {
match result {
Some(Ok(data)) => {
let s = String::from_utf8(data.to_vec()).unwrap_or_default();
trace!("stream completion data: {}", s);
let _ = sink.send(format!("data:{}", s)).await;
},
Some(Err(error)) => {
error!("stream error: {}", error);
let _ = sink.send(format!("error:{}", error)).await;
return;
},
None => {
let _ = sink.send(format!("finish:{}", self.task_id)).await;
return;
},
}
}
}
},
Err(error) => {
error!("stream complete error: {}", error);
let _ = sink.send(format!("error:{}", error)).await;
},
}
},
}
});
}
}