Quic: We had the followuing problem:

- locally we open a stream, our local Drain is sending OpenStream
 - remote Sink will know this and notify remote Drain
 - remote side sends a message
 - local sink does not know about the Stream. as there is (and CANT) be a wat to notify local Sink from local Drain (it could introduce race conditions).

One of the possible solutions was, that the remote drain will copy the OpenStream Msg ON the Quic::stream before first data is send. This would work but is complicated.

Instead we now just mark such streams as "potentially open" and we listen for the first DataHeader to get it's SID.

add support for unreliable messages in quic protocol, benchmarks
This commit is contained in:
Marcel Märtens 2021-04-09 13:17:38 +02:00
parent c16bf51ab2
commit 383482a36e
11 changed files with 1246 additions and 88 deletions

53
Cargo.lock generated
View File

@ -2268,7 +2268,7 @@ dependencies = [
"httpdate",
"itoa",
"pin-project",
"socket2",
"socket2 0.4.0",
"tokio",
"tower-service",
"tracing",
@ -3854,6 +3854,45 @@ dependencies = [
"tracing",
]
[[package]]
name = "quinn"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c82c0a393b300104f989f3db8b8637c0d11f7a32a9c214560b47849ba8f119aa"
dependencies = [
"bytes",
"futures",
"lazy_static",
"libc",
"mio 0.7.11",
"quinn-proto",
"rustls",
"socket2 0.3.19",
"thiserror",
"tokio",
"tracing",
"webpki",
]
[[package]]
name = "quinn-proto"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09027365a21874b71e1fbd9d31cb99bff8e11ba81cc9ef2b9425bb607e42d3b2"
dependencies = [
"bytes",
"ct-logs",
"rand 0.8.3",
"ring",
"rustls",
"rustls-native-certs",
"slab",
"thiserror",
"tinyvec",
"tracing",
"webpki",
]
[[package]]
name = "quote"
version = "0.6.13"
@ -4673,6 +4712,17 @@ dependencies = [
"wayland-client 0.28.5",
]
[[package]]
name = "socket2"
version = "0.3.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "122e570113d28d773067fab24266b66753f6ea915758651696b6e35e49f88d6e"
dependencies = [
"cfg-if 1.0.0",
"libc",
"winapi 0.3.9",
]
[[package]]
name = "socket2"
version = "0.4.0"
@ -5576,6 +5626,7 @@ dependencies = [
"lz-fear",
"prometheus",
"prometheus-hyper",
"quinn",
"rand 0.8.3",
"serde",
"shellexpand",

View File

@ -9,8 +9,9 @@ edition = "2018"
[features]
metrics = ["prometheus", "network-protocol/metrics"]
compression = ["lz-fear"]
quic = ["quinn"]
default = ["metrics","compression"]
default = ["metrics","compression","quinn"]
[dependencies]
@ -33,6 +34,8 @@ async-channel = "1.5.1" #use for .close() channels
#mpsc channel registry
lazy_static = { version = "1.4", default-features = false }
rand = { version = "0.8" }
#quic support
quinn = { version = "0.7.2", optional = true }
#stream flags
bitflags = "1.2.1"
lz-fear = { version = "0.1.1", optional = true }

View File

@ -6,8 +6,9 @@ use std::{sync::Arc, time::Duration};
use tokio::runtime::Runtime;
use veloren_network_protocol::{
InitProtocol, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid, Promises, ProtocolError,
ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, RecvProtocol, SendProtocol, Sid,
TcpRecvProtocol, TcpSendProtocol, UnreliableDrain, UnreliableSink, _internal::OTFrame,
ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, QuicDataFormat, QuicRecvProtocol,
QuicSendProtocol, RecvProtocol, SendProtocol, Sid, TcpRecvProtocol, TcpSendProtocol,
UnreliableDrain, UnreliableSink, _internal::OTFrame,
};
fn frame_serialize(frame: OTFrame, buffer: &mut BytesMut) { frame.write_bytes(buffer); }
@ -145,7 +146,35 @@ fn criterion_tcp(c: &mut Criterion) {
c.finish();
}
criterion_group!(benches, criterion_util, criterion_mpsc, criterion_tcp);
fn criterion_quic(c: &mut Criterion) {
let mut c = c.benchmark_group("quic");
c.significance_level(0.1).sample_size(10);
c.throughput(Throughput::Bytes(1000000000))
.bench_function("1GB_in_10000_msg", |b| {
let buf = Bytes::from(&[155u8; 100_000][..]);
b.to_async(rt()).iter_with_setup(
|| (buf.clone(), utils::quic_bound(10000, None)),
|(b, p)| send_and_recv_msg(p, b, 10_000),
)
});
c.throughput(Throughput::Elements(1000000))
.bench_function("1000000_tiny_msg", |b| {
let buf = Bytes::from(&[3u8; 5][..]);
b.to_async(rt()).iter_with_setup(
|| (buf.clone(), utils::quic_bound(10000, None)),
|(b, p)| send_and_recv_msg(p, b, 1_000_000),
)
});
c.finish();
}
criterion_group!(
benches,
criterion_util,
criterion_mpsc,
criterion_tcp,
criterion_quic
);
criterion_main!(benches);
mod utils {
@ -210,6 +239,36 @@ mod utils {
]
}
pub struct QuicDrain {
pub sender: Sender<QuicDataFormat>,
}
pub struct QuicSink {
pub receiver: Receiver<QuicDataFormat>,
}
/// emulate Quic protocol on Channels
pub fn quic_bound(
cap: usize,
metrics: Option<ProtocolMetricCache>,
) -> [(QuicSendProtocol<QuicDrain>, QuicRecvProtocol<QuicSink>); 2] {
let (s1, r1) = async_channel::bounded(cap);
let (s2, r2) = async_channel::bounded(cap);
let m = metrics.unwrap_or_else(|| {
ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap()))
});
[
(
QuicSendProtocol::new(QuicDrain { sender: s1 }, m.clone()),
QuicRecvProtocol::new(QuicSink { receiver: r2 }, m.clone()),
),
(
QuicSendProtocol::new(QuicDrain { sender: s2 }, m.clone()),
QuicRecvProtocol::new(QuicSink { receiver: r1 }, m),
),
]
}
#[async_trait]
impl UnreliableDrain for ACDrain {
type DataFormat = MpscMsg;
@ -257,4 +316,28 @@ mod utils {
.map_err(|_| ProtocolError::Closed)
}
}
#[async_trait]
impl UnreliableDrain for QuicDrain {
type DataFormat = QuicDataFormat;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> {
self.sender
.send(data)
.await
.map_err(|_| ProtocolError::Closed)
}
}
#[async_trait]
impl UnreliableSink for QuicSink {
type DataFormat = QuicDataFormat;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> {
self.receiver
.recv()
.await
.map_err(|_| ProtocolError::Closed)
}
}
}

