receive ws message continuedly by useing box::pin to fix the msg_tx not copyable

This commit is contained in:
appflowy 2021-09-19 12:54:28 +08:00
parent 2b32f2111f
commit aa8536149f
3 changed files with 45 additions and 113 deletions

View File

@ -1,23 +1,16 @@
use crate::{errors::WsError, MsgReceiver, MsgSender, WsMessage}; use crate::{errors::WsError, MsgReceiver, MsgSender};
use flowy_net::errors::ServerError; use flowy_net::errors::ServerError;
use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender}; use futures_core::{future::BoxFuture, ready};
use futures_core::{future::BoxFuture, ready, Stream}; use futures_util::{FutureExt, StreamExt, TryStreamExt};
use futures_util::{
future,
future::{Either, Select},
pin_mut,
FutureExt,
StreamExt,
};
use pin_project::pin_project; use pin_project::pin_project;
use std::{ use std::{
collections::HashMap, fmt,
future::Future, future::Future,
pin::Pin, pin::Pin,
sync::Arc, sync::Arc,
task::{Context, Poll}, task::{Context, Poll},
}; };
use tokio::{net::TcpStream, task::JoinHandle}; use tokio::net::TcpStream;
use tokio_tungstenite::{ use tokio_tungstenite::{
connect_async, connect_async,
tungstenite::{handshake::client::Response, http::StatusCode, Error, Message}, tungstenite::{handshake::client::Response, http::StatusCode, Error, Message},
@ -63,124 +56,70 @@ impl Future for WsConnection {
loop { loop {
return match ready!(self.as_mut().project().fut.poll(cx)) { return match ready!(self.as_mut().project().fut.poll(cx)) {
Ok((stream, _)) => { Ok((stream, _)) => {
log::debug!("🐴 ws connect success"); log::debug!("🐴 ws connect success: {:?}", error);
let (msg_tx, ws_rx) = (self.msg_tx.take().unwrap(), self.ws_rx.take().unwrap()); let (msg_tx, ws_rx) = (self.msg_tx.take().unwrap(), self.ws_rx.take().unwrap());
Poll::Ready(Ok(WsStream::new(msg_tx, ws_rx, stream))) Poll::Ready(Ok(WsStream::new(msg_tx, ws_rx, stream)))
}, },
Err(error) => Poll::Ready(Err(error_to_flowy_response(error))), Err(error) => {
log::debug!("🐴 ws connect failed: {:?}", error);
Poll::Ready(Err(error_to_flowy_response(error)))
},
}; };
} }
} }
} }
type Fut = BoxFuture<'static, Result<(), WsError>>;
#[pin_project] #[pin_project]
pub struct WsStream { pub struct WsStream {
msg_tx: MsgSender,
#[pin] #[pin]
fut: Option<(BoxFuture<'static, ()>, BoxFuture<'static, ()>)>, inner: Option<(Fut, Fut)>,
} }
impl WsStream { impl WsStream {
pub fn new(msg_tx: MsgSender, ws_rx: MsgReceiver, stream: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self { pub fn new(msg_tx: MsgSender, ws_rx: MsgReceiver, stream: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
let (ws_write, ws_read) = stream.split(); let (ws_write, ws_read) = stream.split();
let to_ws = ws_rx.map(Ok).forward(ws_write);
let from_ws = ws_read.for_each(|message| async {
// handle_new_message(msg_tx.clone(), message)
});
// pin_mut!(to_ws, from_ws);
Self { Self {
msg_tx, inner: Some((
fut: Some((
Box::pin(async move { Box::pin(async move {
let _ = from_ws.await; let _ = ws_read.for_each(|message| async { post_message(msg_tx.clone(), message) }).await;
Ok(())
}), }),
Box::pin(async move { Box::pin(async move {
let _ = to_ws.await; let _ = ws_rx.map(Ok).forward(ws_write).await?;
Ok(())
}), }),
)), )),
} }
} }
} }
impl fmt::Debug for WsStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("WsStream").finish() }
}
impl Future for WsStream { impl Future for WsStream {
type Output = (); type Output = Result<(), WsError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let (mut a, mut b) = self.fut.take().unwrap(); let (mut left, mut right) = self.inner.take().unwrap();
match a.poll_unpin(cx) { match left.poll_unpin(cx) {
Poll::Ready(x) => Poll::Ready(()), Poll::Ready(l) => Poll::Ready(l),
Poll::Pending => match b.poll_unpin(cx) { Poll::Pending => {
Poll::Ready(x) => Poll::Ready(()), //
Poll::Pending => { match right.poll_unpin(cx) {
// self.fut = Some((a, b)); Poll::Ready(r) => Poll::Ready(r),
Poll::Pending Poll::Pending => {
}, self.inner = Some((left, right));
Poll::Pending
},
}
}, },
} }
} }
} }
// pub struct WsStream { fn post_message(tx: MsgSender, message: Result<Message, Error>) {
// msg_tx: Option<MsgSender>,
// ws_rx: Option<MsgReceiver>,
// stream: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
// }
//
// impl WsStream {
// pub fn new(msg_tx: MsgSender, ws_rx: MsgReceiver, stream:
// WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self { Self {
// msg_tx: Some(msg_tx),
// ws_rx: Some(ws_rx),
// stream: Some(stream),
// }
// }
//
// pub fn start(mut self) -> JoinHandle<()> {
// let (msg_tx, ws_rx) = (self.msg_tx.take().unwrap(),
// self.ws_rx.take().unwrap()); let (ws_write, ws_read) =
// self.stream.take().unwrap().split(); tokio::spawn(async move {
// let to_ws = ws_rx.map(Ok).forward(ws_write);
// let from_ws = ws_read.for_each(|message| async {
// handle_new_message(msg_tx.clone(), message) }); pin_mut!(to_ws,
// from_ws);
//
// match future::select(to_ws, from_ws).await {
// Either::Left(_l) => {
// log::info!("ws left");
// },
// Either::Right(_r) => {
// log::info!("ws right");
// },
// }
// })
// }
// }
//
// impl Future for WsStream {
// type Output = ();
// fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) ->
// Poll<Self::Output> { let (msg_tx, ws_rx) =
// (self.msg_tx.take().unwrap(), self.ws_rx.take().unwrap()); let
// (ws_write, ws_read) = self.stream.take().unwrap().split(); let to_ws
// = ws_rx.map(Ok).forward(ws_write); let from_ws =
// ws_read.for_each(|message| async { handle_new_message(msg_tx.clone(),
// message) }); pin_mut!(to_ws, from_ws);
//
// loop {
// match ready!(Pin::new(&mut future::select(to_ws,
// from_ws)).poll(cx)) { Either::Left(a) => {
// //
// return Poll::Ready(());
// },
// Either::Right(b) => {
// //
// return Poll::Ready(());
// },
// }
// }
// }
// }
fn handle_new_message(tx: MsgSender, message: Result<Message, Error>) {
match message { match message {
Ok(Message::Binary(bytes)) => match tx.unbounded_send(Message::Binary(bytes)) { Ok(Message::Binary(bytes)) => match tx.unbounded_send(Message::Binary(bytes)) {
Ok(_) => {}, Ok(_) => {},

View File

@ -62,3 +62,7 @@ impl std::convert::From<protobuf::ProtobufError> for WsError {
impl std::convert::From<futures_channel::mpsc::TrySendError<Message>> for WsError { impl std::convert::From<futures_channel::mpsc::TrySendError<Message>> for WsError {
fn from(error: TrySendError<Message>) -> Self { WsError::internal().context(error) } fn from(error: TrySendError<Message>) -> Self { WsError::internal().context(error) }
} }
impl std::convert::From<tokio_tungstenite::tungstenite::Error> for WsError {
fn from(error: tokio_tungstenite::tungstenite::Error) -> Self { WsError::internal().context(error) }
}

View File

@ -1,14 +1,8 @@
use crate::{connect::WsConnection, errors::WsError, WsMessage}; use crate::{connect::WsConnection, errors::WsError, WsMessage};
use flowy_net::errors::ServerError; use flowy_net::errors::ServerError;
use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender}; use futures_channel::mpsc::{UnboundedReceiver, UnboundedSender};
use futures_core::{future::BoxFuture, ready, Stream}; use futures_core::{ready, Stream};
use futures_util::{
future,
future::{Either, Select},
pin_mut,
FutureExt,
StreamExt,
};
use pin_project::pin_project; use pin_project::pin_project;
use std::{ use std::{
collections::HashMap, collections::HashMap,
@ -17,13 +11,8 @@ use std::{
sync::Arc, sync::Arc,
task::{Context, Poll}, task::{Context, Poll},
}; };
use tokio::{net::TcpStream, task::JoinHandle}; use tokio::task::JoinHandle;
use tokio_tungstenite::{ use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
connect_async,
tungstenite::{handshake::client::Response, http::StatusCode, Error, Message},
MaybeTlsStream,
WebSocketStream,
};
pub type MsgReceiver = UnboundedReceiver<Message>; pub type MsgReceiver = UnboundedReceiver<Message>;
pub type MsgSender = UnboundedSender<Message>; pub type MsgSender = UnboundedSender<Message>;
@ -56,7 +45,7 @@ impl WsController {
} }
pub fn connect(&mut self, addr: String) -> Result<JoinHandle<()>, ServerError> { pub fn connect(&mut self, addr: String) -> Result<JoinHandle<()>, ServerError> {
log::debug!("🐴 Try to connect: {}", &addr); log::debug!("🐴 ws connect: {}", &addr);
let (connection, handlers) = self.make_connect(addr); let (connection, handlers) = self.make_connect(addr);
Ok(tokio::spawn(async { Ok(tokio::spawn(async {
tokio::select! { tokio::select! {