diff --git a/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/errors.pbenum.dart b/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/errors.pbenum.dart index d16d79e397..f7b35639a3 100644 --- a/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/errors.pbenum.dart +++ b/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/errors.pbenum.dart @@ -11,9 +11,13 @@ import 'package:protobuf/protobuf.dart' as $pb; class ErrorCode extends $pb.ProtobufEnum { static const ErrorCode InternalError = ErrorCode._(0, const $core.bool.fromEnvironment('protobuf.omit_enum_names') ? '' : 'InternalError'); + static const ErrorCode DuplicateSource = ErrorCode._(1, const $core.bool.fromEnvironment('protobuf.omit_enum_names') ? '' : 'DuplicateSource'); + static const ErrorCode UnsupportedMessage = ErrorCode._(2, const $core.bool.fromEnvironment('protobuf.omit_enum_names') ? '' : 'UnsupportedMessage'); static const $core.List values = [ InternalError, + DuplicateSource, + UnsupportedMessage, ]; static final $core.Map<$core.int, ErrorCode> _byValue = $pb.ProtobufEnum.initByValue(values); diff --git a/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/errors.pbjson.dart b/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/errors.pbjson.dart index 629328d718..6dd1db5d5e 100644 --- a/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/errors.pbjson.dart +++ b/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/errors.pbjson.dart @@ -13,11 +13,13 @@ const ErrorCode$json = const { '1': 'ErrorCode', '2': const [ const {'1': 'InternalError', '2': 0}, + const {'1': 'DuplicateSource', '2': 1}, + const {'1': 'UnsupportedMessage', '2': 2}, ], }; /// Descriptor for `ErrorCode`. Decode as a `google.protobuf.EnumDescriptorProto`. -final $typed_data.Uint8List errorCodeDescriptor = $convert.base64Decode('CglFcnJvckNvZGUSEQoNSW50ZXJuYWxFcnJvchAA'); +final $typed_data.Uint8List errorCodeDescriptor = $convert.base64Decode('CglFcnJvckNvZGUSEQoNSW50ZXJuYWxFcnJvchAAEhMKD0R1cGxpY2F0ZVNvdXJjZRABEhYKElVuc3VwcG9ydGVkTWVzc2FnZRAC'); @$core.Deprecated('Use wsErrorDescriptor instead') const WsError$json = const { '1': 'WsError', diff --git a/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pb.dart b/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pb.dart new file mode 100644 index 0000000000..1067674ec2 --- /dev/null +++ b/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pb.dart @@ -0,0 +1,72 @@ +/// +// Generated code. Do not modify. +// source: msg.proto +// +// @dart = 2.12 +// ignore_for_file: annotate_overrides,camel_case_types,unnecessary_const,non_constant_identifier_names,library_prefixes,unused_import,unused_shown_name,return_of_invalid_type,unnecessary_this,prefer_final_fields + +import 'dart:core' as $core; + +import 'package:protobuf/protobuf.dart' as $pb; + +class WsMessage extends $pb.GeneratedMessage { + static final $pb.BuilderInfo _i = $pb.BuilderInfo(const $core.bool.fromEnvironment('protobuf.omit_message_names') ? '' : 'WsMessage', createEmptyInstance: create) + ..aOS(1, const $core.bool.fromEnvironment('protobuf.omit_field_names') ? '' : 'source') + ..a<$core.List<$core.int>>(2, const $core.bool.fromEnvironment('protobuf.omit_field_names') ? '' : 'data', $pb.PbFieldType.OY) + ..hasRequiredFields = false + ; + + WsMessage._() : super(); + factory WsMessage({ + $core.String? source, + $core.List<$core.int>? data, + }) { + final _result = create(); + if (source != null) { + _result.source = source; + } + if (data != null) { + _result.data = data; + } + return _result; + } + factory WsMessage.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r); + factory WsMessage.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r); + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.deepCopy] instead. ' + 'Will be removed in next major version') + WsMessage clone() => WsMessage()..mergeFromMessage(this); + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.rebuild] instead. ' + 'Will be removed in next major version') + WsMessage copyWith(void Function(WsMessage) updates) => super.copyWith((message) => updates(message as WsMessage)) as WsMessage; // ignore: deprecated_member_use + $pb.BuilderInfo get info_ => _i; + @$core.pragma('dart2js:noInline') + static WsMessage create() => WsMessage._(); + WsMessage createEmptyInstance() => create(); + static $pb.PbList createRepeated() => $pb.PbList(); + @$core.pragma('dart2js:noInline') + static WsMessage getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor(create); + static WsMessage? _defaultInstance; + + @$pb.TagNumber(1) + $core.String get source => $_getSZ(0); + @$pb.TagNumber(1) + set source($core.String v) { $_setString(0, v); } + @$pb.TagNumber(1) + $core.bool hasSource() => $_has(0); + @$pb.TagNumber(1) + void clearSource() => clearField(1); + + @$pb.TagNumber(2) + $core.List<$core.int> get data => $_getN(1); + @$pb.TagNumber(2) + set data($core.List<$core.int> v) { $_setBytes(1, v); } + @$pb.TagNumber(2) + $core.bool hasData() => $_has(1); + @$pb.TagNumber(2) + void clearData() => clearField(2); +} + diff --git a/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pbenum.dart b/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pbenum.dart new file mode 100644 index 0000000000..59dcf67a9f --- /dev/null +++ b/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pbenum.dart @@ -0,0 +1,7 @@ +/// +// Generated code. Do not modify. +// source: msg.proto +// +// @dart = 2.12 +// ignore_for_file: annotate_overrides,camel_case_types,unnecessary_const,non_constant_identifier_names,library_prefixes,unused_import,unused_shown_name,return_of_invalid_type,unnecessary_this,prefer_final_fields + diff --git a/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pbjson.dart b/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pbjson.dart new file mode 100644 index 0000000000..8c19e9c91c --- /dev/null +++ b/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pbjson.dart @@ -0,0 +1,21 @@ +/// +// Generated code. Do not modify. +// source: msg.proto +// +// @dart = 2.12 +// ignore_for_file: annotate_overrides,camel_case_types,unnecessary_const,non_constant_identifier_names,library_prefixes,unused_import,unused_shown_name,return_of_invalid_type,unnecessary_this,prefer_final_fields,deprecated_member_use_from_same_package + +import 'dart:core' as $core; +import 'dart:convert' as $convert; +import 'dart:typed_data' as $typed_data; +@$core.Deprecated('Use wsMessageDescriptor instead') +const WsMessage$json = const { + '1': 'WsMessage', + '2': const [ + const {'1': 'source', '3': 1, '4': 1, '5': 9, '10': 'source'}, + const {'1': 'data', '3': 2, '4': 1, '5': 12, '10': 'data'}, + ], +}; + +/// Descriptor for `WsMessage`. Decode as a `google.protobuf.DescriptorProto`. +final $typed_data.Uint8List wsMessageDescriptor = $convert.base64Decode('CglXc01lc3NhZ2USFgoGc291cmNlGAEgASgJUgZzb3VyY2USEgoEZGF0YRgCIAEoDFIEZGF0YQ=='); diff --git a/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pbserver.dart b/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pbserver.dart new file mode 100644 index 0000000000..e6a7eccb26 --- /dev/null +++ b/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/msg.pbserver.dart @@ -0,0 +1,9 @@ +/// +// Generated code. Do not modify. +// source: msg.proto +// +// @dart = 2.12 +// ignore_for_file: annotate_overrides,camel_case_types,unnecessary_const,non_constant_identifier_names,library_prefixes,unused_import,unused_shown_name,return_of_invalid_type,unnecessary_this,prefer_final_fields,deprecated_member_use_from_same_package + +export 'msg.pb.dart'; + diff --git a/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/protobuf.dart b/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/protobuf.dart index 92eb134641..3d5e1cc240 100644 --- a/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/protobuf.dart +++ b/app_flowy/packages/flowy_sdk/lib/protobuf/flowy-ws/protobuf.dart @@ -1,2 +1,3 @@ // Auto-generated, do not edit export './errors.pb.dart'; +export './msg.pb.dart'; diff --git a/backend/src/middleware/auth_middleware.rs b/backend/src/middleware/auth_middleware.rs index 97e30dfb5e..34b4615203 100644 --- a/backend/src/middleware/auth_middleware.rs +++ b/backend/src/middleware/auth_middleware.rs @@ -56,7 +56,7 @@ where fn call(&self, req: ServiceRequest) -> Self::Future { let mut authenticate_pass: bool = false; for ignore_route in IGNORE_ROUTES.iter() { - log::info!("ignore: {}, path: {}", ignore_route, req.path()); + // log::info!("ignore: {}, path: {}", ignore_route, req.path()); if req.path().starts_with(ignore_route) { authenticate_pass = true; break; diff --git a/backend/tests/api/helper.rs b/backend/tests/api/helper.rs index 3c11f2f31b..297980203e 100644 --- a/backend/tests/api/helper.rs +++ b/backend/tests/api/helper.rs @@ -13,7 +13,7 @@ use sqlx::{Connection, Executor, PgConnection, PgPool}; use uuid::Uuid; pub struct TestServer { - pub address: String, + pub host: String, pub port: u16, pub pg_pool: PgPool, pub user_token: Option, @@ -30,12 +30,12 @@ impl TestServer { } pub async fn sign_in(&self, params: SignInParams) -> Result { - let url = format!("{}/api/auth", self.address); + let url = format!("{}/api/auth", self.http_addr()); user_sign_in_request(params, &url).await } pub async fn sign_out(&self) { - let url = format!("{}/api/auth", self.address); + let url = format!("{}/api/auth", self.http_addr()); let _ = user_sign_out_request(self.user_token(), &url) .await .unwrap(); @@ -54,7 +54,7 @@ impl TestServer { } pub async fn get_user_profile(&self) -> UserProfile { - let url = format!("{}/api/user", self.address); + let url = format!("{}/api/user", self.http_addr()); let user_profile = get_user_profile_request(self.user_token(), &url) .await .unwrap(); @@ -62,12 +62,12 @@ impl TestServer { } pub async fn update_user_profile(&self, params: UpdateUserParams) -> Result<(), UserError> { - let url = format!("{}/api/user", self.address); + let url = format!("{}/api/user", self.http_addr()); update_user_profile_request(self.user_token(), params, &url).await } pub async fn create_workspace(&self, params: CreateWorkspaceParams) -> Workspace { - let url = format!("{}/api/workspace", self.address); + let url = format!("{}/api/workspace", self.http_addr()); let workspace = create_workspace_request(self.user_token(), params, &url) .await .unwrap(); @@ -75,7 +75,7 @@ impl TestServer { } pub async fn read_workspaces(&self, params: QueryWorkspaceParams) -> RepeatedWorkspace { - let url = format!("{}/api/workspace", self.address); + let url = format!("{}/api/workspace", self.http_addr()); let workspaces = read_workspaces_request(self.user_token(), params, &url) .await .unwrap(); @@ -83,21 +83,21 @@ impl TestServer { } pub async fn update_workspace(&self, params: UpdateWorkspaceParams) { - let url = format!("{}/api/workspace", self.address); + let url = format!("{}/api/workspace", self.http_addr()); update_workspace_request(self.user_token(), params, &url) .await .unwrap(); } pub async fn delete_workspace(&self, params: DeleteWorkspaceParams) { - let url = format!("{}/api/workspace", self.address); + let url = format!("{}/api/workspace", self.http_addr()); delete_workspace_request(self.user_token(), params, &url) .await .unwrap(); } pub async fn create_app(&self, params: CreateAppParams) -> App { - let url = format!("{}/api/app", self.address); + let url = format!("{}/api/app", self.http_addr()); let app = create_app_request(self.user_token(), params, &url) .await .unwrap(); @@ -105,7 +105,7 @@ impl TestServer { } pub async fn read_app(&self, params: QueryAppParams) -> Option { - let url = format!("{}/api/app", self.address); + let url = format!("{}/api/app", self.http_addr()); let app = read_app_request(self.user_token(), params, &url) .await .unwrap(); @@ -113,21 +113,21 @@ impl TestServer { } pub async fn update_app(&self, params: UpdateAppParams) { - let url = format!("{}/api/app", self.address); + let url = format!("{}/api/app", self.http_addr()); update_app_request(self.user_token(), params, &url) .await .unwrap(); } pub async fn delete_app(&self, params: DeleteAppParams) { - let url = format!("{}/api/app", self.address); + let url = format!("{}/api/app", self.http_addr()); delete_app_request(self.user_token(), params, &url) .await .unwrap(); } pub async fn create_view(&self, params: CreateViewParams) -> View { - let url = format!("{}/api/view", self.address); + let url = format!("{}/api/view", self.http_addr()); let view = create_view_request(self.user_token(), params, &url) .await .unwrap(); @@ -135,7 +135,7 @@ impl TestServer { } pub async fn read_view(&self, params: QueryViewParams) -> Option { - let url = format!("{}/api/view", self.address); + let url = format!("{}/api/view", self.http_addr()); let view = read_view_request(self.user_token(), params, &url) .await .unwrap(); @@ -143,21 +143,21 @@ impl TestServer { } pub async fn update_view(&self, params: UpdateViewParams) { - let url = format!("{}/api/view", self.address); + let url = format!("{}/api/view", self.http_addr()); update_view_request(self.user_token(), params, &url) .await .unwrap(); } pub async fn delete_view(&self, params: DeleteViewParams) { - let url = format!("{}/api/view", self.address); + let url = format!("{}/api/view", self.http_addr()); delete_view_request(self.user_token(), params, &url) .await .unwrap(); } pub async fn read_doc(&self, params: QueryDocParams) -> Option { - let url = format!("{}/api/doc", self.address); + let url = format!("{}/api/doc", self.http_addr()); let doc = read_doc_request(self.user_token(), params, &url) .await .unwrap(); @@ -175,13 +175,19 @@ impl TestServer { } pub(crate) async fn register(&self, params: SignUpParams) -> SignUpResponse { - let url = format!("{}/api/register", self.address); + let url = format!("{}/api/register", self.http_addr()); let response = user_sign_up_request(params, &url).await.unwrap(); response } + pub(crate) fn http_addr(&self) -> String { format!("http://{}", self.host) } + pub(crate) fn ws_addr(&self) -> String { - format!("{}/ws/{}", self.address, self.user_token.as_ref().unwrap()) + format!( + "ws://{}/ws/{}", + self.host, + self.user_token.as_ref().unwrap() + ) } } pub async fn spawn_server() -> TestServer { @@ -206,7 +212,7 @@ pub async fn spawn_server() -> TestServer { }); TestServer { - address: format!("http://localhost:{}", application_port), + host: format!("localhost:{}", application_port), port: application_port, pg_pool: get_connection_pool(&configuration.database) .await diff --git a/backend/tests/api/ws.rs b/backend/tests/api/ws.rs index 61d09d6af0..f2fc908d60 100644 --- a/backend/tests/api/ws.rs +++ b/backend/tests/api/ws.rs @@ -6,5 +6,5 @@ async fn ws_connect() { let server = TestServer::new().await; let mut controller = WsController::new(); let addr = server.ws_addr(); - let _ = controller.connect(addr).await.unwrap(); + let _ = controller.connect(addr).unwrap().await; } diff --git a/rust-lib/flowy-ast/src/ty_ext.rs b/rust-lib/flowy-ast/src/ty_ext.rs index a20d1e9006..8f1fb5ceaf 100644 --- a/rust-lib/flowy-ast/src/ty_ext.rs +++ b/rust-lib/flowy-ast/src/ty_ext.rs @@ -60,7 +60,7 @@ pub fn parse_ty<'a>(ctxt: &Ctxt, ty: &'a syn::Type) -> Option> { "Vec" => generate_vec_ty_info(ctxt, seg, bracketed), "Option" => generate_option_ty_info(ctxt, ty, seg, bracketed), _ => { - panic!("Unsupported ty") + panic!("Unsupported ty {}", seg.ident.to_string()) }, } } else { diff --git a/rust-lib/flowy-derive/src/derive_cache/derive_cache.rs b/rust-lib/flowy-derive/src/derive_cache/derive_cache.rs index a56895fb31..f4f4815f6a 100644 --- a/rust-lib/flowy-derive/src/derive_cache/derive_cache.rs +++ b/rust-lib/flowy-derive/src/derive_cache/derive_cache.rs @@ -54,6 +54,7 @@ pub fn category_from_str(type_str: &str) -> TypeCategory { | "RepeatedView" | "WorkspaceError" | "WsError" + | "WsMessage" | "CreateDocParams" | "Doc" | "SaveDocParams" diff --git a/rust-lib/flowy-dispatch/tests/api/module.rs b/rust-lib/flowy-dispatch/tests/api/module.rs index 8dfac5a6dc..4b9e8d06be 100644 --- a/rust-lib/flowy-dispatch/tests/api/module.rs +++ b/rust-lib/flowy-dispatch/tests/api/module.rs @@ -1,4 +1,3 @@ -use crate::helper::*; use flowy_dispatch::prelude::*; use std::sync::Arc; diff --git a/rust-lib/flowy-net/src/config.rs b/rust-lib/flowy-net/src/config.rs index 0696108fa2..ab1fe8ec70 100644 --- a/rust-lib/flowy-net/src/config.rs +++ b/rust-lib/flowy-net/src/config.rs @@ -1,20 +1,20 @@ use lazy_static::lazy_static; -pub const HOST: &'static str = "http://localhost:8000"; - +pub const HOST: &'static str = "localhost:8000"; +pub const SCHEMA: &'static str = "http://"; pub const HEADER_TOKEN: &'static str = "token"; lazy_static! { - pub static ref SIGN_UP_URL: String = format!("{}/api/register", HOST); - pub static ref SIGN_IN_URL: String = format!("{}/api/auth", HOST); - pub static ref SIGN_OUT_URL: String = format!("{}/api/auth", HOST); - pub static ref USER_PROFILE_URL: String = format!("{}/api/user", HOST); + pub static ref SIGN_UP_URL: String = format!("{}/{}/api/register", SCHEMA, HOST); + pub static ref SIGN_IN_URL: String = format!("{}/{}/api/auth", SCHEMA, HOST); + pub static ref SIGN_OUT_URL: String = format!("{}/{}/api/auth", SCHEMA, HOST); + pub static ref USER_PROFILE_URL: String = format!("{}/{}/api/user", SCHEMA, HOST); // - pub static ref WORKSPACE_URL: String = format!("{}/api/workspace", HOST); - pub static ref APP_URL: String = format!("{}/api/app", HOST); - pub static ref VIEW_URL: String = format!("{}/api/view", HOST); - pub static ref DOC_URL: String = format!("{}/api/doc", HOST); + pub static ref WORKSPACE_URL: String = format!("{}/{}/api/workspace", SCHEMA, HOST); + pub static ref APP_URL: String = format!("{}/{}/api/app", SCHEMA, HOST); + pub static ref VIEW_URL: String = format!("{}/{}/api/view", SCHEMA, HOST); + pub static ref DOC_URL: String = format!("{}/{}/api/doc", SCHEMA, HOST); - pub static ref WS_ADDR: String = format!("ws://localhost:8000/ws"); + pub static ref WS_ADDR: String = format!("ws://{}/ws", HOST); } diff --git a/rust-lib/flowy-user/src/errors.rs b/rust-lib/flowy-user/src/errors.rs index cabaf191bd..6686f3101b 100644 --- a/rust-lib/flowy-user/src/errors.rs +++ b/rust-lib/flowy-user/src/errors.rs @@ -109,6 +109,10 @@ impl std::convert::From<::r2d2::Error> for UserError { fn from(error: r2d2::Error) -> Self { UserError::internal().context(error) } } +impl std::convert::From for UserError { + fn from(error: flowy_ws::errors::WsError) -> Self { UserError::internal().context(error) } +} + // use diesel::result::{Error, DatabaseErrorKind}; // use flowy_sqlite::ErrorKind; impl std::convert::From for UserError { diff --git a/rust-lib/flowy-user/src/services/user/user_session.rs b/rust-lib/flowy-user/src/services/user/user_session.rs index 1cbb783323..350cb3ab69 100644 --- a/rust-lib/flowy-user/src/services/user/user_session.rs +++ b/rust-lib/flowy-user/src/services/user/user_session.rs @@ -18,10 +18,10 @@ use flowy_database::{ }; use flowy_infra::kv::KV; use flowy_sqlite::ConnectionPool; -use flowy_ws::WsController; +use flowy_ws::{WsController, WsMessage, WsMessageHandler}; use parking_lot::RwLock; use serde::{Deserialize, Serialize}; -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; pub struct UserSessionConfig { root_dir: String, @@ -47,7 +47,7 @@ pub struct UserSession { #[allow(dead_code)] server: Server, session: RwLock>, - ws: RwLock, + ws_controller: RwLock, status_callback: SessionStatusCallback, } @@ -55,13 +55,13 @@ impl UserSession { pub fn new(config: UserSessionConfig, status_callback: SessionStatusCallback) -> Self { let db = UserDB::new(&config.root_dir); let server = construct_user_server(); - let ws = RwLock::new(WsController::new()); + let ws_controller = RwLock::new(WsController::new()); let user_session = Self { database: db, config, server, session: RwLock::new(None), - ws, + ws_controller, status_callback, }; user_session @@ -172,6 +172,21 @@ impl UserSession { pub fn user_id(&self) -> Result { Ok(self.get_session()?.user_id) } pub fn token(&self) -> Result { Ok(self.get_session()?.token) } + + pub fn add_ws_msg_handler(&self, handler: Arc) -> Result<(), UserError> { + let _ = self.ws_controller.write().add_handler(handler)?; + Ok(()) + } + + pub fn send_ws_msg>(&self, msg: T) -> Result<(), UserError> { + match self.ws_controller.try_read_for(Duration::from_millis(300)) { + None => Err(UserError::internal().context("Send ws message timeout")), + Some(guard) => { + let _ = guard.send_msg(msg)?; + Ok(()) + }, + } + } } impl UserSession { @@ -263,20 +278,7 @@ impl UserSession { fn start_ws_connection(&self, token: &str) -> Result<(), UserError> { let addr = format!("{}/{}", flowy_net::config::WS_ADDR.as_str(), token); - log::debug!("🐴 Try to connect: {}", &addr); - let (conn, handlers) = self.ws.write().make_connect(addr); - tokio::spawn(async { - match conn.await { - Ok(_) => { - log::debug!("🐴 ws connect success"); - let _ = handlers.await; - }, - Err(e) => { - // TODO: retry? - log::error!("ws connect failed: {}", e); - }, - } - }); + let _ = self.ws_controller.write().connect(addr); Ok(()) } } diff --git a/rust-lib/flowy-ws/Flowy.toml b/rust-lib/flowy-ws/Flowy.toml index 642bd7427c..fec0b43d8b 100644 --- a/rust-lib/flowy-ws/Flowy.toml +++ b/rust-lib/flowy-ws/Flowy.toml @@ -1,2 +1,2 @@ -proto_crates = ["src/errors.rs"] +proto_crates = ["src/errors.rs", "src/msg.rs"] event_files = [] \ No newline at end of file diff --git a/rust-lib/flowy-ws/src/connect.rs b/rust-lib/flowy-ws/src/connect.rs new file mode 100644 index 0000000000..6c388278e1 --- /dev/null +++ b/rust-lib/flowy-ws/src/connect.rs @@ -0,0 +1,207 @@ +use crate::{errors::WsError, MsgReceiver, MsgSender, WsMessage}; +use flowy_net::errors::ServerError; +use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender}; +use futures_core::{future::BoxFuture, ready, Stream}; +use futures_util::{ + future, + future::{Either, Select}, + pin_mut, + FutureExt, + StreamExt, +}; +use pin_project::pin_project; +use std::{ + collections::HashMap, + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +use tokio::{net::TcpStream, task::JoinHandle}; +use tokio_tungstenite::{ + connect_async, + tungstenite::{handshake::client::Response, http::StatusCode, Error, Message}, + MaybeTlsStream, + WebSocketStream, +}; + +#[pin_project] +pub struct WsConnection { + msg_tx: Option, + ws_rx: Option, + #[pin] + fut: BoxFuture<'static, Result<(WebSocketStream>, Response), Error>>, +} + +impl WsConnection { + pub fn new(msg_tx: MsgSender, ws_rx: MsgReceiver, addr: String) -> Self { + WsConnection { + msg_tx: Some(msg_tx), + ws_rx: Some(ws_rx), + fut: Box::pin(async move { connect_async(&addr).await }), + } + } +} + +impl Future for WsConnection { + type Output = Result; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // [[pin]] + // poll async function. The following methods not work. + // 1. + // let f = connect_async(""); + // pin_mut!(f); + // ready!(Pin::new(&mut a).poll(cx)) + // + // 2.ready!(Pin::new(&mut Box::pin(connect_async(""))).poll(cx)) + // + // An async method calls poll multiple times and might return to the executor. A + // single poll call can only return to the executor once and will get + // resumed through another poll invocation. the connect_async call multiple time + // from the beginning. So I use fut to hold the future and continue to + // poll it. (Fix me if i was wrong) + loop { + return match ready!(self.as_mut().project().fut.poll(cx)) { + Ok((stream, _)) => { + log::debug!("🐴 ws connect success"); + let (msg_tx, ws_rx) = (self.msg_tx.take().unwrap(), self.ws_rx.take().unwrap()); + Poll::Ready(Ok(WsStream::new(msg_tx, ws_rx, stream))) + }, + Err(error) => Poll::Ready(Err(error_to_flowy_response(error))), + }; + } + } +} + +#[pin_project] +pub struct WsStream { + msg_tx: MsgSender, + #[pin] + fut: Option<(BoxFuture<'static, ()>, BoxFuture<'static, ()>)>, +} + +impl WsStream { + pub fn new(msg_tx: MsgSender, ws_rx: MsgReceiver, stream: WebSocketStream>) -> Self { + let (ws_write, ws_read) = stream.split(); + let to_ws = ws_rx.map(Ok).forward(ws_write); + let from_ws = ws_read.for_each(|message| async { + // handle_new_message(msg_tx.clone(), message) + }); + // pin_mut!(to_ws, from_ws); + Self { + msg_tx, + fut: Some(( + Box::pin(async move { + let _ = from_ws.await; + }), + Box::pin(async move { + let _ = to_ws.await; + }), + )), + } + } +} + +impl Future for WsStream { + type Output = (); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let (mut a, mut b) = self.fut.take().unwrap(); + match a.poll_unpin(cx) { + Poll::Ready(x) => Poll::Ready(()), + Poll::Pending => match b.poll_unpin(cx) { + Poll::Ready(x) => Poll::Ready(()), + Poll::Pending => { + // self.fut = Some((a, b)); + Poll::Pending + }, + }, + } + } +} + +// pub struct WsStream { +// msg_tx: Option, +// ws_rx: Option, +// stream: Option>>, +// } +// +// impl WsStream { +// pub fn new(msg_tx: MsgSender, ws_rx: MsgReceiver, stream: +// WebSocketStream>) -> Self { Self { +// msg_tx: Some(msg_tx), +// ws_rx: Some(ws_rx), +// stream: Some(stream), +// } +// } +// +// pub fn start(mut self) -> JoinHandle<()> { +// let (msg_tx, ws_rx) = (self.msg_tx.take().unwrap(), +// self.ws_rx.take().unwrap()); let (ws_write, ws_read) = +// self.stream.take().unwrap().split(); tokio::spawn(async move { +// let to_ws = ws_rx.map(Ok).forward(ws_write); +// let from_ws = ws_read.for_each(|message| async { +// handle_new_message(msg_tx.clone(), message) }); pin_mut!(to_ws, +// from_ws); +// +// match future::select(to_ws, from_ws).await { +// Either::Left(_l) => { +// log::info!("ws left"); +// }, +// Either::Right(_r) => { +// log::info!("ws right"); +// }, +// } +// }) +// } +// } +// +// impl Future for WsStream { +// type Output = (); +// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> +// Poll { let (msg_tx, ws_rx) = +// (self.msg_tx.take().unwrap(), self.ws_rx.take().unwrap()); let +// (ws_write, ws_read) = self.stream.take().unwrap().split(); let to_ws +// = ws_rx.map(Ok).forward(ws_write); let from_ws = +// ws_read.for_each(|message| async { handle_new_message(msg_tx.clone(), +// message) }); pin_mut!(to_ws, from_ws); +// +// loop { +// match ready!(Pin::new(&mut future::select(to_ws, +// from_ws)).poll(cx)) { Either::Left(a) => { +// // +// return Poll::Ready(()); +// }, +// Either::Right(b) => { +// // +// return Poll::Ready(()); +// }, +// } +// } +// } +// } + +fn handle_new_message(tx: MsgSender, message: Result) { + match message { + Ok(Message::Binary(bytes)) => match tx.unbounded_send(Message::Binary(bytes)) { + Ok(_) => {}, + Err(e) => log::error!("tx send error: {:?}", e), + }, + Ok(_) => {}, + Err(e) => log::error!("ws read error: {:?}", e), + } +} + +fn error_to_flowy_response(error: tokio_tungstenite::tungstenite::Error) -> ServerError { + let error = match error { + Error::Http(response) => { + if response.status() == StatusCode::UNAUTHORIZED { + ServerError::unauthorized() + } else { + ServerError::internal().context(response) + } + }, + _ => ServerError::internal().context(error), + }; + + error +} diff --git a/rust-lib/flowy-ws/src/errors.rs b/rust-lib/flowy-ws/src/errors.rs index d0f02dda19..557d7ccbca 100644 --- a/rust-lib/flowy-ws/src/errors.rs +++ b/rust-lib/flowy-ws/src/errors.rs @@ -36,11 +36,15 @@ impl WsError { } static_user_error!(internal, ErrorCode::InternalError); + static_user_error!(duplicate_source, ErrorCode::DuplicateSource); + static_user_error!(unsupported_message, ErrorCode::UnsupportedMessage); } #[derive(Debug, Clone, ProtoBuf_Enum, Display, PartialEq, Eq)] pub enum ErrorCode { - InternalError = 0, + InternalError = 0, + DuplicateSource = 1, + UnsupportedMessage = 2, } impl std::default::Default for ErrorCode { @@ -51,6 +55,10 @@ impl std::convert::From for WsError { fn from(error: ParseError) -> Self { WsError::internal().context(error) } } +impl std::convert::From for WsError { + fn from(error: protobuf::ProtobufError) -> Self { WsError::internal().context(error) } +} + impl std::convert::From> for WsError { fn from(error: TrySendError) -> Self { WsError::internal().context(error) } } diff --git a/rust-lib/flowy-ws/src/lib.rs b/rust-lib/flowy-ws/src/lib.rs index 40e2b632dd..307c5ff600 100644 --- a/rust-lib/flowy-ws/src/lib.rs +++ b/rust-lib/flowy-ws/src/lib.rs @@ -1,5 +1,8 @@ +mod connect; pub mod errors; +mod msg; pub mod protobuf; mod ws; +pub use msg::*; pub use ws::*; diff --git a/rust-lib/flowy-ws/src/msg.rs b/rust-lib/flowy-ws/src/msg.rs new file mode 100644 index 0000000000..d95455286c --- /dev/null +++ b/rust-lib/flowy-ws/src/msg.rs @@ -0,0 +1,38 @@ +use bytes::Bytes; +use flowy_derive::ProtoBuf; +use std::convert::{TryFrom, TryInto}; +use tokio_tungstenite::tungstenite::Message; + +#[derive(ProtoBuf, Debug, Clone, Default)] +pub struct WsMessage { + #[pb(index = 1)] + pub source: String, + + #[pb(index = 2)] + pub data: Vec, +} + +impl std::convert::Into for WsMessage { + fn into(self) -> Message { + let result: Result = self.try_into(); + match result { + Ok(bytes) => Message::Binary(bytes.to_vec()), + Err(e) => { + log::error!("WsMessage serialize error: {:?}", e); + Message::Binary(vec![]) + }, + } + } +} + +impl std::convert::From for WsMessage { + fn from(value: Message) -> Self { + match value { + Message::Binary(bytes) => WsMessage::try_from(Bytes::from(bytes)).unwrap(), + _ => { + log::error!("WsMessage deserialize failed. Unsupported message"); + WsMessage::default() + }, + } + } +} diff --git a/rust-lib/flowy-ws/src/protobuf/model/errors.rs b/rust-lib/flowy-ws/src/protobuf/model/errors.rs index b8130290f8..655a4ec3bd 100644 --- a/rust-lib/flowy-ws/src/protobuf/model/errors.rs +++ b/rust-lib/flowy-ws/src/protobuf/model/errors.rs @@ -216,6 +216,8 @@ impl ::protobuf::reflect::ProtobufValue for WsError { #[derive(Clone,PartialEq,Eq,Debug,Hash)] pub enum ErrorCode { InternalError = 0, + DuplicateSource = 1, + UnsupportedMessage = 2, } impl ::protobuf::ProtobufEnum for ErrorCode { @@ -226,6 +228,8 @@ impl ::protobuf::ProtobufEnum for ErrorCode { fn from_i32(value: i32) -> ::std::option::Option { match value { 0 => ::std::option::Option::Some(ErrorCode::InternalError), + 1 => ::std::option::Option::Some(ErrorCode::DuplicateSource), + 2 => ::std::option::Option::Some(ErrorCode::UnsupportedMessage), _ => ::std::option::Option::None } } @@ -233,6 +237,8 @@ impl ::protobuf::ProtobufEnum for ErrorCode { fn values() -> &'static [Self] { static values: &'static [ErrorCode] = &[ ErrorCode::InternalError, + ErrorCode::DuplicateSource, + ErrorCode::UnsupportedMessage, ]; values } @@ -262,19 +268,24 @@ impl ::protobuf::reflect::ProtobufValue for ErrorCode { static file_descriptor_proto_data: &'static [u8] = b"\ \n\x0cerrors.proto\";\n\x07WsError\x12\x1e\n\x04code\x18\x01\x20\x01(\ - \x0e2\n.ErrorCodeR\x04code\x12\x10\n\x03msg\x18\x02\x20\x01(\tR\x03msg*\ - \x1e\n\tErrorCode\x12\x11\n\rInternalError\x10\0J\xd9\x01\n\x06\x12\x04\ - \0\0\x08\x01\n\x08\n\x01\x0c\x12\x03\0\0\x12\n\n\n\x02\x04\0\x12\x04\x02\ - \0\x05\x01\n\n\n\x03\x04\0\x01\x12\x03\x02\x08\x0f\n\x0b\n\x04\x04\0\x02\ - \0\x12\x03\x03\x04\x17\n\x0c\n\x05\x04\0\x02\0\x06\x12\x03\x03\x04\r\n\ - \x0c\n\x05\x04\0\x02\0\x01\x12\x03\x03\x0e\x12\n\x0c\n\x05\x04\0\x02\0\ - \x03\x12\x03\x03\x15\x16\n\x0b\n\x04\x04\0\x02\x01\x12\x03\x04\x04\x13\n\ - \x0c\n\x05\x04\0\x02\x01\x05\x12\x03\x04\x04\n\n\x0c\n\x05\x04\0\x02\x01\ - \x01\x12\x03\x04\x0b\x0e\n\x0c\n\x05\x04\0\x02\x01\x03\x12\x03\x04\x11\ - \x12\n\n\n\x02\x05\0\x12\x04\x06\0\x08\x01\n\n\n\x03\x05\0\x01\x12\x03\ - \x06\x05\x0e\n\x0b\n\x04\x05\0\x02\0\x12\x03\x07\x04\x16\n\x0c\n\x05\x05\ - \0\x02\0\x01\x12\x03\x07\x04\x11\n\x0c\n\x05\x05\0\x02\0\x02\x12\x03\x07\ - \x14\x15b\x06proto3\ + \x0e2\n.ErrorCodeR\x04code\x12\x10\n\x03msg\x18\x02\x20\x01(\tR\x03msg*K\ + \n\tErrorCode\x12\x11\n\rInternalError\x10\0\x12\x13\n\x0fDuplicateSourc\ + e\x10\x01\x12\x16\n\x12UnsupportedMessage\x10\x02J\xab\x02\n\x06\x12\x04\ + \0\0\n\x01\n\x08\n\x01\x0c\x12\x03\0\0\x12\n\n\n\x02\x04\0\x12\x04\x02\0\ + \x05\x01\n\n\n\x03\x04\0\x01\x12\x03\x02\x08\x0f\n\x0b\n\x04\x04\0\x02\0\ + \x12\x03\x03\x04\x17\n\x0c\n\x05\x04\0\x02\0\x06\x12\x03\x03\x04\r\n\x0c\ + \n\x05\x04\0\x02\0\x01\x12\x03\x03\x0e\x12\n\x0c\n\x05\x04\0\x02\0\x03\ + \x12\x03\x03\x15\x16\n\x0b\n\x04\x04\0\x02\x01\x12\x03\x04\x04\x13\n\x0c\ + \n\x05\x04\0\x02\x01\x05\x12\x03\x04\x04\n\n\x0c\n\x05\x04\0\x02\x01\x01\ + \x12\x03\x04\x0b\x0e\n\x0c\n\x05\x04\0\x02\x01\x03\x12\x03\x04\x11\x12\n\ + \n\n\x02\x05\0\x12\x04\x06\0\n\x01\n\n\n\x03\x05\0\x01\x12\x03\x06\x05\ + \x0e\n\x0b\n\x04\x05\0\x02\0\x12\x03\x07\x04\x16\n\x0c\n\x05\x05\0\x02\0\ + \x01\x12\x03\x07\x04\x11\n\x0c\n\x05\x05\0\x02\0\x02\x12\x03\x07\x14\x15\ + \n\x0b\n\x04\x05\0\x02\x01\x12\x03\x08\x04\x18\n\x0c\n\x05\x05\0\x02\x01\ + \x01\x12\x03\x08\x04\x13\n\x0c\n\x05\x05\0\x02\x01\x02\x12\x03\x08\x16\ + \x17\n\x0b\n\x04\x05\0\x02\x02\x12\x03\t\x04\x1b\n\x0c\n\x05\x05\0\x02\ + \x02\x01\x12\x03\t\x04\x16\n\x0c\n\x05\x05\0\x02\x02\x02\x12\x03\t\x19\ + \x1ab\x06proto3\ "; static file_descriptor_proto_lazy: ::protobuf::rt::LazyV2<::protobuf::descriptor::FileDescriptorProto> = ::protobuf::rt::LazyV2::INIT; diff --git a/rust-lib/flowy-ws/src/protobuf/model/mod.rs b/rust-lib/flowy-ws/src/protobuf/model/mod.rs index 00f047d293..e082345eb3 100644 --- a/rust-lib/flowy-ws/src/protobuf/model/mod.rs +++ b/rust-lib/flowy-ws/src/protobuf/model/mod.rs @@ -2,3 +2,6 @@ mod errors; pub use errors::*; + +mod msg; +pub use msg::*; diff --git a/rust-lib/flowy-ws/src/protobuf/model/msg.rs b/rust-lib/flowy-ws/src/protobuf/model/msg.rs new file mode 100644 index 0000000000..8437ba0c20 --- /dev/null +++ b/rust-lib/flowy-ws/src/protobuf/model/msg.rs @@ -0,0 +1,250 @@ +// This file is generated by rust-protobuf 2.22.1. Do not edit +// @generated + +// https://github.com/rust-lang/rust-clippy/issues/702 +#![allow(unknown_lints)] +#![allow(clippy::all)] + +#![allow(unused_attributes)] +#![cfg_attr(rustfmt, rustfmt::skip)] + +#![allow(box_pointers)] +#![allow(dead_code)] +#![allow(missing_docs)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] +#![allow(non_upper_case_globals)] +#![allow(trivial_casts)] +#![allow(unused_imports)] +#![allow(unused_results)] +//! Generated file from `msg.proto` + +/// Generated files are compatible only with the same version +/// of protobuf runtime. +// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_22_1; + +#[derive(PartialEq,Clone,Default)] +pub struct WsMessage { + // message fields + pub source: ::std::string::String, + pub data: ::std::vec::Vec, + // special fields + pub unknown_fields: ::protobuf::UnknownFields, + pub cached_size: ::protobuf::CachedSize, +} + +impl<'a> ::std::default::Default for &'a WsMessage { + fn default() -> &'a WsMessage { + ::default_instance() + } +} + +impl WsMessage { + pub fn new() -> WsMessage { + ::std::default::Default::default() + } + + // string source = 1; + + + pub fn get_source(&self) -> &str { + &self.source + } + pub fn clear_source(&mut self) { + self.source.clear(); + } + + // Param is passed by value, moved + pub fn set_source(&mut self, v: ::std::string::String) { + self.source = v; + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_source(&mut self) -> &mut ::std::string::String { + &mut self.source + } + + // Take field + pub fn take_source(&mut self) -> ::std::string::String { + ::std::mem::replace(&mut self.source, ::std::string::String::new()) + } + + // bytes data = 2; + + + pub fn get_data(&self) -> &[u8] { + &self.data + } + pub fn clear_data(&mut self) { + self.data.clear(); + } + + // Param is passed by value, moved + pub fn set_data(&mut self, v: ::std::vec::Vec) { + self.data = v; + } + + // Mutable pointer to the field. + // If field is not initialized, it is initialized with default value first. + pub fn mut_data(&mut self) -> &mut ::std::vec::Vec { + &mut self.data + } + + // Take field + pub fn take_data(&mut self) -> ::std::vec::Vec { + ::std::mem::replace(&mut self.data, ::std::vec::Vec::new()) + } +} + +impl ::protobuf::Message for WsMessage { + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::ProtobufResult<()> { + while !is.eof()? { + let (field_number, wire_type) = is.read_tag_unpack()?; + match field_number { + 1 => { + ::protobuf::rt::read_singular_proto3_string_into(wire_type, is, &mut self.source)?; + }, + 2 => { + ::protobuf::rt::read_singular_proto3_bytes_into(wire_type, is, &mut self.data)?; + }, + _ => { + ::protobuf::rt::read_unknown_or_skip_group(field_number, wire_type, is, self.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u32 { + let mut my_size = 0; + if !self.source.is_empty() { + my_size += ::protobuf::rt::string_size(1, &self.source); + } + if !self.data.is_empty() { + my_size += ::protobuf::rt::bytes_size(2, &self.data); + } + my_size += ::protobuf::rt::unknown_fields_size(self.get_unknown_fields()); + self.cached_size.set(my_size); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::ProtobufResult<()> { + if !self.source.is_empty() { + os.write_string(1, &self.source)?; + } + if !self.data.is_empty() { + os.write_bytes(2, &self.data)?; + } + os.write_unknown_fields(self.get_unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn get_cached_size(&self) -> u32 { + self.cached_size.get() + } + + fn get_unknown_fields(&self) -> &::protobuf::UnknownFields { + &self.unknown_fields + } + + fn mut_unknown_fields(&mut self) -> &mut ::protobuf::UnknownFields { + &mut self.unknown_fields + } + + fn as_any(&self) -> &dyn (::std::any::Any) { + self as &dyn (::std::any::Any) + } + fn as_any_mut(&mut self) -> &mut dyn (::std::any::Any) { + self as &mut dyn (::std::any::Any) + } + fn into_any(self: ::std::boxed::Box) -> ::std::boxed::Box { + self + } + + fn descriptor(&self) -> &'static ::protobuf::reflect::MessageDescriptor { + Self::descriptor_static() + } + + fn new() -> WsMessage { + WsMessage::new() + } + + fn descriptor_static() -> &'static ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::LazyV2<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::LazyV2::INIT; + descriptor.get(|| { + let mut fields = ::std::vec::Vec::new(); + fields.push(::protobuf::reflect::accessor::make_simple_field_accessor::<_, ::protobuf::types::ProtobufTypeString>( + "source", + |m: &WsMessage| { &m.source }, + |m: &mut WsMessage| { &mut m.source }, + )); + fields.push(::protobuf::reflect::accessor::make_simple_field_accessor::<_, ::protobuf::types::ProtobufTypeBytes>( + "data", + |m: &WsMessage| { &m.data }, + |m: &mut WsMessage| { &mut m.data }, + )); + ::protobuf::reflect::MessageDescriptor::new_pb_name::( + "WsMessage", + fields, + file_descriptor_proto() + ) + }) + } + + fn default_instance() -> &'static WsMessage { + static instance: ::protobuf::rt::LazyV2 = ::protobuf::rt::LazyV2::INIT; + instance.get(WsMessage::new) + } +} + +impl ::protobuf::Clear for WsMessage { + fn clear(&mut self) { + self.source.clear(); + self.data.clear(); + self.unknown_fields.clear(); + } +} + +impl ::std::fmt::Debug for WsMessage { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for WsMessage { + fn as_ref(&self) -> ::protobuf::reflect::ReflectValueRef { + ::protobuf::reflect::ReflectValueRef::Message(self) + } +} + +static file_descriptor_proto_data: &'static [u8] = b"\ + \n\tmsg.proto\"7\n\tWsMessage\x12\x16\n\x06source\x18\x01\x20\x01(\tR\ + \x06source\x12\x12\n\x04data\x18\x02\x20\x01(\x0cR\x04dataJ\x98\x01\n\ + \x06\x12\x04\0\0\x05\x01\n\x08\n\x01\x0c\x12\x03\0\0\x12\n\n\n\x02\x04\0\ + \x12\x04\x02\0\x05\x01\n\n\n\x03\x04\0\x01\x12\x03\x02\x08\x11\n\x0b\n\ + \x04\x04\0\x02\0\x12\x03\x03\x04\x16\n\x0c\n\x05\x04\0\x02\0\x05\x12\x03\ + \x03\x04\n\n\x0c\n\x05\x04\0\x02\0\x01\x12\x03\x03\x0b\x11\n\x0c\n\x05\ + \x04\0\x02\0\x03\x12\x03\x03\x14\x15\n\x0b\n\x04\x04\0\x02\x01\x12\x03\ + \x04\x04\x13\n\x0c\n\x05\x04\0\x02\x01\x05\x12\x03\x04\x04\t\n\x0c\n\x05\ + \x04\0\x02\x01\x01\x12\x03\x04\n\x0e\n\x0c\n\x05\x04\0\x02\x01\x03\x12\ + \x03\x04\x11\x12b\x06proto3\ +"; + +static file_descriptor_proto_lazy: ::protobuf::rt::LazyV2<::protobuf::descriptor::FileDescriptorProto> = ::protobuf::rt::LazyV2::INIT; + +fn parse_descriptor_proto() -> ::protobuf::descriptor::FileDescriptorProto { + ::protobuf::Message::parse_from_bytes(file_descriptor_proto_data).unwrap() +} + +pub fn file_descriptor_proto() -> &'static ::protobuf::descriptor::FileDescriptorProto { + file_descriptor_proto_lazy.get(|| { + parse_descriptor_proto() + }) +} diff --git a/rust-lib/flowy-ws/src/protobuf/proto/errors.proto b/rust-lib/flowy-ws/src/protobuf/proto/errors.proto index 14a5c85098..349df305aa 100644 --- a/rust-lib/flowy-ws/src/protobuf/proto/errors.proto +++ b/rust-lib/flowy-ws/src/protobuf/proto/errors.proto @@ -6,4 +6,6 @@ message WsError { } enum ErrorCode { InternalError = 0; + DuplicateSource = 1; + UnsupportedMessage = 2; } diff --git a/rust-lib/flowy-ws/src/protobuf/proto/msg.proto b/rust-lib/flowy-ws/src/protobuf/proto/msg.proto new file mode 100644 index 0000000000..ce54b34bcd --- /dev/null +++ b/rust-lib/flowy-ws/src/protobuf/proto/msg.proto @@ -0,0 +1,6 @@ +syntax = "proto3"; + +message WsMessage { + string source = 1; + bytes data = 2; +} diff --git a/rust-lib/flowy-ws/src/ws.rs b/rust-lib/flowy-ws/src/ws.rs index 75b58ff99c..2445789886 100644 --- a/rust-lib/flowy-ws/src/ws.rs +++ b/rust-lib/flowy-ws/src/ws.rs @@ -1,16 +1,23 @@ -use crate::errors::WsError; -use flowy_net::{errors::ServerError, response::FlowyResponse}; +use crate::{connect::WsConnection, errors::WsError, WsMessage}; +use flowy_net::errors::ServerError; use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender}; use futures_core::{future::BoxFuture, ready, Stream}; -use futures_util::{pin_mut, FutureExt, StreamExt}; +use futures_util::{ + future, + future::{Either, Select}, + pin_mut, + FutureExt, + StreamExt, +}; use pin_project::pin_project; use std::{ + collections::HashMap, future::Future, pin::Pin, sync::Arc, task::{Context, Poll}, }; -use tokio::net::TcpStream; +use tokio::{net::TcpStream, task::JoinHandle}; use tokio_tungstenite::{ connect_async, tungstenite::{handshake::client::Response, http::StatusCode, Error, Message}, @@ -21,37 +28,56 @@ use tokio_tungstenite::{ pub type MsgReceiver = UnboundedReceiver; pub type MsgSender = UnboundedSender; pub trait WsMessageHandler: Sync + Send + 'static { - fn can_handle(&self) -> bool; - fn receive_message(&self, msg: &Message); - fn send_message(&self, sender: Arc); + fn source(&self) -> String; + fn receive_message(&self, msg: WsMessage); } pub struct WsController { sender: Option>, - handlers: Vec>, + handlers: HashMap>, } impl WsController { pub fn new() -> Self { let controller = Self { sender: None, - handlers: vec![], + handlers: HashMap::new(), }; - controller } - pub fn add_handlers(&mut self, handler: Arc) { self.handlers.push(handler); } - - #[allow(dead_code)] - pub async fn connect(&mut self, addr: String) -> Result<(), ServerError> { - let (conn, handlers) = self.make_connect(addr); - let _ = conn.await?; - let _ = tokio::spawn(handlers); + pub fn add_handler(&mut self, handler: Arc) -> Result<(), WsError> { + let source = handler.source(); + if self.handlers.contains_key(&source) { + return Err(WsError::duplicate_source()); + } + self.handlers.insert(source, handler); Ok(()) } - pub fn make_connect(&mut self, addr: String) -> (WsConnection, WsHandlers) { + pub fn connect(&mut self, addr: String) -> Result, ServerError> { + log::debug!("🐴 Try to connect: {}", &addr); + let (connection, handlers) = self.make_connect(addr); + Ok(tokio::spawn(async { + tokio::select! { + result = connection => { + match result { + Ok(stream) => { + tokio::spawn(stream).await; + // stream.start().await; + }, + Err(e) => { + // TODO: retry? + log::error!("ws connect failed {:?}", e); + } + } + }, + result = handlers => log::debug!("handlers completed {:?}", result), + }; + })) + } + + fn make_connect(&mut self, addr: String) -> (WsConnection, WsHandlers) { // Stream User // ┌───────────────┐ ┌──────────────┐ // ┌──────┐ │ ┌─────────┐ │ ┌────────┐ │ ┌────────┐ │ @@ -64,16 +90,15 @@ impl WsController { // └───────────────┘ └──────────────┘ let (msg_tx, msg_rx) = futures_channel::mpsc::unbounded(); let (ws_tx, ws_rx) = futures_channel::mpsc::unbounded(); - let sender = Arc::new(WsSender::new(ws_tx)); let handlers = self.handlers.clone(); - self.sender = Some(sender.clone()); + self.sender = Some(Arc::new(WsSender::new(ws_tx))); (WsConnection::new(msg_tx, ws_rx, addr), WsHandlers::new(handlers, msg_rx)) } - pub fn send_message(&self, msg: Message) -> Result<(), WsError> { - match &self.sender { - None => panic!(), - Some(conn) => conn.send(msg), + pub fn send_msg>(&self, msg: T) -> Result<(), WsError> { + match self.sender.as_ref() { + None => Err(WsError::internal().context("Should call make_connect first")), + Some(sender) => sender.send(msg.into()), } } } @@ -82,11 +107,11 @@ impl WsController { pub struct WsHandlers { #[pin] msg_rx: MsgReceiver, - handlers: Vec>, + handlers: HashMap>, } impl WsHandlers { - fn new(handlers: Vec>, msg_rx: MsgReceiver) -> Self { Self { msg_rx, handlers } } + fn new(handlers: HashMap>, msg_rx: MsgReceiver) -> Self { Self { msg_rx, handlers } } } impl Future for WsHandlers { @@ -94,130 +119,31 @@ impl Future for WsHandlers { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { match ready!(self.as_mut().project().msg_rx.poll_next(cx)) { - None => return Poll::Ready(()), - Some(message) => self.handlers.iter().for_each(|handler| { - handler.receive_message(&message); - }), - } - } - } -} - -#[pin_project] -pub struct WsConnection { - msg_tx: Option, - ws_rx: Option, - #[pin] - fut: BoxFuture<'static, Result<(WebSocketStream>, Response), Error>>, -} - -impl WsConnection { - pub fn new(msg_tx: MsgSender, ws_rx: MsgReceiver, addr: String) -> Self { - WsConnection { - msg_tx: Some(msg_tx), - ws_rx: Some(ws_rx), - fut: Box::pin(async move { connect_async(&addr).await }), - } - } -} - -impl Future for WsConnection { - type Output = Result<(), ServerError>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - // [[pin]] - // poll async function. The following methods not work. - // 1. - // let f = connect_async(""); - // pin_mut!(f); - // ready!(Pin::new(&mut a).poll(cx)) - // - // 2.ready!(Pin::new(&mut Box::pin(connect_async(""))).poll(cx)) - // - // An async method calls poll multiple times and might return to the executor. A - // single poll call can only return to the executor once and will get - // resumed through another poll invocation. the connect_async call multiple time - // from the beginning. So I use fut to hold the future and continue to - // poll it. (Fix me if i was wrong) - - loop { - return match ready!(self.as_mut().project().fut.poll(cx)) { - Ok((stream, _)) => { - let mut ws_stream = WsStream { - msg_tx: self.msg_tx.take(), - ws_rx: self.ws_rx.take(), - stream: Some(stream), - }; - match Pin::new(&mut ws_stream).poll(cx) { - Poll::Ready(_) => Poll::Ready(Ok(())), - Poll::Pending => Poll::Pending, + None => { + // log::debug!("🐴 ws handler done"); + return Poll::Pending; + }, + Some(message) => { + let message = WsMessage::from(message); + match self.handlers.get(&message.source) { + None => log::error!("Can't find any handler for message: {:?}", message), + Some(handler) => handler.receive_message(message.clone()), } }, - Err(error) => Poll::Ready(Err(error_to_flowy_response(error))), - }; + } } } } -fn error_to_flowy_response(error: tokio_tungstenite::tungstenite::Error) -> ServerError { - let error = match error { - Error::Http(response) => { - if response.status() == StatusCode::UNAUTHORIZED { - ServerError::unauthorized() - } else { - ServerError::internal().context(response) - } - }, - _ => ServerError::internal().context(error), - }; - - error -} - -struct WsStream { - msg_tx: Option, - ws_rx: Option, - stream: Option>>, -} - -impl Future for WsStream { - type Output = (); - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let (tx, rx) = (self.msg_tx.take().unwrap(), self.ws_rx.take().unwrap()); - let (ws_write, ws_read) = self.stream.take().unwrap().split(); - let to_ws = rx.map(Ok).forward(ws_write); - let from_ws = ws_read.for_each(|message| async { - match message { - Ok(message) => { - match tx.unbounded_send(message) { - Ok(_) => {}, - Err(e) => log::error!("tx send error: {:?}", e), - }; - }, - Err(e) => log::error!("ws read error: {:?}", e), - } - }); - - pin_mut!(to_ws, from_ws); - log::trace!("🐴 ws start poll stream"); - match to_ws.poll_unpin(cx) { - Poll::Ready(_) => Poll::Ready(()), - Poll::Pending => match from_ws.poll_unpin(cx) { - Poll::Ready(_) => Poll::Ready(()), - Poll::Pending => Poll::Pending, - }, - } - } -} - -pub struct WsSender { +struct WsSender { ws_tx: MsgSender, } impl WsSender { pub fn new(ws_tx: MsgSender) -> Self { Self { ws_tx } } - pub fn send(&self, msg: Message) -> Result<(), WsError> { - let _ = self.ws_tx.unbounded_send(msg)?; + pub fn send(&self, msg: WsMessage) -> Result<(), WsError> { + let _ = self.ws_tx.unbounded_send(msg.into()).map_err(|e| WsError::internal().context(e))?; Ok(()) } }