diff --git a/Cargo.lock b/Cargo.lock index bf345408d0..5403842ba5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2268,7 +2268,7 @@ dependencies = [ "httpdate", "itoa", "pin-project", - "socket2", + "socket2 0.4.0", "tokio", "tower-service", "tracing", @@ -3854,6 +3854,45 @@ dependencies = [ "tracing", ] +[[package]] +name = "quinn" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c82c0a393b300104f989f3db8b8637c0d11f7a32a9c214560b47849ba8f119aa" +dependencies = [ + "bytes", + "futures", + "lazy_static", + "libc", + "mio 0.7.11", + "quinn-proto", + "rustls", + "socket2 0.3.19", + "thiserror", + "tokio", + "tracing", + "webpki", +] + +[[package]] +name = "quinn-proto" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09027365a21874b71e1fbd9d31cb99bff8e11ba81cc9ef2b9425bb607e42d3b2" +dependencies = [ + "bytes", + "ct-logs", + "rand 0.8.3", + "ring", + "rustls", + "rustls-native-certs", + "slab", + "thiserror", + "tinyvec", + "tracing", + "webpki", +] + [[package]] name = "quote" version = "0.6.13" @@ -4673,6 +4712,17 @@ dependencies = [ "wayland-client 0.28.5", ] +[[package]] +name = "socket2" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "122e570113d28d773067fab24266b66753f6ea915758651696b6e35e49f88d6e" +dependencies = [ + "cfg-if 1.0.0", + "libc", + "winapi 0.3.9", +] + [[package]] name = "socket2" version = "0.4.0" @@ -5576,6 +5626,7 @@ dependencies = [ "lz-fear", "prometheus", "prometheus-hyper", + "quinn", "rand 0.8.3", "serde", "shellexpand", diff --git a/network/Cargo.toml b/network/Cargo.toml index 7f0c45c073..a51119a16a 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -9,8 +9,9 @@ edition = "2018" [features] metrics = ["prometheus", "network-protocol/metrics"] compression = ["lz-fear"] +quic = ["quinn"] -default = ["metrics","compression"] +default = ["metrics","compression","quinn"] [dependencies] @@ -33,6 +34,8 @@ async-channel = "1.5.1" #use for .close() channels #mpsc channel registry lazy_static = { version = "1.4", default-features = false } rand = { version = "0.8" } +#quic support +quinn = { version = "0.7.2", optional = true } #stream flags bitflags = "1.2.1" lz-fear = { version = "0.1.1", optional = true } diff --git a/network/protocol/benches/protocols.rs b/network/protocol/benches/protocols.rs index d8859943ed..cb1dc02038 100644 --- a/network/protocol/benches/protocols.rs +++ b/network/protocol/benches/protocols.rs @@ -6,8 +6,9 @@ use std::{sync::Arc, time::Duration}; use tokio::runtime::Runtime; use veloren_network_protocol::{ InitProtocol, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid, Promises, ProtocolError, - ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, RecvProtocol, SendProtocol, Sid, - TcpRecvProtocol, TcpSendProtocol, UnreliableDrain, UnreliableSink, _internal::OTFrame, + ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, QuicDataFormat, QuicRecvProtocol, + QuicSendProtocol, RecvProtocol, SendProtocol, Sid, TcpRecvProtocol, TcpSendProtocol, + UnreliableDrain, UnreliableSink, _internal::OTFrame, }; fn frame_serialize(frame: OTFrame, buffer: &mut BytesMut) { frame.write_bytes(buffer); } @@ -145,7 +146,35 @@ fn criterion_tcp(c: &mut Criterion) { c.finish(); } -criterion_group!(benches, criterion_util, criterion_mpsc, criterion_tcp); +fn criterion_quic(c: &mut Criterion) { + let mut c = c.benchmark_group("quic"); + c.significance_level(0.1).sample_size(10); + c.throughput(Throughput::Bytes(1000000000)) + .bench_function("1GB_in_10000_msg", |b| { + let buf = Bytes::from(&[155u8; 100_000][..]); + b.to_async(rt()).iter_with_setup( + || (buf.clone(), utils::quic_bound(10000, None)), + |(b, p)| send_and_recv_msg(p, b, 10_000), + ) + }); + c.throughput(Throughput::Elements(1000000)) + .bench_function("1000000_tiny_msg", |b| { + let buf = Bytes::from(&[3u8; 5][..]); + b.to_async(rt()).iter_with_setup( + || (buf.clone(), utils::quic_bound(10000, None)), + |(b, p)| send_and_recv_msg(p, b, 1_000_000), + ) + }); + c.finish(); +} + +criterion_group!( + benches, + criterion_util, + criterion_mpsc, + criterion_tcp, + criterion_quic +); criterion_main!(benches); mod utils { @@ -210,6 +239,36 @@ mod utils { ] } + pub struct QuicDrain { + pub sender: Sender, + } + + pub struct QuicSink { + pub receiver: Receiver, + } + + /// emulate Quic protocol on Channels + pub fn quic_bound( + cap: usize, + metrics: Option, + ) -> [(QuicSendProtocol, QuicRecvProtocol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + QuicSendProtocol::new(QuicDrain { sender: s1 }, m.clone()), + QuicRecvProtocol::new(QuicSink { receiver: r2 }, m.clone()), + ), + ( + QuicSendProtocol::new(QuicDrain { sender: s2 }, m.clone()), + QuicRecvProtocol::new(QuicSink { receiver: r1 }, m), + ), + ] + } + #[async_trait] impl UnreliableDrain for ACDrain { type DataFormat = MpscMsg; @@ -257,4 +316,28 @@ mod utils { .map_err(|_| ProtocolError::Closed) } } + + #[async_trait] + impl UnreliableDrain for QuicDrain { + type DataFormat = QuicDataFormat; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for QuicSink { + type DataFormat = QuicDataFormat; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } } diff --git a/network/protocol/src/lib.rs b/network/protocol/src/lib.rs index 79c1ae867a..3c2eb70c75 100644 --- a/network/protocol/src/lib.rs +++ b/network/protocol/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(drain_filter)] //! Network Protocol //! //! a I/O-Free protocol for the veloren network crate. @@ -13,9 +14,9 @@ //! This crate currently defines: //! - TCP //! - MPSC +//! - QUIC //! -//! a UDP implementation will quickly follow, and it's also possible to abstract -//! over QUIC. +//! eventually a pure UDP implementation will follow //! //! warning: don't mix protocol, using the TCP variant for actual UDP socket //! will result in dropped data using UDP with a TCP socket will be a waste of @@ -57,8 +58,10 @@ mod message; mod metrics; mod mpsc; mod prio; +mod quic; mod tcp; mod types; +mod util; pub use error::{InitProtocolError, ProtocolError}; pub use event::ProtocolEvent; @@ -66,12 +69,16 @@ pub use metrics::ProtocolMetricCache; #[cfg(feature = "metrics")] pub use metrics::ProtocolMetrics; pub use mpsc::{MpscMsg, MpscRecvProtocol, MpscSendProtocol}; +pub use quic::{QuicDataFormat, QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol}; pub use tcp::{TcpRecvProtocol, TcpSendProtocol}; pub use types::{Bandwidth, Cid, Pid, Prio, Promises, Sid, HIGHEST_PRIO, VELOREN_NETWORK_VERSION}; ///use at own risk, might change any time, for internal benchmarks pub mod _internal { - pub use crate::frame::{ITFrame, OTFrame}; + pub use crate::{ + frame::{ITFrame, OTFrame}, + util::SortedVec, + }; } use async_trait::async_trait; diff --git a/network/protocol/src/prio.rs b/network/protocol/src/prio.rs index 374a1ac216..7304086a1b 100644 --- a/network/protocol/src/prio.rs +++ b/network/protocol/src/prio.rs @@ -75,7 +75,7 @@ impl PrioManager { /// bandwidth might be extended, as for technical reasons /// guaranteed_bandwidth is used and frames are always 1400 bytes. - pub fn grab(&mut self, bandwidth: Bandwidth, dt: Duration) -> (Vec, Bandwidth) { + pub fn grab(&mut self, bandwidth: Bandwidth, dt: Duration) -> (Vec<(Sid, OTFrame)>, Bandwidth) { let total_bytes = (bandwidth as f64 * dt.as_secs_f64()) as u64; let mut cur_bytes = 0u64; let mut frames = vec![]; @@ -84,7 +84,7 @@ impl PrioManager { let metrics = &mut self.metrics; let mut process_stream = - |stream: &mut StreamInfo, mut bandwidth: i64, cur_bytes: &mut u64| { + |sid: &Sid, stream: &mut StreamInfo, mut bandwidth: i64, cur_bytes: &mut u64| { let mut finished = None; 'outer: for (i, msg) in stream.messages.iter_mut().enumerate() { while let Some(frame) = msg.next() { @@ -95,7 +95,7 @@ impl PrioManager { } as u64; bandwidth -= b as i64; *cur_bytes += b; - frames.push(frame); + frames.push((*sid, frame)); if bandwidth <= 0 { break 'outer; } @@ -111,10 +111,10 @@ impl PrioManager { }; // Add guaranteed bandwidth - for stream in self.streams.values_mut() { + for (sid, stream) in self.streams.iter_mut() { prios[stream.prio as usize] += 1; let stream_byte_cnt = (stream.guaranteed_bandwidth as f64 * dt.as_secs_f64()) as u64; - process_stream(stream, stream_byte_cnt as i64, &mut cur_bytes); + process_stream(sid, stream, stream_byte_cnt as i64, &mut cur_bytes); } if cur_bytes < total_bytes { @@ -124,11 +124,11 @@ impl PrioManager { continue; } let per_stream_bytes = ((total_bytes - cur_bytes) / prios[prio as usize]) as i64; - for stream in self.streams.values_mut() { + for (sid, stream) in self.streams.iter_mut() { if stream.prio != prio { continue; } - process_stream(stream, per_stream_bytes, &mut cur_bytes); + process_stream(sid, stream, per_stream_bytes, &mut cur_bytes); } } } diff --git a/network/protocol/src/quic.rs b/network/protocol/src/quic.rs new file mode 100644 index 0000000000..d2be37c010 --- /dev/null +++ b/network/protocol/src/quic.rs @@ -0,0 +1,955 @@ +use crate::{ + error::ProtocolError, + event::ProtocolEvent, + frame::{ITFrame, InitFrame, OTFrame}, + handshake::{ReliableDrain, ReliableSink}, + message::{ITMessage, ALLOC_BLOCK}, + metrics::{ProtocolMetricCache, RemoveReason}, + prio::PrioManager, + types::{Bandwidth, Mid, Promises, Sid}, + util::SortedVec, + RecvProtocol, SendProtocol, UnreliableDrain, UnreliableSink, +}; +use async_trait::async_trait; +use bytes::BytesMut; +use std::{ + collections::HashMap, + time::{Duration, Instant}, +}; +use tracing::info; +#[cfg(feature = "trace_pedantic")] +use tracing::trace; + +#[derive(PartialEq)] +pub enum QuicDataFormatStream { + Main, + Reliable(u64), + Unreliable, +} + +pub struct QuicDataFormat { + stream: QuicDataFormatStream, + data: BytesMut, +} + +impl QuicDataFormat { + fn with_main(buffer: &mut BytesMut) -> Self { + Self { + stream: QuicDataFormatStream::Main, + data: buffer.split(), + } + } + + fn with_reliable(buffer: &mut BytesMut, id: u64) -> Self { + Self { + stream: QuicDataFormatStream::Reliable(id), + data: buffer.split(), + } + } + + fn with_unreliable(frame: OTFrame) -> Self { + let mut buffer = BytesMut::new(); + frame.write_bytes(&mut buffer); + Self { + stream: QuicDataFormatStream::Unreliable, + data: buffer, + } + } +} + +/// QUIC implementation of [`SendProtocol`] +/// +/// [`SendProtocol`]: crate::SendProtocol +#[derive(Debug)] +pub struct QuicSendProtocol +where + D: UnreliableDrain, +{ + main_buffer: BytesMut, + reliable_buffers: SortedVec, + store: PrioManager, + next_mid: Mid, + closing_streams: Vec, + notify_closing_streams: Vec, + pending_shutdown: bool, + drain: D, + last: Instant, + metrics: ProtocolMetricCache, +} + +/// QUIC implementation of [`RecvProtocol`] +/// +/// [`RecvProtocol`]: crate::RecvProtocol +#[derive(Debug)] +pub struct QuicRecvProtocol +where + S: UnreliableSink, +{ + main_buffer: BytesMut, + unreliable_buffer: BytesMut, + reliable_buffers: SortedVec, + pending_reliable_buffers: Vec<(u64, BytesMut)>, + itmsg_allocator: BytesMut, + incoming: HashMap, + sink: S, + metrics: ProtocolMetricCache, +} + +impl QuicSendProtocol +where + D: UnreliableDrain, +{ + pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self { + Self { + main_buffer: BytesMut::new(), + reliable_buffers: SortedVec::default(), + store: PrioManager::new(metrics.clone()), + next_mid: 0u64, + closing_streams: vec![], + notify_closing_streams: vec![], + pending_shutdown: false, + drain, + last: Instant::now(), + metrics, + } + } + + /// returns all promises that this Protocol can take care of + /// If you open a Stream anyway, unsupported promises are ignored. + pub fn supported_promises() -> Promises { + Promises::ORDERED + | Promises::CONSISTENCY + | Promises::GUARANTEED_DELIVERY + | Promises::COMPRESSED + | Promises::ENCRYPTED + } +} + +impl QuicRecvProtocol +where + S: UnreliableSink, +{ + pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self { + Self { + main_buffer: BytesMut::new(), + unreliable_buffer: BytesMut::new(), + reliable_buffers: SortedVec::default(), + pending_reliable_buffers: vec![], + itmsg_allocator: BytesMut::with_capacity(ALLOC_BLOCK), + incoming: HashMap::new(), + sink, + metrics, + } + } + + async fn recv_into_stream(&mut self) -> Result { + let chunk = self.sink.recv().await?; + let buffer = match chunk.stream { + QuicDataFormatStream::Main => &mut self.main_buffer, + QuicDataFormatStream::Unreliable => &mut self.unreliable_buffer, + QuicDataFormatStream::Reliable(id) => { + match self.reliable_buffers.data.get_mut(id as usize) { + Some((_, buffer)) => buffer, + None => { + self.pending_reliable_buffers.push((id, BytesMut::new())); + //Violated but will never happen + &mut self + .pending_reliable_buffers + .last_mut() + .ok_or(ProtocolError::Violated)? + .1 + }, + } + }, + }; + if buffer.is_empty() { + *buffer = chunk.data + } else { + buffer.extend_from_slice(&chunk.data) + } + Ok(chunk.stream) + } +} + +#[async_trait] +impl SendProtocol for QuicSendProtocol +where + D: UnreliableDrain, +{ + fn notify_from_recv(&mut self, event: ProtocolEvent) { + match event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + self.store + .open_stream(sid, prio, promises, guaranteed_bandwidth); + if promises.contains(Promises::ORDERED) + || promises.contains(Promises::CONSISTENCY) + || promises.contains(Promises::GUARANTEED_DELIVERY) + { + self.reliable_buffers.insert(sid, BytesMut::new()); + } + }, + ProtocolEvent::CloseStream { sid } => { + if !self.store.try_close_stream(sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "hold back notify close stream"); + self.notify_closing_streams.push(sid); + } + }, + _ => {}, + } + } + + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + #[cfg(feature = "trace_pedantic")] + trace!(?event, "send"); + match event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + self.store + .open_stream(sid, prio, promises, guaranteed_bandwidth); + if promises.contains(Promises::ORDERED) + || promises.contains(Promises::CONSISTENCY) + || promises.contains(Promises::GUARANTEED_DELIVERY) + { + self.reliable_buffers.insert(sid, BytesMut::new()); + } + event.to_frame().write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + }, + ProtocolEvent::CloseStream { sid } => { + if self.store.try_close_stream(sid) { + let _ = self.reliable_buffers.delete(&sid); //delete if it was reliable + event.to_frame().write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + } else { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "hold back close stream"); + self.closing_streams.push(sid); + } + }, + ProtocolEvent::Shutdown => { + if self.store.is_empty() { + event.to_frame().write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + } else { + #[cfg(feature = "trace_pedantic")] + trace!("hold back shutdown"); + self.pending_shutdown = true; + } + }, + ProtocolEvent::Message { data, sid } => { + self.metrics.smsg_ib(sid, data.len() as u64); + self.store.add(data, self.next_mid, sid); + self.next_mid += 1; + }, + } + Ok(()) + } + + async fn flush( + &mut self, + bandwidth: Bandwidth, + dt: Duration, + ) -> Result { + let (frames, _) = self.store.grab(bandwidth, dt); + //Todo: optimize reserve + let mut data_frames = 0; + let mut data_bandwidth = 0; + for (sid, frame) in frames { + if let OTFrame::Data { mid: _, data } = &frame { + data_bandwidth += data.len(); + data_frames += 1; + } + match self.reliable_buffers.get_mut(&sid) { + Some(buffer) => frame.write_bytes(buffer), + None => { + self.drain + .send(QuicDataFormat::with_unreliable(frame)) + .await? + }, + } + } + for (id, (_, buffer)) in self.reliable_buffers.data.iter_mut().enumerate() { + self.drain + .send(QuicDataFormat::with_reliable(buffer, id as u64)) + .await?; + } + self.metrics + .sdata_frames_b(data_frames, data_bandwidth as u64); + + let mut finished_streams = vec![]; + for (i, &sid) in self.closing_streams.iter().enumerate() { + if self.store.try_close_stream(sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "close stream, as it's now empty"); + OTFrame::CloseStream { sid }.write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + finished_streams.push(i); + } + } + for i in finished_streams.iter().rev() { + self.closing_streams.remove(*i); + } + + let mut finished_streams = vec![]; + for (i, sid) in self.notify_closing_streams.iter().enumerate() { + if self.store.try_close_stream(*sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "close stream, as it's now empty"); + finished_streams.push(i); + } + } + for i in finished_streams.iter().rev() { + self.notify_closing_streams.remove(*i); + } + + if self.pending_shutdown && self.store.is_empty() { + #[cfg(feature = "trace_pedantic")] + trace!("shutdown, as it's now empty"); + OTFrame::Shutdown {}.write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + self.pending_shutdown = false; + } + Ok(data_bandwidth as u64) + } +} + +#[async_trait] +impl RecvProtocol for QuicRecvProtocol +where + S: UnreliableSink, +{ + async fn recv(&mut self) -> Result { + 'outer: loop { + loop { + match ITFrame::read_frame(&mut self.main_buffer) { + Ok(Some(frame)) => { + #[cfg(feature = "trace_pedantic")] + trace!(?frame, "recv"); + match frame { + ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), + ITFrame::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + if promises.contains(Promises::ORDERED) + || promises.contains(Promises::CONSISTENCY) + || promises.contains(Promises::GUARANTEED_DELIVERY) + { + self.reliable_buffers.insert(sid, BytesMut::new()); + } + break 'outer Ok(ProtocolEvent::OpenStream { + sid, + prio: prio.min(crate::types::HIGHEST_PRIO), + promises, + guaranteed_bandwidth, + }); + }, + ITFrame::CloseStream { sid } => { + //FIXME: defer close! + //let _ = self.reliable_buffers.delete(sid); // if it was reliable + break 'outer Ok(ProtocolEvent::CloseStream { sid }); + }, + _ => break 'outer Err(ProtocolError::Violated), + }; + }, + Ok(None) => break, //inner => read more data + Err(()) => return Err(ProtocolError::Violated), + } + } + + // try to order pending + let mut pending_violated = false; + let mut reliable = vec![]; + self.pending_reliable_buffers.drain_filter(|(_, buffer)| { + // try to get Sid without touching buffer + let mut testbuffer = buffer.clone(); + match ITFrame::read_frame(&mut testbuffer) { + Ok(Some(ITFrame::DataHeader { + sid, + mid: _, + length: _, + })) => { + reliable.push((sid, buffer.clone())); + true + }, + Ok(Some(_)) | Err(_) => { + pending_violated = true; + true + }, + Ok(None) => false, + } + }); + if pending_violated { + break 'outer Err(ProtocolError::Violated); + } + for (sid, buffer) in reliable.into_iter() { + self.reliable_buffers.insert(sid, buffer) + } + + let mut iter = self + .reliable_buffers + .data + .iter_mut() + .map(|(_, b)| (b, true)) + .collect::>(); + iter.push((&mut self.unreliable_buffer, false)); + + for (buffer, reliable) in iter { + loop { + match ITFrame::read_frame(buffer) { + Ok(Some(frame)) => { + #[cfg(feature = "trace_pedantic")] + trace!(?frame, "recv"); + match frame { + ITFrame::DataHeader { sid, mid, length } => { + let m = ITMessage::new(sid, length, &mut self.itmsg_allocator); + self.metrics.rmsg_ib(sid, length); + self.incoming.insert(mid, m); + }, + ITFrame::Data { mid, data } => { + self.metrics.rdata_frames_b(data.len() as u64); + let m = match self.incoming.get_mut(&mid) { + Some(m) => m, + None => { + if reliable { + info!( + ?mid, + "protocol violation by remote side: send Data before \ + Header" + ); + break 'outer Err(ProtocolError::Violated); + } else { + //TODO: cleanup old messages from time to time + continue; + } + }, + }; + m.data.extend_from_slice(&data); + if m.data.len() == m.length as usize { + // finished, yay + let m = self.incoming.remove(&mid).unwrap(); + self.metrics.rmsg_ob( + m.sid, + RemoveReason::Finished, + m.data.len() as u64, + ); + break 'outer Ok(ProtocolEvent::Message { + sid: m.sid, + data: m.data.freeze(), + }); + } + }, + _ => break 'outer Err(ProtocolError::Violated), + }; + }, + Ok(None) => break, //inner => read more data + Err(()) => return Err(ProtocolError::Violated), + } + } + } + + self.recv_into_stream().await?; + } + } +} + +#[async_trait] +impl ReliableDrain for QuicSendProtocol +where + D: UnreliableDrain, +{ + async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> { + self.main_buffer.reserve(500); + frame.write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await + } +} + +#[async_trait] +impl ReliableSink for QuicRecvProtocol +where + S: UnreliableSink, +{ + async fn recv(&mut self) -> Result { + while self.main_buffer.len() < 100 { + if self.recv_into_stream().await? == QuicDataFormatStream::Main { + if let Some(frame) = InitFrame::read_frame(&mut self.main_buffer) { + return Ok(frame); + } + } + } + Err(ProtocolError::Violated) + } +} + +#[cfg(test)] +mod test_utils { + //Quic protocol based on Channel + use super::*; + use crate::metrics::{ProtocolMetricCache, ProtocolMetrics}; + use async_channel::*; + use std::sync::Arc; + + pub struct QuicDrain { + pub sender: Sender, + pub drop_ratio: f32, + } + + pub struct QuicSink { + pub receiver: Receiver, + } + + /// emulate Quic protocol on Channels + pub fn quic_bound( + cap: usize, + drop_ratio: f32, + metrics: Option, + ) -> [(QuicSendProtocol, QuicRecvProtocol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + QuicSendProtocol::new( + QuicDrain { + sender: s1, + drop_ratio, + }, + m.clone(), + ), + QuicRecvProtocol::new(QuicSink { receiver: r2 }, m.clone()), + ), + ( + QuicSendProtocol::new( + QuicDrain { + sender: s2, + drop_ratio, + }, + m.clone(), + ), + QuicRecvProtocol::new(QuicSink { receiver: r1 }, m), + ), + ] + } + + #[async_trait] + impl UnreliableDrain for QuicDrain { + type DataFormat = QuicDataFormat; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + use rand::Rng; + if matches!(data.stream, QuicDataFormatStream::Unreliable) + && rand::thread_rng().gen::() < self.drop_ratio + { + return Ok(()); + } + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for QuicSink { + type DataFormat = QuicDataFormat; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + error::ProtocolError, + frame::OTFrame, + metrics::{ProtocolMetricCache, ProtocolMetrics, RemoveReason}, + quic::{test_utils::*, QuicDataFormat}, + types::{Pid, Promises, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2}, + InitProtocol, ProtocolEvent, RecvProtocol, SendProtocol, + }; + use bytes::{Bytes, BytesMut}; + use std::{sync::Arc, time::Duration}; + + #[tokio::test] + async fn handshake_all_good() { + let [mut p1, mut p2] = quic_bound(10, 0.5, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42))); + assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337))); + } + + #[tokio::test] + async fn open_stream() { + let [p1, p2] = quic_bound(10, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 0u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event.clone()).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn send_short_msg() { + let [p1, p2] = quic_bound(10, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + data: Bytes::from(&[188u8; 600][..]), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + // 2nd short message + let event = ProtocolEvent::Message { + sid: Sid::new(10), + data: Bytes::from(&[7u8; 30][..]), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e) + } + + #[tokio::test] + async fn send_long_msg() { + let mut metrics = + ProtocolMetricCache::new("long_quic", Arc::new(ProtocolMetrics::new().unwrap())); + let sid = Sid::new(1); + let [p1, p2] = quic_bound(10000, 0.5, Some(metrics.clone())); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + metrics.assert_msg(sid, 1, RemoveReason::Finished); + metrics.assert_msg_bytes(sid, 500_000, RemoveReason::Finished); + metrics.assert_data_frames(358); + metrics.assert_data_frames_bytes(500_000); + } + + #[tokio::test] + async fn msg_finishes_after_close() { + let sid = Sid::new(1); + let [p1, p2] = quic_bound(10000, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::CloseStream { sid }; + s.send(event).await.unwrap(); + //send + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + } + + #[tokio::test] + async fn msg_finishes_after_shutdown() { + let sid = Sid::new(1); + let [p1, p2] = quic_bound(10000, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::Shutdown {}; + s.send(event).await.unwrap(); + let event = ProtocolEvent::CloseStream { sid }; + s.send(event).await.unwrap(); + //send + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Shutdown { .. })); + } + + #[tokio::test] + async fn msg_finishes_after_drop() { + let sid = Sid::new(1); + let [p1, p2] = quic_bound(10000, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[100u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + drop(s); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + } + + #[tokio::test] + async fn header_and_data_in_seperate_msg() { + let sid = Sid::new(1); + let (s, r) = async_channel::bounded(10); + let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap())); + let mut r = + super::QuicRecvProtocol::new(super::test_utils::QuicSink { receiver: r }, m.clone()); + + const DATA1: &[u8; 69] = + b"We need to make sure that its okay to send OPEN_STREAM and DATA_HEAD "; + const DATA2: &[u8; 95] = b"in one chunk and (DATA and CLOSE_STREAM) in the second chunk. and then keep the connection open"; + let mut bytes = BytesMut::with_capacity(1500); + OTFrame::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + } + .write_bytes(&mut bytes); + s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap(); + + OTFrame::DataHeader { + mid: 99, + sid, + length: (DATA1.len() + DATA2.len()) as u64, + } + .write_bytes(&mut bytes); + s.send(QuicDataFormat::with_reliable(&mut bytes, 0)) + .await + .unwrap(); + + OTFrame::Data { + mid: 99, + data: Bytes::from(&DATA1[..]), + } + .write_bytes(&mut bytes); + OTFrame::Data { + mid: 99, + data: Bytes::from(&DATA2[..]), + } + .write_bytes(&mut bytes); + s.send(QuicDataFormat::with_reliable(&mut bytes, 0)) + .await + .unwrap(); + + OTFrame::CloseStream { sid }.write_bytes(&mut bytes); + s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap(); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + } + + #[tokio::test] + async fn drop_sink_while_recv() { + let sid = Sid::new(1); + let (s, r) = async_channel::bounded(10); + let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap())); + let mut r = + super::QuicRecvProtocol::new(super::test_utils::QuicSink { receiver: r }, m.clone()); + + let mut bytes = BytesMut::with_capacity(1500); + OTFrame::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 1_000_000, + } + .write_bytes(&mut bytes); + s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + + let e = tokio::spawn(async move { r.recv().await }); + drop(s); + + let e = e.await.unwrap(); + assert_eq!(e, Err(ProtocolError::Closed)); + } + + #[tokio::test] + #[should_panic] + async fn send_on_stream_from_remote_without_notify() { + //remote opens stream + //we send on it + let [mut p1, mut p2] = quic_bound(10, 0.5, None); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + p1.0.send(event).await.unwrap(); + let _ = p2.1.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + data: Bytes::from(&[188u8; 600][..]), + }; + p2.0.send(event.clone()).await.unwrap(); + p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = p1.1.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn send_on_stream_from_remote() { + //remote opens stream + //we send on it + let [mut p1, mut p2] = quic_bound(10, 0.5, None); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + p1.0.send(event).await.unwrap(); + let e = p2.1.recv().await.unwrap(); + p2.0.notify_from_recv(e); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + data: Bytes::from(&[188u8; 600][..]), + }; + p2.0.send(event.clone()).await.unwrap(); + p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = p1.1.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn unrealiable_test() { + const MIN_CHECK: usize = 10; + const COUNT: usize = 10_000; + //We send COUNT msg with 50% of be send each. we check that >= MIN_CHECK && != + // COUNT reach their target + + let [mut p1, mut p2] = quic_bound( + COUNT * 2 - 1, /* 2 times as it is HEADER + DATA but -1 as we want to see not all + * succeed */ + 0.5, + None, + ); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(1337), + prio: 3u8, + promises: Promises::empty(), /* on purpose! */ + guaranteed_bandwidth: 1_000_000, + }; + p1.0.send(event).await.unwrap(); + let e = p2.1.recv().await.unwrap(); + p2.0.notify_from_recv(e); + let event = ProtocolEvent::Message { + sid: Sid::new(1337), + data: Bytes::from(&[188u8; 600][..]), + }; + for _ in 0..COUNT { + p2.0.send(event.clone()).await.unwrap(); + } + p2.0.flush(1_000_000_000, Duration::from_secs(1)) + .await + .unwrap(); + for _ in 0..COUNT { + p2.0.send(event.clone()).await.unwrap(); + } + for _ in 0..MIN_CHECK { + let e = p1.1.recv().await.unwrap(); + assert_eq!(event, e); + } + } +} diff --git a/network/protocol/src/tcp.rs b/network/protocol/src/tcp.rs index 43d14e2a1e..e6c74df4d4 100644 --- a/network/protocol/src/tcp.rs +++ b/network/protocol/src/tcp.rs @@ -176,7 +176,7 @@ where self.buffer.reserve(total_bytes as usize); let mut data_frames = 0; let mut data_bandwidth = 0; - for frame in frames { + for (_, frame) in frames { if let OTFrame::Data { mid: _, data } = &frame { data_bandwidth += data.len(); data_frames += 1; diff --git a/network/protocol/src/util.rs b/network/protocol/src/util.rs new file mode 100644 index 0000000000..1e28d4c4ab --- /dev/null +++ b/network/protocol/src/util.rs @@ -0,0 +1,71 @@ +/// Used for storing Buffers in a QUIC +#[derive(Debug)] +pub struct SortedVec { + pub data: Vec<(K, V)>, +} + +impl Default for SortedVec { + fn default() -> Self { Self { data: vec![] } } +} + +impl SortedVec +where + K: Ord + Copy, +{ + pub fn insert(&mut self, k: K, v: V) { + self.data.push((k, v)); + self.data.sort_by_key(|&(k, _)| k); + } + + pub fn delete(&mut self, k: &K) -> Option { + if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { + Some(self.data.remove(i).1) + } else { + None + } + } + + pub fn get(&self, k: &K) -> Option<&V> { + if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { + Some(&self.data[i].1) + } else { + None + } + } + + pub fn get_mut(&mut self, k: &K) -> Option<&mut V> { + if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { + Some(&mut self.data[i].1) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sorted_vec() { + let mut vec = SortedVec::default(); + vec.insert(10, "Hello"); + println!("{:?}", vec.data); + vec.insert(30, "World"); + println!("{:?}", vec.data); + vec.insert(20, " "); + println!("{:?}", vec.data); + assert_eq!(vec.data[0].1, "Hello"); + assert_eq!(vec.data[1].1, " "); + assert_eq!(vec.data[2].1, "World"); + assert_eq!(vec.get(&30), Some(&"World")); + assert_eq!(vec.get_mut(&20), Some(&mut " ")); + assert_eq!(vec.get(&10), Some(&"Hello")); + assert_eq!(vec.delete(&40), None); + assert_eq!(vec.delete(&10), Some("Hello")); + assert_eq!(vec.delete(&10), None); + assert_eq!(vec.get(&30), Some(&"World")); + assert_eq!(vec.get_mut(&20), Some(&mut " ")); + assert_eq!(vec.get(&10), None); + } +} diff --git a/network/src/channel.rs b/network/src/channel.rs index cf7a7851bd..ce6ee08cab 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -5,6 +5,7 @@ use network_protocol::{ ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol, TcpSendProtocol, UnreliableDrain, UnreliableSink, }; +#[cfg(feature = "quic")] use quinn::*; use std::{sync::Arc, time::Duration}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, @@ -16,18 +17,24 @@ use tokio::{ pub(crate) enum Protocols { Tcp((TcpSendProtocol, TcpRecvProtocol)), Mpsc((MpscSendProtocol, MpscRecvProtocol)), + #[cfg(feature = "quic")] + Quic((QuicSendProtocol, QuicRecvProtocol)), } #[derive(Debug)] pub(crate) enum SendProtocols { Tcp(TcpSendProtocol), Mpsc(MpscSendProtocol), + #[cfg(feature = "quic")] + Quic(QuicSendProtocol), } #[derive(Debug)] pub(crate) enum RecvProtocols { Tcp(TcpRecvProtocol), Mpsc(MpscRecvProtocol), + #[cfg(feature = "quic")] + Quic(QuicSendProtocol), } impl Protocols { @@ -67,6 +74,8 @@ impl Protocols { match self { Protocols::Tcp((s, r)) => (SendProtocols::Tcp(s), RecvProtocols::Tcp(r)), Protocols::Mpsc((s, r)) => (SendProtocols::Mpsc(s), RecvProtocols::Mpsc(r)), + #[cfg(feature = "quic")] + Protocols::Quic((s, r)) => (SendProtocols::Quic(s), RecvProtocols::Quic(r)), } } } @@ -82,6 +91,8 @@ impl network_protocol::InitProtocol for Protocols { match self { Protocols::Tcp(p) => p.initialize(initializer, local_pid, secret).await, Protocols::Mpsc(p) => p.initialize(initializer, local_pid, secret).await, + #[cfg(feature = "quic")] + Protocols::Quic(p) => p.initialize(initializer, local_pid, secret).await, } } } @@ -92,6 +103,8 @@ impl network_protocol::SendProtocol for SendProtocols { match self { SendProtocols::Tcp(s) => s.notify_from_recv(event), SendProtocols::Mpsc(s) => s.notify_from_recv(event), + #[cfg(feature = "quic")] + SendProtocols::Quic(s) => s.notify_from_recv(event), } } @@ -99,6 +112,8 @@ impl network_protocol::SendProtocol for SendProtocols { match self { SendProtocols::Tcp(s) => s.send(event).await, SendProtocols::Mpsc(s) => s.send(event).await, + #[cfg(feature = "quic")] + SendProtocols::Quic(s) => s.send(event).await, } } @@ -110,6 +125,8 @@ impl network_protocol::SendProtocol for SendProtocols { match self { SendProtocols::Tcp(s) => s.flush(bandwidth, dt).await, SendProtocols::Mpsc(s) => s.flush(bandwidth, dt).await, + #[cfg(feature = "quic")] + SendProtocols::Quic(s) => s.flush(bandwidth, dt).await, } } } @@ -120,6 +137,8 @@ impl network_protocol::RecvProtocol for RecvProtocols { match self { RecvProtocols::Tcp(r) => r.recv().await, RecvProtocols::Mpsc(r) => r.recv().await, + #[cfg(feature = "quic")] + RecvProtocols::Quic(r) => r.recv().await, } } } @@ -196,6 +215,45 @@ impl UnreliableSink for MpscSink { } } +/////////////////////////////////////// +//// QUIC +#[derive(Debug)] +pub struct QuicDrain { + half: OwnedWriteHalf, +} + +#[derive(Debug)] +pub struct QuicSink { + half: OwnedReadHalf, + buffer: BytesMut, +} + +#[async_trait] +impl UnreliableDrain for QuicDrain { + type DataFormat = BytesMut; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + match self.half.write_all(&data).await { + Ok(()) => Ok(()), + Err(_) => Err(ProtocolError::Closed), + } + } +} + +#[async_trait] +impl UnreliableSink for QuicSink { + type DataFormat = BytesMut; + + async fn recv(&mut self) -> Result { + self.buffer.resize(1500, 0u8); + match self.half.read(&mut self.buffer).await { + Ok(0) => Err(ProtocolError::Closed), + Ok(n) => Ok(self.buffer.split_to(n)), + Err(_) => Err(ProtocolError::Closed), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/network/src/participant.rs b/network/src/participant.rs index afa30266e8..a06321201c 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -2,12 +2,13 @@ use crate::{ api::{ParticipantError, Stream}, channel::{Protocols, RecvProtocols, SendProtocols}, metrics::NetworkMetrics, - util::{DeferredTracer, SortedVec}, + util::DeferredTracer, }; use bytes::Bytes; use futures_util::{FutureExt, StreamExt}; use network_protocol::{ Bandwidth, Cid, Pid, Prio, Promises, ProtocolEvent, RecvProtocol, SendProtocol, Sid, + _internal::SortedVec, }; use std::{ collections::HashMap, diff --git a/network/src/util.rs b/network/src/util.rs index b9a8801263..640d65ee55 100644 --- a/network/src/util.rs +++ b/network/src/util.rs @@ -44,74 +44,3 @@ impl DeferredTracer { } } } - -/// Used for storing Protocols in a Participant or Stream <-> Protocol -pub(crate) struct SortedVec { - pub data: Vec<(K, V)>, -} - -impl Default for SortedVec { - fn default() -> Self { Self { data: vec![] } } -} - -impl SortedVec -where - K: Ord + Copy, -{ - pub fn insert(&mut self, k: K, v: V) { - self.data.push((k, v)); - self.data.sort_by_key(|&(k, _)| k); - } - - pub fn delete(&mut self, k: &K) -> Option { - if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { - Some(self.data.remove(i).1) - } else { - None - } - } - - pub fn get(&self, k: &K) -> Option<&V> { - if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { - Some(&self.data[i].1) - } else { - None - } - } - - pub fn get_mut(&mut self, k: &K) -> Option<&mut V> { - if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { - Some(&mut self.data[i].1) - } else { - None - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn sorted_vec() { - let mut vec = SortedVec::default(); - vec.insert(10, "Hello"); - println!("{:?}", vec.data); - vec.insert(30, "World"); - println!("{:?}", vec.data); - vec.insert(20, " "); - println!("{:?}", vec.data); - assert_eq!(vec.data[0].1, "Hello"); - assert_eq!(vec.data[1].1, " "); - assert_eq!(vec.data[2].1, "World"); - assert_eq!(vec.get(&30), Some(&"World")); - assert_eq!(vec.get_mut(&20), Some(&mut " ")); - assert_eq!(vec.get(&10), Some(&"Hello")); - assert_eq!(vec.delete(&40), None); - assert_eq!(vec.delete(&10), Some("Hello")); - assert_eq!(vec.delete(&10), None); - assert_eq!(vec.get(&30), Some(&"World")); - assert_eq!(vec.get_mut(&20), Some(&mut " ")); - assert_eq!(vec.get(&10), None); - } -}