use std::sync::Arc; use appflowy_integrate::RocksCollabDB; use serde::{Deserialize, Serialize}; use tokio::sync::RwLock; use flowy_error::internal_error; use flowy_sqlite::ConnectionPool; use flowy_sqlite::{ kv::KV, query_dsl::*, schema::{user_table, user_table::dsl}, DBConnection, ExpressionMethods, UserDatabaseConnection, }; use crate::entities::{ SignInParams, SignInResponse, SignUpParams, SignUpResponse, UpdateUserProfileParams, UserProfile, }; use crate::entities::{UserProfilePB, UserSettingPB}; use crate::event_map::UserStatusCallback; use crate::{ errors::FlowyError, event_map::UserCloudService, notification::*, services::database::{UserDB, UserTable, UserTableChangeset}, }; pub struct UserSessionConfig { root_dir: String, /// Used as the key of `Session` when saving session information to KV. session_cache_key: String, } impl UserSessionConfig { /// The `root_dir` represents as the root of the user folders. It must be unique for each /// users. pub fn new(name: &str, root_dir: &str) -> Self { let session_cache_key = format!("{}_session_cache", name); Self { root_dir: root_dir.to_owned(), session_cache_key, } } } pub struct UserSession { database: UserDB, session_config: UserSessionConfig, cloud_service: Arc, user_status_callback: RwLock>>, } impl UserSession { pub fn new(session_config: UserSessionConfig, cloud_service: Arc) -> Self { let db = UserDB::new(&session_config.root_dir); let user_status_callback = RwLock::new(None); Self { database: db, session_config, cloud_service, user_status_callback, } } pub async fn init(&self, user_status_callback: C) { if let Ok(session) = self.get_session() { let _ = user_status_callback .did_sign_in(&session.token, session.user_id) .await; } *self.user_status_callback.write().await = Some(Arc::new(user_status_callback)); } pub fn db_connection(&self) -> Result { let user_id = self.get_session()?.user_id; self.database.get_connection(user_id) } // The caller will be not 'Sync' before of the return value, // PooledConnection is not sync. You can use // db_connection_pool function to require the ConnectionPool that is 'Sync'. // // let pool = self.db_connection_pool()?; // let conn: PooledConnection = pool.get()?; pub fn db_pool(&self) -> Result, FlowyError> { let user_id = self.get_session()?.user_id; self.database.get_pool(user_id) } pub fn get_collab_db(&self) -> Result, FlowyError> { let user_id = self.get_session()?.user_id; self.database.get_kv_db(user_id) } #[tracing::instrument(level = "debug", skip(self))] pub async fn sign_in(&self, params: SignInParams) -> Result { if self.is_user_login(¶ms.email) { match self.get_user_profile().await { Ok(profile) => { send_sign_in_notification() .payload::(profile.clone().into()) .send(); Ok(profile) }, Err(err) => Err(err), } } else { let resp = self.cloud_service.sign_in(params).await?; let session: Session = resp.clone().into(); self.set_session(Some(session))?; let user_profile: UserProfile = self.save_user(resp.into()).await?.into(); let _ = self .user_status_callback .read() .await .as_ref() .unwrap() .did_sign_in(&user_profile.token, user_profile.id) .await; send_sign_in_notification() .payload::(user_profile.clone().into()) .send(); Ok(user_profile) } } #[tracing::instrument(level = "debug", skip(self))] pub async fn sign_up(&self, params: SignUpParams) -> Result { if self.is_user_login(¶ms.email) { self.get_user_profile().await } else { let resp = self.cloud_service.sign_up(params).await?; let session: Session = resp.clone().into(); self.set_session(Some(session))?; let user_table = self.save_user(resp.into()).await?; let user_profile: UserProfile = user_table.into(); let _ = self .user_status_callback .read() .await .as_ref() .unwrap() .did_sign_up(&user_profile) .await; Ok(user_profile) } } #[tracing::instrument(level = "debug", skip(self))] pub async fn sign_out(&self) -> Result<(), FlowyError> { let session = self.get_session()?; let uid = session.user_id.to_string(); let _ = diesel::delete(dsl::user_table.filter(dsl::id.eq(&uid))) .execute(&*(self.db_connection()?))?; self.database.close_user_db(session.user_id)?; self.set_session(None)?; let _ = self .user_status_callback .read() .await .as_ref() .unwrap() .did_expired(&session.token, session.user_id) .await; self.sign_out_on_server(&session.token).await?; Ok(()) } #[tracing::instrument(level = "debug", skip(self))] pub async fn update_user_profile( &self, params: UpdateUserProfileParams, ) -> Result<(), FlowyError> { let session = self.get_session()?; let changeset = UserTableChangeset::new(params.clone()); diesel_update_table!(user_table, changeset, &*self.db_connection()?); let user_profile = self.get_user_profile().await?; let profile_pb: UserProfilePB = user_profile.into(); send_notification(&session.token, UserNotification::DidUpdateUserProfile) .payload(profile_pb) .send(); self.update_user_on_server(&session.token, params).await?; Ok(()) } pub async fn init_user(&self) -> Result<(), FlowyError> { Ok(()) } pub async fn check_user(&self) -> Result { let (user_id, token) = self.get_session()?.into_part(); let user_id = user_id.to_string(); let user = dsl::user_table .filter(user_table::id.eq(&user_id)) .first::(&*(self.db_connection()?))?; self.read_user_profile_on_server(&token)?; Ok(user.into()) } pub async fn get_user_profile(&self) -> Result { let (user_id, token) = self.get_session()?.into_part(); let user_id = user_id.to_string(); let user = dsl::user_table .filter(user_table::id.eq(&user_id)) .first::(&*(self.db_connection()?))?; self.read_user_profile_on_server(&token)?; Ok(user.into()) } pub fn user_dir(&self) -> Result { let session = self.get_session()?; Ok(format!( "{}/{}", self.session_config.root_dir, session.user_id )) } pub fn user_setting(&self) -> Result { let user_setting = UserSettingPB { user_folder: self.user_dir()?, }; Ok(user_setting) } pub fn user_id(&self) -> Result { Ok(self.get_session()?.user_id) } pub fn user_name(&self) -> Result { Ok(self.get_session()?.name) } pub fn token(&self) -> Result { Ok(self.get_session()?.token) } } impl UserSession { fn read_user_profile_on_server(&self, _token: &str) -> Result<(), FlowyError> { Ok(()) } async fn update_user_on_server( &self, token: &str, params: UpdateUserProfileParams, ) -> Result<(), FlowyError> { let server = self.cloud_service.clone(); let token = token.to_owned(); let _ = tokio::spawn(async move { match server.update_user(&token, params).await { Ok(_) => {}, Err(e) => { // TODO: retry? tracing::error!("update user profile failed: {:?}", e); }, } }) .await; Ok(()) } async fn sign_out_on_server(&self, token: &str) -> Result<(), FlowyError> { let server = self.cloud_service.clone(); let token = token.to_owned(); let _ = tokio::spawn(async move { match server.sign_out(&token).await { Ok(_) => {}, Err(e) => tracing::error!("Sign out failed: {:?}", e), } }) .await; Ok(()) } async fn save_user(&self, user: UserTable) -> Result { let conn = self.db_connection()?; let _ = diesel::insert_into(user_table::table) .values(user.clone()) .execute(&*conn)?; Ok(user) } fn set_session(&self, session: Option) -> Result<(), FlowyError> { tracing::debug!("Set user session: {:?}", session); match &session { None => KV::remove(&self.session_config.session_cache_key), Some(session) => { KV::set_object(&self.session_config.session_cache_key, session.clone()) .map_err(internal_error)?; }, } Ok(()) } fn get_session(&self) -> Result { match KV::get_object::(&self.session_config.session_cache_key) { None => Err(FlowyError::unauthorized()), Some(session) => Ok(session), } } fn is_user_login(&self, email: &str) -> bool { match self.get_session() { Ok(session) => session.email == email, Err(_) => false, } } } pub async fn update_user( _cloud_service: Arc, pool: Arc, params: UpdateUserProfileParams, ) -> Result<(), FlowyError> { let changeset = UserTableChangeset::new(params); let conn = pool.get()?; diesel_update_table!(user_table, changeset, &*conn); Ok(()) } impl UserDatabaseConnection for UserSession { fn get_connection(&self) -> Result { self.db_connection().map_err(|e| format!("{:?}", e)) } } #[derive(Debug, Clone, Default, Serialize, Deserialize)] struct Session { user_id: i64, token: String, email: String, #[serde(default)] name: String, } impl std::convert::From for Session { fn from(resp: SignInResponse) -> Self { Session { user_id: resp.user_id, token: resp.token, email: resp.email, name: resp.name, } } } impl std::convert::From for Session { fn from(resp: SignUpResponse) -> Self { Session { user_id: resp.user_id, token: resp.token, email: resp.email, name: resp.name, } } } impl Session { pub fn into_part(self) -> (i64, String) { (self.user_id, self.token) } } impl std::convert::From for Session { fn from(s: String) -> Self { match serde_json::from_str(&s) { Ok(s) => s, Err(e) => { tracing::error!("Deserialize string to Session failed: {:?}", e); Session::default() }, } } } impl std::convert::From for String { fn from(session: Session) -> Self { match serde_json::to_string(&session) { Ok(s) => s, Err(e) => { tracing::error!("Serialize session to string failed: {:?}", e); "".to_string() }, } } } #[derive(Debug, Clone, Default, Serialize, Deserialize)] struct OldSession { user_id: String, token: String, email: String, #[serde(default)] name: String, }