From 4afadf57dc2328ee030b3de5664c5de4c995566a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= Date: Thu, 22 Apr 2021 21:37:27 +0200 Subject: [PATCH] move connect code to channel and get rid of unwraps --- network/protocol/src/quic.rs | 4 +- network/src/channel.rs | 151 ++++++++++++++++++++++++++++++----- network/src/scheduler.rs | 97 +++------------------- 3 files changed, 149 insertions(+), 103 deletions(-) diff --git a/network/protocol/src/quic.rs b/network/protocol/src/quic.rs index 0e76e1fe32..a10764491b 100644 --- a/network/protocol/src/quic.rs +++ b/network/protocol/src/quic.rs @@ -222,7 +222,9 @@ where if is_reliable(&promises) { self.reliable_buffers.insert(sid, BytesMut::new()); //Send a empty message to notify local drain of stream - self.drain.send(QuicDataFormat::with_reliable(&mut BytesMut::new(), sid)).await?; + self.drain + .send(QuicDataFormat::with_reliable(&mut BytesMut::new(), sid)) + .await?; } event.to_frame().write_bytes(&mut self.main_buffer); self.drain diff --git a/network/src/channel.rs b/network/src/channel.rs index 872c0647cf..fe3bff971e 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -1,3 +1,4 @@ +use crate::api::NetworkConnectError; use async_trait::async_trait; use bytes::BytesMut; use network_protocol::{ @@ -9,10 +10,12 @@ use network_protocol::{ use std::{sync::Arc, time::Duration}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, + net, net::tcp::{OwnedReadHalf, OwnedWriteHalf}, - sync::mpsc, + sync::{mpsc, oneshot}, }; use tokio_stream::StreamExt; +use tracing::{info, trace}; #[allow(clippy::large_enum_variant)] #[derive(Debug)] @@ -40,6 +43,23 @@ pub(crate) enum RecvProtocols { } impl Protocols { + const MPSC_CHANNEL_BOUND: usize = 1000; + + pub(crate) async fn with_tcp_connect( + addr: std::net::SocketAddr, + cid: Cid, + metrics: Arc, + ) -> Result { + let stream = match net::TcpStream::connect(addr).await { + Ok(stream) => stream, + Err(e) => { + return Err(crate::api::NetworkConnectError::Io(e)); + }, + }; + info!("Connecting Tcp to: {}", stream.peer_addr().unwrap()); + Ok(Protocols::new_tcp(stream, cid, metrics)) + } + pub(crate) fn new_tcp( stream: tokio::net::TcpStream, cid: Cid, @@ -59,6 +79,49 @@ impl Protocols { Protocols::Tcp((sp, rp)) } + pub(crate) async fn with_mpsc_connect( + addr: u64, + cid: Cid, + metrics: Arc, + ) -> Result { + let mpsc_s = match crate::scheduler::MPSC_POOL.lock().await.get(&addr) { + Some(s) => s.clone(), + None => { + return Err(NetworkConnectError::Io(std::io::Error::new( + std::io::ErrorKind::NotConnected, + "no mpsc listen on this addr", + ))); + }, + }; + let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND); + let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel(); + if mpsc_s + .send((remote_to_local_s, local_to_remote_oneshot_s)) + .is_err() + { + return Err(NetworkConnectError::Io(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "mpsc pipe broke during connect", + ))); + } + let local_to_remote_s = match local_to_remote_oneshot_r.await { + Ok(s) => s, + Err(e) => { + return Err(NetworkConnectError::Io(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + e, + ))); + }, + }; + info!(?addr, "Connecting Mpsc"); + Ok(Self::new_mpsc( + local_to_remote_s, + remote_to_local_r, + cid, + metrics, + )) + } + pub(crate) fn new_mpsc( sender: mpsc::Sender, receiver: mpsc::Receiver, @@ -72,6 +135,46 @@ impl Protocols { Protocols::Mpsc((sp, rp)) } + pub(crate) async fn with_quic_connect( + addr: std::net::SocketAddr, + config: quinn::ClientConfig, + name: String, + cid: Cid, + metrics: Arc, + ) -> Result { + let config = config.clone(); + let endpoint = quinn::Endpoint::builder(); + let (endpoint, _) = match endpoint.bind(&"[::]:0".parse().unwrap()) { + Ok(e) => e, + Err(quinn::EndpointError::Socket(e)) => return Err(NetworkConnectError::Io(e)), + }; + + info!("Connecting Quic to: {}", &addr); + let connecting = endpoint.connect_with(config, &addr, &name).map_err(|e| { + trace!(?e, "error setting up quic"); + NetworkConnectError::Io(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + e, + )) + })?; + let connection = connecting.await.map_err(|e| { + trace!(?e, "error with quic connection"); + NetworkConnectError::Io(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + e, + )) + })?; + Protocols::new_quic(connection, false, cid, metrics) + .await + .map_err(|e| { + trace!(?e, "error with quic"); + NetworkConnectError::Io(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + e, + )) + }) + } + #[cfg(feature = "quic")] pub(crate) async fn new_quic( mut connection: quinn::NewConnection, @@ -81,14 +184,18 @@ impl Protocols { ) -> Result { let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); - let (sendstream, recvstream) = if listen { + let (sendstream, recvstream) = if listen { connection.connection.open_bi().await? } else { - connection.bi_streams.next().await.expect("none").expect("dasdasd") + connection + .bi_streams + .next() + .await + .ok_or_else(|| quinn::ConnectionError::LocallyClosed)?? }; - let (recvstreams_s,recvstreams_r) = mpsc::unbounded_channel(); + let (recvstreams_s, recvstreams_r) = mpsc::unbounded_channel(); let streams_s_clone = recvstreams_s.clone(); - let (sendstreams_s,sendstreams_r) = mpsc::unbounded_channel(); + let (sendstreams_s, sendstreams_r) = mpsc::unbounded_channel(); let sp = QuicSendProtocol::new( QuicDrain { con: connection.connection.clone(), @@ -261,7 +368,12 @@ impl UnreliableSink for MpscSink { /////////////////////////////////////// //// QUIC #[cfg(feature = "quic")] -type QuicStream = (BytesMut, Result, quinn::ReadError>, quinn::RecvStream, Option); +type QuicStream = ( + BytesMut, + Result, quinn::ReadError>, + quinn::RecvStream, + Option, +); #[cfg(feature = "quic")] #[derive(Debug)] @@ -284,7 +396,11 @@ pub struct QuicSink { } #[cfg(feature = "quic")] -fn spawn_new(mut recvstream: quinn::RecvStream, sid: Option, streams_s: &mpsc::UnboundedSender) { +fn spawn_new( + mut recvstream: quinn::RecvStream, + sid: Option, + streams_s: &mpsc::UnboundedSender, +) { let streams_s_clone = streams_s.clone(); tokio::spawn(async move { let mut buffer = BytesMut::new(); @@ -301,19 +417,16 @@ impl UnreliableDrain for QuicDrain { async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { match match data.stream { - QuicDataFormatStream::Main => { - self.main.write_all(&data.data).await - }, + QuicDataFormatStream::Main => self.main.write_all(&data.data).await, QuicDataFormatStream::Unreliable => unimplemented!(), QuicDataFormatStream::Reliable(sid) => { use std::collections::hash_map::Entry; tracing::trace!(?sid, "Reliable"); match self.reliables.entry(sid) { - Entry::Occupied(mut occupied) => { - occupied.get_mut().write_all(&data.data).await - }, + Entry::Occupied(mut occupied) => occupied.get_mut().write_all(&data.data).await, Entry::Vacant(vacant) => { - // IF the buffer is empty this was created localy and WE are allowed to open_bi(), if not, we NEED to block on sendstreams_r + // IF the buffer is empty this was created localy and WE are allowed to + // open_bi(), if not, we NEED to block on sendstreams_r if data.data.is_empty() { match self.con.open_bi().await { Ok((mut sendstream, recvstream)) => { @@ -327,14 +440,17 @@ impl UnreliableDrain for QuicDrain { Err(_) => return Err(ProtocolError::Closed), } } else { - let sendstream = self.sendstreams_r.recv().await.ok_or(ProtocolError::Closed)?; + let sendstream = self + .sendstreams_r + .recv() + .await + .ok_or(ProtocolError::Closed)?; vacant.insert(sendstream).write_all(&data.data).await } }, } }, - } - { + } { Ok(()) => Ok(()), Err(_) => Err(ProtocolError::Closed), } @@ -391,7 +507,6 @@ impl UnreliableSink for QuicSink { Err(_) => Err(ProtocolError::Closed), }?; - let streams_s_clone = self.recvstreams_s.clone(); tokio::spawn(async move { buffer.resize(1500, 0u8); diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index 475e34371f..1e8d5c69a8 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -34,7 +34,7 @@ use tracing::*; // - c: channel/handshake lazy_static::lazy_static! { - static ref MPSC_POOL: Mutex>> = { + pub(crate) static ref MPSC_POOL: Mutex>> = { Mutex::new(HashMap::new()) }; } @@ -197,94 +197,23 @@ impl Scheduler { let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); let metrics = Arc::clone(&self.protocol_metrics); self.metrics.connect_request(&addr); - let (protocol, handshake) = match addr { - ConnectAddr::Tcp(addr) => { - let stream = match net::TcpStream::connect(addr).await { - Ok(stream) => stream, - Err(e) => { - pid_sender.send(Err(NetworkConnectError::Io(e))).unwrap(); - continue; - }, - }; - info!("Connecting Tcp to: {}", stream.peer_addr().unwrap()); - (Protocols::new_tcp(stream, cid, metrics), false) - }, + let protocol = match addr { + ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, cid, metrics).await, #[cfg(feature = "quic")] ConnectAddr::Quic(addr, ref config, name) => { - let config = config.clone(); - let endpoint = quinn::Endpoint::builder(); - let (endpoint, _) = endpoint.bind(&"[::]:0".parse().unwrap()).expect("FIXME"); - - let connecting = endpoint.connect_with(config, &addr, &name).expect("FIXME"); - let connection = connecting.await.expect("FIXME"); - ( - Protocols::new_quic(connection, false, cid, metrics).await.unwrap(), - false, - ) - //pid_sender.send(Ok(())).unwrap(); + Protocols::with_quic_connect(addr, config.clone(), name, cid, metrics).await }, - ConnectAddr::Mpsc(addr) => { - let mpsc_s = match MPSC_POOL.lock().await.get(&addr) { - Some(s) => s.clone(), - None => { - pid_sender - .send(Err(NetworkConnectError::Io(std::io::Error::new( - std::io::ErrorKind::NotConnected, - "no mpsc listen on this addr", - )))) - .unwrap(); - continue; - }, - }; - let (remote_to_local_s, remote_to_local_r) = - mpsc::channel(Self::MPSC_CHANNEL_BOUND); - let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel(); - mpsc_s - .send((remote_to_local_s, local_to_remote_oneshot_s)) - .unwrap(); - let local_to_remote_s = local_to_remote_oneshot_r.await.unwrap(); - info!(?addr, "Connecting Mpsc"); - ( - Protocols::new_mpsc(local_to_remote_s, remote_to_local_r, cid, metrics), - false, - ) - }, - /* */ - //ProtocolConnectAddr::Udp(addr) => { - //#[cfg(feature = "metrics")] - //self.metrics - //.connect_requests_total - //.with_label_values(&["udp"]) - //.inc(); - //let socket = match net::UdpSocket::bind("0.0.0.0:0").await { - //Ok(socket) => Arc::new(socket), - //Err(e) => { - //pid_sender.send(Err(e)).unwrap(); - //continue; - //}, - //}; - //if let Err(e) = socket.connect(addr).await { - //pid_sender.send(Err(e)).unwrap(); - //continue; - //}; - //info!("Connecting Udp to: {}", addr); - //let (udp_data_sender, udp_data_receiver) = mpsc::unbounded_channel::>(); - //let protocol = UdpProtocol::new( - //Arc::clone(&socket), - //addr, - //#[cfg(feature = "metrics")] - //Arc::clone(&self.metrics), - //udp_data_receiver, - //); - //self.runtime.spawn( - //Self::udp_single_channel_connect(Arc::clone(&socket), udp_data_sender) - //.instrument(tracing::info_span!("udp", ?addr)), - //); - //(Protocols::Udp(protocol), true) - //}, + ConnectAddr::Mpsc(addr) => Protocols::with_mpsc_connect(addr, cid, metrics).await, _ => unimplemented!(), }; - self.init_protocol(protocol, cid, Some(pid_sender), handshake) + let protocol = match protocol { + Ok(p) => p, + Err(e) => { + pid_sender.send(Err(e)).unwrap(); + continue; + }, + }; + self.init_protocol(protocol, cid, Some(pid_sender), false) .await; } trace!("Stop connect_mgr");