edit document via actor stream

This commit is contained in:
appflowy 2021-09-30 17:24:02 +08:00
parent 1fdef0914b
commit efb2a607e7
22 changed files with 378 additions and 230 deletions

View File

@ -102,4 +102,5 @@ flowy-test = { path = "../rust-lib/flowy-test" }
flowy-infra = { path = "../rust-lib/flowy-infra" }
flowy-ot = { path = "../rust-lib/flowy-ot" }
flowy-document = { path = "../rust-lib/flowy-document", features = ["flowy_test", "http_server"] }
flowy-sqlite = { path = "../rust-lib/flowy-sqlite" }
flowy-sqlite = { path = "../rust-lib/flowy-sqlite" }
futures-util = "0.3.15"

View File

@ -61,7 +61,6 @@ pub fn run(listener: TcpListener, app_ctx: AppContext) -> Result<Server, std::io
.app_data(app_ctx.pg_pool.clone())
.app_data(app_ctx.ws_bizs.clone())
.app_data(app_ctx.doc_biz.clone())
.app_data(app_ctx.runtime.clone())
})
.listen(listener)?
.run();

View File

@ -6,9 +6,7 @@ use actix::Addr;
use actix_web::web::Data;
use flowy_ws::WsModule;
use sqlx::PgPool;
use std::{io, sync::Arc};
pub type FlowyRuntime = tokio::runtime::Runtime;
use std::sync::Arc;
#[derive(Clone)]
pub struct AppContext {
@ -16,14 +14,12 @@ pub struct AppContext {
pub pg_pool: Data<PgPool>,
pub ws_bizs: Data<WsBizHandlers>,
pub doc_biz: Data<Arc<DocBiz>>,
pub runtime: Data<FlowyRuntime>,
}
impl AppContext {
pub fn new(ws_server: Addr<WsServer>, db_pool: PgPool) -> Self {
let ws_server = Data::new(ws_server);
let pg_pool = Data::new(db_pool);
let runtime = Data::new(runtime().unwrap());
let mut ws_bizs = WsBizHandlers::new();
let doc_biz = Arc::new(DocBiz::new(pg_pool.clone()));
@ -34,15 +30,6 @@ impl AppContext {
pg_pool,
ws_bizs: Data::new(ws_bizs),
doc_biz: Data::new(doc_biz),
runtime,
}
}
}
fn runtime() -> io::Result<tokio::runtime::Runtime> {
tokio::runtime::Builder::new_multi_thread()
.thread_name("flowy-server-rt")
.enable_io()
.enable_time()
.build()
}

View File

@ -1,9 +1,13 @@
use crate::service::{doc::doc::DocManager, util::parse_from_bytes, ws::WsClientData};
use crate::service::{
doc::doc::DocManager,
util::{md5, parse_from_bytes},
ws::{entities::Socket, WsClientData, WsUser},
};
use actix_rt::task::spawn_blocking;
use actix_web::web::Data;
use async_stream::stream;
use flowy_document::protobuf::{Revision, WsDataType, WsDocumentData};
use flowy_net::errors::{internal_error, Result as DocResult};
use flowy_net::errors::{internal_error, Result as DocResult, ServerError};
use futures::stream::StreamExt;
use sqlx::PgPool;
use std::sync::Arc;
@ -11,7 +15,7 @@ use tokio::sync::{mpsc, oneshot};
pub enum DocWsMsg {
ClientData {
data: WsClientData,
client_data: WsClientData,
pool: Data<PgPool>,
ret: oneshot::Sender<DocResult<()>>,
},
@ -50,52 +54,60 @@ impl DocWsMsgActor {
async fn handle_message(&self, msg: DocWsMsg) {
match msg {
DocWsMsg::ClientData { data, pool, ret } => {
ret.send(self.handle_client_data(data, pool).await);
DocWsMsg::ClientData { client_data, pool, ret } => {
let _ = ret.send(self.handle_client_data(client_data, pool).await);
},
}
}
async fn handle_client_data(&self, data: WsClientData, pool: Data<PgPool>) -> DocResult<()> {
let bytes = data.data.clone();
async fn handle_client_data(&self, client_data: WsClientData, pool: Data<PgPool>) -> DocResult<()> {
let WsClientData { user, socket, data } = client_data;
let document_data = spawn_blocking(move || {
let document_data: WsDocumentData = parse_from_bytes(&bytes)?;
let document_data: WsDocumentData = parse_from_bytes(&data)?;
DocResult::Ok(document_data)
})
.await
.map_err(internal_error)??;
match document_data.ty {
WsDataType::Acked => {},
WsDataType::PushRev => self.handle_push_rev(data, document_data.data, pool).await?,
WsDataType::PullRev => {},
WsDataType::Conflict => {},
WsDataType::Acked => Ok(()),
WsDataType::PushRev => self.handle_push_rev(user, socket, document_data.data, pool).await,
WsDataType::PullRev => Ok(()),
WsDataType::Conflict => Ok(()),
}
Ok(())
}
async fn handle_push_rev(
&self,
client_data: WsClientData,
user: Arc<WsUser>,
socket: Socket,
revision_data: Vec<u8>,
pool: Data<PgPool>,
) -> DocResult<()> {
let revision = spawn_blocking(move || {
let revision: Revision = parse_from_bytes(&revision_data)?;
let _ = verify_md5(&revision)?;
DocResult::Ok(revision)
})
.await
.map_err(internal_error)??;
match self.doc_manager.get(&revision.doc_id, pool).await? {
Some(ctx) => {
ctx.apply_revision(client_data, revision).await;
Some(edit_doc) => {
edit_doc.apply_revision(user, socket, revision).await?;
Ok(())
},
None => {
//
log::error!("Document with id: {} not exist", &revision.doc_id);
Ok(())
},
}
}
}
fn verify_md5(revision: &Revision) -> DocResult<()> {
if md5(&revision.delta_data) != revision.md5 {
return Err(ServerError::internal().context("Revision md5 not match"));
}
Ok(())
}

View File

@ -1,20 +1,21 @@
use super::edit_doc::EditDocContext;
use crate::service::{
doc::{
actor::{DocWsMsg, DocWsMsgActor},
edit::EditDoc,
read_doc,
ws_actor::{DocWsMsg, DocWsMsgActor},
},
ws::{WsBizHandler, WsClientData},
};
use actix_web::web::Data;
use dashmap::DashMap;
use flowy_document::protobuf::QueryDocParams;
use flowy_net::errors::ServerError;
use protobuf::Message;
use flowy_net::errors::{internal_error, ServerError};
use sqlx::PgPool;
use std::{sync::Arc, time::Duration};
use tokio::sync::{mpsc, mpsc::error::SendError, oneshot};
use std::sync::Arc;
use tokio::{
sync::{mpsc, oneshot},
task::spawn_blocking,
};
pub struct DocBiz {
pub manager: Arc<DocManager>,
@ -43,11 +44,7 @@ impl WsBizHandler for DocBiz {
let pool = self.pg_pool.clone();
actix_rt::spawn(async move {
let msg = DocWsMsg::ClientData {
data: client_data,
ret,
pool,
};
let msg = DocWsMsg::ClientData { client_data, ret, pool };
match sender.send(msg).await {
Ok(_) => {},
Err(e) => log::error!("{}", e),
@ -61,7 +58,7 @@ impl WsBizHandler for DocBiz {
}
pub struct DocManager {
docs_map: DashMap<String, Arc<EditDocContext>>,
docs_map: DashMap<String, Arc<EditDoc>>,
}
impl DocManager {
@ -71,7 +68,7 @@ impl DocManager {
}
}
pub async fn get(&self, doc_id: &str, pg_pool: Data<PgPool>) -> Result<Option<Arc<EditDocContext>>, ServerError> {
pub async fn get(&self, doc_id: &str, pg_pool: Data<PgPool>) -> Result<Option<Arc<EditDoc>>, ServerError> {
match self.docs_map.get(doc_id) {
None => {
let params = QueryDocParams {
@ -79,7 +76,8 @@ impl DocManager {
..Default::default()
};
let doc = read_doc(pg_pool.get_ref(), params).await?;
let edit_doc = Arc::new(EditDocContext::new(doc)?);
let edit_doc = spawn_blocking(|| EditDoc::new(doc)).await.map_err(internal_error)?;
let edit_doc = Arc::new(edit_doc?);
self.docs_map.insert(doc_id.to_string(), edit_doc.clone());
Ok(Some(edit_doc))
},

View File

@ -0,0 +1,92 @@
use crate::service::{
doc::edit::EditDocContext,
ws::{entities::Socket, WsUser},
};
use async_stream::stream;
use flowy_document::protobuf::Revision;
use flowy_net::errors::{internal_error, Result as DocResult};
use futures::stream::StreamExt;
use std::sync::Arc;
use tokio::{
sync::{mpsc, oneshot},
task::spawn_blocking,
};
#[derive(Clone)]
pub struct EditUser {
user: Arc<WsUser>,
pub(crate) socket: Socket,
}
impl EditUser {
pub fn id(&self) -> String { self.user.id().to_string() }
}
#[derive(Debug)]
pub enum EditMsg {
Revision {
user: Arc<WsUser>,
socket: Socket,
revision: Revision,
ret: oneshot::Sender<DocResult<()>>,
},
DocumentJson {
ret: oneshot::Sender<DocResult<String>>,
},
}
pub struct EditDocActor {
receiver: Option<mpsc::Receiver<EditMsg>>,
edit_context: Arc<EditDocContext>,
}
impl EditDocActor {
pub fn new(receiver: mpsc::Receiver<EditMsg>, edit_context: Arc<EditDocContext>) -> Self {
Self {
receiver: Some(receiver),
edit_context,
}
}
pub async fn run(mut self) {
let mut receiver = self
.receiver
.take()
.expect("DocActor's receiver should only take one time");
let stream = stream! {
loop {
match receiver.recv().await {
Some(msg) => yield msg,
None => break,
}
}
};
stream.for_each(|msg| self.handle_message(msg)).await;
}
async fn handle_message(&self, msg: EditMsg) {
match msg {
EditMsg::Revision {
user,
socket,
revision,
ret,
} => {
// ret.send(self.handle_client_data(client_data, pool).await);
let user = EditUser {
user: user.clone(),
socket: socket.clone(),
};
let _ = ret.send(self.edit_context.apply_revision(user, revision).await);
},
EditMsg::DocumentJson { ret } => {
let edit_context = self.edit_context.clone();
let json = spawn_blocking(move || edit_context.document_json())
.await
.map_err(internal_error);
let _ = ret.send(json);
},
}
}
}

View File

@ -1,8 +1,4 @@
use crate::service::{
util::md5,
ws::{entities::Socket, WsClientData, WsMessageAdaptor, WsUser},
};
use crate::service::{doc::edit::actor::EditUser, util::md5, ws::WsMessageAdaptor};
use byteorder::{BigEndian, WriteBytesExt};
use bytes::Bytes;
use dashmap::DashMap;
@ -19,7 +15,6 @@ use flowy_ot::{
use flowy_ws::WsMessage;
use parking_lot::RwLock;
use protobuf::Message;
use std::{
convert::TryInto,
sync::{
@ -28,12 +23,6 @@ use std::{
},
time::Duration,
};
struct EditUser {
user: Arc<WsUser>,
socket: Socket,
}
pub struct EditDocContext {
doc_id: String,
rev_id: AtomicI64,
@ -54,18 +43,11 @@ impl EditDocContext {
})
}
pub fn doc_json(&self) -> String { self.document.read().to_json() }
pub fn document_json(&self) -> String { self.document.read().to_json() }
#[tracing::instrument(level = "debug", skip(self, client_data, revision))]
pub async fn apply_revision(&self, client_data: WsClientData, revision: Revision) -> Result<(), ServerError> {
let _ = self.verify_md5(&revision)?;
pub async fn apply_revision(&self, user: EditUser, revision: Revision) -> Result<(), ServerError> {
// Opti: find out another way to keep the user socket available.
let user = EditUser {
user: client_data.user.clone(),
socket: client_data.socket.clone(),
};
self.users.insert(client_data.user.id().to_owned(), user);
self.users.insert(user.id(), user.clone());
log::debug!(
"cur_base_rev_id: {}, expect_base_rev_id: {} rev_id: {}",
self.rev_id.load(SeqCst),
@ -73,7 +55,6 @@ impl EditDocContext {
revision.rev_id
);
let cli_socket = client_data.socket;
let cur_rev_id = self.rev_id.load(SeqCst);
if cur_rev_id > revision.rev_id {
// The client document is outdated. Transform the client revision delta and then
@ -86,19 +67,19 @@ impl EditDocContext {
log::debug!("{} client delta: {}", self.doc_id, cli_prime.to_json());
let cli_revision = self.mk_revision(revision.rev_id, cli_prime);
let ws_cli_revision = mk_push_rev_ws_message(&self.doc_id, cli_revision);
cli_socket.do_send(ws_cli_revision).map_err(internal_error)?;
user.socket.do_send(ws_cli_revision).map_err(internal_error)?;
Ok(())
} else if cur_rev_id < revision.rev_id {
if cur_rev_id != revision.base_rev_id {
// The server document is outdated, try to get the missing revision from the
// client.
cli_socket
user.socket
.do_send(mk_pull_rev_ws_message(&self.doc_id, cur_rev_id, revision.rev_id))
.map_err(internal_error)?;
} else {
let delta = Delta::from_bytes(&revision.delta_data).map_err(internal_error)?;
let _ = self.update_document_delta(delta)?;
cli_socket
user.socket
.do_send(mk_acked_ws_message(&revision))
.map_err(internal_error)?;
self.rev_id.fetch_update(SeqCst, SeqCst, |_e| Some(revision.rev_id));
@ -154,13 +135,6 @@ impl EditDocContext {
Ok(())
}
fn verify_md5(&self, revision: &Revision) -> Result<(), ServerError> {
if md5(&revision.delta_data) != revision.md5 {
return Err(ServerError::internal().context("Delta md5 not match"));
}
Ok(())
}
#[tracing::instrument(level = "debug", skip(self, revision))]
async fn save_revision(&self, revision: &Revision) -> Result<(), ServerError> {
// Opti: save with multiple revisions

View File

@ -0,0 +1,56 @@
use crate::service::{
doc::edit::{
actor::{EditDocActor, EditMsg},
EditDocContext,
},
ws::{entities::Socket, WsUser},
};
use flowy_document::protobuf::{Doc, Revision};
use flowy_net::errors::{internal_error, Result as DocResult, ServerError};
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
pub struct EditDoc {
sender: mpsc::Sender<EditMsg>,
}
impl EditDoc {
pub fn new(doc: Doc) -> Result<Self, ServerError> {
let (sender, receiver) = mpsc::channel(100);
let edit_context = Arc::new(EditDocContext::new(doc)?);
let actor = EditDocActor::new(receiver, edit_context);
tokio::task::spawn(actor.run());
Ok(Self { sender })
}
#[tracing::instrument(level = "debug", skip(self, user, socket, revision))]
pub async fn apply_revision(
&self,
user: Arc<WsUser>,
socket: Socket,
revision: Revision,
) -> Result<(), ServerError> {
let (ret, rx) = oneshot::channel();
let msg = EditMsg::Revision {
user,
socket,
revision,
ret,
};
let _ = self.send(msg, rx).await?;
Ok(())
}
pub async fn document_json(&self) -> DocResult<String> {
let (ret, rx) = oneshot::channel();
let msg = EditMsg::DocumentJson { ret };
self.send(msg, rx).await?
}
async fn send<T>(&self, msg: EditMsg, rx: oneshot::Receiver<T>) -> DocResult<T> {
let _ = self.sender.send(msg).await.map_err(internal_error)?;
let result = rx.await?;
Ok(result)
}
}

View File

@ -0,0 +1,6 @@
mod actor;
mod context;
mod edit_doc;
pub use context::*;
pub use edit_doc::*;

View File

@ -1,8 +1,8 @@
pub mod crud;
pub mod doc;
pub mod edit_doc;
pub mod router;
mod ws_actor;
pub(crate) use crud::*;
pub use router::*;
mod actor;
pub mod crud;
pub mod doc;
mod edit;
pub mod router;

View File

@ -1,7 +1,8 @@
use crate::service::ws::{WsBizHandlers, WsClient, WsServer, WsUser};
use crate::service::{
user::LoggedUser,
ws::{WsBizHandlers, WsClient, WsServer, WsUser},
};
use actix::Addr;
use crate::{context::FlowyRuntime, service::user::LoggedUser};
use actix_web::{
get,
web::{Data, Path, Payload},
@ -17,14 +18,13 @@ pub async fn establish_ws_connection(
payload: Payload,
token: Path<String>,
server: Data<Addr<WsServer>>,
runtime: Data<FlowyRuntime>,
biz_handlers: Data<WsBizHandlers>,
) -> Result<HttpResponse, Error> {
log::info!("establish_ws_connection");
match LoggedUser::from_token(token.clone()) {
Ok(user) => {
let ws_user = WsUser::new(user.clone());
let client = WsClient::new(ws_user, server.get_ref().clone(), biz_handlers, runtime);
let client = WsClient::new(ws_user, server.get_ref().clone(), biz_handlers);
let result = ws::start(client, &request, payload);
match result {
Ok(response) => Ok(response.into()),

View File

@ -1,11 +1,9 @@
use crate::{
config::{HEARTBEAT_INTERVAL, PING_TIMEOUT},
context::FlowyRuntime,
service::{
user::LoggedUser,
ws::{
entities::{Connect, Disconnect, Socket},
WsBizHandler,
WsBizHandlers,
WsMessageAdaptor,
WsServer,
@ -19,6 +17,7 @@ use bytes::Bytes;
use flowy_ws::WsMessage;
use std::{convert::TryFrom, sync::Arc, time::Instant};
#[derive(Debug)]
pub struct WsUser {
inner: LoggedUser,
}
@ -39,23 +38,16 @@ pub struct WsClient {
user: Arc<WsUser>,
server: Addr<WsServer>,
biz_handlers: Data<WsBizHandlers>,
runtime: Data<FlowyRuntime>,
hb: Instant,
}
impl WsClient {
pub fn new(
user: WsUser,
server: Addr<WsServer>,
biz_handlers: Data<WsBizHandlers>,
runtime: Data<FlowyRuntime>,
) -> Self {
pub fn new(user: WsUser, server: Addr<WsServer>, biz_handlers: Data<WsBizHandlers>) -> Self {
Self {
user: Arc::new(user),
server,
biz_handlers,
hb: Instant::now(),
runtime,
}
}

View File

@ -12,3 +12,23 @@ async fn edit_doc_insert_text() {
])
.await;
}
#[actix_rt::test]
async fn edit_doc_insert_large_text() {
let test = DocumentTest::new().await;
test.run_scripts(vec![
DocScript::ConnectWs,
DocScript::SendText(0, "abc"),
DocScript::SendText(0, "abc"),
DocScript::SendText(0, "abc"),
DocScript::SendText(0, "abc"),
DocScript::SendText(0, "abc"),
DocScript::SendText(0, "abc"),
DocScript::SendText(0, "abc"),
DocScript::SendText(0, "abc"),
/* DocScript::AssertClient(r#"[{"insert":"abc123efg\n"}]"#),
* DocScript::AssertServer(r#"[{"insert":"abc123efg\n"}]"#), */
])
.await;
}

View File

@ -1,13 +1,16 @@
// use crate::helper::*;
use crate::helper::{spawn_server, TestServer};
use actix_web::web::Data;
use backend::service::doc::doc::DocManager;
use flowy_document::{
entities::doc::QueryDocParams,
services::doc::edit_doc_context::EditDocContext as ClientEditDocContext,
};
use flowy_net::config::ServerConfig;
use flowy_test::{workspace::ViewTest, FlowyTest};
use flowy_user::services::user::UserSession;
use futures_util::{stream, stream::StreamExt};
use sqlx::PgPool;
use std::sync::Arc;
use tokio::time::{sleep, Duration};
@ -17,6 +20,7 @@ pub struct DocumentTest {
}
#[derive(Clone)]
pub enum DocScript {
ConnectWs,
SendText(usize, &'static str),
AssertClient(&'static str),
AssertServer(&'static str),
@ -31,42 +35,65 @@ impl DocumentTest {
}
pub async fn run_scripts(self, scripts: Vec<DocScript>) {
init_user(&self.flowy_test).await;
let _ = self.flowy_test.sign_up().await;
let DocumentTest { server, flowy_test } = self;
run_scripts(server, flowy_test, scripts).await;
let script_context = ScriptContext {
client_edit_context: create_doc(&flowy_test).await,
user_session: flowy_test.sdk.user_session.clone(),
doc_manager: server.app_ctx.doc_biz.manager.clone(),
pool: Data::new(server.pg_pool.clone()),
};
run_scripts(script_context, scripts).await;
std::mem::forget(flowy_test);
sleep(Duration::from_secs(5)).await;
}
}
pub async fn run_scripts(server: TestServer, flowy_test: FlowyTest, scripts: Vec<DocScript>) {
let client_edit_context = create_doc(&flowy_test).await;
let doc_id = client_edit_context.doc_id.clone();
#[derive(Clone)]
struct ScriptContext {
client_edit_context: Arc<ClientEditDocContext>,
user_session: Arc<UserSession>,
doc_manager: Arc<DocManager>,
pool: Data<PgPool>,
}
async fn run_scripts(context: ScriptContext, scripts: Vec<DocScript>) {
let mut fut_scripts = vec![];
for script in scripts {
match script {
DocScript::SendText(index, s) => {
client_edit_context.insert(index, s);
},
DocScript::AssertClient(s) => {
let json = client_edit_context.doc_json();
assert_eq(s, &json);
},
DocScript::AssertServer(s) => {
sleep(Duration::from_millis(100)).await;
let pool = server.pg_pool.clone();
let edit_context = server
.app_ctx
.doc_biz
.manager
.get(&doc_id, Data::new(pool))
.await
.unwrap()
.unwrap();
let json = edit_context.doc_json();
assert_eq(s, &json);
},
}
let context = context.clone();
let fut = async move {
match script {
DocScript::ConnectWs => {
let token = context.user_session.token().unwrap();
let _ = context.user_session.start_ws_connection(&token).await.unwrap();
},
DocScript::SendText(index, s) => {
context.client_edit_context.insert(index, s).unwrap();
},
DocScript::AssertClient(s) => {
let json = context.client_edit_context.doc_json();
assert_eq(s, &json);
},
DocScript::AssertServer(s) => {
let edit_doc = context
.doc_manager
.get(&context.client_edit_context.doc_id, context.pool)
.await
.unwrap()
.unwrap();
let json = edit_doc.document_json().await.unwrap();
assert_eq(s, &json);
},
}
};
fut_scripts.push(fut);
}
let mut stream = stream::iter(fut_scripts);
while let Some(script) = stream.next().await {
let _ = script.await;
}
std::mem::forget(flowy_test);
}
fn assert_eq(expect: &str, receive: &str) {
@ -90,10 +117,3 @@ async fn create_doc(flowy_test: &FlowyTest) -> Arc<ClientEditDocContext> {
edit_context
}
async fn init_user(flowy_test: &FlowyTest) {
let _ = flowy_test.sign_up().await;
let user_session = flowy_test.sdk.user_session.clone();
user_session.init_user().await.unwrap();
}

View File

@ -14,6 +14,7 @@ pub(crate) struct DocCache {
impl DocCache {
pub(crate) fn new() -> Self { Self { inner: DashMap::new() } }
#[allow(dead_code)]
pub(crate) fn all_docs(&self) -> Vec<Arc<EditDocContext>> {
self.inner
.iter()

View File

@ -237,7 +237,7 @@ impl TestBuilder {
},
TestOp::DocComposeDelta(doc_index, delta_i) => {
let delta = self.deltas.get(*delta_i).unwrap().as_ref().unwrap();
self.documents[*doc_index].compose_delta(delta);
self.documents[*doc_index].compose_delta(delta).unwrap();
},
TestOp::DocComposePrime(doc_index, prime_i) => {
let delta = self

View File

@ -16,13 +16,14 @@ use flowy_database::{
ExpressionMethods,
UserDatabaseConnection,
};
use flowy_infra::kv::KV;
use flowy_infra::{future::wrap_future, kv::KV};
use flowy_net::config::ServerConfig;
use flowy_sqlite::ConnectionPool;
use flowy_ws::{connect::Retry, WsController, WsMessage, WsMessageHandler, WsSender};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::{sync::Arc, time::Duration};
use tokio::task::JoinHandle;
pub struct UserSessionConfig {
root_dir: String,
@ -50,7 +51,7 @@ pub struct UserSession {
#[allow(dead_code)]
server: Server,
session: RwLock<Option<Session>>,
ws_controller: Arc<RwLock<WsController>>,
ws_controller: Arc<WsController>,
status_callback: SessionStatusCallback,
}
@ -58,7 +59,7 @@ impl UserSession {
pub fn new(config: UserSessionConfig, status_callback: SessionStatusCallback) -> Self {
let db = UserDB::new(&config.root_dir);
let server = construct_user_server(&config.server_config);
let ws_controller = Arc::new(RwLock::new(WsController::new()));
let ws_controller = Arc::new(WsController::new());
let user_session = Self {
database: db,
config,
@ -148,7 +149,7 @@ impl UserSession {
pub async fn init_user(&self) -> Result<(), UserError> {
let (_, token) = self.get_session()?.into_part();
let _ = self.start_ws_connection(&token)?;
let _ = self.start_ws_connection(&token).await?;
Ok(())
}
@ -183,22 +184,12 @@ impl UserSession {
pub fn token(&self) -> Result<String, UserError> { Ok(self.get_session()?.token) }
pub fn add_ws_handler(&self, handler: Arc<dyn WsMessageHandler>) {
let _ = self.ws_controller.write().add_handler(handler);
}
pub fn get_ws_sender(&self) -> Result<Arc<WsSender>, UserError> {
match self.ws_controller.try_read_for(Duration::from_millis(300)) {
None => Err(UserError::internal().context("Send ws message timeout")),
Some(guard) => {
let sender = guard.get_sender()?;
Ok(sender)
},
}
let _ = self.ws_controller.add_handler(handler);
}
pub fn send_ws_msg<T: Into<WsMessage>>(&self, msg: T) -> Result<(), UserError> {
let sender = self.get_ws_sender()?;
let _ = sender.send_msg(msg)?;
let sender = self.ws_controller.sender()?;
sender.send_msg(msg)?;
Ok(())
}
}
@ -294,15 +285,10 @@ impl UserSession {
}
}
fn start_ws_connection(&self, token: &str) -> Result<(), UserError> {
pub async fn start_ws_connection(&self, token: &str) -> Result<(), UserError> {
log::debug!("start_ws_connection");
let addr = format!("{}/{}", self.server.ws_addr(), token);
let ws_controller = self.ws_controller.clone();
let retry = Retry::new(&addr, move |addr| {
let _ = ws_controller.write().connect(addr.to_owned());
});
let _ = self.ws_controller.write().connect_with_retry(addr, retry)?;
let _ = self.ws_controller.connect(addr).await?;
Ok(())
}
}

View File

@ -23,6 +23,8 @@ log = "0.4"
protobuf = {version = "2.18.0"}
strum = "0.21"
strum_macros = "0.21"
parking_lot = "0.11"
dashmap = "4.0"
[dev-dependencies]
tokio = {version = "1", features = ["full"]}

View File

@ -86,7 +86,9 @@ impl WsStream {
msg_tx: msg_tx.clone(),
inner: Some((
Box::pin(async move {
let _ = ws_read.for_each(|message| async { post_message(msg_tx.clone(), message) }).await;
let _ = ws_read
.for_each(|message| async { post_message(msg_tx.clone(), message) })
.await;
Ok(())
}),
Box::pin(async move {
@ -135,7 +137,7 @@ fn post_message(tx: MsgSender, message: Result<Message, Error>) {
},
}
}
#[allow(dead_code)]
pub struct Retry<F> {
f: F,
#[allow(dead_code)]
@ -147,6 +149,7 @@ impl<F> Retry<F>
where
F: Fn(&str),
{
#[allow(dead_code)]
pub fn new(addr: &str, f: F) -> Self {
Self {
f,

View File

@ -1,6 +1,6 @@
use bytes::Bytes;
use flowy_derive::{ProtoBuf, ProtoBuf_Enum};
use std::convert::{TryFrom, TryInto};
use std::convert::TryInto;
use tokio_tungstenite::tungstenite::Message as TokioMessage;
// Opti: using four bytes of the data to represent the source

View File

@ -1,13 +1,15 @@
use crate::{
connect::{Retry, WsConnectionFuture},
connect::{WsConnectionFuture, WsStream},
errors::WsError,
WsMessage,
WsModule,
};
use bytes::Bytes;
use flowy_net::errors::ServerError;
use dashmap::DashMap;
use flowy_net::errors::{internal_error, ServerError};
use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
use futures_core::{future::BoxFuture, ready, Stream};
use parking_lot::RwLock;
use pin_project::pin_project;
use std::{
collections::HashMap,
@ -17,7 +19,10 @@ use std::{
sync::Arc,
task::{Context, Poll},
};
use tokio::{sync::RwLock, task::JoinHandle};
use tokio::{
sync::{broadcast, oneshot},
task::JoinHandle,
};
use tokio_tungstenite::tungstenite::{
protocol::{frame::coding::CloseCode, CloseFrame},
Message,
@ -25,6 +30,8 @@ use tokio_tungstenite::tungstenite::{
pub type MsgReceiver = UnboundedReceiver<Message>;
pub type MsgSender = UnboundedSender<Message>;
type Handlers = DashMap<WsModule, Arc<dyn WsMessageHandler>>;
pub trait WsMessageHandler: Sync + Send + 'static {
fn source(&self) -> WsModule;
fn receive_message(&self, msg: WsMessage);
@ -46,6 +53,7 @@ impl WsStateNotify {
}
}
#[derive(Clone)]
pub enum WsState {
Init,
Connected(Arc<WsSender>),
@ -53,37 +61,23 @@ pub enum WsState {
}
pub struct WsController {
handlers: HashMap<WsModule, Arc<dyn WsMessageHandler>>,
state_notify: Arc<RwLock<WsStateNotify>>,
#[allow(dead_code)]
addr: Option<String>,
sender: Option<Arc<WsSender>>,
handlers: Handlers,
state_notify: Arc<broadcast::Sender<WsState>>,
sender: RwLock<Option<Arc<WsSender>>>,
}
impl WsController {
pub fn new() -> Self {
let state_notify = Arc::new(RwLock::new(WsStateNotify {
state: WsState::Init,
callback: None,
}));
let (state_notify, _) = broadcast::channel(16);
let controller = Self {
handlers: HashMap::new(),
state_notify,
addr: None,
sender: None,
handlers: DashMap::new(),
sender: RwLock::new(None),
state_notify: Arc::new(state_notify),
};
controller
}
pub async fn state_callback<SC>(&self, callback: SC)
where
SC: Fn(&WsState) + Send + Sync + 'static,
{
(self.state_notify.write().await).callback = Some(Arc::new(callback));
}
pub fn add_handler(&mut self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
pub fn add_handler(&self, handler: Arc<dyn WsMessageHandler>) -> Result<(), WsError> {
let source = handler.source();
if self.handlers.contains_key(&source) {
log::error!("WsSource's {:?} is already registered", source);
@ -92,60 +86,47 @@ impl WsController {
Ok(())
}
pub fn connect(&mut self, addr: String) -> Result<JoinHandle<()>, ServerError> { self._connect(addr.clone(), None) }
pub fn connect_with_retry<F>(&mut self, addr: String, retry: Retry<F>) -> Result<JoinHandle<()>, ServerError>
where
F: Fn(&str) + Send + Sync + 'static,
{
self._connect(addr, Some(Box::pin(async { retry.await })))
pub async fn connect(&self, addr: String) -> Result<(), ServerError> {
let (ret, rx) = oneshot::channel::<Result<(), ServerError>>();
self._connect(addr.clone(), ret);
rx.await?
}
pub fn get_sender(&self) -> Result<Arc<WsSender>, WsError> {
match &self.sender {
#[allow(dead_code)]
pub fn state_subscribe(&self) -> broadcast::Receiver<WsState> { self.state_notify.subscribe() }
pub fn sender(&self) -> Result<Arc<WsSender>, WsError> {
match &*self.sender.read() {
None => Err(WsError::internal().context("WsSender is not initialized, should call connect first")),
Some(sender) => Ok(sender.clone()),
}
}
fn _connect(&mut self, addr: String, retry: Option<BoxFuture<'static, ()>>) -> Result<JoinHandle<()>, ServerError> {
fn _connect(&self, addr: String, ret: oneshot::Sender<Result<(), ServerError>>) {
log::debug!("🐴 ws connect: {}", &addr);
let (connection, handlers) = self.make_connect(addr.clone());
let state_notify = self.state_notify.clone();
let sender = self
.sender
.read()
.clone()
.expect("Sender should be not empty after calling make_connect");
Ok(tokio::spawn(async move {
tokio::spawn(async move {
match connection.await {
Ok(stream) => {
state_notify.write().await.update_state(WsState::Connected(sender));
tokio::select! {
result = stream => {
match result {
Ok(_) => {},
Err(e) => {
// TODO: retry?
log::error!("ws stream error {:?}", e);
state_notify.write().await.update_state(WsState::Disconnected(e));
}
}
},
result = handlers => log::debug!("handlers completed {:?}", result),
};
state_notify.send(WsState::Connected(sender));
ret.send(Ok(()));
spawn_steam_and_handlers(stream, handlers, state_notify).await;
},
Err(e) => {
log::error!("ws connect {} failed {:?}", addr, e);
state_notify.write().await.update_state(WsState::Disconnected(e));
if let Some(retry) = retry {
tokio::spawn(retry);
}
state_notify.send(WsState::Disconnected(e.clone()));
ret.send(Err(ServerError::internal().context(e)));
},
}
}))
});
}
fn make_connect(&mut self, addr: String) -> (WsConnectionFuture, WsHandlerFuture) {
fn make_connect(&self, addr: String) -> (WsConnectionFuture, WsHandlerFuture) {
// Stream User
// ┌───────────────┐ ┌──────────────┐
// ┌──────┐ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │
@ -159,8 +140,7 @@ impl WsController {
let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded();
let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded();
let handlers = self.handlers.clone();
self.sender = Some(Arc::new(WsSender { ws_tx }));
self.addr = Some(addr.clone());
*self.sender.write() = Some(Arc::new(WsSender { ws_tx }));
(
WsConnectionFuture::new(msg_tx, ws_rx, addr),
WsHandlerFuture::new(handlers, msg_rx),
@ -168,17 +148,36 @@ impl WsController {
}
}
async fn spawn_steam_and_handlers(
stream: WsStream,
handlers: WsHandlerFuture,
state_notify: Arc<broadcast::Sender<WsState>>,
) {
tokio::select! {
result = stream => {
match result {
Ok(_) => {},
Err(e) => {
// TODO: retry?
log::error!("ws stream error {:?}", e);
state_notify.send(WsState::Disconnected(e));
}
}
},
result = handlers => log::debug!("handlers completed {:?}", result),
};
}
#[pin_project]
pub struct WsHandlerFuture {
#[pin]
msg_rx: MsgReceiver,
handlers: HashMap<WsModule, Arc<dyn WsMessageHandler>>,
// Opti: Hashmap would be better
handlers: Handlers,
}
impl WsHandlerFuture {
fn new(handlers: HashMap<WsModule, Arc<dyn WsMessageHandler>>, msg_rx: MsgReceiver) -> Self {
Self { msg_rx, handlers }
}
fn new(handlers: Handlers, msg_rx: MsgReceiver) -> Self { Self { msg_rx, handlers } }
fn handler_ws_message(&self, message: Message) {
match message {