From 9884019963241823326d564fe7103dd9e9198344 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= Date: Fri, 22 Jan 2021 17:09:20 +0100 Subject: [PATCH] COMPLETE REDESIGN of network crate - Implementing a async non-io protocol crate a) no tokio / no channels b) I/O is based on abstraction Sink/Drain c) different Protocols can have a different Drain Type This allow MPSC to send its content without splitting up messages at all! It allows UDP to have internal extra frames to care for security It allows better abstraction for tests Allows benchmarks on the mpsc variant Custom Handshakes to allow sth like Quic protocol easily - reduce the participant managers to 4: channel creations, send, recv and shutdown. keeping the `mut data` in one manager removes the need for all RwLocks. reducing complexity and parallel access problems - more strategic participant shutdown. first send. then wait for remote side to notice recv stop, then remote side will stop send, then local side can stop recv. - metrics are internally abstracted to fit protocol and network layer - in this commit network/protocol tests work and network tests work someway, veloren compiles but does not work - handshake compatible to async_std --- Cargo.lock | 83 +- Cargo.toml | 5 +- client/Cargo.toml | 2 +- network/Cargo.toml | 10 +- network/protocol/Cargo.toml | 33 + network/protocol/benches/protocols.rs | 243 +++++ network/protocol/src/event.rs | 74 ++ network/protocol/src/frame.rs | 634 ++++++++++++ network/protocol/src/handshake.rs | 227 +++++ network/protocol/src/io.rs | 62 ++ network/protocol/src/lib.rs | 75 ++ network/protocol/src/message.rs | 127 +++ network/protocol/src/metrics.rs | 414 ++++++++ network/protocol/src/mpsc.rs | 217 ++++ network/protocol/src/prio.rs | 139 +++ network/protocol/src/tcp.rs | 584 +++++++++++ network/{ => protocol}/src/types.rs | 152 +-- network/protocol/src/udp.rs | 37 + network/src/api.rs | 170 ++-- network/src/channel.rs | 560 ++++------ network/src/lib.rs | 6 +- network/src/message.rs | 89 +- network/src/metrics.rs | 284 +----- network/src/participant.rs | 1350 ++++++++++++------------- network/src/prios.rs | 697 ------------- network/src/protocols.rs | 591 ----------- network/src/scheduler.rs | 140 ++- server/Cargo.toml | 2 +- voxygen/src/hud/chat.rs | 2 +- 29 files changed, 3987 insertions(+), 3022 deletions(-) create mode 100644 network/protocol/Cargo.toml create mode 100644 network/protocol/benches/protocols.rs create mode 100644 network/protocol/src/event.rs create mode 100644 network/protocol/src/frame.rs create mode 100644 network/protocol/src/handshake.rs create mode 100644 network/protocol/src/io.rs create mode 100644 network/protocol/src/lib.rs create mode 100644 network/protocol/src/message.rs create mode 100644 network/protocol/src/metrics.rs create mode 100644 network/protocol/src/mpsc.rs create mode 100644 network/protocol/src/prio.rs create mode 100644 network/protocol/src/tcp.rs rename network/{ => protocol}/src/types.rs (52%) create mode 100644 network/protocol/src/udp.rs delete mode 100644 network/src/prios.rs delete mode 100644 network/src/protocols.rs diff --git a/Cargo.lock b/Cargo.lock index 69b549cd2b..bd0dcb62ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -259,6 +259,17 @@ dependencies = [ "futures-core", ] +[[package]] +name = "async-trait" +version = "0.1.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d3a45e77e34375a7923b1e8febb049bb011f064714a8e17a1a616fef01da13d" +dependencies = [ + "proc-macro2 1.0.24", + "quote 1.0.9", + "syn 1.0.60", +] + [[package]] name = "atom" version = "0.3.6" @@ -1057,6 +1068,7 @@ dependencies = [ "clap", "criterion-plot", "csv", + "futures", "itertools 0.10.0", "lazy_static", "num-traits", @@ -1069,6 +1081,7 @@ dependencies = [ "serde_derive", "serde_json", "tinytemplate", + "tokio 1.2.0", "walkdir 2.3.1", ] @@ -5580,7 +5593,7 @@ dependencies = [ "veloren-common", "veloren-common-net", "veloren-common-sys", - "veloren_network", + "veloren-network", ] [[package]] @@ -5661,6 +5674,47 @@ dependencies = [ "wasmer", ] +[[package]] +name = "veloren-network" +version = "0.3.0" +dependencies = [ + "async-channel", + "async-trait", + "bincode", + "bitflags", + "clap", + "crossbeam-channel 0.5.0", + "futures-core", + "futures-util", + "lazy_static", + "lz-fear", + "prometheus", + "rand 0.8.3", + "serde", + "shellexpand", + "tiny_http", + "tokio 1.2.0", + "tokio-stream", + "tracing", + "tracing-futures", + "tracing-subscriber", + "veloren-network-protocol", +] + +[[package]] +name = "veloren-network-protocol" +version = "0.5.0" +dependencies = [ + "async-channel", + "async-trait", + "bitflags", + "criterion", + "prometheus", + "rand 0.8.3", + "tokio 1.2.0", + "tracing", +] + [[package]] name = "veloren-plugin-api" version = "0.1.0" @@ -5725,9 +5779,9 @@ dependencies = [ "veloren-common", "veloren-common-net", "veloren-common-sys", + "veloren-network", "veloren-plugin-api", "veloren-world", - "veloren_network", ] [[package]] @@ -5864,31 +5918,6 @@ dependencies = [ "veloren-common-net", ] -[[package]] -name = "veloren_network" -version = "0.3.0" -dependencies = [ - "async-channel", - "bincode", - "bitflags", - "clap", - "crossbeam-channel 0.5.0", - "futures-core", - "futures-util", - "lazy_static", - "lz-fear", - "prometheus", - "rand 0.8.3", - "serde", - "shellexpand", - "tiny_http", - "tokio 1.2.0", - "tokio-stream", - "tracing", - "tracing-futures", - "tracing-subscriber", -] - [[package]] name = "version-compare" version = "0.0.10" diff --git a/Cargo.toml b/Cargo.toml index 6bd656493d..e8104a7195 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ members = [ "voxygen/anim", "world", "network", + "network/protocol" ] # default profile for devs, fast to compile, okay enough to run, no debug information @@ -30,8 +31,10 @@ incremental = true # All dependencies (but not this crate itself) [profile.dev.package."*"] opt-level = 3 -[profile.dev.package."veloren_network"] +[profile.dev.package."veloren-network"] opt-level = 2 +[profile.dev.package."veloren-network-protocol"] +opt-level = 3 [profile.dev.package."veloren-common"] opt-level = 2 [profile.dev.package."veloren-client"] diff --git a/client/Cargo.toml b/client/Cargo.toml index b2ebcfead1..fcda01cb65 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -14,7 +14,7 @@ default = ["simd"] common = { package = "veloren-common", path = "../common", features = ["no-assets"] } common-sys = { package = "veloren-common-sys", path = "../common/sys", default-features = false } common-net = { package = "veloren-common-net", path = "../common/net" } -network = { package = "veloren_network", path = "../network", features = ["compression"], default-features = false } +network = { package = "veloren-network", path = "../network", features = ["compression"], default-features = false } byteorder = "1.3.2" uvth = "3.1.1" diff --git a/network/Cargo.toml b/network/Cargo.toml index 0a540ca6dc..f548278896 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "veloren_network" +name = "veloren-network" version = "0.3.0" authors = ["Marcel Märtens "] edition = "2018" @@ -7,13 +7,15 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -metrics = ["prometheus"] +metrics = ["prometheus", "network-protocol/metrics"] compression = ["lz-fear"] default = ["metrics","compression"] [dependencies] +network-protocol = { package = "veloren-network-protocol", path = "protocol", default-features = false } + #serialisation bincode = "1.3.1" serde = { version = "1.0" } @@ -35,10 +37,12 @@ rand = { version = "0.8" } #stream flags bitflags = "1.2.1" lz-fear = { version = "0.1.1", optional = true } +# async traits +async-trait = "0.1.42" [dev-dependencies] tracing-subscriber = { version = "0.2.3", default-features = false, features = ["env-filter", "fmt", "chrono", "ansi", "smallvec"] } -tokio = { version = "1.0.1", default-features = false, features = ["io-std", "fs", "rt-multi-thread"] } +tokio = { version = "1.1.0", default-features = false, features = ["io-std", "fs", "rt-multi-thread"] } futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] } clap = { version = "2.33", default-features = false } shellexpand = "2.0.0" diff --git a/network/protocol/Cargo.toml b/network/protocol/Cargo.toml new file mode 100644 index 0000000000..e097314b6b --- /dev/null +++ b/network/protocol/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "veloren-network-protocol" +description = "pure Protocol without any I/O itself" +version = "0.5.0" +authors = ["Marcel Märtens "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +metrics = ["prometheus"] + +default = ["metrics"] + +[dependencies] + +#tracing and metrics +tracing = { version = "0.1", default-features = false } +prometheus = { version = "0.11", default-features = false, optional = true } +#stream flags +bitflags = "1.2.1" +rand = { version = "0.8" } +# async traits +async-trait = "0.1.42" + +[dev-dependencies] +async-channel = "1.5.1" +tokio = { version = "1.2", default-features = false, features = ["rt", "macros"] } +criterion = { version = "0.3.4", features = ["default", "async_tokio"] } + +[[bench]] +name = "protocols" +harness = false \ No newline at end of file diff --git a/network/protocol/benches/protocols.rs b/network/protocol/benches/protocols.rs new file mode 100644 index 0000000000..5151083b98 --- /dev/null +++ b/network/protocol/benches/protocols.rs @@ -0,0 +1,243 @@ +use async_channel::*; +use async_trait::async_trait; +use criterion::{criterion_group, criterion_main, Criterion}; +use std::{sync::Arc, time::Duration}; +use veloren_network_protocol::{ + InitProtocol, MessageBuffer, MpscMsg, MpscRecvProtcol, MpscSendProtcol, Pid, Promises, + ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, RecvProtocol, SendProtocol, + Sid, TcpRecvProtcol, TcpSendProtcol, UnreliableDrain, UnreliableSink, _internal::Frame, +}; + +fn frame_serialize(frame: Frame, buffer: &mut [u8]) -> usize { frame.to_bytes(buffer).0 } + +async fn mpsc_msg(buffer: Arc) { + // Arrrg, need to include constructor here + let [p1, p2] = utils::ac_bound(10, None); + let (mut s, mut r) = (p1.0, p2.1); + s.send(ProtocolEvent::Message { + sid: Sid::new(12), + mid: 0, + buffer, + }) + .await + .unwrap(); + r.recv().await.unwrap(); +} + +async fn mpsc_handshake() { + let [mut p1, mut p2] = utils::ac_bound(10, None); + let r1 = tokio::spawn(async move { + p1.initialize(true, Pid::fake(2), 1337).await.unwrap(); + p1 + }); + let r2 = tokio::spawn(async move { + p2.initialize(false, Pid::fake(3), 42).await.unwrap(); + p2 + }); + let (r1, r2) = tokio::join!(r1, r2); + r1.unwrap(); + r2.unwrap(); +} + +async fn tcp_msg(buffer: Arc, cnt: usize) { + let [p1, p2] = utils::tcp_bound(10000, None); /*10kbit*/ + let (mut s, mut r) = (p1.0, p2.1); + + let buffer = Arc::clone(&buffer); + let bandwidth = buffer.data.len() as u64 + 1000; + + let r1 = tokio::spawn(async move { + s.send(ProtocolEvent::OpenStream { + sid: Sid::new(12), + prio: 0, + promises: Promises::ORDERED, + guaranteed_bandwidth: 100_000, + }) + .await + .unwrap(); + + for i in 0..cnt { + s.send(ProtocolEvent::Message { + sid: Sid::new(12), + mid: i as u64, + buffer: Arc::clone(&buffer), + }) + .await + .unwrap(); + s.flush(bandwidth, Duration::from_secs(1)).await.unwrap(); + } + }); + let r2 = tokio::spawn(async move { + r.recv().await.unwrap(); + + for _ in 0..cnt { + r.recv().await.unwrap(); + } + }); + let (r1, r2) = tokio::join!(r1, r2); + r1.unwrap(); + r2.unwrap(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let rt = || { + tokio::runtime::Builder::new_current_thread() + .build() + .unwrap() + }; + + c.bench_function("mpsc_short_msg", |b| { + let buffer = Arc::new(MessageBuffer { + data: b"hello_world".to_vec(), + }); + b.to_async(rt()).iter(|| mpsc_msg(Arc::clone(&buffer))) + }); + c.bench_function("mpsc_long_msg", |b| { + let buffer = Arc::new(MessageBuffer { + data: vec![150u8; 500_000], + }); + b.to_async(rt()).iter(|| mpsc_msg(Arc::clone(&buffer))) + }); + c.bench_function("mpsc_handshake", |b| { + b.to_async(rt()).iter(|| mpsc_handshake()) + }); + + let mut buffer = [0u8; 1500]; + + c.bench_function("frame_serialize_short", |b| { + let frame = Frame::Data { + mid: 65, + start: 89u64, + data: b"hello_world".to_vec(), + }; + b.iter(move || frame_serialize(frame.clone(), &mut buffer)) + }); + + c.bench_function("tcp_short_msg", |b| { + let buffer = Arc::new(MessageBuffer { + data: b"hello_world".to_vec(), + }); + b.to_async(rt()).iter(|| tcp_msg(Arc::clone(&buffer), 1)) + }); + c.bench_function("tcp_1GB_in_10000_msg", |b| { + let buffer = Arc::new(MessageBuffer { + data: vec![155u8; 100_000], + }); + b.to_async(rt()) + .iter(|| tcp_msg(Arc::clone(&buffer), 10_000)) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); + +mod utils { + use super::*; + + pub struct ACDrain { + sender: Sender, + } + + pub struct ACSink { + receiver: Receiver, + } + + pub fn ac_bound( + cap: usize, + metrics: Option, + ) -> [(MpscSendProtcol, MpscRecvProtcol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("mpsc", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + MpscSendProtcol::new(ACDrain { sender: s1 }, m.clone()), + MpscRecvProtcol::new(ACSink { receiver: r2 }, m.clone()), + ), + ( + MpscSendProtcol::new(ACDrain { sender: s2 }, m.clone()), + MpscRecvProtcol::new(ACSink { receiver: r1 }, m.clone()), + ), + ] + } + + pub struct TcpDrain { + sender: Sender>, + } + + pub struct TcpSink { + receiver: Receiver>, + } + + /// emulate Tcp protocol on Channels + pub fn tcp_bound( + cap: usize, + metrics: Option, + ) -> [(TcpSendProtcol, TcpRecvProtcol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("tcp", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + TcpSendProtcol::new(TcpDrain { sender: s1 }, m.clone()), + TcpRecvProtcol::new(TcpSink { receiver: r2 }, m.clone()), + ), + ( + TcpSendProtcol::new(TcpDrain { sender: s2 }, m.clone()), + TcpRecvProtcol::new(TcpSink { receiver: r1 }, m.clone()), + ), + ] + } + + #[async_trait] + impl UnreliableDrain for ACDrain { + type DataFormat = MpscMsg; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for ACSink { + type DataFormat = MpscMsg; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableDrain for TcpDrain { + type DataFormat = Vec; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for TcpSink { + type DataFormat = Vec; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } +} diff --git a/network/protocol/src/event.rs b/network/protocol/src/event.rs new file mode 100644 index 0000000000..14b74de558 --- /dev/null +++ b/network/protocol/src/event.rs @@ -0,0 +1,74 @@ +use crate::{ + frame::Frame, + message::MessageBuffer, + types::{Bandwidth, Mid, Prio, Promises, Sid}, +}; +use std::sync::Arc; + +/* used for communication with Protocols */ +#[derive(Debug, Clone)] +#[cfg_attr(test, derive(PartialEq))] +pub enum ProtocolEvent { + Shutdown, + OpenStream { + sid: Sid, + prio: Prio, + promises: Promises, + guaranteed_bandwidth: Bandwidth, + }, + CloseStream { + sid: Sid, + }, + Message { + buffer: Arc, + mid: Mid, + sid: Sid, + }, +} + +impl ProtocolEvent { + pub(crate) fn to_frame(&self) -> Frame { + match self { + ProtocolEvent::Shutdown => Frame::Shutdown, + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth: _, + } => Frame::OpenStream { + sid: *sid, + prio: *prio, + promises: *promises, + }, + ProtocolEvent::CloseStream { sid } => Frame::CloseStream { sid: *sid }, + ProtocolEvent::Message { .. } => { + unimplemented!("Event::Message to Frame IS NOT supported") + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_to_frame() { + assert_eq!(ProtocolEvent::Shutdown.to_frame(), Frame::Shutdown); + assert_eq!( + ProtocolEvent::CloseStream { sid: Sid::new(42) }.to_frame(), + Frame::CloseStream { sid: Sid::new(42) } + ); + } + + #[test] + #[should_panic] + fn test_sixlet_to_str() { + let _ = ProtocolEvent::Message { + buffer: Arc::new(MessageBuffer { data: vec![] }), + mid: 0, + sid: Sid::new(23), + } + .to_frame(); + } +} diff --git a/network/protocol/src/frame.rs b/network/protocol/src/frame.rs new file mode 100644 index 0000000000..4824498940 --- /dev/null +++ b/network/protocol/src/frame.rs @@ -0,0 +1,634 @@ +use crate::types::{Mid, Pid, Prio, Promises, Sid}; +use std::{collections::VecDeque, convert::TryFrom}; + +// const FRAME_RESERVED_1: u8 = 0; +const FRAME_HANDSHAKE: u8 = 1; +const FRAME_INIT: u8 = 2; +const FRAME_SHUTDOWN: u8 = 3; +const FRAME_OPEN_STREAM: u8 = 4; +const FRAME_CLOSE_STREAM: u8 = 5; +const FRAME_DATA_HEADER: u8 = 6; +const FRAME_DATA: u8 = 7; +const FRAME_RAW: u8 = 8; +//const FRAME_RESERVED_2: u8 = 10; +//const FRAME_RESERVED_3: u8 = 13; + +/// Used for Communication between Channel <----(TCP/UDP)----> Channel +#[derive(Debug, PartialEq, Clone)] +pub /* should be crate only */ enum InitFrame { + Handshake { + magic_number: [u8; 7], + version: [u32; 3], + }, + Init { + pid: Pid, + secret: u128, + }, + /* WARNING: Sending RAW is only used for debug purposes in case someone write a new API + * against veloren Server! */ + Raw(Vec), +} + +/// Used for Communication between Channel <----(TCP/UDP)----> Channel +#[derive(Debug, PartialEq, Clone)] +pub enum Frame { + Shutdown, /* Shutdown this channel gracefully, if all channels are shutdown (gracefully), + * Participant is deleted */ + OpenStream { + sid: Sid, + prio: Prio, + promises: Promises, + }, + CloseStream { + sid: Sid, + }, + DataHeader { + mid: Mid, + sid: Sid, + length: u64, + }, + Data { + mid: Mid, + start: u64, + data: Vec, + }, +} + +impl InitFrame { + // Size WITHOUT the 1rst indicating byte + pub(crate) const HANDSHAKE_CNS: usize = 19; + pub(crate) const INIT_CNS: usize = 32; + /// const part of the RAW frame, actual size is variable + pub(crate) const RAW_CNS: usize = 2; + + //provide an appropriate buffer size. > 1500 + pub(crate) fn to_bytes(self, bytes: &mut [u8]) -> usize { + match self { + InitFrame::Handshake { + magic_number, + version, + } => { + let x = FRAME_HANDSHAKE.to_be_bytes(); + bytes[0] = x[0]; + bytes[1..8].copy_from_slice(&magic_number); + bytes[8..12].copy_from_slice(&version[0].to_le_bytes()); + bytes[12..16].copy_from_slice(&version[1].to_le_bytes()); + bytes[16..Self::HANDSHAKE_CNS + 1].copy_from_slice(&version[2].to_le_bytes()); + Self::HANDSHAKE_CNS + 1 + }, + InitFrame::Init { pid, secret } => { + bytes[0] = FRAME_INIT.to_be_bytes()[0]; + bytes[1..17].copy_from_slice(&pid.to_le_bytes()); + bytes[17..Self::INIT_CNS + 1].copy_from_slice(&secret.to_le_bytes()); + Self::INIT_CNS + 1 + }, + InitFrame::Raw(data) => { + bytes[0] = FRAME_RAW.to_be_bytes()[0]; + bytes[1..3].copy_from_slice(&(data.len() as u16).to_le_bytes()); + bytes[Self::RAW_CNS + 1..(data.len() + Self::RAW_CNS + 1)] + .clone_from_slice(&data[..]); + Self::RAW_CNS + 1 + data.len() + }, + } + } + + pub(crate) fn to_frame(bytes: Vec) -> Option { + let frame_no = match bytes.get(0) { + Some(&f) => f, + None => return None, + }; + let frame = match frame_no { + FRAME_HANDSHAKE => { + if bytes.len() < Self::HANDSHAKE_CNS + 1 { + return None; + } + InitFrame::gen_handshake( + *<&[u8; Self::HANDSHAKE_CNS]>::try_from(&bytes[1..Self::HANDSHAKE_CNS + 1]) + .unwrap(), + ) + }, + FRAME_INIT => { + if bytes.len() < Self::INIT_CNS + 1 { + return None; + } + InitFrame::gen_init( + *<&[u8; Self::INIT_CNS]>::try_from(&bytes[1..Self::INIT_CNS + 1]).unwrap(), + ) + }, + FRAME_RAW => { + if bytes.len() < Self::RAW_CNS + 1 { + return None; + } + let length = InitFrame::gen_raw( + *<&[u8; Self::RAW_CNS]>::try_from(&bytes[1..Self::RAW_CNS + 1]).unwrap(), + ); + let mut data = vec![0; length as usize]; + let slice = &bytes[Self::RAW_CNS + 1..]; + if slice.len() != length as usize { + return None; + } + data.copy_from_slice(&bytes[Self::RAW_CNS + 1..]); + InitFrame::Raw(data) + }, + _ => InitFrame::Raw(bytes), + }; + Some(frame) + } + + fn gen_handshake(buf: [u8; Self::HANDSHAKE_CNS]) -> Self { + let magic_number = *<&[u8; 7]>::try_from(&buf[0..7]).unwrap(); + InitFrame::Handshake { + magic_number, + version: [ + u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[7..11]).unwrap()), + u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[11..15]).unwrap()), + u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[15..Self::HANDSHAKE_CNS]).unwrap()), + ], + } + } + + fn gen_init(buf: [u8; Self::INIT_CNS]) -> Self { + InitFrame::Init { + pid: Pid::from_le_bytes(*<&[u8; 16]>::try_from(&buf[0..16]).unwrap()), + secret: u128::from_le_bytes(*<&[u8; 16]>::try_from(&buf[16..Self::INIT_CNS]).unwrap()), + } + } + + fn gen_raw(buf: [u8; Self::RAW_CNS]) -> u16 { + u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[0..Self::RAW_CNS]).unwrap()) + } +} + +impl Frame { + pub(crate) const CLOSE_STREAM_CNS: usize = 8; + /// const part of the DATA frame, actual size is variable + pub(crate) const DATA_CNS: usize = 18; + pub(crate) const DATA_HEADER_CNS: usize = 24; + #[cfg(feature = "metrics")] + pub const FRAMES_LEN: u8 = 5; + pub(crate) const OPEN_STREAM_CNS: usize = 10; + // Size WITHOUT the 1rst indicating byte + pub(crate) const SHUTDOWN_CNS: usize = 0; + + #[cfg(feature = "metrics")] + pub const fn int_to_string(i: u8) -> &'static str { + match i { + 0 => "Shutdown", + 1 => "OpenStream", + 2 => "CloseStream", + 3 => "DataHeader", + 4 => "Data", + _ => "", + } + } + + #[cfg(feature = "metrics")] + pub fn get_int(&self) -> u8 { + match self { + Frame::Shutdown => 0, + Frame::OpenStream { .. } => 1, + Frame::CloseStream { .. } => 2, + Frame::DataHeader { .. } => 3, + Frame::Data { .. } => 4, + } + } + + #[cfg(feature = "metrics")] + pub fn get_string(&self) -> &str { Self::int_to_string(self.get_int()) } + + //provide an appropriate buffer size. > 1500 + pub fn to_bytes(self, bytes: &mut [u8]) -> (/* buf */ usize, /* actual data */ u64) { + match self { + Frame::Shutdown => { + bytes[Self::SHUTDOWN_CNS] = FRAME_SHUTDOWN.to_be_bytes()[0]; + (Self::SHUTDOWN_CNS + 1, 0) + }, + Frame::OpenStream { + sid, + prio, + promises, + } => { + bytes[0] = FRAME_OPEN_STREAM.to_be_bytes()[0]; + bytes[1..9].copy_from_slice(&sid.to_le_bytes()); + bytes[9] = prio.to_le_bytes()[0]; + bytes[Self::OPEN_STREAM_CNS] = promises.to_le_bytes()[0]; + (Self::OPEN_STREAM_CNS + 1, 0) + }, + Frame::CloseStream { sid } => { + bytes[0] = FRAME_CLOSE_STREAM.to_be_bytes()[0]; + bytes[1..Self::CLOSE_STREAM_CNS + 1].copy_from_slice(&sid.to_le_bytes()); + (Self::CLOSE_STREAM_CNS + 1, 0) + }, + Frame::DataHeader { mid, sid, length } => { + bytes[0] = FRAME_DATA_HEADER.to_be_bytes()[0]; + bytes[1..9].copy_from_slice(&mid.to_le_bytes()); + bytes[9..17].copy_from_slice(&sid.to_le_bytes()); + bytes[17..Self::DATA_HEADER_CNS + 1].copy_from_slice(&length.to_le_bytes()); + (Self::DATA_HEADER_CNS + 1, 0) + }, + Frame::Data { mid, start, data } => { + bytes[0] = FRAME_DATA.to_be_bytes()[0]; + bytes[1..9].copy_from_slice(&mid.to_le_bytes()); + bytes[9..17].copy_from_slice(&start.to_le_bytes()); + bytes[17..Self::DATA_CNS + 1].copy_from_slice(&(data.len() as u16).to_le_bytes()); + bytes[Self::DATA_CNS + 1..(data.len() + Self::DATA_CNS + 1)] + .clone_from_slice(&data[..]); + (Self::DATA_CNS + 1 + data.len(), data.len() as u64) + }, + } + } + + pub(crate) fn to_frame(bytes: &mut VecDeque) -> Option { + let frame_no = match bytes.get(0) { + Some(&f) => f, + None => return None, + }; + let size = match frame_no { + FRAME_SHUTDOWN => Self::SHUTDOWN_CNS, + FRAME_OPEN_STREAM => Self::OPEN_STREAM_CNS, + FRAME_CLOSE_STREAM => Self::CLOSE_STREAM_CNS, + FRAME_DATA_HEADER => Self::DATA_HEADER_CNS, + FRAME_DATA => { + u16::from_le_bytes([bytes[16 + 1], bytes[17 + 1]]) as usize + Self::DATA_CNS + }, + _ => return None, + }; + + if bytes.len() < size + 1 { + return None; + } + + let frame = match frame_no { + FRAME_SHUTDOWN => { + let _ = bytes.drain(..size + 1); + Frame::Shutdown + }, + FRAME_OPEN_STREAM => { + let bytes = bytes.drain(..size + 1).skip(1).collect::>(); + Frame::gen_open_stream(<[u8; 10]>::try_from(bytes).unwrap()) + }, + FRAME_CLOSE_STREAM => { + let bytes = bytes.drain(..size + 1).skip(1).collect::>(); + Frame::gen_close_stream(<[u8; 8]>::try_from(bytes).unwrap()) + }, + FRAME_DATA_HEADER => { + let bytes = bytes.drain(..size + 1).skip(1).collect::>(); + Frame::gen_data_header(<[u8; 24]>::try_from(bytes).unwrap()) + }, + FRAME_DATA => { + let info = bytes + .drain(..Self::DATA_CNS + 1) + .skip(1) + .collect::>(); + let (mid, start, length) = Frame::gen_data(<[u8; 18]>::try_from(info).unwrap()); + debug_assert_eq!(length as usize, size - Self::DATA_CNS); + let data = bytes.drain(..length as usize).collect::>(); + Frame::Data { mid, start, data } + }, + _ => unreachable!("Frame::to_frame should be handled before!"), + }; + Some(frame) + } + + fn gen_open_stream(buf: [u8; Self::OPEN_STREAM_CNS]) -> Self { + Frame::OpenStream { + sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), + prio: buf[8], + promises: Promises::from_bits_truncate(buf[Self::OPEN_STREAM_CNS - 1]), + } + } + + fn gen_close_stream(buf: [u8; Self::CLOSE_STREAM_CNS]) -> Self { + Frame::CloseStream { + sid: Sid::from_le_bytes( + *<&[u8; 8]>::try_from(&buf[0..Self::CLOSE_STREAM_CNS]).unwrap(), + ), + } + } + + fn gen_data_header(buf: [u8; Self::DATA_HEADER_CNS]) -> Self { + Frame::DataHeader { + mid: Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), + sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()), + length: u64::from_le_bytes( + *<&[u8; 8]>::try_from(&buf[16..Self::DATA_HEADER_CNS]).unwrap(), + ), + } + } + + fn gen_data(buf: [u8; Self::DATA_CNS]) -> (Mid, u64, u16) { + let mid = Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()); + let start = u64::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()); + let length = u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[16..Self::DATA_CNS]).unwrap()); + (mid, start, length) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::{VELOREN_MAGIC_NUMBER, VELOREN_NETWORK_VERSION}; + + fn get_initframes() -> Vec { + vec![ + InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }, + InitFrame::Init { + pid: Pid::fake(0), + secret: 0u128, + }, + InitFrame::Raw(vec![1, 2, 3]), + ] + } + + fn get_frames() -> Vec { + vec![ + Frame::OpenStream { + sid: Sid::new(1337), + prio: 14, + promises: Promises::GUARANTEED_DELIVERY, + }, + Frame::DataHeader { + sid: Sid::new(1337), + mid: 0, + length: 36, + }, + Frame::Data { + mid: 0, + start: 0, + data: vec![77u8; 20], + }, + Frame::Data { + mid: 0, + start: 20, + data: vec![42u8; 16], + }, + Frame::CloseStream { + sid: Sid::new(1337), + }, + Frame::Shutdown, + ] + } + + #[test] + fn initframe_individual() { + let dupl = |frame: InitFrame| { + let mut buffer = vec![0u8; 1500]; + let size = InitFrame::to_bytes(frame.clone(), &mut buffer); + buffer.truncate(size); + InitFrame::to_frame(buffer) + }; + + for frame in get_initframes() { + println!("initframe: {:?}", &frame); + assert_eq!(Some(frame.clone()), dupl(frame)); + } + } + + #[test] + fn initframe_multiple() { + let mut buffer = vec![0u8; 3000]; + + let mut frames = get_initframes(); + let mut last = 0; + // to string + let sizes = frames + .iter() + .map(|f| { + let s = InitFrame::to_bytes(f.clone(), &mut buffer[last..]); + last += s; + s + }) + .collect::>(); + + // from string + let mut last = 0; + let mut framesd = sizes + .iter() + .map(|&s| { + let f = InitFrame::to_frame(buffer[last..last + s].to_vec()); + last += s; + f + }) + .collect::>(); + + // compare + for (f, fd) in frames.drain(..).zip(framesd.drain(..)) { + println!("initframe: {:?}", &f); + assert_eq!(Some(f), fd); + } + } + + #[test] + fn frame_individual() { + let dupl = |frame: Frame| { + let mut buffer = vec![0u8; 1500]; + let (size, _) = Frame::to_bytes(frame.clone(), &mut buffer); + let mut deque = buffer[..size].iter().map(|b| *b).collect(); + Frame::to_frame(&mut deque) + }; + + for frame in get_frames() { + println!("frame: {:?}", &frame); + assert_eq!(Some(frame.clone()), dupl(frame)); + } + } + + #[test] + fn frame_multiple() { + let mut buffer = vec![0u8; 3000]; + + let mut frames = get_frames(); + let mut last = 0; + // to string + let sizes = frames + .iter() + .map(|f| { + let s = Frame::to_bytes(f.clone(), &mut buffer[last..]).0; + last += s; + s + }) + .collect::>(); + + assert_eq!(sizes[0], 1 + Frame::OPEN_STREAM_CNS); + assert_eq!(sizes[1], 1 + Frame::DATA_HEADER_CNS); + assert_eq!(sizes[2], 1 + Frame::DATA_CNS + 20); + assert_eq!(sizes[3], 1 + Frame::DATA_CNS + 16); + assert_eq!(sizes[4], 1 + Frame::CLOSE_STREAM_CNS); + assert_eq!(sizes[5], 1 + Frame::SHUTDOWN_CNS); + + let mut buffer = buffer.drain(..).collect::>(); + + // from string + let mut framesd = sizes + .iter() + .map(|&_| Frame::to_frame(&mut buffer)) + .collect::>(); + + // compare + for (f, fd) in frames.drain(..).zip(framesd.drain(..)) { + println!("frame: {:?}", &f); + assert_eq!(Some(f), fd); + } + } + + #[test] + fn frame_exact_size() { + let mut buffer = vec![0u8; Frame::CLOSE_STREAM_CNS+1/*first byte*/]; + + let frame1 = Frame::CloseStream { + sid: Sid::new(1337), + }; + let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + let mut deque = buffer.iter().map(|b| *b).collect(); + let frame2 = Frame::to_frame(&mut deque); + assert_eq!(Some(frame1), frame2); + } + + #[test] + #[should_panic] + fn initframe_too_short_buffer() { + let mut buffer = vec![0u8; 10]; + + let frame1 = InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }; + let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); + } + + #[test] + fn initframe_too_less_data() { + let mut buffer = vec![0u8; 20]; + + let frame1 = InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }; + let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); + buffer.truncate(6); // simulate partial retrieve + let frame1d = InitFrame::to_frame(buffer[..6].to_vec()); + assert_eq!(frame1d, None); + } + + #[test] + fn initframe_rubish() { + let buffer = b"dtrgwcser".to_vec(); + assert_eq!( + InitFrame::to_frame(buffer), + Some(InitFrame::Raw(b"dtrgwcser".to_vec())) + ); + } + + #[test] + fn initframe_attack_too_much_length() { + let mut buffer = vec![0u8; 50]; + + let frame1 = InitFrame::Raw(b"foobar".to_vec()); + let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); + buffer[2] = 255; + let framed = InitFrame::to_frame(buffer); + assert_eq!(framed, None); + } + + #[test] + fn initframe_attack_too_low_length() { + let mut buffer = vec![0u8; 50]; + + let frame1 = InitFrame::Raw(b"foobar".to_vec()); + let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); + buffer[2] = 3; + let framed = InitFrame::to_frame(buffer); + assert_eq!(framed, None); + } + + #[test] + #[should_panic] + fn frame_too_short_buffer() { + let mut buffer = vec![0u8; 10]; + + let frame1 = Frame::OpenStream { + sid: Sid::new(88), + promises: Promises::ENCRYPTED, + prio: 88, + }; + let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + } + + #[test] + fn frame_too_less_data() { + let mut buffer = vec![0u8; 20]; + + let frame1 = Frame::OpenStream { + sid: Sid::new(88), + promises: Promises::ENCRYPTED, + prio: 88, + }; + let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + buffer.truncate(6); // simulate partial retrieve + let mut buffer = buffer.drain(..6).collect::>(); + let frame1d = Frame::to_frame(&mut buffer); + assert_eq!(frame1d, None); + } + + #[test] + fn frame_rubish() { + let mut buffer = b"dtrgwcser".iter().map(|u| *u).collect::>(); + assert_eq!(Frame::to_frame(&mut buffer), None); + } + + #[test] + fn frame_attack_too_much_length() { + let mut buffer = vec![0u8; 50]; + + let frame1 = Frame::Data { + mid: 7u64, + start: 1u64, + data: b"foobar".to_vec(), + }; + + let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + buffer[17] = 255; + let mut buffer = buffer.drain(..).collect::>(); + let framed = Frame::to_frame(&mut buffer); + assert_eq!(framed, None); + } + + #[test] + fn frame_attack_too_low_length() { + let mut buffer = vec![0u8; 50]; + + let frame1 = Frame::Data { + mid: 7u64, + start: 1u64, + data: b"foobar".to_vec(), + }; + + let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + buffer[17] = 3; + let mut buffer = buffer.drain(..).collect::>(); + let framed = Frame::to_frame(&mut buffer); + assert_eq!( + framed, + Some(Frame::Data { + mid: 7u64, + start: 1u64, + data: b"foo".to_vec(), + }) + ); + //next = Invalid => Empty + let framed = Frame::to_frame(&mut buffer); + assert_eq!(framed, None); + } + + #[test] + fn frame_int2str() { + assert_eq!(Frame::int_to_string(0), "Shutdown"); + assert_eq!(Frame::int_to_string(1), "OpenStream"); + assert_eq!(Frame::int_to_string(2), "CloseStream"); + assert_eq!(Frame::int_to_string(3), "DataHeader"); + assert_eq!(Frame::int_to_string(4), "Data"); + } +} diff --git a/network/protocol/src/handshake.rs b/network/protocol/src/handshake.rs new file mode 100644 index 0000000000..cc46791fc6 --- /dev/null +++ b/network/protocol/src/handshake.rs @@ -0,0 +1,227 @@ +use crate::{ + frame::InitFrame, + types::{ + Pid, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2, VELOREN_MAGIC_NUMBER, + VELOREN_NETWORK_VERSION, + }, + InitProtocol, InitProtocolError, ProtocolError, +}; +use async_trait::async_trait; +use tracing::{debug, error, info, trace}; + +// Protocols might define a Reliable Variant for auto Handshake discovery +// this doesn't need to be effective +#[async_trait] +pub trait ReliableDrain { + async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError>; +} + +#[async_trait] +pub trait ReliableSink { + async fn recv(&mut self) -> Result; +} + +#[async_trait] +impl InitProtocol for (D, S) +where + D: ReliableDrain + Send, + S: ReliableSink + Send, +{ + async fn initialize( + &mut self, + initializer: bool, + local_pid: Pid, + local_secret: u128, + ) -> Result<(Pid, Sid, u128), InitProtocolError> { + #[cfg(debug_assertions)] + const WRONG_NUMBER: &'static [u8] = "Handshake does not contain the magic number required \ + by veloren server.\nWe are not sure if you are a \ + valid veloren client.\nClosing the connection" + .as_bytes(); + #[cfg(debug_assertions)] + const WRONG_VERSION: &'static str = "Handshake does contain a correct magic number, but \ + invalid version.\nWe don't know how to communicate \ + with you.\nClosing the connection"; + const ERR_S: &str = "Got A Raw Message, these are usually Debug Messages indicating that \ + something went wrong on network layer and connection will be closed"; + + let drain = &mut self.0; + let sink = &mut self.1; + + if initializer { + drain + .send(InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }) + .await?; + } + + match sink.recv().await? { + InitFrame::Handshake { + magic_number, + version, + } => { + trace!(?magic_number, ?version, "Recv handshake"); + if magic_number != VELOREN_MAGIC_NUMBER { + error!(?magic_number, "Connection with invalid magic_number"); + #[cfg(debug_assertions)] + drain.send(InitFrame::Raw(WRONG_NUMBER.to_vec())).await?; + Err(InitProtocolError::WrongMagicNumber(magic_number)) + } else if version != VELOREN_NETWORK_VERSION { + error!(?version, "Connection with wrong network version"); + #[cfg(debug_assertions)] + drain + .send(InitFrame::Raw( + format!( + "{} Our Version: {:?}\nYour Version: {:?}\nClosing the connection", + WRONG_VERSION, VELOREN_NETWORK_VERSION, version, + ) + .as_bytes() + .to_vec(), + )) + .await?; + Err(InitProtocolError::WrongVersion(version)) + } else { + trace!("Handshake Frame completed"); + if initializer { + drain + .send(InitFrame::Init { + pid: local_pid, + secret: local_secret, + }) + .await?; + } else { + drain + .send(InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }) + .await?; + } + Ok(()) + } + }, + InitFrame::Raw(bytes) => { + match std::str::from_utf8(bytes.as_slice()) { + Ok(string) => error!(?string, ERR_S), + _ => error!(?bytes, ERR_S), + } + Err(InitProtocolError::Closed) + }, + _ => { + info!("Handshake failed"); + Err(InitProtocolError::Closed) + }, + }?; + + match sink.recv().await? { + InitFrame::Init { pid, secret } => { + debug!(?pid, "Participant send their ID"); + let stream_id_offset = if initializer { + STREAM_ID_OFFSET1 + } else { + drain + .send(InitFrame::Init { + pid: local_pid, + secret: local_secret, + }) + .await?; + STREAM_ID_OFFSET2 + }; + info!(?pid, "This Handshake is now configured!"); + Ok((pid, stream_id_offset, secret)) + }, + InitFrame::Raw(bytes) => { + match std::str::from_utf8(bytes.as_slice()) { + Ok(string) => error!(?string, ERR_S), + _ => error!(?bytes, ERR_S), + } + Err(InitProtocolError::Closed) + }, + _ => { + info!("Handshake failed"); + Err(InitProtocolError::Closed) + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{mpsc::test_utils::*, InitProtocolError}; + + #[tokio::test] + async fn handshake_drop_start() { + let [mut p1, p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { + let _ = p2; + }); + let (r1, _) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Err(InitProtocolError::Closed)); + } + + #[tokio::test] + async fn handshake_wrong_magic_number() { + let [mut p1, mut p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { + let _ = p2.1.recv().await?; + p2.0.send(InitFrame::Handshake { + magic_number: *b"woopsie", + version: VELOREN_NETWORK_VERSION, + }) + .await?; + let _ = p2.1.recv().await?; + Result::<(), InitProtocolError>::Ok(()) + }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!( + r1.unwrap(), + Err(InitProtocolError::WrongMagicNumber(*b"woopsie")) + ); + assert_eq!(r2.unwrap(), Ok(())); + } + + #[tokio::test] + async fn handshake_wrong_version() { + let [mut p1, mut p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { + let _ = p2.1.recv().await?; + p2.0.send(InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: [0, 1, 2], + }) + .await?; + let _ = p2.1.recv().await?; + let _ = p2.1.recv().await?; //this should be closed now + Ok(()) + }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Err(InitProtocolError::WrongVersion([0, 1, 2]))); + assert_eq!(r2.unwrap(), Err(InitProtocolError::Closed)); + } + + #[tokio::test] + async fn handshake_unexpected_raw() { + let [mut p1, mut p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { + let _ = p2.1.recv().await?; + p2.0.send(InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }) + .await?; + let _ = p2.1.recv().await?; + p2.0.send(InitFrame::Raw(b"Hello World".to_vec())).await?; + Result::<(), InitProtocolError>::Ok(()) + }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Err(InitProtocolError::Closed)); + assert_eq!(r2.unwrap(), Ok(())); + } +} diff --git a/network/protocol/src/io.rs b/network/protocol/src/io.rs new file mode 100644 index 0000000000..c4e3eba43e --- /dev/null +++ b/network/protocol/src/io.rs @@ -0,0 +1,62 @@ +use crate::ProtocolError; +use async_trait::async_trait; +use std::collections::VecDeque; +///! I/O-Free (Sans-I/O) protocol https://sans-io.readthedocs.io/how-to-sans-io.html + +// Protocols should base on the Unrealiable variants to get something effective! +#[async_trait] +pub trait UnreliableDrain: Send { + type DataFormat; + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError>; +} + +#[async_trait] +pub trait UnreliableSink: Send { + type DataFormat; + async fn recv(&mut self) -> Result; +} + +pub struct BaseDrain { + data: VecDeque>, +} + +pub struct BaseSink { + data: VecDeque>, +} + +impl BaseDrain { + pub fn new() -> Self { + Self { + data: VecDeque::new(), + } + } +} + +impl BaseSink { + pub fn new() -> Self { + Self { + data: VecDeque::new(), + } + } +} + +//TODO: Test Sinks that drop 20% by random and log that + +#[async_trait] +impl UnreliableDrain for BaseDrain { + type DataFormat = Vec; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.data.push_back(data); + Ok(()) + } +} + +#[async_trait] +impl UnreliableSink for BaseSink { + type DataFormat = Vec; + + async fn recv(&mut self) -> Result { + self.data.pop_front().ok_or(ProtocolError::Closed) + } +} diff --git a/network/protocol/src/lib.rs b/network/protocol/src/lib.rs new file mode 100644 index 0000000000..8d49ed58c9 --- /dev/null +++ b/network/protocol/src/lib.rs @@ -0,0 +1,75 @@ +mod event; +mod frame; +mod handshake; +mod io; +mod message; +mod metrics; +mod mpsc; +mod prio; +mod tcp; +mod types; + +pub use event::ProtocolEvent; +pub use io::{BaseDrain, BaseSink, UnreliableDrain, UnreliableSink}; +pub use message::MessageBuffer; +pub use metrics::ProtocolMetricCache; +#[cfg(feature = "metrics")] +pub use metrics::ProtocolMetrics; +pub use mpsc::{MpscMsg, MpscRecvProtcol, MpscSendProtcol}; +pub use tcp::{TcpRecvProtcol, TcpSendProtcol}; +pub use types::{Bandwidth, Cid, Mid, Pid, Prio, Promises, Sid, VELOREN_NETWORK_VERSION}; + +///use at own risk, might change any time, for internal benchmarks +pub mod _internal { + pub use crate::frame::Frame; +} + +use async_trait::async_trait; + +#[async_trait] +pub trait InitProtocol { + async fn initialize( + &mut self, + initializer: bool, + local_pid: Pid, + secret: u128, + ) -> Result<(Pid, Sid, u128), InitProtocolError>; +} + +#[async_trait] +pub trait SendProtocol { + //a stream MUST be bound to a specific Protocol, there will be a failover + // feature comming for the case where a Protocol fails completly + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError>; + async fn flush( + &mut self, + bandwidth: Bandwidth, + dt: std::time::Duration, + ) -> Result<(), ProtocolError>; +} + +#[async_trait] +pub trait RecvProtocol { + async fn recv(&mut self) -> Result; +} + +#[derive(Debug, PartialEq)] +pub enum InitProtocolError { + Closed, + WrongMagicNumber([u8; 7]), + WrongVersion([u32; 3]), +} + +#[derive(Debug, PartialEq)] +/// When you return closed you must stay closed! +pub enum ProtocolError { + Closed, +} + +impl From for InitProtocolError { + fn from(err: ProtocolError) -> Self { + match err { + ProtocolError::Closed => InitProtocolError::Closed, + } + } +} diff --git a/network/protocol/src/message.rs b/network/protocol/src/message.rs new file mode 100644 index 0000000000..1bda1325ad --- /dev/null +++ b/network/protocol/src/message.rs @@ -0,0 +1,127 @@ +use crate::{ + frame::Frame, + types::{Mid, Sid}, +}; +use std::{collections::VecDeque, sync::Arc}; + +//Todo: Evaluate switching to VecDeque for quickly adding and removing data +// from front, back. +// - It would prob require custom bincode code but thats possible. +#[cfg_attr(test, derive(PartialEq))] +pub struct MessageBuffer { + pub data: Vec, +} + +impl std::fmt::Debug for MessageBuffer { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + //TODO: small messages! + let len = self.data.len(); + if len > 20 { + write!( + f, + "MessageBuffer(len: {}, {}, {}, {}, {:X?}..{:X?})", + len, + u32::from_le_bytes([self.data[0], self.data[1], self.data[2], self.data[3]]), + u32::from_le_bytes([self.data[4], self.data[5], self.data[6], self.data[7]]), + u32::from_le_bytes([self.data[8], self.data[9], self.data[10], self.data[11]]), + &self.data[13..16], + &self.data[len - 8..len] + ) + } else { + write!(f, "MessageBuffer(len: {}, {:?})", len, &self.data[..]) + } + } +} + +/// Contains a outgoing message and store what was *send* and *confirmed* +/// All Chunks have the same size, except for the last chunk which can end +/// earlier. E.g. +/// ```ignore +/// msg = OutgoingMessage::new(); +/// msg.next(); +/// msg.next(); +/// msg.confirm(1); +/// msg.confirm(2); +/// ``` +#[derive(Debug)] +pub(crate) struct OutgoingMessage { + buffer: Arc, + send_index: u64, // 3 => 4200 (3*FRAME_DATA_SIZE) + send_header: bool, + mid: Mid, + sid: Sid, + max_index: u64, //speedup + missing_header: bool, + missing_indices: VecDeque, +} + +impl OutgoingMessage { + pub(crate) const FRAME_DATA_SIZE: u64 = 1400; + + pub(crate) fn new(buffer: Arc, mid: Mid, sid: Sid) -> Self { + let max_index = + (buffer.data.len() as u64 + Self::FRAME_DATA_SIZE - 1) / Self::FRAME_DATA_SIZE; + Self { + buffer, + send_index: 0, + send_header: false, + mid, + sid, + max_index, + missing_header: false, + missing_indices: VecDeque::new(), + } + } + + /// all has been send once, but might been resend due to failures. + #[allow(dead_code)] + pub(crate) fn initial_sent(&self) -> bool { self.send_index == self.max_index } + + pub fn get_header(&self) -> Frame { + Frame::DataHeader { + mid: self.mid, + sid: self.sid, + length: self.buffer.data.len() as u64, + } + } + + pub fn get_data(&self, index: u64) -> Frame { + let start = index * Self::FRAME_DATA_SIZE; + let to_send = std::cmp::min( + self.buffer.data[start as usize..].len() as u64, + Self::FRAME_DATA_SIZE, + ); + Frame::Data { + mid: self.mid, + start, + data: self.buffer.data[start as usize..][..to_send as usize].to_vec(), + } + } + + #[allow(dead_code)] + pub(crate) fn set_missing(&mut self, missing_header: bool, missing_indicies: VecDeque) { + self.missing_header = missing_header; + self.missing_indices = missing_indicies; + } + + /// returns if something was added + pub(crate) fn next(&mut self) -> Option { + if !self.send_header { + self.send_header = true; + Some(self.get_header()) + } else if self.send_index < self.max_index { + self.send_index += 1; + Some(self.get_data(self.send_index - 1)) + } else if self.missing_header { + self.missing_header = false; + Some(self.get_header()) + } else if let Some(index) = self.missing_indices.pop_front() { + Some(self.get_data(index)) + } else { + None + } + } + + pub(crate) fn get_sid_len(&self) -> (Sid, u64) { (self.sid, self.buffer.data.len() as u64) } +} diff --git a/network/protocol/src/metrics.rs b/network/protocol/src/metrics.rs new file mode 100644 index 0000000000..715a06fc9d --- /dev/null +++ b/network/protocol/src/metrics.rs @@ -0,0 +1,414 @@ +use crate::types::Sid; +#[cfg(feature = "metrics")] +use prometheus::{IntCounterVec, IntGaugeVec, Opts, Registry}; +#[cfg(feature = "metrics")] +use std::{error::Error, sync::Arc}; + +#[allow(dead_code)] +pub enum RemoveReason { + Finished, + Dropped, +} + +#[cfg(feature = "metrics")] +pub struct ProtocolMetrics { + // smsg=send_msg rdata=receive_data + // i=in o=out + // t=total b=byte throughput + //e.g smsg_it = sending messages, in (responsibility of protocol) total + + // based on CHANNEL/STREAM + /// messages added to be send total, by STREAM, + smsg_it: IntCounterVec, + /// messages bytes added to be send throughput, by STREAM, + smsg_ib: IntCounterVec, + /// messages removed from to be send, because they where finished total, by + /// STREAM AND REASON(finished/canceled), + smsg_ot: IntCounterVec, + /// messages bytes removed from to be send throughput, because they where + /// finished total, by STREAM AND REASON(finished/dropped), + smsg_ob: IntCounterVec, + /// data frames send by prio by CHANNEL, + sdata_frames_t: IntCounterVec, + /// data frames bytes send by prio by CHANNEL, + sdata_frames_b: IntCounterVec, + + // based on CHANNEL/STREAM + /// messages added to be received total, by STREAM, + rmsg_it: IntCounterVec, + /// messages bytes added to be received throughput, by STREAM, + rmsg_ib: IntCounterVec, + /// messages removed from to be received, because they where finished total, + /// by STREAM AND REASON(finished/canceled), + rmsg_ot: IntCounterVec, + /// messages bytes removed from to be received throughput, because they + /// where finished total, by STREAM AND REASON(finished/dropped), + rmsg_ob: IntCounterVec, + /// data frames send by prio by CHANNEL, + rdata_frames_t: IntCounterVec, + /// data frames bytes send by prio by CHANNEL, + rdata_frames_b: IntCounterVec, + /// ping per CHANNEL //TODO: implement + ping: IntGaugeVec, +} + +#[cfg(feature = "metrics")] +#[derive(Debug, Clone)] +pub struct ProtocolMetricCache { + cid: String, + m: Arc, +} + +#[cfg(not(feature = "metrics"))] +#[derive(Debug, Clone)] +pub struct ProtocolMetricCache {} + +#[cfg(feature = "metrics")] +impl ProtocolMetrics { + pub fn new() -> Result> { + let smsg_it = IntCounterVec::new( + Opts::new( + "send_messages_in_total", + "All Messages that are added to this Protocol to be send at stream level", + ), + &["channel", "stream"], + )?; + let smsg_ib = IntCounterVec::new( + Opts::new( + "send_messages_in_throughput", + "All Message bytes that are added to this Protocol to be send at stream level", + ), + &["channel", "stream"], + )?; + let smsg_ot = IntCounterVec::new( + Opts::new( + "send_messages_out_total", + "All Messages that are removed from this Protocol to be send at stream and \ + reason(finished/canceled) level", + ), + &["channel", "stream", "reason"], + )?; + let smsg_ob = IntCounterVec::new( + Opts::new( + "send_messages_out_throughput", + "All Message bytes that are removed from this Protocol to be send at stream and \ + reason(finished/canceled) level", + ), + &["channel", "stream", "reason"], + )?; + let sdata_frames_t = IntCounterVec::new( + Opts::new( + "send_data_frames_total", + "Number of data frames send per channel", + ), + &["channel"], + )?; + let sdata_frames_b = IntCounterVec::new( + Opts::new( + "send_data_frames_throughput", + "Number of data frames bytes send per channel", + ), + &["channel"], + )?; + + let rmsg_it = IntCounterVec::new( + Opts::new( + "recv_messages_in_total", + "All Messages that are added to this Protocol to be received at stream level", + ), + &["channel", "stream"], + )?; + let rmsg_ib = IntCounterVec::new( + Opts::new( + "recv_messages_in_throughput", + "All Message bytes that are added to this Protocol to be received at stream level", + ), + &["channel", "stream"], + )?; + let rmsg_ot = IntCounterVec::new( + Opts::new( + "recv_messages_out_total", + "All Messages that are removed from this Protocol to be received at stream and \ + reason(finished/canceled) level", + ), + &["channel", "stream", "reason"], + )?; + let rmsg_ob = IntCounterVec::new( + Opts::new( + "recv_messages_out_throughput", + "All Message bytes that are removed from this Protocol to be received at stream \ + and reason(finished/canceled) level", + ), + &["channel", "stream", "reason"], + )?; + let rdata_frames_t = IntCounterVec::new( + Opts::new( + "recv_data_frames_total", + "Number of data frames received per channel", + ), + &["channel"], + )?; + let rdata_frames_b = IntCounterVec::new( + Opts::new( + "recv_data_frames_throughput", + "Number of data frames bytes received per channel", + ), + &["channel"], + )?; + let ping = IntGaugeVec::new(Opts::new("ping", "Ping per channel"), &["channel"])?; + + Ok(Self { + smsg_it, + smsg_ib, + smsg_ot, + smsg_ob, + sdata_frames_t, + sdata_frames_b, + rmsg_it, + rmsg_ib, + rmsg_ot, + rmsg_ob, + rdata_frames_t, + rdata_frames_b, + ping, + }) + } + + pub fn register(&self, registry: &Registry) -> Result<(), Box> { + registry.register(Box::new(self.smsg_it.clone()))?; + registry.register(Box::new(self.smsg_ib.clone()))?; + registry.register(Box::new(self.smsg_ot.clone()))?; + registry.register(Box::new(self.smsg_ob.clone()))?; + registry.register(Box::new(self.sdata_frames_t.clone()))?; + registry.register(Box::new(self.sdata_frames_b.clone()))?; + registry.register(Box::new(self.rmsg_it.clone()))?; + registry.register(Box::new(self.rmsg_ib.clone()))?; + registry.register(Box::new(self.rmsg_ot.clone()))?; + registry.register(Box::new(self.rmsg_ob.clone()))?; + registry.register(Box::new(self.rdata_frames_t.clone()))?; + registry.register(Box::new(self.rdata_frames_b.clone()))?; + registry.register(Box::new(self.ping.clone()))?; + Ok(()) + } +} + +#[cfg(feature = "metrics")] +impl ProtocolMetricCache { + pub fn new(channel_key: &str, metrics: Arc) -> Self { + Self { + cid: channel_key.to_string(), + m: metrics, + } + } + + pub(crate) fn smsg_it(&self, sid: Sid) { + self.m + .smsg_it + .with_label_values(&[&self.cid, &sid.to_string()]) + .inc(); + } + + pub(crate) fn smsg_ib(&self, sid: Sid, bytes: u64) { + self.m + .smsg_ib + .with_label_values(&[&self.cid, &sid.to_string()]) + .inc_by(bytes); + } + + pub(crate) fn smsg_ot(&self, sid: Sid, reason: RemoveReason) { + self.m + .smsg_ot + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .inc(); + } + + pub(crate) fn smsg_ob(&self, sid: Sid, reason: RemoveReason, bytes: u64) { + self.m + .smsg_ob + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .inc_by(bytes); + } + + pub(crate) fn sdata_frames_t(&self) { + self.m.sdata_frames_t.with_label_values(&[&self.cid]).inc(); + } + + pub(crate) fn sdata_frames_b(&self, bytes: u64) { + self.m + .sdata_frames_b + .with_label_values(&[&self.cid]) + .inc_by(bytes); + } + + pub(crate) fn rmsg_it(&self, sid: Sid) { + self.m + .rmsg_it + .with_label_values(&[&self.cid, &sid.to_string()]) + .inc(); + } + + pub(crate) fn rmsg_ib(&self, sid: Sid, bytes: u64) { + self.m + .rmsg_ib + .with_label_values(&[&self.cid, &sid.to_string()]) + .inc_by(bytes); + } + + pub(crate) fn rmsg_ot(&self, sid: Sid, reason: RemoveReason) { + self.m + .rmsg_ot + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .inc(); + } + + pub(crate) fn rmsg_ob(&self, sid: Sid, reason: RemoveReason, bytes: u64) { + self.m + .rmsg_ob + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .inc_by(bytes); + } + + pub(crate) fn rdata_frames_t(&self) { + self.m.rdata_frames_t.with_label_values(&[&self.cid]).inc(); + } + + pub(crate) fn rdata_frames_b(&self, bytes: u64) { + self.m + .rdata_frames_b + .with_label_values(&[&self.cid]) + .inc_by(bytes); + } + + #[cfg(test)] + pub(crate) fn assert_msg(&self, sid: Sid, cnt: u64, reason: RemoveReason) { + assert_eq!( + self.m + .smsg_it + .with_label_values(&[&self.cid, &sid.to_string()]) + .get(), + cnt + ); + assert_eq!( + self.m + .smsg_ot + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .get(), + cnt + ); + assert_eq!( + self.m + .rmsg_it + .with_label_values(&[&self.cid, &sid.to_string()]) + .get(), + cnt + ); + assert_eq!( + self.m + .rmsg_ot + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .get(), + cnt + ); + } + + #[cfg(test)] + pub(crate) fn assert_msg_bytes(&self, sid: Sid, bytes: u64, reason: RemoveReason) { + assert_eq!( + self.m + .smsg_ib + .with_label_values(&[&self.cid, &sid.to_string()]) + .get(), + bytes + ); + assert_eq!( + self.m + .smsg_ob + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .get(), + bytes + ); + assert_eq!( + self.m + .rmsg_ib + .with_label_values(&[&self.cid, &sid.to_string()]) + .get(), + bytes + ); + assert_eq!( + self.m + .rmsg_ob + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .get(), + bytes + ); + } + + #[cfg(test)] + pub(crate) fn assert_data_frames(&self, cnt: u64) { + assert_eq!( + self.m.sdata_frames_t.with_label_values(&[&self.cid]).get(), + cnt + ); + assert_eq!( + self.m.rdata_frames_t.with_label_values(&[&self.cid]).get(), + cnt + ); + } + + #[cfg(test)] + pub(crate) fn assert_data_frames_bytes(&self, bytes: u64) { + assert_eq!( + self.m.sdata_frames_b.with_label_values(&[&self.cid]).get(), + bytes + ); + assert_eq!( + self.m.rdata_frames_b.with_label_values(&[&self.cid]).get(), + bytes + ); + } +} + +#[cfg(feature = "metrics")] +impl std::fmt::Debug for ProtocolMetrics { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ProtocolMetrics()") + } +} + +#[cfg(not(feature = "metrics"))] +impl ProtocolMetricCache { + pub(crate) fn smsg_it(&self, _sid: Sid) {} + + pub(crate) fn smsg_ib(&self, _sid: Sid, _b: u64) {} + + pub(crate) fn smsg_ot(&self, _sid: Sid, _reason: RemoveReason) {} + + pub(crate) fn smsg_ob(&self, _sid: Sid, _reason: RemoveReason, _b: u64) {} + + pub(crate) fn sdata_frames_t(&self) {} + + pub(crate) fn sdata_frames_b(&self, _b: u64) {} + + pub(crate) fn rmsg_it(&self, _sid: Sid) {} + + pub(crate) fn rmsg_ib(&self, _sid: Sid, _b: u64) {} + + pub(crate) fn rmsg_ot(&self, _sid: Sid, _reason: RemoveReason) {} + + pub(crate) fn rmsg_ob(&self, _sid: Sid, _reason: RemoveReason, _b: u64) {} + + pub(crate) fn rdata_frames_t(&self) {} + + pub(crate) fn rdata_frames_b(&self, _b: u64) {} +} + +impl RemoveReason { + #[cfg(feature = "metrics")] + fn to_str(&self) -> &str { + match self { + RemoveReason::Dropped => "Dropped", + RemoveReason::Finished => "Finished", + } + } +} diff --git a/network/protocol/src/mpsc.rs b/network/protocol/src/mpsc.rs new file mode 100644 index 0000000000..3e9e5d55fe --- /dev/null +++ b/network/protocol/src/mpsc.rs @@ -0,0 +1,217 @@ +use crate::{ + event::ProtocolEvent, + frame::InitFrame, + handshake::{ReliableDrain, ReliableSink}, + io::{UnreliableDrain, UnreliableSink}, + metrics::{ProtocolMetricCache, RemoveReason}, + types::Bandwidth, + ProtocolError, RecvProtocol, SendProtocol, +}; +use async_trait::async_trait; +use std::time::{Duration, Instant}; + +pub /* should be private */ enum MpscMsg { + Event(ProtocolEvent), + InitFrame(InitFrame), +} + +#[derive(Debug)] +pub struct MpscSendProtcol +where + D: UnreliableDrain, +{ + drain: D, + last: Instant, + metrics: ProtocolMetricCache, +} + +#[derive(Debug)] +pub struct MpscRecvProtcol +where + S: UnreliableSink, +{ + sink: S, + metrics: ProtocolMetricCache, +} + +impl MpscSendProtcol +where + D: UnreliableDrain, +{ + pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self { + Self { + drain, + last: Instant::now(), + metrics, + } + } +} + +impl MpscRecvProtcol +where + S: UnreliableSink, +{ + pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self { Self { sink, metrics } } +} + +#[async_trait] +impl SendProtocol for MpscSendProtcol +where + D: UnreliableDrain, +{ + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + match &event { + ProtocolEvent::Message { + buffer, + mid: _, + sid, + } => { + let sid = *sid; + let bytes = buffer.data.len() as u64; + self.metrics.smsg_it(sid); + self.metrics.smsg_ib(sid, bytes); + let r = self.drain.send(MpscMsg::Event(event)).await; + self.metrics.smsg_ot(sid, RemoveReason::Finished); + self.metrics.smsg_ob(sid, RemoveReason::Finished, bytes); + r + }, + _ => self.drain.send(MpscMsg::Event(event)).await, + } + } + + async fn flush(&mut self, _: Bandwidth, _: Duration) -> Result<(), ProtocolError> { Ok(()) } +} + +#[async_trait] +impl RecvProtocol for MpscRecvProtcol +where + S: UnreliableSink, +{ + async fn recv(&mut self) -> Result { + match self.sink.recv().await? { + MpscMsg::Event(e) => { + if let ProtocolEvent::Message { + buffer, + mid: _, + sid, + } = &e + { + let sid = *sid; + let bytes = buffer.data.len() as u64; + self.metrics.rmsg_it(sid); + self.metrics.rmsg_ib(sid, bytes); + self.metrics.rmsg_ot(sid, RemoveReason::Finished); + self.metrics.rmsg_ob(sid, RemoveReason::Finished, bytes); + } + Ok(e) + }, + MpscMsg::InitFrame(_) => Err(ProtocolError::Closed), + } + } +} + +#[async_trait] +impl ReliableDrain for MpscSendProtcol +where + D: UnreliableDrain, +{ + async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> { + self.drain.send(MpscMsg::InitFrame(frame)).await + } +} + +#[async_trait] +impl ReliableSink for MpscRecvProtcol +where + S: UnreliableSink, +{ + async fn recv(&mut self) -> Result { + match self.sink.recv().await? { + MpscMsg::Event(_) => Err(ProtocolError::Closed), + MpscMsg::InitFrame(f) => Ok(f), + } + } +} + +#[cfg(test)] +pub mod test_utils { + use super::*; + use crate::{ + io::*, + metrics::{ProtocolMetricCache, ProtocolMetrics}, + }; + use async_channel::*; + use std::sync::Arc; + + pub struct ACDrain { + sender: Sender, + } + + pub struct ACSink { + receiver: Receiver, + } + + pub fn ac_bound( + cap: usize, + metrics: Option, + ) -> [(MpscSendProtcol, MpscRecvProtcol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("mpsc", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + MpscSendProtcol::new(ACDrain { sender: s1 }, m.clone()), + MpscRecvProtcol::new(ACSink { receiver: r2 }, m.clone()), + ), + ( + MpscSendProtcol::new(ACDrain { sender: s2 }, m.clone()), + MpscRecvProtcol::new(ACSink { receiver: r1 }, m.clone()), + ), + ] + } + + #[async_trait] + impl UnreliableDrain for ACDrain { + type DataFormat = MpscMsg; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for ACSink { + type DataFormat = MpscMsg; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + mpsc::test_utils::*, + types::{Pid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2}, + InitProtocol, + }; + + #[tokio::test] + async fn handshake_all_good() { + let [mut p1, mut p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42))); + assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337))); + } +} diff --git a/network/protocol/src/prio.rs b/network/protocol/src/prio.rs new file mode 100644 index 0000000000..35b7067352 --- /dev/null +++ b/network/protocol/src/prio.rs @@ -0,0 +1,139 @@ +use crate::{ + frame::Frame, + message::{MessageBuffer, OutgoingMessage}, + metrics::{ProtocolMetricCache, RemoveReason}, + types::{Bandwidth, Mid, Prio, Promises, Sid}, +}; +use std::{collections::HashMap, sync::Arc, time::Duration}; + +#[derive(Debug)] +struct StreamInfo { + pub(crate) guaranteed_bandwidth: Bandwidth, + pub(crate) prio: Prio, + pub(crate) promises: Promises, + pub(crate) messages: Vec, +} + +/// Responsible for queueing messages. +/// every stream has a guaranteed bandwidth and a prio 0-7. +/// when `n` Bytes are available in the buffer, first the guaranteed bandwidth +/// is used. Then remaining bandwidth is used to fill up the prios. +#[derive(Debug)] +pub(crate) struct PrioManager { + streams: HashMap, + metrics: ProtocolMetricCache, +} + +// Send everything ONCE, then keep it till it's confirmed + +impl PrioManager { + const HIGHEST_PRIO: u8 = 7; + + pub fn new(metrics: ProtocolMetricCache) -> Self { + Self { + streams: HashMap::new(), + metrics, + } + } + + pub fn open_stream( + &mut self, + sid: Sid, + prio: Prio, + promises: Promises, + guaranteed_bandwidth: Bandwidth, + ) { + self.streams.insert(sid, StreamInfo { + guaranteed_bandwidth, + prio, + promises, + messages: vec![], + }); + } + + pub fn try_close_stream(&mut self, sid: Sid) -> bool { + if let Some(si) = self.streams.get(&sid) { + if si.messages.is_empty() { + self.streams.remove(&sid); + return true; + } + } + false + } + + pub fn is_empty(&self) -> bool { self.streams.is_empty() } + + pub fn add(&mut self, buffer: Arc, mid: Mid, sid: Sid) { + self.streams + .get_mut(&sid) + .unwrap() + .messages + .push(OutgoingMessage::new(buffer, mid, sid)); + } + + /// bandwidth might be extended, as for technical reasons + /// guaranteed_bandwidth is used and frames are always 1400 bytes. + pub fn grab(&mut self, bandwidth: Bandwidth, dt: Duration) -> Vec { + let total_bytes = (bandwidth as f64 * dt.as_secs_f64()) as u64; + let mut cur_bytes = 0u64; + let mut frames = vec![]; + + let mut prios = [0u64; (Self::HIGHEST_PRIO + 1) as usize]; + let metrics = &self.metrics; + + let mut process_stream = + |stream: &mut StreamInfo, mut bandwidth: i64, cur_bytes: &mut u64| { + let mut finished = vec![]; + 'outer: for (i, msg) in stream.messages.iter_mut().enumerate() { + while let Some(frame) = msg.next() { + let b = if matches!(frame, Frame::DataHeader { .. }) { + 25 + } else { + 19 + OutgoingMessage::FRAME_DATA_SIZE + }; + bandwidth -= b as i64; + *cur_bytes += b; + frames.push(frame); + if bandwidth <= 0 { + break 'outer; + } + } + finished.push(i); + } + + //cleanup + for i in finished.iter().rev() { + let msg = stream.messages.remove(*i); + let (sid, bytes) = msg.get_sid_len(); + metrics.smsg_ot(sid, RemoveReason::Finished); + metrics.smsg_ob(sid, RemoveReason::Finished, bytes); + } + }; + + // Add guaranteed bandwidth + for (_, stream) in &mut self.streams { + prios[stream.prio.min(Self::HIGHEST_PRIO) as usize] += 1; + let stream_byte_cnt = (stream.guaranteed_bandwidth as f64 * dt.as_secs_f64()) as u64; + process_stream(stream, stream_byte_cnt as i64, &mut cur_bytes); + } + + if cur_bytes < total_bytes { + // Add optional bandwidth + for prio in 0..=Self::HIGHEST_PRIO { + if prios[prio as usize] == 0 { + continue; + } + let per_stream_bytes = (total_bytes - cur_bytes) / prios[prio as usize]; + + for (_, stream) in &mut self.streams { + if stream.prio != prio { + continue; + } + process_stream(stream, per_stream_bytes as i64, &mut cur_bytes); + } + } + } + + frames + } +} diff --git a/network/protocol/src/tcp.rs b/network/protocol/src/tcp.rs new file mode 100644 index 0000000000..e1c8e10e84 --- /dev/null +++ b/network/protocol/src/tcp.rs @@ -0,0 +1,584 @@ +use crate::{ + event::ProtocolEvent, + frame::{Frame, InitFrame}, + handshake::{ReliableDrain, ReliableSink}, + io::{UnreliableDrain, UnreliableSink}, + metrics::{ProtocolMetricCache, RemoveReason}, + prio::PrioManager, + types::Bandwidth, + ProtocolError, RecvProtocol, SendProtocol, +}; +use async_trait::async_trait; +use std::{ + collections::{HashMap, VecDeque}, + sync::Arc, + time::{Duration, Instant}, +}; +use tracing::info; + +#[derive(Debug)] +pub struct TcpSendProtcol +where + D: UnreliableDrain>, +{ + buffer: Vec, + store: PrioManager, + closing_streams: Vec, + pending_shutdown: bool, + drain: D, + last: Instant, + metrics: ProtocolMetricCache, +} + +#[derive(Debug)] +pub struct TcpRecvProtcol +where + S: UnreliableSink>, +{ + buffer: VecDeque, + incoming: HashMap, + sink: S, + metrics: ProtocolMetricCache, +} + +impl TcpSendProtcol +where + D: UnreliableDrain>, +{ + pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self { + Self { + buffer: vec![0u8; 1500], + store: PrioManager::new(metrics.clone()), + closing_streams: vec![], + pending_shutdown: false, + drain, + last: Instant::now(), + metrics, + } + } +} + +impl TcpRecvProtcol +where + S: UnreliableSink>, +{ + pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self { + Self { + buffer: VecDeque::new(), + incoming: HashMap::new(), + sink, + metrics, + } + } +} + +#[async_trait] +impl SendProtocol for TcpSendProtcol +where + D: UnreliableDrain>, +{ + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + match event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + self.store + .open_stream(sid, prio, promises, guaranteed_bandwidth); + let frame = event.to_frame(); + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + }, + ProtocolEvent::CloseStream { sid } => { + if self.store.try_close_stream(sid) { + let frame = event.to_frame(); + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + } else { + self.closing_streams.push(sid); + } + }, + ProtocolEvent::Shutdown => { + if self.store.is_empty() { + tracing::error!(?event, "send frame"); + let frame = event.to_frame(); + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + } else { + self.pending_shutdown = true; + } + }, + ProtocolEvent::Message { buffer, mid, sid } => { + self.metrics.smsg_it(sid); + self.metrics.smsg_ib(sid, buffer.data.len() as u64); + self.store.add(buffer, mid, sid); + }, + } + Ok(()) + } + + async fn flush(&mut self, bandwidth: Bandwidth, dt: Duration) -> Result<(), ProtocolError> { + let frames = self.store.grab(bandwidth, dt); + for frame in frames { + if let Frame::Data { + mid: _, + start: _, + data, + } = &frame + { + self.metrics.sdata_frames_t(); + self.metrics.sdata_frames_b(data.len() as u64); + } + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + tracing::warn!("send data frame, woop"); + } + let mut finished_streams = vec![]; + for (i, sid) in self.closing_streams.iter().enumerate() { + if self.store.try_close_stream(*sid) { + let frame = ProtocolEvent::CloseStream { sid: *sid }.to_frame(); + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + finished_streams.push(i); + } + } + for i in finished_streams.iter().rev() { + self.closing_streams.remove(*i); + } + if self.pending_shutdown && self.store.is_empty() { + tracing::error!("send shutdown frame"); + let frame = ProtocolEvent::Shutdown {}.to_frame(); + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + self.pending_shutdown = false; + } + Ok(()) + } +} + +use crate::{ + message::MessageBuffer, + types::{Mid, Sid}, +}; + +#[derive(Debug)] +struct IncomingMsg { + sid: Sid, + length: u64, + data: MessageBuffer, +} + +#[async_trait] +impl RecvProtocol for TcpRecvProtcol +where + S: UnreliableSink>, +{ + async fn recv(&mut self) -> Result { + tracing::error!(?self.buffer, "enter loop"); + 'outer: loop { + tracing::error!(?self.buffer, "continue loop"); + while let Some(frame) = Frame::to_frame(&mut self.buffer) { + tracing::error!(?frame, "recv frame"); + match frame { + Frame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), + Frame::OpenStream { + sid, + prio, + promises, + } => { + break 'outer Ok(ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth: 1_000_000, + }); + }, + Frame::CloseStream { sid } => { + break 'outer Ok(ProtocolEvent::CloseStream { sid }); + }, + Frame::DataHeader { sid, mid, length } => { + let m = IncomingMsg { + sid, + length, + data: MessageBuffer { data: vec![] }, + }; + self.metrics.rmsg_it(sid); + self.metrics.rmsg_ib(sid, length); + self.incoming.insert(mid, m); + }, + Frame::Data { + mid, + start: _, + mut data, + } => { + self.metrics.rdata_frames_t(); + self.metrics.rdata_frames_b(data.len() as u64); + let m = match self.incoming.get_mut(&mid) { + Some(m) => m, + None => { + info!("protocol violation by remote side: send Data before Header"); + break 'outer Err(ProtocolError::Closed); + }, + }; + m.data.data.append(&mut data); + if m.data.data.len() == m.length as usize { + // finished, yay + drop(m); + let m = self.incoming.remove(&mid).unwrap(); + self.metrics.rmsg_ot(m.sid, RemoveReason::Finished); + self.metrics.rmsg_ob( + m.sid, + RemoveReason::Finished, + m.data.data.len() as u64, + ); + break 'outer Ok(ProtocolEvent::Message { + sid: m.sid, + mid, + buffer: Arc::new(m.data), + }); + } + }, + }; + } + tracing::error!(?self.buffer, "receiving on tcp sink"); + let chunk = self.sink.recv().await?; + self.buffer.reserve(chunk.len()); + for b in chunk { + self.buffer.push_back(b); + } + tracing::error!(?self.buffer,"receiving on tcp sink done"); + } + } +} + +#[async_trait] +impl ReliableDrain for TcpSendProtcol +where + D: UnreliableDrain>, +{ + async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> { + let mut buffer = vec![0u8; 1500]; + let s = frame.to_bytes(&mut buffer); + buffer.truncate(s); + self.drain.send(buffer).await + } +} + +#[async_trait] +impl ReliableSink for TcpRecvProtcol +where + S: UnreliableSink>, +{ + async fn recv(&mut self) -> Result { + while self.buffer.len() < 100 { + let chunk = self.sink.recv().await?; + self.buffer.reserve(chunk.len()); + for b in chunk { + self.buffer.push_back(b); + } + let todo_use_bytes_instead = self.buffer.iter().map(|b| *b).collect(); + if let Some(frame) = InitFrame::to_frame(todo_use_bytes_instead) { + match frame { + InitFrame::Handshake { .. } => self.buffer.drain(.. InitFrame::HANDSHAKE_CNS + 1), + InitFrame::Init { .. } => self.buffer.drain(.. InitFrame::INIT_CNS + 1), + InitFrame::Raw { .. } => self.buffer.drain(.. InitFrame::RAW_CNS + 1), + }; + return Ok(frame); + } + } + Err(ProtocolError::Closed) + } +} + +#[cfg(test)] +mod test_utils { + //TCP protocol based on Channel + use super::*; + use crate::{ + io::*, + metrics::{ProtocolMetricCache, ProtocolMetrics}, + }; + use async_channel::*; + + pub struct TcpDrain { + pub sender: Sender>, + } + + pub struct TcpSink { + pub receiver: Receiver>, + } + + /// emulate Tcp protocol on Channels + pub fn tcp_bound( + cap: usize, + metrics: Option, + ) -> [(TcpSendProtcol, TcpRecvProtcol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("tcp", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + TcpSendProtcol::new(TcpDrain { sender: s1 }, m.clone()), + TcpRecvProtcol::new(TcpSink { receiver: r2 }, m.clone()), + ), + ( + TcpSendProtcol::new(TcpDrain { sender: s2 }, m.clone()), + TcpRecvProtcol::new(TcpSink { receiver: r1 }, m.clone()), + ), + ] + } + + #[async_trait] + impl UnreliableDrain for TcpDrain { + type DataFormat = Vec; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for TcpSink { + type DataFormat = Vec; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + metrics::{ProtocolMetricCache, ProtocolMetrics, RemoveReason}, + tcp::test_utils::*, + types::{Pid, Promises, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2}, + InitProtocol, MessageBuffer, ProtocolEvent, RecvProtocol, SendProtocol, + }; + use std::{sync::Arc, time::Duration}; + + #[tokio::test] + async fn handshake_all_good() { + let [mut p1, mut p2] = tcp_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42))); + assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337))); + } + + #[tokio::test] + async fn open_stream() { + let [p1, p2] = tcp_bound(10, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 9u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event.clone()).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn send_short_msg() { + let [p1, p2] = tcp_bound(10, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + mid: 0, + buffer: Arc::new(MessageBuffer { + data: vec![188u8; 600], + }), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + // 2nd short message + let event = ProtocolEvent::Message { + sid: Sid::new(10), + mid: 1, + buffer: Arc::new(MessageBuffer { + data: vec![7u8; 30], + }), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e) + } + + #[tokio::test] + async fn send_long_msg() { + let metrics = + ProtocolMetricCache::new("long_tcp", Arc::new(ProtocolMetrics::new().unwrap())); + let sid = Sid::new(1); + let [p1, p2] = tcp_bound(10000, Some(metrics.clone())); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + mid: 77, + buffer: Arc::new(MessageBuffer { + data: vec![99u8; 500_000], + }), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + metrics.assert_msg(sid, 1, RemoveReason::Finished); + metrics.assert_msg_bytes(sid, 500_000, RemoveReason::Finished); + metrics.assert_data_frames(358); + metrics.assert_data_frames_bytes(500_000); + } + + #[tokio::test] + async fn msg_finishes_after_close() { + let sid = Sid::new(1); + let [p1, p2] = tcp_bound(10000, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + mid: 77, + buffer: Arc::new(MessageBuffer { + data: vec![99u8; 500_000], + }), + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::CloseStream { sid }; + s.send(event).await.unwrap(); + //send + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + } + + #[tokio::test] + async fn msg_finishes_after_shutdown() { + let sid = Sid::new(1); + let [p1, p2] = tcp_bound(10000, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + mid: 77, + buffer: Arc::new(MessageBuffer { + data: vec![99u8; 500_000], + }), + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::Shutdown {}; + s.send(event).await.unwrap(); + let event = ProtocolEvent::CloseStream { sid }; + s.send(event).await.unwrap(); + //send + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Shutdown { .. })); + } + + #[tokio::test] + async fn header_and_data_in_seperate_msg() { + let sid = Sid::new(1); + let (s, r) = async_channel::bounded(10); + let m = ProtocolMetricCache::new("tcp", Arc::new(ProtocolMetrics::new().unwrap())); + let mut r = + super::TcpRecvProtcol::new(super::test_utils::TcpSink { receiver: r }, m.clone()); + + const DATA1: &[u8; 69] = + b"We need to make sure that its okay to send OPEN_STREAM and DATA_HEAD "; + const DATA2: &[u8; 95] = b"in one chunk and (DATA and CLOSE_STREAM) in the second chunk. and then keep the connection open"; + let mut buf = vec![0u8; 1500]; + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 0, + }; + let (i, _) = event.to_frame().to_bytes(&mut buf); + let (i2, _) = crate::frame::Frame::DataHeader { + mid: 99, + sid, + length: (DATA1.len() + DATA2.len()) as u64, + } + .to_bytes(&mut buf[i..]); + buf.truncate(i + i2); + s.send(buf).await.unwrap(); + + let mut buf = vec![0u8; 1500]; + let (i, _) = crate::frame::Frame::Data { + mid: 99, + start: 0, + data: DATA1.to_vec(), + } + .to_bytes(&mut buf); + let (i2, _) = crate::frame::Frame::Data { + mid: 99, + start: DATA1.len() as u64, + data: DATA2.to_vec(), + } + .to_bytes(&mut buf[i..]); + let (i3, _) = crate::frame::Frame::CloseStream { sid }.to_bytes(&mut buf[i + i2..]); + buf.truncate(i + i2 + i3); + s.send(buf).await.unwrap(); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + } +} diff --git a/network/src/types.rs b/network/protocol/src/types.rs similarity index 52% rename from network/src/types.rs rename to network/protocol/src/types.rs index d257ed808f..b6f63ca208 100644 --- a/network/src/types.rs +++ b/network/protocol/src/types.rs @@ -1,10 +1,10 @@ use bitflags::bitflags; use rand::Rng; -use std::convert::TryFrom; pub type Mid = u64; pub type Cid = u64; pub type Prio = u8; +pub type Bandwidth = u64; bitflags! { /// use promises to modify the behavior of [`Streams`]. @@ -21,9 +21,8 @@ bitflags! { /// this will guarantee that the other side will receive every message exactly /// once no messages are dropped const GUARANTEED_DELIVERY = 0b00000100; - /// this will enable the internal compression on this + /// this will enable the internal compression on this, only useable with #[cfg(feature = "compression")] /// [`Stream`](crate::api::Stream) - #[cfg(feature = "compression")] const COMPRESSED = 0b00001000; /// this will enable the internal encryption on this /// [`Stream`](crate::api::Stream) @@ -35,7 +34,7 @@ impl Promises { pub const fn to_le_bytes(self) -> [u8; 1] { self.bits.to_le_bytes() } } -pub(crate) const VELOREN_MAGIC_NUMBER: [u8; 7] = [86, 69, 76, 79, 82, 69, 78]; //VELOREN +pub(crate) const VELOREN_MAGIC_NUMBER: [u8; 7] = *b"VELOREN"; pub const VELOREN_NETWORK_VERSION: [u32; 3] = [0, 5, 0]; pub(crate) const STREAM_ID_OFFSET1: Sid = Sid::new(0); pub(crate) const STREAM_ID_OFFSET2: Sid = Sid::new(u64::MAX / 2); @@ -51,144 +50,18 @@ pub struct Pid { } #[derive(PartialEq, Eq, Hash, Clone, Copy)] -pub(crate) struct Sid { +pub struct Sid { internal: u64, } -// Used for Communication between Channel <----(TCP/UDP)----> Channel -#[derive(Debug)] -pub(crate) enum Frame { - Handshake { - magic_number: [u8; 7], - version: [u32; 3], - }, - Init { - pid: Pid, - secret: u128, - }, - Shutdown, /* Shutdown this channel gracefully, if all channels are shutdown, Participant - * is deleted */ - OpenStream { - sid: Sid, - prio: Prio, - promises: Promises, - }, - CloseStream { - sid: Sid, - }, - DataHeader { - mid: Mid, - sid: Sid, - length: u64, - }, - Data { - mid: Mid, - start: u64, - data: Vec, - }, - /* WARNING: Sending RAW is only used for debug purposes in case someone write a new API - * against veloren Server! */ - Raw(Vec), -} - -impl Frame { - #[cfg(feature = "metrics")] - pub const FRAMES_LEN: u8 = 8; - - #[cfg(feature = "metrics")] - pub const fn int_to_string(i: u8) -> &'static str { - match i { - 0 => "Handshake", - 1 => "Init", - 2 => "Shutdown", - 3 => "OpenStream", - 4 => "CloseStream", - 5 => "DataHeader", - 6 => "Data", - 7 => "Raw", - _ => "", - } - } - - #[cfg(feature = "metrics")] - pub fn get_int(&self) -> u8 { - match self { - Frame::Handshake { .. } => 0, - Frame::Init { .. } => 1, - Frame::Shutdown => 2, - Frame::OpenStream { .. } => 3, - Frame::CloseStream { .. } => 4, - Frame::DataHeader { .. } => 5, - Frame::Data { .. } => 6, - Frame::Raw(_) => 7, - } - } - - #[cfg(feature = "metrics")] - pub fn get_string(&self) -> &str { Self::int_to_string(self.get_int()) } - - pub fn gen_handshake(buf: [u8; 19]) -> Self { - let magic_number = *<&[u8; 7]>::try_from(&buf[0..7]).unwrap(); - Frame::Handshake { - magic_number, - version: [ - u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[7..11]).unwrap()), - u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[11..15]).unwrap()), - u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[15..19]).unwrap()), - ], - } - } - - pub fn gen_init(buf: [u8; 32]) -> Self { - Frame::Init { - pid: Pid::from_le_bytes(*<&[u8; 16]>::try_from(&buf[0..16]).unwrap()), - secret: u128::from_le_bytes(*<&[u8; 16]>::try_from(&buf[16..32]).unwrap()), - } - } - - pub fn gen_open_stream(buf: [u8; 10]) -> Self { - Frame::OpenStream { - sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), - prio: buf[8], - promises: Promises::from_bits_truncate(buf[9]), - } - } - - pub fn gen_close_stream(buf: [u8; 8]) -> Self { - Frame::CloseStream { - sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), - } - } - - pub fn gen_data_header(buf: [u8; 24]) -> Self { - Frame::DataHeader { - mid: Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), - sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()), - length: u64::from_le_bytes(*<&[u8; 8]>::try_from(&buf[16..24]).unwrap()), - } - } - - pub fn gen_data(buf: [u8; 18]) -> (Mid, u64, u16) { - let mid = Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()); - let start = u64::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()); - let length = u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[16..18]).unwrap()); - (mid, start, length) - } - - pub fn gen_raw(buf: [u8; 2]) -> u16 { - u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[0..2]).unwrap()) - } -} - impl Pid { /// create a new Pid with a random interior value /// /// # Example /// ```rust - /// use veloren_network::{Network, Pid}; + /// use veloren_network_protocol::Pid; /// /// let pid = Pid::new(); - /// let _ = Network::new(pid); /// ``` pub fn new() -> Self { Self { @@ -295,20 +168,7 @@ fn sixlet_to_str(sixlet: u128) -> char { #[cfg(test)] mod tests { - use crate::types::*; - - #[test] - fn frame_int2str() { - assert_eq!(Frame::int_to_string(3), "OpenStream"); - assert_eq!(Frame::int_to_string(7), "Raw"); - assert_eq!(Frame::int_to_string(8), ""); - } - - #[test] - fn frame_get_int() { - assert_eq!(Frame::get_int(&Frame::Raw(b"Foo".to_vec())), 7); - assert_eq!(Frame::get_int(&Frame::Shutdown), 2); - } + use super::*; #[test] fn frame_creation() { diff --git a/network/protocol/src/udp.rs b/network/protocol/src/udp.rs new file mode 100644 index 0000000000..ad5c31a126 --- /dev/null +++ b/network/protocol/src/udp.rs @@ -0,0 +1,37 @@ +// TODO: quick and dirty which activly waits for an ack! +/* +UDP protocol + +All Good Case: +S --HEADER--> R +S --DATA--> R +S --DATA--> R +S <--FINISHED-- R + + +Delayed HEADER: +S --HEADER--> +S --DATA--> R // STORE IT + --HEADER--> R // apply left data and continue +S --DATA--> R +S <--FINISHED-- R + + +NO HEADER: +S --HEADER--> ! +S --DATA--> R // STORE IT +S --DATA--> R // STORE IT +S <--MISSING_HEADER-- R // SEND AFTER 10 ms after DATA1 +S --HEADER--> R +S <--FINISHED-- R + + +NO DATA: +S --HEADER--> R +S --DATA--> R +S --DATA--> ! +S --STATUS--> R +S <--MISSING_DATA -- R +S --DATA--> R +S <--FINISHED-- R +*/ diff --git a/network/src/api.rs b/network/src/api.rs index ef6fb113db..08274c90be 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -3,13 +3,13 @@ //! //! (cd network/examples/async_recv && RUST_BACKTRACE=1 cargo run) use crate::{ - message::{partial_eq_bincode, IncomingMessage, Message, OutgoingMessage}, + message::{partial_eq_bincode, Message}, participant::{A2bStreamOpen, S2bShutdownBparticipant}, scheduler::Scheduler, - types::{Mid, Pid, Prio, Promises, Sid}, }; #[cfg(feature = "compression")] use lz_fear::raw::DecodeError; +use network_protocol::{Bandwidth, MessageBuffer, Mid, Pid, Prio, Promises, Sid}; #[cfg(feature = "metrics")] use prometheus::Registry; use serde::{de::DeserializeOwned, Serialize}; @@ -20,6 +20,7 @@ use std::{ atomic::{AtomicBool, Ordering}, Arc, }, + time::Duration, }; use tokio::{ io, @@ -49,8 +50,7 @@ pub enum ProtocolAddr { pub struct Participant { local_pid: Pid, remote_pid: Pid, - runtime: Arc, - a2b_stream_open_s: Mutex>, + a2b_open_stream_s: Mutex>, b2a_stream_opened_r: Mutex>, a2s_disconnect_s: A2sDisconnect, } @@ -75,9 +75,10 @@ pub struct Stream { mid: Mid, prio: Prio, promises: Promises, + guaranteed_bandwidth: Bandwidth, send_closed: Arc, - a2b_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, - b2a_msg_recv_r: Option>, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, + b2a_msg_recv_r: Option>, a2b_close_stream_s: Option>, } @@ -419,16 +420,14 @@ impl Participant { pub(crate) fn new( local_pid: Pid, remote_pid: Pid, - runtime: Arc, - a2b_stream_open_s: mpsc::UnboundedSender, + a2b_open_stream_s: mpsc::UnboundedSender, b2a_stream_opened_r: mpsc::UnboundedReceiver, a2s_disconnect_s: mpsc::UnboundedSender<(Pid, S2bShutdownBparticipant)>, ) -> Self { Self { local_pid, remote_pid, - runtime, - a2b_stream_open_s: Mutex::new(a2b_stream_open_s), + a2b_open_stream_s: Mutex::new(a2b_open_stream_s), b2a_stream_opened_r: Mutex::new(b2a_stream_opened_r), a2s_disconnect_s: Arc::new(Mutex::new(Some(a2s_disconnect_s))), } @@ -477,13 +476,13 @@ impl Participant { /// /// [`Streams`]: crate::api::Stream pub async fn open(&self, prio: u8, promises: Promises) -> Result { - let (p2a_return_stream_s, p2a_return_stream_r) = oneshot::channel(); - if let Err(e) = - self.a2b_stream_open_s - .lock() - .await - .send((prio, promises, p2a_return_stream_s)) - { + let (p2a_return_stream_s, p2a_return_stream_r) = oneshot::channel::(); + if let Err(e) = self.a2b_open_stream_s.lock().await.send(( + prio, + promises, + 100000u64, + p2a_return_stream_s, + )) { debug!(?e, "bParticipant is already closed, notifying"); return Err(ParticipantError::ParticipantDisconnected); } @@ -602,7 +601,7 @@ impl Participant { // Participant is connecting to Scheduler here, not as usual // Participant<->BParticipant a2s_disconnect_s - .send((pid, finished_sender)) + .send((pid, (Duration::from_secs(120), finished_sender))) .expect("Something is wrong in internal scheduler coding"); match finished_receiver.await { Ok(res) => { @@ -647,9 +646,10 @@ impl Stream { sid: Sid, prio: Prio, promises: Promises, + guaranteed_bandwidth: Bandwidth, send_closed: Arc, - a2b_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, - b2a_msg_recv_r: async_channel::Receiver, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, + b2a_msg_recv_r: async_channel::Receiver, a2b_close_stream_s: mpsc::UnboundedSender, ) -> Self { Self { @@ -658,6 +658,7 @@ impl Stream { mid: 0, prio, promises, + guaranteed_bandwidth, send_closed, a2b_msg_s, b2a_msg_recv_r: Some(b2a_msg_recv_r), @@ -776,12 +777,8 @@ impl Stream { } #[cfg(debug_assertions)] message.verify(&self); - self.a2b_msg_s.send((self.prio, self.sid, OutgoingMessage { - buffer: Arc::clone(&message.buffer), - cursor: 0, - mid: self.mid, - sid: self.sid, - }))?; + self.a2b_msg_s + .send((self.sid, Arc::clone(&message.buffer)))?; self.mid += 1; Ok(()) } @@ -864,7 +861,7 @@ impl Stream { Some(b2a_msg_recv_r) => { match b2a_msg_recv_r.recv().await { Ok(msg) => Ok(Message { - buffer: Arc::new(msg.buffer), + buffer: Arc::new(msg), #[cfg(feature = "compression")] compressed: self.promises.contains(Promises::COMPRESSED), }), @@ -917,7 +914,7 @@ impl Stream { Some(b2a_msg_recv_r) => match b2a_msg_recv_r.try_recv() { Ok(msg) => Ok(Some( Message { - buffer: Arc::new(msg.buffer), + buffer: Arc::new(msg), #[cfg(feature = "compression")] compressed: self.promises().contains(Promises::COMPRESSED), } @@ -953,47 +950,62 @@ impl Drop for Network { "Shutting down Participants of Network, while we still have metrics" ); let mut finished_receiver_list = vec![]; - self.runtime.block_on(async { - // we MUST avoid nested block_on, good that Network::Drop no longer triggers - // Participant::Drop directly but just the BParticipant - for (remote_pid, a2s_disconnect_s) in - self.participant_disconnect_sender.lock().await.drain() - { - match a2s_disconnect_s.lock().await.take() { - Some(a2s_disconnect_s) => { - trace!(?remote_pid, "Participants will be closed"); - let (finished_sender, finished_receiver) = oneshot::channel(); - finished_receiver_list.push((remote_pid, finished_receiver)); - a2s_disconnect_s.send((remote_pid, finished_sender)).expect( - "Scheduler is closed, but nobody other should be able to close it", - ); - }, - None => trace!(?remote_pid, "Participant already disconnected gracefully"), + + if tokio::runtime::Handle::try_current().is_ok() { + error!("we have a runtime but we mustn't, DROP NETWORK from async runtime is illegal") + } + + tokio::task::block_in_place(|| { + /* This context prevents panic if Dropped in a async fn */ + self.runtime.block_on(async { + for (remote_pid, a2s_disconnect_s) in + self.participant_disconnect_sender.lock().await.drain() + { + match a2s_disconnect_s.lock().await.take() { + Some(a2s_disconnect_s) => { + trace!(?remote_pid, "Participants will be closed"); + let (finished_sender, finished_receiver) = oneshot::channel(); + finished_receiver_list.push((remote_pid, finished_receiver)); + a2s_disconnect_s + .send((remote_pid, (Duration::from_secs(120), finished_sender))) + .expect( + "Scheduler is closed, but nobody other should be able to \ + close it", + ); + }, + None => trace!(?remote_pid, "Participant already disconnected gracefully"), + } } - } - //wait after close is requested for all - for (remote_pid, finished_receiver) in finished_receiver_list.drain(..) { - match finished_receiver.await { - Ok(Ok(())) => trace!(?remote_pid, "disconnect successful"), - Ok(Err(e)) => info!(?remote_pid, ?e, "unclean disconnect"), - Err(e) => warn!( - ?remote_pid, - ?e, - "Failed to get a message back from the scheduler, seems like the network \ - is already closed" - ), + //wait after close is requested for all + for (remote_pid, finished_receiver) in finished_receiver_list.drain(..) { + match finished_receiver.await { + Ok(Ok(())) => trace!(?remote_pid, "disconnect successful"), + Ok(Err(e)) => info!(?remote_pid, ?e, "unclean disconnect"), + Err(e) => warn!( + ?remote_pid, + ?e, + "Failed to get a message back from the scheduler, seems like the \ + network is already closed" + ), + } } - } + }); }); trace!(?pid, "Participants have shut down!"); trace!(?pid, "Shutting down Scheduler"); - self.shutdown_sender.take().unwrap().send(()).expect("Scheduler is closed, but nobody other should be able to close it"); + self.shutdown_sender + .take() + .unwrap() + .send(()) + .expect("Scheduler is closed, but nobody other should be able to close it"); debug!(?pid, "Network has shut down"); } } impl Drop for Participant { fn drop(&mut self) { + use tokio::sync::oneshot::error::TryRecvError; + // ignore closed, as we need to send it even though we disconnected the // participant from network let pid = self.remote_pid; @@ -1011,23 +1023,28 @@ impl Drop for Participant { ), Some(a2s_disconnect_s) => { debug!(?pid, "Disconnect from Scheduler"); - self.runtime.block_on(async { - let (finished_sender, finished_receiver) = oneshot::channel(); - a2s_disconnect_s - .send((self.remote_pid, finished_sender)) - .expect("Something is wrong in internal scheduler coding"); - if let Err(e) = finished_receiver - .await - .expect("Something is wrong in internal scheduler/participant coding") - { - error!( + let (finished_sender, mut finished_receiver) = oneshot::channel(); + a2s_disconnect_s + .send((self.remote_pid, (Duration::from_secs(120), finished_sender))) + .expect("Something is wrong in internal scheduler coding"); + loop { + match finished_receiver.try_recv() { + Ok(Ok(())) => break, + Ok(Err(e)) => error!( ?pid, ?e, "Error while dropping the participant, couldn't send all outgoing \ messages, dropping remaining" - ); - }; - }); + ), + Err(TryRecvError::Closed) => { + panic!("Something is wrong in internal scheduler/participant coding") + }, + Err(TryRecvError::Empty) => { + trace!("activly sleeping"); + std::thread::sleep(Duration::from_millis(20)); + }, + } + } }, } debug!(?pid, "Participant dropped"); @@ -1041,11 +1058,12 @@ impl Drop for Stream { let sid = self.sid; let pid = self.pid; debug!(?pid, ?sid, "Shutting down Stream"); - self.a2b_close_stream_s - .take() - .unwrap() - .send(self.sid) - .expect("bparticipant part of a gracefully shutdown must have crashed"); + if let Err(e) = self.a2b_close_stream_s.take().unwrap().send(self.sid) { + debug!( + ?e, + "bparticipant part of a gracefully shutdown was already closed" + ); + } } else { let sid = self.sid; let pid = self.pid; diff --git a/network/src/channel.rs b/network/src/channel.rs index 7928337bd1..654175fb1d 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -1,361 +1,231 @@ -#[cfg(feature = "metrics")] -use crate::metrics::NetworkMetrics; -use crate::{ - participant::C2pFrame, - protocols::Protocols, - types::{ - Cid, Frame, Pid, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2, VELOREN_MAGIC_NUMBER, - VELOREN_NETWORK_VERSION, - }, -}; -use futures_core::task::Poll; -use futures_util::{ - task::{noop_waker, Context}, - FutureExt, +use async_trait::async_trait; +use network_protocol::{ + InitProtocolError, MpscMsg, MpscRecvProtcol, MpscSendProtcol, Pid, ProtocolError, + ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtcol, TcpSendProtcol, + UnreliableDrain, UnreliableSink, }; #[cfg(feature = "metrics")] use std::sync::Arc; +use std::time::Duration; use tokio::{ - join, - sync::{mpsc, oneshot}, + io::{AsyncReadExt, AsyncWriteExt}, + net::tcp::{OwnedReadHalf, OwnedWriteHalf}, + sync::mpsc, }; -use tracing::*; -pub(crate) struct Channel { - cid: Cid, - c2w_frame_r: Option>, - read_stop_receiver: Option>, -} - -impl Channel { - pub fn new(cid: u64) -> (Self, mpsc::UnboundedSender, oneshot::Sender<()>) { - let (c2w_frame_s, c2w_frame_r) = mpsc::unbounded_channel::(); - let (read_stop_sender, read_stop_receiver) = oneshot::channel(); - ( - Self { - cid, - c2w_frame_r: Some(c2w_frame_r), - read_stop_receiver: Some(read_stop_receiver), - }, - c2w_frame_s, - read_stop_sender, - ) - } - - pub async fn run( - mut self, - protocol: Protocols, - mut w2c_cid_frame_s: mpsc::UnboundedSender, - mut leftover_cid_frame: Vec, - ) { - let c2w_frame_r = self.c2w_frame_r.take().unwrap(); - let read_stop_receiver = self.read_stop_receiver.take().unwrap(); - - //reapply leftovers from handshake - let cnt = leftover_cid_frame.len(); - trace!(?cnt, "Reapplying leftovers"); - for cid_frame in leftover_cid_frame.drain(..) { - w2c_cid_frame_s.send(cid_frame).unwrap(); - } - trace!(?cnt, "All leftovers reapplied"); - - trace!("Start up channel"); - match protocol { - Protocols::Tcp(tcp) => { - join!( - tcp.read_from_wire(self.cid, &mut w2c_cid_frame_s, read_stop_receiver), - tcp.write_to_wire(self.cid, c2w_frame_r), - ); - }, - Protocols::Udp(udp) => { - join!( - udp.read_from_wire(self.cid, &mut w2c_cid_frame_s, read_stop_receiver), - udp.write_to_wire(self.cid, c2w_frame_r), - ); - }, - } - - trace!("Shut down channel"); - } +#[derive(Debug)] +pub(crate) enum Protocols { + Tcp((TcpSendProtcol, TcpRecvProtcol)), + Mpsc((MpscSendProtcol, MpscRecvProtcol)), } #[derive(Debug)] -pub(crate) struct Handshake { - cid: Cid, - local_pid: Pid, - secret: u128, - init_handshake: bool, - #[cfg(feature = "metrics")] - metrics: Arc, +pub(crate) enum SendProtocols { + Tcp(TcpSendProtcol), + Mpsc(MpscSendProtcol), } -impl Handshake { - #[cfg(debug_assertions)] - const WRONG_NUMBER: &'static [u8] = "Handshake does not contain the magic number required by \ - veloren server.\nWe are not sure if you are a valid \ - veloren client.\nClosing the connection" - .as_bytes(); - #[cfg(debug_assertions)] - const WRONG_VERSION: &'static str = "Handshake does contain a correct magic number, but \ - invalid version.\nWe don't know how to communicate with \ - you.\nClosing the connection"; +#[derive(Debug)] +pub(crate) enum RecvProtocols { + Tcp(TcpRecvProtcol), + Mpsc(MpscRecvProtcol), +} - pub fn new( - cid: u64, +impl Protocols { + pub(crate) fn new_tcp(stream: tokio::net::TcpStream) -> Self { + let (r, w) = stream.into_split(); + #[cfg(feature = "metrics")] + let metrics = ProtocolMetricCache::new( + "foooobaaaarrrrrrrr", + Arc::new(ProtocolMetrics::new().unwrap()), + ); + #[cfg(not(feature = "metrics"))] + let metrics = ProtocolMetricCache {}; + + let sp = TcpSendProtcol::new(TcpDrain { half: w }, metrics.clone()); + let rp = TcpRecvProtcol::new(TcpSink { half: r }, metrics.clone()); + Protocols::Tcp((sp, rp)) + } + + pub(crate) fn new_mpsc( + sender: mpsc::Sender, + receiver: mpsc::Receiver, + ) -> Self { + #[cfg(feature = "metrics")] + let metrics = + ProtocolMetricCache::new("mppppsssscccc", Arc::new(ProtocolMetrics::new().unwrap())); + #[cfg(not(feature = "metrics"))] + let metrics = ProtocolMetricCache {}; + + let sp = MpscSendProtcol::new(MpscDrain { sender }, metrics.clone()); + let rp = MpscRecvProtcol::new(MpscSink { receiver }, metrics.clone()); + Protocols::Mpsc((sp, rp)) + } + + pub(crate) fn split(self) -> (SendProtocols, RecvProtocols) { + match self { + Protocols::Tcp((s, r)) => (SendProtocols::Tcp(s), RecvProtocols::Tcp(r)), + Protocols::Mpsc((s, r)) => (SendProtocols::Mpsc(s), RecvProtocols::Mpsc(r)), + } + } +} + +#[async_trait] +impl network_protocol::InitProtocol for Protocols { + async fn initialize( + &mut self, + initializer: bool, local_pid: Pid, secret: u128, - #[cfg(feature = "metrics")] metrics: Arc, - init_handshake: bool, - ) -> Self { - Self { - cid, - local_pid, - secret, - #[cfg(feature = "metrics")] - metrics, - init_handshake, + ) -> Result<(Pid, Sid, u128), InitProtocolError> { + match self { + Protocols::Tcp(p) => p.initialize(initializer, local_pid, secret).await, + Protocols::Mpsc(p) => p.initialize(initializer, local_pid, secret).await, + } + } +} + +#[async_trait] +impl network_protocol::SendProtocol for SendProtocols { + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + match self { + SendProtocols::Tcp(s) => s.send(event).await, + SendProtocols::Mpsc(s) => s.send(event).await, + } + } + + async fn flush(&mut self, bandwidth: u64, dt: Duration) -> Result<(), ProtocolError> { + match self { + SendProtocols::Tcp(s) => s.flush(bandwidth, dt).await, + SendProtocols::Mpsc(s) => s.flush(bandwidth, dt).await, + } + } +} + +#[async_trait] +impl network_protocol::RecvProtocol for RecvProtocols { + async fn recv(&mut self) -> Result { + match self { + RecvProtocols::Tcp(r) => r.recv().await, + RecvProtocols::Mpsc(r) => r.recv().await, + } + } +} + +/////////////////////////////////////// +//// TCP +#[derive(Debug)] +pub struct TcpDrain { + half: OwnedWriteHalf, +} + +#[derive(Debug)] +pub struct TcpSink { + half: OwnedReadHalf, +} + +#[async_trait] +impl UnreliableDrain for TcpDrain { + type DataFormat = Vec; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + //self.half.recv + match self.half.write_all(&data).await { + Ok(()) => Ok(()), + Err(_) => Err(ProtocolError::Closed), + } + } +} + +#[async_trait] +impl UnreliableSink for TcpSink { + type DataFormat = Vec; + + async fn recv(&mut self) -> Result { + let mut data = vec![0u8; 1500]; + match self.half.read(&mut data).await { + Ok(n) => { + data.truncate(n); + Ok(data) + }, + Err(_) => Err(ProtocolError::Closed), + } + } +} + +/////////////////////////////////////// +//// MPSC +#[derive(Debug)] +pub struct MpscDrain { + sender: tokio::sync::mpsc::Sender, +} + +#[derive(Debug)] +pub struct MpscSink { + receiver: tokio::sync::mpsc::Receiver, +} + +#[async_trait] +impl UnreliableDrain for MpscDrain { + type DataFormat = MpscMsg; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } +} + +#[async_trait] +impl UnreliableSink for MpscSink { + type DataFormat = MpscMsg; + + async fn recv(&mut self) -> Result { + self.receiver.recv().await.ok_or(ProtocolError::Closed) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use network_protocol::{Promises, RecvProtocol, SendProtocol}; + use tokio::net::{TcpListener, TcpStream}; + + #[tokio::test] + async fn tokio_sinks() { + let listener = TcpListener::bind("127.0.0.1:5000").await.unwrap(); + let r1 = tokio::spawn(async move { + let (server, _) = listener.accept().await.unwrap(); + (listener, server) + }); + let client = TcpStream::connect("127.0.0.1:5000").await.unwrap(); + let (_listener, server) = r1.await.unwrap(); + let client = Protocols::new_tcp(client); + let server = Protocols::new_tcp(server); + let (mut s, _) = client.split(); + let (_, mut r) = server.split(); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(1), + prio: 4u8, + promises: Promises::GUARANTEED_DELIVERY, + guaranteed_bandwidth: 1_000, + }; + s.send(event.clone()).await.unwrap(); + let r = r.recv().await; + match r { + Ok(ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth: _, + }) => { + assert_eq!(sid, Sid::new(1)); + assert_eq!(prio, 4u8); + assert_eq!(promises, Promises::GUARANTEED_DELIVERY); + }, + _ => { + panic!("wrong type {:?}", r); + }, } } - - pub async fn setup(self, protocol: &Protocols) -> Result<(Pid, Sid, u128, Vec), ()> { - let (c2w_frame_s, c2w_frame_r) = mpsc::unbounded_channel::(); - let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded_channel::(); - - let (read_stop_sender, read_stop_receiver) = oneshot::channel(); - let handler_future = - self.frame_handler(&mut w2c_cid_frame_r, c2w_frame_s, read_stop_sender); - let res = match protocol { - Protocols::Tcp(tcp) => { - (join! { - tcp.read_from_wire(self.cid, &mut w2c_cid_frame_s, read_stop_receiver), - tcp.write_to_wire(self.cid, c2w_frame_r).fuse(), - handler_future, - }) - .2 - }, - Protocols::Udp(udp) => { - (join! { - udp.read_from_wire(self.cid, &mut w2c_cid_frame_s, read_stop_receiver), - udp.write_to_wire(self.cid, c2w_frame_r), - handler_future, - }) - .2 - }, - }; - - match res { - Ok(res) => { - let fake_waker = noop_waker(); - let mut ctx = Context::from_waker(&fake_waker); - let mut leftover_frames = vec![]; - while let Poll::Ready(Some(cid_frame)) = w2c_cid_frame_r.poll_recv(&mut ctx) { - leftover_frames.push(cid_frame); - } - let cnt = leftover_frames.len(); - if cnt > 0 { - debug!( - ?cnt, - "Some additional frames got already transferred, piping them to the \ - bparticipant as leftover_frames" - ); - } - Ok((res.0, res.1, res.2, leftover_frames)) - }, - Err(()) => Err(()), - } - } - - async fn frame_handler( - &self, - w2c_cid_frame_r: &mut mpsc::UnboundedReceiver, - mut c2w_frame_s: mpsc::UnboundedSender, - read_stop_sender: oneshot::Sender<()>, - ) -> Result<(Pid, Sid, u128), ()> { - const ERR_S: &str = "Got A Raw Message, these are usually Debug Messages indicating that \ - something went wrong on network layer and connection will be closed"; - #[cfg(feature = "metrics")] - let cid_string = self.cid.to_string(); - - if self.init_handshake { - self.send_handshake(&mut c2w_frame_s).await; - } - - let frame = w2c_cid_frame_r.recv().await.map(|(_cid, frame)| frame); - #[cfg(feature = "metrics")] - { - if let Some(Ok(ref frame)) = frame { - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, &frame.get_string()]) - .inc(); - } - } - let r = match frame { - Some(Ok(Frame::Handshake { - magic_number, - version, - })) => { - trace!(?magic_number, ?version, "Recv handshake"); - if magic_number != VELOREN_MAGIC_NUMBER { - error!(?magic_number, "Connection with invalid magic_number"); - #[cfg(debug_assertions)] - self.send_raw_and_shutdown(&mut c2w_frame_s, Self::WRONG_NUMBER.to_vec()) - .await; - Err(()) - } else if version != VELOREN_NETWORK_VERSION { - error!(?version, "Connection with wrong network version"); - #[cfg(debug_assertions)] - self.send_raw_and_shutdown( - &mut c2w_frame_s, - format!( - "{} Our Version: {:?}\nYour Version: {:?}\nClosing the connection", - Self::WRONG_VERSION, - VELOREN_NETWORK_VERSION, - version, - ) - .as_bytes() - .to_vec(), - ) - .await; - Err(()) - } else { - debug!("Handshake completed"); - if self.init_handshake { - self.send_init(&mut c2w_frame_s).await; - } else { - self.send_handshake(&mut c2w_frame_s).await; - } - Ok(()) - } - }, - Some(Ok(frame)) => { - #[cfg(feature = "metrics")] - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, frame.get_string()]) - .inc(); - if let Frame::Raw(bytes) = frame { - match std::str::from_utf8(bytes.as_slice()) { - Ok(string) => error!(?string, ERR_S), - _ => error!(?bytes, ERR_S), - } - } - Err(()) - }, - Some(Err(())) => { - info!("Protocol got interrupted"); - Err(()) - }, - None => Err(()), - }; - if let Err(()) = r { - if let Err(e) = read_stop_sender.send(()) { - trace!( - ?e, - "couldn't stop protocol, probably it encountered a Protocol Stop and closed \ - itself already, which is fine" - ); - } - return Err(()); - } - - let frame = w2c_cid_frame_r.recv().await.map(|(_cid, frame)| frame); - let r = match frame { - Some(Ok(Frame::Init { pid, secret })) => { - debug!(?pid, "Participant send their ID"); - #[cfg(feature = "metrics")] - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, "ParticipantId"]) - .inc(); - let stream_id_offset = if self.init_handshake { - STREAM_ID_OFFSET1 - } else { - self.send_init(&mut c2w_frame_s).await; - STREAM_ID_OFFSET2 - }; - info!(?pid, "This Handshake is now configured!"); - Ok((pid, stream_id_offset, secret)) - }, - Some(Ok(frame)) => { - #[cfg(feature = "metrics")] - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, frame.get_string()]) - .inc(); - if let Frame::Raw(bytes) = frame { - match std::str::from_utf8(bytes.as_slice()) { - Ok(string) => error!(?string, ERR_S), - _ => error!(?bytes, ERR_S), - } - } - Err(()) - }, - Some(Err(())) => { - info!("Protocol got interrupted"); - Err(()) - }, - None => Err(()), - }; - if r.is_err() { - if let Err(e) = read_stop_sender.send(()) { - trace!( - ?e, - "couldn't stop protocol, probably it encountered a Protocol Stop and closed \ - itself already, which is fine" - ); - } - } - r - } - - async fn send_handshake(&self, c2w_frame_s: &mut mpsc::UnboundedSender) { - #[cfg(feature = "metrics")] - self.metrics - .frames_out_total - .with_label_values(&[&self.cid.to_string(), "Handshake"]) - .inc(); - c2w_frame_s - .send(Frame::Handshake { - magic_number: VELOREN_MAGIC_NUMBER, - version: VELOREN_NETWORK_VERSION, - }) - .unwrap(); - } - - async fn send_init(&self, c2w_frame_s: &mut mpsc::UnboundedSender) { - #[cfg(feature = "metrics")] - self.metrics - .frames_out_total - .with_label_values(&[&self.cid.to_string(), "ParticipantId"]) - .inc(); - c2w_frame_s - .send(Frame::Init { - pid: self.local_pid, - secret: self.secret, - }) - .unwrap(); - } - - #[cfg(debug_assertions)] - async fn send_raw_and_shutdown( - &self, - c2w_frame_s: &mut mpsc::UnboundedSender, - data: Vec, - ) { - debug!("Sending client instructions before killing"); - #[cfg(feature = "metrics")] - { - let cid_string = self.cid.to_string(); - self.metrics - .frames_out_total - .with_label_values(&[&cid_string, "Raw"]) - .inc(); - self.metrics - .frames_out_total - .with_label_values(&[&cid_string, "Shutdown"]) - .inc(); - } - c2w_frame_s.send(Frame::Raw(data)).unwrap(); - c2w_frame_s.send(Frame::Shutdown).unwrap(); - } } diff --git a/network/src/lib.rs b/network/src/lib.rs index ffba192643..7593f1edd8 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -104,14 +104,10 @@ mod channel; mod message; #[cfg(feature = "metrics")] mod metrics; mod participant; -mod prios; -mod protocols; mod scheduler; -#[macro_use] -mod types; pub use api::{ Network, NetworkError, Participant, ParticipantError, ProtocolAddr, Stream, StreamError, }; pub use message::Message; -pub use types::{Pid, Promises}; +pub use network_protocol::{Pid, Promises}; diff --git a/network/src/message.rs b/network/src/message.rs index 9ab9941599..0ad24c63ad 100644 --- a/network/src/message.rs +++ b/network/src/message.rs @@ -1,11 +1,9 @@ use serde::{de::DeserializeOwned, Serialize}; //use std::collections::VecDeque; +use crate::api::{Stream, StreamError}; +use network_protocol::MessageBuffer; #[cfg(feature = "compression")] -use crate::types::Promises; -use crate::{ - api::{Stream, StreamError}, - types::{Frame, Mid, Sid}, -}; +use network_protocol::Promises; use std::{io, sync::Arc}; #[cfg(all(feature = "compression", debug_assertions))] use tracing::warn; @@ -23,29 +21,6 @@ pub struct Message { pub(crate) compressed: bool, } -//Todo: Evaluate switching to VecDeque for quickly adding and removing data -// from front, back. -// - It would prob require custom bincode code but thats possible. -pub(crate) struct MessageBuffer { - pub data: Vec, -} - -#[derive(Debug)] -pub(crate) struct OutgoingMessage { - pub buffer: Arc, - pub cursor: u64, - pub mid: Mid, - pub sid: Sid, -} - -#[derive(Debug)] -pub(crate) struct IncomingMessage { - pub buffer: MessageBuffer, - pub length: u64, - pub mid: Mid, - pub sid: Sid, -} - impl Message { /// This serializes any message, according to the [`Streams`] [`Promises`]. /// You can reuse this `Message` and send it via other [`Streams`], if the @@ -170,38 +145,6 @@ impl Message { } } -impl OutgoingMessage { - pub(crate) const FRAME_DATA_SIZE: u64 = 1400; - - /// returns if msg is empty - pub(crate) fn fill_next>( - &mut self, - msg_sid: Sid, - frames: &mut E, - ) -> bool { - let to_send = std::cmp::min( - self.buffer.data[self.cursor as usize..].len() as u64, - Self::FRAME_DATA_SIZE, - ); - if to_send > 0 { - if self.cursor == 0 { - frames.extend(std::iter::once((msg_sid, Frame::DataHeader { - mid: self.mid, - sid: self.sid, - length: self.buffer.data.len() as u64, - }))); - } - frames.extend(std::iter::once((msg_sid, Frame::Data { - mid: self.mid, - start: self.cursor, - data: self.buffer.data[self.cursor as usize..][..to_send as usize].to_vec(), - }))); - }; - self.cursor += to_send; - self.cursor >= self.buffer.data.len() as u64 - } -} - ///wouldn't trust this aaaassss much, fine for tests pub(crate) fn partial_eq_io_error(first: &io::Error, second: &io::Error) -> bool { if let Some(f) = first.raw_os_error() { @@ -231,28 +174,6 @@ pub(crate) fn partial_eq_bincode(first: &bincode::ErrorKind, second: &bincode::E } } -impl std::fmt::Debug for MessageBuffer { - #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - //TODO: small messages! - let len = self.data.len(); - if len > 20 { - write!( - f, - "MessageBuffer(len: {}, {}, {}, {}, {:X?}..{:X?})", - len, - u32::from_le_bytes([self.data[0], self.data[1], self.data[2], self.data[3]]), - u32::from_le_bytes([self.data[4], self.data[5], self.data[6], self.data[7]]), - u32::from_le_bytes([self.data[8], self.data[9], self.data[10], self.data[11]]), - &self.data[13..16], - &self.data[len - 8..len] - ) - } else { - write!(f, "MessageBuffer(len: {}, {:?})", len, &self.data[..]) - } - } -} - #[cfg(test)] mod tests { use crate::{api::Stream, message::*}; @@ -260,7 +181,8 @@ mod tests { use tokio::sync::mpsc; fn stub_stream(compressed: bool) -> Stream { - use crate::{api::*, types::*}; + use crate::api::*; + use network_protocol::*; #[cfg(feature = "compression")] let promises = if compressed { @@ -281,6 +203,7 @@ mod tests { Sid::new(0), 0u8, promises, + 1_000_000, Arc::new(AtomicBool::new(true)), a2b_msg_s, b2a_msg_recv_r, diff --git a/network/src/metrics.rs b/network/src/metrics.rs index d43aeaae3a..60650f68fe 100644 --- a/network/src/metrics.rs +++ b/network/src/metrics.rs @@ -1,10 +1,6 @@ -use crate::types::{Cid, Frame, Pid}; -use prometheus::{ - core::{AtomicU64, GenericCounter}, - IntCounter, IntCounterVec, IntGauge, IntGaugeVec, Opts, Registry, -}; +use network_protocol::Pid; +use prometheus::{IntCounter, IntCounterVec, IntGauge, IntGaugeVec, Opts, Registry}; use std::error::Error; -use tracing::*; /// 1:1 relation between NetworkMetrics and Network /// use 2NF here and avoid redundant data like CHANNEL AND PARTICIPANT encoding. @@ -25,29 +21,6 @@ pub struct NetworkMetrics { pub streams_opened_total: IntCounterVec, pub streams_closed_total: IntCounterVec, pub network_info: IntGauge, - // Frames counted a channel level, seperated by CHANNEL (and PARTICIPANT) AND FRAME TYPE, - pub frames_out_total: IntCounterVec, - pub frames_in_total: IntCounterVec, - // Frames counted at protocol level, seperated by CHANNEL (and PARTICIPANT) AND FRAME TYPE, - pub frames_wire_out_total: IntCounterVec, - pub frames_wire_in_total: IntCounterVec, - // throughput at protocol level, seperated by CHANNEL (and PARTICIPANT), - pub wire_out_throughput: IntCounterVec, - pub wire_in_throughput: IntCounterVec, - // send(prio) Messages count, seperated by STREAM AND PARTICIPANT, - pub message_out_total: IntCounterVec, - // send(prio) Messages throughput, seperated by STREAM AND PARTICIPANT, - pub message_out_throughput: IntCounterVec, - // flushed(prio) stream count, seperated by PARTICIPANT, - pub streams_flushed: IntCounterVec, - // TODO: queued Messages, seperated by STREAM (add PART, CHANNEL), - // queued Messages, seperated by PARTICIPANT - pub queued_count: IntGaugeVec, - // TODO: queued Messages bytes, seperated by STREAM (add PART, CHANNEL), - // queued Messages bytes, seperated by PARTICIPANT - pub queued_bytes: IntGaugeVec, - // ping calculated based on last msg seperated by PARTICIPANT - pub participants_ping: IntGaugeVec, } impl NetworkMetrics { @@ -115,99 +88,13 @@ impl NetworkMetrics { "version", &format!( "{}.{}.{}", - &crate::types::VELOREN_NETWORK_VERSION[0], - &crate::types::VELOREN_NETWORK_VERSION[1], - &crate::types::VELOREN_NETWORK_VERSION[2] + &network_protocol::VELOREN_NETWORK_VERSION[0], + &network_protocol::VELOREN_NETWORK_VERSION[1], + &network_protocol::VELOREN_NETWORK_VERSION[2] ), ) .const_label("local_pid", &format!("{}", &local_pid)); let network_info = IntGauge::with_opts(opts)?; - let frames_out_total = IntCounterVec::new( - Opts::new( - "frames_out_total", - "Number of all frames send per channel, at the channel level", - ), - &["channel", "frametype"], - )?; - let frames_in_total = IntCounterVec::new( - Opts::new( - "frames_in_total", - "Number of all frames received per channel, at the channel level", - ), - &["channel", "frametype"], - )?; - let frames_wire_out_total = IntCounterVec::new( - Opts::new( - "frames_wire_out_total", - "Number of all frames send per channel, at the protocol level", - ), - &["channel", "frametype"], - )?; - let frames_wire_in_total = IntCounterVec::new( - Opts::new( - "frames_wire_in_total", - "Number of all frames received per channel, at the protocol level", - ), - &["channel", "frametype"], - )?; - let wire_out_throughput = IntCounterVec::new( - Opts::new( - "wire_out_throughput", - "Throupgput of all data frames send per channel, at the protocol level", - ), - &["channel"], - )?; - let wire_in_throughput = IntCounterVec::new( - Opts::new( - "wire_in_throughput", - "Throupgput of all data frames send per channel, at the protocol level", - ), - &["channel"], - )?; - //TODO IN - let message_out_total = IntCounterVec::new( - Opts::new( - "message_out_total", - "Number of messages send by streams on the network", - ), - &["participant", "stream"], - )?; - //TODO IN - let message_out_throughput = IntCounterVec::new( - Opts::new( - "message_out_throughput", - "Throughput of messages send by streams on the network", - ), - &["participant", "stream"], - )?; - let streams_flushed = IntCounterVec::new( - Opts::new( - "stream_flushed", - "Number of flushed streams requested to PrioManager at participant level", - ), - &["participant"], - )?; - let queued_count = IntGaugeVec::new( - Opts::new( - "queued_count", - "Queued number of messages by participant on the network", - ), - &["channel"], - )?; - let queued_bytes = IntGaugeVec::new( - Opts::new( - "queued_bytes", - "Queued bytes of messages by participant on the network", - ), - &["channel"], - )?; - let participants_ping = IntGaugeVec::new( - Opts::new( - "participants_ping", - "Ping time to participants on the network", - ), - &["channel"], - )?; Ok(Self { listen_requests_total, @@ -220,18 +107,6 @@ impl NetworkMetrics { streams_opened_total, streams_closed_total, network_info, - frames_out_total, - frames_in_total, - frames_wire_out_total, - frames_wire_in_total, - wire_out_throughput, - wire_in_throughput, - message_out_total, - message_out_throughput, - streams_flushed, - queued_count, - queued_bytes, - participants_ping, }) } @@ -246,22 +121,8 @@ impl NetworkMetrics { registry.register(Box::new(self.streams_opened_total.clone()))?; registry.register(Box::new(self.streams_closed_total.clone()))?; registry.register(Box::new(self.network_info.clone()))?; - registry.register(Box::new(self.frames_out_total.clone()))?; - registry.register(Box::new(self.frames_in_total.clone()))?; - registry.register(Box::new(self.frames_wire_out_total.clone()))?; - registry.register(Box::new(self.frames_wire_in_total.clone()))?; - registry.register(Box::new(self.wire_out_throughput.clone()))?; - registry.register(Box::new(self.wire_in_throughput.clone()))?; - registry.register(Box::new(self.message_out_total.clone()))?; - registry.register(Box::new(self.message_out_throughput.clone()))?; - registry.register(Box::new(self.queued_count.clone()))?; - registry.register(Box::new(self.queued_bytes.clone()))?; - registry.register(Box::new(self.participants_ping.clone()))?; Ok(()) } - - //pub fn _is_100th_tick(&self) -> bool { - // self.tick.load(Ordering::Relaxed).rem_euclid(100) == 0 } } impl std::fmt::Debug for NetworkMetrics { @@ -270,138 +131,3 @@ impl std::fmt::Debug for NetworkMetrics { write!(f, "NetworkMetrics()") } } - -/* -pub(crate) struct PidCidFrameCache { - metric: MetricVec, - pid: String, - cache: Vec<[T::M; 8]>, -} -*/ - -pub(crate) struct MultiCidFrameCache { - metric: IntCounterVec, - cache: Vec<[Option>; Frame::FRAMES_LEN as usize]>, -} - -impl MultiCidFrameCache { - const CACHE_SIZE: usize = 2048; - - pub fn new(metric: IntCounterVec) -> Self { - Self { - metric, - cache: vec![], - } - } - - fn populate(&mut self, cid: Cid) { - let start_cid = self.cache.len(); - if cid >= start_cid as u64 && cid > (Self::CACHE_SIZE as Cid) { - warn!( - ?cid, - "cid, getting quite high, is this a attack on the cache?" - ); - } - self.cache.resize((cid + 1) as usize, [ - None, None, None, None, None, None, None, None, - ]); - } - - pub fn with_label_values(&mut self, cid: Cid, frame: &Frame) -> &GenericCounter { - self.populate(cid); - let frame_int = frame.get_int() as usize; - let r = &mut self.cache[cid as usize][frame_int]; - if r.is_none() { - *r = Some( - self.metric - .with_label_values(&[&cid.to_string(), &frame_int.to_string()]), - ); - } - r.as_ref().unwrap() - } -} - -pub(crate) struct CidFrameCache { - cache: [GenericCounter; Frame::FRAMES_LEN as usize], -} - -impl CidFrameCache { - pub fn new(metric: IntCounterVec, cid: Cid) -> Self { - let cid = cid.to_string(); - let cache = [ - metric.with_label_values(&[&cid, Frame::int_to_string(0)]), - metric.with_label_values(&[&cid, Frame::int_to_string(1)]), - metric.with_label_values(&[&cid, Frame::int_to_string(2)]), - metric.with_label_values(&[&cid, Frame::int_to_string(3)]), - metric.with_label_values(&[&cid, Frame::int_to_string(4)]), - metric.with_label_values(&[&cid, Frame::int_to_string(5)]), - metric.with_label_values(&[&cid, Frame::int_to_string(6)]), - metric.with_label_values(&[&cid, Frame::int_to_string(7)]), - ]; - Self { cache } - } - - pub fn with_label_values(&mut self, frame: &Frame) -> &GenericCounter { - &self.cache[frame.get_int() as usize] - } -} - -#[cfg(test)] -mod tests { - use crate::{ - metrics::*, - types::{Frame, Pid}, - }; - - #[test] - fn register_metrics() { - let registry = Registry::new(); - let metrics = NetworkMetrics::new(&Pid::fake(1)).unwrap(); - metrics.register(®istry).unwrap(); - } - - #[test] - fn multi_cid_frame_cache() { - let pid = Pid::fake(1); - let frame1 = Frame::Raw(b"Foo".to_vec()); - let frame2 = Frame::Raw(b"Bar".to_vec()); - let metrics = NetworkMetrics::new(&pid).unwrap(); - let mut cache = MultiCidFrameCache::new(metrics.frames_in_total); - let v1 = cache.with_label_values(1, &frame1); - v1.inc(); - assert_eq!(v1.get(), 1); - let v2 = cache.with_label_values(1, &frame1); - v2.inc(); - assert_eq!(v2.get(), 2); - let v3 = cache.with_label_values(1, &frame2); - v3.inc(); - assert_eq!(v3.get(), 3); - let v4 = cache.with_label_values(3, &frame1); - v4.inc(); - assert_eq!(v4.get(), 1); - let v5 = cache.with_label_values(3, &Frame::Shutdown); - v5.inc(); - assert_eq!(v5.get(), 1); - } - - #[test] - fn cid_frame_cache() { - let pid = Pid::fake(1); - let frame1 = Frame::Raw(b"Foo".to_vec()); - let frame2 = Frame::Raw(b"Bar".to_vec()); - let metrics = NetworkMetrics::new(&pid).unwrap(); - let mut cache = CidFrameCache::new(metrics.frames_wire_out_total, 1); - let v1 = cache.with_label_values(&frame1); - v1.inc(); - assert_eq!(v1.get(), 1); - let v2 = cache.with_label_values(&frame1); - v2.inc(); - assert_eq!(v2.get(), 2); - let v3 = cache.with_label_values(&frame2); - v3.inc(); - assert_eq!(v3.get(), 3); - let v4 = cache.with_label_values(&Frame::Shutdown); - v4.inc(); - assert_eq!(v4.get(), 1); - } -} diff --git a/network/src/participant.rs b/network/src/participant.rs index 6986a70e8f..a942632f7b 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -1,43 +1,39 @@ #[cfg(feature = "metrics")] -use crate::metrics::{MultiCidFrameCache, NetworkMetrics}; +use crate::metrics::NetworkMetrics; use crate::{ api::{ParticipantError, Stream}, - channel::Channel, - message::{IncomingMessage, MessageBuffer, OutgoingMessage}, - prios::PrioManager, - protocols::Protocols, - types::{Cid, Frame, Pid, Prio, Promises, Sid}, + channel::{Protocols, RecvProtocols, SendProtocols}, }; use futures_util::{FutureExt, StreamExt}; +use network_protocol::{ + Bandwidth, Cid, MessageBuffer, Pid, Prio, Promises, ProtocolEvent, RecvProtocol, SendProtocol, + Sid, +}; use std::{ - collections::{HashMap, VecDeque}, + collections::HashMap, sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, + atomic::{AtomicBool, AtomicI32, Ordering}, Arc, }, time::{Duration, Instant}, }; use tokio::{ - runtime::Runtime, select, sync::{mpsc, oneshot, Mutex, RwLock}, + task::JoinHandle, }; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::*; -use tracing_futures::Instrument; -pub(crate) type A2bStreamOpen = (Prio, Promises, oneshot::Sender); -pub(crate) type C2pFrame = (Cid, Result); -pub(crate) type S2bCreateChannel = (Cid, Sid, Protocols, Vec, oneshot::Sender<()>); -pub(crate) type S2bShutdownBparticipant = oneshot::Sender>; +pub(crate) type A2bStreamOpen = (Prio, Promises, Bandwidth, oneshot::Sender); +pub(crate) type S2bCreateChannel = (Cid, Sid, Protocols, oneshot::Sender<()>); +pub(crate) type S2bShutdownBparticipant = (Duration, oneshot::Sender>); pub(crate) type B2sPrioStatistic = (Pid, u64, u64); #[derive(Debug)] struct ChannelInfo { cid: Cid, cid_string: String, //optimisationmetrics - b2w_frame_s: mpsc::UnboundedSender, - b2r_read_shutdown: oneshot::Sender<()>, } #[derive(Debug)] @@ -45,23 +41,19 @@ struct StreamInfo { prio: Prio, promises: Promises, send_closed: Arc, - b2a_msg_recv_s: Mutex>, + b2a_msg_recv_s: Mutex>, } #[derive(Debug)] struct ControlChannels { - a2b_stream_open_r: mpsc::UnboundedReceiver, + a2b_open_stream_r: mpsc::UnboundedReceiver, b2a_stream_opened_s: mpsc::UnboundedSender, - b2b_close_stream_opened_sender_r: oneshot::Receiver<()>, s2b_create_channel_r: mpsc::UnboundedReceiver, - a2b_close_stream_r: mpsc::UnboundedReceiver, - a2b_close_stream_s: mpsc::UnboundedSender, s2b_shutdown_bparticipant_r: oneshot::Receiver, /* own */ } #[derive(Debug)] struct ShutdownInfo { - //a2b_stream_open_r: mpsc::UnboundedReceiver, b2b_close_stream_opened_sender_s: Option>, error: Option, } @@ -71,29 +63,27 @@ pub struct BParticipant { remote_pid: Pid, remote_pid_string: String, //optimisation offset_sid: Sid, - runtime: Arc, channels: Arc>>>, streams: RwLock>, - running_mgr: AtomicUsize, run_channels: Option, + shutdown_barrier: AtomicI32, #[cfg(feature = "metrics")] metrics: Arc, no_channel_error_info: RwLock<(Instant, u64)>, - shutdown_info: RwLock, } impl BParticipant { - const BANDWIDTH: u64 = 25_000_000; - const FRAMES_PER_TICK: u64 = Self::BANDWIDTH * Self::TICK_TIME_MS / 1000 / 1400 /*TCP FRAME*/; + // We use integer instead of Barrier to not block mgr from freeing at the end + const BARR_CHANNEL: i32 = 1; + const BARR_RECV: i32 = 4; + const BARR_SEND: i32 = 2; const TICK_TIME: Duration = Duration::from_millis(Self::TICK_TIME_MS); - //in bit/s const TICK_TIME_MS: u64 = 10; #[allow(clippy::type_complexity)] pub(crate) fn new( remote_pid: Pid, offset_sid: Sid, - runtime: Arc, #[cfg(feature = "metrics")] metrics: Arc, ) -> ( Self, @@ -102,27 +92,15 @@ impl BParticipant { mpsc::UnboundedSender, oneshot::Sender, ) { - let (a2b_steam_open_s, a2b_stream_open_r) = mpsc::unbounded_channel::(); + let (a2b_open_stream_s, a2b_open_stream_r) = mpsc::unbounded_channel::(); let (b2a_stream_opened_s, b2a_stream_opened_r) = mpsc::unbounded_channel::(); - let (b2b_close_stream_opened_sender_s, b2b_close_stream_opened_sender_r) = - oneshot::channel(); - let (a2b_close_stream_s, a2b_close_stream_r) = mpsc::unbounded_channel(); let (s2b_shutdown_bparticipant_s, s2b_shutdown_bparticipant_r) = oneshot::channel(); let (s2b_create_channel_s, s2b_create_channel_r) = mpsc::unbounded_channel(); - let shutdown_info = RwLock::new(ShutdownInfo { - //a2b_stream_open_r: a2b_stream_open_r.clone(), - b2b_close_stream_opened_sender_s: Some(b2b_close_stream_opened_sender_s), - error: None, - }); - let run_channels = Some(ControlChannels { - a2b_stream_open_r, + a2b_open_stream_r, b2a_stream_opened_s, - b2b_close_stream_opened_sender_r, s2b_create_channel_r, - a2b_close_stream_r, - a2b_close_stream_s, s2b_shutdown_bparticipant_r, }); @@ -131,17 +109,17 @@ impl BParticipant { remote_pid, remote_pid_string: remote_pid.to_string(), offset_sid, - runtime, channels: Arc::new(RwLock::new(HashMap::new())), streams: RwLock::new(HashMap::new()), - running_mgr: AtomicUsize::new(0), + shutdown_barrier: AtomicI32::new( + Self::BARR_CHANNEL + Self::BARR_SEND + Self::BARR_RECV, + ), run_channels, #[cfg(feature = "metrics")] metrics, no_channel_error_info: RwLock::new((Instant::now(), 0)), - shutdown_info, }, - a2b_steam_open_s, + a2b_open_stream_s, b2a_stream_opened_r, s2b_create_channel_s, s2b_shutdown_bparticipant_s, @@ -149,693 +127,486 @@ impl BParticipant { } pub async fn run(mut self, b2s_prio_statistic_s: mpsc::UnboundedSender) { - //those managers that listen on api::Participant need an additional oneshot for - // shutdown scenario, those handled by scheduler will be closed by it. - let (shutdown_send_mgr_sender, shutdown_send_mgr_receiver) = oneshot::channel(); - let (shutdown_stream_close_mgr_sender, shutdown_stream_close_mgr_receiver) = - oneshot::channel(); - let (shutdown_open_mgr_sender, shutdown_open_mgr_receiver) = oneshot::channel(); - let (w2b_frames_s, w2b_frames_r) = mpsc::unbounded_channel::(); - let (prios, a2p_msg_s, b2p_notify_empty_stream_s) = PrioManager::new( - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - self.remote_pid_string.clone(), - ); + let (b2b_add_send_protocol_s, b2b_add_send_protocol_r) = + mpsc::unbounded_channel::<(Cid, SendProtocols)>(); + let (b2b_add_recv_protocol_s, b2b_add_recv_protocol_r) = + mpsc::unbounded_channel::<(Cid, RecvProtocols)>(); + let (b2b_close_send_protocol_s, b2b_close_send_protocol_r) = + async_channel::unbounded::(); + let (b2b_force_close_recv_protocol_s, b2b_force_close_recv_protocol_r) = + async_channel::unbounded::(); + + let (a2b_close_stream_s, a2b_close_stream_r) = mpsc::unbounded_channel::(); + const STREAM_BOUND: usize = 10_000; + let (a2b_msg_s, a2b_msg_r) = + crossbeam_channel::bounded::<(Sid, Arc)>(STREAM_BOUND); let run_channels = self.run_channels.take().unwrap(); tokio::join!( - self.open_mgr( - run_channels.a2b_stream_open_r, - run_channels.a2b_close_stream_s.clone(), - a2p_msg_s.clone(), - shutdown_open_mgr_receiver, + self.send_mgr( + run_channels.a2b_open_stream_r, + a2b_close_stream_r, + a2b_msg_r, + b2b_add_send_protocol_r, + b2b_close_send_protocol_r, + b2s_prio_statistic_s, + a2b_msg_s.clone(), //self + a2b_close_stream_s.clone(), //self ), - self.handle_frames_mgr( - w2b_frames_r, + self.recv_mgr( run_channels.b2a_stream_opened_s, - run_channels.b2b_close_stream_opened_sender_r, - run_channels.a2b_close_stream_s, - a2p_msg_s.clone(), + b2b_add_recv_protocol_r, + b2b_force_close_recv_protocol_r, + b2b_close_send_protocol_s.clone(), + a2b_msg_s.clone(), //self + a2b_close_stream_s.clone(), //self ), - self.create_channel_mgr(run_channels.s2b_create_channel_r, w2b_frames_s), - self.send_mgr(prios, shutdown_send_mgr_receiver, b2s_prio_statistic_s), - self.stream_close_mgr( - run_channels.a2b_close_stream_r, - shutdown_stream_close_mgr_receiver, - b2p_notify_empty_stream_s, + self.create_channel_mgr( + run_channels.s2b_create_channel_r, + b2b_add_send_protocol_s, + b2b_add_recv_protocol_s, ), self.participant_shutdown_mgr( run_channels.s2b_shutdown_bparticipant_r, - shutdown_open_mgr_sender, - shutdown_stream_close_mgr_sender, - shutdown_send_mgr_sender, + b2b_close_send_protocol_s.clone(), + b2b_force_close_recv_protocol_s, ), ); } + //TODO: local stream_cid: HashMap to know the respective protocol async fn send_mgr( &self, - mut prios: PrioManager, - mut shutdown_send_mgr_receiver: oneshot::Receiver>, - b2s_prio_statistic_s: mpsc::UnboundedSender, - ) { - //This time equals the MINIMUM Latency in average, so keep it down and //Todo: - // make it configurable or switch to await E.g. Prio 0 = await, prio 50 - // wait for more messages - self.running_mgr.fetch_add(1, Ordering::Relaxed); - let mut b2b_prios_flushed_s = None; //closing up - let mut interval = tokio::time::interval(Self::TICK_TIME); - trace!("Start send_mgr"); - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - let mut i: u64 = 0; - loop { - let mut frames = VecDeque::new(); - prios - .fill_frames(Self::FRAMES_PER_TICK as usize, &mut frames) - .await; - let len = frames.len(); - for (_, frame) in frames { - self.send_frame( - frame, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await; - } - b2s_prio_statistic_s - .send((self.remote_pid, len as u64, /* */ 0)) - .unwrap(); - interval.tick().await; - i += 1; - if i.rem_euclid(1000) == 0 { - trace!("Did 1000 ticks"); - } - //shutdown after all msg are send! - // Make sure this is called after the API is closed, and all streams are known - // to be droped to the priomgr - if b2b_prios_flushed_s.is_some() && (len == 0) { - break; - } - if b2b_prios_flushed_s.is_none() { - if let Ok(prios_flushed_s) = shutdown_send_mgr_receiver.try_recv() { - b2b_prios_flushed_s = Some(prios_flushed_s); - } - } - } - trace!("Stop send_mgr"); - b2b_prios_flushed_s - .expect("b2b_prios_flushed_s not set") - .send(()) - .unwrap(); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); - } - - //returns false if sending isn't possible. In that case we have to render the - // Participant `closed` - #[must_use = "You need to check if the send was successful and report to client!"] - async fn send_frame( - &self, - frame: Frame, - #[cfg(feature = "metrics")] frames_out_total_cache: &mut MultiCidFrameCache, - ) -> bool { - let mut drop_cid = None; - // TODO: find out ideal channel here - - let res = if let Some(ci) = self.channels.read().await.values().next() { - let ci = ci.lock().await; - //we are increasing metrics without checking the result to please - // borrow_checker. otherwise we would need to close `frame` what we - // dont want! - #[cfg(feature = "metrics")] - frames_out_total_cache - .with_label_values(ci.cid, &frame) - .inc(); - if let Err(e) = ci.b2w_frame_s.send(frame) { - let cid = ci.cid; - info!(?e, ?cid, "channel no longer available"); - drop_cid = Some(cid); - false - } else { - true - } - } else { - let mut guard = self.no_channel_error_info.write().await; - let now = Instant::now(); - if now.duration_since(guard.0) > Duration::from_secs(1) { - guard.0 = now; - let occurrences = guard.1 + 1; - guard.1 = 0; - let lastframe = frame; - error!( - ?occurrences, - ?lastframe, - "Participant has no channel to communicate on" - ); - } else { - guard.1 += 1; - } - false - }; - if let Some(cid) = drop_cid { - if let Some(ci) = self.channels.write().await.remove(&cid) { - let ci = ci.into_inner(); - trace!(?cid, "stopping read protocol"); - if let Err(e) = ci.b2r_read_shutdown.send(()) { - trace!(?cid, ?e, "seems like was already shut down"); - } - } - //TODO FIXME tags: takeover channel multiple - info!( - "FIXME: the frame is actually drop. which is fine for now as the participant will \ - be closed, but not if we do channel-takeover" - ); - //TEMP FIX: as we dont have channel takeover yet drop the whole bParticipant - self.close_write_api(Some(ParticipantError::ProtocolFailedUnrecoverable)) - .await; - }; - res - } - - async fn handle_frames_mgr( - &self, - mut w2b_frames_r: mpsc::UnboundedReceiver, - b2a_stream_opened_s: mpsc::UnboundedSender, - b2b_close_stream_opened_sender_r: oneshot::Receiver<()>, + mut a2b_open_stream_r: mpsc::UnboundedReceiver, + mut a2b_close_stream_r: mpsc::UnboundedReceiver, + a2b_msg_r: crossbeam_channel::Receiver<(Sid, Arc)>, + mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, SendProtocols)>, + b2b_close_send_protocol_r: async_channel::Receiver, + _b2s_prio_statistic_s: mpsc::UnboundedSender, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, a2b_close_stream_s: mpsc::UnboundedSender, - a2p_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); - trace!("Start handle_frames_mgr"); - let mut messages = HashMap::new(); - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - let mut dropped_instant = Instant::now(); - let mut dropped_cnt = 0u64; - let mut dropped_sid = Sid::new(0); - let mut b2a_stream_opened_s = Some(b2a_stream_opened_s); - let mut b2b_close_stream_opened_sender_r = b2b_close_stream_opened_sender_r.fuse(); + let mut send_protocols: HashMap = HashMap::new(); + let mut interval = tokio::time::interval(Self::TICK_TIME); + let mut stream_ids = self.offset_sid; + trace!("workaround, activly wait for first protocol"); + b2b_add_protocol_r + .recv() + .await + .map(|(c, p)| send_protocols.insert(c, p)); + trace!("Start send_mgr"); + loop { + let (open, close, _, addp, remp) = select!( + next = a2b_open_stream_r.recv().fuse() => (Some(next), None, None, None, None), + next = a2b_close_stream_r.recv().fuse() => (None, Some(next), None, None, None), + _ = interval.tick() => (None, None, Some(()), None, None), + next = b2b_add_protocol_r.recv().fuse() => (None, None, None, Some(next), None), + next = b2b_close_send_protocol_r.recv().fuse() => (None, None, None, None, Some(next)), + ); - while let Some((cid, result_frame)) = select!( - next = w2b_frames_r.recv().fuse() => next, - _ = &mut b2b_close_stream_opened_sender_r => { - b2a_stream_opened_s = None; - None - }, - ) { - //trace!(?result_frame, "handling frame"); - let frame = match result_frame { - Ok(frame) => frame, - Err(()) => { - // The read protocol stopped, i need to make sure that write gets stopped, can - // drop channel as it's dead anyway - debug!("read protocol was closed. Stopping channel"); - self.channels.write().await.remove(&cid); + trace!(?open, ?close, ?addp, ?remp, "foobar"); + + addp.flatten().map(|(c, p)| send_protocols.insert(c, p)); + match remp { + Some(Ok(cid)) => { + trace!(?cid, "remove send protocol"); + match send_protocols.remove(&cid) { + Some(mut prot) => { + trace!("blocking flush"); + let _ = prot.flush(u64::MAX, Duration::from_secs(1)).await; + trace!("shutdown prot"); + let _ = prot.send(ProtocolEvent::Shutdown).await; + }, + None => trace!("tried to remove protocol twice"), + }; + if send_protocols.is_empty() { + break; + } + }, + _ => (), + }; + + let cid = 0; + let active = match send_protocols.get_mut(&cid) { + Some(a) => a, + None => { + warn!("no channel arrg"); continue; }, }; - #[cfg(feature = "metrics")] - { - let cid_string = cid.to_string(); - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, frame.get_string()]) - .inc(); - } - match frame { - Frame::OpenStream { - sid, - prio, - promises, - } => { - trace!(?sid, ?prio, ?promises, "Opened frame from remote"); - let a2p_msg_s = a2p_msg_s.clone(); + + let active_err = async { + if let Some(Some((prio, promises, guaranteed_bandwidth, return_s))) = open { + trace!(?stream_ids, "openuing some new stream"); + let sid = stream_ids; + stream_ids += Sid::from(1); let stream = self - .create_stream(sid, prio, promises, a2p_msg_s, &a2b_close_stream_s) + .create_stream( + sid, + prio, + promises, + guaranteed_bandwidth, + &a2b_msg_s, + &a2b_close_stream_s, + ) .await; - match &b2a_stream_opened_s { - None => debug!("dropping openStream as Channel is already closing"), - Some(s) => { - if let Err(e) = s.send(stream) { - warn!( - ?e, - ?sid, - "couldn't notify api::Participant that a stream got opened. \ - Is the participant already dropped?" - ); - } - }, - } - }, - Frame::CloseStream { sid } => { - // no need to keep flushing as the remote no longer knows about this stream - // anyway - self.delete_stream( - sid, - None, - true, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await; - }, - Frame::DataHeader { mid, sid, length } => { - let imsg = IncomingMessage { - buffer: MessageBuffer { data: Vec::new() }, - length, - mid, + + let event = ProtocolEvent::OpenStream { sid, + prio, + promises, + guaranteed_bandwidth, }; - messages.insert(mid, imsg); - }, - Frame::Data { - mid, - start: _, - mut data, - } => { - let finished = if let Some(imsg) = messages.get_mut(&mid) { - imsg.buffer.data.append(&mut data); - imsg.buffer.data.len() as u64 == imsg.length - } else { - false - }; - if finished { - //trace!(?mid, "finished receiving message"); - let imsg = messages.remove(&mid).unwrap(); - if let Some(si) = self.streams.read().await.get(&imsg.sid) { - if let Err(e) = si.b2a_msg_recv_s.lock().await.send(imsg).await { - warn!( - ?e, - ?mid, - "Dropping message, as streams seem to be in act of being \ - dropped right now" - ); - } - } else { - //aggregate errors - let n = Instant::now(); - if dropped_cnt > 0 - && (dropped_sid != imsg.sid - || n.duration_since(dropped_instant) > Duration::from_secs(1)) - { - warn!( - ?dropped_cnt, - "Dropping multiple messages as stream no longer seems to \ - exist because it was dropped probably." - ); - dropped_cnt = 0; - dropped_instant = n; - dropped_sid = imsg.sid; - } else { - dropped_cnt += 1; - } - } - } - }, - Frame::Shutdown => { - debug!("Shutdown received from remote side"); - self.close_api(Some(ParticipantError::ParticipantDisconnected)) - .await; - }, - f => { - unreachable!( - "Frame should never reach participant!: {:?}, cid: {}", - f, cid - ); - }, + + return_s.send(stream).unwrap(); + active.send(event).await?; + } + + // get all messages and assign it to a channel + for (sid, buffer) in a2b_msg_r.try_iter() { + warn!(?sid, "sending!"); + active + .send(ProtocolEvent::Message { + buffer, + mid: 0u64, + sid, + }) + .await? + } + + if let Some(Some(sid)) = close { + warn!(?sid, "delete_stream!"); + self.delete_stream(sid).await; + // Fire&Forget the protocol will take care to verify that this Frame is delayed + // till the last msg was received! + active.send(ProtocolEvent::CloseStream { sid }).await?; + } + + warn!("flush!"); + active + .flush(1_000_000, Duration::from_secs(1) /* TODO */) + .await?; //this actually blocks, so we cant set streams whilte it. + let r: Result<(), network_protocol::ProtocolError> = Ok(()); + r + } + .await; + if let Err(e) = active_err { + info!(?cid, ?e, "send protocol failed, shutting down channel"); + // remote recv will now fail, which will trigger remote send which will trigger + // recv + send_protocols.remove(&cid).unwrap(); } } - if dropped_cnt > 0 { - warn!( - ?dropped_cnt, - "Dropping multiple messages as stream no longer seems to exist because it was \ - dropped probably." + trace!("Stop send_mgr"); + self.shutdown_barrier + .fetch_sub(Self::BARR_SEND, Ordering::Relaxed); + } + + async fn recv_mgr( + &self, + b2a_stream_opened_s: mpsc::UnboundedSender, + mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, RecvProtocols)>, + b2b_force_close_recv_protocol_r: async_channel::Receiver, + b2b_close_send_protocol_s: async_channel::Sender, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, + a2b_close_stream_s: mpsc::UnboundedSender, + ) { + let mut recv_protocols: HashMap> = HashMap::new(); + // we should be able to directly await futures imo + let (hacky_recv_s, mut hacky_recv_r) = mpsc::unbounded_channel(); + + let retrigger = |cid: Cid, mut p: RecvProtocols, map: &mut HashMap<_, _>| { + let hacky_recv_s = hacky_recv_s.clone(); + let handle = tokio::spawn(async move { + let cid = cid; + let r = p.recv().await; + let _ = hacky_recv_s.send((cid, r, p)); // ignoring failed + }); + map.insert(cid, handle); + }; + + let remove_c = |recv_protocols: &mut HashMap>, cid: &Cid| { + match recv_protocols.remove(&cid) { + Some(h) => h.abort(), + None => trace!("tried to remove protocol twice"), + }; + recv_protocols.is_empty() + }; + + trace!("Start recv_mgr"); + loop { + let (event, addp, remp) = select!( + next = hacky_recv_r.recv().fuse() => (Some(next), None, None), + Some(next) = b2b_add_protocol_r.recv().fuse() => (None, Some(next), None), + next = b2b_force_close_recv_protocol_r.recv().fuse() => (None, None, Some(next)), ); + + addp.map(|(cid, p)| { + retrigger(cid, p, &mut recv_protocols); + }); + if let Some(Ok(cid)) = remp { + // no need to stop the send_mgr here as it has been canceled before + if remove_c(&mut recv_protocols, &cid) { + break; + } + }; + + warn!(?event, "recv event!"); + if let Some(Some((cid, r, p))) = event { + match r { + Ok(ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + }) => { + trace!(?sid, "open stream"); + let stream = self + .create_stream( + sid, + prio, + promises, + guaranteed_bandwidth, + &a2b_msg_s, + &a2b_close_stream_s, + ) + .await; + b2a_stream_opened_s.send(stream).unwrap(); + retrigger(cid, p, &mut recv_protocols); + }, + Ok(ProtocolEvent::CloseStream { sid }) => { + trace!(?sid, "close stream"); + self.delete_stream(sid).await; + retrigger(cid, p, &mut recv_protocols); + }, + Ok(ProtocolEvent::Message { + buffer, + mid: _, + sid, + }) => { + let buffer = Arc::try_unwrap(buffer).unwrap(); + let lock = self.streams.read().await; + match lock.get(&sid) { + Some(stream) => { + stream + .b2a_msg_recv_s + .lock() + .await + .send(buffer) + .await + .unwrap(); + }, + None => warn!("recv a msg with orphan stream"), + }; + retrigger(cid, p, &mut recv_protocols); + }, + Ok(ProtocolEvent::Shutdown) => { + info!(?cid, "shutdown protocol"); + if let Err(e) = b2b_close_send_protocol_s.send(cid).await { + debug!(?e, ?cid, "send_mgr was already closed simultaneously"); + } + if remove_c(&mut recv_protocols, &cid) { + break; + } + }, + Err(e) => { + info!(?cid, ?e, "recv protocol failed, shutting down channel"); + if let Err(e) = b2b_close_send_protocol_s.send(cid).await { + debug!(?e, ?cid, "send_mgr was already closed simultaneously"); + } + if remove_c(&mut recv_protocols, &cid) { + break; + } + }, + } + } } - trace!("Stop handle_frames_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); + + trace!("Stop recv_mgr"); + self.shutdown_barrier + .fetch_sub(Self::BARR_RECV, Ordering::Relaxed); } async fn create_channel_mgr( &self, s2b_create_channel_r: mpsc::UnboundedReceiver, - w2b_frames_s: mpsc::UnboundedSender, + b2b_add_send_protocol_s: mpsc::UnboundedSender<(Cid, SendProtocols)>, + b2b_add_recv_protocol_s: mpsc::UnboundedSender<(Cid, RecvProtocols)>, ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); trace!("Start create_channel_mgr"); let s2b_create_channel_r = UnboundedReceiverStream::new(s2b_create_channel_r); s2b_create_channel_r - .for_each_concurrent( - None, - |(cid, _, protocol, leftover_cid_frame, b2s_create_channel_done_s)| { - // This channel is now configured, and we are running it in scope of the - // participant. - let w2b_frames_s = w2b_frames_s.clone(); - let channels = Arc::clone(&self.channels); - async move { - let (channel, b2w_frame_s, b2r_read_shutdown) = Channel::new(cid); - let mut lock = channels.write().await; - #[cfg(feature = "metrics")] - let mut channel_no = lock.len(); - #[cfg(not(feature = "metrics"))] - let channel_no = lock.len(); - lock.insert( + .for_each_concurrent(None, |(cid, _, protocol, b2s_create_channel_done_s)| { + // This channel is now configured, and we are running it in scope of the + // participant. + //let w2b_frames_s = w2b_frames_s.clone(); + let channels = Arc::clone(&self.channels); + let b2b_add_send_protocol_s = b2b_add_send_protocol_s.clone(); + let b2b_add_recv_protocol_s = b2b_add_recv_protocol_s.clone(); + async move { + let mut lock = channels.write().await; + #[cfg(feature = "metrics")] + let mut channel_no = lock.len(); + lock.insert( + cid, + Mutex::new(ChannelInfo { cid, - Mutex::new(ChannelInfo { - cid, - cid_string: cid.to_string(), - b2w_frame_s, - b2r_read_shutdown, - }), - ); - drop(lock); - b2s_create_channel_done_s.send(()).unwrap(); - #[cfg(feature = "metrics")] - { - self.metrics - .channels_connected_total - .with_label_values(&[&self.remote_pid_string]) - .inc(); - if channel_no > 5 { - debug!(?channel_no, "metrics will overwrite channel #5"); - channel_no = 5; - } - self.metrics - .participants_channel_ids - .with_label_values(&[ - &self.remote_pid_string, - &channel_no.to_string(), - ]) - .set(cid as i64); - } - trace!(?cid, ?channel_no, "Running channel in participant"); - channel - .run(protocol, w2b_frames_s, leftover_cid_frame) - .instrument(tracing::info_span!("", ?cid)) - .await; - #[cfg(feature = "metrics")] + cid_string: cid.to_string(), + }), + ); + drop(lock); + let (send, recv) = protocol.split(); + b2b_add_send_protocol_s.send((cid, send)).unwrap(); + b2b_add_recv_protocol_s.send((cid, recv)).unwrap(); + b2s_create_channel_done_s.send(()).unwrap(); + #[cfg(feature = "metrics")] + { self.metrics - .channels_disconnected_total + .channels_connected_total .with_label_values(&[&self.remote_pid_string]) .inc(); - info!(?cid, "Channel got closed"); - //maybe channel got already dropped, we don't know. - channels.write().await.remove(&cid); - trace!(?cid, "Channel cleanup completed"); - //TEMP FIX: as we dont have channel takeover yet drop the whole - // bParticipant - self.close_write_api(None).await; + if channel_no > 5 { + debug!(?channel_no, "metrics will overwrite channel #5"); + channel_no = 5; + } + self.metrics + .participants_channel_ids + .with_label_values(&[&self.remote_pid_string, &channel_no.to_string()]) + .set(cid as i64); } - }, - ) + } + }) .await; trace!("Stop create_channel_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); + self.shutdown_barrier + .fetch_sub(Self::BARR_CHANNEL, Ordering::Relaxed); } - async fn open_mgr( - &self, - mut a2b_stream_open_r: mpsc::UnboundedReceiver, - a2b_close_stream_s: mpsc::UnboundedSender, - a2p_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, - shutdown_open_mgr_receiver: oneshot::Receiver<()>, - ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); - trace!("Start open_mgr"); - let mut stream_ids = self.offset_sid; - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - let mut shutdown_open_mgr_receiver = shutdown_open_mgr_receiver.fuse(); - //from api or shutdown signal - while let Some((prio, promises, p2a_return_stream)) = select! { - next = a2b_stream_open_r.recv().fuse() => next, - _ = &mut shutdown_open_mgr_receiver => None, - } { - debug!(?prio, ?promises, "Got request to open a new steam"); - //TODO: a2b_stream_open_r isn't closed on api_close yet. This needs to change. - //till then just check here if we are closed and in that case do nothing (not - // even answer) - if self.shutdown_info.read().await.error.is_some() { - continue; - } - - let a2p_msg_s = a2p_msg_s.clone(); - let sid = stream_ids; - let stream = self - .create_stream(sid, prio, promises, a2p_msg_s, &a2b_close_stream_s) - .await; - if self - .send_frame( - Frame::OpenStream { - sid, - prio, - promises, - }, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await - { - //On error, we drop this, so it gets closed and client will handle this as an - // Err any way (: - p2a_return_stream.send(stream).unwrap(); - stream_ids += Sid::from(1); - } - } - trace!("Stop open_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); - } - - /// when activated this function will drop the participant completely and - /// wait for everything to go right! Then return 1. Shutting down - /// Streams for API and End user! 2. Wait for all "prio queued" Messages - /// to be send. 3. Send Stream + /// sink shutdown: + /// Situation AS, AR, BS, BR. A wants to close. + /// AS shutdown. + /// BR notices shutdown and tries to stops BS. (success) + /// BS shutdown + /// AR notices shutdown and tries to stop AS. (fails) + /// For the case where BS didn't get shutdowned, e.g. by a handing situation + /// on the remote, we have a timeout to also force close AR. + /// + /// This fn will: + /// - 1. stop api to interact with bparticipant by closing sendmsg and + /// openstream + /// - 2. stop the send_mgr (it will take care of clearing the + /// queue and finish with a Shutdown) + /// - (3). force stop recv after 60 + /// seconds + /// - (4). this fn finishes last and afterwards BParticipant + /// drops + /// + /// before calling this fn, make sure `s2b_create_channel` is closed! /// If BParticipant kills itself managers stay active till this function is /// called by api to get the result status async fn participant_shutdown_mgr( &self, s2b_shutdown_bparticipant_r: oneshot::Receiver, - shutdown_open_mgr_sender: oneshot::Sender<()>, - shutdown_stream_close_mgr_sender: oneshot::Sender>, - shutdown_send_mgr_sender: oneshot::Sender>, + b2b_close_send_protocol_s: async_channel::Sender, + b2b_force_close_recv_protocol_s: async_channel::Sender, ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); + let wait_for_manager = || async { + let mut sleep = 0.01f64; + loop { + let bytes = self.shutdown_barrier.load(Ordering::Relaxed); + if bytes == 0 { + break; + } + sleep *= 1.4; + tokio::time::sleep(Duration::from_secs_f64(sleep)).await; + if sleep > 0.2 { + trace!(?bytes, "wait for mgr to close"); + } + } + }; + trace!("Start participant_shutdown_mgr"); - let sender = s2b_shutdown_bparticipant_r.await.unwrap(); + let (timeout_time, sender) = s2b_shutdown_bparticipant_r.await.unwrap(); + debug!("participant_shutdown_mgr triggered"); - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - - self.close_api(None).await; - - debug!("Closing all managers"); - shutdown_open_mgr_sender - .send(()) - .expect("open_mgr must have crashed before"); - let (b2b_stream_close_shutdown_confirmed_s, b2b_stream_close_shutdown_confirmed_r) = - oneshot::channel(); - shutdown_stream_close_mgr_sender - .send(b2b_stream_close_shutdown_confirmed_s) - .expect("stream_close_mgr must have crashed before"); - // We need to wait for the stream_close_mgr BEFORE send_mgr, as the - // stream_close_mgr needs to wait on the API to drop `Stream` and be triggered - // It will then sleep for streams to be flushed in PRIO, and send_mgr is - // responsible for ticking PRIO WHILE this happens, so we cant close it before! - b2b_stream_close_shutdown_confirmed_r.await.unwrap(); - - //closing send_mgr now: - let (b2b_prios_flushed_s, b2b_prios_flushed_r) = oneshot::channel(); - shutdown_send_mgr_sender - .send(b2b_prios_flushed_s) - .expect("stream_close_mgr must have crashed before"); - b2b_prios_flushed_r.await.unwrap(); - - if Some(ParticipantError::ParticipantDisconnected) != self.shutdown_info.read().await.error + debug!("Closing all streams for send"); { - debug!("Sending shutdown frame after flushed all prios"); - if !self - .send_frame( - Frame::Shutdown, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await - { - warn!("couldn't send shutdown frame, are channels already closed?"); + let lock = self.streams.read().await; + for si in lock.values() { + si.send_closed.store(true, Ordering::Relaxed); } } - debug!("Closing all channels, after flushed prios"); - for (cid, ci) in self.channels.write().await.drain() { - let ci = ci.into_inner(); - if let Err(e) = ci.b2r_read_shutdown.send(()) { + let lock = self.channels.read().await; + assert!( + !lock.is_empty(), + "no channel existed remote_pid={}", + self.remote_pid + ); + for cid in lock.keys() { + if let Err(e) = b2b_close_send_protocol_s.send(*cid).await { debug!( ?e, ?cid, - "Seems like this read protocol got already dropped by closing the Stream \ - itself, ignoring" - ); - }; - } - - //Wait for other bparticipants mgr to close via AtomicUsize - const SLEEP_TIME: Duration = Duration::from_millis(5); - const ALLOWED_MANAGER: usize = 1; - tokio::time::sleep(SLEEP_TIME).await; - let mut i: u32 = 1; - while self.running_mgr.load(Ordering::Relaxed) > ALLOWED_MANAGER { - i += 1; - if i.rem_euclid(10) == 1 { - trace!( - ?ALLOWED_MANAGER, - "Waiting for bparticipant mgr to shut down, remaining {}", - self.running_mgr.load(Ordering::Relaxed) - ALLOWED_MANAGER + "closing send_mgr may fail if we got a recv error simultaneously" ); } - tokio::time::sleep(SLEEP_TIME * i).await; } - trace!("All BParticipant mgr (except me) are shut down now"); + drop(lock); + + trace!("wait for other managers"); + let timeout = tokio::time::sleep(timeout_time); + let timeout = tokio::select! { + _ = wait_for_manager() => false, + _ = timeout => true, + }; + if timeout { + warn!("timeout triggered: for killing recv"); + let lock = self.channels.read().await; + for cid in lock.keys() { + if let Err(e) = b2b_force_close_recv_protocol_s.send(*cid).await { + debug!( + ?e, + ?cid, + "closing recv_mgr may fail if we got a recv error simultaneously" + ); + } + } + } + + trace!("wait again"); + wait_for_manager().await; + + sender.send(Ok(())).unwrap(); #[cfg(feature = "metrics")] self.metrics.participants_disconnected_total.inc(); - debug!("BParticipant close done"); - - let mut lock = self.shutdown_info.write().await; - sender - .send(match lock.error.take() { - None => Ok(()), - Some(ParticipantError::ProtocolFailedUnrecoverable) => { - Err(ParticipantError::ProtocolFailedUnrecoverable) - }, - Some(ParticipantError::ParticipantDisconnected) => Ok(()), - }) - .unwrap(); - trace!("Stop participant_shutdown_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); - } - - async fn stream_close_mgr( - &self, - mut a2b_close_stream_r: mpsc::UnboundedReceiver, - shutdown_stream_close_mgr_receiver: oneshot::Receiver>, - b2p_notify_empty_stream_s: crossbeam_channel::Sender<(Sid, oneshot::Sender<()>)>, - ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); - trace!("Start stream_close_mgr"); - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - let mut shutdown_stream_close_mgr_receiver = shutdown_stream_close_mgr_receiver.fuse(); - let mut b2b_stream_close_shutdown_confirmed_s = None; - - //from api or shutdown signal - while let Some(sid) = select! { - next = a2b_close_stream_r.recv().fuse() => next, - sender = &mut shutdown_stream_close_mgr_receiver => { - b2b_stream_close_shutdown_confirmed_s = Some(sender.unwrap()); - None - } - } { - //TODO: make this concurrent! - //TODO: Performance, closing is slow! - self.delete_stream( - sid, - Some(b2p_notify_empty_stream_s.clone()), - false, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await; - } - trace!("deleting all leftover streams"); - let sids = self - .streams - .read() - .await - .keys() - .cloned() - .collect::>(); - for sid in sids { - //flushing is still important, e.g. when Participant::drop is called (but - // Stream:drop isn't)! - self.delete_stream( - sid, - Some(b2p_notify_empty_stream_s.clone()), - false, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await; - } - if b2b_stream_close_shutdown_confirmed_s.is_none() { - b2b_stream_close_shutdown_confirmed_s = - Some(shutdown_stream_close_mgr_receiver.await.unwrap()); - } - b2b_stream_close_shutdown_confirmed_s - .unwrap() - .send(()) - .unwrap(); - trace!("Stop stream_close_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); } + /// Stopping API and participant usage + /// Protocol will take care of the order of the frame async fn delete_stream( &self, sid: Sid, - b2p_notify_empty_stream_s: Option)>>, - from_remote: bool, - #[cfg(feature = "metrics")] frames_out_total_cache: &mut MultiCidFrameCache, + /* #[cfg(feature = "metrics")] frames_out_total_cache: &mut MultiCidFrameCache, */ ) { - //This needs to first stop clients from sending any more. - //Then it will wait for all pending messages (in prio) to be send to the - // protocol After this happened the stream is closed - //Only after all messages are send to the protocol, we can send the CloseStream - // frame! If we would send it before, all followup messages couldn't - // be handled at the remote side. - async { - trace!("Stopping api to use this stream"); - match self.streams.read().await.get(&sid) { - Some(si) => { - si.send_closed.store(true, Ordering::Relaxed); - si.b2a_msg_recv_s.lock().await.close(); - }, - None => trace!( - "Couldn't find the stream, might be simultaneous close from local/remote" - ), - } - - if !from_remote { - trace!("Wait for stream to be flushed"); - let (s2b_stream_finished_closed_s, s2b_stream_finished_closed_r) = - oneshot::channel(); - b2p_notify_empty_stream_s - .expect("needs to be set when from_remote is false") - .send((sid, s2b_stream_finished_closed_s)) - .unwrap(); - s2b_stream_finished_closed_r.await.unwrap(); - - trace!("Stream was successfully flushed"); - } - - #[cfg(feature = "metrics")] - self.metrics - .streams_closed_total - .with_label_values(&[&self.remote_pid_string]) - .inc(); - //only now remove the Stream, that means we can still recv on it. - self.streams.write().await.remove(&sid); - - if !from_remote { - self.send_frame( - Frame::CloseStream { sid }, - #[cfg(feature = "metrics")] - frames_out_total_cache, - ) - .await; - } + let stream = { self.streams.write().await.remove(&sid) }; + match stream { + Some(si) => { + si.send_closed.store(true, Ordering::Relaxed); + si.b2a_msg_recv_s.lock().await.close(); + }, + None => { + trace!("Couldn't find the stream, might be simultaneous close from local/remote") + }, } - .instrument(tracing::info_span!("close", ?sid, ?from_remote)) - .await; + /* + #[cfg(feature = "metrics")] + self.metrics + .streams_closed_total + .with_label_values(&[&self.remote_pid_string]) + .inc();*/ } async fn create_stream( @@ -843,10 +614,11 @@ impl BParticipant { sid: Sid, prio: Prio, promises: Promises, - a2p_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, + guaranteed_bandwidth: Bandwidth, + a2b_msg_s: &crossbeam_channel::Sender<(Sid, Arc)>, a2b_close_stream_s: &mpsc::UnboundedSender, ) -> Stream { - let (b2a_msg_recv_s, b2a_msg_recv_r) = async_channel::unbounded::(); + let (b2a_msg_recv_s, b2a_msg_recv_r) = async_channel::unbounded::(); let send_closed = Arc::new(AtomicBool::new(false)); self.streams.write().await.insert(sid, StreamInfo { prio, @@ -864,38 +636,248 @@ impl BParticipant { sid, prio, promises, + guaranteed_bandwidth, send_closed, - a2p_msg_s, + a2b_msg_s.clone(), b2a_msg_recv_r, a2b_close_stream_s.clone(), ) } +} - async fn close_write_api(&self, reason: Option) { - trace!(?reason, "close_api"); - let mut lock = self.shutdown_info.write().await; - if let Some(r) = reason { - lock.error = Some(r); - } - lock.b2b_close_stream_opened_sender_s - .take() - .map(|s| s.send(())); +#[cfg(test)] +mod tests { + use super::*; + use tokio::{ + runtime::Runtime, + sync::{mpsc, oneshot}, + task::JoinHandle, + }; - debug!("Closing all streams for write"); - for (sid, si) in self.streams.read().await.iter() { - trace!(?sid, "Shutting down Stream for write"); - si.send_closed.store(true, Ordering::Relaxed); - } + fn mock_bparticipant() -> ( + Arc, + mpsc::UnboundedSender, + mpsc::UnboundedReceiver, + mpsc::UnboundedSender, + oneshot::Sender, + mpsc::UnboundedReceiver, + JoinHandle<()>, + ) { + let runtime = Arc::new(tokio::runtime::Runtime::new().unwrap()); + let runtime_clone = Arc::clone(&runtime); + + let (b2s_prio_statistic_s, b2s_prio_statistic_r) = + mpsc::unbounded_channel::(); + + let ( + bparticipant, + a2b_open_stream_s, + b2a_stream_opened_r, + s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + ) = runtime_clone.block_on(async move { + let pid = Pid::fake(1); + let sid = Sid::new(1000); + let metrics = Arc::new(NetworkMetrics::new(&pid).unwrap()); + + BParticipant::new(pid, sid, Arc::clone(&metrics)) + }); + + let handle = runtime_clone.spawn(bparticipant.run(b2s_prio_statistic_s)); + ( + runtime_clone, + a2b_open_stream_s, + b2a_stream_opened_r, + s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) } - ///closing api::Participant is done by closing all channels, expect for the - /// shutdown channel at this point! - async fn close_api(&self, reason: Option) { - self.close_write_api(reason).await; - debug!("Closing all streams"); - for (sid, si) in self.streams.read().await.iter() { - trace!(?sid, "Shutting down Stream"); - si.b2a_msg_recv_s.lock().await.close(); - } + async fn mock_mpsc( + cid: Cid, + _runtime: &Arc, + create_channel: &mut mpsc::UnboundedSender, + ) -> Protocols { + let (s1, r1) = mpsc::channel(100); + let (s2, r2) = mpsc::channel(100); + let p1 = Protocols::new_mpsc(s1, r2); + let (complete_s, complete_r) = oneshot::channel(); + create_channel + .send((cid, Sid::new(0), p1, complete_s)) + .unwrap(); + complete_r.await.unwrap(); + Protocols::new_mpsc(s2, r1) + } + + #[test] + fn close_bparticipant_by_timeout_during_close() { + let ( + runtime, + a2b_open_stream_s, + b2a_stream_opened_r, + mut s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) = mock_bparticipant(); + + let _remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s)); + std::thread::sleep(Duration::from_millis(50)); + + let (s, r) = oneshot::channel(); + let before = Instant::now(); + runtime.block_on(async { + drop(s2b_create_channel_s); + s2b_shutdown_bparticipant_s + .send((Duration::from_secs(1), s)) + .unwrap(); + r.await.unwrap().unwrap(); + }); + assert!( + before.elapsed() > Duration::from_millis(900), + "timeout wasn't triggered" + ); + + runtime.block_on(handle).unwrap(); + + drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r)); + drop(runtime); + } + + #[test] + fn close_bparticipant_cleanly() { + let ( + runtime, + a2b_open_stream_s, + b2a_stream_opened_r, + mut s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) = mock_bparticipant(); + + let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s)); + std::thread::sleep(Duration::from_millis(50)); + + let (s, r) = oneshot::channel(); + let before = Instant::now(); + runtime.block_on(async { + drop(s2b_create_channel_s); + s2b_shutdown_bparticipant_s + .send((Duration::from_secs(2), s)) + .unwrap(); + drop(remote); // remote needs to be dropped as soon as local.sender is closed + r.await.unwrap().unwrap(); + }); + assert!( + before.elapsed() < Duration::from_millis(1900), + "timeout was triggered" + ); + + runtime.block_on(handle).unwrap(); + + drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r)); + drop(runtime); + } + + #[test] + fn create_stream() { + let ( + runtime, + a2b_open_stream_s, + b2a_stream_opened_r, + mut s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) = mock_bparticipant(); + + let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s)); + std::thread::sleep(Duration::from_millis(50)); + + // created stream + let (rs, mut rr) = remote.split(); + let (stream_sender, _stream_receiver) = oneshot::channel(); + a2b_open_stream_s + .send((7u8, Promises::ENCRYPTED, 1_000_000, stream_sender)) + .unwrap(); + + let stream_event = runtime.block_on(rr.recv()).unwrap(); + match stream_event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + assert_eq!(sid, Sid::new(1000)); + assert_eq!(prio, 7u8); + assert_eq!(promises, Promises::ENCRYPTED); + assert_eq!(guaranteed_bandwidth, 1_000_000); + }, + _ => panic!("wrong event"), + }; + + let (s, r) = oneshot::channel(); + runtime.block_on(async { + drop(s2b_create_channel_s); + s2b_shutdown_bparticipant_s + .send((Duration::from_secs(1), s)) + .unwrap(); + drop((rs, rr)); + r.await.unwrap().unwrap(); + }); + + runtime.block_on(handle).unwrap(); + + drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r)); + drop(runtime); + } + + #[test] + fn created_stream() { + let ( + runtime, + a2b_open_stream_s, + mut b2a_stream_opened_r, + mut s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) = mock_bparticipant(); + + let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s)); + std::thread::sleep(Duration::from_millis(50)); + + // create stream + let (mut rs, rr) = remote.split(); + runtime + .block_on(rs.send(ProtocolEvent::OpenStream { + sid: Sid::new(1000), + prio: 9u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + })) + .unwrap(); + + let stream = runtime.block_on(b2a_stream_opened_r.recv()).unwrap(); + assert_eq!(stream.promises(), Promises::ORDERED); + + let (s, r) = oneshot::channel(); + runtime.block_on(async { + drop(s2b_create_channel_s); + s2b_shutdown_bparticipant_s + .send((Duration::from_secs(1), s)) + .unwrap(); + drop((rs, rr)); + r.await.unwrap().unwrap(); + }); + + runtime.block_on(handle).unwrap(); + + drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r)); + drop(runtime); } } diff --git a/network/src/prios.rs b/network/src/prios.rs deleted file mode 100644 index a544a31241..0000000000 --- a/network/src/prios.rs +++ /dev/null @@ -1,697 +0,0 @@ -//!Priorities are handled the following way. -//!Prios from 0-63 are allowed. -//!all 5 numbers the throughput is halved. -//!E.g. in the same time 100 prio0 messages are send, only 50 prio5, 25 prio10, -//! 12 prio15 or 6 prio20 messages are send. Note: TODO: prio0 will be send -//! immediately when found! -#[cfg(feature = "metrics")] -use crate::metrics::NetworkMetrics; -use crate::{ - message::OutgoingMessage, - types::{Frame, Prio, Sid}, -}; -use crossbeam_channel::{unbounded, Receiver, Sender}; -use std::collections::{HashMap, HashSet, VecDeque}; -#[cfg(feature = "metrics")] use std::sync::Arc; -use tokio::sync::oneshot; -use tracing::trace; - -const PRIO_MAX: usize = 64; - -#[derive(Default)] -struct PidSidInfo { - len: u64, - empty_notify: Option>, -} - -pub(crate) struct PrioManager { - points: [u32; PRIO_MAX], - messages: [VecDeque<(Sid, OutgoingMessage)>; PRIO_MAX], - messages_rx: Receiver<(Prio, Sid, OutgoingMessage)>, - sid_owned: HashMap, - //you can register to be notified if a pid_sid combination is flushed completely here - sid_flushed_rx: Receiver<(Sid, oneshot::Sender<()>)>, - queued: HashSet, - #[cfg(feature = "metrics")] - metrics: Arc, - #[cfg(feature = "metrics")] - pid: String, -} - -impl PrioManager { - const PRIOS: [u32; PRIO_MAX] = [ - 100, 115, 132, 152, 174, 200, 230, 264, 303, 348, 400, 459, 528, 606, 696, 800, 919, 1056, - 1213, 1393, 1600, 1838, 2111, 2425, 2786, 3200, 3676, 4222, 4850, 5572, 6400, 7352, 8445, - 9701, 11143, 12800, 14703, 16890, 19401, 22286, 25600, 29407, 33779, 38802, 44572, 51200, - 58813, 67559, 77605, 89144, 102400, 117627, 135118, 155209, 178289, 204800, 235253, 270235, - 310419, 356578, 409600, 470507, 540470, 620838, - ]; - - #[allow(clippy::type_complexity)] - pub fn new( - #[cfg(feature = "metrics")] metrics: Arc, - pid: String, - ) -> ( - Self, - Sender<(Prio, Sid, OutgoingMessage)>, - Sender<(Sid, oneshot::Sender<()>)>, - ) { - #[cfg(not(feature = "metrics"))] - let _pid = pid; - // (a2p_msg_s, a2p_msg_r) - let (messages_tx, messages_rx) = unbounded(); - let (sid_flushed_tx, sid_flushed_rx) = unbounded(); - ( - Self { - points: [0; PRIO_MAX], - messages: [ - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - ], - messages_rx, - queued: HashSet::new(), //TODO: optimize with u64 and 64 bits - sid_flushed_rx, - sid_owned: HashMap::new(), - #[cfg(feature = "metrics")] - metrics, - #[cfg(feature = "metrics")] - pid, - }, - messages_tx, - sid_flushed_tx, - ) - } - - async fn tick(&mut self) { - // Check Range - for (prio, sid, msg) in self.messages_rx.try_iter() { - debug_assert!(prio as usize <= PRIO_MAX); - #[cfg(feature = "metrics")] - { - let sid_string = sid.to_string(); - self.metrics - .message_out_total - .with_label_values(&[&self.pid, &sid_string]) - .inc(); - self.metrics - .message_out_throughput - .with_label_values(&[&self.pid, &sid_string]) - .inc_by(msg.buffer.data.len() as u64); - } - - //trace!(?prio, ?sid_string, "tick"); - self.queued.insert(prio); - self.messages[prio as usize].push_back((sid, msg)); - self.sid_owned.entry(sid).or_default().len += 1; - } - //this must be AFTER messages - for (sid, return_sender) in self.sid_flushed_rx.try_iter() { - #[cfg(feature = "metrics")] - self.metrics - .streams_flushed - .with_label_values(&[&self.pid]) - .inc(); - if let Some(cnt) = self.sid_owned.get_mut(&sid) { - // register sender - cnt.empty_notify = Some(return_sender); - trace!(?sid, "register empty notify"); - } else { - // return immediately - return_sender.send(()).unwrap(); - trace!(?sid, "return immediately that stream is empty"); - } - } - } - - //if None returned, we are empty! - fn calc_next_prio(&self) -> Option { - // compare all queued prios, max 64 operations - let mut lowest = std::u32::MAX; - let mut lowest_id = None; - for &n in &self.queued { - let n_points = self.points[n as usize]; - if n_points < lowest { - lowest = n_points; - lowest_id = Some(n) - } else if n_points == lowest && lowest_id.is_some() && n < lowest_id.unwrap() { - //on equal points lowest first! - lowest_id = Some(n) - } - } - lowest_id - /* - self.queued - .iter() - .min_by_key(|&n| self.points[*n as usize]).cloned()*/ - } - - /// no_of_frames = frames.len() - /// Your goal is to try to find a realistic no_of_frames! - /// no_of_frames should be choosen so, that all Frames can be send out till - /// the next tick! - /// - if no_of_frames is too high you will fill either the Socket buffer, - /// or your internal buffer. In that case you will increase latency for - /// high prio messages! - /// - if no_of_frames is too low you wont saturate your Socket fully, thus - /// have a lower bandwidth as possible - pub async fn fill_frames>( - &mut self, - no_of_frames: usize, - frames: &mut E, - ) { - for v in self.messages.iter_mut() { - v.reserve_exact(no_of_frames) - } - self.tick().await; - for _ in 0..no_of_frames { - match self.calc_next_prio() { - Some(prio) => { - //let prio2 = self.calc_next_prio().unwrap(); - //trace!(?prio, "handle next prio"); - self.points[prio as usize] += Self::PRIOS[prio as usize]; - //pop message from front of VecDeque, handle it and push it back, so that all - // => messages with same prio get a fair chance :) - //TODO: evaluate not popping every time - let (sid, mut msg) = self.messages[prio as usize].pop_front().unwrap(); - if msg.fill_next(sid, frames) { - //trace!(?m.mid, "finish message"); - //check if prio is empty - if self.messages[prio as usize].is_empty() { - self.queued.remove(&prio); - } - //decrease pid_sid counter by 1 again - let cnt = self.sid_owned.get_mut(&sid).expect( - "The pid_sid_owned counter works wrong, more pid,sid removed than \ - inserted", - ); - cnt.len -= 1; - if cnt.len == 0 { - let cnt = self.sid_owned.remove(&sid).unwrap(); - if let Some(empty_notify) = cnt.empty_notify { - empty_notify.send(()).unwrap(); - trace!(?sid, "returned that stream is empty"); - } - } - } else { - self.messages[prio as usize].push_front((sid, msg)); - } - }, - None => { - //QUEUE is empty, we are clearing the POINTS to not build up huge pipes of - // POINTS on a prio from the past - self.points = [0; PRIO_MAX]; - break; - }, - } - } - } -} - -impl std::fmt::Debug for PrioManager { - #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut cnt = 0; - for m in self.messages.iter() { - cnt += m.len(); - } - write!(f, "PrioManager(len: {}, queued: {:?})", cnt, &self.queued,) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - message::{MessageBuffer, OutgoingMessage}, - metrics::NetworkMetrics, - prios::*, - types::{Frame, Pid, Prio, Sid}, - }; - use crossbeam_channel::Sender; - use std::{collections::VecDeque, sync::Arc}; - use tokio::{runtime::Runtime, sync::oneshot}; - - const SIZE: u64 = OutgoingMessage::FRAME_DATA_SIZE; - const USIZE: usize = OutgoingMessage::FRAME_DATA_SIZE as usize; - - #[allow(clippy::type_complexity)] - fn mock_new() -> ( - PrioManager, - Sender<(Prio, Sid, OutgoingMessage)>, - Sender<(Sid, oneshot::Sender<()>)>, - ) { - let pid = Pid::fake(1); - PrioManager::new( - Arc::new(NetworkMetrics::new(&pid).unwrap()), - pid.to_string(), - ) - } - - fn mock_out(prio: Prio, sid: u64) -> (Prio, Sid, OutgoingMessage) { - let sid = Sid::new(sid); - (prio, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { - data: vec![48, 49, 50], - }), - cursor: 0, - mid: 1, - sid, - }) - } - - fn mock_out_large(prio: Prio, sid: u64) -> (Prio, Sid, OutgoingMessage) { - let sid = Sid::new(sid); - let mut data = vec![48; USIZE]; - data.append(&mut vec![49; USIZE]); - data.append(&mut vec![50; 20]); - (prio, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { data }), - cursor: 0, - mid: 1, - sid, - }) - } - - fn assert_header(frames: &mut VecDeque<(Sid, Frame)>, f_sid: u64, f_length: u64) { - let frame = frames - .pop_front() - .expect("Frames vecdeque doesn't contain enough frames!") - .1; - if let Frame::DataHeader { mid, sid, length } = frame { - assert_eq!(mid, 1); - assert_eq!(sid, Sid::new(f_sid)); - assert_eq!(length, f_length); - } else { - panic!("Wrong frame type!, expected DataHeader"); - } - } - - fn assert_data(frames: &mut VecDeque<(Sid, Frame)>, f_start: u64, f_data: Vec) { - let frame = frames - .pop_front() - .expect("Frames vecdeque doesn't contain enough frames!") - .1; - if let Frame::Data { mid, start, data } = frame { - assert_eq!(mid, 1); - assert_eq!(start, f_start); - assert_eq!(data, f_data); - } else { - panic!("Wrong frame type!, expected Data"); - } - } - - #[test] - fn single_p16() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(16, 1337)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 1337, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert!(frames.is_empty()); - } - - #[test] - fn single_p16_p20() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(16, 1337)).unwrap(); - msg_tx.send(mock_out(20, 42)).unwrap(); - let mut frames = VecDeque::new(); - - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - assert_header(&mut frames, 1337, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 42, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert!(frames.is_empty()); - } - - #[test] - fn single_p20_p16() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(20, 42)).unwrap(); - msg_tx.send(mock_out(16, 1337)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 1337, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 42, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert!(frames.is_empty()); - } - - #[test] - fn multiple_p16_p20() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(20, 2)).unwrap(); - msg_tx.send(mock_out(16, 1)).unwrap(); - msg_tx.send(mock_out(16, 3)).unwrap(); - msg_tx.send(mock_out(16, 5)).unwrap(); - msg_tx.send(mock_out(20, 4)).unwrap(); - msg_tx.send(mock_out(20, 7)).unwrap(); - msg_tx.send(mock_out(16, 6)).unwrap(); - msg_tx.send(mock_out(20, 10)).unwrap(); - msg_tx.send(mock_out(16, 8)).unwrap(); - msg_tx.send(mock_out(20, 12)).unwrap(); - msg_tx.send(mock_out(16, 9)).unwrap(); - msg_tx.send(mock_out(16, 11)).unwrap(); - msg_tx.send(mock_out(20, 13)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - for i in 1..14 { - assert_header(&mut frames, i, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - } - assert!(frames.is_empty()); - } - - #[test] - fn multiple_fill_frames_p16_p20() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(20, 2)).unwrap(); - msg_tx.send(mock_out(16, 1)).unwrap(); - msg_tx.send(mock_out(16, 3)).unwrap(); - msg_tx.send(mock_out(16, 5)).unwrap(); - msg_tx.send(mock_out(20, 4)).unwrap(); - msg_tx.send(mock_out(20, 7)).unwrap(); - msg_tx.send(mock_out(16, 6)).unwrap(); - msg_tx.send(mock_out(20, 10)).unwrap(); - msg_tx.send(mock_out(16, 8)).unwrap(); - msg_tx.send(mock_out(20, 12)).unwrap(); - msg_tx.send(mock_out(16, 9)).unwrap(); - msg_tx.send(mock_out(16, 11)).unwrap(); - msg_tx.send(mock_out(20, 13)).unwrap(); - - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(3, &mut frames)); - for i in 1..4 { - assert_header(&mut frames, i, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - } - assert!(frames.is_empty()); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(11, &mut frames)); - for i in 4..14 { - assert_header(&mut frames, i, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - } - assert!(frames.is_empty()); - } - - #[test] - fn single_large_p16() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out_large(16, 1)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 1, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert!(frames.is_empty()); - } - - #[test] - fn multiple_large_p16() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out_large(16, 1)).unwrap(); - msg_tx.send(mock_out_large(16, 2)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 1, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert_header(&mut frames, 2, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert!(frames.is_empty()); - } - - #[test] - fn multiple_large_p16_sudden_p0() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out_large(16, 1)).unwrap(); - msg_tx.send(mock_out_large(16, 2)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2, &mut frames)); - - assert_header(&mut frames, 1, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - - msg_tx.send(mock_out(0, 3)).unwrap(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 3, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert_header(&mut frames, 2, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert!(frames.is_empty()); - } - - #[test] - fn single_p20_thousand_p16_at_once() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - for _ in 0..998 { - msg_tx.send(mock_out(16, 2)).unwrap(); - } - msg_tx.send(mock_out(20, 1)).unwrap(); - msg_tx.send(mock_out(16, 2)).unwrap(); - msg_tx.send(mock_out(16, 2)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - - assert_header(&mut frames, 2, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 1, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 2, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 2, 3); - //unimportant - } - - #[test] - fn single_p20_thousand_p16_later() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - for _ in 0..998 { - msg_tx.send(mock_out(16, 2)).unwrap(); - } - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - //^unimportant frames, gonna be dropped - msg_tx.send(mock_out(20, 1)).unwrap(); - msg_tx.send(mock_out(16, 2)).unwrap(); - msg_tx.send(mock_out(16, 2)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - - //important in that test is, that after the first frames got cleared i reset - // the Points even though 998 prio 16 messages have been send at this - // point and 0 prio20 messages the next message is a prio16 message - // again, and only then prio20! we dont want to build dept over a idling - // connection - assert_header(&mut frames, 2, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 1, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 2, 3); - //unimportant - } - - #[test] - fn gigantic_message() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - let mut data = vec![1; USIZE]; - data.extend_from_slice(&[2; USIZE]); - data.extend_from_slice(&[3; USIZE]); - data.extend_from_slice(&[4; USIZE]); - data.extend_from_slice(&[5; USIZE]); - let sid = Sid::new(2); - msg_tx - .send((16, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { data }), - cursor: 0, - mid: 1, - sid, - })) - .unwrap(); - - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - - assert_header(&mut frames, 2, 7000); - assert_data(&mut frames, 0, vec![1; USIZE]); - assert_data(&mut frames, 1400, vec![2; USIZE]); - assert_data(&mut frames, 2800, vec![3; USIZE]); - assert_data(&mut frames, 4200, vec![4; USIZE]); - assert_data(&mut frames, 5600, vec![5; USIZE]); - } - - #[test] - fn gigantic_message_order() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - let mut data = vec![1; USIZE]; - data.extend_from_slice(&[2; USIZE]); - data.extend_from_slice(&[3; USIZE]); - data.extend_from_slice(&[4; USIZE]); - data.extend_from_slice(&[5; USIZE]); - let sid = Sid::new(2); - msg_tx - .send((16, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { data }), - cursor: 0, - mid: 1, - sid, - })) - .unwrap(); - msg_tx.send(mock_out(16, 8)).unwrap(); - - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - - assert_header(&mut frames, 2, 7000); - assert_data(&mut frames, 0, vec![1; USIZE]); - assert_data(&mut frames, 1400, vec![2; USIZE]); - assert_data(&mut frames, 2800, vec![3; USIZE]); - assert_data(&mut frames, 4200, vec![4; USIZE]); - assert_data(&mut frames, 5600, vec![5; USIZE]); - assert_header(&mut frames, 8, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - } - - #[test] - fn gigantic_message_order_other_prio() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - let mut data = vec![1; USIZE]; - data.extend_from_slice(&[2; USIZE]); - data.extend_from_slice(&[3; USIZE]); - data.extend_from_slice(&[4; USIZE]); - data.extend_from_slice(&[5; USIZE]); - let sid = Sid::new(2); - msg_tx - .send((16, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { data }), - cursor: 0, - mid: 1, - sid, - })) - .unwrap(); - msg_tx.send(mock_out(20, 8)).unwrap(); - - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - - assert_header(&mut frames, 2, 7000); - assert_data(&mut frames, 0, vec![1; USIZE]); - assert_header(&mut frames, 8, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_data(&mut frames, 1400, vec![2; USIZE]); - assert_data(&mut frames, 2800, vec![3; USIZE]); - assert_data(&mut frames, 4200, vec![4; USIZE]); - assert_data(&mut frames, 5600, vec![5; USIZE]); - } -} diff --git a/network/src/protocols.rs b/network/src/protocols.rs deleted file mode 100644 index a18c1e1cbd..0000000000 --- a/network/src/protocols.rs +++ /dev/null @@ -1,591 +0,0 @@ -#[cfg(feature = "metrics")] -use crate::metrics::{CidFrameCache, NetworkMetrics}; -use crate::{ - participant::C2pFrame, - types::{Cid, Frame}, -}; -use futures_util::{future::Fuse, FutureExt}; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::{TcpStream, UdpSocket}, - select, - sync::{mpsc, oneshot, Mutex}, -}; - -use std::{convert::TryFrom, net::SocketAddr, sync::Arc}; -use tracing::*; - -// Reserving bytes 0, 10, 13 as i have enough space and want to make it easy to -// detect a invalid client, e.g. sending an empty line would make 10 first char -// const FRAME_RESERVED_1: u8 = 0; -const FRAME_HANDSHAKE: u8 = 1; -const FRAME_INIT: u8 = 2; -const FRAME_SHUTDOWN: u8 = 3; -const FRAME_OPEN_STREAM: u8 = 4; -const FRAME_CLOSE_STREAM: u8 = 5; -const FRAME_DATA_HEADER: u8 = 6; -const FRAME_DATA: u8 = 7; -const FRAME_RAW: u8 = 8; -//const FRAME_RESERVED_2: u8 = 10; -//const FRAME_RESERVED_3: u8 = 13; - -#[derive(Debug)] -pub(crate) enum Protocols { - Tcp(TcpProtocol), - Udp(UdpProtocol), - //Mpsc(MpscChannel), -} - -#[derive(Debug)] -pub(crate) struct TcpProtocol { - read_stream: tokio::sync::Mutex, - write_stream: tokio::sync::Mutex, - #[cfg(feature = "metrics")] - metrics: Arc, -} - -#[derive(Debug)] -pub(crate) struct UdpProtocol { - socket: Arc, - remote_addr: SocketAddr, - #[cfg(feature = "metrics")] - metrics: Arc, - data_in: Mutex>>, -} - -//TODO: PERFORMACE: Use BufWriter and BufReader from std::io! -impl TcpProtocol { - pub(crate) fn new( - stream: TcpStream, - #[cfg(feature = "metrics")] metrics: Arc, - ) -> Self { - let (read_stream, write_stream) = stream.into_split(); - Self { - read_stream: tokio::sync::Mutex::new(read_stream), - write_stream: tokio::sync::Mutex::new(write_stream), - #[cfg(feature = "metrics")] - metrics, - } - } - - async fn read_frame( - r: &mut R, - end_receiver: &mut Fuse>, - ) -> Result> { - let handle = |read_result| match read_result { - Ok(_) => Ok(()), - Err(e) => Err(Some(e)), - }; - - let mut frame_no = [0u8; 1]; - match select! { - r = r.read_exact(&mut frame_no).fuse() => Some(r), - _ = end_receiver => None, - } { - Some(read_result) => handle(read_result)?, - None => { - trace!("shutdown requested"); - return Err(None); - }, - }; - - match frame_no[0] { - FRAME_HANDSHAKE => { - let mut bytes = [0u8; 19]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_handshake(bytes)) - }, - FRAME_INIT => { - let mut bytes = [0u8; 32]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_init(bytes)) - }, - FRAME_SHUTDOWN => Ok(Frame::Shutdown), - FRAME_OPEN_STREAM => { - let mut bytes = [0u8; 10]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_open_stream(bytes)) - }, - FRAME_CLOSE_STREAM => { - let mut bytes = [0u8; 8]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_close_stream(bytes)) - }, - FRAME_DATA_HEADER => { - let mut bytes = [0u8; 24]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_data_header(bytes)) - }, - FRAME_DATA => { - let mut bytes = [0u8; 18]; - handle(r.read_exact(&mut bytes).await)?; - let (mid, start, length) = Frame::gen_data(bytes); - let mut data = vec![0; length as usize]; - handle(r.read_exact(&mut data).await)?; - Ok(Frame::Data { mid, start, data }) - }, - FRAME_RAW => { - let mut bytes = [0u8; 2]; - handle(r.read_exact(&mut bytes).await)?; - let length = Frame::gen_raw(bytes); - let mut data = vec![0; length as usize]; - handle(r.read_exact(&mut data).await)?; - Ok(Frame::Raw(data)) - }, - other => { - // report a RAW frame, but cannot rely on the next 2 bytes to be a size. - // guessing 32 bytes, which might help to sort down issues - let mut data = vec![0; 32]; - //keep the first byte! - match r.read(&mut data[1..]).await { - Ok(n) => { - data.truncate(n + 1); - Ok(()) - }, - Err(e) => Err(Some(e)), - }?; - data[0] = other; - warn!(?data, "got a unexpected RAW msg"); - Ok(Frame::Raw(data)) - }, - } - } - - pub async fn read_from_wire( - &self, - cid: Cid, - w2c_cid_frame_s: &mut mpsc::UnboundedSender, - end_r: oneshot::Receiver<()>, - ) { - trace!("Starting up tcp read()"); - #[cfg(feature = "metrics")] - let mut metrics_cache = CidFrameCache::new(self.metrics.frames_wire_in_total.clone(), cid); - #[cfg(feature = "metrics")] - let throughput_cache = self - .metrics - .wire_in_throughput - .with_label_values(&[&cid.to_string()]); - let mut read_stream = self.read_stream.lock().await; - let mut end_r = end_r.fuse(); - - loop { - match Self::read_frame(&mut *read_stream, &mut end_r).await { - Ok(frame) => { - #[cfg(feature = "metrics")] - { - metrics_cache.with_label_values(&frame).inc(); - if let Frame::Data { - mid: _, - start: _, - ref data, - } = frame - { - throughput_cache.inc_by(data.len() as u64); - } - } - if let Err(e) = w2c_cid_frame_s.send((cid, Ok(frame))) { - warn!(?e, "Channel or Participant seems no longer to exist"); - } - }, - Err(e_option) => { - if let Some(e) = e_option { - info!(?e, "Closing tcp protocol due to read error"); - //w2c_cid_frame_s is shared, dropping it wouldn't notify the receiver as - // every channel is holding a sender! thats why Ne - // need a explicit STOP here - if let Err(e) = w2c_cid_frame_s.send((cid, Err(()))) { - warn!(?e, "Channel or Participant seems no longer to exist"); - } - } - //None is clean shutdown - break; - }, - } - } - trace!("Shutting down tcp read()"); - } - - pub async fn write_frame( - w: &mut W, - frame: Frame, - ) -> Result<(), std::io::Error> { - match frame { - Frame::Handshake { - magic_number, - version, - } => { - w.write_all(&FRAME_HANDSHAKE.to_be_bytes()).await?; - w.write_all(&magic_number).await?; - w.write_all(&version[0].to_le_bytes()).await?; - w.write_all(&version[1].to_le_bytes()).await?; - w.write_all(&version[2].to_le_bytes()).await?; - }, - Frame::Init { pid, secret } => { - w.write_all(&FRAME_INIT.to_be_bytes()).await?; - w.write_all(&pid.to_le_bytes()).await?; - w.write_all(&secret.to_le_bytes()).await?; - }, - Frame::Shutdown => { - w.write_all(&FRAME_SHUTDOWN.to_be_bytes()).await?; - }, - Frame::OpenStream { - sid, - prio, - promises, - } => { - w.write_all(&FRAME_OPEN_STREAM.to_be_bytes()).await?; - w.write_all(&sid.to_le_bytes()).await?; - w.write_all(&prio.to_le_bytes()).await?; - w.write_all(&promises.to_le_bytes()).await?; - }, - Frame::CloseStream { sid } => { - w.write_all(&FRAME_CLOSE_STREAM.to_be_bytes()).await?; - w.write_all(&sid.to_le_bytes()).await?; - }, - Frame::DataHeader { mid, sid, length } => { - w.write_all(&FRAME_DATA_HEADER.to_be_bytes()).await?; - w.write_all(&mid.to_le_bytes()).await?; - w.write_all(&sid.to_le_bytes()).await?; - w.write_all(&length.to_le_bytes()).await?; - }, - Frame::Data { mid, start, data } => { - w.write_all(&FRAME_DATA.to_be_bytes()).await?; - w.write_all(&mid.to_le_bytes()).await?; - w.write_all(&start.to_le_bytes()).await?; - w.write_all(&(data.len() as u16).to_le_bytes()).await?; - w.write_all(&data).await?; - }, - Frame::Raw(data) => { - w.write_all(&FRAME_RAW.to_be_bytes()).await?; - w.write_all(&(data.len() as u16).to_le_bytes()).await?; - w.write_all(&data).await?; - }, - }; - Ok(()) - } - - pub async fn write_to_wire(&self, cid: Cid, mut c2w_frame_r: mpsc::UnboundedReceiver) { - trace!("Starting up tcp write()"); - let mut write_stream = self.write_stream.lock().await; - #[cfg(feature = "metrics")] - let mut metrics_cache = CidFrameCache::new(self.metrics.frames_wire_out_total.clone(), cid); - #[cfg(feature = "metrics")] - let throughput_cache = self - .metrics - .wire_out_throughput - .with_label_values(&[&cid.to_string()]); - #[cfg(not(feature = "metrics"))] - let _cid = cid; - - while let Some(frame) = c2w_frame_r.recv().await { - #[cfg(feature = "metrics")] - { - metrics_cache.with_label_values(&frame).inc(); - if let Frame::Data { - mid: _, - start: _, - ref data, - } = frame - { - throughput_cache.inc_by(data.len() as u64); - } - } - if let Err(e) = Self::write_frame(&mut *write_stream, frame).await { - info!( - ?e, - "Got an error writing to tcp, going to close this channel" - ); - c2w_frame_r.close(); - break; - }; - } - trace!("shutting down tcp write()"); - } -} - -impl UdpProtocol { - pub(crate) fn new( - socket: Arc, - remote_addr: SocketAddr, - #[cfg(feature = "metrics")] metrics: Arc, - data_in: mpsc::UnboundedReceiver>, - ) -> Self { - Self { - socket, - remote_addr, - #[cfg(feature = "metrics")] - metrics, - data_in: Mutex::new(data_in), - } - } - - pub async fn read_from_wire( - &self, - cid: Cid, - w2c_cid_frame_s: &mut mpsc::UnboundedSender, - end_r: oneshot::Receiver<()>, - ) { - trace!("Starting up udp read()"); - #[cfg(feature = "metrics")] - let mut metrics_cache = CidFrameCache::new(self.metrics.frames_wire_in_total.clone(), cid); - #[cfg(feature = "metrics")] - let throughput_cache = self - .metrics - .wire_in_throughput - .with_label_values(&[&cid.to_string()]); - let mut data_in = self.data_in.lock().await; - let mut end_r = end_r.fuse(); - while let Some(bytes) = select! { - r = data_in.recv().fuse() => match r { - Some(r) => Some(r), - None => { - info!("Udp read ended"); - w2c_cid_frame_s.send((cid, Err(()))).expect("Channel or Participant seems no longer to exist"); - None - } - }, - _ = &mut end_r => None, - } { - trace!("Got raw UDP message with len: {}", bytes.len()); - let frame_no = bytes[0]; - let frame = match frame_no { - FRAME_HANDSHAKE => { - Frame::gen_handshake(*<&[u8; 19]>::try_from(&bytes[1..20]).unwrap()) - }, - FRAME_INIT => Frame::gen_init(*<&[u8; 32]>::try_from(&bytes[1..33]).unwrap()), - FRAME_SHUTDOWN => Frame::Shutdown, - FRAME_OPEN_STREAM => { - Frame::gen_open_stream(*<&[u8; 10]>::try_from(&bytes[1..11]).unwrap()) - }, - FRAME_CLOSE_STREAM => { - Frame::gen_close_stream(*<&[u8; 8]>::try_from(&bytes[1..9]).unwrap()) - }, - FRAME_DATA_HEADER => { - Frame::gen_data_header(*<&[u8; 24]>::try_from(&bytes[1..25]).unwrap()) - }, - FRAME_DATA => { - let (mid, start, length) = - Frame::gen_data(*<&[u8; 18]>::try_from(&bytes[1..19]).unwrap()); - let mut data = vec![0; length as usize]; - #[cfg(feature = "metrics")] - throughput_cache.inc_by(length as u64); - data.copy_from_slice(&bytes[19..]); - Frame::Data { mid, start, data } - }, - FRAME_RAW => { - let length = Frame::gen_raw(*<&[u8; 2]>::try_from(&bytes[1..3]).unwrap()); - let mut data = vec![0; length as usize]; - data.copy_from_slice(&bytes[3..]); - Frame::Raw(data) - }, - _ => Frame::Raw(bytes), - }; - #[cfg(feature = "metrics")] - metrics_cache.with_label_values(&frame).inc(); - w2c_cid_frame_s.send((cid, Ok(frame))).unwrap(); - } - trace!("Shutting down udp read()"); - } - - pub async fn write_to_wire(&self, cid: Cid, mut c2w_frame_r: mpsc::UnboundedReceiver) { - trace!("Starting up udp write()"); - let mut buffer = [0u8; 2000]; - #[cfg(feature = "metrics")] - let mut metrics_cache = CidFrameCache::new(self.metrics.frames_wire_out_total.clone(), cid); - #[cfg(feature = "metrics")] - let throughput_cache = self - .metrics - .wire_out_throughput - .with_label_values(&[&cid.to_string()]); - #[cfg(not(feature = "metrics"))] - let _cid = cid; - while let Some(frame) = c2w_frame_r.recv().await { - #[cfg(feature = "metrics")] - metrics_cache.with_label_values(&frame).inc(); - let len = match frame { - Frame::Handshake { - magic_number, - version, - } => { - let x = FRAME_HANDSHAKE.to_be_bytes(); - buffer[0] = x[0]; - buffer[1..8].copy_from_slice(&magic_number); - buffer[8..12].copy_from_slice(&version[0].to_le_bytes()); - buffer[12..16].copy_from_slice(&version[1].to_le_bytes()); - buffer[16..20].copy_from_slice(&version[2].to_le_bytes()); - 20 - }, - Frame::Init { pid, secret } => { - buffer[0] = FRAME_INIT.to_be_bytes()[0]; - buffer[1..17].copy_from_slice(&pid.to_le_bytes()); - buffer[17..33].copy_from_slice(&secret.to_le_bytes()); - 33 - }, - Frame::Shutdown => { - buffer[0] = FRAME_SHUTDOWN.to_be_bytes()[0]; - 1 - }, - Frame::OpenStream { - sid, - prio, - promises, - } => { - buffer[0] = FRAME_OPEN_STREAM.to_be_bytes()[0]; - buffer[1..9].copy_from_slice(&sid.to_le_bytes()); - buffer[9] = prio.to_le_bytes()[0]; - buffer[10] = promises.to_le_bytes()[0]; - 11 - }, - Frame::CloseStream { sid } => { - buffer[0] = FRAME_CLOSE_STREAM.to_be_bytes()[0]; - buffer[1..9].copy_from_slice(&sid.to_le_bytes()); - 9 - }, - Frame::DataHeader { mid, sid, length } => { - buffer[0] = FRAME_DATA_HEADER.to_be_bytes()[0]; - buffer[1..9].copy_from_slice(&mid.to_le_bytes()); - buffer[9..17].copy_from_slice(&sid.to_le_bytes()); - buffer[17..25].copy_from_slice(&length.to_le_bytes()); - 25 - }, - Frame::Data { mid, start, data } => { - buffer[0] = FRAME_DATA.to_be_bytes()[0]; - buffer[1..9].copy_from_slice(&mid.to_le_bytes()); - buffer[9..17].copy_from_slice(&start.to_le_bytes()); - buffer[17..19].copy_from_slice(&(data.len() as u16).to_le_bytes()); - buffer[19..(data.len() + 19)].clone_from_slice(&data[..]); - #[cfg(feature = "metrics")] - throughput_cache.inc_by(data.len() as u64); - 19 + data.len() - }, - Frame::Raw(data) => { - buffer[0] = FRAME_RAW.to_be_bytes()[0]; - buffer[1..3].copy_from_slice(&(data.len() as u16).to_le_bytes()); - buffer[3..(data.len() + 3)].clone_from_slice(&data[..]); - 3 + data.len() - }, - }; - let mut start = 0; - while start < len { - trace!(?start, ?len, "Splitting up udp frame in multiple packages"); - match self - .socket - .send_to(&buffer[start..len], self.remote_addr) - .await - { - Ok(n) => { - start += n; - if n != len { - error!( - "THIS DOESN'T WORK, as RECEIVER CURRENTLY ONLY HANDLES 1 FRAME \ - per UDP message. splitting up will fail!" - ); - } - }, - Err(e) => error!(?e, "Need to handle that error!"), - } - } - } - trace!("Shutting down udp write()"); - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{metrics::NetworkMetrics, types::Pid}; - use std::sync::Arc; - use tokio::{net, runtime::Runtime, sync::mpsc}; - - #[test] - fn tcp_read_handshake() { - let pid = Pid::new(); - let cid = 80085; - let metrics = Arc::new(NetworkMetrics::new(&pid).unwrap()); - let addr = std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(127, 0, 0, 1), 50500); - Runtime::new().unwrap().block_on(async { - let server = net::TcpListener::bind(addr).await.unwrap(); - let mut client = net::TcpStream::connect(addr).await.unwrap(); - - let (s_stream, _) = server.accept().await.unwrap(); - let prot = TcpProtocol::new(s_stream, metrics); - - //Send Handshake - client.write_all(&[FRAME_HANDSHAKE]).await.unwrap(); - client.write_all(b"HELLOWO").await.unwrap(); - client.write_all(&1337u32.to_le_bytes()).await.unwrap(); - client.write_all(&0u32.to_le_bytes()).await.unwrap(); - client.write_all(&42u32.to_le_bytes()).await.unwrap(); - client.flush().await.unwrap(); - - //handle data - let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded_channel::(); - let (read_stop_sender, read_stop_receiver) = oneshot::channel(); - let cid2 = cid; - let t = std::thread::spawn(move || { - Runtime::new().unwrap().block_on(async { - prot.read_from_wire(cid2, &mut w2c_cid_frame_s, read_stop_receiver) - .await; - }) - }); - // Assert than we get some value back! Its a Handshake! - //tokio::task::sleep(std::time::Duration::from_millis(1000)); - let (cid_r, frame) = w2c_cid_frame_r.recv().await.unwrap(); - assert_eq!(cid, cid_r); - if let Ok(Frame::Handshake { - magic_number, - version, - }) = frame - { - assert_eq!(&magic_number, b"HELLOWO"); - assert_eq!(version, [1337, 0, 42]); - } else { - panic!("wrong handshake"); - } - read_stop_sender.send(()).unwrap(); - t.join().unwrap(); - }); - } - - #[test] - fn tcp_read_garbage() { - let pid = Pid::new(); - let cid = 80085; - let metrics = Arc::new(NetworkMetrics::new(&pid).unwrap()); - let addr = std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(127, 0, 0, 1), 50501); - Runtime::new().unwrap().block_on(async { - let server = net::TcpListener::bind(addr).await.unwrap(); - let mut client = net::TcpStream::connect(addr).await.unwrap(); - - let (s_stream, _) = server.accept().await.unwrap(); - let prot = TcpProtocol::new(s_stream, metrics); - - //Send Handshake - client - .write_all("x4hrtzsektfhxugzdtz5r78gzrtzfhxfdthfthuzhfzzufasgasdfg".as_bytes()) - .await - .unwrap(); - client.flush().await.unwrap(); - //handle data - let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded_channel::(); - let (read_stop_sender, read_stop_receiver) = oneshot::channel(); - let cid2 = cid; - let t = std::thread::spawn(move || { - Runtime::new().unwrap().block_on(async { - prot.read_from_wire(cid2, &mut w2c_cid_frame_s, read_stop_receiver) - .await; - }) - }); - // Assert than we get some value back! Its a Raw! - let (cid_r, frame) = w2c_cid_frame_r.recv().await.unwrap(); - assert_eq!(cid, cid_r); - if let Ok(Frame::Raw(data)) = frame { - assert_eq!(&data.as_slice(), b"x4hrtzsektfhxugzdtz5r78gzrtzfhxf"); - } else { - panic!("wrong frame type"); - } - read_stop_sender.send(()).unwrap(); - t.join().unwrap(); - }); - } -} diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index f648d48a15..eb6d21bd7e 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -2,12 +2,11 @@ use crate::metrics::NetworkMetrics; use crate::{ api::{Participant, ProtocolAddr}, - channel::Handshake, + channel::Protocols, participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel, S2bShutdownBparticipant}, - protocols::{Protocols, TcpProtocol, UdpProtocol}, - types::Pid, }; use futures_util::{FutureExt, StreamExt}; +use network_protocol::Pid; #[cfg(feature = "metrics")] use prometheus::Registry; use rand::Rng; @@ -17,6 +16,7 @@ use std::{ atomic::{AtomicBool, AtomicU64, Ordering}, Arc, }, + time::Duration, }; use tokio::{ io, net, @@ -214,47 +214,40 @@ impl Scheduler { }, }; info!("Connecting Tcp to: {}", stream.peer_addr().unwrap()); - ( - Protocols::Tcp(TcpProtocol::new( - stream, - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - )), - false, - ) - }, - ProtocolAddr::Udp(addr) => { - #[cfg(feature = "metrics")] - self.metrics - .connect_requests_total - .with_label_values(&["udp"]) - .inc(); - let socket = match net::UdpSocket::bind("0.0.0.0:0").await { - Ok(socket) => Arc::new(socket), - Err(e) => { - pid_sender.send(Err(e)).unwrap(); - continue; - }, - }; - if let Err(e) = socket.connect(addr).await { - pid_sender.send(Err(e)).unwrap(); - continue; - }; - info!("Connecting Udp to: {}", addr); - let (udp_data_sender, udp_data_receiver) = mpsc::unbounded_channel::>(); - let protocol = UdpProtocol::new( - Arc::clone(&socket), - addr, - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - udp_data_receiver, - ); - self.runtime.spawn( - Self::udp_single_channel_connect(Arc::clone(&socket), udp_data_sender) - .instrument(tracing::info_span!("udp", ?addr)), - ); - (Protocols::Udp(protocol), true) - }, + (Protocols::new_tcp(stream), false) + }, /* */ + //ProtocolAddr::Udp(addr) => { + //#[cfg(feature = "metrics")] + //self.metrics + //.connect_requests_total + //.with_label_values(&["udp"]) + //.inc(); + //let socket = match net::UdpSocket::bind("0.0.0.0:0").await { + //Ok(socket) => Arc::new(socket), + //Err(e) => { + //pid_sender.send(Err(e)).unwrap(); + //continue; + //}, + //}; + //if let Err(e) = socket.connect(addr).await { + //pid_sender.send(Err(e)).unwrap(); + //continue; + //}; + //info!("Connecting Udp to: {}", addr); + //let (udp_data_sender, udp_data_receiver) = mpsc::unbounded_channel::>(); + //let protocol = UdpProtocol::new( + //Arc::clone(&socket), + //addr, + //#[cfg(feature = "metrics")] + //Arc::clone(&self.metrics), + //udp_data_receiver, + //); + //self.runtime.spawn( + //Self::udp_single_channel_connect(Arc::clone(&socket), udp_data_sender) + //.instrument(tracing::info_span!("udp", ?addr)), + //); + //(Protocols::Udp(protocol), true) + //}, _ => unimplemented!(), }; self.init_protocol(protocol, Some(pid_sender), handshake) @@ -265,7 +258,9 @@ impl Scheduler { async fn disconnect_mgr(&self, mut a2s_disconnect_r: mpsc::UnboundedReceiver) { trace!("Start disconnect_mgr"); - while let Some((pid, return_once_successful_shutdown)) = a2s_disconnect_r.recv().await { + while let Some((pid, (timeout_time, return_once_successful_shutdown))) = + a2s_disconnect_r.recv().await + { //Closing Participants is done the following way: // 1. We drop our senders and receivers // 2. we need to close BParticipant, this will drop its senderns and receivers @@ -279,7 +274,7 @@ impl Scheduler { pi.s2b_shutdown_bparticipant_s .take() .unwrap() - .send(finished_sender) + .send((timeout_time, finished_sender)) .unwrap(); drop(pi); trace!(?pid, "dropped bparticipant, waiting for finish"); @@ -322,7 +317,7 @@ impl Scheduler { pi.s2b_shutdown_bparticipant_s .take() .unwrap() - .send(finished_sender) + .send((Duration::from_secs(120), finished_sender)) .unwrap(); (pid, finished_receiver) }) @@ -392,15 +387,10 @@ impl Scheduler { }, }; info!("Accepting Tcp from: {}", remote_addr); - let protocol = TcpProtocol::new( - stream, - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - ); - self.init_protocol(Protocols::Tcp(protocol), None, true) + self.init_protocol(Protocols::new_tcp(stream), None, true) .await; } - }, + },/* ProtocolAddr::Udp(addr) => { let socket = match net::UdpSocket::bind(addr).await { Ok(socket) => { @@ -451,12 +441,13 @@ impl Scheduler { let udp_data_sender = listeners.get_mut(&remote_addr).unwrap(); udp_data_sender.send(datavec).unwrap(); } - }, + },*/ _ => unimplemented!(), } trace!(?addr, "Ending channel creator"); } + #[allow(dead_code)] async fn udp_single_channel_connect( socket: Arc, w2p_udp_package_s: mpsc::UnboundedSender>, @@ -483,7 +474,7 @@ impl Scheduler { async fn init_protocol( &self, - protocol: Protocols, + mut protocol: Protocols, s2a_return_pid_s: Option>>, send_handshake: bool, ) { @@ -509,20 +500,13 @@ impl Scheduler { self.runtime.spawn( async move { trace!(?cid, "Open channel and be ready for Handshake"); - let handshake = Handshake::new( - cid, - local_pid, - local_secret, - #[cfg(feature = "metrics")] - Arc::clone(&metrics), - send_handshake, - ); - match handshake - .setup(&protocol) + use network_protocol::InitProtocol; + let init_result = protocol + .initialize(send_handshake, local_pid, local_secret) .instrument(tracing::info_span!("handshake", ?cid)) - .await - { - Ok((pid, sid, secret, leftover_cid_frame)) => { + .await; + match init_result { + Ok((pid, sid, secret)) => { trace!( ?cid, ?pid, @@ -533,14 +517,13 @@ impl Scheduler { debug!(?cid, "New participant connected via a channel"); let ( bparticipant, - a2b_stream_open_s, + a2b_open_stream_s, b2a_stream_opened_r, s2b_create_channel_s, s2b_shutdown_bparticipant_s, ) = BParticipant::new( pid, sid, - Arc::clone(&runtime), #[cfg(feature = "metrics")] Arc::clone(&metrics), ); @@ -548,8 +531,7 @@ impl Scheduler { let participant = Participant::new( local_pid, pid, - Arc::clone(&runtime), - a2b_stream_open_s, + a2b_open_stream_s, b2a_stream_opened_r, participant_channels.a2s_disconnect_s, ); @@ -573,13 +555,7 @@ impl Scheduler { oneshot::channel(); //From now on wire connects directly with bparticipant! s2b_create_channel_s - .send(( - cid, - sid, - protocol, - leftover_cid_frame, - b2s_create_channel_done_s, - )) + .send((cid, sid, protocol, b2s_create_channel_done_s)) .unwrap(); b2s_create_channel_done_r.await.unwrap(); if let Some(pid_oneshot) = s2a_return_pid_s { @@ -627,8 +603,8 @@ impl Scheduler { //From now on this CHANNEL can receiver other frames! // move directly to participant! }, - Err(()) => { - debug!(?cid, "Handshake from a new connection failed"); + Err(e) => { + debug!(?cid, ?e, "Handshake from a new connection failed"); if let Some(pid_oneshot) = s2a_return_pid_s { // someone is waiting with `connect`, so give them their Error trace!(?cid, "returning the Err to api who requested the connect"); diff --git a/server/Cargo.toml b/server/Cargo.toml index f997749565..8726a037d8 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -16,7 +16,7 @@ common = { package = "veloren-common", path = "../common" } common-sys = { package = "veloren-common-sys", path = "../common/sys" } common-net = { package = "veloren-common-net", path = "../common/net" } world = { package = "veloren-world", path = "../world" } -network = { package = "veloren_network", path = "../network", features = ["metrics", "compression"], default-features = false } +network = { package = "veloren-network", path = "../network", features = ["metrics", "compression"], default-features = false } specs = { git = "https://github.com/amethyst/specs.git", features = ["shred-derive"], rev = "d4435bdf496cf322c74886ca09dd8795984919b4" } specs-idvs = { git = "https://gitlab.com/veloren/specs-idvs.git", rev = "9fab7b396acd6454585486e50ae4bfe2069858a9" } diff --git a/voxygen/src/hud/chat.rs b/voxygen/src/hud/chat.rs index b2dc69c780..1ecc5fe621 100644 --- a/voxygen/src/hud/chat.rs +++ b/voxygen/src/hud/chat.rs @@ -373,7 +373,7 @@ impl<'a> Widget for Chat<'a> { let ChatMsg { chat_type, .. } = &message; // For each ChatType needing localization get/set matching pre-formatted // localized string. This string will be formatted with the data - // provided in ChatType in the client/src/lib.rs + // provided in ChatType in the client/src/mod.rs // fn format_message called below message.message = match chat_type { ChatType::Online(_) => self