View File

@ -1,3 +1,4 @@
#![feature(drain_filter)]
//! Network Protocol
//!
//! a I/O-Free protocol for the veloren network crate.
@ -13,9 +14,9 @@
//! This crate currently defines:
//! - TCP
//! - MPSC
//! - QUIC
//!
//! a UDP implementation will quickly follow, and it's also possible to abstract
//! over QUIC.
//! eventually a pure UDP implementation will follow
//!
//! warning: don't mix protocol, using the TCP variant for actual UDP socket
//! will result in dropped data using UDP with a TCP socket will be a waste of
@ -57,8 +58,10 @@ mod message;
mod metrics;
mod mpsc;
mod prio;
mod quic;
mod tcp;
mod types;
mod util;
pub use error::{InitProtocolError, ProtocolError};
pub use event::ProtocolEvent;
@ -66,12 +69,16 @@ pub use metrics::ProtocolMetricCache;
#[cfg(feature = "metrics")]
pub use metrics::ProtocolMetrics;
pub use mpsc::{MpscMsg, MpscRecvProtocol, MpscSendProtocol};
pub use quic::{QuicDataFormat, QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol};
pub use tcp::{TcpRecvProtocol, TcpSendProtocol};
pub use types::{Bandwidth, Cid, Pid, Prio, Promises, Sid, HIGHEST_PRIO, VELOREN_NETWORK_VERSION};
///use at own risk, might change any time, for internal benchmarks
pub mod _internal {
pub use crate::frame::{ITFrame, OTFrame};
pub use crate::{
frame::{ITFrame, OTFrame},
util::SortedVec,
};
}
use async_trait::async_trait;

View File

