From e8b7485abe1c723391354fcfc9fd66ae2c742a51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= <marcel.cochem@googlemail.com> Date: Fri, 9 Apr 2021 13:17:38 +0200 Subject: [PATCH 1/7] Quic: We had the followuing problem: - locally we open a stream, our local Drain is sending OpenStream - remote Sink will know this and notify remote Drain - remote side sends a message - local sink does not know about the Stream. as there is (and CANT) be a wat to notify local Sink from local Drain (it could introduce race conditions). One of the possible solutions was, that the remote drain will copy the OpenStream Msg ON the Quic::stream before first data is send. This would work but is complicated. Instead we now just mark such streams as "potentially open" and we listen for the first DataHeader to get it's SID. add support for unreliable messages in quic protocol, benchmarks --- Cargo.lock | 53 +- network/Cargo.toml | 5 +- network/protocol/benches/protocols.rs | 89 ++- network/protocol/src/lib.rs | 13 +- network/protocol/src/prio.rs | 14 +- network/protocol/src/quic.rs | 955 ++++++++++++++++++++++++++ network/protocol/src/tcp.rs | 2 +- network/protocol/src/util.rs | 71 ++ network/src/channel.rs | 58 ++ network/src/participant.rs | 3 +- network/src/util.rs | 71 -- 11 files changed, 1246 insertions(+), 88 deletions(-) create mode 100644 network/protocol/src/quic.rs create mode 100644 network/protocol/src/util.rs diff --git a/Cargo.lock b/Cargo.lock index bf345408d0..5403842ba5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2268,7 +2268,7 @@ dependencies = [ "httpdate", "itoa", "pin-project", - "socket2", + "socket2 0.4.0", "tokio", "tower-service", "tracing", @@ -3854,6 +3854,45 @@ dependencies = [ "tracing", ] +[[package]] +name = "quinn" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c82c0a393b300104f989f3db8b8637c0d11f7a32a9c214560b47849ba8f119aa" +dependencies = [ + "bytes", + "futures", + "lazy_static", + "libc", + "mio 0.7.11", + "quinn-proto", + "rustls", + "socket2 0.3.19", + "thiserror", + "tokio", + "tracing", + "webpki", +] + +[[package]] +name = "quinn-proto" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09027365a21874b71e1fbd9d31cb99bff8e11ba81cc9ef2b9425bb607e42d3b2" +dependencies = [ + "bytes", + "ct-logs", + "rand 0.8.3", + "ring", + "rustls", + "rustls-native-certs", + "slab", + "thiserror", + "tinyvec", + "tracing", + "webpki", +] + [[package]] name = "quote" version = "0.6.13" @@ -4673,6 +4712,17 @@ dependencies = [ "wayland-client 0.28.5", ] +[[package]] +name = "socket2" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "122e570113d28d773067fab24266b66753f6ea915758651696b6e35e49f88d6e" +dependencies = [ + "cfg-if 1.0.0", + "libc", + "winapi 0.3.9", +] + [[package]] name = "socket2" version = "0.4.0" @@ -5576,6 +5626,7 @@ dependencies = [ "lz-fear", "prometheus", "prometheus-hyper", + "quinn", "rand 0.8.3", "serde", "shellexpand", diff --git a/network/Cargo.toml b/network/Cargo.toml index 7f0c45c073..a51119a16a 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -9,8 +9,9 @@ edition = "2018" [features] metrics = ["prometheus", "network-protocol/metrics"] compression = ["lz-fear"] +quic = ["quinn"] -default = ["metrics","compression"] +default = ["metrics","compression","quinn"] [dependencies] @@ -33,6 +34,8 @@ async-channel = "1.5.1" #use for .close() channels #mpsc channel registry lazy_static = { version = "1.4", default-features = false } rand = { version = "0.8" } +#quic support +quinn = { version = "0.7.2", optional = true } #stream flags bitflags = "1.2.1" lz-fear = { version = "0.1.1", optional = true } diff --git a/network/protocol/benches/protocols.rs b/network/protocol/benches/protocols.rs index d8859943ed..cb1dc02038 100644 --- a/network/protocol/benches/protocols.rs +++ b/network/protocol/benches/protocols.rs @@ -6,8 +6,9 @@ use std::{sync::Arc, time::Duration}; use tokio::runtime::Runtime; use veloren_network_protocol::{ InitProtocol, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid, Promises, ProtocolError, - ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, RecvProtocol, SendProtocol, Sid, - TcpRecvProtocol, TcpSendProtocol, UnreliableDrain, UnreliableSink, _internal::OTFrame, + ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, QuicDataFormat, QuicRecvProtocol, + QuicSendProtocol, RecvProtocol, SendProtocol, Sid, TcpRecvProtocol, TcpSendProtocol, + UnreliableDrain, UnreliableSink, _internal::OTFrame, }; fn frame_serialize(frame: OTFrame, buffer: &mut BytesMut) { frame.write_bytes(buffer); } @@ -145,7 +146,35 @@ fn criterion_tcp(c: &mut Criterion) { c.finish(); } -criterion_group!(benches, criterion_util, criterion_mpsc, criterion_tcp); +fn criterion_quic(c: &mut Criterion) { + let mut c = c.benchmark_group("quic"); + c.significance_level(0.1).sample_size(10); + c.throughput(Throughput::Bytes(1000000000)) + .bench_function("1GB_in_10000_msg", |b| { + let buf = Bytes::from(&[155u8; 100_000][..]); + b.to_async(rt()).iter_with_setup( + || (buf.clone(), utils::quic_bound(10000, None)), + |(b, p)| send_and_recv_msg(p, b, 10_000), + ) + }); + c.throughput(Throughput::Elements(1000000)) + .bench_function("1000000_tiny_msg", |b| { + let buf = Bytes::from(&[3u8; 5][..]); + b.to_async(rt()).iter_with_setup( + || (buf.clone(), utils::quic_bound(10000, None)), + |(b, p)| send_and_recv_msg(p, b, 1_000_000), + ) + }); + c.finish(); +} + +criterion_group!( + benches, + criterion_util, + criterion_mpsc, + criterion_tcp, + criterion_quic +); criterion_main!(benches); mod utils { @@ -210,6 +239,36 @@ mod utils { ] } + pub struct QuicDrain { + pub sender: Sender<QuicDataFormat>, + } + + pub struct QuicSink { + pub receiver: Receiver<QuicDataFormat>, + } + + /// emulate Quic protocol on Channels + pub fn quic_bound( + cap: usize, + metrics: Option<ProtocolMetricCache>, + ) -> [(QuicSendProtocol<QuicDrain>, QuicRecvProtocol<QuicSink>); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + QuicSendProtocol::new(QuicDrain { sender: s1 }, m.clone()), + QuicRecvProtocol::new(QuicSink { receiver: r2 }, m.clone()), + ), + ( + QuicSendProtocol::new(QuicDrain { sender: s2 }, m.clone()), + QuicRecvProtocol::new(QuicSink { receiver: r1 }, m), + ), + ] + } + #[async_trait] impl UnreliableDrain for ACDrain { type DataFormat = MpscMsg; @@ -257,4 +316,28 @@ mod utils { .map_err(|_| ProtocolError::Closed) } } + + #[async_trait] + impl UnreliableDrain for QuicDrain { + type DataFormat = QuicDataFormat; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for QuicSink { + type DataFormat = QuicDataFormat; + + async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } } diff --git a/network/protocol/src/lib.rs b/network/protocol/src/lib.rs index 79c1ae867a..3c2eb70c75 100644 --- a/network/protocol/src/lib.rs +++ b/network/protocol/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(drain_filter)] //! Network Protocol //! //! a I/O-Free protocol for the veloren network crate. @@ -13,9 +14,9 @@ //! This crate currently defines: //! - TCP //! - MPSC +//! - QUIC //! -//! a UDP implementation will quickly follow, and it's also possible to abstract -//! over QUIC. +//! eventually a pure UDP implementation will follow //! //! warning: don't mix protocol, using the TCP variant for actual UDP socket //! will result in dropped data using UDP with a TCP socket will be a waste of @@ -57,8 +58,10 @@ mod message; mod metrics; mod mpsc; mod prio; +mod quic; mod tcp; mod types; +mod util; pub use error::{InitProtocolError, ProtocolError}; pub use event::ProtocolEvent; @@ -66,12 +69,16 @@ pub use metrics::ProtocolMetricCache; #[cfg(feature = "metrics")] pub use metrics::ProtocolMetrics; pub use mpsc::{MpscMsg, MpscRecvProtocol, MpscSendProtocol}; +pub use quic::{QuicDataFormat, QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol}; pub use tcp::{TcpRecvProtocol, TcpSendProtocol}; pub use types::{Bandwidth, Cid, Pid, Prio, Promises, Sid, HIGHEST_PRIO, VELOREN_NETWORK_VERSION}; ///use at own risk, might change any time, for internal benchmarks pub mod _internal { - pub use crate::frame::{ITFrame, OTFrame}; + pub use crate::{ + frame::{ITFrame, OTFrame}, + util::SortedVec, + }; } use async_trait::async_trait; diff --git a/network/protocol/src/prio.rs b/network/protocol/src/prio.rs index 374a1ac216..7304086a1b 100644 --- a/network/protocol/src/prio.rs +++ b/network/protocol/src/prio.rs @@ -75,7 +75,7 @@ impl PrioManager { /// bandwidth might be extended, as for technical reasons /// guaranteed_bandwidth is used and frames are always 1400 bytes. - pub fn grab(&mut self, bandwidth: Bandwidth, dt: Duration) -> (Vec<OTFrame>, Bandwidth) { + pub fn grab(&mut self, bandwidth: Bandwidth, dt: Duration) -> (Vec<(Sid, OTFrame)>, Bandwidth) { let total_bytes = (bandwidth as f64 * dt.as_secs_f64()) as u64; let mut cur_bytes = 0u64; let mut frames = vec![]; @@ -84,7 +84,7 @@ impl PrioManager { let metrics = &mut self.metrics; let mut process_stream = - |stream: &mut StreamInfo, mut bandwidth: i64, cur_bytes: &mut u64| { + |sid: &Sid, stream: &mut StreamInfo, mut bandwidth: i64, cur_bytes: &mut u64| { let mut finished = None; 'outer: for (i, msg) in stream.messages.iter_mut().enumerate() { while let Some(frame) = msg.next() { @@ -95,7 +95,7 @@ impl PrioManager { } as u64; bandwidth -= b as i64; *cur_bytes += b; - frames.push(frame); + frames.push((*sid, frame)); if bandwidth <= 0 { break 'outer; } @@ -111,10 +111,10 @@ impl PrioManager { }; // Add guaranteed bandwidth - for stream in self.streams.values_mut() { + for (sid, stream) in self.streams.iter_mut() { prios[stream.prio as usize] += 1; let stream_byte_cnt = (stream.guaranteed_bandwidth as f64 * dt.as_secs_f64()) as u64; - process_stream(stream, stream_byte_cnt as i64, &mut cur_bytes); + process_stream(sid, stream, stream_byte_cnt as i64, &mut cur_bytes); } if cur_bytes < total_bytes { @@ -124,11 +124,11 @@ impl PrioManager { continue; } let per_stream_bytes = ((total_bytes - cur_bytes) / prios[prio as usize]) as i64; - for stream in self.streams.values_mut() { + for (sid, stream) in self.streams.iter_mut() { if stream.prio != prio { continue; } - process_stream(stream, per_stream_bytes, &mut cur_bytes); + process_stream(sid, stream, per_stream_bytes, &mut cur_bytes); } } } diff --git a/network/protocol/src/quic.rs b/network/protocol/src/quic.rs new file mode 100644 index 0000000000..d2be37c010 --- /dev/null +++ b/network/protocol/src/quic.rs @@ -0,0 +1,955 @@ +use crate::{ + error::ProtocolError, + event::ProtocolEvent, + frame::{ITFrame, InitFrame, OTFrame}, + handshake::{ReliableDrain, ReliableSink}, + message::{ITMessage, ALLOC_BLOCK}, + metrics::{ProtocolMetricCache, RemoveReason}, + prio::PrioManager, + types::{Bandwidth, Mid, Promises, Sid}, + util::SortedVec, + RecvProtocol, SendProtocol, UnreliableDrain, UnreliableSink, +}; +use async_trait::async_trait; +use bytes::BytesMut; +use std::{ + collections::HashMap, + time::{Duration, Instant}, +}; +use tracing::info; +#[cfg(feature = "trace_pedantic")] +use tracing::trace; + +#[derive(PartialEq)] +pub enum QuicDataFormatStream { + Main, + Reliable(u64), + Unreliable, +} + +pub struct QuicDataFormat { + stream: QuicDataFormatStream, + data: BytesMut, +} + +impl QuicDataFormat { + fn with_main(buffer: &mut BytesMut) -> Self { + Self { + stream: QuicDataFormatStream::Main, + data: buffer.split(), + } + } + + fn with_reliable(buffer: &mut BytesMut, id: u64) -> Self { + Self { + stream: QuicDataFormatStream::Reliable(id), + data: buffer.split(), + } + } + + fn with_unreliable(frame: OTFrame) -> Self { + let mut buffer = BytesMut::new(); + frame.write_bytes(&mut buffer); + Self { + stream: QuicDataFormatStream::Unreliable, + data: buffer, + } + } +} + +/// QUIC implementation of [`SendProtocol`] +/// +/// [`SendProtocol`]: crate::SendProtocol +#[derive(Debug)] +pub struct QuicSendProtocol<D> +where + D: UnreliableDrain<DataFormat = QuicDataFormat>, +{ + main_buffer: BytesMut, + reliable_buffers: SortedVec<Sid, BytesMut>, + store: PrioManager, + next_mid: Mid, + closing_streams: Vec<Sid>, + notify_closing_streams: Vec<Sid>, + pending_shutdown: bool, + drain: D, + last: Instant, + metrics: ProtocolMetricCache, +} + +/// QUIC implementation of [`RecvProtocol`] +/// +/// [`RecvProtocol`]: crate::RecvProtocol +#[derive(Debug)] +pub struct QuicRecvProtocol<S> +where + S: UnreliableSink<DataFormat = QuicDataFormat>, +{ + main_buffer: BytesMut, + unreliable_buffer: BytesMut, + reliable_buffers: SortedVec<Sid, BytesMut>, + pending_reliable_buffers: Vec<(u64, BytesMut)>, + itmsg_allocator: BytesMut, + incoming: HashMap<Mid, ITMessage>, + sink: S, + metrics: ProtocolMetricCache, +} + +impl<D> QuicSendProtocol<D> +where + D: UnreliableDrain<DataFormat = QuicDataFormat>, +{ + pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self { + Self { + main_buffer: BytesMut::new(), + reliable_buffers: SortedVec::default(), + store: PrioManager::new(metrics.clone()), + next_mid: 0u64, + closing_streams: vec![], + notify_closing_streams: vec![], + pending_shutdown: false, + drain, + last: Instant::now(), + metrics, + } + } + + /// returns all promises that this Protocol can take care of + /// If you open a Stream anyway, unsupported promises are ignored. + pub fn supported_promises() -> Promises { + Promises::ORDERED + | Promises::CONSISTENCY + | Promises::GUARANTEED_DELIVERY + | Promises::COMPRESSED + | Promises::ENCRYPTED + } +} + +impl<S> QuicRecvProtocol<S> +where + S: UnreliableSink<DataFormat = QuicDataFormat>, +{ + pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self { + Self { + main_buffer: BytesMut::new(), + unreliable_buffer: BytesMut::new(), + reliable_buffers: SortedVec::default(), + pending_reliable_buffers: vec![], + itmsg_allocator: BytesMut::with_capacity(ALLOC_BLOCK), + incoming: HashMap::new(), + sink, + metrics, + } + } + + async fn recv_into_stream(&mut self) -> Result<QuicDataFormatStream, ProtocolError> { + let chunk = self.sink.recv().await?; + let buffer = match chunk.stream { + QuicDataFormatStream::Main => &mut self.main_buffer, + QuicDataFormatStream::Unreliable => &mut self.unreliable_buffer, + QuicDataFormatStream::Reliable(id) => { + match self.reliable_buffers.data.get_mut(id as usize) { + Some((_, buffer)) => buffer, + None => { + self.pending_reliable_buffers.push((id, BytesMut::new())); + //Violated but will never happen + &mut self + .pending_reliable_buffers + .last_mut() + .ok_or(ProtocolError::Violated)? + .1 + }, + } + }, + }; + if buffer.is_empty() { + *buffer = chunk.data + } else { + buffer.extend_from_slice(&chunk.data) + } + Ok(chunk.stream) + } +} + +#[async_trait] +impl<D> SendProtocol for QuicSendProtocol<D> +where + D: UnreliableDrain<DataFormat = QuicDataFormat>, +{ + fn notify_from_recv(&mut self, event: ProtocolEvent) { + match event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + self.store + .open_stream(sid, prio, promises, guaranteed_bandwidth); + if promises.contains(Promises::ORDERED) + || promises.contains(Promises::CONSISTENCY) + || promises.contains(Promises::GUARANTEED_DELIVERY) + { + self.reliable_buffers.insert(sid, BytesMut::new()); + } + }, + ProtocolEvent::CloseStream { sid } => { + if !self.store.try_close_stream(sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "hold back notify close stream"); + self.notify_closing_streams.push(sid); + } + }, + _ => {}, + } + } + + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + #[cfg(feature = "trace_pedantic")] + trace!(?event, "send"); + match event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + self.store + .open_stream(sid, prio, promises, guaranteed_bandwidth); + if promises.contains(Promises::ORDERED) + || promises.contains(Promises::CONSISTENCY) + || promises.contains(Promises::GUARANTEED_DELIVERY) + { + self.reliable_buffers.insert(sid, BytesMut::new()); + } + event.to_frame().write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + }, + ProtocolEvent::CloseStream { sid } => { + if self.store.try_close_stream(sid) { + let _ = self.reliable_buffers.delete(&sid); //delete if it was reliable + event.to_frame().write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + } else { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "hold back close stream"); + self.closing_streams.push(sid); + } + }, + ProtocolEvent::Shutdown => { + if self.store.is_empty() { + event.to_frame().write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + } else { + #[cfg(feature = "trace_pedantic")] + trace!("hold back shutdown"); + self.pending_shutdown = true; + } + }, + ProtocolEvent::Message { data, sid } => { + self.metrics.smsg_ib(sid, data.len() as u64); + self.store.add(data, self.next_mid, sid); + self.next_mid += 1; + }, + } + Ok(()) + } + + async fn flush( + &mut self, + bandwidth: Bandwidth, + dt: Duration, + ) -> Result</* actual */ Bandwidth, ProtocolError> { + let (frames, _) = self.store.grab(bandwidth, dt); + //Todo: optimize reserve + let mut data_frames = 0; + let mut data_bandwidth = 0; + for (sid, frame) in frames { + if let OTFrame::Data { mid: _, data } = &frame { + data_bandwidth += data.len(); + data_frames += 1; + } + match self.reliable_buffers.get_mut(&sid) { + Some(buffer) => frame.write_bytes(buffer), + None => { + self.drain + .send(QuicDataFormat::with_unreliable(frame)) + .await? + }, + } + } + for (id, (_, buffer)) in self.reliable_buffers.data.iter_mut().enumerate() { + self.drain + .send(QuicDataFormat::with_reliable(buffer, id as u64)) + .await?; + } + self.metrics + .sdata_frames_b(data_frames, data_bandwidth as u64); + + let mut finished_streams = vec![]; + for (i, &sid) in self.closing_streams.iter().enumerate() { + if self.store.try_close_stream(sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "close stream, as it's now empty"); + OTFrame::CloseStream { sid }.write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + finished_streams.push(i); + } + } + for i in finished_streams.iter().rev() { + self.closing_streams.remove(*i); + } + + let mut finished_streams = vec![]; + for (i, sid) in self.notify_closing_streams.iter().enumerate() { + if self.store.try_close_stream(*sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "close stream, as it's now empty"); + finished_streams.push(i); + } + } + for i in finished_streams.iter().rev() { + self.notify_closing_streams.remove(*i); + } + + if self.pending_shutdown && self.store.is_empty() { + #[cfg(feature = "trace_pedantic")] + trace!("shutdown, as it's now empty"); + OTFrame::Shutdown {}.write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + self.pending_shutdown = false; + } + Ok(data_bandwidth as u64) + } +} + +#[async_trait] +impl<S> RecvProtocol for QuicRecvProtocol<S> +where + S: UnreliableSink<DataFormat = QuicDataFormat>, +{ + async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError> { + 'outer: loop { + loop { + match ITFrame::read_frame(&mut self.main_buffer) { + Ok(Some(frame)) => { + #[cfg(feature = "trace_pedantic")] + trace!(?frame, "recv"); + match frame { + ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), + ITFrame::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + if promises.contains(Promises::ORDERED) + || promises.contains(Promises::CONSISTENCY) + || promises.contains(Promises::GUARANTEED_DELIVERY) + { + self.reliable_buffers.insert(sid, BytesMut::new()); + } + break 'outer Ok(ProtocolEvent::OpenStream { + sid, + prio: prio.min(crate::types::HIGHEST_PRIO), + promises, + guaranteed_bandwidth, + }); + }, + ITFrame::CloseStream { sid } => { + //FIXME: defer close! + //let _ = self.reliable_buffers.delete(sid); // if it was reliable + break 'outer Ok(ProtocolEvent::CloseStream { sid }); + }, + _ => break 'outer Err(ProtocolError::Violated), + }; + }, + Ok(None) => break, //inner => read more data + Err(()) => return Err(ProtocolError::Violated), + } + } + + // try to order pending + let mut pending_violated = false; + let mut reliable = vec![]; + self.pending_reliable_buffers.drain_filter(|(_, buffer)| { + // try to get Sid without touching buffer + let mut testbuffer = buffer.clone(); + match ITFrame::read_frame(&mut testbuffer) { + Ok(Some(ITFrame::DataHeader { + sid, + mid: _, + length: _, + })) => { + reliable.push((sid, buffer.clone())); + true + }, + Ok(Some(_)) | Err(_) => { + pending_violated = true; + true + }, + Ok(None) => false, + } + }); + if pending_violated { + break 'outer Err(ProtocolError::Violated); + } + for (sid, buffer) in reliable.into_iter() { + self.reliable_buffers.insert(sid, buffer) + } + + let mut iter = self + .reliable_buffers + .data + .iter_mut() + .map(|(_, b)| (b, true)) + .collect::<Vec<_>>(); + iter.push((&mut self.unreliable_buffer, false)); + + for (buffer, reliable) in iter { + loop { + match ITFrame::read_frame(buffer) { + Ok(Some(frame)) => { + #[cfg(feature = "trace_pedantic")] + trace!(?frame, "recv"); + match frame { + ITFrame::DataHeader { sid, mid, length } => { + let m = ITMessage::new(sid, length, &mut self.itmsg_allocator); + self.metrics.rmsg_ib(sid, length); + self.incoming.insert(mid, m); + }, + ITFrame::Data { mid, data } => { + self.metrics.rdata_frames_b(data.len() as u64); + let m = match self.incoming.get_mut(&mid) { + Some(m) => m, + None => { + if reliable { + info!( + ?mid, + "protocol violation by remote side: send Data before \ + Header" + ); + break 'outer Err(ProtocolError::Violated); + } else { + //TODO: cleanup old messages from time to time + continue; + } + }, + }; + m.data.extend_from_slice(&data); + if m.data.len() == m.length as usize { + // finished, yay + let m = self.incoming.remove(&mid).unwrap(); + self.metrics.rmsg_ob( + m.sid, + RemoveReason::Finished, + m.data.len() as u64, + ); + break 'outer Ok(ProtocolEvent::Message { + sid: m.sid, + data: m.data.freeze(), + }); + } + }, + _ => break 'outer Err(ProtocolError::Violated), + }; + }, + Ok(None) => break, //inner => read more data + Err(()) => return Err(ProtocolError::Violated), + } + } + } + + self.recv_into_stream().await?; + } + } +} + +#[async_trait] +impl<D> ReliableDrain for QuicSendProtocol<D> +where + D: UnreliableDrain<DataFormat = QuicDataFormat>, +{ + async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> { + self.main_buffer.reserve(500); + frame.write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await + } +} + +#[async_trait] +impl<S> ReliableSink for QuicRecvProtocol<S> +where + S: UnreliableSink<DataFormat = QuicDataFormat>, +{ + async fn recv(&mut self) -> Result<InitFrame, ProtocolError> { + while self.main_buffer.len() < 100 { + if self.recv_into_stream().await? == QuicDataFormatStream::Main { + if let Some(frame) = InitFrame::read_frame(&mut self.main_buffer) { + return Ok(frame); + } + } + } + Err(ProtocolError::Violated) + } +} + +#[cfg(test)] +mod test_utils { + //Quic protocol based on Channel + use super::*; + use crate::metrics::{ProtocolMetricCache, ProtocolMetrics}; + use async_channel::*; + use std::sync::Arc; + + pub struct QuicDrain { + pub sender: Sender<QuicDataFormat>, + pub drop_ratio: f32, + } + + pub struct QuicSink { + pub receiver: Receiver<QuicDataFormat>, + } + + /// emulate Quic protocol on Channels + pub fn quic_bound( + cap: usize, + drop_ratio: f32, + metrics: Option<ProtocolMetricCache>, + ) -> [(QuicSendProtocol<QuicDrain>, QuicRecvProtocol<QuicSink>); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + QuicSendProtocol::new( + QuicDrain { + sender: s1, + drop_ratio, + }, + m.clone(), + ), + QuicRecvProtocol::new(QuicSink { receiver: r2 }, m.clone()), + ), + ( + QuicSendProtocol::new( + QuicDrain { + sender: s2, + drop_ratio, + }, + m.clone(), + ), + QuicRecvProtocol::new(QuicSink { receiver: r1 }, m), + ), + ] + } + + #[async_trait] + impl UnreliableDrain for QuicDrain { + type DataFormat = QuicDataFormat; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + use rand::Rng; + if matches!(data.stream, QuicDataFormatStream::Unreliable) + && rand::thread_rng().gen::<f32>() < self.drop_ratio + { + return Ok(()); + } + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for QuicSink { + type DataFormat = QuicDataFormat; + + async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + error::ProtocolError, + frame::OTFrame, + metrics::{ProtocolMetricCache, ProtocolMetrics, RemoveReason}, + quic::{test_utils::*, QuicDataFormat}, + types::{Pid, Promises, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2}, + InitProtocol, ProtocolEvent, RecvProtocol, SendProtocol, + }; + use bytes::{Bytes, BytesMut}; + use std::{sync::Arc, time::Duration}; + + #[tokio::test] + async fn handshake_all_good() { + let [mut p1, mut p2] = quic_bound(10, 0.5, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42))); + assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337))); + } + + #[tokio::test] + async fn open_stream() { + let [p1, p2] = quic_bound(10, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 0u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event.clone()).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn send_short_msg() { + let [p1, p2] = quic_bound(10, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + data: Bytes::from(&[188u8; 600][..]), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + // 2nd short message + let event = ProtocolEvent::Message { + sid: Sid::new(10), + data: Bytes::from(&[7u8; 30][..]), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e) + } + + #[tokio::test] + async fn send_long_msg() { + let mut metrics = + ProtocolMetricCache::new("long_quic", Arc::new(ProtocolMetrics::new().unwrap())); + let sid = Sid::new(1); + let [p1, p2] = quic_bound(10000, 0.5, Some(metrics.clone())); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + metrics.assert_msg(sid, 1, RemoveReason::Finished); + metrics.assert_msg_bytes(sid, 500_000, RemoveReason::Finished); + metrics.assert_data_frames(358); + metrics.assert_data_frames_bytes(500_000); + } + + #[tokio::test] + async fn msg_finishes_after_close() { + let sid = Sid::new(1); + let [p1, p2] = quic_bound(10000, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::CloseStream { sid }; + s.send(event).await.unwrap(); + //send + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + } + + #[tokio::test] + async fn msg_finishes_after_shutdown() { + let sid = Sid::new(1); + let [p1, p2] = quic_bound(10000, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::Shutdown {}; + s.send(event).await.unwrap(); + let event = ProtocolEvent::CloseStream { sid }; + s.send(event).await.unwrap(); + //send + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Shutdown { .. })); + } + + #[tokio::test] + async fn msg_finishes_after_drop() { + let sid = Sid::new(1); + let [p1, p2] = quic_bound(10000, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[100u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + drop(s); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + } + + #[tokio::test] + async fn header_and_data_in_seperate_msg() { + let sid = Sid::new(1); + let (s, r) = async_channel::bounded(10); + let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap())); + let mut r = + super::QuicRecvProtocol::new(super::test_utils::QuicSink { receiver: r }, m.clone()); + + const DATA1: &[u8; 69] = + b"We need to make sure that its okay to send OPEN_STREAM and DATA_HEAD "; + const DATA2: &[u8; 95] = b"in one chunk and (DATA and CLOSE_STREAM) in the second chunk. and then keep the connection open"; + let mut bytes = BytesMut::with_capacity(1500); + OTFrame::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + } + .write_bytes(&mut bytes); + s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap(); + + OTFrame::DataHeader { + mid: 99, + sid, + length: (DATA1.len() + DATA2.len()) as u64, + } + .write_bytes(&mut bytes); + s.send(QuicDataFormat::with_reliable(&mut bytes, 0)) + .await + .unwrap(); + + OTFrame::Data { + mid: 99, + data: Bytes::from(&DATA1[..]), + } + .write_bytes(&mut bytes); + OTFrame::Data { + mid: 99, + data: Bytes::from(&DATA2[..]), + } + .write_bytes(&mut bytes); + s.send(QuicDataFormat::with_reliable(&mut bytes, 0)) + .await + .unwrap(); + + OTFrame::CloseStream { sid }.write_bytes(&mut bytes); + s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap(); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + } + + #[tokio::test] + async fn drop_sink_while_recv() { + let sid = Sid::new(1); + let (s, r) = async_channel::bounded(10); + let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap())); + let mut r = + super::QuicRecvProtocol::new(super::test_utils::QuicSink { receiver: r }, m.clone()); + + let mut bytes = BytesMut::with_capacity(1500); + OTFrame::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 1_000_000, + } + .write_bytes(&mut bytes); + s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + + let e = tokio::spawn(async move { r.recv().await }); + drop(s); + + let e = e.await.unwrap(); + assert_eq!(e, Err(ProtocolError::Closed)); + } + + #[tokio::test] + #[should_panic] + async fn send_on_stream_from_remote_without_notify() { + //remote opens stream + //we send on it + let [mut p1, mut p2] = quic_bound(10, 0.5, None); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + p1.0.send(event).await.unwrap(); + let _ = p2.1.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + data: Bytes::from(&[188u8; 600][..]), + }; + p2.0.send(event.clone()).await.unwrap(); + p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = p1.1.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn send_on_stream_from_remote() { + //remote opens stream + //we send on it + let [mut p1, mut p2] = quic_bound(10, 0.5, None); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + p1.0.send(event).await.unwrap(); + let e = p2.1.recv().await.unwrap(); + p2.0.notify_from_recv(e); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + data: Bytes::from(&[188u8; 600][..]), + }; + p2.0.send(event.clone()).await.unwrap(); + p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = p1.1.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn unrealiable_test() { + const MIN_CHECK: usize = 10; + const COUNT: usize = 10_000; + //We send COUNT msg with 50% of be send each. we check that >= MIN_CHECK && != + // COUNT reach their target + + let [mut p1, mut p2] = quic_bound( + COUNT * 2 - 1, /* 2 times as it is HEADER + DATA but -1 as we want to see not all + * succeed */ + 0.5, + None, + ); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(1337), + prio: 3u8, + promises: Promises::empty(), /* on purpose! */ + guaranteed_bandwidth: 1_000_000, + }; + p1.0.send(event).await.unwrap(); + let e = p2.1.recv().await.unwrap(); + p2.0.notify_from_recv(e); + let event = ProtocolEvent::Message { + sid: Sid::new(1337), + data: Bytes::from(&[188u8; 600][..]), + }; + for _ in 0..COUNT { + p2.0.send(event.clone()).await.unwrap(); + } + p2.0.flush(1_000_000_000, Duration::from_secs(1)) + .await + .unwrap(); + for _ in 0..COUNT { + p2.0.send(event.clone()).await.unwrap(); + } + for _ in 0..MIN_CHECK { + let e = p1.1.recv().await.unwrap(); + assert_eq!(event, e); + } + } +} diff --git a/network/protocol/src/tcp.rs b/network/protocol/src/tcp.rs index 43d14e2a1e..e6c74df4d4 100644 --- a/network/protocol/src/tcp.rs +++ b/network/protocol/src/tcp.rs @@ -176,7 +176,7 @@ where self.buffer.reserve(total_bytes as usize); let mut data_frames = 0; let mut data_bandwidth = 0; - for frame in frames { + for (_, frame) in frames { if let OTFrame::Data { mid: _, data } = &frame { data_bandwidth += data.len(); data_frames += 1; diff --git a/network/protocol/src/util.rs b/network/protocol/src/util.rs new file mode 100644 index 0000000000..1e28d4c4ab --- /dev/null +++ b/network/protocol/src/util.rs @@ -0,0 +1,71 @@ +/// Used for storing Buffers in a QUIC +#[derive(Debug)] +pub struct SortedVec<K, V> { + pub data: Vec<(K, V)>, +} + +impl<K, V> Default for SortedVec<K, V> { + fn default() -> Self { Self { data: vec![] } } +} + +impl<K, V> SortedVec<K, V> +where + K: Ord + Copy, +{ + pub fn insert(&mut self, k: K, v: V) { + self.data.push((k, v)); + self.data.sort_by_key(|&(k, _)| k); + } + + pub fn delete(&mut self, k: &K) -> Option<V> { + if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { + Some(self.data.remove(i).1) + } else { + None + } + } + + pub fn get(&self, k: &K) -> Option<&V> { + if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { + Some(&self.data[i].1) + } else { + None + } + } + + pub fn get_mut(&mut self, k: &K) -> Option<&mut V> { + if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { + Some(&mut self.data[i].1) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sorted_vec() { + let mut vec = SortedVec::default(); + vec.insert(10, "Hello"); + println!("{:?}", vec.data); + vec.insert(30, "World"); + println!("{:?}", vec.data); + vec.insert(20, " "); + println!("{:?}", vec.data); + assert_eq!(vec.data[0].1, "Hello"); + assert_eq!(vec.data[1].1, " "); + assert_eq!(vec.data[2].1, "World"); + assert_eq!(vec.get(&30), Some(&"World")); + assert_eq!(vec.get_mut(&20), Some(&mut " ")); + assert_eq!(vec.get(&10), Some(&"Hello")); + assert_eq!(vec.delete(&40), None); + assert_eq!(vec.delete(&10), Some("Hello")); + assert_eq!(vec.delete(&10), None); + assert_eq!(vec.get(&30), Some(&"World")); + assert_eq!(vec.get_mut(&20), Some(&mut " ")); + assert_eq!(vec.get(&10), None); + } +} diff --git a/network/src/channel.rs b/network/src/channel.rs index cf7a7851bd..ce6ee08cab 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -5,6 +5,7 @@ use network_protocol::{ ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol, TcpSendProtocol, UnreliableDrain, UnreliableSink, }; +#[cfg(feature = "quic")] use quinn::*; use std::{sync::Arc, time::Duration}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, @@ -16,18 +17,24 @@ use tokio::{ pub(crate) enum Protocols { Tcp((TcpSendProtocol<TcpDrain>, TcpRecvProtocol<TcpSink>)), Mpsc((MpscSendProtocol<MpscDrain>, MpscRecvProtocol<MpscSink>)), + #[cfg(feature = "quic")] + Quic((QuicSendProtocol<QuicDrain>, QuicRecvProtocol<QuicSink>)), } #[derive(Debug)] pub(crate) enum SendProtocols { Tcp(TcpSendProtocol<TcpDrain>), Mpsc(MpscSendProtocol<MpscDrain>), + #[cfg(feature = "quic")] + Quic(QuicSendProtocol<QuicDrain>), } #[derive(Debug)] pub(crate) enum RecvProtocols { Tcp(TcpRecvProtocol<TcpSink>), Mpsc(MpscRecvProtocol<MpscSink>), + #[cfg(feature = "quic")] + Quic(QuicSendProtocol<QuicDrain>), } impl Protocols { @@ -67,6 +74,8 @@ impl Protocols { match self { Protocols::Tcp((s, r)) => (SendProtocols::Tcp(s), RecvProtocols::Tcp(r)), Protocols::Mpsc((s, r)) => (SendProtocols::Mpsc(s), RecvProtocols::Mpsc(r)), + #[cfg(feature = "quic")] + Protocols::Quic((s, r)) => (SendProtocols::Quic(s), RecvProtocols::Quic(r)), } } } @@ -82,6 +91,8 @@ impl network_protocol::InitProtocol for Protocols { match self { Protocols::Tcp(p) => p.initialize(initializer, local_pid, secret).await, Protocols::Mpsc(p) => p.initialize(initializer, local_pid, secret).await, + #[cfg(feature = "quic")] + Protocols::Quic(p) => p.initialize(initializer, local_pid, secret).await, } } } @@ -92,6 +103,8 @@ impl network_protocol::SendProtocol for SendProtocols { match self { SendProtocols::Tcp(s) => s.notify_from_recv(event), SendProtocols::Mpsc(s) => s.notify_from_recv(event), + #[cfg(feature = "quic")] + SendProtocols::Quic(s) => s.notify_from_recv(event), } } @@ -99,6 +112,8 @@ impl network_protocol::SendProtocol for SendProtocols { match self { SendProtocols::Tcp(s) => s.send(event).await, SendProtocols::Mpsc(s) => s.send(event).await, + #[cfg(feature = "quic")] + SendProtocols::Quic(s) => s.send(event).await, } } @@ -110,6 +125,8 @@ impl network_protocol::SendProtocol for SendProtocols { match self { SendProtocols::Tcp(s) => s.flush(bandwidth, dt).await, SendProtocols::Mpsc(s) => s.flush(bandwidth, dt).await, + #[cfg(feature = "quic")] + SendProtocols::Quic(s) => s.flush(bandwidth, dt).await, } } } @@ -120,6 +137,8 @@ impl network_protocol::RecvProtocol for RecvProtocols { match self { RecvProtocols::Tcp(r) => r.recv().await, RecvProtocols::Mpsc(r) => r.recv().await, + #[cfg(feature = "quic")] + RecvProtocols::Quic(r) => r.recv().await, } } } @@ -196,6 +215,45 @@ impl UnreliableSink for MpscSink { } } +/////////////////////////////////////// +//// QUIC +#[derive(Debug)] +pub struct QuicDrain { + half: OwnedWriteHalf, +} + +#[derive(Debug)] +pub struct QuicSink { + half: OwnedReadHalf, + buffer: BytesMut, +} + +#[async_trait] +impl UnreliableDrain for QuicDrain { + type DataFormat = BytesMut; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + match self.half.write_all(&data).await { + Ok(()) => Ok(()), + Err(_) => Err(ProtocolError::Closed), + } + } +} + +#[async_trait] +impl UnreliableSink for QuicSink { + type DataFormat = BytesMut; + + async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> { + self.buffer.resize(1500, 0u8); + match self.half.read(&mut self.buffer).await { + Ok(0) => Err(ProtocolError::Closed), + Ok(n) => Ok(self.buffer.split_to(n)), + Err(_) => Err(ProtocolError::Closed), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/network/src/participant.rs b/network/src/participant.rs index afa30266e8..a06321201c 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -2,12 +2,13 @@ use crate::{ api::{ParticipantError, Stream}, channel::{Protocols, RecvProtocols, SendProtocols}, metrics::NetworkMetrics, - util::{DeferredTracer, SortedVec}, + util::DeferredTracer, }; use bytes::Bytes; use futures_util::{FutureExt, StreamExt}; use network_protocol::{ Bandwidth, Cid, Pid, Prio, Promises, ProtocolEvent, RecvProtocol, SendProtocol, Sid, + _internal::SortedVec, }; use std::{ collections::HashMap, diff --git a/network/src/util.rs b/network/src/util.rs index b9a8801263..640d65ee55 100644 --- a/network/src/util.rs +++ b/network/src/util.rs @@ -44,74 +44,3 @@ impl<T: Eq + Hash> DeferredTracer<T> { } } } - -/// Used for storing Protocols in a Participant or Stream <-> Protocol -pub(crate) struct SortedVec<K, V> { - pub data: Vec<(K, V)>, -} - -impl<K, V> Default for SortedVec<K, V> { - fn default() -> Self { Self { data: vec![] } } -} - -impl<K, V> SortedVec<K, V> -where - K: Ord + Copy, -{ - pub fn insert(&mut self, k: K, v: V) { - self.data.push((k, v)); - self.data.sort_by_key(|&(k, _)| k); - } - - pub fn delete(&mut self, k: &K) -> Option<V> { - if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { - Some(self.data.remove(i).1) - } else { - None - } - } - - pub fn get(&self, k: &K) -> Option<&V> { - if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { - Some(&self.data[i].1) - } else { - None - } - } - - pub fn get_mut(&mut self, k: &K) -> Option<&mut V> { - if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { - Some(&mut self.data[i].1) - } else { - None - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn sorted_vec() { - let mut vec = SortedVec::default(); - vec.insert(10, "Hello"); - println!("{:?}", vec.data); - vec.insert(30, "World"); - println!("{:?}", vec.data); - vec.insert(20, " "); - println!("{:?}", vec.data); - assert_eq!(vec.data[0].1, "Hello"); - assert_eq!(vec.data[1].1, " "); - assert_eq!(vec.data[2].1, "World"); - assert_eq!(vec.get(&30), Some(&"World")); - assert_eq!(vec.get_mut(&20), Some(&mut " ")); - assert_eq!(vec.get(&10), Some(&"Hello")); - assert_eq!(vec.delete(&40), None); - assert_eq!(vec.delete(&10), Some("Hello")); - assert_eq!(vec.delete(&10), None); - assert_eq!(vec.get(&30), Some(&"World")); - assert_eq!(vec.get_mut(&20), Some(&mut " ")); - assert_eq!(vec.get(&10), None); - } -} From d40261e38e1a0961384678aee25f3fdfbd1ded0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= <marcel.cochem@googlemail.com> Date: Sun, 11 Apr 2021 23:37:48 +0200 Subject: [PATCH 2/7] work on getting quic in the network --- network/Cargo.toml | 2 +- network/protocol/src/quic.rs | 4 +-- network/src/api.rs | 2 ++ network/src/channel.rs | 52 ++++++++++++++++++++++++++++++------ network/src/scheduler.rs | 41 ++++++++++++++++++++++++++++ 5 files changed, 90 insertions(+), 11 deletions(-) diff --git a/network/Cargo.toml b/network/Cargo.toml index a51119a16a..7f854f68a9 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -11,7 +11,7 @@ metrics = ["prometheus", "network-protocol/metrics"] compression = ["lz-fear"] quic = ["quinn"] -default = ["metrics","compression","quinn"] +default = ["metrics","compression","quic"] [dependencies] diff --git a/network/protocol/src/quic.rs b/network/protocol/src/quic.rs index d2be37c010..b4af04a193 100644 --- a/network/protocol/src/quic.rs +++ b/network/protocol/src/quic.rs @@ -28,8 +28,8 @@ pub enum QuicDataFormatStream { } pub struct QuicDataFormat { - stream: QuicDataFormatStream, - data: BytesMut, + pub stream: QuicDataFormatStream, + pub data: BytesMut, } impl QuicDataFormat { diff --git a/network/src/api.rs b/network/src/api.rs index e04094d6ce..d38318aa57 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -33,6 +33,8 @@ type A2sDisconnect = Arc<Mutex<Option<mpsc::UnboundedSender<(Pid, S2bShutdownBpa pub enum ProtocolAddr { Tcp(SocketAddr), Udp(SocketAddr), + #[cfg(feature = "quic")] + Quic(SocketAddr, quinn::ServerConfig), Mpsc(u64), } diff --git a/network/src/channel.rs b/network/src/channel.rs index ce6ee08cab..85bf824134 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -1,6 +1,7 @@ use async_trait::async_trait; use bytes::BytesMut; use network_protocol::{ + QuicDataFormat, QuicDataFormatStream, QuicSendProtocol, QuicRecvProtocol, Bandwidth, Cid, InitProtocolError, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid, ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol, TcpSendProtocol, UnreliableDrain, UnreliableSink, @@ -70,6 +71,31 @@ impl Protocols { Protocols::Mpsc((sp, rp)) } + #[cfg(feature = "quic")] + pub(crate) async fn new_quic( + connection: quinn::NewConnection, + cid: Cid, + metrics: Arc<ProtocolMetrics>, + ) -> Result<Self, quinn::ConnectionError> { + let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); + + let (sendstream, recvstream) = connection.connection.open_bi().await?; + + + let sp = QuicSendProtocol::new(QuicDrain { + con: connection.connection.clone(), + main: sendstream, + reliables: vec!(), + }, metrics.clone()); + let rp = QuicRecvProtocol::new(QuicSink { + con: connection.connection, + main: recvstream, + reliables: vec!(), + buffer: BytesMut::new(), + }, metrics); + Ok(Protocols::Quic((sp, rp))) + } + pub(crate) fn split(self) -> (SendProtocols, RecvProtocols) { match self { Protocols::Tcp((s, r)) => (SendProtocols::Tcp(s), RecvProtocols::Tcp(r)), @@ -219,21 +245,29 @@ impl UnreliableSink for MpscSink { //// QUIC #[derive(Debug)] pub struct QuicDrain { - half: OwnedWriteHalf, + con: quinn::Connection, + main: quinn::SendStream, + reliables: Vec<quinn::SendStream>, } #[derive(Debug)] pub struct QuicSink { - half: OwnedReadHalf, + con: quinn::Connection, + main: quinn::RecvStream, + reliables: Vec<quinn::RecvStream>, buffer: BytesMut, } #[async_trait] impl UnreliableDrain for QuicDrain { - type DataFormat = BytesMut; + type DataFormat = QuicDataFormat; async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { - match self.half.write_all(&data).await { + match match data.stream { + QuicDataFormatStream::Main => self.main.write_all(&data.data), + QuicDataFormatStream::Unreliable => unimplemented!(), + QuicDataFormatStream::Reliable(id) => self.reliables.get_mut(id as usize).ok_or(ProtocolError::Closed)?.write_all(&data.data), + }.await { Ok(()) => Ok(()), Err(_) => Err(ProtocolError::Closed), } @@ -242,13 +276,15 @@ impl UnreliableDrain for QuicDrain { #[async_trait] impl UnreliableSink for QuicSink { - type DataFormat = BytesMut; + type DataFormat = QuicDataFormat; async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> { self.buffer.resize(1500, 0u8); - match self.half.read(&mut self.buffer).await { - Ok(0) => Err(ProtocolError::Closed), - Ok(n) => Ok(self.buffer.split_to(n)), + //TODO improve + match self.main.read(&mut self.buffer).await { + Ok(Some(0)) => Err(ProtocolError::Closed), + Ok(Some(n)) => Ok(QuicDataFormat{stream: QuicDataFormatStream::Main, data: self.buffer.split_to(n)}), + Ok(None) => Err(ProtocolError::Closed), Err(_) => Err(ProtocolError::Closed), } } diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index 527ea6f5fe..11a2a0f774 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -431,6 +431,47 @@ impl Scheduler { .await; } }, + #[cfg(feature = "quic")] + ProtocolAddr::Quic(addr, server_config) => { + let mut endpoint = quinn::Endpoint::builder(); + endpoint.listen(server_config); + let (endpoint, mut listener) = match endpoint.bind(&addr) { + Ok((endpoint, listener)) => { + s2a_listen_result_s.send(Ok(())).unwrap(); + (endpoint, listener) + }, + Err(quinn::EndpointError::Socket(e)) => { + info!( + ?addr, + ?e, + "Quic bind error during listener startup" + ); + s2a_listen_result_s.send(Err(e)).unwrap(); + return; + } + }; + trace!(?addr, "Listener bound"); + let mut end_receiver = s2s_stop_listening_r.fuse(); + while let Some(Some(connecting)) = select! { + next = listener.next().fuse() => Some(next), + _ = &mut end_receiver => None, + } { + let remote_addr = connecting.remote_address(); + let connection = match connecting.await { + Ok(c) => c, + Err(e) => { + debug!(?e, ?remote_addr, "skipping connection attempt"); + continue; + }, + }; + #[cfg(feature = "metrics")] + mcache.inc(); + let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); + info!(?remote_addr, ?cid, "Accepting Quic from"); + self.init_protocol(Protocols::new_quic(connection, cid, Arc::clone(&self.protocol_metrics)), cid, None, true) + .await; + } + }, ProtocolAddr::Mpsc(addr) => { let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel(); MPSC_POOL.lock().await.insert(addr, mpsc_s); From 4d360a871c0d00042abeff3e210e4f1e5e77b455 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= <marcel.cochem@googlemail.com> Date: Thu, 15 Apr 2021 10:16:42 +0200 Subject: [PATCH 3/7] protocoladdr change for listen and connect (remove a loop in quic protocol which wasnt a actual loop) --- Cargo.lock | 33 ++++++ client/src/lib.rs | 6 +- network/Cargo.toml | 2 + network/benches/speed.rs | 17 +-- network/examples/chat.rs | 28 +++-- network/examples/fileshare/commands.rs | 4 +- network/examples/fileshare/main.rs | 8 +- network/examples/fileshare/server.rs | 4 +- network/examples/network-speed/main.rs | 30 +++-- network/protocol/src/quic.rs | 83 ++++++------- network/src/api.rs | 132 +++++++++++---------- network/src/channel.rs | 154 ++++++++++++++++++++----- network/src/lib.rs | 19 +-- network/src/message.rs | 6 +- network/src/metrics.rs | 60 +++++++--- network/src/scheduler.rs | 55 ++++++--- network/tests/closing.rs | 12 +- network/tests/helper.rs | 78 ++++++++++--- network/tests/integration.rs | 52 ++++++--- server/src/lib.rs | 6 +- 20 files changed, 534 insertions(+), 255 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5403842ba5..dc612eb409 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3631,6 +3631,17 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" +[[package]] +name = "pem" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd56cbd21fea48d0c440b41cd69c589faacade08c992d9a54e471b79d0fd13eb" +dependencies = [ + "base64", + "once_cell", + "regex", +] + [[package]] name = "percent-encoding" version = "2.1.0" @@ -4061,6 +4072,18 @@ dependencies = [ "num_cpus", ] +[[package]] +name = "rcgen" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e80a701a04edd9cab874a3d59323bebe24c9a92dd602088c78da83732066d1b" +dependencies = [ + "chrono", + "pem", + "ring", + "yasna", +] + [[package]] name = "redox_syscall" version = "0.1.57" @@ -5628,6 +5651,7 @@ dependencies = [ "prometheus-hyper", "quinn", "rand 0.8.3", + "rcgen", "serde", "shellexpand", "tokio", @@ -6638,3 +6662,12 @@ name = "xml-rs" version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b07db065a5cf61a7e4ba64f29e67db906fb1787316516c4e6e5ff0fea1efcd8a" + +[[package]] +name = "yasna" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de7bff972b4f2a06c85f6d8454b09df153af7e3a4ec2aac81db1b105b684ddb" +dependencies = [ + "chrono", +] diff --git a/client/src/lib.rs b/client/src/lib.rs index 51c577aaeb..35511879a4 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -61,7 +61,7 @@ use comp::BuffKind; use futures_util::FutureExt; use hashbrown::{HashMap, HashSet}; use image::DynamicImage; -use network::{Network, Participant, Pid, ProtocolAddr, Stream}; +use network::{ConnectAddr, Network, Participant, Pid, Stream}; use num::traits::FloatConst; use rayon::prelude::*; use specs::Component; @@ -217,7 +217,7 @@ impl Client { // Try to connect to all IP's and return the first that works let mut participant = None; for addr in addrs { - match network.connect(ProtocolAddr::Tcp(addr)).await { + match network.connect(ConnectAddr::Tcp(addr)).await { Ok(p) => { participant = Some(Ok(p)); break; @@ -228,7 +228,7 @@ impl Client { participant .unwrap_or_else(|| Err(Error::Other("No Ip Addr provided".to_string())))? }, - ConnectionArgs::Mpsc(id) => network.connect(ProtocolAddr::Mpsc(id)).await?, + ConnectionArgs::Mpsc(id) => network.connect(ConnectAddr::Mpsc(id)).await?, }; let stream = participant.opened().await?; diff --git a/network/Cargo.toml b/network/Cargo.toml index 7f854f68a9..bcb509aea1 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -52,6 +52,8 @@ shellexpand = "2.0.0" serde = { version = "1.0", features = ["derive"] } prometheus-hyper = "0.1.2" criterion = { version = "0.3.4", features = ["default", "async_tokio"] } +#quic +rcgen = { version = "0.8.10"} [[bench]] name = "speed" diff --git a/network/benches/speed.rs b/network/benches/speed.rs index b110d308ae..d7f0f2b63c 100644 --- a/network/benches/speed.rs +++ b/network/benches/speed.rs @@ -1,7 +1,9 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use std::{net::SocketAddr, sync::Arc}; use tokio::{runtime::Runtime, sync::Mutex}; -use veloren_network::{Message, Network, Participant, Pid, Promises, ProtocolAddr, Stream}; +use veloren_network::{ + ConnectAddr, ListenAddr, Message, Network, Participant, Pid, Promises, Stream, +}; fn serialize(data: &[u8], stream: &Stream) { let _ = Message::serialize(data, stream.params()); } @@ -30,7 +32,7 @@ fn criterion_util(c: &mut Criterion) { c.significance_level(0.1).sample_size(100); let (r, _n_a, p_a, s1_a, _n_b, _p_b, _s1_b) = - network_participant_stream(ProtocolAddr::Mpsc(5000)); + network_participant_stream((ListenAddr::Mpsc(5000), ConnectAddr::Mpsc(5000))); let s2_a = r.block_on(p_a.open(4, Promises::COMPRESSED, 0)).unwrap(); c.throughput(Throughput::Bytes(1000)) @@ -50,7 +52,7 @@ fn criterion_mpsc(c: &mut Criterion) { c.significance_level(0.1).sample_size(10); let (_r, _n_a, _p_a, s1_a, _n_b, _p_b, s1_b) = - network_participant_stream(ProtocolAddr::Mpsc(5000)); + network_participant_stream((ListenAddr::Mpsc(5000), ConnectAddr::Mpsc(5000))); let s1_a = Arc::new(Mutex::new(s1_a)); let s1_b = Arc::new(Mutex::new(s1_b)); @@ -82,8 +84,9 @@ fn criterion_tcp(c: &mut Criterion) { let mut c = c.benchmark_group("net_tcp"); c.significance_level(0.1).sample_size(10); + let socket_addr = SocketAddr::from(([127, 0, 0, 1], 5000)); let (_r, _n_a, _p_a, s1_a, _n_b, _p_b, s1_b) = - network_participant_stream(ProtocolAddr::Tcp(SocketAddr::from(([127, 0, 0, 1], 5000)))); + network_participant_stream((ListenAddr::Tcp(socket_addr), ConnectAddr::Tcp(socket_addr))); let s1_a = Arc::new(Mutex::new(s1_a)); let s1_b = Arc::new(Mutex::new(s1_b)); @@ -115,7 +118,7 @@ criterion_group!(benches, criterion_util, criterion_mpsc, criterion_tcp); criterion_main!(benches); pub fn network_participant_stream( - addr: ProtocolAddr, + addr: (ListenAddr, ConnectAddr), ) -> ( Runtime, Network, @@ -130,8 +133,8 @@ pub fn network_participant_stream( let n_a = Network::new(Pid::fake(0), &runtime); let n_b = Network::new(Pid::fake(1), &runtime); - n_a.listen(addr.clone()).await.unwrap(); - let p1_b = n_b.connect(addr).await.unwrap(); + n_a.listen(addr.0).await.unwrap(); + let p1_b = n_b.connect(addr.1).await.unwrap(); let p1_a = n_a.connected().await.unwrap(); let s1_a = p1_a.open(4, Promises::empty(), 0).await.unwrap(); diff --git a/network/examples/chat.rs b/network/examples/chat.rs index 8746479f73..2dc1e56e78 100644 --- a/network/examples/chat.rs +++ b/network/examples/chat.rs @@ -8,7 +8,7 @@ use std::{sync::Arc, thread, time::Duration}; use tokio::{io, io::AsyncBufReadExt, runtime::Runtime, sync::RwLock}; use tracing::*; use tracing_subscriber::EnvFilter; -use veloren_network::{Network, Participant, Pid, Promises, ProtocolAddr}; +use veloren_network::{ConnectAddr, ListenAddr, Network, Participant, Pid, Promises}; ///This example contains a simple chatserver, that allows to send messages /// between participants, it's neither pretty nor perfect, but it should show @@ -75,21 +75,27 @@ fn main() { let port: u16 = matches.value_of("port").unwrap().parse().unwrap(); let ip: &str = matches.value_of("ip").unwrap(); - let address = match matches.value_of("protocol") { - Some("tcp") => ProtocolAddr::Tcp(format!("{}:{}", ip, port).parse().unwrap()), - Some("udp") => ProtocolAddr::Udp(format!("{}:{}", ip, port).parse().unwrap()), + let addresses = match matches.value_of("protocol") { + Some("tcp") => ( + ListenAddr::Tcp(format!("{}:{}", ip, port).parse().unwrap()), + ConnectAddr::Tcp(format!("{}:{}", ip, port).parse().unwrap()), + ), + Some("udp") => ( + ListenAddr::Udp(format!("{}:{}", ip, port).parse().unwrap()), + ConnectAddr::Udp(format!("{}:{}", ip, port).parse().unwrap()), + ), _ => panic!("invalid mode, run --help!"), }; let mut background = None; match matches.value_of("mode") { - Some("server") => server(address), - Some("client") => client(address), + Some("server") => server(addresses.0), + Some("client") => client(addresses.1), Some("both") => { - let address1 = address.clone(); - background = Some(thread::spawn(|| server(address1))); + let s = addresses.0; + background = Some(thread::spawn(|| server(s))); thread::sleep(Duration::from_millis(200)); //start client after server - client(address) + client(addresses.1) }, _ => panic!("invalid mode, run --help!"), }; @@ -98,7 +104,7 @@ fn main() { } } -fn server(address: ProtocolAddr) { +fn server(address: ListenAddr) { let r = Arc::new(Runtime::new().unwrap()); let server = Network::new(Pid::new(), &r); let server = Arc::new(server); @@ -144,7 +150,7 @@ async fn client_connection( println!("[{}] disconnected", username); } -fn client(address: ProtocolAddr) { +fn client(address: ConnectAddr) { let r = Arc::new(Runtime::new().unwrap()); let client = Network::new(Pid::new(), &r); diff --git a/network/examples/fileshare/commands.rs b/network/examples/fileshare/commands.rs index a18c90b38e..9f23ddb6aa 100644 --- a/network/examples/fileshare/commands.rs +++ b/network/examples/fileshare/commands.rs @@ -2,7 +2,7 @@ use rand::Rng; use serde::{Deserialize, Serialize}; use std::path::{Path, PathBuf}; use tokio::fs; -use veloren_network::{Participant, ProtocolAddr, Stream}; +use veloren_network::{ConnectAddr, Participant, Stream}; use std::collections::HashMap; @@ -10,7 +10,7 @@ use std::collections::HashMap; pub enum LocalCommand { Shutdown, Disconnect, - Connect(ProtocolAddr), + Connect(ConnectAddr), List, Serve(FileInfo), Get(u32, Option<String>), diff --git a/network/examples/fileshare/main.rs b/network/examples/fileshare/main.rs index f000f371e0..158b825073 100644 --- a/network/examples/fileshare/main.rs +++ b/network/examples/fileshare/main.rs @@ -9,7 +9,7 @@ use std::{path::PathBuf, sync::Arc, thread, time::Duration}; use tokio::{io, io::AsyncBufReadExt, runtime::Runtime, sync::mpsc}; use tracing::*; use tracing_subscriber::EnvFilter; -use veloren_network::ProtocolAddr; +use veloren_network::{ConnectAddr, ListenAddr}; mod commands; mod server; use commands::{FileInfo, LocalCommand}; @@ -50,7 +50,7 @@ fn main() { .init(); let port: u16 = matches.value_of("port").unwrap().parse().unwrap(); - let address = ProtocolAddr::Tcp(format!("{}:{}", "127.0.0.1", port).parse().unwrap()); + let address = ListenAddr::Tcp(format!("{}:{}", "127.0.0.1", port).parse().unwrap()); let runtime = Arc::new(Runtime::new().unwrap()); let (server, cmd_sender) = Server::new(Arc::clone(&runtime)); @@ -158,12 +158,12 @@ async fn client(cmd_sender: mpsc::UnboundedSender<LocalCommand>) { .parse() .unwrap(); cmd_sender - .send(LocalCommand::Connect(ProtocolAddr::Tcp(socketaddr))) + .send(LocalCommand::Connect(ConnectAddr::Tcp(socketaddr))) .unwrap(); }, ("t", _) => { cmd_sender - .send(LocalCommand::Connect(ProtocolAddr::Tcp( + .send(LocalCommand::Connect(ConnectAddr::Tcp( "127.0.0.1:1231".parse().unwrap(), ))) .unwrap(); diff --git a/network/examples/fileshare/server.rs b/network/examples/fileshare/server.rs index 252ebdf32c..7a40d5be11 100644 --- a/network/examples/fileshare/server.rs +++ b/network/examples/fileshare/server.rs @@ -8,7 +8,7 @@ use tokio::{ }; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::*; -use veloren_network::{Network, Participant, Pid, Promises, ProtocolAddr, Stream}; +use veloren_network::{ListenAddr, Network, Participant, Pid, Promises, Stream}; #[derive(Debug)] struct ControlChannels { @@ -42,7 +42,7 @@ impl Server { ) } - pub async fn run(mut self, address: ProtocolAddr) { + pub async fn run(mut self, address: ListenAddr) { let run_channels = self.run_channels.take().unwrap(); self.network.listen(address).await.unwrap(); diff --git a/network/examples/network-speed/main.rs b/network/examples/network-speed/main.rs index e058aac7a8..e8ccc8f278 100644 --- a/network/examples/network-speed/main.rs +++ b/network/examples/network-speed/main.rs @@ -16,7 +16,7 @@ use std::{ use tokio::runtime::Runtime; use tracing::*; use tracing_subscriber::EnvFilter; -use veloren_network::{Message, Network, Pid, Promises, ProtocolAddr}; +use veloren_network::{ConnectAddr, ListenAddr, Message, Network, Pid, Promises}; #[derive(Serialize, Deserialize, Debug)] enum Msg { @@ -96,23 +96,29 @@ fn main() { let port: u16 = matches.value_of("port").unwrap().parse().unwrap(); let ip: &str = matches.value_of("ip").unwrap(); - let address = match matches.value_of("protocol") { - Some("tcp") => ProtocolAddr::Tcp(format!("{}:{}", ip, port).parse().unwrap()), - Some("udp") => ProtocolAddr::Udp(format!("{}:{}", ip, port).parse().unwrap()), - _ => panic!("Invalid mode, run --help!"), + let addresses = match matches.value_of("protocol") { + Some("tcp") => ( + ListenAddr::Tcp(format!("{}:{}", ip, port).parse().unwrap()), + ConnectAddr::Tcp(format!("{}:{}", ip, port).parse().unwrap()), + ), + Some("udp") => ( + ListenAddr::Udp(format!("{}:{}", ip, port).parse().unwrap()), + ConnectAddr::Udp(format!("{}:{}", ip, port).parse().unwrap()), + ), + _ => panic!("invalid mode, run --help!"), }; let mut background = None; let runtime = Arc::new(Runtime::new().unwrap()); match matches.value_of("mode") { - Some("server") => server(address, Arc::clone(&runtime)), - Some("client") => client(address, Arc::clone(&runtime)), + Some("server") => server(addresses.0, Arc::clone(&runtime)), + Some("client") => client(addresses.1, Arc::clone(&runtime)), Some("both") => { - let address1 = address.clone(); + let s = addresses.0; let runtime2 = Arc::clone(&runtime); - background = Some(thread::spawn(|| server(address1, runtime2))); + background = Some(thread::spawn(|| server(s, runtime2))); thread::sleep(Duration::from_millis(200)); //start client after server - client(address, Arc::clone(&runtime)); + client(addresses.1, Arc::clone(&runtime)); }, _ => panic!("Invalid mode, run --help!"), }; @@ -121,7 +127,7 @@ fn main() { } } -fn server(address: ProtocolAddr, runtime: Arc<Runtime>) { +fn server(address: ListenAddr, runtime: Arc<Runtime>) { let registry = Arc::new(Registry::new()); let server = Network::new_with_registry(Pid::new(), &runtime, ®istry); runtime.spawn(Server::run( @@ -153,7 +159,7 @@ fn server(address: ProtocolAddr, runtime: Arc<Runtime>) { } } -fn client(address: ProtocolAddr, runtime: Arc<Runtime>) { +fn client(address: ConnectAddr, runtime: Arc<Runtime>) { let registry = Arc::new(Registry::new()); let client = Network::new_with_registry(Pid::new(), &runtime, ®istry); runtime.spawn(Server::run( diff --git a/network/protocol/src/quic.rs b/network/protocol/src/quic.rs index b4af04a193..e656fdf5a1 100644 --- a/network/protocol/src/quic.rs +++ b/network/protocol/src/quic.rs @@ -285,9 +285,11 @@ where } } for (id, (_, buffer)) in self.reliable_buffers.data.iter_mut().enumerate() { - self.drain - .send(QuicDataFormat::with_reliable(buffer, id as u64)) - .await?; + if !buffer.is_empty() { + self.drain + .send(QuicDataFormat::with_reliable(buffer, id as u64)) + .await?; + } } self.metrics .sdata_frames_b(data_frames, data_bandwidth as u64); @@ -340,43 +342,41 @@ where { async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError> { 'outer: loop { - loop { - match ITFrame::read_frame(&mut self.main_buffer) { - Ok(Some(frame)) => { - #[cfg(feature = "trace_pedantic")] - trace!(?frame, "recv"); - match frame { - ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), - ITFrame::OpenStream { + match ITFrame::read_frame(&mut self.main_buffer) { + Ok(Some(frame)) => { + #[cfg(feature = "trace_pedantic")] + trace!(?frame, "recv"); + match frame { + ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), + ITFrame::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + if promises.contains(Promises::ORDERED) + || promises.contains(Promises::CONSISTENCY) + || promises.contains(Promises::GUARANTEED_DELIVERY) + { + self.reliable_buffers.insert(sid, BytesMut::new()); + } + break 'outer Ok(ProtocolEvent::OpenStream { sid, - prio, + prio: prio.min(crate::types::HIGHEST_PRIO), promises, guaranteed_bandwidth, - } => { - if promises.contains(Promises::ORDERED) - || promises.contains(Promises::CONSISTENCY) - || promises.contains(Promises::GUARANTEED_DELIVERY) - { - self.reliable_buffers.insert(sid, BytesMut::new()); - } - break 'outer Ok(ProtocolEvent::OpenStream { - sid, - prio: prio.min(crate::types::HIGHEST_PRIO), - promises, - guaranteed_bandwidth, - }); - }, - ITFrame::CloseStream { sid } => { - //FIXME: defer close! - //let _ = self.reliable_buffers.delete(sid); // if it was reliable - break 'outer Ok(ProtocolEvent::CloseStream { sid }); - }, - _ => break 'outer Err(ProtocolError::Violated), - }; - }, - Ok(None) => break, //inner => read more data - Err(()) => return Err(ProtocolError::Violated), - } + }); + }, + ITFrame::CloseStream { sid } => { + //FIXME: defer close! + //let _ = self.reliable_buffers.delete(sid); // if it was reliable + break 'outer Ok(ProtocolEvent::CloseStream { sid }); + }, + _ => break 'outer Err(ProtocolError::Violated), + }; + }, + Ok(None) => {}, + Err(()) => return Err(ProtocolError::Violated), } // try to order pending @@ -401,6 +401,7 @@ where Ok(None) => false, } }); + if pending_violated { break 'outer Err(ProtocolError::Violated); } @@ -435,10 +436,10 @@ where None => { if reliable { info!( - ?mid, - "protocol violation by remote side: send Data before \ - Header" - ); + ?mid, + "protocol violation by remote side: send Data \ + before Header" + ); break 'outer Err(ProtocolError::Violated); } else { //TODO: cleanup old messages from time to time diff --git a/network/src/api.rs b/network/src/api.rs index d38318aa57..ad95dd3419 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -28,9 +28,19 @@ use tracing::*; type A2sDisconnect = Arc<Mutex<Option<mpsc::UnboundedSender<(Pid, S2bShutdownBparticipant)>>>>; -/// Represents a Tcp or Udp or Mpsc address -#[derive(Clone, Debug, Hash, PartialEq, Eq)] -pub enum ProtocolAddr { +/// Represents a Tcp, Quic, Udp or Mpsc connection address +#[derive(Clone, Debug)] +pub enum ConnectAddr { + Tcp(SocketAddr), + Udp(SocketAddr), + #[cfg(feature = "quic")] + Quic(SocketAddr, quinn::ClientConfig, String), + Mpsc(u64), +} + +/// Represents a Tcp, Quic, Udp or Mpsc listen address +#[derive(Clone, Debug)] +pub enum ListenAddr { Tcp(SocketAddr), Udp(SocketAddr), #[cfg(feature = "quic")] @@ -135,8 +145,8 @@ pub struct StreamParams { /// [`Arc`](std::sync::Arc) as all commands have internal mutability. /// /// The `Network` has methods to [`connect`] to other [`Participants`] actively -/// via their [`ProtocolAddr`], or [`listen`] passively for [`connected`] -/// [`Participants`]. +/// via their [`ProtocolConnectAddr`], or [`listen`] passively for [`connected`] +/// [`Participants`] via [`ProtocolListenAddr`]. /// /// Too guarantee a clean shutdown, the [`Runtime`] MUST NOT be droped before /// the Network. @@ -144,7 +154,7 @@ pub struct StreamParams { /// # Examples /// ```rust /// use tokio::runtime::Runtime; -/// use veloren_network::{Network, ProtocolAddr, Pid}; +/// use veloren_network::{Network, ConnectAddr, ListenAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { /// // Create a Network, listen on port `2999` to accept connections and connect to port `8080` to connect to a (pseudo) database Application @@ -153,9 +163,9 @@ pub struct StreamParams { /// runtime.block_on(async{ /// # //setup pseudo database! /// # let database = Network::new(Pid::new(), &runtime); -/// # database.listen(ProtocolAddr::Tcp("127.0.0.1:8080".parse().unwrap())).await?; -/// network.listen(ProtocolAddr::Tcp("127.0.0.1:2999".parse().unwrap())).await?; -/// let database = network.connect(ProtocolAddr::Tcp("127.0.0.1:8080".parse().unwrap())).await?; +/// # database.listen(ListenAddr::Tcp("127.0.0.1:8080".parse().unwrap())).await?; +/// network.listen(ListenAddr::Tcp("127.0.0.1:2999".parse().unwrap())).await?; +/// let database = network.connect(ConnectAddr::Tcp("127.0.0.1:8080".parse().unwrap())).await?; /// drop(network); /// # drop(database); /// # Ok(()) @@ -171,7 +181,7 @@ pub struct StreamParams { pub struct Network { local_pid: Pid, participant_disconnect_sender: Arc<Mutex<HashMap<Pid, A2sDisconnect>>>, - listen_sender: Mutex<mpsc::UnboundedSender<(ProtocolAddr, oneshot::Sender<io::Result<()>>)>>, + listen_sender: Mutex<mpsc::UnboundedSender<(ListenAddr, oneshot::Sender<io::Result<()>>)>>, connect_sender: Mutex<mpsc::UnboundedSender<A2sConnect>>, connected_receiver: Mutex<mpsc::UnboundedReceiver<Participant>>, shutdown_network_s: Option<oneshot::Sender<oneshot::Sender<()>>>, @@ -197,7 +207,7 @@ impl Network { /// # Examples /// ```rust /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolAddr}; + /// use veloren_network::{Network, Pid}; /// /// let runtime = Runtime::new().unwrap(); /// let network = Network::new(Pid::new(), &runtime); @@ -230,7 +240,7 @@ impl Network { /// ```rust /// use prometheus::Registry; /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolAddr}; + /// use veloren_network::{Network, Pid}; /// /// let runtime = Runtime::new().unwrap(); /// let registry = Registry::new(); @@ -283,7 +293,7 @@ impl Network { } } - /// starts listening on an [`ProtocolAddr`]. + /// starts listening on an [`ProtocolListenAddr`]. /// When the method returns the `Network` is ready to listen for incoming /// connections OR has returned a [`NetworkError`] (e.g. port already used). /// You can call [`connected`] to asynchrony wait for a [`Participant`] to @@ -293,7 +303,7 @@ impl Network { /// # Examples /// ```ignore /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolAddr}; + /// use veloren_network::{Network, Pid, ProtocolListenAddr}; /// /// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { /// // Create a Network, listen on port `2000` TCP on all NICs and `2001` UDP locally @@ -301,10 +311,10 @@ impl Network { /// let network = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { /// network - /// .listen(ProtocolAddr::Tcp("127.0.0.1:2000".parse().unwrap())) + /// .listen(ProtocolListenAddr::Tcp("127.0.0.1:2000".parse().unwrap())) /// .await?; /// network - /// .listen(ProtocolAddr::Udp("127.0.0.1:2001".parse().unwrap())) + /// .listen(ProtocolListenAddr::Udp("127.0.0.1:2001".parse().unwrap())) /// .await?; /// drop(network); /// # Ok(()) @@ -314,7 +324,7 @@ impl Network { /// /// [`connected`]: Network::connected #[instrument(name="network", skip(self, address), fields(p = %self.local_pid))] - pub async fn listen(&self, address: ProtocolAddr) -> Result<(), NetworkError> { + pub async fn listen(&self, address: ListenAddr) -> Result<(), NetworkError> { let (s2a_result_s, s2a_result_r) = oneshot::channel::<tokio::io::Result<()>>(); debug!(?address, "listening on address"); self.listen_sender @@ -329,13 +339,13 @@ impl Network { } } - /// starts connection to an [`ProtocolAddr`]. + /// starts connection to an [`ProtocolConnectAddr`]. /// When the method returns the Network either returns a [`Participant`] /// ready to open [`Streams`] on OR has returned a [`NetworkError`] (e.g. /// can't connect, or invalid Handshake) # Examples /// ```ignore /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolAddr}; + /// use veloren_network::{Network, Pid, ProtocolListenAddr, ProtocolConnectAddr}; /// /// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { /// // Create a Network, connect on port `2010` TCP and `2011` UDP like listening above @@ -343,16 +353,16 @@ impl Network { /// let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// # remote.listen(ProtocolAddr::Tcp("127.0.0.1:2010".parse().unwrap())).await?; - /// # remote.listen(ProtocolAddr::Udp("127.0.0.1:2011".parse().unwrap())).await?; + /// # remote.listen(ProtocolListenAddr::Tcp("127.0.0.1:2010".parse().unwrap())).await?; + /// # remote.listen(ProtocolListenAddr::Udp("127.0.0.1:2011".parse().unwrap())).await?; /// let p1 = network - /// .connect(ProtocolAddr::Tcp("127.0.0.1:2010".parse().unwrap())) + /// .connect(ProtocolConnectAddr::Tcp("127.0.0.1:2010".parse().unwrap())) /// .await?; /// # //this doesn't work yet, so skip the test /// # //TODO fixme! /// # return Ok(()); /// let p2 = network - /// .connect(ProtocolAddr::Udp("127.0.0.1:2011".parse().unwrap())) + /// .connect(ProtocolConnectAddr::Udp("127.0.0.1:2011".parse().unwrap())) /// .await?; /// assert_eq!(&p1, &p2); /// # Ok(()) @@ -364,15 +374,15 @@ impl Network { /// ``` /// Usually the `Network` guarantees that a operation on a [`Participant`] /// succeeds, e.g. by automatic retrying unless it fails completely e.g. by - /// disconnecting from the remote. If 2 [`ProtocolAddres`] you `connect` to - /// belongs to the same [`Participant`], you get the same [`Participant`] as - /// a result. This is useful e.g. by connecting to the same - /// [`Participant`] via multiple Protocols. + /// disconnecting from the remote. If 2 [`ProtocolConnectAddres`] you + /// `connect` to belongs to the same [`Participant`], you get the same + /// [`Participant`] as a result. This is useful e.g. by connecting to + /// the same [`Participant`] via multiple Protocols. /// /// [`Streams`]: crate::api::Stream - /// [`ProtocolAddres`]: crate::api::ProtocolAddr + /// [`ProtocolConnectAddres`]: crate::api::ProtocolConnectAddr #[instrument(name="network", skip(self, address), fields(p = %self.local_pid))] - pub async fn connect(&self, address: ProtocolAddr) -> Result<Participant, NetworkError> { + pub async fn connect(&self, address: ConnectAddr) -> Result<Participant, NetworkError> { let (pid_sender, pid_receiver) = oneshot::channel::<Result<Participant, NetworkConnectError>>(); debug!(?address, "Connect to address"); @@ -393,15 +403,15 @@ impl Network { Ok(participant) } - /// returns a [`Participant`] created from a [`ProtocolAddr`] you called - /// [`listen`] on before. This function will either return a working - /// [`Participant`] ready to open [`Streams`] on OR has returned a - /// [`NetworkError`] (e.g. Network got closed) + /// returns a [`Participant`] created from a [`ProtocolListenAddr`] you + /// called [`listen`] on before. This function will either return a + /// working [`Participant`] ready to open [`Streams`] on OR has returned + /// a [`NetworkError`] (e.g. Network got closed) /// /// # Examples /// ```rust /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolAddr}; + /// use veloren_network::{ConnectAddr, ListenAddr, Network, Pid}; /// /// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { /// // Create a Network, listen on port `2020` TCP and opens returns their Pid @@ -410,9 +420,9 @@ impl Network { /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { /// network - /// .listen(ProtocolAddr::Tcp("127.0.0.1:2020".parse().unwrap())) + /// .listen(ListenAddr::Tcp("127.0.0.1:2020".parse().unwrap())) /// .await?; - /// # remote.connect(ProtocolAddr::Tcp("127.0.0.1:2020".parse().unwrap())).await?; + /// # remote.connect(ConnectAddr::Tcp("127.0.0.1:2020".parse().unwrap())).await?; /// while let Ok(participant) = network.connected().await { /// println!("Participant connected: {}", participant.remote_pid()); /// # //skip test here as it would be a endless loop @@ -530,7 +540,7 @@ impl Participant { /// # Examples /// ```rust /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, Promises, ProtocolAddr}; + /// use veloren_network::{ConnectAddr, ListenAddr, Network, Pid, Promises}; /// /// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { /// // Create a Network, connect on port 2100 and open a stream @@ -538,9 +548,9 @@ impl Participant { /// let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// # remote.listen(ProtocolAddr::Tcp("127.0.0.1:2100".parse().unwrap())).await?; + /// # remote.listen(ListenAddr::Tcp("127.0.0.1:2100".parse().unwrap())).await?; /// let p1 = network - /// .connect(ProtocolAddr::Tcp("127.0.0.1:2100".parse().unwrap())) + /// .connect(ConnectAddr::Tcp("127.0.0.1:2100".parse().unwrap())) /// .await?; /// let _s1 = p1 /// .open(4, Promises::ORDERED | Promises::CONSISTENCY, 1000) @@ -597,7 +607,7 @@ impl Participant { /// # Examples /// ```rust /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolAddr, Promises}; + /// use veloren_network::{Network, Pid, ListenAddr, ConnectAddr, Promises}; /// /// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { /// // Create a Network, connect on port 2110 and wait for the other side to open a stream @@ -606,8 +616,8 @@ impl Participant { /// let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// # remote.listen(ProtocolAddr::Tcp("127.0.0.1:2110".parse().unwrap())).await?; - /// let p1 = network.connect(ProtocolAddr::Tcp("127.0.0.1:2110".parse().unwrap())).await?; + /// # remote.listen(ListenAddr::Tcp("127.0.0.1:2110".parse().unwrap())).await?; + /// let p1 = network.connect(ConnectAddr::Tcp("127.0.0.1:2110".parse().unwrap())).await?; /// # let p2 = remote.connected().await?; /// # p2.open(4, Promises::ORDERED | Promises::CONSISTENCY, 0).await?; /// let _s1 = p1.opened().await?; @@ -654,7 +664,7 @@ impl Participant { /// # Examples /// ```rust /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolAddr}; + /// use veloren_network::{Network, Pid, ListenAddr, ConnectAddr}; /// /// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { /// // Create a Network, listen on port `2030` TCP and opens returns their Pid and close connection. @@ -663,9 +673,9 @@ impl Participant { /// # let remote = Network::new(Pid::new(), &runtime); /// let err = runtime.block_on(async { /// network - /// .listen(ProtocolAddr::Tcp("127.0.0.1:2030".parse().unwrap())) + /// .listen(ListenAddr::Tcp("127.0.0.1:2030".parse().unwrap())) /// .await?; - /// # let keep_alive = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2030".parse().unwrap())).await?; + /// # let keep_alive = remote.connect(ConnectAddr::Tcp("127.0.0.1:2030".parse().unwrap())).await?; /// while let Ok(participant) = network.connected().await { /// println!("Participant connected: {}", participant.remote_pid()); /// participant.disconnect().await?; @@ -790,7 +800,7 @@ impl Stream { /// ``` /// # use veloren_network::Promises; /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, ProtocolAddr, Pid}; + /// use veloren_network::{Network, ListenAddr, ConnectAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { /// // Create a Network, listen on Port `2200` and wait for a Stream to be opened, then answer `Hello World` @@ -798,8 +808,8 @@ impl Stream { /// let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2200".parse().unwrap())).await?; - /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2200".parse().unwrap())).await?; + /// network.listen(ListenAddr::Tcp("127.0.0.1:2200".parse().unwrap())).await?; + /// # let remote_p = remote.connect(ConnectAddr::Tcp("127.0.0.1:2200".parse().unwrap())).await?; /// # // keep it alive /// # let _stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY, 0).await?; /// let participant_a = network.connected().await?; @@ -832,7 +842,7 @@ impl Stream { /// # use veloren_network::Promises; /// use tokio::runtime::Runtime; /// use bincode; - /// use veloren_network::{Network, ProtocolAddr, Pid, Message}; + /// use veloren_network::{Network, ListenAddr, ConnectAddr, Pid, Message}; /// /// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { /// let runtime = Runtime::new().unwrap(); @@ -840,9 +850,9 @@ impl Stream { /// # let remote1 = Network::new(Pid::new(), &runtime); /// # let remote2 = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; - /// # let remote1_p = remote1.connect(ProtocolAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; - /// # let remote2_p = remote2.connect(ProtocolAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; + /// network.listen(ListenAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; + /// # let remote1_p = remote1.connect(ConnectAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; + /// # let remote2_p = remote2.connect(ConnectAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; /// # assert_eq!(remote1_p.remote_pid(), remote2_p.remote_pid()); /// # remote1_p.open(4, Promises::ORDERED | Promises::CONSISTENCY, 0).await?; /// # remote2_p.open(4, Promises::ORDERED | Promises::CONSISTENCY, 0).await?; @@ -891,7 +901,7 @@ impl Stream { /// ``` /// # use veloren_network::Promises; /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, ProtocolAddr, Pid}; + /// use veloren_network::{Network, ListenAddr, ConnectAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { /// // Create a Network, listen on Port `2220` and wait for a Stream to be opened, then listen on it @@ -899,8 +909,8 @@ impl Stream { /// let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2220".parse().unwrap())).await?; - /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2220".parse().unwrap())).await?; + /// network.listen(ListenAddr::Tcp("127.0.0.1:2220".parse().unwrap())).await?; + /// # let remote_p = remote.connect(ConnectAddr::Tcp("127.0.0.1:2220".parse().unwrap())).await?; /// # let mut stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY, 0).await?; /// # stream_p.send("Hello World"); /// let participant_a = network.connected().await?; @@ -925,7 +935,7 @@ impl Stream { /// ``` /// # use veloren_network::Promises; /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, ProtocolAddr, Pid}; + /// use veloren_network::{Network, ListenAddr, ConnectAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { /// // Create a Network, listen on Port `2230` and wait for a Stream to be opened, then listen on it @@ -933,8 +943,8 @@ impl Stream { /// let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2230".parse().unwrap())).await?; - /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2230".parse().unwrap())).await?; + /// network.listen(ListenAddr::Tcp("127.0.0.1:2230".parse().unwrap())).await?; + /// # let remote_p = remote.connect(ConnectAddr::Tcp("127.0.0.1:2230".parse().unwrap())).await?; /// # let mut stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY, 0).await?; /// # stream_p.send("Hello World"); /// let participant_a = network.connected().await?; @@ -981,7 +991,7 @@ impl Stream { /// ``` /// # use veloren_network::Promises; /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, ProtocolAddr, Pid}; + /// use veloren_network::{Network, ListenAddr, ConnectAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { /// // Create a Network, listen on Port `2240` and wait for a Stream to be opened, then listen on it @@ -989,8 +999,8 @@ impl Stream { /// let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2240".parse().unwrap())).await?; - /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2240".parse().unwrap())).await?; + /// network.listen(ListenAddr::Tcp("127.0.0.1:2240".parse().unwrap())).await?; + /// # let remote_p = remote.connect(ConnectAddr::Tcp("127.0.0.1:2240".parse().unwrap())).await?; /// # let mut stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY, 0).await?; /// # stream_p.send("Hello World"); /// # std::thread::sleep(std::time::Duration::from_secs(1)); diff --git a/network/src/channel.rs b/network/src/channel.rs index 85bf824134..9866d88da9 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -1,19 +1,20 @@ use async_trait::async_trait; use bytes::BytesMut; use network_protocol::{ - QuicDataFormat, QuicDataFormatStream, QuicSendProtocol, QuicRecvProtocol, Bandwidth, Cid, InitProtocolError, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid, - ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol, + ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, QuicDataFormat, + QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol, Sid, TcpRecvProtocol, TcpSendProtocol, UnreliableDrain, UnreliableSink, }; -#[cfg(feature = "quic")] use quinn::*; use std::{sync::Arc, time::Duration}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::tcp::{OwnedReadHalf, OwnedWriteHalf}, sync::mpsc, }; +use tokio_stream::StreamExt; +#[allow(clippy::large_enum_variant)] #[derive(Debug)] pub(crate) enum Protocols { Tcp((TcpSendProtocol<TcpDrain>, TcpRecvProtocol<TcpSink>)), @@ -35,7 +36,7 @@ pub(crate) enum RecvProtocols { Tcp(TcpRecvProtocol<TcpSink>), Mpsc(MpscRecvProtocol<MpscSink>), #[cfg(feature = "quic")] - Quic(QuicSendProtocol<QuicDrain>), + Quic(QuicRecvProtocol<QuicSink>), } impl Protocols { @@ -73,26 +74,39 @@ impl Protocols { #[cfg(feature = "quic")] pub(crate) async fn new_quic( - connection: quinn::NewConnection, + mut connection: quinn::NewConnection, + listen: bool, cid: Cid, metrics: Arc<ProtocolMetrics>, ) -> Result<Self, quinn::ConnectionError> { let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); - let (sendstream, recvstream) = connection.connection.open_bi().await?; - - - let sp = QuicSendProtocol::new(QuicDrain { - con: connection.connection.clone(), - main: sendstream, - reliables: vec!(), - }, metrics.clone()); - let rp = QuicRecvProtocol::new(QuicSink { - con: connection.connection, - main: recvstream, - reliables: vec!(), - buffer: BytesMut::new(), - }, metrics); + let (sendstream, recvstream) = if listen { + connection.connection.open_bi().await? + } else { + connection.bi_streams.next().await.expect("none").expect("dasdasd") + }; + let (streams_s,streams_r) = mpsc::unbounded_channel(); + let streams_s_clone = streams_s.clone(); + let sp = QuicSendProtocol::new( + QuicDrain { + con: connection.connection.clone(), + main: sendstream, + reliables: std::collections::HashMap::new(), + streams_s: streams_s_clone, + }, + metrics.clone(), + ); + spawn_new(recvstream, None, &streams_s); + let rp = QuicRecvProtocol::new( + QuicSink { + con: connection.connection, + bi: connection.bi_streams, + streams_r, + streams_s, + }, + metrics, + ); Ok(Protocols::Quic((sp, rp))) } @@ -243,50 +257,128 @@ impl UnreliableSink for MpscSink { /////////////////////////////////////// //// QUIC +#[cfg(feature = "quic")] +type QuicStream = (BytesMut, Result<Option<usize>, quinn::ReadError>, quinn::RecvStream, Option<u64>); + +#[cfg(feature = "quic")] #[derive(Debug)] pub struct QuicDrain { con: quinn::Connection, main: quinn::SendStream, - reliables: Vec<quinn::SendStream>, + reliables: std::collections::HashMap<u64, quinn::SendStream>, + streams_s: mpsc::UnboundedSender<QuicStream>, } +#[cfg(feature = "quic")] #[derive(Debug)] pub struct QuicSink { con: quinn::Connection, - main: quinn::RecvStream, - reliables: Vec<quinn::RecvStream>, - buffer: BytesMut, + bi: quinn::IncomingBiStreams, + streams_r: mpsc::UnboundedReceiver<QuicStream>, + streams_s: mpsc::UnboundedSender<QuicStream>, } +#[cfg(feature = "quic")] +fn spawn_new(mut recvstream: quinn::RecvStream, id: Option<u64>, streams_s: &mpsc::UnboundedSender<QuicStream>) { + let streams_s_clone = streams_s.clone(); + tokio::spawn(async move { + let mut buffer = BytesMut::new(); + buffer.resize(1500, 0u8); + let r = recvstream.read(&mut buffer).await; + let _ = streams_s_clone.send((buffer, r, recvstream, id)); + }); +} + +#[cfg(feature = "quic")] #[async_trait] impl UnreliableDrain for QuicDrain { type DataFormat = QuicDataFormat; async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { match match data.stream { - QuicDataFormatStream::Main => self.main.write_all(&data.data), + QuicDataFormatStream::Main => { + self.main.write_all(&data.data).await + }, QuicDataFormatStream::Unreliable => unimplemented!(), - QuicDataFormatStream::Reliable(id) => self.reliables.get_mut(id as usize).ok_or(ProtocolError::Closed)?.write_all(&data.data), - }.await { + QuicDataFormatStream::Reliable(id) => { + use std::collections::hash_map::Entry; + match self.reliables.entry(id) { + Entry::Occupied(mut occupied) => { + occupied.get_mut().write_all(&data.data).await + }, + Entry::Vacant(vacant) => { + match self.con.open_bi().await { + Ok((sendstream, recvstream)) => { + let id = Some(0); //TODO FIXME + spawn_new(recvstream, id, &self.streams_s); + vacant.insert(sendstream).write_all(&data.data).await + }, + Err(_) => return Err(ProtocolError::Closed), + } + }, + } + }, + } + { Ok(()) => Ok(()), Err(_) => Err(ProtocolError::Closed), } } } +#[cfg(feature = "quic")] #[async_trait] impl UnreliableSink for QuicSink { type DataFormat = QuicDataFormat; async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> { - self.buffer.resize(1500, 0u8); - //TODO improve - match self.main.read(&mut self.buffer).await { + let (mut buffer, result, mut recvstream, id) = loop { + use futures_util::FutureExt; + // first handle all bi streams! + let (a, b) = tokio::select! { + biased; + Some(n) = self.bi.next().fuse() => (Some(n), None), + Some(n) = self.streams_r.recv().fuse() => (None, Some(n)), + }; + + if let Some(remote_stream) = a { + match remote_stream { + Ok((sendstream, recvstream)) => { + //FIXME TODO + let id = Some(0); // get real ID + drop(sendstream); // not drop it! + spawn_new(recvstream, id, &self.streams_s); + }, + Err(_) => return Err(ProtocolError::Closed), + } + } + + if let Some(data) = b { + break data; + } + }; + + let r = match result { Ok(Some(0)) => Err(ProtocolError::Closed), - Ok(Some(n)) => Ok(QuicDataFormat{stream: QuicDataFormatStream::Main, data: self.buffer.split_to(n)}), + Ok(Some(n)) => Ok(QuicDataFormat { + stream: match id { + Some(id) => QuicDataFormatStream::Reliable(id), + None => QuicDataFormatStream::Main, + }, + data: buffer.split_to(n), + }), Ok(None) => Err(ProtocolError::Closed), Err(_) => Err(ProtocolError::Closed), - } + }?; + + + let streams_s_clone = self.streams_s.clone(); + tokio::spawn(async move { + buffer.resize(1500, 0u8); + let r = recvstream.read(&mut buffer).await; + let _ = streams_s_clone.send((buffer, r, recvstream, id)); + }); + Ok(r) } } diff --git a/network/src/lib.rs b/network/src/lib.rs index 448b50f41c..70b68fbafb 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -13,14 +13,14 @@ //! Say you have an application that wants to communicate with other application //! over a Network or on the same computer. Now each application instances the //! struct [`Network`] once with a new [`Pid`]. The Pid is necessary to identify -//! other [`Networks`] over the network protocols (e.g. TCP, UDP) +//! other [`Networks`] over the network protocols (e.g. TCP, UDP, QUIC, MPSC) //! -//! To connect to another application, you must know it's [`ProtocolAddr`]. One +//! To connect to another application, you must know it's [`ConnectAddr`]. One //! side will call [`connect`], the other [`connected`]. If successful both //! applications will now get a [`Participant`]. //! //! This [`Participant`] represents the connection between those 2 applications. -//! over the respective [`ProtocolAddr`] and with it the chosen network +//! over the respective [`ConnectAddr`] and with it the chosen network //! protocol. However messages can't be send directly via [`Participants`], //! instead you must open a [`Stream`] on it. Like above, one side has to call //! [`open`], the other [`opened`]. [`Streams`] can have a different priority @@ -41,14 +41,14 @@ //! ```rust //! use std::sync::Arc; //! use tokio::{join, runtime::Runtime, time::sleep}; -//! use veloren_network::{Network, Pid, Promises, ProtocolAddr}; +//! use veloren_network::{ConnectAddr, ListenAddr, Network, Pid, Promises}; //! //! // Client //! async fn client(runtime: &Runtime) -> std::result::Result<(), Box<dyn std::error::Error>> { //! sleep(std::time::Duration::from_secs(1)).await; // `connect` MUST be after `listen` //! let client_network = Network::new(Pid::new(), runtime); //! let server = client_network -//! .connect(ProtocolAddr::Tcp("127.0.0.1:12345".parse().unwrap())) +//! .connect(ConnectAddr::Tcp("127.0.0.1:12345".parse().unwrap())) //! .await?; //! let mut stream = server //! .open(4, Promises::ORDERED | Promises::CONSISTENCY, 0) @@ -61,7 +61,7 @@ //! async fn server(runtime: &Runtime) -> std::result::Result<(), Box<dyn std::error::Error>> { //! let server_network = Network::new(Pid::new(), runtime); //! server_network -//! .listen(ProtocolAddr::Tcp("127.0.0.1:12345".parse().unwrap())) +//! .listen(ListenAddr::Tcp("127.0.0.1:12345".parse().unwrap())) //! .await?; //! let client = server_network.connected().await?; //! let mut stream = client.opened().await?; @@ -95,7 +95,8 @@ //! [`send`]: crate::api::Stream::send //! [`recv`]: crate::api::Stream::recv //! [`Pid`]: network_protocol::Pid -//! [`ProtocolAddr`]: crate::api::ProtocolAddr +//! [`ListenAddr`]: crate::api::ListenAddr +//! [`ConnectAddr`]: crate::api::ConnectAddr //! [`Promises`]: network_protocol::Promises mod api; @@ -107,8 +108,8 @@ mod scheduler; mod util; pub use api::{ - Network, NetworkConnectError, NetworkError, Participant, ParticipantError, ProtocolAddr, - Stream, StreamError, StreamParams, + ConnectAddr, ListenAddr, Network, NetworkConnectError, NetworkError, Participant, + ParticipantError, Stream, StreamError, StreamParams, }; pub use message::Message; pub use network_protocol::{InitProtocolError, Pid, Promises}; diff --git a/network/src/message.rs b/network/src/message.rs index bc81e25802..5c0029cf16 100644 --- a/network/src/message.rs +++ b/network/src/message.rs @@ -70,7 +70,7 @@ impl Message { /// /// # Example /// ``` - /// # use veloren_network::{Network, ProtocolAddr, Pid}; + /// # use veloren_network::{Network, ListenAddr, ConnectAddr, Pid}; /// # use veloren_network::Promises; /// # use tokio::runtime::Runtime; /// # use std::sync::Arc; @@ -81,8 +81,8 @@ impl Message { /// # let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// # runtime.block_on(async { - /// # network.listen(ProtocolAddr::Tcp("127.0.0.1:2300".parse().unwrap())).await?; - /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2300".parse().unwrap())).await?; + /// # network.listen(ListenAddr::Tcp("127.0.0.1:2300".parse().unwrap())).await?; + /// # let remote_p = remote.connect(ConnectAddr::Tcp("127.0.0.1:2300".parse().unwrap())).await?; /// # let mut stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY, 0).await?; /// # stream_p.send("Hello World"); /// # let participant_a = network.connected().await?; diff --git a/network/src/metrics.rs b/network/src/metrics.rs index c46fe16bda..d532347140 100644 --- a/network/src/metrics.rs +++ b/network/src/metrics.rs @@ -1,8 +1,29 @@ -use crate::api::ProtocolAddr; +use crate::api::{ConnectAddr, ListenAddr}; use network_protocol::{Cid, Pid}; #[cfg(feature = "metrics")] use prometheus::{IntCounter, IntCounterVec, IntGauge, IntGaugeVec, Opts, Registry}; -use std::error::Error; +use std::{error::Error, net::SocketAddr}; + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub(crate) enum ProtocolInfo { + Tcp(SocketAddr), + Udp(SocketAddr), + #[cfg(feature = "quic")] + Quic(SocketAddr), + Mpsc(u64), +} + +impl From<ListenAddr> for ProtocolInfo { + fn from(other: ListenAddr) -> ProtocolInfo { + match other { + ListenAddr::Tcp(s) => ProtocolInfo::Tcp(s), + ListenAddr::Udp(s) => ProtocolInfo::Udp(s), + #[cfg(feature = "quic")] + ListenAddr::Quic(s, _) => ProtocolInfo::Quic(s), + ListenAddr::Mpsc(s) => ProtocolInfo::Mpsc(s), + } + } +} /// 1:1 relation between NetworkMetrics and Network #[cfg(feature = "metrics")] @@ -154,9 +175,9 @@ impl NetworkMetrics { Ok(()) } - pub(crate) fn connect_requests_cache(&self, protocol: &ProtocolAddr) -> prometheus::IntCounter { + pub(crate) fn connect_requests_cache(&self, protocol: &ListenAddr) -> prometheus::IntCounter { self.incoming_connections_total - .with_label_values(&[protocol_name(protocol)]) + .with_label_values(&[protocollisten_name(protocol)]) } pub(crate) fn channels_connected(&self, remote_p: &str, no: usize, cid: Cid) { @@ -192,15 +213,15 @@ impl NetworkMetrics { .inc(); } - pub(crate) fn listen_request(&self, protocol: &ProtocolAddr) { + pub(crate) fn listen_request(&self, protocol: &ListenAddr) { self.listen_requests_total - .with_label_values(&[protocol_name(protocol)]) + .with_label_values(&[protocollisten_name(protocol)]) .inc(); } - pub(crate) fn connect_request(&self, protocol: &ProtocolAddr) { + pub(crate) fn connect_request(&self, protocol: &ConnectAddr) { self.connect_requests_total - .with_label_values(&[protocol_name(protocol)]) + .with_label_values(&[protocolconnect_name(protocol)]) .inc(); } @@ -225,11 +246,22 @@ impl NetworkMetrics { } #[cfg(feature = "metrics")] -fn protocol_name(protocol: &ProtocolAddr) -> &str { +fn protocolconnect_name(protocol: &ConnectAddr) -> &str { match protocol { - ProtocolAddr::Tcp(_) => "tcp", - ProtocolAddr::Udp(_) => "udp", - ProtocolAddr::Mpsc(_) => "mpsc", + ConnectAddr::Tcp(_) => "tcp", + ConnectAddr::Udp(_) => "udp", + ConnectAddr::Mpsc(_) => "mpsc", + ConnectAddr::Quic(_, _, _) => "quic", + } +} + +#[cfg(feature = "metrics")] +fn protocollisten_name(protocol: &ListenAddr) -> &str { + match protocol { + ListenAddr::Tcp(_) => "tcp", + ListenAddr::Udp(_) => "udp", + ListenAddr::Mpsc(_) => "mpsc", + ListenAddr::Quic(_, _) => "quic", } } @@ -247,9 +279,9 @@ impl NetworkMetrics { pub(crate) fn streams_closed(&self, _remote_p: &str) {} - pub(crate) fn listen_request(&self, _protocol: &ProtocolAddr) {} + pub(crate) fn listen_request(&self, _protocol: &ListenAddr) {} - pub(crate) fn connect_request(&self, _protocol: &ProtocolAddr) {} + pub(crate) fn connect_request(&self, _protocol: &ConnectAddr) {} pub(crate) fn cleanup_participant(&self, _remote_p: &str) {} } diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index 11a2a0f774..475e34371f 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -1,7 +1,7 @@ use crate::{ - api::{NetworkConnectError, Participant, ProtocolAddr}, + api::{ConnectAddr, ListenAddr, NetworkConnectError, Participant}, channel::Protocols, - metrics::NetworkMetrics, + metrics::{NetworkMetrics, ProtocolInfo}, participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel, S2bShutdownBparticipant}, }; use futures_util::{FutureExt, StreamExt}; @@ -46,9 +46,9 @@ struct ParticipantInfo { s2b_shutdown_bparticipant_s: Option<oneshot::Sender<S2bShutdownBparticipant>>, } -type A2sListen = (ProtocolAddr, oneshot::Sender<io::Result<()>>); +type A2sListen = (ListenAddr, oneshot::Sender<io::Result<()>>); pub(crate) type A2sConnect = ( - ProtocolAddr, + ConnectAddr, oneshot::Sender<Result<Participant, NetworkConnectError>>, ); type A2sDisconnect = (Pid, S2bShutdownBparticipant); @@ -82,7 +82,7 @@ pub struct Scheduler { participant_channels: Arc<Mutex<Option<ParticipantChannels>>>, participants: Arc<Mutex<HashMap<Pid, ParticipantInfo>>>, channel_ids: Arc<AtomicU64>, - channel_listener: Mutex<HashMap<ProtocolAddr, oneshot::Sender<()>>>, + channel_listener: Mutex<HashMap<ProtocolInfo, oneshot::Sender<()>>>, metrics: Arc<NetworkMetrics>, protocol_metrics: Arc<ProtocolMetrics>, } @@ -182,7 +182,7 @@ impl Scheduler { self.channel_listener .lock() .await - .insert(address.clone(), end_sender); + .insert(address.clone().into(), end_sender); self.channel_creator(address, end_receiver, s2a_listen_result_s) .await; } @@ -198,7 +198,7 @@ impl Scheduler { let metrics = Arc::clone(&self.protocol_metrics); self.metrics.connect_request(&addr); let (protocol, handshake) = match addr { - ProtocolAddr::Tcp(addr) => { + ConnectAddr::Tcp(addr) => { let stream = match net::TcpStream::connect(addr).await { Ok(stream) => stream, Err(e) => { @@ -209,7 +209,21 @@ impl Scheduler { info!("Connecting Tcp to: {}", stream.peer_addr().unwrap()); (Protocols::new_tcp(stream, cid, metrics), false) }, - ProtocolAddr::Mpsc(addr) => { + #[cfg(feature = "quic")] + ConnectAddr::Quic(addr, ref config, name) => { + let config = config.clone(); + let endpoint = quinn::Endpoint::builder(); + let (endpoint, _) = endpoint.bind(&"[::]:0".parse().unwrap()).expect("FIXME"); + + let connecting = endpoint.connect_with(config, &addr, &name).expect("FIXME"); + let connection = connecting.await.expect("FIXME"); + ( + Protocols::new_quic(connection, false, cid, metrics).await.unwrap(), + false, + ) + //pid_sender.send(Ok(())).unwrap(); + }, + ConnectAddr::Mpsc(addr) => { let mpsc_s = match MPSC_POOL.lock().await.get(&addr) { Some(s) => s.clone(), None => { @@ -236,7 +250,7 @@ impl Scheduler { ) }, /* */ - //ProtocolAddr::Udp(addr) => { + //ProtocolConnectAddr::Udp(addr) => { //#[cfg(feature = "metrics")] //self.metrics //.connect_requests_total @@ -386,7 +400,7 @@ impl Scheduler { async fn channel_creator( &self, - addr: ProtocolAddr, + addr: ListenAddr, s2s_stop_listening_r: oneshot::Receiver<()>, s2a_listen_result_s: oneshot::Sender<io::Result<()>>, ) { @@ -394,7 +408,7 @@ impl Scheduler { #[cfg(feature = "metrics")] let mcache = self.metrics.connect_requests_cache(&addr); match addr { - ProtocolAddr::Tcp(addr) => { + ListenAddr::Tcp(addr) => { let listener = match net::TcpListener::bind(addr).await { Ok(listener) => { s2a_listen_result_s.send(Ok(())).unwrap(); @@ -432,10 +446,10 @@ impl Scheduler { } }, #[cfg(feature = "quic")] - ProtocolAddr::Quic(addr, server_config) => { + ListenAddr::Quic(addr, ref server_config) => { let mut endpoint = quinn::Endpoint::builder(); - endpoint.listen(server_config); - let (endpoint, mut listener) = match endpoint.bind(&addr) { + endpoint.listen(server_config.clone()); + let (_endpoint, mut listener) = match endpoint.bind(&addr) { Ok((endpoint, listener)) => { s2a_listen_result_s.send(Ok(())).unwrap(); (endpoint, listener) @@ -468,11 +482,18 @@ impl Scheduler { mcache.inc(); let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); info!(?remote_addr, ?cid, "Accepting Quic from"); - self.init_protocol(Protocols::new_quic(connection, cid, Arc::clone(&self.protocol_metrics)), cid, None, true) + let quic = match Protocols::new_quic(connection, true, cid, Arc::clone(&self.protocol_metrics)).await { + Ok(quic) => quic, + Err(e) => { + trace!(?e, "failed to start quic"); + continue; + } + }; + self.init_protocol(quic, cid, None, true) .await; } }, - ProtocolAddr::Mpsc(addr) => { + ListenAddr::Mpsc(addr) => { let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel(); MPSC_POOL.lock().await.insert(addr, mpsc_s); s2a_listen_result_s.send(Ok(())).unwrap(); @@ -494,7 +515,7 @@ impl Scheduler { } warn!("MpscStream Failed, stopping"); },/* - ProtocolAddr::Udp(addr) => { + ProtocolListenAddr::Udp(addr) => { let socket = match net::UdpSocket::bind(addr).await { Ok(socket) => { s2a_listen_result_s.send(Ok(())).unwrap(); diff --git a/network/tests/closing.rs b/network/tests/closing.rs index 7d6a2cb0ee..100e84e544 100644 --- a/network/tests/closing.rs +++ b/network/tests/closing.rs @@ -347,8 +347,8 @@ fn open_participant_before_remote_part_is_closed() { let n_a = Network::new(Pid::fake(0), &r); let n_b = Network::new(Pid::fake(1), &r); let addr = tcp(); - r.block_on(n_a.listen(addr.clone())).unwrap(); - let p_b = r.block_on(n_b.connect(addr)).unwrap(); + r.block_on(n_a.listen(addr.0)).unwrap(); + let p_b = r.block_on(n_b.connect(addr.1)).unwrap(); let mut s1_b = r.block_on(p_b.open(4, Promises::empty(), 0)).unwrap(); s1_b.send("HelloWorld").unwrap(); let p_a = r.block_on(n_a.connected()).unwrap(); @@ -367,8 +367,8 @@ fn open_participant_after_remote_part_is_closed() { let n_a = Network::new(Pid::fake(0), &r); let n_b = Network::new(Pid::fake(1), &r); let addr = tcp(); - r.block_on(n_a.listen(addr.clone())).unwrap(); - let p_b = r.block_on(n_b.connect(addr)).unwrap(); + r.block_on(n_a.listen(addr.0)).unwrap(); + let p_b = r.block_on(n_b.connect(addr.1)).unwrap(); let mut s1_b = r.block_on(p_b.open(4, Promises::empty(), 0)).unwrap(); s1_b.send("HelloWorld").unwrap(); drop(s1_b); @@ -387,8 +387,8 @@ fn close_network_scheduler_completely() { let n_a = Network::new(Pid::fake(0), &r); let n_b = Network::new(Pid::fake(1), &r); let addr = tcp(); - r.block_on(n_a.listen(addr.clone())).unwrap(); - let p_b = r.block_on(n_b.connect(addr)).unwrap(); + r.block_on(n_a.listen(addr.0)).unwrap(); + let p_b = r.block_on(n_b.connect(addr.1)).unwrap(); let mut s1_b = r.block_on(p_b.open(4, Promises::empty(), 0)).unwrap(); s1_b.send("HelloWorld").unwrap(); diff --git a/network/tests/helper.rs b/network/tests/helper.rs index 68d5cebd87..9e78928f55 100644 --- a/network/tests/helper.rs +++ b/network/tests/helper.rs @@ -11,7 +11,7 @@ use std::{ use tokio::runtime::Runtime; use tracing::*; use tracing_subscriber::EnvFilter; -use veloren_network::{Network, Participant, Pid, Promises, ProtocolAddr, Stream}; +use veloren_network::{ConnectAddr, ListenAddr, Network, Participant, Pid, Promises, Stream}; #[allow(dead_code)] pub fn setup(tracing: bool, sleep: u64) -> (u64, u64) { @@ -47,7 +47,7 @@ pub fn setup(tracing: bool, sleep: u64) -> (u64, u64) { #[allow(dead_code)] pub fn network_participant_stream( - addr: ProtocolAddr, + addr: (ListenAddr, ConnectAddr), ) -> ( Arc<Runtime>, Network, @@ -62,11 +62,11 @@ pub fn network_participant_stream( let n_a = Network::new(Pid::fake(0), &runtime); let n_b = Network::new(Pid::fake(1), &runtime); - n_a.listen(addr.clone()).await.unwrap(); - let p1_b = n_b.connect(addr).await.unwrap(); + n_a.listen(addr.0).await.unwrap(); + let p1_b = n_b.connect(addr.1).await.unwrap(); let p1_a = n_a.connected().await.unwrap(); - let s1_a = p1_a.open(4, Promises::empty(), 0).await.unwrap(); + let s1_a = p1_a.open(4, Promises::ORDERED, 0).await.unwrap(); let s1_b = p1_b.opened().await.unwrap(); (n_a, p1_a, s1_a, n_b, p1_b, s1_b) @@ -75,28 +75,76 @@ pub fn network_participant_stream( } #[allow(dead_code)] -pub fn tcp() -> ProtocolAddr { +pub fn tcp() -> (ListenAddr, ConnectAddr) { lazy_static! { static ref PORTS: AtomicU16 = AtomicU16::new(5000); } let port = PORTS.fetch_add(1, Ordering::Relaxed); - ProtocolAddr::Tcp(SocketAddr::from(([127, 0, 0, 1], port))) + ( + ListenAddr::Tcp(SocketAddr::from(([127, 0, 0, 1], port))), + ConnectAddr::Tcp(SocketAddr::from(([127, 0, 0, 1], port))), + ) +} + +lazy_static! { + static ref UDP_PORTS: AtomicU16 = AtomicU16::new(5000); } #[allow(dead_code)] -pub fn udp() -> ProtocolAddr { - lazy_static! { - static ref PORTS: AtomicU16 = AtomicU16::new(5000); - } - let port = PORTS.fetch_add(1, Ordering::Relaxed); - ProtocolAddr::Udp(SocketAddr::from(([127, 0, 0, 1], port))) +pub fn quic() -> (ListenAddr, ConnectAddr) { + const LOCALHOST: &str = "localhost"; + let port = UDP_PORTS.fetch_add(1, Ordering::Relaxed); + + let transport_config = quinn::TransportConfig::default(); + let mut server_config = quinn::ServerConfig::default(); + server_config.transport = Arc::new(transport_config); + let mut server_config = quinn::ServerConfigBuilder::new(server_config); + server_config.protocols(&[b"veloren"]); + + trace!("generating self-signed certificate"); + let cert = rcgen::generate_simple_self_signed(vec![LOCALHOST.into()]).unwrap(); + let key = cert.serialize_private_key_der(); + let cert = cert.serialize_der().unwrap(); + + let key = quinn::PrivateKey::from_der(&key).expect("private key failed"); + let cert = quinn::Certificate::from_der(&cert).expect("cert failed"); + server_config + .certificate(quinn::CertificateChain::from_certs(vec![cert.clone()]), key) + .expect("set cert failed"); + + let server_config = server_config.build(); + + let mut client_config = quinn::ClientConfigBuilder::default(); + client_config.protocols(&[b"veloren"]); + client_config + .add_certificate_authority(cert) + .expect("adding certificate failed"); + + let client_config = client_config.build(); + ( + ListenAddr::Quic(SocketAddr::from(([127, 0, 0, 1], port)), server_config), + ConnectAddr::Quic( + SocketAddr::from(([127, 0, 0, 1], port)), + client_config, + LOCALHOST.to_owned(), + ), + ) } #[allow(dead_code)] -pub fn mpsc() -> ProtocolAddr { +pub fn udp() -> (ListenAddr, ConnectAddr) { + let port = UDP_PORTS.fetch_add(1, Ordering::Relaxed); + ( + ListenAddr::Udp(SocketAddr::from(([127, 0, 0, 1], port))), + ConnectAddr::Udp(SocketAddr::from(([127, 0, 0, 1], port))), + ) +} + +#[allow(dead_code)] +pub fn mpsc() -> (ListenAddr, ConnectAddr) { lazy_static! { static ref PORTS: AtomicU64 = AtomicU64::new(5000); } let port = PORTS.fetch_add(1, Ordering::Relaxed); - ProtocolAddr::Mpsc(port) + (ListenAddr::Mpsc(port), ConnectAddr::Mpsc(port)) } diff --git a/network/tests/integration.rs b/network/tests/integration.rs index 93534ac082..e81530b4f0 100644 --- a/network/tests/integration.rs +++ b/network/tests/integration.rs @@ -2,9 +2,9 @@ use std::sync::Arc; use tokio::runtime::Runtime; use veloren_network::{NetworkError, StreamError}; mod helper; -use helper::{mpsc, network_participant_stream, tcp, udp}; +use helper::{mpsc, network_participant_stream, quic, tcp, udp}; use std::io::ErrorKind; -use veloren_network::{Network, Pid, Promises, ProtocolAddr}; +use veloren_network::{ConnectAddr, ListenAddr, Network, Pid, Promises}; #[test] #[ignore] @@ -73,6 +73,30 @@ fn stream_simple_mpsc_3msg() { drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown } +#[test] +fn stream_simple_quic() { + let (_, _) = helper::setup(false, 0); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(quic()); + + s1_a.send("Hello World").unwrap(); + assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown +} + +#[test] +fn stream_simple_quic_3msg() { + let (_, _) = helper::setup(true, 0); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(quic()); + + s1_a.send("Hello World").unwrap(); + s1_a.send(1337).unwrap(); + assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok(1337)); + s1_a.send("3rdMessage").unwrap(); + assert_eq!(r.block_on(s1_b.recv()), Ok("3rdMessage".to_string())); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown +} + #[test] #[ignore] fn stream_simple_udp() { @@ -110,16 +134,16 @@ fn tcp_and_udp_2_connections() -> std::result::Result<(), Box<dyn std::error::Er let network = network; let remote = remote; remote - .listen(ProtocolAddr::Tcp("127.0.0.1:2000".parse().unwrap())) + .listen(ListenAddr::Tcp("127.0.0.1:2000".parse().unwrap())) .await?; remote - .listen(ProtocolAddr::Udp("127.0.0.1:2001".parse().unwrap())) + .listen(ListenAddr::Udp("127.0.0.1:2001".parse().unwrap())) .await?; let p1 = network - .connect(ProtocolAddr::Tcp("127.0.0.1:2000".parse().unwrap())) + .connect(ConnectAddr::Tcp("127.0.0.1:2000".parse().unwrap())) .await?; let p2 = network - .connect(ProtocolAddr::Udp("127.0.0.1:2001".parse().unwrap())) + .connect(ConnectAddr::Udp("127.0.0.1:2001".parse().unwrap())) .await?; assert_eq!(&p1, &p2); Ok(()) @@ -134,13 +158,13 @@ fn failed_listen_on_used_ports() -> std::result::Result<(), Box<dyn std::error:: let network = Network::new(Pid::new(), &r); let udp1 = udp(); let tcp1 = tcp(); - r.block_on(network.listen(udp1.clone()))?; - r.block_on(network.listen(tcp1.clone()))?; + r.block_on(network.listen(udp1.0.clone()))?; + r.block_on(network.listen(tcp1.0.clone()))?; std::thread::sleep(std::time::Duration::from_millis(200)); let network2 = Network::new(Pid::new(), &r); - let e1 = r.block_on(network2.listen(udp1)); - let e2 = r.block_on(network2.listen(tcp1)); + let e1 = r.block_on(network2.listen(udp1.0)); + let e2 = r.block_on(network2.listen(tcp1.0)); match e1 { Err(NetworkError::ListenFailed(e)) if e.kind() == ErrorKind::AddrInUse => (), _ => panic!(), @@ -170,10 +194,10 @@ fn api_stream_send_main() -> std::result::Result<(), Box<dyn std::error::Error>> let network = network; let remote = remote; network - .listen(ProtocolAddr::Tcp("127.0.0.1:1200".parse().unwrap())) + .listen(ListenAddr::Tcp("127.0.0.1:1200".parse().unwrap())) .await?; let remote_p = remote - .connect(ProtocolAddr::Tcp("127.0.0.1:1200".parse().unwrap())) + .connect(ConnectAddr::Tcp("127.0.0.1:1200".parse().unwrap())) .await?; // keep it alive let _stream_p = remote_p @@ -199,10 +223,10 @@ fn api_stream_recv_main() -> std::result::Result<(), Box<dyn std::error::Error>> let network = network; let remote = remote; network - .listen(ProtocolAddr::Tcp("127.0.0.1:1220".parse().unwrap())) + .listen(ListenAddr::Tcp("127.0.0.1:1220".parse().unwrap())) .await?; let remote_p = remote - .connect(ProtocolAddr::Tcp("127.0.0.1:1220".parse().unwrap())) + .connect(ConnectAddr::Tcp("127.0.0.1:1220".parse().unwrap())) .await?; let mut stream_p = remote_p .open(4, Promises::ORDERED | Promises::CONSISTENCY, 0) diff --git a/server/src/lib.rs b/server/src/lib.rs index 692ae02c0e..910f124a13 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -83,7 +83,7 @@ use common_state::plugin::PluginMgr; use common_state::{BuildAreas, State}; use common_systems::add_local_systems; use metrics::{EcsSystemMetrics, PhysicsMetrics, TickMetrics}; -use network::{Network, Pid, ProtocolAddr}; +use network::{ListenAddr, Network, Pid}; use persistence::{ character_loader::{CharacterLoader, CharacterLoaderResponseKind}, character_updater::CharacterUpdater, @@ -386,8 +386,8 @@ impl Server { ) .await }); - runtime.block_on(network.listen(ProtocolAddr::Tcp(settings.gameserver_address)))?; - runtime.block_on(network.listen(ProtocolAddr::Mpsc(14004)))?; + runtime.block_on(network.listen(ListenAddr::Tcp(settings.gameserver_address)))?; + runtime.block_on(network.listen(ListenAddr::Mpsc(14004)))?; let connection_handler = ConnectionHandler::new(network, &runtime); // Initiate real-time world simulation From 01992c05c66f50695c6182c4a061e76c7679e066 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= <marcel.cochem@googlemail.com> Date: Mon, 19 Apr 2021 16:49:23 +0200 Subject: [PATCH 4/7] QuicSink and QuicDrain do work now. When local SendProtocol is opening a Stream, it will send a empty message to QuicDrain which will then know that its time to open a quic stream. It will open a QuicStream and send its SID over to remote. The RecvStream will be send to local QuicSink RemoteRecv will notice a new BiStream was opened and read its Sid. It will now start listening on it. while remote main will get the information that a stream was opened and will notice the frontend. in participant remote Recv is synced with remote send (without triggering a empty message!). RemoteRecv Sink will send the sendstream to RemoteSend Drain and it will be used when a first message is send on this stream. --- network/protocol/src/quic.rs | 43 ++++++++++--------- network/protocol/src/types.rs | 2 + network/src/channel.rs | 78 ++++++++++++++++++++++------------- 3 files changed, 72 insertions(+), 51 deletions(-) diff --git a/network/protocol/src/quic.rs b/network/protocol/src/quic.rs index e656fdf5a1..0e76e1fe32 100644 --- a/network/protocol/src/quic.rs +++ b/network/protocol/src/quic.rs @@ -23,7 +23,7 @@ use tracing::trace; #[derive(PartialEq)] pub enum QuicDataFormatStream { Main, - Reliable(u64), + Reliable(Sid), Unreliable, } @@ -40,9 +40,9 @@ impl QuicDataFormat { } } - fn with_reliable(buffer: &mut BytesMut, id: u64) -> Self { + fn with_reliable(buffer: &mut BytesMut, sid: Sid) -> Self { Self { - stream: QuicDataFormatStream::Reliable(id), + stream: QuicDataFormatStream::Reliable(sid), data: buffer.split(), } } @@ -88,13 +88,19 @@ where main_buffer: BytesMut, unreliable_buffer: BytesMut, reliable_buffers: SortedVec<Sid, BytesMut>, - pending_reliable_buffers: Vec<(u64, BytesMut)>, + pending_reliable_buffers: Vec<(Sid, BytesMut)>, itmsg_allocator: BytesMut, incoming: HashMap<Mid, ITMessage>, sink: S, metrics: ProtocolMetricCache, } +fn is_reliable(p: &Promises) -> bool { + p.contains(Promises::ORDERED) + || p.contains(Promises::CONSISTENCY) + || p.contains(Promises::GUARANTEED_DELIVERY) +} + impl<D> QuicSendProtocol<D> where D: UnreliableDrain<DataFormat = QuicDataFormat>, @@ -148,8 +154,8 @@ where QuicDataFormatStream::Main => &mut self.main_buffer, QuicDataFormatStream::Unreliable => &mut self.unreliable_buffer, QuicDataFormatStream::Reliable(id) => { - match self.reliable_buffers.data.get_mut(id as usize) { - Some((_, buffer)) => buffer, + match self.reliable_buffers.get_mut(&id) { + Some(buffer) => buffer, None => { self.pending_reliable_buffers.push((id, BytesMut::new())); //Violated but will never happen @@ -186,10 +192,7 @@ where } => { self.store .open_stream(sid, prio, promises, guaranteed_bandwidth); - if promises.contains(Promises::ORDERED) - || promises.contains(Promises::CONSISTENCY) - || promises.contains(Promises::GUARANTEED_DELIVERY) - { + if is_reliable(&promises) { self.reliable_buffers.insert(sid, BytesMut::new()); } }, @@ -216,11 +219,10 @@ where } => { self.store .open_stream(sid, prio, promises, guaranteed_bandwidth); - if promises.contains(Promises::ORDERED) - || promises.contains(Promises::CONSISTENCY) - || promises.contains(Promises::GUARANTEED_DELIVERY) - { + if is_reliable(&promises) { self.reliable_buffers.insert(sid, BytesMut::new()); + //Send a empty message to notify local drain of stream + self.drain.send(QuicDataFormat::with_reliable(&mut BytesMut::new(), sid)).await?; } event.to_frame().write_bytes(&mut self.main_buffer); self.drain @@ -284,10 +286,10 @@ where }, } } - for (id, (_, buffer)) in self.reliable_buffers.data.iter_mut().enumerate() { + for (sid, buffer) in self.reliable_buffers.data.iter_mut() { if !buffer.is_empty() { self.drain - .send(QuicDataFormat::with_reliable(buffer, id as u64)) + .send(QuicDataFormat::with_reliable(buffer, *sid)) .await?; } } @@ -354,10 +356,7 @@ where promises, guaranteed_bandwidth, } => { - if promises.contains(Promises::ORDERED) - || promises.contains(Promises::CONSISTENCY) - || promises.contains(Promises::GUARANTEED_DELIVERY) - { + if is_reliable(&promises) { self.reliable_buffers.insert(sid, BytesMut::new()); } break 'outer Ok(ProtocolEvent::OpenStream { @@ -808,7 +807,7 @@ mod tests { length: (DATA1.len() + DATA2.len()) as u64, } .write_bytes(&mut bytes); - s.send(QuicDataFormat::with_reliable(&mut bytes, 0)) + s.send(QuicDataFormat::with_reliable(&mut bytes, sid)) .await .unwrap(); @@ -822,7 +821,7 @@ mod tests { data: Bytes::from(&DATA2[..]), } .write_bytes(&mut bytes); - s.send(QuicDataFormat::with_reliable(&mut bytes, 0)) + s.send(QuicDataFormat::with_reliable(&mut bytes, sid)) .await .unwrap(); diff --git a/network/protocol/src/types.rs b/network/protocol/src/types.rs index dfc9142f38..2e189b412d 100644 --- a/network/protocol/src/types.rs +++ b/network/protocol/src/types.rs @@ -118,6 +118,8 @@ impl Pid { impl Sid { pub const fn new(internal: u64) -> Self { Self { internal } } + pub fn get_u64(&self) -> u64 { self.internal } + #[inline] pub(crate) fn from_bytes(bytes: &mut BytesMut) -> Self { Self { diff --git a/network/src/channel.rs b/network/src/channel.rs index 9866d88da9..872c0647cf 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -86,24 +86,27 @@ impl Protocols { } else { connection.bi_streams.next().await.expect("none").expect("dasdasd") }; - let (streams_s,streams_r) = mpsc::unbounded_channel(); - let streams_s_clone = streams_s.clone(); + let (recvstreams_s,recvstreams_r) = mpsc::unbounded_channel(); + let streams_s_clone = recvstreams_s.clone(); + let (sendstreams_s,sendstreams_r) = mpsc::unbounded_channel(); let sp = QuicSendProtocol::new( QuicDrain { con: connection.connection.clone(), main: sendstream, reliables: std::collections::HashMap::new(), - streams_s: streams_s_clone, + recvstreams_s: streams_s_clone, + sendstreams_r, }, metrics.clone(), ); - spawn_new(recvstream, None, &streams_s); + spawn_new(recvstream, None, &recvstreams_s); let rp = QuicRecvProtocol::new( QuicSink { con: connection.connection, bi: connection.bi_streams, - streams_r, - streams_s, + recvstreams_r, + recvstreams_s, + sendstreams_s, }, metrics, ); @@ -258,15 +261,16 @@ impl UnreliableSink for MpscSink { /////////////////////////////////////// //// QUIC #[cfg(feature = "quic")] -type QuicStream = (BytesMut, Result<Option<usize>, quinn::ReadError>, quinn::RecvStream, Option<u64>); +type QuicStream = (BytesMut, Result<Option<usize>, quinn::ReadError>, quinn::RecvStream, Option<Sid>); #[cfg(feature = "quic")] #[derive(Debug)] pub struct QuicDrain { con: quinn::Connection, main: quinn::SendStream, - reliables: std::collections::HashMap<u64, quinn::SendStream>, - streams_s: mpsc::UnboundedSender<QuicStream>, + reliables: std::collections::HashMap<Sid, quinn::SendStream>, + recvstreams_s: mpsc::UnboundedSender<QuicStream>, + sendstreams_r: mpsc::UnboundedReceiver<quinn::SendStream>, } #[cfg(feature = "quic")] @@ -274,18 +278,19 @@ pub struct QuicDrain { pub struct QuicSink { con: quinn::Connection, bi: quinn::IncomingBiStreams, - streams_r: mpsc::UnboundedReceiver<QuicStream>, - streams_s: mpsc::UnboundedSender<QuicStream>, + recvstreams_r: mpsc::UnboundedReceiver<QuicStream>, + recvstreams_s: mpsc::UnboundedSender<QuicStream>, + sendstreams_s: mpsc::UnboundedSender<quinn::SendStream>, } #[cfg(feature = "quic")] -fn spawn_new(mut recvstream: quinn::RecvStream, id: Option<u64>, streams_s: &mpsc::UnboundedSender<QuicStream>) { +fn spawn_new(mut recvstream: quinn::RecvStream, sid: Option<Sid>, streams_s: &mpsc::UnboundedSender<QuicStream>) { let streams_s_clone = streams_s.clone(); tokio::spawn(async move { let mut buffer = BytesMut::new(); buffer.resize(1500, 0u8); let r = recvstream.read(&mut buffer).await; - let _ = streams_s_clone.send((buffer, r, recvstream, id)); + let _ = streams_s_clone.send((buffer, r, recvstream, sid)); }); } @@ -300,20 +305,30 @@ impl UnreliableDrain for QuicDrain { self.main.write_all(&data.data).await }, QuicDataFormatStream::Unreliable => unimplemented!(), - QuicDataFormatStream::Reliable(id) => { + QuicDataFormatStream::Reliable(sid) => { use std::collections::hash_map::Entry; - match self.reliables.entry(id) { + tracing::trace!(?sid, "Reliable"); + match self.reliables.entry(sid) { Entry::Occupied(mut occupied) => { occupied.get_mut().write_all(&data.data).await }, Entry::Vacant(vacant) => { - match self.con.open_bi().await { - Ok((sendstream, recvstream)) => { - let id = Some(0); //TODO FIXME - spawn_new(recvstream, id, &self.streams_s); - vacant.insert(sendstream).write_all(&data.data).await - }, - Err(_) => return Err(ProtocolError::Closed), + // IF the buffer is empty this was created localy and WE are allowed to open_bi(), if not, we NEED to block on sendstreams_r + if data.data.is_empty() { + match self.con.open_bi().await { + Ok((mut sendstream, recvstream)) => { + // send SID as first msg + if sendstream.write_u64(sid.get_u64()).await.is_err() { + return Err(ProtocolError::Closed); + } + spawn_new(recvstream, Some(sid), &self.recvstreams_s); + vacant.insert(sendstream).write_all(&data.data).await + }, + Err(_) => return Err(ProtocolError::Closed), + } + } else { + let sendstream = self.sendstreams_r.recv().await.ok_or(ProtocolError::Closed)?; + vacant.insert(sendstream).write_all(&data.data).await } }, } @@ -338,16 +353,21 @@ impl UnreliableSink for QuicSink { let (a, b) = tokio::select! { biased; Some(n) = self.bi.next().fuse() => (Some(n), None), - Some(n) = self.streams_r.recv().fuse() => (None, Some(n)), + Some(n) = self.recvstreams_r.recv().fuse() => (None, Some(n)), }; if let Some(remote_stream) = a { match remote_stream { - Ok((sendstream, recvstream)) => { - //FIXME TODO - let id = Some(0); // get real ID - drop(sendstream); // not drop it! - spawn_new(recvstream, id, &self.streams_s); + Ok((sendstream, mut recvstream)) => { + let sid = match recvstream.read_u64().await { + Ok(u64::MAX) => None, //unreliable + Ok(sid) => Some(Sid::new(sid)), + Err(_) => return Err(ProtocolError::Violated), + }; + if self.sendstreams_s.send(sendstream).is_err() { + return Err(ProtocolError::Closed); + } + spawn_new(recvstream, sid, &self.recvstreams_s); }, Err(_) => return Err(ProtocolError::Closed), } @@ -372,7 +392,7 @@ impl UnreliableSink for QuicSink { }?; - let streams_s_clone = self.streams_s.clone(); + let streams_s_clone = self.recvstreams_s.clone(); tokio::spawn(async move { buffer.resize(1500, 0u8); let r = recvstream.read(&mut buffer).await; From 66e206847692e0b1a447799ecbb9588c61672ca4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= <marcel.cochem@googlemail.com> Date: Thu, 22 Apr 2021 21:37:27 +0200 Subject: [PATCH 5/7] move connect code to channel and get rid of unwraps --- network/protocol/src/quic.rs | 4 +- network/src/channel.rs | 151 ++++++++++++++++++++++++++++++----- network/src/scheduler.rs | 97 +++------------------- 3 files changed, 149 insertions(+), 103 deletions(-) diff --git a/network/protocol/src/quic.rs b/network/protocol/src/quic.rs index 0e76e1fe32..a10764491b 100644 --- a/network/protocol/src/quic.rs +++ b/network/protocol/src/quic.rs @@ -222,7 +222,9 @@ where if is_reliable(&promises) { self.reliable_buffers.insert(sid, BytesMut::new()); //Send a empty message to notify local drain of stream - self.drain.send(QuicDataFormat::with_reliable(&mut BytesMut::new(), sid)).await?; + self.drain + .send(QuicDataFormat::with_reliable(&mut BytesMut::new(), sid)) + .await?; } event.to_frame().write_bytes(&mut self.main_buffer); self.drain diff --git a/network/src/channel.rs b/network/src/channel.rs index 872c0647cf..fe3bff971e 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -1,3 +1,4 @@ +use crate::api::NetworkConnectError; use async_trait::async_trait; use bytes::BytesMut; use network_protocol::{ @@ -9,10 +10,12 @@ use network_protocol::{ use std::{sync::Arc, time::Duration}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, + net, net::tcp::{OwnedReadHalf, OwnedWriteHalf}, - sync::mpsc, + sync::{mpsc, oneshot}, }; use tokio_stream::StreamExt; +use tracing::{info, trace}; #[allow(clippy::large_enum_variant)] #[derive(Debug)] @@ -40,6 +43,23 @@ pub(crate) enum RecvProtocols { } impl Protocols { + const MPSC_CHANNEL_BOUND: usize = 1000; + + pub(crate) async fn with_tcp_connect( + addr: std::net::SocketAddr, + cid: Cid, + metrics: Arc<ProtocolMetrics>, + ) -> Result<Self, NetworkConnectError> { + let stream = match net::TcpStream::connect(addr).await { + Ok(stream) => stream, + Err(e) => { + return Err(crate::api::NetworkConnectError::Io(e)); + }, + }; + info!("Connecting Tcp to: {}", stream.peer_addr().unwrap()); + Ok(Protocols::new_tcp(stream, cid, metrics)) + } + pub(crate) fn new_tcp( stream: tokio::net::TcpStream, cid: Cid, @@ -59,6 +79,49 @@ impl Protocols { Protocols::Tcp((sp, rp)) } + pub(crate) async fn with_mpsc_connect( + addr: u64, + cid: Cid, + metrics: Arc<ProtocolMetrics>, + ) -> Result<Self, NetworkConnectError> { + let mpsc_s = match crate::scheduler::MPSC_POOL.lock().await.get(&addr) { + Some(s) => s.clone(), + None => { + return Err(NetworkConnectError::Io(std::io::Error::new( + std::io::ErrorKind::NotConnected, + "no mpsc listen on this addr", + ))); + }, + }; + let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND); + let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel(); + if mpsc_s + .send((remote_to_local_s, local_to_remote_oneshot_s)) + .is_err() + { + return Err(NetworkConnectError::Io(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "mpsc pipe broke during connect", + ))); + } + let local_to_remote_s = match local_to_remote_oneshot_r.await { + Ok(s) => s, + Err(e) => { + return Err(NetworkConnectError::Io(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + e, + ))); + }, + }; + info!(?addr, "Connecting Mpsc"); + Ok(Self::new_mpsc( + local_to_remote_s, + remote_to_local_r, + cid, + metrics, + )) + } + pub(crate) fn new_mpsc( sender: mpsc::Sender<MpscMsg>, receiver: mpsc::Receiver<MpscMsg>, @@ -72,6 +135,46 @@ impl Protocols { Protocols::Mpsc((sp, rp)) } + pub(crate) async fn with_quic_connect( + addr: std::net::SocketAddr, + config: quinn::ClientConfig, + name: String, + cid: Cid, + metrics: Arc<ProtocolMetrics>, + ) -> Result<Self, NetworkConnectError> { + let config = config.clone(); + let endpoint = quinn::Endpoint::builder(); + let (endpoint, _) = match endpoint.bind(&"[::]:0".parse().unwrap()) { + Ok(e) => e, + Err(quinn::EndpointError::Socket(e)) => return Err(NetworkConnectError::Io(e)), + }; + + info!("Connecting Quic to: {}", &addr); + let connecting = endpoint.connect_with(config, &addr, &name).map_err(|e| { + trace!(?e, "error setting up quic"); + NetworkConnectError::Io(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + e, + )) + })?; + let connection = connecting.await.map_err(|e| { + trace!(?e, "error with quic connection"); + NetworkConnectError::Io(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + e, + )) + })?; + Protocols::new_quic(connection, false, cid, metrics) + .await + .map_err(|e| { + trace!(?e, "error with quic"); + NetworkConnectError::Io(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + e, + )) + }) + } + #[cfg(feature = "quic")] pub(crate) async fn new_quic( mut connection: quinn::NewConnection, @@ -81,14 +184,18 @@ impl Protocols { ) -> Result<Self, quinn::ConnectionError> { let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); - let (sendstream, recvstream) = if listen { + let (sendstream, recvstream) = if listen { connection.connection.open_bi().await? } else { - connection.bi_streams.next().await.expect("none").expect("dasdasd") + connection + .bi_streams + .next() + .await + .ok_or_else(|| quinn::ConnectionError::LocallyClosed)?? }; - let (recvstreams_s,recvstreams_r) = mpsc::unbounded_channel(); + let (recvstreams_s, recvstreams_r) = mpsc::unbounded_channel(); let streams_s_clone = recvstreams_s.clone(); - let (sendstreams_s,sendstreams_r) = mpsc::unbounded_channel(); + let (sendstreams_s, sendstreams_r) = mpsc::unbounded_channel(); let sp = QuicSendProtocol::new( QuicDrain { con: connection.connection.clone(), @@ -261,7 +368,12 @@ impl UnreliableSink for MpscSink { /////////////////////////////////////// //// QUIC #[cfg(feature = "quic")] -type QuicStream = (BytesMut, Result<Option<usize>, quinn::ReadError>, quinn::RecvStream, Option<Sid>); +type QuicStream = ( + BytesMut, + Result<Option<usize>, quinn::ReadError>, + quinn::RecvStream, + Option<Sid>, +); #[cfg(feature = "quic")] #[derive(Debug)] @@ -284,7 +396,11 @@ pub struct QuicSink { } #[cfg(feature = "quic")] -fn spawn_new(mut recvstream: quinn::RecvStream, sid: Option<Sid>, streams_s: &mpsc::UnboundedSender<QuicStream>) { +fn spawn_new( + mut recvstream: quinn::RecvStream, + sid: Option<Sid>, + streams_s: &mpsc::UnboundedSender<QuicStream>, +) { let streams_s_clone = streams_s.clone(); tokio::spawn(async move { let mut buffer = BytesMut::new(); @@ -301,19 +417,16 @@ impl UnreliableDrain for QuicDrain { async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { match match data.stream { - QuicDataFormatStream::Main => { - self.main.write_all(&data.data).await - }, + QuicDataFormatStream::Main => self.main.write_all(&data.data).await, QuicDataFormatStream::Unreliable => unimplemented!(), QuicDataFormatStream::Reliable(sid) => { use std::collections::hash_map::Entry; tracing::trace!(?sid, "Reliable"); match self.reliables.entry(sid) { - Entry::Occupied(mut occupied) => { - occupied.get_mut().write_all(&data.data).await - }, + Entry::Occupied(mut occupied) => occupied.get_mut().write_all(&data.data).await, Entry::Vacant(vacant) => { - // IF the buffer is empty this was created localy and WE are allowed to open_bi(), if not, we NEED to block on sendstreams_r + // IF the buffer is empty this was created localy and WE are allowed to + // open_bi(), if not, we NEED to block on sendstreams_r if data.data.is_empty() { match self.con.open_bi().await { Ok((mut sendstream, recvstream)) => { @@ -327,14 +440,17 @@ impl UnreliableDrain for QuicDrain { Err(_) => return Err(ProtocolError::Closed), } } else { - let sendstream = self.sendstreams_r.recv().await.ok_or(ProtocolError::Closed)?; + let sendstream = self + .sendstreams_r + .recv() + .await + .ok_or(ProtocolError::Closed)?; vacant.insert(sendstream).write_all(&data.data).await } }, } }, - } - { + } { Ok(()) => Ok(()), Err(_) => Err(ProtocolError::Closed), } @@ -391,7 +507,6 @@ impl UnreliableSink for QuicSink { Err(_) => Err(ProtocolError::Closed), }?; - let streams_s_clone = self.recvstreams_s.clone(); tokio::spawn(async move { buffer.resize(1500, 0u8); diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index 475e34371f..1e8d5c69a8 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -34,7 +34,7 @@ use tracing::*; // - c: channel/handshake lazy_static::lazy_static! { - static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<S2sMpscConnect>>> = { + pub(crate) static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<S2sMpscConnect>>> = { Mutex::new(HashMap::new()) }; } @@ -197,94 +197,23 @@ impl Scheduler { let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); let metrics = Arc::clone(&self.protocol_metrics); self.metrics.connect_request(&addr); - let (protocol, handshake) = match addr { - ConnectAddr::Tcp(addr) => { - let stream = match net::TcpStream::connect(addr).await { - Ok(stream) => stream, - Err(e) => { - pid_sender.send(Err(NetworkConnectError::Io(e))).unwrap(); - continue; - }, - }; - info!("Connecting Tcp to: {}", stream.peer_addr().unwrap()); - (Protocols::new_tcp(stream, cid, metrics), false) - }, + let protocol = match addr { + ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, cid, metrics).await, #[cfg(feature = "quic")] ConnectAddr::Quic(addr, ref config, name) => { - let config = config.clone(); - let endpoint = quinn::Endpoint::builder(); - let (endpoint, _) = endpoint.bind(&"[::]:0".parse().unwrap()).expect("FIXME"); - - let connecting = endpoint.connect_with(config, &addr, &name).expect("FIXME"); - let connection = connecting.await.expect("FIXME"); - ( - Protocols::new_quic(connection, false, cid, metrics).await.unwrap(), - false, - ) - //pid_sender.send(Ok(())).unwrap(); + Protocols::with_quic_connect(addr, config.clone(), name, cid, metrics).await }, - ConnectAddr::Mpsc(addr) => { - let mpsc_s = match MPSC_POOL.lock().await.get(&addr) { - Some(s) => s.clone(), - None => { - pid_sender - .send(Err(NetworkConnectError::Io(std::io::Error::new( - std::io::ErrorKind::NotConnected, - "no mpsc listen on this addr", - )))) - .unwrap(); - continue; - }, - }; - let (remote_to_local_s, remote_to_local_r) = - mpsc::channel(Self::MPSC_CHANNEL_BOUND); - let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel(); - mpsc_s - .send((remote_to_local_s, local_to_remote_oneshot_s)) - .unwrap(); - let local_to_remote_s = local_to_remote_oneshot_r.await.unwrap(); - info!(?addr, "Connecting Mpsc"); - ( - Protocols::new_mpsc(local_to_remote_s, remote_to_local_r, cid, metrics), - false, - ) - }, - /* */ - //ProtocolConnectAddr::Udp(addr) => { - //#[cfg(feature = "metrics")] - //self.metrics - //.connect_requests_total - //.with_label_values(&["udp"]) - //.inc(); - //let socket = match net::UdpSocket::bind("0.0.0.0:0").await { - //Ok(socket) => Arc::new(socket), - //Err(e) => { - //pid_sender.send(Err(e)).unwrap(); - //continue; - //}, - //}; - //if let Err(e) = socket.connect(addr).await { - //pid_sender.send(Err(e)).unwrap(); - //continue; - //}; - //info!("Connecting Udp to: {}", addr); - //let (udp_data_sender, udp_data_receiver) = mpsc::unbounded_channel::<Vec<u8>>(); - //let protocol = UdpProtocol::new( - //Arc::clone(&socket), - //addr, - //#[cfg(feature = "metrics")] - //Arc::clone(&self.metrics), - //udp_data_receiver, - //); - //self.runtime.spawn( - //Self::udp_single_channel_connect(Arc::clone(&socket), udp_data_sender) - //.instrument(tracing::info_span!("udp", ?addr)), - //); - //(Protocols::Udp(protocol), true) - //}, + ConnectAddr::Mpsc(addr) => Protocols::with_mpsc_connect(addr, cid, metrics).await, _ => unimplemented!(), }; - self.init_protocol(protocol, cid, Some(pid_sender), handshake) + let protocol = match protocol { + Ok(p) => p, + Err(e) => { + pid_sender.send(Err(e)).unwrap(); + continue; + }, + }; + self.init_protocol(protocol, cid, Some(pid_sender), false) .await; } trace!("Stop connect_mgr"); From 99a23c6aeac4a39069dbcd0ebdf630fceb9deabf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= <marcel.cochem@googlemail.com> Date: Tue, 27 Apr 2021 17:59:36 +0200 Subject: [PATCH 6/7] extract protocol specific listen code from scheduler and move it to channel.rs --- network/protocol/src/quic.rs | 5 +- network/src/api.rs | 34 ++-- network/src/channel.rs | 272 +++++++++++++++++++++++--------- network/src/message.rs | 2 +- network/src/metrics.rs | 2 + network/src/participant.rs | 10 +- network/src/scheduler.rs | 291 ++++++++--------------------------- network/tests/integration.rs | 2 +- 8 files changed, 300 insertions(+), 318 deletions(-) diff --git a/network/protocol/src/quic.rs b/network/protocol/src/quic.rs index a10764491b..a4dfa328d1 100644 --- a/network/protocol/src/quic.rs +++ b/network/protocol/src/quic.rs @@ -451,7 +451,10 @@ where m.data.extend_from_slice(&data); if m.data.len() == m.length as usize { // finished, yay - let m = self.incoming.remove(&mid).unwrap(); + let m = self + .incoming + .remove(&mid) + .ok_or(ProtocolError::Violated)?; self.metrics.rmsg_ob( m.sid, RemoveReason::Finished, diff --git a/network/src/api.rs b/network/src/api.rs index ad95dd3419..0da58aa6d5 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -145,8 +145,8 @@ pub struct StreamParams { /// [`Arc`](std::sync::Arc) as all commands have internal mutability. /// /// The `Network` has methods to [`connect`] to other [`Participants`] actively -/// via their [`ProtocolConnectAddr`], or [`listen`] passively for [`connected`] -/// [`Participants`] via [`ProtocolListenAddr`]. +/// via their [`ConnectAddr`], or [`listen`] passively for [`connected`] +/// [`Participants`] via [`ListenAddr`]. /// /// Too guarantee a clean shutdown, the [`Runtime`] MUST NOT be droped before /// the Network. @@ -178,6 +178,8 @@ pub struct StreamParams { /// [`connect`]: Network::connect /// [`listen`]: Network::listen /// [`connected`]: Network::connected +/// [`ConnectAddr`]: crate::api::ConnectAddr +/// [`ListenAddr`]: crate::api::ListenAddr pub struct Network { local_pid: Pid, participant_disconnect_sender: Arc<Mutex<HashMap<Pid, A2sDisconnect>>>, @@ -293,7 +295,7 @@ impl Network { } } - /// starts listening on an [`ProtocolListenAddr`]. + /// starts listening on an [`ListenAddr`]. /// When the method returns the `Network` is ready to listen for incoming /// connections OR has returned a [`NetworkError`] (e.g. port already used). /// You can call [`connected`] to asynchrony wait for a [`Participant`] to @@ -303,7 +305,7 @@ impl Network { /// # Examples /// ```ignore /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolListenAddr}; + /// use veloren_network::{Network, Pid, ListenAddr}; /// /// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { /// // Create a Network, listen on port `2000` TCP on all NICs and `2001` UDP locally @@ -311,10 +313,10 @@ impl Network { /// let network = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { /// network - /// .listen(ProtocolListenAddr::Tcp("127.0.0.1:2000".parse().unwrap())) + /// .listen(ListenAddr::Tcp("127.0.0.1:2000".parse().unwrap())) /// .await?; /// network - /// .listen(ProtocolListenAddr::Udp("127.0.0.1:2001".parse().unwrap())) + /// .listen(ListenAddr::Udp("127.0.0.1:2001".parse().unwrap())) /// .await?; /// drop(network); /// # Ok(()) @@ -323,6 +325,7 @@ impl Network { /// ``` /// /// [`connected`]: Network::connected + /// [`ListenAddr`]: crate::api::ListenAddr #[instrument(name="network", skip(self, address), fields(p = %self.local_pid))] pub async fn listen(&self, address: ListenAddr) -> Result<(), NetworkError> { let (s2a_result_s, s2a_result_r) = oneshot::channel::<tokio::io::Result<()>>(); @@ -339,13 +342,13 @@ impl Network { } } - /// starts connection to an [`ProtocolConnectAddr`]. + /// starts connection to an [`ConnectAddr`]. /// When the method returns the Network either returns a [`Participant`] /// ready to open [`Streams`] on OR has returned a [`NetworkError`] (e.g. /// can't connect, or invalid Handshake) # Examples /// ```ignore /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolListenAddr, ProtocolConnectAddr}; + /// use veloren_network::{Network, Pid, ListenAddr, ConnectAddr}; /// /// # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { /// // Create a Network, connect on port `2010` TCP and `2011` UDP like listening above @@ -353,16 +356,16 @@ impl Network { /// let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// # remote.listen(ProtocolListenAddr::Tcp("127.0.0.1:2010".parse().unwrap())).await?; - /// # remote.listen(ProtocolListenAddr::Udp("127.0.0.1:2011".parse().unwrap())).await?; + /// # remote.listen(ListenAddr::Tcp("127.0.0.1:2010".parse().unwrap())).await?; + /// # remote.listen(ListenAddr::Udp("127.0.0.1:2011".parse().unwrap())).await?; /// let p1 = network - /// .connect(ProtocolConnectAddr::Tcp("127.0.0.1:2010".parse().unwrap())) + /// .connect(ConnectAddr::Tcp("127.0.0.1:2010".parse().unwrap())) /// .await?; /// # //this doesn't work yet, so skip the test /// # //TODO fixme! /// # return Ok(()); /// let p2 = network - /// .connect(ProtocolConnectAddr::Udp("127.0.0.1:2011".parse().unwrap())) + /// .connect(ConnectAddr::Udp("127.0.0.1:2011".parse().unwrap())) /// .await?; /// assert_eq!(&p1, &p2); /// # Ok(()) @@ -374,13 +377,13 @@ impl Network { /// ``` /// Usually the `Network` guarantees that a operation on a [`Participant`] /// succeeds, e.g. by automatic retrying unless it fails completely e.g. by - /// disconnecting from the remote. If 2 [`ProtocolConnectAddres`] you + /// disconnecting from the remote. If 2 [`ConnectAddr] you /// `connect` to belongs to the same [`Participant`], you get the same /// [`Participant`] as a result. This is useful e.g. by connecting to /// the same [`Participant`] via multiple Protocols. /// /// [`Streams`]: crate::api::Stream - /// [`ProtocolConnectAddres`]: crate::api::ProtocolConnectAddr + /// [`ConnectAddr`]: crate::api::ConnectAddr #[instrument(name="network", skip(self, address), fields(p = %self.local_pid))] pub async fn connect(&self, address: ConnectAddr) -> Result<Participant, NetworkError> { let (pid_sender, pid_receiver) = @@ -403,7 +406,7 @@ impl Network { Ok(participant) } - /// returns a [`Participant`] created from a [`ProtocolListenAddr`] you + /// returns a [`Participant`] created from a [`ListenAddr`] you /// called [`listen`] on before. This function will either return a /// working [`Participant`] ready to open [`Streams`] on OR has returned /// a [`NetworkError`] (e.g. Network got closed) @@ -437,6 +440,7 @@ impl Network { /// /// [`Streams`]: crate::api::Stream /// [`listen`]: crate::api::Network::listen + /// [`ListenAddr`]: crate::api::ListenAddr #[instrument(name="network", skip(self), fields(p = %self.local_pid))] pub async fn connected(&self) -> Result<Participant, NetworkError> { let participant = self.connected_receiver.lock().await.recv().await?; diff --git a/network/src/channel.rs b/network/src/channel.rs index fe3bff971e..03930c03ec 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -1,21 +1,34 @@ use crate::api::NetworkConnectError; use async_trait::async_trait; use bytes::BytesMut; +use futures_util::FutureExt; +#[cfg(feature = "quic")] +use futures_util::StreamExt; use network_protocol::{ Bandwidth, Cid, InitProtocolError, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid, - ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, QuicDataFormat, - QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol, Sid, TcpRecvProtocol, + ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol, TcpSendProtocol, UnreliableDrain, UnreliableSink, }; -use std::{sync::Arc, time::Duration}; +#[cfg(feature = "quic")] +use network_protocol::{QuicDataFormat, QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol}; +use std::{ + collections::HashMap, + io, + net::SocketAddr, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::Duration, +}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net, net::tcp::{OwnedReadHalf, OwnedWriteHalf}, - sync::{mpsc, oneshot}, + select, + sync::{mpsc, oneshot, Mutex}, }; -use tokio_stream::StreamExt; -use tracing::{info, trace}; +use tracing::{error, info, trace, warn}; #[allow(clippy::large_enum_variant)] #[derive(Debug)] @@ -42,32 +55,67 @@ pub(crate) enum RecvProtocols { Quic(QuicRecvProtocol<QuicSink>), } +lazy_static::lazy_static! { + pub(crate) static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<C2cMpscConnect>>> = { + Mutex::new(HashMap::new()) + }; +} + +pub(crate) type C2cMpscConnect = ( + mpsc::Sender<MpscMsg>, + oneshot::Sender<mpsc::Sender<MpscMsg>>, +); + impl Protocols { const MPSC_CHANNEL_BOUND: usize = 1000; pub(crate) async fn with_tcp_connect( - addr: std::net::SocketAddr, - cid: Cid, - metrics: Arc<ProtocolMetrics>, + addr: SocketAddr, + metrics: ProtocolMetricCache, ) -> Result<Self, NetworkConnectError> { - let stream = match net::TcpStream::connect(addr).await { - Ok(stream) => stream, - Err(e) => { - return Err(crate::api::NetworkConnectError::Io(e)); - }, - }; - info!("Connecting Tcp to: {}", stream.peer_addr().unwrap()); - Ok(Protocols::new_tcp(stream, cid, metrics)) + let stream = net::TcpStream::connect(addr) + .await + .map_err(NetworkConnectError::Io)?; + info!( + "Connecting Tcp to: {}", + stream.peer_addr().map_err(NetworkConnectError::Io)? + ); + Ok(Self::new_tcp(stream, metrics)) } - pub(crate) fn new_tcp( - stream: tokio::net::TcpStream, - cid: Cid, + pub(crate) async fn with_tcp_listen( + addr: SocketAddr, + cids: Arc<AtomicU64>, metrics: Arc<ProtocolMetrics>, - ) -> Self { - let (r, w) = stream.into_split(); - let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); + s2s_stop_listening_r: oneshot::Receiver<()>, + c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>, + ) -> std::io::Result<()> { + let listener = net::TcpListener::bind(addr).await?; + trace!(?addr, "Tcp Listener bound"); + let mut end_receiver = s2s_stop_listening_r.fuse(); + tokio::spawn(async move { + while let Some(data) = select! { + next = listener.accept().fuse() => Some(next), + _ = &mut end_receiver => None, + } { + let (stream, remote_addr) = match data { + Ok((s, p)) => (s, p), + Err(e) => { + trace!(?e, "TcpStream Error, ignoring connection attempt"); + continue; + }, + }; + let cid = cids.fetch_add(1, Ordering::Relaxed); + info!(?remote_addr, ?cid, "Accepting Tcp from"); + let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics)); + let _ = c2s_protocol_s.send((Self::new_tcp(stream, metrics.clone()), cid)); + } + }); + Ok(()) + } + pub(crate) fn new_tcp(stream: tokio::net::TcpStream, metrics: ProtocolMetricCache) -> Self { + let (r, w) = stream.into_split(); let sp = TcpSendProtocol::new(TcpDrain { half: w }, metrics.clone()); let rp = TcpRecvProtocol::new( TcpSink { @@ -81,70 +129,104 @@ impl Protocols { pub(crate) async fn with_mpsc_connect( addr: u64, - cid: Cid, - metrics: Arc<ProtocolMetrics>, + metrics: ProtocolMetricCache, ) -> Result<Self, NetworkConnectError> { - let mpsc_s = match crate::scheduler::MPSC_POOL.lock().await.get(&addr) { - Some(s) => s.clone(), - None => { - return Err(NetworkConnectError::Io(std::io::Error::new( - std::io::ErrorKind::NotConnected, + let mpsc_s = MPSC_POOL + .lock() + .await + .get(&addr) + .ok_or_else(|| { + NetworkConnectError::Io(io::Error::new( + io::ErrorKind::NotConnected, "no mpsc listen on this addr", - ))); - }, - }; + )) + })? + .clone(); let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND); let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel(); - if mpsc_s + mpsc_s .send((remote_to_local_s, local_to_remote_oneshot_s)) - .is_err() - { - return Err(NetworkConnectError::Io(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - "mpsc pipe broke during connect", - ))); - } - let local_to_remote_s = match local_to_remote_oneshot_r.await { - Ok(s) => s, - Err(e) => { - return Err(NetworkConnectError::Io(std::io::Error::new( - std::io::ErrorKind::BrokenPipe, - e, - ))); - }, - }; + .map_err(|_| { + NetworkConnectError::Io(io::Error::new( + io::ErrorKind::BrokenPipe, + "mpsc pipe broke during connect", + )) + })?; + let local_to_remote_s = local_to_remote_oneshot_r + .await + .map_err(|e| NetworkConnectError::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?; info!(?addr, "Connecting Mpsc"); Ok(Self::new_mpsc( local_to_remote_s, remote_to_local_r, - cid, metrics, )) } + pub(crate) async fn with_mpsc_listen( + addr: u64, + cids: Arc<AtomicU64>, + metrics: Arc<ProtocolMetrics>, + s2s_stop_listening_r: oneshot::Receiver<()>, + c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>, + ) -> std::io::Result<()> { + let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel(); + MPSC_POOL.lock().await.insert(addr, mpsc_s); + trace!(?addr, "Mpsc Listener bound"); + let mut end_receiver = s2s_stop_listening_r.fuse(); + tokio::spawn(async move { + while let Some((local_to_remote_s, local_remote_to_local_s)) = select! { + next = mpsc_r.recv().fuse() => next, + _ = &mut end_receiver => None, + } { + let (remote_to_local_s, remote_to_local_r) = + mpsc::channel(Self::MPSC_CHANNEL_BOUND); + if let Err(e) = local_remote_to_local_s.send(remote_to_local_s) { + error!(?e, "mpsc listen aborted"); + } + + let cid = cids.fetch_add(1, Ordering::Relaxed); + info!(?addr, ?cid, "Accepting Mpsc from"); + let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics)); + let _ = c2s_protocol_s.send(( + Self::new_mpsc(local_to_remote_s, remote_to_local_r, metrics.clone()), + cid, + )); + } + warn!("MpscStream Failed, stopping"); + }); + Ok(()) + } + pub(crate) fn new_mpsc( sender: mpsc::Sender<MpscMsg>, receiver: mpsc::Receiver<MpscMsg>, - cid: Cid, - metrics: Arc<ProtocolMetrics>, + metrics: ProtocolMetricCache, ) -> Self { - let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); - let sp = MpscSendProtocol::new(MpscDrain { sender }, metrics.clone()); let rp = MpscRecvProtocol::new(MpscSink { receiver }, metrics); Protocols::Mpsc((sp, rp)) } + #[cfg(feature = "quic")] pub(crate) async fn with_quic_connect( - addr: std::net::SocketAddr, + addr: SocketAddr, config: quinn::ClientConfig, name: String, - cid: Cid, - metrics: Arc<ProtocolMetrics>, + metrics: ProtocolMetricCache, ) -> Result<Self, NetworkConnectError> { let config = config.clone(); let endpoint = quinn::Endpoint::builder(); - let (endpoint, _) = match endpoint.bind(&"[::]:0".parse().unwrap()) { + + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + + let bindsock = match addr { + SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0), + SocketAddr::V6(_) => { + SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0) + }, + }; + let (endpoint, _) = match endpoint.bind(&bindsock) { Ok(e) => e, Err(quinn::EndpointError::Socket(e)) => return Err(NetworkConnectError::Io(e)), }; @@ -164,7 +246,7 @@ impl Protocols { e, )) })?; - Protocols::new_quic(connection, false, cid, metrics) + Self::new_quic(connection, false, metrics) .await .map_err(|e| { trace!(?e, "error with quic"); @@ -175,15 +257,60 @@ impl Protocols { }) } + #[cfg(feature = "quic")] + pub(crate) async fn with_quic_listen( + addr: SocketAddr, + server_config: quinn::ServerConfig, + cids: Arc<AtomicU64>, + metrics: Arc<ProtocolMetrics>, + s2s_stop_listening_r: oneshot::Receiver<()>, + c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>, + ) -> std::io::Result<()> { + let mut endpoint = quinn::Endpoint::builder(); + endpoint.listen(server_config); + let (_endpoint, mut listener) = match endpoint.bind(&addr) { + Ok(v) => v, + Err(quinn::EndpointError::Socket(e)) => return Err(e), + }; + trace!(?addr, "Quic Listener bound"); + let mut end_receiver = s2s_stop_listening_r.fuse(); + tokio::spawn(async move { + while let Some(Some(connecting)) = select! { + next = listener.next().fuse() => Some(next), + _ = &mut end_receiver => None, + } { + let remote_addr = connecting.remote_address(); + let connection = match connecting.await { + Ok(c) => c, + Err(e) => { + tracing::debug!(?e, ?remote_addr, "skipping connection attempt"); + continue; + }, + }; + + let cid = cids.fetch_add(1, Ordering::Relaxed); + info!(?remote_addr, ?cid, "Accepting Quic from"); + let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics)); + match Protocols::new_quic(connection, true, metrics).await { + Ok(quic) => { + let _ = c2s_protocol_s.send((quic, cid)); + }, + Err(e) => { + trace!(?e, "failed to start quic"); + continue; + }, + } + } + }); + Ok(()) + } + #[cfg(feature = "quic")] pub(crate) async fn new_quic( mut connection: quinn::NewConnection, listen: bool, - cid: Cid, - metrics: Arc<ProtocolMetrics>, + metrics: ProtocolMetricCache, ) -> Result<Self, quinn::ConnectionError> { - let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); - let (sendstream, recvstream) = if listen { connection.connection.open_bi().await? } else { @@ -191,7 +318,7 @@ impl Protocols { .bi_streams .next() .await - .ok_or_else(|| quinn::ConnectionError::LocallyClosed)?? + .ok_or(quinn::ConnectionError::LocallyClosed)?? }; let (recvstreams_s, recvstreams_r) = mpsc::unbounded_channel(); let streams_s_clone = recvstreams_s.clone(); @@ -521,7 +648,8 @@ impl UnreliableSink for QuicSink { mod tests { use super::*; use bytes::Bytes; - use network_protocol::{Promises, RecvProtocol, SendProtocol}; + use network_protocol::{Promises, ProtocolMetrics, RecvProtocol, SendProtocol}; + use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; #[tokio::test] @@ -533,9 +661,9 @@ mod tests { }); let client = TcpStream::connect("127.0.0.1:5000").await.unwrap(); let (_listener, server) = r1.await.unwrap(); - let metrics = Arc::new(ProtocolMetrics::new().unwrap()); - let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics)); - let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics)); + let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap())); + let client = Protocols::new_tcp(client, metrics.clone()); + let server = Protocols::new_tcp(server, metrics); let (mut s, _) = client.split(); let (_, mut r) = server.split(); let event = ProtocolEvent::OpenStream { @@ -582,9 +710,9 @@ mod tests { }); let client = TcpStream::connect("127.0.0.1:5001").await.unwrap(); let (_listener, server) = r1.await.unwrap(); - let metrics = Arc::new(ProtocolMetrics::new().unwrap()); - let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics)); - let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics)); + let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap())); + let client = Protocols::new_tcp(client, metrics.clone()); + let server = Protocols::new_tcp(server, metrics); let (s, _) = client.split(); let (_, mut r) = server.split(); let e = tokio::spawn(async move { r.recv().await }); diff --git a/network/src/message.rs b/network/src/message.rs index 5c0029cf16..f821511450 100644 --- a/network/src/message.rs +++ b/network/src/message.rs @@ -30,7 +30,7 @@ impl Message { /// # Example /// for example coding, see [`send_raw`] /// - /// [`send_raw`]: Stream::send_raw + /// [`send_raw`]: crate::api::Stream::send_raw /// [`Participants`]: crate::api::Participant /// [`compress`]: lz_fear::raw::compress2 /// [`Message::serialize`]: crate::message::Message::serialize diff --git a/network/src/metrics.rs b/network/src/metrics.rs index d532347140..f3341e392b 100644 --- a/network/src/metrics.rs +++ b/network/src/metrics.rs @@ -251,6 +251,7 @@ fn protocolconnect_name(protocol: &ConnectAddr) -> &str { ConnectAddr::Tcp(_) => "tcp", ConnectAddr::Udp(_) => "udp", ConnectAddr::Mpsc(_) => "mpsc", + #[cfg(feature = "quic")] ConnectAddr::Quic(_, _, _) => "quic", } } @@ -261,6 +262,7 @@ fn protocollisten_name(protocol: &ListenAddr) -> &str { ListenAddr::Tcp(_) => "tcp", ListenAddr::Udp(_) => "udp", ListenAddr::Mpsc(_) => "mpsc", + #[cfg(feature = "quic")] ListenAddr::Quic(_, _) => "quic", } } diff --git a/network/src/participant.rs b/network/src/participant.rs index a06321201c..2735fd5bdd 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -756,7 +756,7 @@ impl BParticipant { #[cfg(test)] mod tests { use super::*; - use network_protocol::ProtocolMetrics; + use network_protocol::{ProtocolMetricCache, ProtocolMetrics}; use tokio::{ runtime::Runtime, sync::{mpsc, oneshot}, @@ -816,14 +816,16 @@ mod tests { ) -> Protocols { let (s1, r1) = mpsc::channel(100); let (s2, r2) = mpsc::channel(100); - let metrics = Arc::new(ProtocolMetrics::new().unwrap()); - let p1 = Protocols::new_mpsc(s1, r2, cid, Arc::clone(&metrics)); + let met = Arc::new(ProtocolMetrics::new().unwrap()); + let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&met)); + let p1 = Protocols::new_mpsc(s1, r2, metrics); let (complete_s, complete_r) = oneshot::channel(); create_channel .send((cid, Sid::new(0), p1, complete_s)) .unwrap(); complete_r.await.unwrap(); - Protocols::new_mpsc(s2, r1, cid, Arc::clone(&metrics)) + let metrics = ProtocolMetricCache::new(&cid.to_string(), met); + Protocols::new_mpsc(s2, r1, metrics) } #[test] diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index 1e8d5c69a8..a232be440b 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -4,8 +4,8 @@ use crate::{ metrics::{NetworkMetrics, ProtocolInfo}, participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel, S2bShutdownBparticipant}, }; -use futures_util::{FutureExt, StreamExt}; -use network_protocol::{Cid, MpscMsg, Pid, ProtocolMetrics}; +use futures_util::StreamExt; +use network_protocol::{Cid, Pid, ProtocolMetricCache, ProtocolMetrics}; #[cfg(feature = "metrics")] use prometheus::Registry; use rand::Rng; @@ -18,7 +18,7 @@ use std::{ time::Duration, }; use tokio::{ - io, net, select, + io, sync::{mpsc, oneshot, Mutex}, }; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -33,12 +33,6 @@ use tracing::*; // - w: wire // - c: channel/handshake -lazy_static::lazy_static! { - pub(crate) static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<S2sMpscConnect>>> = { - Mutex::new(HashMap::new()) - }; -} - #[derive(Debug)] struct ParticipantInfo { secret: u128, @@ -52,10 +46,6 @@ pub(crate) type A2sConnect = ( oneshot::Sender<Result<Participant, NetworkConnectError>>, ); type A2sDisconnect = (Pid, S2bShutdownBparticipant); -type S2sMpscConnect = ( - mpsc::Sender<MpscMsg>, - oneshot::Sender<mpsc::Sender<MpscMsg>>, -); #[derive(Debug)] struct ControlChannels { @@ -88,8 +78,6 @@ pub struct Scheduler { } impl Scheduler { - const MPSC_CHANNEL_BOUND: usize = 1000; - pub fn new( local_pid: Pid, #[cfg(feature = "metrics")] registry: Option<&Registry>, @@ -157,7 +145,10 @@ impl Scheduler { } pub async fn run(mut self) { - let run_channels = self.run_channels.take().unwrap(); + let run_channels = self + .run_channels + .take() + .expect("run() can only be called once"); tokio::join!( self.listen_mgr(run_channels.a2s_listen_r), @@ -174,17 +165,66 @@ impl Scheduler { a2s_listen_r .for_each_concurrent(None, |(address, s2a_listen_result_s)| { let address = address; + let cids = Arc::clone(&self.channel_ids); + + #[cfg(feature = "metrics")] + let mcache = self.metrics.connect_requests_cache(&address); + + debug!(?address, "Got request to open a channel_creator"); + self.metrics.listen_request(&address); + let (s2s_stop_listening_s, s2s_stop_listening_r) = oneshot::channel::<()>(); + let (c2s_protocol_s, mut c2s_protocol_r) = mpsc::unbounded_channel(); + let metrics = Arc::clone(&self.protocol_metrics); async move { - debug!(?address, "Got request to open a channel_creator"); - self.metrics.listen_request(&address); - let (end_sender, end_receiver) = oneshot::channel::<()>(); self.channel_listener .lock() .await - .insert(address.clone().into(), end_sender); - self.channel_creator(address, end_receiver, s2a_listen_result_s) - .await; + .insert(address.clone().into(), s2s_stop_listening_s); + + #[cfg(feature = "metrics")] + mcache.inc(); + + let res = match address { + ListenAddr::Tcp(addr) => { + Protocols::with_tcp_listen( + addr, + cids, + metrics, + s2s_stop_listening_r, + c2s_protocol_s, + ) + .await + }, + #[cfg(feature = "quic")] + ListenAddr::Quic(addr, ref server_config) => { + Protocols::with_quic_listen( + addr, + server_config.clone(), + cids, + metrics, + s2s_stop_listening_r, + c2s_protocol_s, + ) + .await + }, + ListenAddr::Mpsc(addr) => { + Protocols::with_mpsc_listen( + addr, + cids, + metrics, + s2s_stop_listening_r, + c2s_protocol_s, + ) + .await + }, + _ => unimplemented!(), + }; + let _ = s2a_listen_result_s.send(res); + + while let Some((prot, cid)) = c2s_protocol_r.recv().await { + self.init_protocol(prot, cid, None, true).await; + } } }) .await; @@ -195,15 +235,16 @@ impl Scheduler { trace!("Start connect_mgr"); while let Some((addr, pid_sender)) = a2s_connect_r.recv().await { let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); - let metrics = Arc::clone(&self.protocol_metrics); + let metrics = + ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&self.protocol_metrics)); self.metrics.connect_request(&addr); let protocol = match addr { - ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, cid, metrics).await, + ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, metrics).await, #[cfg(feature = "quic")] ConnectAddr::Quic(addr, ref config, name) => { - Protocols::with_quic_connect(addr, config.clone(), name, cid, metrics).await + Protocols::with_quic_connect(addr, config.clone(), name, metrics).await }, - ConnectAddr::Mpsc(addr) => Protocols::with_mpsc_connect(addr, cid, metrics).await, + ConnectAddr::Mpsc(addr) => Protocols::with_mpsc_connect(addr, metrics).await, _ => unimplemented!(), }; let protocol = match protocol { @@ -327,204 +368,6 @@ impl Scheduler { trace!("Stop scheduler_shutdown_mgr"); } - async fn channel_creator( - &self, - addr: ListenAddr, - s2s_stop_listening_r: oneshot::Receiver<()>, - s2a_listen_result_s: oneshot::Sender<io::Result<()>>, - ) { - trace!(?addr, "Start up channel creator"); - #[cfg(feature = "metrics")] - let mcache = self.metrics.connect_requests_cache(&addr); - match addr { - ListenAddr::Tcp(addr) => { - let listener = match net::TcpListener::bind(addr).await { - Ok(listener) => { - s2a_listen_result_s.send(Ok(())).unwrap(); - listener - }, - Err(e) => { - info!( - ?addr, - ?e, - "Tcp bind error during listener startup" - ); - s2a_listen_result_s.send(Err(e)).unwrap(); - return; - }, - }; - trace!(?addr, "Listener bound"); - let mut end_receiver = s2s_stop_listening_r.fuse(); - while let Some(data) = select! { - next = listener.accept().fuse() => Some(next), - _ = &mut end_receiver => None, - } { - let (stream, remote_addr) = match data { - Ok((s, p)) => (s, p), - Err(e) => { - warn!(?e, "TcpStream Error, ignoring connection attempt"); - continue; - }, - }; - #[cfg(feature = "metrics")] - mcache.inc(); - let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); - info!(?remote_addr, ?cid, "Accepting Tcp from"); - self.init_protocol(Protocols::new_tcp(stream, cid, Arc::clone(&self.protocol_metrics)), cid, None, true) - .await; - } - }, - #[cfg(feature = "quic")] - ListenAddr::Quic(addr, ref server_config) => { - let mut endpoint = quinn::Endpoint::builder(); - endpoint.listen(server_config.clone()); - let (_endpoint, mut listener) = match endpoint.bind(&addr) { - Ok((endpoint, listener)) => { - s2a_listen_result_s.send(Ok(())).unwrap(); - (endpoint, listener) - }, - Err(quinn::EndpointError::Socket(e)) => { - info!( - ?addr, - ?e, - "Quic bind error during listener startup" - ); - s2a_listen_result_s.send(Err(e)).unwrap(); - return; - } - }; - trace!(?addr, "Listener bound"); - let mut end_receiver = s2s_stop_listening_r.fuse(); - while let Some(Some(connecting)) = select! { - next = listener.next().fuse() => Some(next), - _ = &mut end_receiver => None, - } { - let remote_addr = connecting.remote_address(); - let connection = match connecting.await { - Ok(c) => c, - Err(e) => { - debug!(?e, ?remote_addr, "skipping connection attempt"); - continue; - }, - }; - #[cfg(feature = "metrics")] - mcache.inc(); - let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); - info!(?remote_addr, ?cid, "Accepting Quic from"); - let quic = match Protocols::new_quic(connection, true, cid, Arc::clone(&self.protocol_metrics)).await { - Ok(quic) => quic, - Err(e) => { - trace!(?e, "failed to start quic"); - continue; - } - }; - self.init_protocol(quic, cid, None, true) - .await; - } - }, - ListenAddr::Mpsc(addr) => { - let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel(); - MPSC_POOL.lock().await.insert(addr, mpsc_s); - s2a_listen_result_s.send(Ok(())).unwrap(); - trace!(?addr, "Listener bound"); - - let mut end_receiver = s2s_stop_listening_r.fuse(); - while let Some((local_to_remote_s, local_remote_to_local_s)) = select! { - next = mpsc_r.recv().fuse() => next, - _ = &mut end_receiver => None, - } { - let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND); - local_remote_to_local_s.send(remote_to_local_s).unwrap(); - #[cfg(feature = "metrics")] - mcache.inc(); - let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); - info!(?addr, ?cid, "Accepting Mpsc from"); - self.init_protocol(Protocols::new_mpsc(local_to_remote_s, remote_to_local_r, cid, Arc::clone(&self.protocol_metrics)), cid, None, true) - .await; - } - warn!("MpscStream Failed, stopping"); - },/* - ProtocolListenAddr::Udp(addr) => { - let socket = match net::UdpSocket::bind(addr).await { - Ok(socket) => { - s2a_listen_result_s.send(Ok(())).unwrap(); - Arc::new(socket) - }, - Err(e) => { - info!( - ?addr, - ?e, - "Listener couldn't be started due to error on udp bind" - ); - s2a_listen_result_s.send(Err(e)).unwrap(); - return; - }, - }; - trace!(?addr, "Listener bound"); - // receiving is done from here and will be piped to protocol as UDP does not - // have any state - let mut listeners = HashMap::new(); - let mut end_receiver = s2s_stop_listening_r.fuse(); - const UDP_MAXIMUM_SINGLE_PACKET_SIZE_EVER: usize = 9216; - let mut data = [0u8; UDP_MAXIMUM_SINGLE_PACKET_SIZE_EVER]; - while let Ok((size, remote_addr)) = select! { - next = socket.recv_from(&mut data).fuse() => next, - _ = &mut end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")), - } { - let mut datavec = Vec::with_capacity(size); - datavec.extend_from_slice(&data[0..size]); - //Due to the async nature i cannot make of .entry() as it would lead to a still - // borrowed in another branch situation - #[allow(clippy::map_entry)] - if !listeners.contains_key(&remote_addr) { - info!("Accepting Udp from: {}", &remote_addr); - let (udp_data_sender, udp_data_receiver) = - mpsc::unbounded_channel::<Vec<u8>>(); - listeners.insert(remote_addr, udp_data_sender); - let protocol = UdpProtocol::new( - Arc::clone(&socket), - remote_addr, - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - udp_data_receiver, - ); - self.init_protocol(Protocols::Udp(protocol), None, false) - .await; - } - let udp_data_sender = listeners.get_mut(&remote_addr).unwrap(); - udp_data_sender.send(datavec).unwrap(); - } - },*/ - _ => unimplemented!(), - } - trace!(?addr, "Ending channel creator"); - } - - #[allow(dead_code)] - async fn udp_single_channel_connect( - socket: Arc<net::UdpSocket>, - w2p_udp_package_s: mpsc::UnboundedSender<Vec<u8>>, - ) { - let addr = socket.local_addr(); - trace!(?addr, "Start udp_single_channel_connect"); - //TODO: implement real closing - let (_end_sender, end_receiver) = oneshot::channel::<()>(); - - // receiving is done from here and will be piped to protocol as UDP does not - // have any state - let mut end_receiver = end_receiver.fuse(); - let mut data = [0u8; 9216]; - while let Ok(size) = select! { - next = socket.recv(&mut data).fuse() => next, - _ = &mut end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")), - } { - let mut datavec = Vec::with_capacity(size); - datavec.extend_from_slice(&data[0..size]); - w2p_udp_package_s.send(datavec).unwrap(); - } - trace!(?addr, "Stop udp_single_channel_connect"); - } - async fn init_protocol( &self, mut protocol: Protocols, diff --git a/network/tests/integration.rs b/network/tests/integration.rs index e81530b4f0..9d2e57bf77 100644 --- a/network/tests/integration.rs +++ b/network/tests/integration.rs @@ -85,7 +85,7 @@ fn stream_simple_quic() { #[test] fn stream_simple_quic_3msg() { - let (_, _) = helper::setup(true, 0); + let (_, _) = helper::setup(false, 0); let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(quic()); s1_a.send("Hello World").unwrap(); From cecf3e5fd0fe0e6768eac89235d3314a34dc946b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= <marcel.cochem@googlemail.com> Date: Thu, 29 Apr 2021 19:12:57 +0200 Subject: [PATCH 7/7] switch network/protocol to hashbrown (5% perf increase) --- Cargo.lock | 1 + network/protocol/Cargo.toml | 1 + network/protocol/src/quic.rs | 6 ++---- network/protocol/src/tcp.rs | 6 ++---- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dc612eb409..51a55f62bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5670,6 +5670,7 @@ dependencies = [ "bitflags", "bytes", "criterion", + "hashbrown", "prometheus", "rand 0.8.3", "tokio", diff --git a/network/protocol/Cargo.toml b/network/protocol/Cargo.toml index 5b06792fc9..4043ac488f 100644 --- a/network/protocol/Cargo.toml +++ b/network/protocol/Cargo.toml @@ -24,6 +24,7 @@ rand = { version = "0.8" } # async traits async-trait = "0.1.42" bytes = "^1" +hashbrown = { version = ">=0.9, <0.12" } [dev-dependencies] async-channel = "1.5.1" diff --git a/network/protocol/src/quic.rs b/network/protocol/src/quic.rs index a4dfa328d1..cd76c0b3ef 100644 --- a/network/protocol/src/quic.rs +++ b/network/protocol/src/quic.rs @@ -12,10 +12,8 @@ use crate::{ }; use async_trait::async_trait; use bytes::BytesMut; -use std::{ - collections::HashMap, - time::{Duration, Instant}, -}; +use hashbrown::HashMap; +use std::time::{Duration, Instant}; use tracing::info; #[cfg(feature = "trace_pedantic")] use tracing::trace; diff --git a/network/protocol/src/tcp.rs b/network/protocol/src/tcp.rs index e6c74df4d4..0909336a58 100644 --- a/network/protocol/src/tcp.rs +++ b/network/protocol/src/tcp.rs @@ -11,10 +11,8 @@ use crate::{ }; use async_trait::async_trait; use bytes::BytesMut; -use std::{ - collections::HashMap, - time::{Duration, Instant}, -}; +use hashbrown::HashMap; +use std::time::{Duration, Instant}; use tracing::info; #[cfg(feature = "trace_pedantic")] use tracing::trace;