move connect code to channel and get rid of unwraps

This commit is contained in:
Marcel Märtens 2021-04-22 21:37:27 +02:00
parent 95b186e29a
commit 4afadf57dc
3 changed files with 149 additions and 103 deletions

View File

@ -222,7 +222,9 @@ where
if is_reliable(&promises) { if is_reliable(&promises) {
self.reliable_buffers.insert(sid, BytesMut::new()); self.reliable_buffers.insert(sid, BytesMut::new());
//Send a empty message to notify local drain of stream //Send a empty message to notify local drain of stream
self.drain.send(QuicDataFormat::with_reliable(&mut BytesMut::new(), sid)).await?; self.drain
.send(QuicDataFormat::with_reliable(&mut BytesMut::new(), sid))
.await?;
} }
event.to_frame().write_bytes(&mut self.main_buffer); event.to_frame().write_bytes(&mut self.main_buffer);
self.drain self.drain

View File

@ -1,3 +1,4 @@
use crate::api::NetworkConnectError;
use async_trait::async_trait; use async_trait::async_trait;
use bytes::BytesMut; use bytes::BytesMut;
use network_protocol::{ use network_protocol::{
@ -9,10 +10,12 @@ use network_protocol::{
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
use tokio::{ use tokio::{
io::{AsyncReadExt, AsyncWriteExt}, io::{AsyncReadExt, AsyncWriteExt},
net,
net::tcp::{OwnedReadHalf, OwnedWriteHalf}, net::tcp::{OwnedReadHalf, OwnedWriteHalf},
sync::mpsc, sync::{mpsc, oneshot},
}; };
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::{info, trace};
#[allow(clippy::large_enum_variant)] #[allow(clippy::large_enum_variant)]
#[derive(Debug)] #[derive(Debug)]
@ -40,6 +43,23 @@ pub(crate) enum RecvProtocols {
} }
impl Protocols { impl Protocols {
const MPSC_CHANNEL_BOUND: usize = 1000;
pub(crate) async fn with_tcp_connect(
addr: std::net::SocketAddr,
cid: Cid,
metrics: Arc<ProtocolMetrics>,
) -> Result<Self, NetworkConnectError> {
let stream = match net::TcpStream::connect(addr).await {
Ok(stream) => stream,
Err(e) => {
return Err(crate::api::NetworkConnectError::Io(e));
},
};
info!("Connecting Tcp to: {}", stream.peer_addr().unwrap());
Ok(Protocols::new_tcp(stream, cid, metrics))
}
pub(crate) fn new_tcp( pub(crate) fn new_tcp(
stream: tokio::net::TcpStream, stream: tokio::net::TcpStream,
cid: Cid, cid: Cid,
@ -59,6 +79,49 @@ impl Protocols {
Protocols::Tcp((sp, rp)) Protocols::Tcp((sp, rp))
} }
pub(crate) async fn with_mpsc_connect(
addr: u64,
cid: Cid,
metrics: Arc<ProtocolMetrics>,
) -> Result<Self, NetworkConnectError> {
let mpsc_s = match crate::scheduler::MPSC_POOL.lock().await.get(&addr) {
Some(s) => s.clone(),
None => {
return Err(NetworkConnectError::Io(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"no mpsc listen on this addr",
)));
},
};
let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND);
let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel();
if mpsc_s
.send((remote_to_local_s, local_to_remote_oneshot_s))
.is_err()
{
return Err(NetworkConnectError::Io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"mpsc pipe broke during connect",
)));
}
let local_to_remote_s = match local_to_remote_oneshot_r.await {
Ok(s) => s,
Err(e) => {
return Err(NetworkConnectError::Io(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
e,
)));
},
};
info!(?addr, "Connecting Mpsc");
Ok(Self::new_mpsc(
local_to_remote_s,
remote_to_local_r,
cid,
metrics,
))
}
pub(crate) fn new_mpsc( pub(crate) fn new_mpsc(
sender: mpsc::Sender<MpscMsg>, sender: mpsc::Sender<MpscMsg>,
receiver: mpsc::Receiver<MpscMsg>, receiver: mpsc::Receiver<MpscMsg>,
@ -72,6 +135,46 @@ impl Protocols {
Protocols::Mpsc((sp, rp)) Protocols::Mpsc((sp, rp))
} }
pub(crate) async fn with_quic_connect(
addr: std::net::SocketAddr,
config: quinn::ClientConfig,
name: String,
cid: Cid,
metrics: Arc<ProtocolMetrics>,
) -> Result<Self, NetworkConnectError> {
let config = config.clone();
let endpoint = quinn::Endpoint::builder();
let (endpoint, _) = match endpoint.bind(&"[::]:0".parse().unwrap()) {
Ok(e) => e,
Err(quinn::EndpointError::Socket(e)) => return Err(NetworkConnectError::Io(e)),
};
info!("Connecting Quic to: {}", &addr);
let connecting = endpoint.connect_with(config, &addr, &name).map_err(|e| {
trace!(?e, "error setting up quic");
NetworkConnectError::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
e,
))
})?;
let connection = connecting.await.map_err(|e| {
trace!(?e, "error with quic connection");
NetworkConnectError::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
e,
))
})?;
Protocols::new_quic(connection, false, cid, metrics)
.await
.map_err(|e| {
trace!(?e, "error with quic");
NetworkConnectError::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
e,
))
})
}
#[cfg(feature = "quic")] #[cfg(feature = "quic")]
pub(crate) async fn new_quic( pub(crate) async fn new_quic(
mut connection: quinn::NewConnection, mut connection: quinn::NewConnection,
@ -81,14 +184,18 @@ impl Protocols {
) -> Result<Self, quinn::ConnectionError> { ) -> Result<Self, quinn::ConnectionError> {
let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics);
let (sendstream, recvstream) = if listen { let (sendstream, recvstream) = if listen {
connection.connection.open_bi().await? connection.connection.open_bi().await?
} else { } else {
connection.bi_streams.next().await.expect("none").expect("dasdasd") connection
.bi_streams
.next()
.await
.ok_or_else(|| quinn::ConnectionError::LocallyClosed)??
}; };
let (recvstreams_s,recvstreams_r) = mpsc::unbounded_channel(); let (recvstreams_s, recvstreams_r) = mpsc::unbounded_channel();
let streams_s_clone = recvstreams_s.clone(); let streams_s_clone = recvstreams_s.clone();
let (sendstreams_s,sendstreams_r) = mpsc::unbounded_channel(); let (sendstreams_s, sendstreams_r) = mpsc::unbounded_channel();
let sp = QuicSendProtocol::new( let sp = QuicSendProtocol::new(
QuicDrain { QuicDrain {
con: connection.connection.clone(), con: connection.connection.clone(),
@ -261,7 +368,12 @@ impl UnreliableSink for MpscSink {
/////////////////////////////////////// ///////////////////////////////////////
//// QUIC //// QUIC
#[cfg(feature = "quic")] #[cfg(feature = "quic")]
type QuicStream = (BytesMut, Result<Option<usize>, quinn::ReadError>, quinn::RecvStream, Option<Sid>); type QuicStream = (
BytesMut,
Result<Option<usize>, quinn::ReadError>,
quinn::RecvStream,
Option<Sid>,
);
#[cfg(feature = "quic")] #[cfg(feature = "quic")]
#[derive(Debug)] #[derive(Debug)]
@ -284,7 +396,11 @@ pub struct QuicSink {
} }
#[cfg(feature = "quic")] #[cfg(feature = "quic")]
fn spawn_new(mut recvstream: quinn::RecvStream, sid: Option<Sid>, streams_s: &mpsc::UnboundedSender<QuicStream>) { fn spawn_new(
mut recvstream: quinn::RecvStream,
sid: Option<Sid>,
streams_s: &mpsc::UnboundedSender<QuicStream>,
) {
let streams_s_clone = streams_s.clone(); let streams_s_clone = streams_s.clone();
tokio::spawn(async move { tokio::spawn(async move {
let mut buffer = BytesMut::new(); let mut buffer = BytesMut::new();
@ -301,19 +417,16 @@ impl UnreliableDrain for QuicDrain {
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> {
match match data.stream { match match data.stream {
QuicDataFormatStream::Main => { QuicDataFormatStream::Main => self.main.write_all(&data.data).await,
self.main.write_all(&data.data).await
},
QuicDataFormatStream::Unreliable => unimplemented!(), QuicDataFormatStream::Unreliable => unimplemented!(),
QuicDataFormatStream::Reliable(sid) => { QuicDataFormatStream::Reliable(sid) => {
use std::collections::hash_map::Entry; use std::collections::hash_map::Entry;
tracing::trace!(?sid, "Reliable"); tracing::trace!(?sid, "Reliable");
match self.reliables.entry(sid) { match self.reliables.entry(sid) {
Entry::Occupied(mut occupied) => { Entry::Occupied(mut occupied) => occupied.get_mut().write_all(&data.data).await,
occupied.get_mut().write_all(&data.data).await
},
Entry::Vacant(vacant) => { Entry::Vacant(vacant) => {
// IF the buffer is empty this was created localy and WE are allowed to open_bi(), if not, we NEED to block on sendstreams_r // IF the buffer is empty this was created localy and WE are allowed to
// open_bi(), if not, we NEED to block on sendstreams_r
if data.data.is_empty() { if data.data.is_empty() {
match self.con.open_bi().await { match self.con.open_bi().await {
Ok((mut sendstream, recvstream)) => { Ok((mut sendstream, recvstream)) => {
@ -327,14 +440,17 @@ impl UnreliableDrain for QuicDrain {
Err(_) => return Err(ProtocolError::Closed), Err(_) => return Err(ProtocolError::Closed),
} }
} else { } else {
let sendstream = self.sendstreams_r.recv().await.ok_or(ProtocolError::Closed)?; let sendstream = self
.sendstreams_r
.recv()
.await
.ok_or(ProtocolError::Closed)?;
vacant.insert(sendstream).write_all(&data.data).await vacant.insert(sendstream).write_all(&data.data).await
} }
}, },
} }
}, },
} } {
{
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(_) => Err(ProtocolError::Closed), Err(_) => Err(ProtocolError::Closed),
} }
@ -391,7 +507,6 @@ impl UnreliableSink for QuicSink {
Err(_) => Err(ProtocolError::Closed), Err(_) => Err(ProtocolError::Closed),
}?; }?;
let streams_s_clone = self.recvstreams_s.clone(); let streams_s_clone = self.recvstreams_s.clone();
tokio::spawn(async move { tokio::spawn(async move {
buffer.resize(1500, 0u8); buffer.resize(1500, 0u8);

View File

@ -34,7 +34,7 @@ use tracing::*;
// - c: channel/handshake // - c: channel/handshake
lazy_static::lazy_static! { lazy_static::lazy_static! {
static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<S2sMpscConnect>>> = { pub(crate) static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<S2sMpscConnect>>> = {
Mutex::new(HashMap::new()) Mutex::new(HashMap::new())
}; };
} }
@ -197,94 +197,23 @@ impl Scheduler {
let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed);
let metrics = Arc::clone(&self.protocol_metrics); let metrics = Arc::clone(&self.protocol_metrics);
self.metrics.connect_request(&addr); self.metrics.connect_request(&addr);
let (protocol, handshake) = match addr { let protocol = match addr {
ConnectAddr::Tcp(addr) => { ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, cid, metrics).await,
let stream = match net::TcpStream::connect(addr).await {
Ok(stream) => stream,
Err(e) => {
pid_sender.send(Err(NetworkConnectError::Io(e))).unwrap();
continue;
},
};
info!("Connecting Tcp to: {}", stream.peer_addr().unwrap());
(Protocols::new_tcp(stream, cid, metrics), false)
},
#[cfg(feature = "quic")] #[cfg(feature = "quic")]
ConnectAddr::Quic(addr, ref config, name) => { ConnectAddr::Quic(addr, ref config, name) => {
let config = config.clone(); Protocols::with_quic_connect(addr, config.clone(), name, cid, metrics).await
let endpoint = quinn::Endpoint::builder();
let (endpoint, _) = endpoint.bind(&"[::]:0".parse().unwrap()).expect("FIXME");
let connecting = endpoint.connect_with(config, &addr, &name).expect("FIXME");
let connection = connecting.await.expect("FIXME");
(
Protocols::new_quic(connection, false, cid, metrics).await.unwrap(),
false,
)
//pid_sender.send(Ok(())).unwrap();
}, },
ConnectAddr::Mpsc(addr) => { ConnectAddr::Mpsc(addr) => Protocols::with_mpsc_connect(addr, cid, metrics).await,
let mpsc_s = match MPSC_POOL.lock().await.get(&addr) {
Some(s) => s.clone(),
None => {
pid_sender
.send(Err(NetworkConnectError::Io(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"no mpsc listen on this addr",
))))
.unwrap();
continue;
},
};
let (remote_to_local_s, remote_to_local_r) =
mpsc::channel(Self::MPSC_CHANNEL_BOUND);
let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel();
mpsc_s
.send((remote_to_local_s, local_to_remote_oneshot_s))
.unwrap();
let local_to_remote_s = local_to_remote_oneshot_r.await.unwrap();
info!(?addr, "Connecting Mpsc");
(
Protocols::new_mpsc(local_to_remote_s, remote_to_local_r, cid, metrics),
false,
)
},
/* */
//ProtocolConnectAddr::Udp(addr) => {
//#[cfg(feature = "metrics")]
//self.metrics
//.connect_requests_total
//.with_label_values(&["udp"])
//.inc();
//let socket = match net::UdpSocket::bind("0.0.0.0:0").await {
//Ok(socket) => Arc::new(socket),
//Err(e) => {
//pid_sender.send(Err(e)).unwrap();
//continue;
//},
//};
//if let Err(e) = socket.connect(addr).await {
//pid_sender.send(Err(e)).unwrap();
//continue;
//};
//info!("Connecting Udp to: {}", addr);
//let (udp_data_sender, udp_data_receiver) = mpsc::unbounded_channel::<Vec<u8>>();
//let protocol = UdpProtocol::new(
//Arc::clone(&socket),
//addr,
//#[cfg(feature = "metrics")]
//Arc::clone(&self.metrics),
//udp_data_receiver,
//);
//self.runtime.spawn(
//Self::udp_single_channel_connect(Arc::clone(&socket), udp_data_sender)
//.instrument(tracing::info_span!("udp", ?addr)),
//);
//(Protocols::Udp(protocol), true)
//},
_ => unimplemented!(), _ => unimplemented!(),
}; };
self.init_protocol(protocol, cid, Some(pid_sender), handshake) let protocol = match protocol {
Ok(p) => p,
Err(e) => {
pid_sender.send(Err(e)).unwrap();
continue;
},
};
self.init_protocol(protocol, cid, Some(pid_sender), false)
.await; .await;
} }
trace!("Stop connect_mgr"); trace!("Stop connect_mgr");