@ -75,7 +75,7 @@ impl PrioManager {
/// bandwidth might be extended, as for technical reasons
/// guaranteed_bandwidth is used and frames are always 1400 bytes.
pub fn grab(&mut self, bandwidth: Bandwidth, dt: Duration) -> (Vec<OTFrame>, Bandwidth) {
pub fn grab(&mut self, bandwidth: Bandwidth, dt: Duration) -> (Vec<(Sid, OTFrame)>, Bandwidth) {
let total_bytes = (bandwidth as f64 * dt.as_secs_f64()) as u64;
let mut cur_bytes = 0u64;
let mut frames = vec![];
@ -84,7 +84,7 @@ impl PrioManager {
let metrics = &mut self.metrics;
let mut process_stream =
|stream: &mut StreamInfo, mut bandwidth: i64, cur_bytes: &mut u64| {
|sid: &Sid, stream: &mut StreamInfo, mut bandwidth: i64, cur_bytes: &mut u64| {
let mut finished = None;
'outer: for (i, msg) in stream.messages.iter_mut().enumerate() {
while let Some(frame) = msg.next() {
@ -95,7 +95,7 @@ impl PrioManager {
} as u64;
bandwidth -= b as i64;
*cur_bytes += b;
frames.push(frame);
frames.push((*sid, frame));
if bandwidth <= 0 {
break 'outer;
}
@ -111,10 +111,10 @@ impl PrioManager {
};
// Add guaranteed bandwidth
for stream in self.streams.values_mut() {
for (sid, stream) in self.streams.iter_mut() {
prios[stream.prio as usize] += 1;
let stream_byte_cnt = (stream.guaranteed_bandwidth as f64 * dt.as_secs_f64()) as u64;
process_stream(stream, stream_byte_cnt as i64, &mut cur_bytes);
process_stream(sid, stream, stream_byte_cnt as i64, &mut cur_bytes);
}
if cur_bytes < total_bytes {
@ -124,11 +124,11 @@ impl PrioManager {
continue;
}
let per_stream_bytes = ((total_bytes - cur_bytes) / prios[prio as usize]) as i64;
for stream in self.streams.values_mut() {
for (sid, stream) in self.streams.iter_mut() {
if stream.prio != prio {
continue;
}
process_stream(stream, per_stream_bytes, &mut cur_bytes);
process_stream(sid, stream, per_stream_bytes, &mut cur_bytes);
}
}
}

View File

@ -0,0 +1,955 @@
use crate::{
error::ProtocolError,
event::ProtocolEvent,
frame::{ITFrame, InitFrame, OTFrame},
handshake::{ReliableDrain, ReliableSink},
message::{ITMessage, ALLOC_BLOCK},
metrics::{ProtocolMetricCache, RemoveReason},
prio::PrioManager,
types::{Bandwidth, Mid, Promises, Sid},
util::SortedVec,
RecvProtocol, SendProtocol, UnreliableDrain, UnreliableSink,
};
use async_trait::async_trait;
use bytes::BytesMut;
use std::{
collections::HashMap,
time::{Duration, Instant},
};
use tracing::info;
#[cfg(feature = "trace_pedantic")]
use tracing::trace;
#[derive(PartialEq)]
pub enum QuicDataFormatStream {
Main,
Reliable(u64),
Unreliable,
}
pub struct QuicDataFormat {
stream: QuicDataFormatStream,
data: BytesMut,
}
impl QuicDataFormat {
fn with_main(buffer: &mut BytesMut) -> Self {
Self {
stream: QuicDataFormatStream::Main,
data: buffer.split(),
}
}
fn with_reliable(buffer: &mut BytesMut, id: u64) -> Self {
Self {
stream: QuicDataFormatStream::Reliable(id),
data: buffer.split(),
}
}
fn with_unreliable(frame: OTFrame) -> Self {
let mut buffer = BytesMut::new();
frame.write_bytes(&mut buffer);
Self {
stream: QuicDataFormatStream::Unreliable,
data: buffer,
}
}
}
/// QUIC implementation of [`SendProtocol`]
///
/// [`SendProtocol`]: crate::SendProtocol
#[derive(Debug)]
pub struct QuicSendProtocol<D>
where
D: UnreliableDrain<DataFormat = QuicDataFormat>,
{
main_buffer: BytesMut,
reliable_buffers: SortedVec<Sid, BytesMut>,
store: PrioManager,
next_mid: Mid,
closing_streams: Vec<Sid>,
notify_closing_streams: Vec<Sid>,
pending_shutdown: bool,
drain: D,
last: Instant,
metrics: ProtocolMetricCache,
}
/// QUIC implementation of [`RecvProtocol`]
///
/// [`RecvProtocol`]: crate::RecvProtocol
#[derive(Debug)]
pub struct QuicRecvProtocol<S>
where
S: UnreliableSink<DataFormat = QuicDataFormat>,
{
main_buffer: BytesMut,
unreliable_buffer: BytesMut,
reliable_buffers: SortedVec<Sid, BytesMut>,
pending_reliable_buffers: Vec<(u64, BytesMut)>,
itmsg_allocator: BytesMut,
incoming: HashMap<Mid, ITMessage>,
sink: S,
metrics: ProtocolMetricCache,
}
impl<D> QuicSendProtocol<D>
where
D: UnreliableDrain<DataFormat = QuicDataFormat>,
{
pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self {
Self {
main_buffer: BytesMut::new(),
reliable_buffers: SortedVec::default(),
store: PrioManager::new(metrics.clone()),
next_mid: 0u64,
closing_streams: vec![],
notify_closing_streams: vec![],
pending_shutdown: false,
drain,
last: Instant::now(),
metrics,
}
}
/// returns all promises that this Protocol can take care of
/// If you open a Stream anyway, unsupported promises are ignored.
pub fn supported_promises() -> Promises {
Promises::ORDERED
| Promises::CONSISTENCY
| Promises::GUARANTEED_DELIVERY
| Promises::COMPRESSED
| Promises::ENCRYPTED
}
}
impl<S> QuicRecvProtocol<S>
where
S: UnreliableSink<DataFormat = QuicDataFormat>,
{
pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self {
Self {
main_buffer: BytesMut::new(),
unreliable_buffer: BytesMut::new(),
reliable_buffers: SortedVec::default(),
pending_reliable_buffers: vec![],
itmsg_allocator: BytesMut::with_capacity(ALLOC_BLOCK),
incoming: HashMap::new(),
sink,
metrics,
}
}
async fn recv_into_stream(&mut self) -> Result<QuicDataFormatStream, ProtocolError> {
let chunk = self.sink.recv().await?;
let buffer = match chunk.stream {
QuicDataFormatStream::Main => &mut self.main_buffer,
QuicDataFormatStream::Unreliable => &mut self.unreliable_buffer,
QuicDataFormatStream::Reliable(id) => {
match self.reliable_buffers.data.get_mut(id as usize) {
Some((_, buffer)) => buffer,
None => {
self.pending_reliable_buffers.push((id, BytesMut::new()));
//Violated but will never happen
&mut self
.pending_reliable_buffers
.last_mut()
.ok_or(ProtocolError::Violated)?
.1
},
}
},
};
if buffer.is_empty() {
*buffer = chunk.data
} else {
buffer.extend_from_slice(&chunk.data)
}
Ok(chunk.stream)
}
}
#[async_trait]
impl<D> SendProtocol for QuicSendProtocol<D>
where
D: UnreliableDrain<DataFormat = QuicDataFormat>,
{
fn notify_from_recv(&mut self, event: ProtocolEvent) {
match event {
ProtocolEvent::OpenStream {
sid,
prio,
promises,
guaranteed_bandwidth,
} => {
self.store
.open_stream(sid, prio, promises, guaranteed_bandwidth);
if promises.contains(Promises::ORDERED)
|| promises.contains(Promises::CONSISTENCY)
|| promises.contains(Promises::GUARANTEED_DELIVERY)
{
self.reliable_buffers.insert(sid, BytesMut::new());
}
},
ProtocolEvent::CloseStream { sid } => {
if !self.store.try_close_stream(sid) {
#[cfg(feature = "trace_pedantic")]
trace!(?sid, "hold back notify close stream");
self.notify_closing_streams.push(sid);
}
},
_ => {},
}
}
async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> {
#[cfg(feature = "trace_pedantic")]
trace!(?event, "send");
match event {
ProtocolEvent::OpenStream {
sid,
prio,
promises,
guaranteed_bandwidth,
} => {
self.store
.open_stream(sid, prio, promises, guaranteed_bandwidth);
if promises.contains(Promises::ORDERED)
|| promises.contains(Promises::CONSISTENCY)
|| promises.contains(Promises::GUARANTEED_DELIVERY)
{
self.reliable_buffers.insert(sid, BytesMut::new());
}
event.to_frame().write_bytes(&mut self.main_buffer);
self.drain
.send(QuicDataFormat::with_main(&mut self.main_buffer))
.await?;
},
ProtocolEvent::CloseStream { sid } => {
if self.store.try_close_stream(sid) {
let _ = self.reliable_buffers.delete(&sid); //delete if it was reliable
event.to_frame().write_bytes(&mut self.main_buffer);
self.drain
.send(QuicDataFormat::with_main(&mut self.main_buffer))
.await?;
} else {
#[cfg(feature = "trace_pedantic")]
trace!(?sid, "hold back close stream");
self.closing_streams.push(sid);
}
},
ProtocolEvent::Shutdown => {
if self.store.is_empty() {
event.to_frame().write_bytes(&mut self.main_buffer);
self.drain
.send(QuicDataFormat::with_main(&mut self.main_buffer))
.await?;
} else {
#[cfg(feature = "trace_pedantic")]
trace!("hold back shutdown");
self.pending_shutdown = true;
}
},
ProtocolEvent::Message { data, sid } => {
self.metrics.smsg_ib(sid, data.len() as u64);
self.store.add(data, self.next_mid, sid);
self.next_mid += 1;
},
}
Ok(())
}
async fn flush(
&mut self,
bandwidth: Bandwidth,
dt: Duration,
) -> Result</* actual */ Bandwidth, ProtocolError> {
let (frames, _) = self.store.grab(bandwidth, dt);
//Todo: optimize reserve
let mut data_frames = 0;
let mut data_bandwidth = 0;
for (sid, frame) in frames {
if let OTFrame::Data { mid: _, data } = &frame {
data_bandwidth += data.len();
data_frames += 1;
}
match self.reliable_buffers.get_mut(&sid) {
Some(buffer) => frame.write_bytes(buffer),
None => {
self.drain
.send(QuicDataFormat::with_unreliable(frame))
.await?
},
}
}
for (id, (_, buffer)) in self.reliable_buffers.data.iter_mut().enumerate() {
self.drain
.send(QuicDataFormat::with_reliable(buffer, id as u64))
.await?;
}
self.metrics
.sdata_frames_b(data_frames, data_bandwidth as u64);
let mut finished_streams = vec![];
for (i, &sid) in self.closing_streams.iter().enumerate() {
if self.store.try_close_stream(sid) {
#[cfg(feature = "trace_pedantic")]
trace!(?sid, "close stream, as it's now empty");
OTFrame::CloseStream { sid }.write_bytes(&mut self.main_buffer);
self.drain
.send(QuicDataFormat::with_main(&mut self.main_buffer))
.await?;
finished_streams.push(i);
}
}
for i in finished_streams.iter().rev() {
self.closing_streams.remove(*i);
}
let mut finished_streams = vec![];
for (i, sid) in self.notify_closing_streams.iter().enumerate() {
if self.store.try_close_stream(*sid) {
#[cfg(feature = "trace_pedantic")]
trace!(?sid, "close stream, as it's now empty");
finished_streams.push(i);
}
}
for i in finished_streams.iter().rev() {
self.notify_closing_streams.remove(*i);
}
if self.pending_shutdown && self.store.is_empty() {
#[cfg(feature = "trace_pedantic")]
trace!("shutdown, as it's now empty");
OTFrame::Shutdown {}.write_bytes(&mut self.main_buffer);
self.drain
.send(QuicDataFormat::with_main(&mut self.main_buffer))
.await?;
self.pending_shutdown = false;
}
Ok(data_bandwidth as u64)
}
}
#[async_trait]
impl<S> RecvProtocol for QuicRecvProtocol<S>
where
S: UnreliableSink<DataFormat = QuicDataFormat>,
{
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError> {
'outer: loop {
loop {
match ITFrame::read_frame(&mut self.main_buffer) {
Ok(Some(frame)) => {
#[cfg(feature = "trace_pedantic")]
trace!(?frame, "recv");
match frame {
ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown),
ITFrame::OpenStream {
sid,
prio,
promises,
guaranteed_bandwidth,
} => {
if promises.contains(Promises::ORDERED)
|| promises.contains(Promises::CONSISTENCY)
|| promises.contains(Promises::GUARANTEED_DELIVERY)
{
self.reliable_buffers.insert(sid, BytesMut::new());
}
break 'outer Ok(ProtocolEvent::OpenStream {
sid,
prio: prio.min(crate::types::HIGHEST_PRIO),
promises,
guaranteed_bandwidth,
});
},
ITFrame::CloseStream { sid } => {
//FIXME: defer close!
//let _ = self.reliable_buffers.delete(sid); // if it was reliable
break 'outer Ok(ProtocolEvent::CloseStream { sid });
},
_ => break 'outer Err(ProtocolError::Violated),
};
},
Ok(None) => break, //inner => read more data
Err(()) => return Err(ProtocolError::Violated),
}
}
// try to order pending
let mut pending_violated = false;
let mut reliable = vec![];
self.pending_reliable_buffers.drain_filter(|(_, buffer)| {
// try to get Sid without touching buffer
let mut testbuffer = buffer.clone();
match ITFrame::read_frame(&mut testbuffer) {
Ok(Some(ITFrame::DataHeader {
sid,
mid: _,
length: _,
})) => {
reliable.push((sid, buffer.clone()));
true
},
Ok(Some(_)) | Err(_) => {
pending_violated = true;
true
},
Ok(None) => false,
}
});
if pending_violated {
break 'outer Err(ProtocolError::Violated);
}
for (sid, buffer) in reliable.into_iter() {
self.reliable_buffers.insert(sid, buffer)
}
let mut iter = self
.reliable_buffers
.data
.iter_mut()
.map(|(_, b)| (b, true))
.collect::<Vec<_>>();
iter.push((&mut self.unreliable_buffer, false));
for (buffer, reliable) in iter {
loop {
match ITFrame::read_frame(buffer) {
Ok(Some(frame)) => {
#[cfg(feature = "trace_pedantic")]
trace!(?frame, "recv");
match frame {
ITFrame::DataHeader { sid, mid, length } => {
let m = ITMessage::new(sid, length, &mut self.itmsg_allocator);
self.metrics.rmsg_ib(sid, length);
self.incoming.insert(mid, m);
},
ITFrame::Data { mid, data } => {
self.metrics.rdata_frames_b(data.len() as u64);
let m = match self.incoming.get_mut(&mid) {
Some(m) => m,
None => {
if reliable {
info!(
?mid,
"protocol violation by remote side: send Data before \
Header"
);
break 'outer Err(ProtocolError::Violated);
} else {
//TODO: cleanup old messages from time to time
continue;
}
},
};
m.data.extend_from_slice(&data);
if m.data.len() == m.length as usize {
// finished, yay
let m = self.incoming.remove(&mid).unwrap();
self.metrics.rmsg_ob(
m.sid,
RemoveReason::Finished,
m.data.len() as u64,
);
break 'outer Ok(ProtocolEvent::Message {
sid: m.sid,
data: m.data.freeze(),
});
}
},
_ => break 'outer Err(ProtocolError::Violated),
};
},
Ok(None) => break, //inner => read more data
Err(()) => return Err(ProtocolError::Violated),
}
}
}
self.recv_into_stream().await?;
}
}
}
#[async_trait]
impl<D> ReliableDrain for QuicSendProtocol<D>
where
D: UnreliableDrain<DataFormat = QuicDataFormat>,
{
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> {
self.main_buffer.reserve(500);
frame.write_bytes(&mut self.main_buffer);
self.drain
.send(QuicDataFormat::with_main(&mut self.main_buffer))
.await
}
}
#[async_trait]
impl<S> ReliableSink for QuicRecvProtocol<S>
where
S: UnreliableSink<DataFormat = QuicDataFormat>,
{
async fn recv(&mut self) -> Result<InitFrame, ProtocolError> {
while self.main_buffer.len() < 100 {
if self.recv_into_stream().await? == QuicDataFormatStream::Main {
if let Some(frame) = InitFrame::read_frame(&mut self.main_buffer) {
return Ok(frame);
}
}
}
Err(ProtocolError::Violated)
}
}
#[cfg(test)]
mod test_utils {
//Quic protocol based on Channel
use super::*;
use crate::metrics::{ProtocolMetricCache, ProtocolMetrics};
use async_channel::*;
use std::sync::Arc;
pub struct QuicDrain {
pub sender: Sender<QuicDataFormat>,
pub drop_ratio: f32,
}
pub struct QuicSink {
pub receiver: Receiver<QuicDataFormat>,
}
/// emulate Quic protocol on Channels
pub fn quic_bound(
cap: usize,
drop_ratio: f32,
metrics: Option<ProtocolMetricCache>,
) -> [(QuicSendProtocol<QuicDrain>, QuicRecvProtocol<QuicSink>); 2] {
let (s1, r1) = async_channel::bounded(cap);
let (s2, r2) = async_channel::bounded(cap);
let m = metrics.unwrap_or_else(|| {
ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap()))
});
[
(
QuicSendProtocol::new(
QuicDrain {
sender: s1,
drop_ratio,
},
m.clone(),
),
QuicRecvProtocol::new(QuicSink { receiver: r2 }, m.clone()),
),
(
QuicSendProtocol::new(
QuicDrain {
sender: s2,
drop_ratio,
},
m.clone(),
),
QuicRecvProtocol::new(QuicSink { receiver: r1 }, m),
),
]
}
#[async_trait]
impl UnreliableDrain for QuicDrain {
type DataFormat = QuicDataFormat;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> {
use rand::Rng;
if matches!(data.stream, QuicDataFormatStream::Unreliable)
&& rand::thread_rng().gen::<f32>() < self.drop_ratio
{
return Ok(());
}
self.sender
.send(data)
.await
.map_err(|_| ProtocolError::Closed)
}
}
#[async_trait]
impl UnreliableSink for QuicSink {
type DataFormat = QuicDataFormat;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> {
self.receiver
.recv()
.await
.map_err(|_| ProtocolError::Closed)
}
}
}
#[cfg(test)]
mod tests {
use crate::{
error::ProtocolError,
frame::OTFrame,
metrics::{ProtocolMetricCache, ProtocolMetrics, RemoveReason},
quic::{test_utils::*, QuicDataFormat},
types::{Pid, Promises, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2},
InitProtocol, ProtocolEvent, RecvProtocol, SendProtocol,
};
use bytes::{Bytes, BytesMut};
use std::{sync::Arc, time::Duration};
#[tokio::test]
async fn handshake_all_good() {
let [mut p1, mut p2] = quic_bound(10, 0.5, None);
let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await });
let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await });
let (r1, r2) = tokio::join!(r1, r2);
assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42)));
assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337)));
}
#[tokio::test]
async fn open_stream() {
let [p1, p2] = quic_bound(10, 0.5, None);
let (mut s, mut r) = (p1.0, p2.1);
let event = ProtocolEvent::OpenStream {
sid: Sid::new(10),
prio: 0u8,
promises: Promises::ORDERED,
guaranteed_bandwidth: 1_000_000,
};
s.send(event.clone()).await.unwrap();
let e = r.recv().await.unwrap();
assert_eq!(event, e);
}
#[tokio::test]
async fn send_short_msg() {
let [p1, p2] = quic_bound(10, 0.5, None);
let (mut s, mut r) = (p1.0, p2.1);
let event = ProtocolEvent::OpenStream {
sid: Sid::new(10),
prio: 3u8,
promises: Promises::ORDERED,
guaranteed_bandwidth: 1_000_000,
};
s.send(event).await.unwrap();
let _ = r.recv().await.unwrap();
let event = ProtocolEvent::Message {
sid: Sid::new(10),
data: Bytes::from(&[188u8; 600][..]),
};
s.send(event.clone()).await.unwrap();
s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
let e = r.recv().await.unwrap();
assert_eq!(event, e);
// 2nd short message
let event = ProtocolEvent::Message {
sid: Sid::new(10),
data: Bytes::from(&[7u8; 30][..]),
};
s.send(event.clone()).await.unwrap();
s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
let e = r.recv().await.unwrap();
assert_eq!(event, e)
}
#[tokio::test]
async fn send_long_msg() {
let mut metrics =
ProtocolMetricCache::new("long_quic", Arc::new(ProtocolMetrics::new().unwrap()));
let sid = Sid::new(1);
let [p1, p2] = quic_bound(10000, 0.5, Some(metrics.clone()));
let (mut s, mut r) = (p1.0, p2.1);
let event = ProtocolEvent::OpenStream {
sid,
prio: 5u8,
promises: Promises::COMPRESSED | Promises::ORDERED,
guaranteed_bandwidth: 1_000_000,
};
s.send(event).await.unwrap();
let _ = r.recv().await.unwrap();
let event = ProtocolEvent::Message {
sid,
data: Bytes::from(&[99u8; 500_000][..]),
};
s.send(event.clone()).await.unwrap();
s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
let e = r.recv().await.unwrap();
assert_eq!(event, e);
metrics.assert_msg(sid, 1, RemoveReason::Finished);
metrics.assert_msg_bytes(sid, 500_000, RemoveReason::Finished);
metrics.assert_data_frames(358);
metrics.assert_data_frames_bytes(500_000);
}
#[tokio::test]
async fn msg_finishes_after_close() {
let sid = Sid::new(1);
let [p1, p2] = quic_bound(10000, 0.5, None);
let (mut s, mut r) = (p1.0, p2.1);
let event = ProtocolEvent::OpenStream {
sid,
prio: 5u8,
promises: Promises::COMPRESSED | Promises::ORDERED,
guaranteed_bandwidth: 0,
};
s.send(event).await.unwrap();
let _ = r.recv().await.unwrap();
let event = ProtocolEvent::Message {
sid,
data: Bytes::from(&[99u8; 500_000][..]),
};
s.send(event).await.unwrap();
let event = ProtocolEvent::CloseStream { sid };
s.send(event).await.unwrap();
//send
s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
let e = r.recv().await.unwrap();
assert!(matches!(e, ProtocolEvent::Message { .. }));
let e = r.recv().await.unwrap();
assert!(matches!(e, ProtocolEvent::CloseStream { .. }));
}
#[tokio::test]
async fn msg_finishes_after_shutdown() {
let sid = Sid::new(1);
let [p1, p2] = quic_bound(10000, 0.5, None);
let (mut s, mut r) = (p1.0, p2.1);
let event = ProtocolEvent::OpenStream {
sid,
prio: 5u8,
promises: Promises::COMPRESSED | Promises::ORDERED,
guaranteed_bandwidth: 0,
};
s.send(event).await.unwrap();
let _ = r.recv().await.unwrap();
let event = ProtocolEvent::Message {
sid,
data: Bytes::from(&[99u8; 500_000][..]),
};
s.send(event).await.unwrap();
let event = ProtocolEvent::Shutdown {};
s.send(event).await.unwrap();
let event = ProtocolEvent::CloseStream { sid };
s.send(event).await.unwrap();
//send
s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
let e = r.recv().await.unwrap();
assert!(matches!(e, ProtocolEvent::Message { .. }));
let e = r.recv().await.unwrap();
assert!(matches!(e, ProtocolEvent::CloseStream { .. }));
let e = r.recv().await.unwrap();
assert!(matches!(e, ProtocolEvent::Shutdown { .. }));
}
#[tokio::test]
async fn msg_finishes_after_drop() {
let sid = Sid::new(1);
let [p1, p2] = quic_bound(10000, 0.5, None);
let (mut s, mut r) = (p1.0, p2.1);
let event = ProtocolEvent::OpenStream {
sid,
prio: 5u8,
promises: Promises::COMPRESSED | Promises::ORDERED,
guaranteed_bandwidth: 0,
};
s.send(event).await.unwrap();
let event = ProtocolEvent::Message {
sid,
data: Bytes::from(&[99u8; 500_000][..]),
};
s.send(event).await.unwrap();
s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
let event = ProtocolEvent::Message {
sid,
data: Bytes::from(&[100u8; 500_000][..]),
};
s.send(event).await.unwrap();
s.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
drop(s);
let e = r.recv().await.unwrap();
assert!(matches!(e, ProtocolEvent::OpenStream { .. }));
let e = r.recv().await.unwrap();
assert!(matches!(e, ProtocolEvent::Message { .. }));
let e = r.recv().await.unwrap();
assert!(matches!(e, ProtocolEvent::Message { .. }));
}
#[tokio::test]
async fn header_and_data_in_seperate_msg() {
let sid = Sid::new(1);
let (s, r) = async_channel::bounded(10);
let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap()));
let mut r =
super::QuicRecvProtocol::new(super::test_utils::QuicSink { receiver: r }, m.clone());
const DATA1: &[u8; 69] =
b"We need to make sure that its okay to send OPEN_STREAM and DATA_HEAD ";
const DATA2: &[u8; 95] = b"in one chunk and (DATA and CLOSE_STREAM) in the second chunk. and then keep the connection open";
let mut bytes = BytesMut::with_capacity(1500);
OTFrame::OpenStream {
sid,
prio: 5u8,
promises: Promises::COMPRESSED | Promises::ORDERED,
guaranteed_bandwidth: 1_000_000,
}
.write_bytes(&mut bytes);
s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap();
OTFrame::DataHeader {
mid: 99,
sid,
length: (DATA1.len() + DATA2.len()) as u64,
}
.write_bytes(&mut bytes);
s.send(QuicDataFormat::with_reliable(&mut bytes, 0))
.await
.unwrap();
OTFrame::Data {
mid: 99,
data: Bytes::from(&DATA1[..]),
}
.write_bytes(&mut bytes);
OTFrame::Data {
mid: 99,
data: Bytes::from(&DATA2[..]),
}
.write_bytes(&mut bytes);
s.send(QuicDataFormat::with_reliable(&mut bytes, 0))
.await
.unwrap();
OTFrame::CloseStream { sid }.write_bytes(&mut bytes);
s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap();
let e = r.recv().await.unwrap();
assert!(matches!(e, ProtocolEvent::OpenStream { .. }));
let e = r.recv().await.unwrap();
assert!(matches!(e, ProtocolEvent::Message { .. }));
let e = r.recv().await.unwrap();
assert!(matches!(e, ProtocolEvent::CloseStream { .. }));
}
#[tokio::test]
async fn drop_sink_while_recv() {
let sid = Sid::new(1);
let (s, r) = async_channel::bounded(10);
let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap()));
let mut r =
super::QuicRecvProtocol::new(super::test_utils::QuicSink { receiver: r }, m.clone());
let mut bytes = BytesMut::with_capacity(1500);
OTFrame::OpenStream {
sid,
prio: 5u8,
promises: Promises::COMPRESSED,
guaranteed_bandwidth: 1_000_000,
}
.write_bytes(&mut bytes);
s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap();
let e = r.recv().await.unwrap();
assert!(matches!(e, ProtocolEvent::OpenStream { .. }));
let e = tokio::spawn(async move { r.recv().await });
drop(s);
let e = e.await.unwrap();
assert_eq!(e, Err(ProtocolError::Closed));
}
#[tokio::test]
#[should_panic]
async fn send_on_stream_from_remote_without_notify() {
//remote opens stream
//we send on it
let [mut p1, mut p2] = quic_bound(10, 0.5, None);
let event = ProtocolEvent::OpenStream {
sid: Sid::new(10),
prio: 3u8,
promises: Promises::ORDERED,
guaranteed_bandwidth: 1_000_000,
};
p1.0.send(event).await.unwrap();
let _ = p2.1.recv().await.unwrap();
let event = ProtocolEvent::Message {
sid: Sid::new(10),
data: Bytes::from(&[188u8; 600][..]),
};
p2.0.send(event.clone()).await.unwrap();
p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
let e = p1.1.recv().await.unwrap();
assert_eq!(event, e);
}
#[tokio::test]
async fn send_on_stream_from_remote() {
//remote opens stream
//we send on it
let [mut p1, mut p2] = quic_bound(10, 0.5, None);
let event = ProtocolEvent::OpenStream {
sid: Sid::new(10),
prio: 3u8,
promises: Promises::ORDERED,
guaranteed_bandwidth: 1_000_000,
};
p1.0.send(event).await.unwrap();
let e = p2.1.recv().await.unwrap();
p2.0.notify_from_recv(e);
let event = ProtocolEvent::Message {
sid: Sid::new(10),
data: Bytes::from(&[188u8; 600][..]),
};
p2.0.send(event.clone()).await.unwrap();
p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap();
let e = p1.1.recv().await.unwrap();
assert_eq!(event, e);
}
#[tokio::test]
async fn unrealiable_test() {
const MIN_CHECK: usize = 10;
const COUNT: usize = 10_000;
//We send COUNT msg with 50% of be send each. we check that >= MIN_CHECK && !=
// COUNT reach their target
let [mut p1, mut p2] = quic_bound(
COUNT * 2 - 1, /* 2 times as it is HEADER + DATA but -1 as we want to see not all
* succeed */
0.5,
None,
);
let event = ProtocolEvent::OpenStream {
sid: Sid::new(1337),
prio: 3u8,
promises: Promises::empty(), /* on purpose! */
guaranteed_bandwidth: 1_000_000,
};
p1.0.send(event).await.unwrap();
let e = p2.1.recv().await.unwrap();
p2.0.notify_from_recv(e);
let event = ProtocolEvent::Message {
sid: Sid::new(1337),
data: Bytes::from(&[188u8; 600][..]),
};
for _ in 0..COUNT {
p2.0.send(event.clone()).await.unwrap();
}
p2.0.flush(1_000_000_000, Duration::from_secs(1))
.await
.unwrap();
for _ in 0..COUNT {
p2.0.send(event.clone()).await.unwrap();
}
for _ in 0..MIN_CHECK {
let e = p1.1.recv().await.unwrap();
assert_eq!(event, e);
}
}
}

