From 383482a36e6f82f0fce446e1c0c10cd532a6f293 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= Date: Fri, 9 Apr 2021 13:17:38 +0200 Subject: [PATCH] Quic: We had the followuing problem: - locally we open a stream, our local Drain is sending OpenStream - remote Sink will know this and notify remote Drain - remote side sends a message - local sink does not know about the Stream. as there is (and CANT) be a wat to notify local Sink from local Drain (it could introduce race conditions). One of the possible solutions was, that the remote drain will copy the OpenStream Msg ON the Quic::stream before first data is send. This would work but is complicated. Instead we now just mark such streams as "potentially open" and we listen for the first DataHeader to get it's SID. add support for unreliable messages in quic protocol, benchmarks --- Cargo.lock | 53 +- network/Cargo.toml | 5 +- network/protocol/benches/protocols.rs | 89 ++- network/protocol/src/lib.rs | 13 +- network/protocol/src/prio.rs | 14 +- network/protocol/src/quic.rs | 955 ++++++++++++++++++++++++++ network/protocol/src/tcp.rs | 2 +- network/protocol/src/util.rs | 71 ++ network/src/channel.rs | 58 ++ network/src/participant.rs | 3 +- network/src/util.rs | 71 -- 11 files changed, 1246 insertions(+), 88 deletions(-) create mode 100644 network/protocol/src/quic.rs create mode 100644 network/protocol/src/util.rs diff --git a/Cargo.lock b/Cargo.lock index bf345408d0..5403842ba5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2268,7 +2268,7 @@ dependencies = [ "httpdate", "itoa", "pin-project", - "socket2", + "socket2 0.4.0", "tokio", "tower-service", "tracing", @@ -3854,6 +3854,45 @@ dependencies = [ "tracing", ] +[[package]] +name = "quinn" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c82c0a393b300104f989f3db8b8637c0d11f7a32a9c214560b47849ba8f119aa" +dependencies = [ + "bytes", + "futures", + "lazy_static", + "libc", + "mio 0.7.11", + "quinn-proto", + "rustls", + "socket2 0.3.19", + "thiserror", + "tokio", + "tracing", + "webpki", +] + +[[package]] +name = "quinn-proto" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09027365a21874b71e1fbd9d31cb99bff8e11ba81cc9ef2b9425bb607e42d3b2" +dependencies = [ + "bytes", + "ct-logs", + "rand 0.8.3", + "ring", + "rustls", + "rustls-native-certs", + "slab", + "thiserror", + "tinyvec", + "tracing", + "webpki", +] + [[package]] name = "quote" version = "0.6.13" @@ -4673,6 +4712,17 @@ dependencies = [ "wayland-client 0.28.5", ] +[[package]] +name = "socket2" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "122e570113d28d773067fab24266b66753f6ea915758651696b6e35e49f88d6e" +dependencies = [ + "cfg-if 1.0.0", + "libc", + "winapi 0.3.9", +] + [[package]] name = "socket2" version = "0.4.0" @@ -5576,6 +5626,7 @@ dependencies = [ "lz-fear", "prometheus", "prometheus-hyper", + "quinn", "rand 0.8.3", "serde", "shellexpand", diff --git a/network/Cargo.toml b/network/Cargo.toml index 7f0c45c073..a51119a16a 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -9,8 +9,9 @@ edition = "2018" [features] metrics = ["prometheus", "network-protocol/metrics"] compression = ["lz-fear"] +quic = ["quinn"] -default = ["metrics","compression"] +default = ["metrics","compression","quinn"] [dependencies] @@ -33,6 +34,8 @@ async-channel = "1.5.1" #use for .close() channels #mpsc channel registry lazy_static = { version = "1.4", default-features = false } rand = { version = "0.8" } +#quic support +quinn = { version = "0.7.2", optional = true } #stream flags bitflags = "1.2.1" lz-fear = { version = "0.1.1", optional = true } diff --git a/network/protocol/benches/protocols.rs b/network/protocol/benches/protocols.rs index d8859943ed..cb1dc02038 100644 --- a/network/protocol/benches/protocols.rs +++ b/network/protocol/benches/protocols.rs @@ -6,8 +6,9 @@ use std::{sync::Arc, time::Duration}; use tokio::runtime::Runtime; use veloren_network_protocol::{ InitProtocol, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid, Promises, ProtocolError, - ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, RecvProtocol, SendProtocol, Sid, - TcpRecvProtocol, TcpSendProtocol, UnreliableDrain, UnreliableSink, _internal::OTFrame, + ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, QuicDataFormat, QuicRecvProtocol, + QuicSendProtocol, RecvProtocol, SendProtocol, Sid, TcpRecvProtocol, TcpSendProtocol, + UnreliableDrain, UnreliableSink, _internal::OTFrame, }; fn frame_serialize(frame: OTFrame, buffer: &mut BytesMut) { frame.write_bytes(buffer); } @@ -145,7 +146,35 @@ fn criterion_tcp(c: &mut Criterion) { c.finish(); } -criterion_group!(benches, criterion_util, criterion_mpsc, criterion_tcp); +fn criterion_quic(c: &mut Criterion) { + let mut c = c.benchmark_group("quic"); + c.significance_level(0.1).sample_size(10); + c.throughput(Throughput::Bytes(1000000000)) + .bench_function("1GB_in_10000_msg", |b| { + let buf = Bytes::from(&[155u8; 100_000][..]); + b.to_async(rt()).iter_with_setup( + || (buf.clone(), utils::quic_bound(10000, None)), + |(b, p)| send_and_recv_msg(p, b, 10_000), + ) + }); + c.throughput(Throughput::Elements(1000000)) + .bench_function("1000000_tiny_msg", |b| { + let buf = Bytes::from(&[3u8; 5][..]); + b.to_async(rt()).iter_with_setup( + || (buf.clone(), utils::quic_bound(10000, None)), + |(b, p)| send_and_recv_msg(p, b, 1_000_000), + ) + }); + c.finish(); +} + +criterion_group!( + benches, + criterion_util, + criterion_mpsc, + criterion_tcp, + criterion_quic +); criterion_main!(benches); mod utils { @@ -210,6 +239,36 @@ mod utils { ] } + pub struct QuicDrain { + pub sender: Sender, + } + + pub struct QuicSink { + pub receiver: Receiver, + } + + /// emulate Quic protocol on Channels + pub fn quic_bound( + cap: usize, + metrics: Option, + ) -> [(QuicSendProtocol, QuicRecvProtocol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + QuicSendProtocol::new(QuicDrain { sender: s1 }, m.clone()), + QuicRecvProtocol::new(QuicSink { receiver: r2 }, m.clone()), + ), + ( + QuicSendProtocol::new(QuicDrain { sender: s2 }, m.clone()), + QuicRecvProtocol::new(QuicSink { receiver: r1 }, m), + ), + ] + } + #[async_trait] impl UnreliableDrain for ACDrain { type DataFormat = MpscMsg; @@ -257,4 +316,28 @@ mod utils { .map_err(|_| ProtocolError::Closed) } } + + #[async_trait] + impl UnreliableDrain for QuicDrain { + type DataFormat = QuicDataFormat; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for QuicSink { + type DataFormat = QuicDataFormat; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } } diff --git a/network/protocol/src/lib.rs b/network/protocol/src/lib.rs index 79c1ae867a..3c2eb70c75 100644 --- a/network/protocol/src/lib.rs +++ b/network/protocol/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(drain_filter)] //! Network Protocol //! //! a I/O-Free protocol for the veloren network crate. @@ -13,9 +14,9 @@ //! This crate currently defines: //! - TCP //! - MPSC +//! - QUIC //! -//! a UDP implementation will quickly follow, and it's also possible to abstract -//! over QUIC. +//! eventually a pure UDP implementation will follow //! //! warning: don't mix protocol, using the TCP variant for actual UDP socket //! will result in dropped data using UDP with a TCP socket will be a waste of @@ -57,8 +58,10 @@ mod message; mod metrics; mod mpsc; mod prio; +mod quic; mod tcp; mod types; +mod util; pub use error::{InitProtocolError, ProtocolError}; pub use event::ProtocolEvent; @@ -66,12 +69,16 @@ pub use metrics::ProtocolMetricCache; #[cfg(feature = "metrics")] pub use metrics::ProtocolMetrics; pub use mpsc::{MpscMsg, MpscRecvProtocol, MpscSendProtocol}; +pub use quic::{QuicDataFormat, QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol}; pub use tcp::{TcpRecvProtocol, TcpSendProtocol}; pub use types::{Bandwidth, Cid, Pid, Prio, Promises, Sid, HIGHEST_PRIO, VELOREN_NETWORK_VERSION}; ///use at own risk, might change any time, for internal benchmarks pub mod _internal { - pub use crate::frame::{ITFrame, OTFrame}; + pub use crate::{ + frame::{ITFrame, OTFrame}, + util::SortedVec, + }; } use async_trait::async_trait; diff --git a/network/protocol/src/prio.rs b/network/protocol/src/prio.rs index 374a1ac216..7304086a1b 100644 --- a/network/protocol/src/prio.rs +++ b/network/protocol/src/prio.rs @@ -75,7 +75,7 @@ impl PrioManager { /// bandwidth might be extended, as for technical reasons /// guaranteed_bandwidth is used and frames are always 1400 bytes. - pub fn grab(&mut self, bandwidth: Bandwidth, dt: Duration) -> (Vec, Bandwidth) { + pub fn grab(&mut self, bandwidth: Bandwidth, dt: Duration) -> (Vec<(Sid, OTFrame)>, Bandwidth) { let total_bytes = (bandwidth as f64 * dt.as_secs_f64()) as u64; let mut cur_bytes = 0u64; let mut frames = vec![]; @@ -84,7 +84,7 @@ impl PrioManager { let metrics = &mut self.metrics; let mut process_stream = - |stream: &mut StreamInfo, mut bandwidth: i64, cur_bytes: &mut u64| { + |sid: &Sid, stream: &mut StreamInfo, mut bandwidth: i64, cur_bytes: &mut u64| { let mut finished = None; 'outer: for (i, msg) in stream.messages.iter_mut().enumerate() { while let Some(frame) = msg.next() { @@ -95,7 +95,7 @@ impl PrioManager { } as u64; bandwidth -= b as i64; *cur_bytes += b; - frames.push(frame); + frames.push((*sid, frame)); if bandwidth <= 0 { break 'outer; } @@ -111,10 +111,10 @@ impl PrioManager { }; // Add guaranteed bandwidth - for stream in self.streams.values_mut() { + for (sid, stream) in self.streams.iter_mut() { prios[stream.prio as usize] += 1; let stream_byte_cnt = (stream.guaranteed_bandwidth as f64 * dt.as_secs_f64()) as u64; - process_stream(stream, stream_byte_cnt as i64, &mut cur_bytes); + process_stream(sid, stream, stream_byte_cnt as i64, &mut cur_bytes); } if cur_bytes < total_bytes { @@ -124,11 +124,11 @@ impl PrioManager { continue; } let per_stream_bytes = ((total_bytes - cur_bytes) / prios[prio as usize]) as i64; - for stream in self.streams.values_mut() { + for (sid, stream) in self.streams.iter_mut() { if stream.prio != prio { continue; } - process_stream(stream, per_stream_bytes, &mut cur_bytes); + process_stream(sid, stream, per_stream_bytes, &mut cur_bytes); } } } diff --git a/network/protocol/src/quic.rs b/network/protocol/src/quic.rs new file mode 100644 index 0000000000..d2be37c010 --- /dev/null +++ b/network/protocol/src/quic.rs @@ -0,0 +1,955 @@ +use crate::{ + error::ProtocolError, + event::ProtocolEvent, + frame::{ITFrame, InitFrame, OTFrame}, + handshake::{ReliableDrain, ReliableSink}, + message::{ITMessage, ALLOC_BLOCK}, + metrics::{ProtocolMetricCache, RemoveReason}, + prio::PrioManager, + types::{Bandwidth, Mid, Promises, Sid}, + util::SortedVec, + RecvProtocol, SendProtocol, UnreliableDrain, UnreliableSink, +}; +use async_trait::async_trait; +use bytes::BytesMut; +use std::{ + collections::HashMap, + time::{Duration, Instant}, +}; +use tracing::info; +#[cfg(feature = "trace_pedantic")] +use tracing::trace; + +#[derive(PartialEq)] +pub enum QuicDataFormatStream { + Main, + Reliable(u64), + Unreliable, +} + +pub struct QuicDataFormat { + stream: QuicDataFormatStream, + data: BytesMut, +} + +impl QuicDataFormat { + fn with_main(buffer: &mut BytesMut) -> Self { + Self { + stream: QuicDataFormatStream::Main, + data: buffer.split(), + } + } + + fn with_reliable(buffer: &mut BytesMut, id: u64) -> Self { + Self { + stream: QuicDataFormatStream::Reliable(id), + data: buffer.split(), + } + } + + fn with_unreliable(frame: OTFrame) -> Self { + let mut buffer = BytesMut::new(); + frame.write_bytes(&mut buffer); + Self { + stream: QuicDataFormatStream::Unreliable, + data: buffer, + } + } +} + +/// QUIC implementation of [`SendProtocol`] +/// +/// [`SendProtocol`]: crate::SendProtocol +#[derive(Debug)] +pub struct QuicSendProtocol +where + D: UnreliableDrain, +{ + main_buffer: BytesMut, + reliable_buffers: SortedVec, + store: PrioManager, + next_mid: Mid, + closing_streams: Vec, + notify_closing_streams: Vec, + pending_shutdown: bool, + drain: D, + last: Instant, + metrics: ProtocolMetricCache, +} + +/// QUIC implementation of [`RecvProtocol`] +/// +/// [`RecvProtocol`]: crate::RecvProtocol +#[derive(Debug)] +pub struct QuicRecvProtocol +where + S: UnreliableSink, +{ + main_buffer: BytesMut, + unreliable_buffer: BytesMut, + reliable_buffers: SortedVec, + pending_reliable_buffers: Vec<(u64, BytesMut)>, + itmsg_allocator: BytesMut, + incoming: HashMap, + sink: S, + metrics: ProtocolMetricCache, +} + +impl QuicSendProtocol +where + D: UnreliableDrain, +{ + pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self { + Self { + main_buffer: BytesMut::new(), + reliable_buffers: SortedVec::default(), + store: PrioManager::new(metrics.clone()), + next_mid: 0u64, + closing_streams: vec![], + notify_closing_streams: vec![], + pending_shutdown: false, + drain, + last: Instant::now(), + metrics, + } + } + + /// returns all promises that this Protocol can take care of + /// If you open a Stream anyway, unsupported promises are ignored. + pub fn supported_promises() -> Promises { + Promises::ORDERED + | Promises::CONSISTENCY + | Promises::GUARANTEED_DELIVERY + | Promises::COMPRESSED + | Promises::ENCRYPTED + } +} + +impl QuicRecvProtocol +where + S: UnreliableSink, +{ + pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self { + Self { + main_buffer: BytesMut::new(), + unreliable_buffer: BytesMut::new(), + reliable_buffers: SortedVec::default(), + pending_reliable_buffers: vec![], + itmsg_allocator: BytesMut::with_capacity(ALLOC_BLOCK), + incoming: HashMap::new(), + sink, + metrics, + } + } + + async fn recv_into_stream(&mut self) -> Result { + let chunk = self.sink.recv().await?; + let buffer = match chunk.stream { + QuicDataFormatStream::Main => &mut self.main_buffer, + QuicDataFormatStream::Unreliable => &mut self.unreliable_buffer, + QuicDataFormatStream::Reliable(id) => { + match self.reliable_buffers.data.get_mut(id as usize) { + Some((_, buffer)) => buffer, + None => { + self.pending_reliable_buffers.push((id, BytesMut::new())); + //Violated but will never happen + &mut self + .pending_reliable_buffers + .last_mut() + .ok_or(ProtocolError::Violated)? + .1 + }, + } + }, + }; + if buffer.is_empty() { + *buffer = chunk.data + } else { + buffer.extend_from_slice(&chunk.data) + } + Ok(chunk.stream) + } +} + +#[async_trait] +impl SendProtocol for QuicSendProtocol +where + D: UnreliableDrain, +{ + fn notify_from_recv(&mut self, event: ProtocolEvent) { + match event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + self.store + .open_stream(sid, prio, promises, guaranteed_bandwidth); + if promises.contains(Promises::ORDERED) + || promises.contains(Promises::CONSISTENCY) + || promises.contains(Promises::GUARANTEED_DELIVERY) + { + self.reliable_buffers.insert(sid, BytesMut::new()); + } + }, + ProtocolEvent::CloseStream { sid } => { + if !self.store.try_close_stream(sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "hold back notify close stream"); + self.notify_closing_streams.push(sid); + } + }, + _ => {}, + } + } + + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + #[cfg(feature = "trace_pedantic")] + trace!(?event, "send"); + match event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + self.store + .open_stream(sid, prio, promises, guaranteed_bandwidth); + if promises.contains(Promises::ORDERED) + || promises.contains(Promises::CONSISTENCY) + || promises.contains(Promises::GUARANTEED_DELIVERY) + { + self.reliable_buffers.insert(sid, BytesMut::new()); + } + event.to_frame().write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + }, + ProtocolEvent::CloseStream { sid } => { + if self.store.try_close_stream(sid) { + let _ = self.reliable_buffers.delete(&sid); //delete if it was reliable + event.to_frame().write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + } else { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "hold back close stream"); + self.closing_streams.push(sid); + } + }, + ProtocolEvent::Shutdown => { + if self.store.is_empty() { + event.to_frame().write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + } else { + #[cfg(feature = "trace_pedantic")] + trace!("hold back shutdown"); + self.pending_shutdown = true; + } + }, + ProtocolEvent::Message { data, sid } => { + self.metrics.smsg_ib(sid, data.len() as u64); + self.store.add(data, self.next_mid, sid); + self.next_mid += 1; + }, + } + Ok(()) + } + + async fn flush( + &mut self, + bandwidth: Bandwidth, + dt: Duration, + ) -> Result { + let (frames, _) = self.store.grab(bandwidth, dt); + //Todo: optimize reserve + let mut data_frames = 0; + let mut data_bandwidth = 0; + for (sid, frame) in frames { + if let OTFrame::Data { mid: _, data } = &frame { + data_bandwidth += data.len(); + data_frames += 1; + } + match self.reliable_buffers.get_mut(&sid) { + Some(buffer) => frame.write_bytes(buffer), + None => { + self.drain + .send(QuicDataFormat::with_unreliable(frame)) + .await? + }, + } + } + for (id, (_, buffer)) in self.reliable_buffers.data.iter_mut().enumerate() { + self.drain + .send(QuicDataFormat::with_reliable(buffer, id as u64)) + .await?; + } + self.metrics + .sdata_frames_b(data_frames, data_bandwidth as u64); + + let mut finished_streams = vec![]; + for (i, &sid) in self.closing_streams.iter().enumerate() { + if self.store.try_close_stream(sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "close stream, as it's now empty"); + OTFrame::CloseStream { sid }.write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + finished_streams.push(i); + } + } + for i in finished_streams.iter().rev() { + self.closing_streams.remove(*i); + } + + let mut finished_streams = vec![]; + for (i, sid) in self.notify_closing_streams.iter().enumerate() { + if self.store.try_close_stream(*sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "close stream, as it's now empty"); + finished_streams.push(i); + } + } + for i in finished_streams.iter().rev() { + self.notify_closing_streams.remove(*i); + } + + if self.pending_shutdown && self.store.is_empty() { + #[cfg(feature = "trace_pedantic")] + trace!("shutdown, as it's now empty"); + OTFrame::Shutdown {}.write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + self.pending_shutdown = false; + } + Ok(data_bandwidth as u64) + } +} + +#[async_trait] +impl RecvProtocol for QuicRecvProtocol +where + S: UnreliableSink, +{ + async fn recv(&mut self) -> Result { + 'outer: loop { + loop { + match ITFrame::read_frame(&mut self.main_buffer) { + Ok(Some(frame)) => { + #[cfg(feature = "trace_pedantic")] + trace!(?frame, "recv"); + match frame { + ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), + ITFrame::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + if promises.contains(Promises::ORDERED) + || promises.contains(Promises::CONSISTENCY) + || promises.contains(Promises::GUARANTEED_DELIVERY) + { + self.reliable_buffers.insert(sid, BytesMut::new()); + } + break 'outer Ok(ProtocolEvent::OpenStream { + sid, + prio: prio.min(crate::types::HIGHEST_PRIO), + promises, + guaranteed_bandwidth, + }); + }, + ITFrame::CloseStream { sid } => { + //FIXME: defer close! + //let _ = self.reliable_buffers.delete(sid); // if it was reliable + break 'outer Ok(ProtocolEvent::CloseStream { sid }); + }, + _ => break 'outer Err(ProtocolError::Violated), + }; + }, + Ok(None) => break, //inner => read more data + Err(()) => return Err(ProtocolError::Violated), + } + } + + // try to order pending + let mut pending_violated = false; + let mut reliable = vec![]; + self.pending_reliable_buffers.drain_filter(|(_, buffer)| { + // try to get Sid without touching buffer + let mut testbuffer = buffer.clone(); + match ITFrame::read_frame(&mut testbuffer) { + Ok(Some(ITFrame::DataHeader { + sid, + mid: _, + length: _, + })) => { + reliable.push((sid, buffer.clone())); + true + }, + Ok(Some(_)) | Err(_) => { + pending_violated = true; + true + }, + Ok(None) => false, + } + }); + if pending_violated { + break 'outer Err(ProtocolError::Violated); + } + for (sid, buffer) in reliable.into_iter() { + self.reliable_buffers.insert(sid, buffer) + } + + let mut iter = self + .reliable_buffers + .data + .iter_mut() + .map(|(_, b)| (b, true)) + .collect::>(); + iter.push((&mut self.unreliable_buffer, false)); + + for (buffer, reliable) in iter { + loop { + match ITFrame::read_frame(buffer) { + Ok(Some(frame)) => { + #[cfg(feature = "trace_pedantic")] + trace!(?frame, "recv"); + match frame { + ITFrame::DataHeader { sid, mid, length } => { + let m = ITMessage::new(sid, length, &mut self.itmsg_allocator); + self.metrics.rmsg_ib(sid, length); + self.incoming.insert(mid, m); + }, + ITFrame::Data { mid, data } => { + self.metrics.rdata_frames_b(data.len() as u64); + let m = match self.incoming.get_mut(&mid) { + Some(m) => m, + None => { + if reliable { + info!( + ?mid, + "protocol violation by remote side: send Data before \ + Header" + ); + break 'outer Err(ProtocolError::Violated); + } else { + //TODO: cleanup old messages from time to time + continue; + } + }, + }; + m.data.extend_from_slice(&data); + if m.data.len() == m.length as usize { + // finished, yay + let m = self.incoming.remove(&mid).unwrap(); + self.metrics.rmsg_ob( + m.sid, + RemoveReason::Finished, + m.data.len() as u64, + ); + break 'outer Ok(ProtocolEvent::Message { + sid: m.sid, + data: m.data.freeze(), + }); + } + }, + _ => break 'outer Err(ProtocolError::Violated), + }; + }, + Ok(None) => break, //inner => read more data + Err(()) => return Err(ProtocolError::Violated), + } + } + } + + self.recv_into_stream().await?; + } + } +} + +#[async_trait] +impl ReliableDrain for QuicSendProtocol +where + D: UnreliableDrain, +{ + async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> { + self.main_buffer.reserve(500); + frame.write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await + } +} + +#[async_trait] +impl ReliableSink for QuicRecvProtocol +where + S: UnreliableSink, +{ + async fn recv(&mut self) -> Result { + while self.main_buffer.len() < 100 { + if self.recv_into_stream().await? == QuicDataFormatStream::Main { + if let Some(frame) = InitFrame::read_frame(&mut self.main_buffer) { + return Ok(frame); + } + } + } + Err(ProtocolError::Violated) + } +} + +#[cfg(test)] +mod test_utils { + //Quic protocol based on Channel + use super::*; + use crate::metrics::{ProtocolMetricCache, ProtocolMetrics}; + use async_channel::*; + use std::sync::Arc; + + pub struct QuicDrain { + pub sender: Sender, + pub drop_ratio: f32, + } + + pub struct QuicSink { + pub receiver: Receiver, + } + + /// emulate Quic protocol on Channels + pub fn quic_bound( + cap: usize, + drop_ratio: f32, + metrics: Option, + ) -> [(QuicSendProtocol, QuicRecvProtocol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + QuicSendProtocol::new( + QuicDrain { + sender: s1, + drop_ratio, + }, + m.clone(), + ), + QuicRecvProtocol::new(QuicSink { receiver: r2 }, m.clone()), + ), + ( + QuicSendProtocol::new( + QuicDrain { + sender: s2, + drop_ratio, + }, + m.clone(), + ), + QuicRecvProtocol::new(QuicSink { receiver: r1 }, m), + ), + ] + } + + #[async_trait] + impl UnreliableDrain for QuicDrain { + type DataFormat = QuicDataFormat; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + use rand::Rng; + if matches!(data.stream, QuicDataFormatStream::Unreliable) + && rand::thread_rng().gen::() < self.drop_ratio + { + return Ok(()); + } + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for QuicSink { + type DataFormat = QuicDataFormat; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + error::ProtocolError, + frame::OTFrame, + metrics::{ProtocolMetricCache, ProtocolMetrics, RemoveReason}, + quic::{test_utils::*, QuicDataFormat}, + types::{Pid, Promises, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2}, + InitProtocol, ProtocolEvent, RecvProtocol, SendProtocol, + }; + use bytes::{Bytes, BytesMut}; + use std::{sync::Arc, time::Duration}; + + #[tokio::test] + async fn handshake_all_good() { + let [mut p1, mut p2] = quic_bound(10, 0.5, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42))); + assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337))); + } + + #[tokio::test] + async fn open_stream() { + let [p1, p2] = quic_bound(10, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 0u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event.clone()).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn send_short_msg() { + let [p1, p2] = quic_bound(10, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + data: Bytes::from(&[188u8; 600][..]), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + // 2nd short message + let event = ProtocolEvent::Message { + sid: Sid::new(10), + data: Bytes::from(&[7u8; 30][..]), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e) + } + + #[tokio::test] + async fn send_long_msg() { + let mut metrics = + ProtocolMetricCache::new("long_quic", Arc::new(ProtocolMetrics::new().unwrap())); + let sid = Sid::new(1); + let [p1, p2] = quic_bound(10000, 0.5, Some(metrics.clone())); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + metrics.assert_msg(sid, 1, RemoveReason::Finished); + metrics.assert_msg_bytes(sid, 500_000, RemoveReason::Finished); + metrics.assert_data_frames(358); + metrics.assert_data_frames_bytes(500_000); + } + + #[tokio::test] + async fn msg_finishes_after_close() { + let sid = Sid::new(1); + let [p1, p2] = quic_bound(10000, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::CloseStream { sid }; + s.send(event).await.unwrap(); + //send + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + } + + #[tokio::test] + async fn msg_finishes_after_shutdown() { + let sid = Sid::new(1); + let [p1, p2] = quic_bound(10000, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::Shutdown {}; + s.send(event).await.unwrap(); + let event = ProtocolEvent::CloseStream { sid }; + s.send(event).await.unwrap(); + //send + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Shutdown { .. })); + } + + #[tokio::test] + async fn msg_finishes_after_drop() { + let sid = Sid::new(1); + let [p1, p2] = quic_bound(10000, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[100u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + drop(s); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + } + + #[tokio::test] + async fn header_and_data_in_seperate_msg() { + let sid = Sid::new(1); + let (s, r) = async_channel::bounded(10); + let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap())); + let mut r = + super::QuicRecvProtocol::new(super::test_utils::QuicSink { receiver: r }, m.clone()); + + const DATA1: &[u8; 69] = + b"We need to make sure that its okay to send OPEN_STREAM and DATA_HEAD "; + const DATA2: &[u8; 95] = b"in one chunk and (DATA and CLOSE_STREAM) in the second chunk. and then keep the connection open"; + let mut bytes = BytesMut::with_capacity(1500); + OTFrame::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + } + .write_bytes(&mut bytes); + s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap(); + + OTFrame::DataHeader { + mid: 99, + sid, + length: (DATA1.len() + DATA2.len()) as u64, + } + .write_bytes(&mut bytes); + s.send(QuicDataFormat::with_reliable(&mut bytes, 0)) + .await + .unwrap(); + + OTFrame::Data { + mid: 99, + data: Bytes::from(&DATA1[..]), + } + .write_bytes(&mut bytes); + OTFrame::Data { + mid: 99, + data: Bytes::from(&DATA2[..]), + } + .write_bytes(&mut bytes); + s.send(QuicDataFormat::with_reliable(&mut bytes, 0)) + .await + .unwrap(); + + OTFrame::CloseStream { sid }.write_bytes(&mut bytes); + s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap(); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + } + + #[tokio::test] + async fn drop_sink_while_recv() { + let sid = Sid::new(1); + let (s, r) = async_channel::bounded(10); + let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap())); + let mut r = + super::QuicRecvProtocol::new(super::test_utils::QuicSink { receiver: r }, m.clone()); + + let mut bytes = BytesMut::with_capacity(1500); + OTFrame::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 1_000_000, + } + .write_bytes(&mut bytes); + s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + + let e = tokio::spawn(async move { r.recv().await }); + drop(s); + + let e = e.await.unwrap(); + assert_eq!(e, Err(ProtocolError::Closed)); + } + + #[tokio::test] + #[should_panic] + async fn send_on_stream_from_remote_without_notify() { + //remote opens stream + //we send on it + let [mut p1, mut p2] = quic_bound(10, 0.5, None); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + p1.0.send(event).await.unwrap(); + let _ = p2.1.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + data: Bytes::from(&[188u8; 600][..]), + }; + p2.0.send(event.clone()).await.unwrap(); + p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = p1.1.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn send_on_stream_from_remote() { + //remote opens stream + //we send on it + let [mut p1, mut p2] = quic_bound(10, 0.5, None); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + p1.0.send(event).await.unwrap(); + let e = p2.1.recv().await.unwrap(); + p2.0.notify_from_recv(e); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + data: Bytes::from(&[188u8; 600][..]), + }; + p2.0.send(event.clone()).await.unwrap(); + p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = p1.1.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn unrealiable_test() { + const MIN_CHECK: usize = 10; + const COUNT: usize = 10_000; + //We send COUNT msg with 50% of be send each. we check that >= MIN_CHECK && != + // COUNT reach their target + + let [mut p1, mut p2] = quic_bound( + COUNT * 2 - 1, /* 2 times as it is HEADER + DATA but -1 as we want to see not all + * succeed */ + 0.5, + None, + ); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(1337), + prio: 3u8, + promises: Promises::empty(), /* on purpose! */ + guaranteed_bandwidth: 1_000_000, + }; + p1.0.send(event).await.unwrap(); + let e = p2.1.recv().await.unwrap(); + p2.0.notify_from_recv(e); + let event = ProtocolEvent::Message { + sid: Sid::new(1337), + data: Bytes::from(&[188u8; 600][..]), + }; + for _ in 0..COUNT { + p2.0.send(event.clone()).await.unwrap(); + } + p2.0.flush(1_000_000_000, Duration::from_secs(1)) + .await + .unwrap(); + for _ in 0..COUNT { + p2.0.send(event.clone()).await.unwrap(); + } + for _ in 0..MIN_CHECK { + let e = p1.1.recv().await.unwrap(); + assert_eq!(event, e); + } + } +} diff --git a/network/protocol/src/tcp.rs b/network/protocol/src/tcp.rs index 43d14e2a1e..e6c74df4d4 100644 --- a/network/protocol/src/tcp.rs +++ b/network/protocol/src/tcp.rs @@ -176,7 +176,7 @@ where self.buffer.reserve(total_bytes as usize); let mut data_frames = 0; let mut data_bandwidth = 0; - for frame in frames { + for (_, frame) in frames { if let OTFrame::Data { mid: _, data } = &frame { data_bandwidth += data.len(); data_frames += 1; diff --git a/network/protocol/src/util.rs b/network/protocol/src/util.rs new file mode 100644 index 0000000000..1e28d4c4ab --- /dev/null +++ b/network/protocol/src/util.rs @@ -0,0 +1,71 @@ +/// Used for storing Buffers in a QUIC +#[derive(Debug)] +pub struct SortedVec { + pub data: Vec<(K, V)>, +} + +impl Default for SortedVec { + fn default() -> Self { Self { data: vec![] } } +} + +impl SortedVec +where + K: Ord + Copy, +{ + pub fn insert(&mut self, k: K, v: V) { + self.data.push((k, v)); + self.data.sort_by_key(|&(k, _)| k); + } + + pub fn delete(&mut self, k: &K) -> Option { + if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { + Some(self.data.remove(i).1) + } else { + None + } + } + + pub fn get(&self, k: &K) -> Option<&V> { + if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { + Some(&self.data[i].1) + } else { + None + } + } + + pub fn get_mut(&mut self, k: &K) -> Option<&mut V> { + if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { + Some(&mut self.data[i].1) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sorted_vec() { + let mut vec = SortedVec::default(); + vec.insert(10, "Hello"); + println!("{:?}", vec.data); + vec.insert(30, "World"); + println!("{:?}", vec.data); + vec.insert(20, " "); + println!("{:?}", vec.data); + assert_eq!(vec.data[0].1, "Hello"); + assert_eq!(vec.data[1].1, " "); + assert_eq!(vec.data[2].1, "World"); + assert_eq!(vec.get(&30), Some(&"World")); + assert_eq!(vec.get_mut(&20), Some(&mut " ")); + assert_eq!(vec.get(&10), Some(&"Hello")); + assert_eq!(vec.delete(&40), None); + assert_eq!(vec.delete(&10), Some("Hello")); + assert_eq!(vec.delete(&10), None); + assert_eq!(vec.get(&30), Some(&"World")); + assert_eq!(vec.get_mut(&20), Some(&mut " ")); + assert_eq!(vec.get(&10), None); + } +} diff --git a/network/src/channel.rs b/network/src/channel.rs index cf7a7851bd..ce6ee08cab 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -5,6 +5,7 @@ use network_protocol::{ ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol, TcpSendProtocol, UnreliableDrain, UnreliableSink, }; +#[cfg(feature = "quic")] use quinn::*; use std::{sync::Arc, time::Duration}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, @@ -16,18 +17,24 @@ use tokio::{ pub(crate) enum Protocols { Tcp((TcpSendProtocol, TcpRecvProtocol)), Mpsc((MpscSendProtocol, MpscRecvProtocol)), + #[cfg(feature = "quic")] + Quic((QuicSendProtocol, QuicRecvProtocol)), } #[derive(Debug)] pub(crate) enum SendProtocols { Tcp(TcpSendProtocol), Mpsc(MpscSendProtocol), + #[cfg(feature = "quic")] + Quic(QuicSendProtocol), } #[derive(Debug)] pub(crate) enum RecvProtocols { Tcp(TcpRecvProtocol), Mpsc(MpscRecvProtocol), + #[cfg(feature = "quic")] + Quic(QuicSendProtocol), } impl Protocols { @@ -67,6 +74,8 @@ impl Protocols { match self { Protocols::Tcp((s, r)) => (SendProtocols::Tcp(s), RecvProtocols::Tcp(r)), Protocols::Mpsc((s, r)) => (SendProtocols::Mpsc(s), RecvProtocols::Mpsc(r)), + #[cfg(feature = "quic")] + Protocols::Quic((s, r)) => (SendProtocols::Quic(s), RecvProtocols::Quic(r)), } } } @@ -82,6 +91,8 @@ impl network_protocol::InitProtocol for Protocols { match self { Protocols::Tcp(p) => p.initialize(initializer, local_pid, secret).await, Protocols::Mpsc(p) => p.initialize(initializer, local_pid, secret).await, + #[cfg(feature = "quic")] + Protocols::Quic(p) => p.initialize(initializer, local_pid, secret).await, } } } @@ -92,6 +103,8 @@ impl network_protocol::SendProtocol for SendProtocols { match self { SendProtocols::Tcp(s) => s.notify_from_recv(event), SendProtocols::Mpsc(s) => s.notify_from_recv(event), + #[cfg(feature = "quic")] + SendProtocols::Quic(s) => s.notify_from_recv(event), } } @@ -99,6 +112,8 @@ impl network_protocol::SendProtocol for SendProtocols { match self { SendProtocols::Tcp(s) => s.send(event).await, SendProtocols::Mpsc(s) => s.send(event).await, + #[cfg(feature = "quic")] + SendProtocols::Quic(s) => s.send(event).await, } } @@ -110,6 +125,8 @@ impl network_protocol::SendProtocol for SendProtocols { match self { SendProtocols::Tcp(s) => s.flush(bandwidth, dt).await, SendProtocols::Mpsc(s) => s.flush(bandwidth, dt).await, + #[cfg(feature = "quic")] + SendProtocols::Quic(s) => s.flush(bandwidth, dt).await, } } } @@ -120,6 +137,8 @@ impl network_protocol::RecvProtocol for RecvProtocols { match self { RecvProtocols::Tcp(r) => r.recv().await, RecvProtocols::Mpsc(r) => r.recv().await, + #[cfg(feature = "quic")] + RecvProtocols::Quic(r) => r.recv().await, } } } @@ -196,6 +215,45 @@ impl UnreliableSink for MpscSink { } } +/////////////////////////////////////// +//// QUIC +#[derive(Debug)] +pub struct QuicDrain { + half: OwnedWriteHalf, +} + +#[derive(Debug)] +pub struct QuicSink { + half: OwnedReadHalf, + buffer: BytesMut, +} + +#[async_trait] +impl UnreliableDrain for QuicDrain { + type DataFormat = BytesMut; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + match self.half.write_all(&data).await { + Ok(()) => Ok(()), + Err(_) => Err(ProtocolError::Closed), + } + } +} + +#[async_trait] +impl UnreliableSink for QuicSink { + type DataFormat = BytesMut; + + async fn recv(&mut self) -> Result { + self.buffer.resize(1500, 0u8); + match self.half.read(&mut self.buffer).await { + Ok(0) => Err(ProtocolError::Closed), + Ok(n) => Ok(self.buffer.split_to(n)), + Err(_) => Err(ProtocolError::Closed), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/network/src/participant.rs b/network/src/participant.rs index afa30266e8..a06321201c 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -2,12 +2,13 @@ use crate::{ api::{ParticipantError, Stream}, channel::{Protocols, RecvProtocols, SendProtocols}, metrics::NetworkMetrics, - util::{DeferredTracer, SortedVec}, + util::DeferredTracer, }; use bytes::Bytes; use futures_util::{FutureExt, StreamExt}; use network_protocol::{ Bandwidth, Cid, Pid, Prio, Promises, ProtocolEvent, RecvProtocol, SendProtocol, Sid, + _internal::SortedVec, }; use std::{ collections::HashMap, diff --git a/network/src/util.rs b/network/src/util.rs index b9a8801263..640d65ee55 100644 --- a/network/src/util.rs +++ b/network/src/util.rs @@ -44,74 +44,3 @@ impl DeferredTracer { } } } - -/// Used for storing Protocols in a Participant or Stream <-> Protocol -pub(crate) struct SortedVec { - pub data: Vec<(K, V)>, -} - -impl Default for SortedVec { - fn default() -> Self { Self { data: vec![] } } -} - -impl SortedVec -where - K: Ord + Copy, -{ - pub fn insert(&mut self, k: K, v: V) { - self.data.push((k, v)); - self.data.sort_by_key(|&(k, _)| k); - } - - pub fn delete(&mut self, k: &K) -> Option { - if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { - Some(self.data.remove(i).1) - } else { - None - } - } - - pub fn get(&self, k: &K) -> Option<&V> { - if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { - Some(&self.data[i].1) - } else { - None - } - } - - pub fn get_mut(&mut self, k: &K) -> Option<&mut V> { - if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { - Some(&mut self.data[i].1) - } else { - None - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn sorted_vec() { - let mut vec = SortedVec::default(); - vec.insert(10, "Hello"); - println!("{:?}", vec.data); - vec.insert(30, "World"); - println!("{:?}", vec.data); - vec.insert(20, " "); - println!("{:?}", vec.data); - assert_eq!(vec.data[0].1, "Hello"); - assert_eq!(vec.data[1].1, " "); - assert_eq!(vec.data[2].1, "World"); - assert_eq!(vec.get(&30), Some(&"World")); - assert_eq!(vec.get_mut(&20), Some(&mut " ")); - assert_eq!(vec.get(&10), Some(&"Hello")); - assert_eq!(vec.delete(&40), None); - assert_eq!(vec.delete(&10), Some("Hello")); - assert_eq!(vec.delete(&10), None); - assert_eq!(vec.get(&30), Some(&"World")); - assert_eq!(vec.get_mut(&20), Some(&mut " ")); - assert_eq!(vec.get(&10), None); - } -}