View File

@ -176,7 +176,7 @@ where
self.buffer.reserve(total_bytes as usize);
let mut data_frames = 0;
let mut data_bandwidth = 0;
for frame in frames {
for (_, frame) in frames {
if let OTFrame::Data { mid: _, data } = &frame {
data_bandwidth += data.len();
data_frames += 1;

View File

@ -0,0 +1,71 @@
/// Used for storing Buffers in a QUIC
#[derive(Debug)]
pub struct SortedVec<K, V> {
pub data: Vec<(K, V)>,
}
impl<K, V> Default for SortedVec<K, V> {
fn default() -> Self { Self { data: vec![] } }
}
impl<K, V> SortedVec<K, V>
where
K: Ord + Copy,
{
pub fn insert(&mut self, k: K, v: V) {
self.data.push((k, v));
self.data.sort_by_key(|&(k, _)| k);
}
pub fn delete(&mut self, k: &K) -> Option<V> {
if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) {
Some(self.data.remove(i).1)
} else {
None
}
}
pub fn get(&self, k: &K) -> Option<&V> {
if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) {
Some(&self.data[i].1)
} else {
None
}
}
pub fn get_mut(&mut self, k: &K) -> Option<&mut V> {
if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) {
Some(&mut self.data[i].1)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sorted_vec() {
let mut vec = SortedVec::default();
vec.insert(10, "Hello");
println!("{:?}", vec.data);
vec.insert(30, "World");
println!("{:?}", vec.data);
vec.insert(20, " ");
println!("{:?}", vec.data);
assert_eq!(vec.data[0].1, "Hello");
assert_eq!(vec.data[1].1, " ");
assert_eq!(vec.data[2].1, "World");
assert_eq!(vec.get(&30), Some(&"World"));
assert_eq!(vec.get_mut(&20), Some(&mut " "));
assert_eq!(vec.get(&10), Some(&"Hello"));
assert_eq!(vec.delete(&40), None);
assert_eq!(vec.delete(&10), Some("Hello"));
assert_eq!(vec.delete(&10), None);
assert_eq!(vec.get(&30), Some(&"World"));
assert_eq!(vec.get_mut(&20), Some(&mut " "));
assert_eq!(vec.get(&10), None);
}
}

View File

@ -5,6 +5,7 @@ use network_protocol::{
ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol,
TcpSendProtocol, UnreliableDrain, UnreliableSink,
};
#[cfg(feature = "quic")] use quinn::*;
use std::{sync::Arc, time::Duration};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
@ -16,18 +17,24 @@ use tokio::{
pub(crate) enum Protocols {
Tcp((TcpSendProtocol<TcpDrain>, TcpRecvProtocol<TcpSink>)),
Mpsc((MpscSendProtocol<MpscDrain>, MpscRecvProtocol<MpscSink>)),
#[cfg(feature = "quic")]
Quic((QuicSendProtocol<QuicDrain>, QuicRecvProtocol<QuicSink>)),
}
#[derive(Debug)]
pub(crate) enum SendProtocols {
Tcp(TcpSendProtocol<TcpDrain>),
Mpsc(MpscSendProtocol<MpscDrain>),
#[cfg(feature = "quic")]
Quic(QuicSendProtocol<QuicDrain>),
}
#[derive(Debug)]
pub(crate) enum RecvProtocols {
Tcp(TcpRecvProtocol<TcpSink>),
Mpsc(MpscRecvProtocol<MpscSink>),
#[cfg(feature = "quic")]
Quic(QuicSendProtocol<QuicDrain>),
}
impl Protocols {
@ -67,6 +74,8 @@ impl Protocols {
match self {
Protocols::Tcp((s, r)) => (SendProtocols::Tcp(s), RecvProtocols::Tcp(r)),
Protocols::Mpsc((s, r)) => (SendProtocols::Mpsc(s), RecvProtocols::Mpsc(r)),
#[cfg(feature = "quic")]
Protocols::Quic((s, r)) => (SendProtocols::Quic(s), RecvProtocols::Quic(r)),
}
}
}
@ -82,6 +91,8 @@ impl network_protocol::InitProtocol for Protocols {
match self {
Protocols::Tcp(p) => p.initialize(initializer, local_pid, secret).await,
Protocols::Mpsc(p) => p.initialize(initializer, local_pid, secret).await,
#[cfg(feature = "quic")]
Protocols::Quic(p) => p.initialize(initializer, local_pid, secret).await,
}
}
}
@ -92,6 +103,8 @@ impl network_protocol::SendProtocol for SendProtocols {
match self {
SendProtocols::Tcp(s) => s.notify_from_recv(event),
SendProtocols::Mpsc(s) => s.notify_from_recv(event),
#[cfg(feature = "quic")]
SendProtocols::Quic(s) => s.notify_from_recv(event),
}
}
@ -99,6 +112,8 @@ impl network_protocol::SendProtocol for SendProtocols {
match self {
SendProtocols::Tcp(s) => s.send(event).await,
SendProtocols::Mpsc(s) => s.send(event).await,
#[cfg(feature = "quic")]
SendProtocols::Quic(s) => s.send(event).await,
}
}
@ -110,6 +125,8 @@ impl network_protocol::SendProtocol for SendProtocols {
match self {
SendProtocols::Tcp(s) => s.flush(bandwidth, dt).await,
SendProtocols::Mpsc(s) => s.flush(bandwidth, dt).await,
#[cfg(feature = "quic")]
SendProtocols::Quic(s) => s.flush(bandwidth, dt).await,
}
}
}
@ -120,6 +137,8 @@ impl network_protocol::RecvProtocol for RecvProtocols {
match self {
RecvProtocols::Tcp(r) => r.recv().await,
RecvProtocols::Mpsc(r) => r.recv().await,
#[cfg(feature = "quic")]
RecvProtocols::Quic(r) => r.recv().await,
}
}
}
@ -196,6 +215,45 @@ impl UnreliableSink for MpscSink {
}
}
///////////////////////////////////////
//// QUIC
#[derive(Debug)]
pub struct QuicDrain {
half: OwnedWriteHalf,
}
#[derive(Debug)]
pub struct QuicSink {
half: OwnedReadHalf,
buffer: BytesMut,
}
#[async_trait]
impl UnreliableDrain for QuicDrain {
type DataFormat = BytesMut;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> {
match self.half.write_all(&data).await {
Ok(()) => Ok(()),
Err(_) => Err(ProtocolError::Closed),
}
}
}
#[async_trait]
impl UnreliableSink for QuicSink {
type DataFormat = BytesMut;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> {
self.buffer.resize(1500, 0u8);
match self.half.read(&mut self.buffer).await {
Ok(0) => Err(ProtocolError::Closed),
Ok(n) => Ok(self.buffer.split_to(n)),
Err(_) => Err(ProtocolError::Closed),
}
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -2,12 +2,13 @@ use crate::{
api::{ParticipantError, Stream},
channel::{Protocols, RecvProtocols, SendProtocols},
metrics::NetworkMetrics,
util::{DeferredTracer, SortedVec},
util::DeferredTracer,
};
use bytes::Bytes;
use futures_util::{FutureExt, StreamExt};
use network_protocol::{
Bandwidth, Cid, Pid, Prio, Promises, ProtocolEvent, RecvProtocol, SendProtocol, Sid,
_internal::SortedVec,
};
use std::{
collections::HashMap,

View File

@ -44,74 +44,3 @@ impl<T: Eq + Hash> DeferredTracer<T> {
}
}
}
/// Used for storing Protocols in a Participant or Stream <-> Protocol
pub(crate) struct SortedVec<K, V> {
pub data: Vec<(K, V)>,
}
impl<K, V> Default for SortedVec<K, V> {
fn default() -> Self { Self { data: vec![] } }
}
impl<K, V> SortedVec<K, V>
where
K: Ord + Copy,
{
pub fn insert(&mut self, k: K, v: V) {
self.data.push((k, v));
self.data.sort_by_key(|&(k, _)| k);
}
pub fn delete(&mut self, k: &K) -> Option<V> {
if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) {
Some(self.data.remove(i).1)
} else {
None
}
}
pub fn get(&self, k: &K) -> Option<&V> {
if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) {
Some(&self.data[i].1)
} else {
None
}
}
pub fn get_mut(&mut self, k: &K) -> Option<&mut V> {
if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) {
Some(&mut self.data[i].1)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sorted_vec() {
let mut vec = SortedVec::default();
vec.insert(10, "Hello");
println!("{:?}", vec.data);
vec.insert(30, "World");
println!("{:?}", vec.data);
vec.insert(20, " ");
println!("{:?}", vec.data);
assert_eq!(vec.data[0].1, "Hello");
assert_eq!(vec.data[1].1, " ");
assert_eq!(vec.data[2].1, "World");
assert_eq!(vec.get(&30), Some(&"World"));
assert_eq!(vec.get_mut(&20), Some(&mut " "));
assert_eq!(vec.get(&10), Some(&"Hello"));
assert_eq!(vec.delete(&40), None);
assert_eq!(vec.delete(&10), Some("Hello"));
assert_eq!(vec.delete(&10), None);
assert_eq!(vec.get(&30), Some(&"World"));
assert_eq!(vec.get_mut(&20), Some(&mut " "));
assert_eq!(vec.get(&10), None);
}
}