From 1b77b6dc41841a867220dd94bae7fb9058ff05e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= Date: Wed, 13 Jan 2021 14:16:22 +0100 Subject: [PATCH 1/6] Initial switch to tokio for network, minimum working example. --- Cargo.lock | 143 ++++++++------------------- client/Cargo.toml | 1 + client/src/lib.rs | 10 +- network/Cargo.toml | 5 +- network/src/api.rs | 68 +++++++------ network/src/lib.rs | 2 +- network/src/participant.rs | 13 ++- network/src/protocols.rs | 27 ++--- network/src/scheduler.rs | 37 +++---- server-cli/Cargo.toml | 1 + server-cli/src/main.rs | 3 +- server/Cargo.toml | 1 + server/src/lib.rs | 7 +- voxygen/Cargo.toml | 2 + voxygen/src/menu/main/client_init.rs | 5 +- voxygen/src/singleplayer.rs | 4 +- 16 files changed, 149 insertions(+), 180 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d2595f8d09..56d9e2a87a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -248,41 +248,6 @@ dependencies = [ "serde_json", ] -[[package]] -name = "async-std" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "538ecb01eb64eecd772087e5b6f7540cbc917f047727339a472dafed2185b267" -dependencies = [ - "async-task", - "crossbeam-channel 0.4.4", - "crossbeam-deque 0.7.3", - "crossbeam-utils 0.7.2", - "futures-core", - "futures-io", - "futures-timer 2.0.2", - "kv-log-macro", - "log", - "memchr", - "mio 0.6.23", - "mio-uds", - "num_cpus", - "once_cell", - "pin-project-lite 0.1.11", - "pin-utils", - "slab", -] - -[[package]] -name = "async-task" -version = "1.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ac2c016b079e771204030951c366db398864f5026f84a44dafb0ff20f02085d" -dependencies = [ - "libc", - "winapi 0.3.9", -] - [[package]] name = "atom" version = "0.3.6" @@ -1114,16 +1079,6 @@ dependencies = [ "crossbeam-utils 0.6.6", ] -[[package]] -name = "crossbeam-channel" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b153fe7cbef478c567df0f972e02e6d736db11affe43dfc9c56a9374d1adfb87" -dependencies = [ - "crossbeam-utils 0.7.2", - "maybe-uninit", -] - [[package]] name = "crossbeam-channel" version = "0.5.0" @@ -1290,16 +1245,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "ctor" -version = "0.1.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8f45d9ad417bcef4817d614a501ab55cdd96a6fdb24f49aab89a54acfd66b19" -dependencies = [ - "quote 1.0.9", - "syn 1.0.60", -] - [[package]] name = "daggy" version = "0.5.0" @@ -1861,12 +1806,6 @@ dependencies = [ "once_cell", ] -[[package]] -name = "futures-timer" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1de7508b218029b0f01662ed8f61b1c964b3ae99d6f25462d0f55a595109df6" - [[package]] name = "futures-timer" version = "3.0.2" @@ -2242,7 +2181,7 @@ dependencies = [ "http", "indexmap", "slab", - "tokio", + "tokio 0.2.25", "tokio-util", "tracing", "tracing-futures", @@ -2393,7 +2332,7 @@ dependencies = [ "itoa", "pin-project 1.0.5", "socket2", - "tokio", + "tokio 0.2.25", "tower-service", "tracing", "want", @@ -2410,7 +2349,7 @@ dependencies = [ "hyper", "log", "rustls 0.18.1", - "tokio", + "tokio 0.2.25", "tokio-rustls", "webpki", ] @@ -2687,15 +2626,6 @@ version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2db585e1d738fc771bf08a151420d3ed193d9d895a36df7f6f8a9456b911ddc" -[[package]] -name = "kv-log-macro" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0de8b303297635ad57c9f5059fd9cee7a47f8e8daa09df0fcd07dd39fb22977f" -dependencies = [ - "log", -] - [[package]] name = "lazy-bytes-cast" version = "5.0.1" @@ -2863,7 +2793,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710" dependencies = [ "cfg-if 1.0.0", - "value-bag", ] [[package]] @@ -3088,17 +3017,6 @@ dependencies = [ "slab", ] -[[package]] -name = "mio-uds" -version = "0.6.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afcb699eb26d4332647cc848492bbc15eafb26f08d0304550d5aa1f612e066f0" -dependencies = [ - "iovec", - "libc", - "mio 0.6.23", -] - [[package]] name = "miow" version = "0.2.2" @@ -4282,7 +4200,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "tokio", + "tokio 0.2.25", "tokio-rustls", "url", "wasm-bindgen", @@ -5162,6 +5080,33 @@ dependencies = [ "slab", ] +[[package]] +name = "tokio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8190d04c665ea9e6b6a0dc45523ade572c088d2e6566244c1122671dbf4ae3a" +dependencies = [ + "autocfg", + "bytes 1.0.1", + "libc", + "memchr", + "mio 0.7.7", + "num_cpus", + "pin-project-lite 0.2.4", + "tokio-macros", +] + +[[package]] +name = "tokio-macros" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "caf7b11a536f46a809a8a9f0bb4237020f70ecbf115b842360afb127ea2fda57" +dependencies = [ + "proc-macro2 1.0.24", + "quote 1.0.9", + "syn 1.0.60", +] + [[package]] name = "tokio-rustls" version = "0.14.1" @@ -5170,7 +5115,7 @@ checksum = "e12831b255bcfa39dc0436b01e19fea231a37db570686c06ee72c423479f889a" dependencies = [ "futures-core", "rustls 0.18.1", - "tokio", + "tokio 0.2.25", "webpki", ] @@ -5185,7 +5130,7 @@ dependencies = [ "futures-sink", "log", "pin-project-lite 0.1.11", - "tokio", + "tokio 0.2.25", ] [[package]] @@ -5528,15 +5473,6 @@ dependencies = [ "num_cpus", ] -[[package]] -name = "value-bag" -version = "1.0.0-alpha.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b676010e055c99033117c2343b33a40a30b91fecd6c49055ac9cd2d6c305ab1" -dependencies = [ - "ctor", -] - [[package]] name = "vcpkg" version = "0.2.11" @@ -5596,7 +5532,7 @@ dependencies = [ "authc", "byteorder", "futures-executor", - "futures-timer 3.0.2", + "futures-timer", "futures-util", "hashbrown 0.9.1", "image", @@ -5604,6 +5540,7 @@ dependencies = [ "num_cpus", "rayon", "specs", + "tokio 1.2.0", "tracing", "tracing-subscriber", "uvth 3.1.1", @@ -5731,7 +5668,7 @@ dependencies = [ "dotenv", "futures-channel", "futures-executor", - "futures-timer 3.0.2", + "futures-timer", "futures-util", "hashbrown 0.9.1", "itertools 0.9.0", @@ -5749,6 +5686,7 @@ dependencies = [ "specs", "specs-idvs", "tiny_http", + "tokio 1.2.0", "tracing", "uvth 3.1.1", "vek 0.12.0", @@ -5772,6 +5710,7 @@ dependencies = [ "serde", "signal-hook 0.2.3", "termcolor", + "tokio 1.2.0", "tracing", "tracing-subscriber", "tracing-tracy", @@ -5818,6 +5757,7 @@ dependencies = [ "lazy_static", "native-dialog", "num 0.3.1", + "num_cpus", "old_school_gfx_glutin_ext", "ordered-float 2.1.1", "rand 0.8.3", @@ -5827,6 +5767,7 @@ dependencies = [ "specs", "specs-idvs", "termcolor", + "tokio 1.2.0", "tracing", "tracing-appender", "tracing-log", @@ -5893,9 +5834,8 @@ dependencies = [ [[package]] name = "veloren_network" -version = "0.2.0" +version = "0.3.0" dependencies = [ - "async-std", "bincode", "bitflags", "clap", @@ -5908,6 +5848,7 @@ dependencies = [ "serde", "shellexpand", "tiny_http", + "tokio 1.2.0", "tracing", "tracing-futures", "tracing-subscriber", diff --git a/client/Cargo.toml b/client/Cargo.toml index c775dcfc79..b2ebcfead1 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -21,6 +21,7 @@ uvth = "3.1.1" futures-util = "0.3.7" futures-executor = "0.3" futures-timer = "3.0" +tokio = { version = "1.0.1", default-features = false, features = ["rt"] } image = { version = "0.23.12", default-features = false, features = ["png"] } num = "0.3.1" num_cpus = "1.10.1" diff --git a/client/src/lib.rs b/client/src/lib.rs index 1f15a66bfe..9b13ed80f1 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -64,6 +64,7 @@ use std::{ time::{Duration, Instant}, }; use tracing::{debug, error, trace, warn}; +use tokio::runtime::Runtime; use uvth::{ThreadPool, ThreadPoolBuilder}; use vek::*; @@ -129,6 +130,7 @@ impl WorldData { pub struct Client { registered: bool, presence: Option, + runtime: Arc, thread_pool: ThreadPool, server_info: ServerInfo, world_data: WorldData, @@ -185,15 +187,14 @@ pub struct CharacterList { impl Client { /// Create a new `Client`. - pub fn new>(addr: A, view_distance: Option) -> Result { + pub fn new>(addr: A, view_distance: Option, runtime: Arc) -> Result { let mut thread_pool = ThreadPoolBuilder::new() .name("veloren-worker".into()) .build(); // We reduce the thread count by 1 to keep rendering smooth thread_pool.set_num_threads((num_cpus::get() - 1).max(1)); - let (network, scheduler) = Network::new(Pid::new()); - thread_pool.execute(scheduler); + let network = Network::new(Pid::new(), Arc::clone(&runtime)); let participant = block_on(network.connect(ProtocolAddr::Tcp(addr.into())))?; let stream = block_on(participant.opened())?; @@ -417,6 +418,7 @@ impl Client { Ok(Self { registered: false, presence: None, + runtime, thread_pool, server_info, world_data: WorldData { @@ -1733,6 +1735,8 @@ impl Client { /// exempt). pub fn thread_pool(&self) -> &ThreadPool { &self.thread_pool } + pub fn runtime(&self) -> &Arc { &self.runtime } + /// Get a reference to the client's game state. pub fn state(&self) -> &State { &self.state } diff --git a/network/Cargo.toml b/network/Cargo.toml index 49caa4d62d..d477be73e1 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "veloren_network" -version = "0.2.0" +version = "0.3.0" authors = ["Marcel Märtens "] edition = "2018" @@ -19,8 +19,7 @@ bincode = "1.3.1" serde = { version = "1.0" } #sending crossbeam-channel = "0.5" -# NOTE: Upgrading async-std can trigger spontanious crashes for `network`ing. Consider elaborate tests before upgrading -async-std = { version = "~1.5", default-features = false, features = ["std", "async-task", "default"] } +tokio = { version = "1.2", default-features = false, features = ["io-util", "macros", "rt", "net", "time"] } #tracing and metrics tracing = { version = "0.1", default-features = false } tracing-futures = "0.2" diff --git a/network/src/api.rs b/network/src/api.rs index 66ffa82096..8baaa72581 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -8,7 +8,8 @@ use crate::{ scheduler::Scheduler, types::{Mid, Pid, Prio, Promises, Sid}, }; -use async_std::{io, sync::Mutex, task}; +use tokio::{io, sync::Mutex}; +use tokio::runtime::Runtime; use futures::{ channel::{mpsc, oneshot}, sink::SinkExt, @@ -50,6 +51,7 @@ pub enum ProtocolAddr { pub struct Participant { local_pid: Pid, remote_pid: Pid, + runtime: Arc, a2b_stream_open_s: Mutex>, b2a_stream_opened_r: Mutex>, a2s_disconnect_s: A2sDisconnect, @@ -76,6 +78,7 @@ pub struct Stream { prio: Prio, promises: Promises, send_closed: Arc, + runtime: Arc, a2b_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, b2a_msg_recv_r: Option>, a2b_close_stream_s: Option>, @@ -150,9 +153,10 @@ pub enum StreamError { /// [`connected`]: Network::connected pub struct Network { local_pid: Pid, + runtime: Arc, participant_disconnect_sender: Mutex>, listen_sender: - Mutex>)>>, + Mutex>)>>, connect_sender: Mutex>)>>, connected_receiver: Mutex>, @@ -165,17 +169,12 @@ impl Network { /// # Arguments /// * `participant_id` - provide it by calling [`Pid::new()`], usually you /// don't want to reuse a Pid for 2 `Networks` + /// * `runtime` - provide a tokio::Runtime, it's used to internally spawn tasks /// /// # Result /// * `Self` - returns a `Network` which can be `Send` to multiple areas of /// your code, including multiple threads. This is the base strct of this /// crate. - /// * `FnOnce` - you need to run the returning FnOnce exactly once, probably - /// in it's own thread. this is NOT done internally, so that you are free - /// to choose the threadpool implementation of your choice. We recommend - /// using [`ThreadPool`] from [`uvth`] crate. This fn will run the - /// Scheduler to handle all `Network` internals. Additional threads will - /// be allocated on an internal async-aware threadpool /// /// # Examples /// ```rust @@ -204,9 +203,10 @@ impl Network { /// [`Pid::new()`]: crate::types::Pid::new /// [`ThreadPool`]: https://docs.rs/uvth/newest/uvth/struct.ThreadPool.html /// [`uvth`]: https://docs.rs/uvth - pub fn new(participant_id: Pid) -> (Self, impl std::ops::FnOnce()) { + pub fn new(participant_id: Pid, runtime: Arc) -> Self { Self::internal_new( participant_id, + runtime, #[cfg(feature = "metrics")] None, ) @@ -232,42 +232,46 @@ impl Network { #[cfg(feature = "metrics")] pub fn new_with_registry( participant_id: Pid, + runtime: Arc, registry: &Registry, - ) -> (Self, impl std::ops::FnOnce()) { - Self::internal_new(participant_id, Some(registry)) + ) -> Self { + Self::internal_new(participant_id, runtime, Some(registry)) } fn internal_new( participant_id: Pid, + runtime: Arc, #[cfg(feature = "metrics")] registry: Option<&Registry>, - ) -> (Self, impl std::ops::FnOnce()) { + ) -> Self { let p = participant_id; debug!(?p, "Starting Network"); let (scheduler, listen_sender, connect_sender, connected_receiver, shutdown_sender) = Scheduler::new( participant_id, + Arc::clone(&runtime), #[cfg(feature = "metrics")] registry, ); - ( - Self { - local_pid: participant_id, - participant_disconnect_sender: Mutex::new(HashMap::new()), - listen_sender: Mutex::new(listen_sender), - connect_sender: Mutex::new(connect_sender), - connected_receiver: Mutex::new(connected_receiver), - shutdown_sender: Some(shutdown_sender), - }, - move || { + runtime.spawn( + async move { trace!(?p, "Starting scheduler in own thread"); - let _handle = task::block_on( + let _handle = tokio::spawn( scheduler .run() .instrument(tracing::info_span!("scheduler", ?p)), ); trace!(?p, "Stopping scheduler and his own thread"); - }, - ) + } + ); + Self { + local_pid: participant_id, + runtime: runtime, + participant_disconnect_sender: Mutex::new(HashMap::new()), + listen_sender: Mutex::new(listen_sender), + connect_sender: Mutex::new(connect_sender), + connected_receiver: Mutex::new(connected_receiver), + shutdown_sender: Some(shutdown_sender), + } } /// starts listening on an [`ProtocolAddr`]. @@ -300,7 +304,7 @@ impl Network { /// /// [`connected`]: Network::connected pub async fn listen(&self, address: ProtocolAddr) -> Result<(), NetworkError> { - let (s2a_result_s, s2a_result_r) = oneshot::channel::>(); + let (s2a_result_s, s2a_result_r) = oneshot::channel::>(); debug!(?address, "listening on address"); self.listen_sender .lock() @@ -426,6 +430,7 @@ impl Participant { pub(crate) fn new( local_pid: Pid, remote_pid: Pid, + runtime: Arc, a2b_stream_open_s: mpsc::UnboundedSender, b2a_stream_opened_r: mpsc::UnboundedReceiver, a2s_disconnect_s: mpsc::UnboundedSender<(Pid, S2bShutdownBparticipant)>, @@ -433,6 +438,7 @@ impl Participant { Self { local_pid, remote_pid, + runtime, a2b_stream_open_s: Mutex::new(a2b_stream_open_s), b2a_stream_opened_r: Mutex::new(b2a_stream_opened_r), a2s_disconnect_s: Arc::new(Mutex::new(Some(a2s_disconnect_s))), @@ -655,6 +661,7 @@ impl Stream { prio: Prio, promises: Promises, send_closed: Arc, + runtime: Arc, a2b_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, b2a_msg_recv_r: mpsc::UnboundedReceiver, a2b_close_stream_s: mpsc::UnboundedSender, @@ -666,6 +673,7 @@ impl Stream { prio, promises, send_closed, + runtime, a2b_msg_s, b2a_msg_recv_r: Some(b2a_msg_recv_r), a2b_close_stream_s: Some(a2b_close_stream_s), @@ -960,7 +968,7 @@ impl Drop for Network { "Shutting down Participants of Network, while we still have metrics" ); let mut finished_receiver_list = vec![]; - task::block_on(async { + self.runtime.block_on(async { // we MUST avoid nested block_on, good that Network::Drop no longer triggers // Participant::Drop directly but just the BParticipant for (remote_pid, a2s_disconnect_s) in @@ -1013,14 +1021,14 @@ impl Drop for Participant { let pid = self.remote_pid; debug!(?pid, "Shutting down Participant"); - match task::block_on(self.a2s_disconnect_s.lock()).take() { + match self.runtime.block_on(self.a2s_disconnect_s.lock()).take() { None => trace!( ?pid, "Participant has been shutdown cleanly, no further waiting is required!" ), Some(mut a2s_disconnect_s) => { debug!(?pid, "Disconnect from Scheduler"); - task::block_on(async { + self.runtime.block_on(async { let (finished_sender, finished_receiver) = oneshot::channel(); a2s_disconnect_s .send((self.remote_pid, finished_sender)) @@ -1051,7 +1059,7 @@ impl Drop for Stream { let sid = self.sid; let pid = self.pid; debug!(?pid, ?sid, "Shutting down Stream"); - task::block_on(self.a2b_close_stream_s.take().unwrap().send(self.sid)) + self.runtime.block_on(self.a2b_close_stream_s.take().unwrap().send(self.sid)) .expect("bparticipant part of a gracefully shutdown must have crashed"); } else { let sid = self.sid; diff --git a/network/src/lib.rs b/network/src/lib.rs index bb14782a69..69bd5f07c0 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -39,7 +39,7 @@ //! //! # Examples //! ```rust -//! use async_std::task::sleep; +//! use tokio::task::sleep; //! use futures::{executor::block_on, join}; //! use veloren_network::{Network, Pid, Promises, ProtocolAddr}; //! diff --git a/network/src/participant.rs b/network/src/participant.rs index 8e0d0a1904..78d1dacd41 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -8,7 +8,8 @@ use crate::{ protocols::Protocols, types::{Cid, Frame, Pid, Prio, Promises, Sid}, }; -use async_std::sync::{Mutex, RwLock}; +use tokio::sync::{Mutex, RwLock}; +use tokio::runtime::Runtime; use futures::{ channel::{mpsc, oneshot}, future::FutureExt, @@ -71,6 +72,7 @@ pub struct BParticipant { remote_pid: Pid, remote_pid_string: String, //optimisation offset_sid: Sid, + runtime: Arc, channels: Arc>>>, streams: RwLock>, running_mgr: AtomicUsize, @@ -86,6 +88,7 @@ impl BParticipant { pub(crate) fn new( remote_pid: Pid, offset_sid: Sid, + runtime: Arc, #[cfg(feature = "metrics")] metrics: Arc, ) -> ( Self, @@ -120,6 +123,7 @@ impl BParticipant { remote_pid, remote_pid_string: remote_pid.to_string(), offset_sid, + runtime, channels: Arc::new(RwLock::new(HashMap::new())), streams: RwLock::new(HashMap::new()), running_mgr: AtomicUsize::new(0), @@ -213,7 +217,7 @@ impl BParticipant { .send((self.remote_pid, len as u64, /* */ 0)) .await .unwrap(); - async_std::task::sleep(TICK_TIME).await; + tokio::time::sleep(TICK_TIME).await; i += 1; if i.rem_euclid(1000) == 0 { trace!("Did 1000 ticks"); @@ -659,7 +663,7 @@ impl BParticipant { //Wait for other bparticipants mgr to close via AtomicUsize const SLEEP_TIME: Duration = Duration::from_millis(5); const ALLOWED_MANAGER: usize = 1; - async_std::task::sleep(SLEEP_TIME).await; + tokio::time::sleep(SLEEP_TIME).await; let mut i: u32 = 1; while self.running_mgr.load(Ordering::Relaxed) > ALLOWED_MANAGER { i += 1; @@ -670,7 +674,7 @@ impl BParticipant { self.running_mgr.load(Ordering::Relaxed) - ALLOWED_MANAGER ); } - async_std::task::sleep(SLEEP_TIME * i).await; + tokio::time::sleep(SLEEP_TIME * i).await; } trace!("All BParticipant mgr (except me) are shut down now"); @@ -843,6 +847,7 @@ impl BParticipant { prio, promises, send_closed, + Arc::clone(&self.runtime), a2p_msg_s, b2a_msg_recv_r, a2b_close_stream_s.clone(), diff --git a/network/src/protocols.rs b/network/src/protocols.rs index 7b0b8651b6..771ea649e5 100644 --- a/network/src/protocols.rs +++ b/network/src/protocols.rs @@ -4,8 +4,8 @@ use crate::{ participant::C2pFrame, types::{Cid, Frame}, }; -use async_std::{ - io::prelude::*, +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, net::{TcpStream, UdpSocket}, }; @@ -43,7 +43,8 @@ pub(crate) enum Protocols { #[derive(Debug)] pub(crate) struct TcpProtocol { - stream: TcpStream, + read_stream: tokio::sync::Mutex, + write_stream: tokio::sync::Mutex, #[cfg(feature = "metrics")] metrics: Arc, } @@ -63,14 +64,16 @@ impl TcpProtocol { stream: TcpStream, #[cfg(feature = "metrics")] metrics: Arc, ) -> Self { + let (read_stream, write_stream) = stream.into_split(); Self { - stream, + read_stream: tokio::sync::Mutex::new(read_stream), + write_stream: tokio::sync::Mutex::new(write_stream), #[cfg(feature = "metrics")] metrics, } } - async fn read_frame( + async fn read_frame( r: &mut R, mut end_receiver: &mut Fuse>, ) -> Result> { @@ -167,11 +170,11 @@ impl TcpProtocol { .metrics .wire_in_throughput .with_label_values(&[&cid.to_string()]); - let mut stream = self.stream.clone(); + let mut read_stream = self.read_stream.lock().await; let mut end_r = end_r.fuse(); loop { - match Self::read_frame(&mut stream, &mut end_r).await { + match Self::read_frame(&mut *read_stream, &mut end_r).await { Ok(frame) => { #[cfg(feature = "metrics")] { @@ -209,7 +212,7 @@ impl TcpProtocol { trace!("Shutting down tcp read()"); } - pub async fn write_frame( + pub async fn write_frame( w: &mut W, frame: Frame, ) -> Result<(), std::io::Error> { @@ -270,7 +273,7 @@ impl TcpProtocol { pub async fn write_to_wire(&self, cid: Cid, mut c2w_frame_r: mpsc::UnboundedReceiver) { trace!("Starting up tcp write()"); - let mut stream = self.stream.clone(); + let mut write_stream = self.write_stream.lock().await; #[cfg(feature = "metrics")] let mut metrics_cache = CidFrameCache::new(self.metrics.frames_wire_out_total.clone(), cid); #[cfg(feature = "metrics")] @@ -294,7 +297,7 @@ impl TcpProtocol { throughput_cache.inc_by(data.len() as u64); } } - if let Err(e) = Self::write_frame(&mut stream, frame).await { + if let Err(e) = Self::write_frame(&mut *write_stream, frame).await { info!( ?e, "Got an error writing to tcp, going to close this channel" @@ -498,7 +501,7 @@ impl UdpProtocol { mod tests { use super::*; use crate::{metrics::NetworkMetrics, types::Pid}; - use async_std::net; + use tokio::net; use futures::{executor::block_on, stream::StreamExt}; use std::sync::Arc; @@ -534,7 +537,7 @@ mod tests { }) }); // Assert than we get some value back! Its a Handshake! - //async_std::task::sleep(std::time::Duration::from_millis(1000)); + //tokio::task::sleep(std::time::Duration::from_millis(1000)); let (cid_r, frame) = w2c_cid_frame_r.next().await.unwrap(); assert_eq!(cid, cid_r); if let Ok(Frame::Handshake { diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index 33cb4ed054..e0c3b0ef84 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -7,10 +7,10 @@ use crate::{ protocols::{Protocols, TcpProtocol, UdpProtocol}, types::Pid, }; -use async_std::{io, net, sync::Mutex}; +use tokio::{io, net, sync::Mutex}; +use tokio::runtime::Runtime; use futures::{ channel::{mpsc, oneshot}, - executor::ThreadPool, future::FutureExt, select, sink::SinkExt, @@ -68,9 +68,9 @@ struct ParticipantChannels { #[derive(Debug)] pub struct Scheduler { local_pid: Pid, + runtime: Arc, local_secret: u128, closed: AtomicBool, - pool: Arc, run_channels: Option, participant_channels: Arc>>, participants: Arc>>, @@ -83,6 +83,7 @@ pub struct Scheduler { impl Scheduler { pub fn new( local_pid: Pid, + runtime: Arc, #[cfg(feature = "metrics")] registry: Option<&Registry>, ) -> ( Self, @@ -128,9 +129,9 @@ impl Scheduler { ( Self { local_pid, + runtime, local_secret, closed: AtomicBool::new(false), - pool: Arc::new(ThreadPool::new().unwrap()), run_channels, participant_channels: Arc::new(Mutex::new(Some(participant_channels))), participants: Arc::new(Mutex::new(HashMap::new())), @@ -247,7 +248,7 @@ impl Scheduler { Arc::clone(&self.metrics), udp_data_receiver, ); - self.pool.spawn_ok( + self.runtime.spawn( Self::udp_single_channel_connect(Arc::clone(&socket), udp_data_sender) .instrument(tracing::info_span!("udp", ?addr)), ); @@ -377,27 +378,19 @@ impl Scheduler { }, }; trace!(?addr, "Listener bound"); - let mut incoming = listener.incoming(); let mut end_receiver = s2s_stop_listening_r.fuse(); - while let Some(stream) = select! { - next = incoming.next().fuse() => next, + while let Some(data) = select! { + next = listener.accept().fuse() => Some(next), _ = end_receiver => None, } { - let stream = match stream { - Ok(s) => s, + let (stream, remote_addr) = match data { + Ok((s, p)) => (s, p), Err(e) => { warn!(?e, "TcpStream Error, ignoring connection attempt"); continue; }, }; - let peer_addr = match stream.peer_addr() { - Ok(s) => s, - Err(e) => { - warn!(?e, "TcpStream Error, ignoring connection attempt"); - continue; - }, - }; - info!("Accepting Tcp from: {}", peer_addr); + info!("Accepting Tcp from: {}", remote_addr); let protocol = TcpProtocol::new( stream, #[cfg(feature = "metrics")] @@ -505,13 +498,13 @@ impl Scheduler { // the UDP listening is done in another place. let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); let participants = Arc::clone(&self.participants); + let runtime = Arc::clone(&self.runtime); #[cfg(feature = "metrics")] let metrics = Arc::clone(&self.metrics); - let pool = Arc::clone(&self.pool); let local_pid = self.local_pid; let local_secret = self.local_secret; // this is necessary for UDP to work at all and to remove code duplication - self.pool.spawn_ok( + self.runtime.spawn( async move { trace!(?cid, "Open channel and be ready for Handshake"); let handshake = Handshake::new( @@ -545,6 +538,7 @@ impl Scheduler { ) = BParticipant::new( pid, sid, + Arc::clone(&runtime), #[cfg(feature = "metrics")] Arc::clone(&metrics), ); @@ -552,6 +546,7 @@ impl Scheduler { let participant = Participant::new( local_pid, pid, + Arc::clone(&runtime), a2b_stream_open_s, b2a_stream_opened_r, participant_channels.a2s_disconnect_s, @@ -566,7 +561,7 @@ impl Scheduler { }); drop(participants); trace!("dropped participants lock"); - pool.spawn_ok( + runtime.spawn( bparticipant .run(participant_channels.b2s_prio_statistic_s) .instrument(tracing::info_span!("participant", ?pid)), diff --git a/server-cli/Cargo.toml b/server-cli/Cargo.toml index 6269f04932..4d8a3cf866 100644 --- a/server-cli/Cargo.toml +++ b/server-cli/Cargo.toml @@ -15,6 +15,7 @@ server = { package = "veloren-server", path = "../server", default-features = fa common = { package = "veloren-common", path = "../common" } common-net = { package = "veloren-common-net", path = "../common/net" } +tokio = { version = "1.0.1", default-features = false, features = ["rt-multi-thread"] } ansi-parser = "0.7" clap = "2.33" crossterm = "0.18" diff --git a/server-cli/src/main.rs b/server-cli/src/main.rs index ba83e1d67f..18b3390199 100644 --- a/server-cli/src/main.rs +++ b/server-cli/src/main.rs @@ -129,7 +129,8 @@ fn main() -> io::Result<()> { let server_port = &server_settings.gameserver_address.port(); let metrics_port = &server_settings.metrics_address.port(); // Create server - let mut server = Server::new(server_settings, editable_settings, &server_data_dir) + let runtime = Arc::new(tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap()); + let mut server = Server::new(server_settings, editable_settings, &server_data_dir, runtime) .expect("Failed to create server instance!"); info!( diff --git a/server/Cargo.toml b/server/Cargo.toml index e1f4b7c7e3..f997749565 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -28,6 +28,7 @@ futures-util = "0.3.7" futures-executor = "0.3" futures-timer = "3.0" futures-channel = "0.3" +tokio = { version = "1.0.1", default-features = false, features = ["rt"] } itertools = "0.9" lazy_static = "1.4.0" scan_fmt = { git = "https://github.com/Imberflur/scan_fmt" } diff --git a/server/src/lib.rs b/server/src/lib.rs index 93c5916d72..fe389ae266 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -92,6 +92,7 @@ use std::{ #[cfg(not(feature = "worldgen"))] use test_world::{IndexOwned, World}; use tracing::{debug, error, info, trace}; +use tokio::runtime::Runtime; use uvth::{ThreadPool, ThreadPoolBuilder}; use vek::*; @@ -120,6 +121,7 @@ pub struct Server { connection_handler: ConnectionHandler, + runtime: Arc, thread_pool: ThreadPool, metrics: ServerMetrics, @@ -136,6 +138,7 @@ impl Server { settings: Settings, editable_settings: EditableSettings, data_dir: &std::path::Path, + runtime: Arc, ) -> Result { info!("Server is data dir is: {}", data_dir.display()); if settings.auth_server_address.is_none() { @@ -364,11 +367,10 @@ impl Server { let thread_pool = ThreadPoolBuilder::new() .name("veloren-worker".to_string()) .build(); - let (network, f) = Network::new_with_registry(Pid::new(), &metrics.registry()); + let network = Network::new_with_registry(Pid::new(), Arc::clone(&runtime), &metrics.registry()); metrics .run(settings.metrics_address) .expect("Failed to initialize server metrics submodule."); - thread_pool.execute(f); block_on(network.listen(ProtocolAddr::Tcp(settings.gameserver_address)))?; let connection_handler = ConnectionHandler::new(network); @@ -386,6 +388,7 @@ impl Server { connection_handler, + runtime, thread_pool, metrics, diff --git a/voxygen/Cargo.toml b/voxygen/Cargo.toml index 7480a7ed2f..aeebeb0a30 100644 --- a/voxygen/Cargo.toml +++ b/voxygen/Cargo.toml @@ -82,6 +82,8 @@ ron = {version = "0.6", default-features = false} serde = {version = "1.0", features = [ "rc", "derive" ]} treeculler = "0.1.0" uvth = "3.1.1" +tokio = { version = "1.0.1", default-features = false, features = ["rt-multi-thread"] } +num_cpus = "1.0" # vec_map = { version = "0.8.2" } inline_tweak = "1.0.2" itertools = "0.10.0" diff --git a/voxygen/src/menu/main/client_init.rs b/voxygen/src/menu/main/client_init.rs index a1e01ab71a..010071cd56 100644 --- a/voxygen/src/menu/main/client_init.rs +++ b/voxygen/src/menu/main/client_init.rs @@ -71,6 +71,9 @@ impl ClientInit { let mut last_err = None; + let cores = num_cpus::get(); + let runtime = Arc::new(tokio::runtime::Builder::new_multi_thread().enable_all().worker_threads(if cores > 4 {cores-1} else {cores}).build().unwrap()); + const FOUR_MINUTES_RETRIES: u64 = 48; 'tries: for _ in 0..FOUR_MINUTES_RETRIES { if cancel2.load(Ordering::Relaxed) { @@ -79,7 +82,7 @@ impl ClientInit { for socket_addr in first_addrs.clone().into_iter().chain(second_addrs.clone()) { - match Client::new(socket_addr, view_distance) { + match Client::new(socket_addr, view_distance, Arc::clone(&runtime)) { Ok(mut client) => { if let Err(e) = client.register(username, password, |auth_server| { diff --git a/voxygen/src/singleplayer.rs b/voxygen/src/singleplayer.rs index 0e1a050988..32368a19fc 100644 --- a/voxygen/src/singleplayer.rs +++ b/voxygen/src/singleplayer.rs @@ -82,6 +82,8 @@ impl Singleplayer { let editable_settings = server::EditableSettings::singleplayer(&server_data_dir); let thread_pool = client.map(|c| c.thread_pool().clone()); + let cores = num_cpus::get(); + let runtime = Arc::new(tokio::runtime::Builder::new_multi_thread().enable_all().worker_threads(if cores > 4 {cores-1} else {cores}).build().unwrap()); let settings2 = settings.clone(); let paused = Arc::new(AtomicBool::new(false)); @@ -92,7 +94,7 @@ impl Singleplayer { let thread = thread::spawn(move || { let mut server = None; if let Err(e) = result_sender.send( - match Server::new(settings2, editable_settings, &server_data_dir) { + match Server::new(settings2, editable_settings, &server_data_dir, runtime) { Ok(s) => { server = Some(s); Ok(()) From 5aa1940ef88afbda76894a82ad0cb5ef659416ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= Date: Fri, 15 Jan 2021 14:04:32 +0100 Subject: [PATCH 2/6] get rid of `async_std::channel` switch to `tokio` and `async_channel` crate. I wanted to do tokio first, but it doesnt feature Sender::close(), thus i included async_channel Got rid of `futures` and only need `futures_core` and `futures_util`. Tokio does not support `Stream` and `StreamExt` so for now i need to use `tokio-stream`, i think this will go in `std` in the future Created `b2b_close_stream_opened_sender_r` as the shutdown procedure does not need a copy of a Sender, it just need to stop it. Various adjustments, e.g. for `select!` which now requieres a `&mut` for oneshots. Future things to do: - Use some better signalling than oneshot<()> in some cases. - Use a Watch for the Prio propergation (impl. it ofc) - Use Bounded Channels in order to improve performance - adjust tests coding bring tests to work --- Cargo.lock | 66 +++++++--- client/examples/chat-cli/main.rs | 12 +- client/src/lib.rs | 8 +- network/Cargo.toml | 9 +- network/examples/chat.rs | 26 ++-- network/examples/fileshare/commands.rs | 6 +- network/examples/fileshare/main.rs | 36 ++---- network/examples/fileshare/server.rs | 27 +++-- network/examples/network-speed/main.rs | 39 +++--- network/src/api.rs | 136 ++++++++++----------- network/src/channel.rs | 34 +++--- network/src/lib.rs | 2 +- network/src/message.rs | 6 +- network/src/participant.rs | 130 +++++++++++--------- network/src/prios.rs | 68 ++++++++--- network/src/protocols.rs | 54 ++++----- network/src/scheduler.rs | 60 ++++----- network/tests/closing.rs | 161 ++++++++++++------------- network/tests/helper.rs | 40 ++++-- network/tests/integration.rs | 80 ++++++------ server-cli/src/main.rs | 16 ++- server/src/lib.rs | 9 +- voxygen/src/menu/main/client_init.rs | 8 +- voxygen/src/singleplayer.rs | 8 +- 24 files changed, 571 insertions(+), 470 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 56d9e2a87a..69b549cd2b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -248,6 +248,17 @@ dependencies = [ "serde_json", ] +[[package]] +name = "async-channel" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59740d83946db6a5af71ae25ddf9562c2b176b2ca42cf99a455f09f4a220d6b9" +dependencies = [ + "concurrent-queue", + "event-listener", + "futures-core", +] + [[package]] name = "atom" version = "0.3.6" @@ -452,6 +463,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" +[[package]] +name = "cache-padded" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "631ae5198c9be5e753e5cc215e1bd73c2b466a3565173db433f52bb9d3e66dba" + [[package]] name = "calloop" version = "0.6.5" @@ -712,6 +729,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "concurrent-queue" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30ed07550be01594c6026cff2a1d7fe9c8f683caa798e12b68694ac9e88286a3" +dependencies = [ + "cache-padded", +] + [[package]] name = "conrod_core" version = "0.63.0" @@ -1566,6 +1592,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "event-listener" +version = "2.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7531096570974c3a9dcf9e4b8e1cede1ec26cf5046219fb3b9d897503b9be59" + [[package]] name = "fallible-iterator" version = "0.2.0" @@ -5119,6 +5151,17 @@ dependencies = [ "webpki", ] +[[package]] +name = "tokio-stream" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1981ad97df782ab506a1f43bf82c967326960d278acf3bf8279809648c3ff3ea" +dependencies = [ + "futures-core", + "pin-project-lite 0.2.4", + "tokio 1.2.0", +] + [[package]] name = "tokio-util" version = "0.3.1" @@ -5462,17 +5505,6 @@ dependencies = [ "num_cpus", ] -[[package]] -name = "uvth" -version = "4.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e5910f9106b96334c6cae1f1d77a764bda66ac4ca9f507f73259f184fe1bb6b" -dependencies = [ - "crossbeam-channel 0.3.9", - "log", - "num_cpus", -] - [[package]] name = "vcpkg" version = "0.2.11" @@ -5543,7 +5575,7 @@ dependencies = [ "tokio 1.2.0", "tracing", "tracing-subscriber", - "uvth 3.1.1", + "uvth", "vek 0.12.0", "veloren-common", "veloren-common-net", @@ -5688,7 +5720,7 @@ dependencies = [ "tiny_http", "tokio 1.2.0", "tracing", - "uvth 3.1.1", + "uvth", "vek 0.12.0", "veloren-common", "veloren-common-net", @@ -5774,7 +5806,7 @@ dependencies = [ "tracing-subscriber", "tracing-tracy", "treeculler", - "uvth 3.1.1", + "uvth", "vek 0.12.0", "veloren-client", "veloren-common", @@ -5836,11 +5868,13 @@ dependencies = [ name = "veloren_network" version = "0.3.0" dependencies = [ + "async-channel", "bincode", "bitflags", "clap", "crossbeam-channel 0.5.0", - "futures", + "futures-core", + "futures-util", "lazy_static", "lz-fear", "prometheus", @@ -5849,10 +5883,10 @@ dependencies = [ "shellexpand", "tiny_http", "tokio 1.2.0", + "tokio-stream", "tracing", "tracing-futures", "tracing-subscriber", - "uvth 4.0.1", ] [[package]] diff --git a/client/examples/chat-cli/main.rs b/client/examples/chat-cli/main.rs index 3ffaa7d5b9..115d9ae50c 100644 --- a/client/examples/chat-cli/main.rs +++ b/client/examples/chat-cli/main.rs @@ -3,7 +3,14 @@ #![deny(clippy::clone_on_ref_ptr)] use common::{clock::Clock, comp}; -use std::{io, net::ToSocketAddrs, sync::mpsc, thread, time::Duration}; +use std::{ + io, + net::ToSocketAddrs, + sync::{mpsc, Arc}, + thread, + time::Duration, +}; +use tokio::runtime::Runtime; use tracing::{error, info}; use veloren_client::{Client, Event}; @@ -37,6 +44,8 @@ fn main() { println!("Enter your password"); let password = read_input(); + let runtime = Arc::new(Runtime::new().unwrap()); + // Create a client. let mut client = Client::new( server_addr @@ -45,6 +54,7 @@ fn main() { .next() .unwrap(), None, + runtime, ) .expect("Failed to create client instance"); diff --git a/client/src/lib.rs b/client/src/lib.rs index 9b13ed80f1..b48fa7335b 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -63,8 +63,8 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; -use tracing::{debug, error, trace, warn}; use tokio::runtime::Runtime; +use tracing::{debug, error, trace, warn}; use uvth::{ThreadPool, ThreadPoolBuilder}; use vek::*; @@ -187,7 +187,11 @@ pub struct CharacterList { impl Client { /// Create a new `Client`. - pub fn new>(addr: A, view_distance: Option, runtime: Arc) -> Result { + pub fn new>( + addr: A, + view_distance: Option, + runtime: Arc, + ) -> Result { let mut thread_pool = ThreadPoolBuilder::new() .name("veloren-worker".into()) .build(); diff --git a/network/Cargo.toml b/network/Cargo.toml index d477be73e1..0a540ca6dc 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -20,12 +20,15 @@ serde = { version = "1.0" } #sending crossbeam-channel = "0.5" tokio = { version = "1.2", default-features = false, features = ["io-util", "macros", "rt", "net", "time"] } +tokio-stream = { version = "0.1.2", default-features = false } #tracing and metrics tracing = { version = "0.1", default-features = false } tracing-futures = "0.2" prometheus = { version = "0.11", default-features = false, optional = true } #async -futures = { version = "0.3", features = ["thread-pool"] } +futures-core = { version = "0.3", default-features = false } +futures-util = { version = "0.3", default-features = false, features = ["std"] } +async-channel = "1.5.1" #use for .close() channels #mpsc channel registry lazy_static = { version = "1.4", default-features = false } rand = { version = "0.8" } @@ -35,8 +38,8 @@ lz-fear = { version = "0.1.1", optional = true } [dev-dependencies] tracing-subscriber = { version = "0.2.3", default-features = false, features = ["env-filter", "fmt", "chrono", "ansi", "smallvec"] } -# `uvth` needed for doc tests -uvth = { version = ">= 3.0, <= 4.0", default-features = false } +tokio = { version = "1.0.1", default-features = false, features = ["io-std", "fs", "rt-multi-thread"] } +futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] } clap = { version = "2.33", default-features = false } shellexpand = "2.0.0" tiny_http = "0.8.0" diff --git a/network/examples/chat.rs b/network/examples/chat.rs index 91fcdea733..a1a3f09cf0 100644 --- a/network/examples/chat.rs +++ b/network/examples/chat.rs @@ -3,10 +3,9 @@ //! RUST_BACKTRACE=1 cargo run --example chat -- --trace=info --port 15006 //! RUST_BACKTRACE=1 cargo run --example chat -- --trace=info --port 15006 --mode=client //! ``` -use async_std::{io, sync::RwLock}; use clap::{App, Arg}; -use futures::executor::{block_on, ThreadPool}; use std::{sync::Arc, thread, time::Duration}; +use tokio::{io, io::AsyncBufReadExt, runtime::Runtime, sync::RwLock}; use tracing::*; use tracing_subscriber::EnvFilter; use veloren_network::{Network, Participant, Pid, Promises, ProtocolAddr}; @@ -100,18 +99,17 @@ fn main() { } fn server(address: ProtocolAddr) { - let (server, f) = Network::new(Pid::new()); + let r = Arc::new(Runtime::new().unwrap()); + let server = Network::new(Pid::new(), Arc::clone(&r)); let server = Arc::new(server); - std::thread::spawn(f); - let pool = ThreadPool::new().unwrap(); let participants = Arc::new(RwLock::new(Vec::new())); - block_on(async { + r.block_on(async { server.listen(address).await.unwrap(); loop { let p1 = Arc::new(server.connected().await.unwrap()); let server1 = server.clone(); participants.write().await.push(p1.clone()); - pool.spawn_ok(client_connection(server1, p1, participants.clone())); + tokio::spawn(client_connection(server1, p1, participants.clone())); } }); } @@ -144,27 +142,27 @@ async fn client_connection( } fn client(address: ProtocolAddr) { - let (client, f) = Network::new(Pid::new()); - std::thread::spawn(f); - let pool = ThreadPool::new().unwrap(); + let r = Arc::new(Runtime::new().unwrap()); + let client = Network::new(Pid::new(), Arc::clone(&r)); - block_on(async { + r.block_on(async { let p1 = client.connect(address.clone()).await.unwrap(); //remote representation of p1 let mut s1 = p1 .open(16, Promises::ORDERED | Promises::CONSISTENCY) .await .unwrap(); //remote representation of s1 + let mut input_lines = io::BufReader::new(io::stdin()); println!("Enter your username:"); let mut username = String::new(); - io::stdin().read_line(&mut username).await.unwrap(); + input_lines.read_line(&mut username).await.unwrap(); username = username.split_whitespace().collect(); println!("Your username is: {}", username); println!("write /quit to close"); - pool.spawn_ok(read_messages(p1)); + tokio::spawn(read_messages(p1)); s1.send(username).unwrap(); loop { let mut line = String::new(); - io::stdin().read_line(&mut line).await.unwrap(); + input_lines.read_line(&mut line).await.unwrap(); line = line.split_whitespace().collect(); if line.as_str() == "/quit" { println!("goodbye"); diff --git a/network/examples/fileshare/commands.rs b/network/examples/fileshare/commands.rs index 3967631a56..a18c90b38e 100644 --- a/network/examples/fileshare/commands.rs +++ b/network/examples/fileshare/commands.rs @@ -1,9 +1,7 @@ -use async_std::{ - fs, - path::{Path, PathBuf}, -}; use rand::Rng; use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; +use tokio::fs; use veloren_network::{Participant, ProtocolAddr, Stream}; use std::collections::HashMap; diff --git a/network/examples/fileshare/main.rs b/network/examples/fileshare/main.rs index d4b7b832e7..f000f371e0 100644 --- a/network/examples/fileshare/main.rs +++ b/network/examples/fileshare/main.rs @@ -4,14 +4,9 @@ //! --profile=release -Z unstable-options -- --trace=info --port 15006) //! (cd network/examples/fileshare && RUST_BACKTRACE=1 cargo run //! --profile=release -Z unstable-options -- --trace=info --port 15007) ``` -use async_std::{io, path::PathBuf}; use clap::{App, Arg, SubCommand}; -use futures::{ - channel::mpsc, - executor::{block_on, ThreadPool}, - sink::SinkExt, -}; -use std::{thread, time::Duration}; +use std::{path::PathBuf, sync::Arc, thread, time::Duration}; +use tokio::{io, io::AsyncBufReadExt, runtime::Runtime, sync::mpsc}; use tracing::*; use tracing_subscriber::EnvFilter; use veloren_network::ProtocolAddr; @@ -56,14 +51,14 @@ fn main() { let port: u16 = matches.value_of("port").unwrap().parse().unwrap(); let address = ProtocolAddr::Tcp(format!("{}:{}", "127.0.0.1", port).parse().unwrap()); + let runtime = Arc::new(Runtime::new().unwrap()); - let (server, cmd_sender) = Server::new(); - let pool = ThreadPool::new().unwrap(); - pool.spawn_ok(server.run(address)); + let (server, cmd_sender) = Server::new(Arc::clone(&runtime)); + runtime.spawn(server.run(address)); thread::sleep(Duration::from_millis(50)); //just for trace - block_on(client(cmd_sender)); + runtime.block_on(client(cmd_sender)); } fn file_exists(file: String) -> Result<(), String> { @@ -130,14 +125,15 @@ fn get_options<'a, 'b>() -> App<'a, 'b> { ) } -async fn client(mut cmd_sender: mpsc::UnboundedSender) { +async fn client(cmd_sender: mpsc::UnboundedSender) { use std::io::Write; loop { let mut line = String::new(); + let mut input_lines = io::BufReader::new(io::stdin()); print!("==> "); std::io::stdout().flush().unwrap(); - io::stdin().read_line(&mut line).await.unwrap(); + input_lines.read_line(&mut line).await.unwrap(); let matches = match get_options().get_matches_from_safe(line.split_whitespace()) { Err(e) => { println!("{}", e.message); @@ -148,12 +144,12 @@ async fn client(mut cmd_sender: mpsc::UnboundedSender) { match matches.subcommand() { ("quit", _) => { - cmd_sender.send(LocalCommand::Shutdown).await.unwrap(); + cmd_sender.send(LocalCommand::Shutdown).unwrap(); println!("goodbye"); break; }, ("disconnect", _) => { - cmd_sender.send(LocalCommand::Disconnect).await.unwrap(); + cmd_sender.send(LocalCommand::Disconnect).unwrap(); }, ("connect", Some(connect_matches)) => { let socketaddr = connect_matches @@ -163,7 +159,6 @@ async fn client(mut cmd_sender: mpsc::UnboundedSender) { .unwrap(); cmd_sender .send(LocalCommand::Connect(ProtocolAddr::Tcp(socketaddr))) - .await .unwrap(); }, ("t", _) => { @@ -171,28 +166,23 @@ async fn client(mut cmd_sender: mpsc::UnboundedSender) { .send(LocalCommand::Connect(ProtocolAddr::Tcp( "127.0.0.1:1231".parse().unwrap(), ))) - .await .unwrap(); }, ("serve", Some(serve_matches)) => { let path = shellexpand::tilde(serve_matches.value_of("file").unwrap()); let path: PathBuf = path.parse().unwrap(); if let Some(fileinfo) = FileInfo::new(&path).await { - cmd_sender - .send(LocalCommand::Serve(fileinfo)) - .await - .unwrap(); + cmd_sender.send(LocalCommand::Serve(fileinfo)).unwrap(); } }, ("list", _) => { - cmd_sender.send(LocalCommand::List).await.unwrap(); + cmd_sender.send(LocalCommand::List).unwrap(); }, ("get", Some(get_matches)) => { let id: u32 = get_matches.value_of("id").unwrap().parse().unwrap(); let file = get_matches.value_of("file"); cmd_sender .send(LocalCommand::Get(id, file.map(|s| s.to_string()))) - .await .unwrap(); }, diff --git a/network/examples/fileshare/server.rs b/network/examples/fileshare/server.rs index 080b85fff5..5db8345d46 100644 --- a/network/examples/fileshare/server.rs +++ b/network/examples/fileshare/server.rs @@ -1,11 +1,12 @@ use crate::commands::{Command, FileInfo, LocalCommand, RemoteInfo}; -use async_std::{ - fs, - path::PathBuf, - sync::{Mutex, RwLock}, +use futures_util::{FutureExt, StreamExt}; +use std::{collections::HashMap, path::PathBuf, sync::Arc}; +use tokio::{ + fs, join, + runtime::Runtime, + sync::{mpsc, Mutex, RwLock}, }; -use futures::{channel::mpsc, future::FutureExt, stream::StreamExt}; -use std::{collections::HashMap, sync::Arc}; +use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::*; use veloren_network::{Network, Participant, Pid, Promises, ProtocolAddr, Stream}; @@ -23,11 +24,10 @@ pub struct Server { } impl Server { - pub fn new() -> (Self, mpsc::UnboundedSender) { - let (command_sender, command_receiver) = mpsc::unbounded(); + pub fn new(runtime: Arc) -> (Self, mpsc::UnboundedSender) { + let (command_sender, command_receiver) = mpsc::unbounded_channel(); - let (network, f) = Network::new(Pid::new()); - std::thread::spawn(f); + let network = Network::new(Pid::new(), runtime); let run_channels = Some(ControlChannels { command_receiver }); ( @@ -47,7 +47,7 @@ impl Server { self.network.listen(address).await.unwrap(); - futures::join!( + join!( self.command_manager(run_channels.command_receiver,), self.connect_manager(), ); @@ -55,6 +55,7 @@ impl Server { async fn command_manager(&self, command_receiver: mpsc::UnboundedReceiver) { trace!("Start command_manager"); + let command_receiver = UnboundedReceiverStream::new(command_receiver); command_receiver .for_each_concurrent(None, async move |cmd| { match cmd { @@ -106,7 +107,7 @@ impl Server { async fn connect_manager(&self) { trace!("Start connect_manager"); - let iter = futures::stream::unfold((), |_| { + let iter = futures_util::stream::unfold((), |_| { self.network.connected().map(|r| r.ok().map(|v| (v, ()))) }); @@ -129,7 +130,7 @@ impl Server { let id = p.remote_pid(); let ri = Arc::new(Mutex::new(RemoteInfo::new(cmd_out, file_out, p))); self.remotes.write().await.insert(id, ri.clone()); - futures::join!( + join!( self.handle_remote_cmd(cmd_in, ri.clone()), self.handle_files(file_in, ri.clone()), ); diff --git a/network/examples/network-speed/main.rs b/network/examples/network-speed/main.rs index 5f0617ec68..9814cec998 100644 --- a/network/examples/network-speed/main.rs +++ b/network/examples/network-speed/main.rs @@ -6,12 +6,13 @@ mod metrics; use clap::{App, Arg}; -use futures::executor::block_on; use serde::{Deserialize, Serialize}; use std::{ + sync::Arc, thread, time::{Duration, Instant}, }; +use tokio::runtime::Runtime; use tracing::*; use tracing_subscriber::EnvFilter; use veloren_network::{Message, Network, Pid, Promises, ProtocolAddr}; @@ -101,14 +102,16 @@ fn main() { }; let mut background = None; + let runtime = Arc::new(Runtime::new().unwrap()); match matches.value_of("mode") { - Some("server") => server(address), - Some("client") => client(address), + Some("server") => server(address, Arc::clone(&runtime)), + Some("client") => client(address, Arc::clone(&runtime)), Some("both") => { let address1 = address.clone(); - background = Some(thread::spawn(|| server(address1))); + let runtime2 = Arc::clone(&runtime); + background = Some(thread::spawn(|| server(address1, runtime2))); thread::sleep(Duration::from_millis(200)); //start client after server - client(address); + client(address, Arc::clone(&runtime)); }, _ => panic!("Invalid mode, run --help!"), }; @@ -117,18 +120,17 @@ fn main() { } } -fn server(address: ProtocolAddr) { +fn server(address: ProtocolAddr, runtime: Arc) { let mut metrics = metrics::SimpleMetrics::new(); - let (server, f) = Network::new_with_registry(Pid::new(), metrics.registry()); - std::thread::spawn(f); + let server = Network::new_with_registry(Pid::new(), Arc::clone(&runtime), metrics.registry()); metrics.run("0.0.0.0:59112".parse().unwrap()).unwrap(); - block_on(server.listen(address)).unwrap(); + runtime.block_on(server.listen(address)).unwrap(); loop { info!("Waiting for participant to connect"); - let p1 = block_on(server.connected()).unwrap(); //remote representation of p1 - let mut s1 = block_on(p1.opened()).unwrap(); //remote representation of s1 - block_on(async { + let p1 = runtime.block_on(server.connected()).unwrap(); //remote representation of p1 + let mut s1 = runtime.block_on(p1.opened()).unwrap(); //remote representation of s1 + runtime.block_on(async { let mut last = Instant::now(); let mut id = 0u64; while let Ok(_msg) = s1.recv_raw().await { @@ -145,14 +147,15 @@ fn server(address: ProtocolAddr) { } } -fn client(address: ProtocolAddr) { +fn client(address: ProtocolAddr, runtime: Arc) { let mut metrics = metrics::SimpleMetrics::new(); - let (client, f) = Network::new_with_registry(Pid::new(), metrics.registry()); - std::thread::spawn(f); + let client = Network::new_with_registry(Pid::new(), Arc::clone(&runtime), metrics.registry()); metrics.run("0.0.0.0:59111".parse().unwrap()).unwrap(); - let p1 = block_on(client.connect(address)).unwrap(); //remote representation of p1 - let mut s1 = block_on(p1.open(16, Promises::ORDERED | Promises::CONSISTENCY)).unwrap(); //remote representation of s1 + let p1 = runtime.block_on(client.connect(address)).unwrap(); //remote representation of p1 + let mut s1 = runtime + .block_on(p1.open(16, Promises::ORDERED | Promises::CONSISTENCY)) + .unwrap(); //remote representation of s1 let mut last = Instant::now(); let mut id = 0u64; let raw_msg = Message::serialize( @@ -180,7 +183,7 @@ fn client(address: ProtocolAddr) { drop(s1); std::thread::sleep(std::time::Duration::from_millis(5000)); info!("Closing participant"); - block_on(p1.disconnect()).unwrap(); + runtime.block_on(p1.disconnect()).unwrap(); std::thread::sleep(std::time::Duration::from_millis(25000)); info!("DROPPING! client"); drop(client); diff --git a/network/src/api.rs b/network/src/api.rs index 8baaa72581..1b349f3248 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -8,13 +8,6 @@ use crate::{ scheduler::Scheduler, types::{Mid, Pid, Prio, Promises, Sid}, }; -use tokio::{io, sync::Mutex}; -use tokio::runtime::Runtime; -use futures::{ - channel::{mpsc, oneshot}, - sink::SinkExt, - stream::StreamExt, -}; #[cfg(feature = "compression")] use lz_fear::raw::DecodeError; #[cfg(feature = "metrics")] @@ -28,6 +21,11 @@ use std::{ Arc, }, }; +use tokio::{ + io, + runtime::Runtime, + sync::{mpsc, oneshot, Mutex}, +}; use tracing::*; use tracing_futures::Instrument; @@ -78,9 +76,8 @@ pub struct Stream { prio: Prio, promises: Promises, send_closed: Arc, - runtime: Arc, a2b_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, - b2a_msg_recv_r: Option>, + b2a_msg_recv_r: Option>, a2b_close_stream_s: Option>, } @@ -169,7 +166,8 @@ impl Network { /// # Arguments /// * `participant_id` - provide it by calling [`Pid::new()`], usually you /// don't want to reuse a Pid for 2 `Networks` - /// * `runtime` - provide a tokio::Runtime, it's used to internally spawn tasks + /// * `runtime` - provide a tokio::Runtime, it's used to internally spawn + /// tasks. It is necessary to clean up in the non-async `Drop`. /// /// # Result /// * `Self` - returns a `Network` which can be `Send` to multiple areas of @@ -178,22 +176,15 @@ impl Network { /// /// # Examples /// ```rust - /// //Example with uvth - /// use uvth::ThreadPoolBuilder; + /// //Example with tokio + /// use std::sync::Arc; + /// use tokio::runtime::Runtime; /// use veloren_network::{Network, Pid, ProtocolAddr}; /// - /// let pool = ThreadPoolBuilder::new().build(); - /// let (network, f) = Network::new(Pid::new()); - /// pool.execute(f); + /// let runtime = Runtime::new(); + /// let network = Network::new(Pid::new(), Arc::new(runtime)); /// ``` /// - /// ```rust - /// //Example with std::thread - /// use veloren_network::{Network, Pid, ProtocolAddr}; - /// - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// ``` /// /// Usually you only create a single `Network` for an application, /// except when client and server are in the same application, then you @@ -252,20 +243,18 @@ impl Network { #[cfg(feature = "metrics")] registry, ); - runtime.spawn( - async move { - trace!(?p, "Starting scheduler in own thread"); - let _handle = tokio::spawn( - scheduler - .run() - .instrument(tracing::info_span!("scheduler", ?p)), - ); - trace!(?p, "Stopping scheduler and his own thread"); - } - ); + runtime.spawn(async move { + trace!(?p, "Starting scheduler in own thread"); + let _handle = tokio::spawn( + scheduler + .run() + .instrument(tracing::info_span!("scheduler", ?p)), + ); + trace!(?p, "Stopping scheduler and his own thread"); + }); Self { local_pid: participant_id, - runtime: runtime, + runtime, participant_disconnect_sender: Mutex::new(HashMap::new()), listen_sender: Mutex::new(listen_sender), connect_sender: Mutex::new(connect_sender), @@ -309,8 +298,7 @@ impl Network { self.listen_sender .lock() .await - .send((address, s2a_result_s)) - .await?; + .send((address, s2a_result_s))?; match s2a_result_r.await? { //waiting guarantees that we either listened successfully or get an error like port in // use @@ -365,8 +353,7 @@ impl Network { self.connect_sender .lock() .await - .send((address, pid_sender)) - .await?; + .send((address, pid_sender))?; let participant = match pid_receiver.await? { Ok(p) => p, Err(e) => return Err(NetworkError::ConnectFailed(e)), @@ -417,7 +404,7 @@ impl Network { /// [`Streams`]: crate::api::Stream /// [`listen`]: crate::api::Network::listen pub async fn connected(&self) -> Result { - let participant = self.connected_receiver.lock().await.next().await?; + let participant = self.connected_receiver.lock().await.recv().await?; self.participant_disconnect_sender.lock().await.insert( participant.remote_pid, Arc::clone(&participant.a2s_disconnect_s), @@ -489,12 +476,11 @@ impl Participant { /// [`Streams`]: crate::api::Stream pub async fn open(&self, prio: u8, promises: Promises) -> Result { let (p2a_return_stream_s, p2a_return_stream_r) = oneshot::channel(); - if let Err(e) = self - .a2b_stream_open_s - .lock() - .await - .send((prio, promises, p2a_return_stream_s)) - .await + if let Err(e) = + self.a2b_stream_open_s + .lock() + .await + .send((prio, promises, p2a_return_stream_s)) { debug!(?e, "bParticipant is already closed, notifying"); return Err(ParticipantError::ParticipantDisconnected); @@ -546,7 +532,7 @@ impl Participant { /// [`connected`]: Network::connected /// [`open`]: Participant::open pub async fn opened(&self) -> Result { - match self.b2a_stream_opened_r.lock().await.next().await { + match self.b2a_stream_opened_r.lock().await.recv().await { Some(stream) => { let sid = stream.sid; debug!(?sid, ?self.remote_pid, "Receive opened stream"); @@ -609,13 +595,12 @@ impl Participant { //Streams will be closed by BParticipant match self.a2s_disconnect_s.lock().await.take() { - Some(mut a2s_disconnect_s) => { + Some(a2s_disconnect_s) => { let (finished_sender, finished_receiver) = oneshot::channel(); // Participant is connecting to Scheduler here, not as usual // Participant<->BParticipant a2s_disconnect_s .send((pid, finished_sender)) - .await .expect("Something is wrong in internal scheduler coding"); match finished_receiver.await { Ok(res) => { @@ -661,9 +646,8 @@ impl Stream { prio: Prio, promises: Promises, send_closed: Arc, - runtime: Arc, a2b_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, - b2a_msg_recv_r: mpsc::UnboundedReceiver, + b2a_msg_recv_r: async_channel::Receiver, a2b_close_stream_s: mpsc::UnboundedSender, ) -> Self { Self { @@ -673,7 +657,6 @@ impl Stream { prio, promises, send_closed, - runtime, a2b_msg_s, b2a_msg_recv_r: Some(b2a_msg_recv_r), a2b_close_stream_s: Some(a2b_close_stream_s), @@ -877,13 +860,13 @@ impl Stream { pub async fn recv_raw(&mut self) -> Result { match &mut self.b2a_msg_recv_r { Some(b2a_msg_recv_r) => { - match b2a_msg_recv_r.next().await { - Some(msg) => Ok(Message { + match b2a_msg_recv_r.recv().await { + Ok(msg) => Ok(Message { buffer: Arc::new(msg.buffer), #[cfg(feature = "compression")] compressed: self.promises.contains(Promises::COMPRESSED), }), - None => { + Err(_) => { self.b2a_msg_recv_r = None; //prevent panic Err(StreamError::StreamClosed) }, @@ -929,13 +912,8 @@ impl Stream { #[inline] pub fn try_recv(&mut self) -> Result, StreamError> { match &mut self.b2a_msg_recv_r { - Some(b2a_msg_recv_r) => match b2a_msg_recv_r.try_next() { - Err(_) => Ok(None), - Ok(None) => { - self.b2a_msg_recv_r = None; //prevent panic - Err(StreamError::StreamClosed) - }, - Ok(Some(msg)) => Ok(Some( + Some(b2a_msg_recv_r) => match b2a_msg_recv_r.try_recv() { + Ok(msg) => Ok(Some( Message { buffer: Arc::new(msg.buffer), #[cfg(feature = "compression")] @@ -943,6 +921,11 @@ impl Stream { } .deserialize()?, )), + Err(async_channel::TryRecvError::Empty) => Ok(None), + Err(async_channel::TryRecvError::Closed) => { + self.b2a_msg_recv_r = None; //prevent panic + Err(StreamError::StreamClosed) + }, }, None => Err(StreamError::StreamClosed), } @@ -975,16 +958,13 @@ impl Drop for Network { self.participant_disconnect_sender.lock().await.drain() { match a2s_disconnect_s.lock().await.take() { - Some(mut a2s_disconnect_s) => { + Some(a2s_disconnect_s) => { trace!(?remote_pid, "Participants will be closed"); let (finished_sender, finished_receiver) = oneshot::channel(); finished_receiver_list.push((remote_pid, finished_receiver)); - a2s_disconnect_s - .send((remote_pid, finished_sender)) - .await - .expect( - "Scheduler is closed, but nobody other should be able to close it", - ); + a2s_disconnect_s.send((remote_pid, finished_sender)).expect( + "Scheduler is closed, but nobody other should be able to close it", + ); }, None => trace!(?remote_pid, "Participant already disconnected gracefully"), } @@ -1026,13 +1006,12 @@ impl Drop for Participant { ?pid, "Participant has been shutdown cleanly, no further waiting is required!" ), - Some(mut a2s_disconnect_s) => { + Some(a2s_disconnect_s) => { debug!(?pid, "Disconnect from Scheduler"); self.runtime.block_on(async { let (finished_sender, finished_receiver) = oneshot::channel(); a2s_disconnect_s .send((self.remote_pid, finished_sender)) - .await .expect("Something is wrong in internal scheduler coding"); if let Err(e) = finished_receiver .await @@ -1059,7 +1038,10 @@ impl Drop for Stream { let sid = self.sid; let pid = self.pid; debug!(?pid, ?sid, "Shutting down Stream"); - self.runtime.block_on(self.a2b_close_stream_s.take().unwrap().send(self.sid)) + self.a2b_close_stream_s + .take() + .unwrap() + .send(self.sid) .expect("bparticipant part of a gracefully shutdown must have crashed"); } else { let sid = self.sid; @@ -1096,12 +1078,16 @@ impl From for NetworkError { fn from(_err: std::option::NoneError) -> Self { NetworkError::NetworkClosed } } -impl From for NetworkError { - fn from(_err: mpsc::SendError) -> Self { NetworkError::NetworkClosed } +impl From> for NetworkError { + fn from(_err: mpsc::error::SendError) -> Self { NetworkError::NetworkClosed } } -impl From for NetworkError { - fn from(_err: oneshot::Canceled) -> Self { NetworkError::NetworkClosed } +impl From for NetworkError { + fn from(_err: oneshot::error::RecvError) -> Self { NetworkError::NetworkClosed } +} + +impl From for NetworkError { + fn from(_err: std::io::Error) -> Self { NetworkError::NetworkClosed } } impl From> for StreamError { diff --git a/network/src/channel.rs b/network/src/channel.rs index c591fb0b88..7928337bd1 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -8,14 +8,16 @@ use crate::{ VELOREN_NETWORK_VERSION, }, }; -use futures::{ - channel::{mpsc, oneshot}, - join, - sink::SinkExt, - stream::StreamExt, +use futures_core::task::Poll; +use futures_util::{ + task::{noop_waker, Context}, FutureExt, }; #[cfg(feature = "metrics")] use std::sync::Arc; +use tokio::{ + join, + sync::{mpsc, oneshot}, +}; use tracing::*; pub(crate) struct Channel { @@ -26,7 +28,7 @@ pub(crate) struct Channel { impl Channel { pub fn new(cid: u64) -> (Self, mpsc::UnboundedSender, oneshot::Sender<()>) { - let (c2w_frame_s, c2w_frame_r) = mpsc::unbounded::(); + let (c2w_frame_s, c2w_frame_r) = mpsc::unbounded_channel::(); let (read_stop_sender, read_stop_receiver) = oneshot::channel(); ( Self { @@ -52,7 +54,7 @@ impl Channel { let cnt = leftover_cid_frame.len(); trace!(?cnt, "Reapplying leftovers"); for cid_frame in leftover_cid_frame.drain(..) { - w2c_cid_frame_s.send(cid_frame).await.unwrap(); + w2c_cid_frame_s.send(cid_frame).unwrap(); } trace!(?cnt, "All leftovers reapplied"); @@ -115,8 +117,8 @@ impl Handshake { } pub async fn setup(self, protocol: &Protocols) -> Result<(Pid, Sid, u128, Vec), ()> { - let (c2w_frame_s, c2w_frame_r) = mpsc::unbounded::(); - let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded::(); + let (c2w_frame_s, c2w_frame_r) = mpsc::unbounded_channel::(); + let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded_channel::(); let (read_stop_sender, read_stop_receiver) = oneshot::channel(); let handler_future = @@ -142,8 +144,10 @@ impl Handshake { match res { Ok(res) => { + let fake_waker = noop_waker(); + let mut ctx = Context::from_waker(&fake_waker); let mut leftover_frames = vec![]; - while let Ok(Some(cid_frame)) = w2c_cid_frame_r.try_next() { + while let Poll::Ready(Some(cid_frame)) = w2c_cid_frame_r.poll_recv(&mut ctx) { leftover_frames.push(cid_frame); } let cnt = leftover_frames.len(); @@ -175,7 +179,7 @@ impl Handshake { self.send_handshake(&mut c2w_frame_s).await; } - let frame = w2c_cid_frame_r.next().await.map(|(_cid, frame)| frame); + let frame = w2c_cid_frame_r.recv().await.map(|(_cid, frame)| frame); #[cfg(feature = "metrics")] { if let Some(Ok(ref frame)) = frame { @@ -254,7 +258,7 @@ impl Handshake { return Err(()); } - let frame = w2c_cid_frame_r.next().await.map(|(_cid, frame)| frame); + let frame = w2c_cid_frame_r.recv().await.map(|(_cid, frame)| frame); let r = match frame { Some(Ok(Frame::Init { pid, secret })) => { debug!(?pid, "Participant send their ID"); @@ -315,7 +319,6 @@ impl Handshake { magic_number: VELOREN_MAGIC_NUMBER, version: VELOREN_NETWORK_VERSION, }) - .await .unwrap(); } @@ -330,7 +333,6 @@ impl Handshake { pid: self.local_pid, secret: self.secret, }) - .await .unwrap(); } @@ -353,7 +355,7 @@ impl Handshake { .with_label_values(&[&cid_string, "Shutdown"]) .inc(); } - c2w_frame_s.send(Frame::Raw(data)).await.unwrap(); - c2w_frame_s.send(Frame::Shutdown).await.unwrap(); + c2w_frame_s.send(Frame::Raw(data)).unwrap(); + c2w_frame_s.send(Frame::Shutdown).unwrap(); } } diff --git a/network/src/lib.rs b/network/src/lib.rs index 69bd5f07c0..ffba192643 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -39,8 +39,8 @@ //! //! # Examples //! ```rust -//! use tokio::task::sleep; //! use futures::{executor::block_on, join}; +//! use tokio::task::sleep; //! use veloren_network::{Network, Pid, Promises, ProtocolAddr}; //! //! // Client diff --git a/network/src/message.rs b/network/src/message.rs index ad668908b6..9ab9941599 100644 --- a/network/src/message.rs +++ b/network/src/message.rs @@ -256,8 +256,8 @@ impl std::fmt::Debug for MessageBuffer { #[cfg(test)] mod tests { use crate::{api::Stream, message::*}; - use futures::channel::mpsc; use std::sync::{atomic::AtomicBool, Arc}; + use tokio::sync::mpsc; fn stub_stream(compressed: bool) -> Stream { use crate::{api::*, types::*}; @@ -273,8 +273,8 @@ mod tests { let promises = Promises::empty(); let (a2b_msg_s, _a2b_msg_r) = crossbeam_channel::unbounded(); - let (_b2a_msg_recv_s, b2a_msg_recv_r) = mpsc::unbounded(); - let (a2b_close_stream_s, _a2b_close_stream_r) = mpsc::unbounded(); + let (_b2a_msg_recv_s, b2a_msg_recv_r) = async_channel::unbounded(); + let (a2b_close_stream_s, _a2b_close_stream_r) = mpsc::unbounded_channel(); Stream::new( Pid::fake(0), diff --git a/network/src/participant.rs b/network/src/participant.rs index 78d1dacd41..764f407cdf 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -8,15 +8,7 @@ use crate::{ protocols::Protocols, types::{Cid, Frame, Pid, Prio, Promises, Sid}, }; -use tokio::sync::{Mutex, RwLock}; -use tokio::runtime::Runtime; -use futures::{ - channel::{mpsc, oneshot}, - future::FutureExt, - select, - sink::SinkExt, - stream::StreamExt, -}; +use futures_util::{FutureExt, StreamExt}; use std::{ collections::{HashMap, VecDeque}, sync::{ @@ -25,6 +17,12 @@ use std::{ }, time::{Duration, Instant}, }; +use tokio::{ + runtime::Runtime, + select, + sync::{mpsc, oneshot, Mutex, RwLock}, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::*; use tracing_futures::Instrument; @@ -47,13 +45,14 @@ struct StreamInfo { prio: Prio, promises: Promises, send_closed: Arc, - b2a_msg_recv_s: Mutex>, + b2a_msg_recv_s: Mutex>, } #[derive(Debug)] struct ControlChannels { a2b_stream_open_r: mpsc::UnboundedReceiver, b2a_stream_opened_s: mpsc::UnboundedSender, + b2b_close_stream_opened_sender_r: oneshot::Receiver<()>, s2b_create_channel_r: mpsc::UnboundedReceiver, a2b_close_stream_r: mpsc::UnboundedReceiver, a2b_close_stream_s: mpsc::UnboundedSender, @@ -63,7 +62,7 @@ struct ControlChannels { #[derive(Debug)] struct ShutdownInfo { //a2b_stream_open_r: mpsc::UnboundedReceiver, - b2a_stream_opened_s: mpsc::UnboundedSender, + b2b_close_stream_opened_sender_s: Option>, error: Option, } @@ -84,6 +83,12 @@ pub struct BParticipant { } impl BParticipant { + const BANDWIDTH: u64 = 25_000_000; + const FRAMES_PER_TICK: u64 = Self::BANDWIDTH * Self::TICK_TIME_MS / 1000 / 1400 /*TCP FRAME*/; + const TICK_TIME: Duration = Duration::from_millis(Self::TICK_TIME_MS); + //in bit/s + const TICK_TIME_MS: u64 = 10; + #[allow(clippy::type_complexity)] pub(crate) fn new( remote_pid: Pid, @@ -97,21 +102,24 @@ impl BParticipant { mpsc::UnboundedSender, oneshot::Sender, ) { - let (a2b_steam_open_s, a2b_stream_open_r) = mpsc::unbounded::(); - let (b2a_stream_opened_s, b2a_stream_opened_r) = mpsc::unbounded::(); - let (a2b_close_stream_s, a2b_close_stream_r) = mpsc::unbounded(); + let (a2b_steam_open_s, a2b_stream_open_r) = mpsc::unbounded_channel::(); + let (b2a_stream_opened_s, b2a_stream_opened_r) = mpsc::unbounded_channel::(); + let (b2b_close_stream_opened_sender_s, b2b_close_stream_opened_sender_r) = + oneshot::channel(); + let (a2b_close_stream_s, a2b_close_stream_r) = mpsc::unbounded_channel(); let (s2b_shutdown_bparticipant_s, s2b_shutdown_bparticipant_r) = oneshot::channel(); - let (s2b_create_channel_s, s2b_create_channel_r) = mpsc::unbounded(); + let (s2b_create_channel_s, s2b_create_channel_r) = mpsc::unbounded_channel(); let shutdown_info = RwLock::new(ShutdownInfo { //a2b_stream_open_r: a2b_stream_open_r.clone(), - b2a_stream_opened_s: b2a_stream_opened_s.clone(), + b2b_close_stream_opened_sender_s: Some(b2b_close_stream_opened_sender_s), error: None, }); let run_channels = Some(ControlChannels { a2b_stream_open_r, b2a_stream_opened_s, + b2b_close_stream_opened_sender_r, s2b_create_channel_r, a2b_close_stream_r, a2b_close_stream_s, @@ -147,7 +155,7 @@ impl BParticipant { let (shutdown_stream_close_mgr_sender, shutdown_stream_close_mgr_receiver) = oneshot::channel(); let (shutdown_open_mgr_sender, shutdown_open_mgr_receiver) = oneshot::channel(); - let (w2b_frames_s, w2b_frames_r) = mpsc::unbounded::(); + let (w2b_frames_s, w2b_frames_r) = mpsc::unbounded_channel::(); let (prios, a2p_msg_s, b2p_notify_empty_stream_s) = PrioManager::new( #[cfg(feature = "metrics")] Arc::clone(&self.metrics), @@ -155,7 +163,7 @@ impl BParticipant { ); let run_channels = self.run_channels.take().unwrap(); - futures::join!( + tokio::join!( self.open_mgr( run_channels.a2b_stream_open_r, run_channels.a2b_close_stream_s.clone(), @@ -165,6 +173,7 @@ impl BParticipant { self.handle_frames_mgr( w2b_frames_r, run_channels.b2a_stream_opened_s, + run_channels.b2b_close_stream_opened_sender_r, run_channels.a2b_close_stream_s, a2p_msg_s.clone(), ), @@ -188,13 +197,11 @@ impl BParticipant { &self, mut prios: PrioManager, mut shutdown_send_mgr_receiver: oneshot::Receiver>, - mut b2s_prio_statistic_s: mpsc::UnboundedSender, + b2s_prio_statistic_s: mpsc::UnboundedSender, ) { //This time equals the MINIMUM Latency in average, so keep it down and //Todo: // make it configurable or switch to await E.g. Prio 0 = await, prio 50 // wait for more messages - const TICK_TIME: Duration = Duration::from_millis(10); - const FRAMES_PER_TICK: usize = 10005; self.running_mgr.fetch_add(1, Ordering::Relaxed); let mut b2b_prios_flushed_s = None; //closing up trace!("Start send_mgr"); @@ -203,7 +210,9 @@ impl BParticipant { let mut i: u64 = 0; loop { let mut frames = VecDeque::new(); - prios.fill_frames(FRAMES_PER_TICK, &mut frames).await; + prios + .fill_frames(Self::FRAMES_PER_TICK as usize, &mut frames) + .await; let len = frames.len(); for (_, frame) in frames { self.send_frame( @@ -215,9 +224,8 @@ impl BParticipant { } b2s_prio_statistic_s .send((self.remote_pid, len as u64, /* */ 0)) - .await .unwrap(); - tokio::time::sleep(TICK_TIME).await; + tokio::time::sleep(Self::TICK_TIME).await; i += 1; if i.rem_euclid(1000) == 0 { trace!("Did 1000 ticks"); @@ -229,7 +237,7 @@ impl BParticipant { break; } if b2b_prios_flushed_s.is_none() { - if let Some(prios_flushed_s) = shutdown_send_mgr_receiver.try_recv().unwrap() { + if let Ok(prios_flushed_s) = shutdown_send_mgr_receiver.try_recv() { b2b_prios_flushed_s = Some(prios_flushed_s); } } @@ -252,8 +260,9 @@ impl BParticipant { ) -> bool { let mut drop_cid = None; // TODO: find out ideal channel here + let res = if let Some(ci) = self.channels.read().await.values().next() { - let mut ci = ci.lock().await; + let ci = ci.lock().await; //we are increasing metrics without checking the result to please // borrow_checker. otherwise we would need to close `frame` what we // dont want! @@ -261,7 +270,7 @@ impl BParticipant { frames_out_total_cache .with_label_values(ci.cid, &frame) .inc(); - if let Err(e) = ci.b2w_frame_s.send(frame).await { + if let Err(e) = ci.b2w_frame_s.send(frame) { let cid = ci.cid; info!(?e, ?cid, "channel no longer available"); drop_cid = Some(cid); @@ -294,7 +303,6 @@ impl BParticipant { if let Err(e) = ci.b2r_read_shutdown.send(()) { trace!(?cid, ?e, "seems like was already shut down"); } - ci.b2w_frame_s.close_channel(); } //TODO FIXME tags: takeover channel multiple info!( @@ -311,7 +319,8 @@ impl BParticipant { async fn handle_frames_mgr( &self, mut w2b_frames_r: mpsc::UnboundedReceiver, - mut b2a_stream_opened_s: mpsc::UnboundedSender, + b2a_stream_opened_s: mpsc::UnboundedSender, + b2b_close_stream_opened_sender_r: oneshot::Receiver<()>, a2b_close_stream_s: mpsc::UnboundedSender, a2p_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, ) { @@ -323,21 +332,24 @@ impl BParticipant { let mut dropped_instant = Instant::now(); let mut dropped_cnt = 0u64; let mut dropped_sid = Sid::new(0); + let mut b2a_stream_opened_s = Some(b2a_stream_opened_s); + let mut b2b_close_stream_opened_sender_r = b2b_close_stream_opened_sender_r.fuse(); - while let Some((cid, result_frame)) = w2b_frames_r.next().await { + while let Some((cid, result_frame)) = select!( + next = w2b_frames_r.recv().fuse() => next, + _ = &mut b2b_close_stream_opened_sender_r => { + b2a_stream_opened_s = None; + None + }, + ) { //trace!(?result_frame, "handling frame"); let frame = match result_frame { Ok(frame) => frame, Err(()) => { - // The read protocol stopped, i need to make sure that write gets stopped - debug!("read protocol was closed. Stopping write protocol"); - if let Some(ci) = self.channels.read().await.get(&cid) { - let mut ci = ci.lock().await; - ci.b2w_frame_s - .close() - .await - .expect("couldn't stop write protocol"); - } + // The read protocol stopped, i need to make sure that write gets stopped, can + // drop channel as it's dead anyway + debug!("read protocol was closed. Stopping channel"); + self.channels.write().await.remove(&cid); continue; }, }; @@ -360,13 +372,18 @@ impl BParticipant { let stream = self .create_stream(sid, prio, promises, a2p_msg_s, &a2b_close_stream_s) .await; - if let Err(e) = b2a_stream_opened_s.send(stream).await { - warn!( - ?e, - ?sid, - "couldn't notify api::Participant that a stream got opened. Is the \ - participant already dropped?" - ); + match &b2a_stream_opened_s { + None => debug!("dropping openStream as Channel is already closing"), + Some(s) => { + if let Err(e) = s.send(stream) { + warn!( + ?e, + ?sid, + "couldn't notify api::Participant that a stream got opened. \ + Is the participant already dropped?" + ); + } + }, } }, Frame::CloseStream { sid } => { @@ -465,6 +482,7 @@ impl BParticipant { ) { self.running_mgr.fetch_add(1, Ordering::Relaxed); trace!("Start create_channel_mgr"); + let s2b_create_channel_r = UnboundedReceiverStream::new(s2b_create_channel_r); s2b_create_channel_r .for_each_concurrent( None, @@ -549,8 +567,8 @@ impl BParticipant { let mut shutdown_open_mgr_receiver = shutdown_open_mgr_receiver.fuse(); //from api or shutdown signal while let Some((prio, promises, p2a_return_stream)) = select! { - next = a2b_stream_open_r.next().fuse() => next, - _ = shutdown_open_mgr_receiver => None, + next = a2b_stream_open_r.recv().fuse() => next, + _ = &mut shutdown_open_mgr_receiver => None, } { debug!(?prio, ?promises, "Got request to open a new steam"); //TODO: a2b_stream_open_r isn't closed on api_close yet. This needs to change. @@ -657,7 +675,6 @@ impl BParticipant { itself, ignoring" ); }; - ci.b2w_frame_s.close_channel(); } //Wait for other bparticipants mgr to close via AtomicUsize @@ -712,8 +729,8 @@ impl BParticipant { //from api or shutdown signal while let Some(sid) = select! { - next = a2b_close_stream_r.next().fuse() => next, - sender = shutdown_stream_close_mgr_receiver => { + next = a2b_close_stream_r.recv().fuse() => next, + sender = &mut shutdown_stream_close_mgr_receiver => { b2b_stream_close_shutdown_confirmed_s = Some(sender.unwrap()); None } @@ -779,7 +796,7 @@ impl BParticipant { match self.streams.read().await.get(&sid) { Some(si) => { si.send_closed.store(true, Ordering::Relaxed); - si.b2a_msg_recv_s.lock().await.close_channel(); + si.b2a_msg_recv_s.lock().await.close(); }, None => trace!( "Couldn't find the stream, might be simultaneous close from local/remote" @@ -828,7 +845,7 @@ impl BParticipant { a2p_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, a2b_close_stream_s: &mpsc::UnboundedSender, ) -> Stream { - let (b2a_msg_recv_s, b2a_msg_recv_r) = mpsc::unbounded::(); + let (b2a_msg_recv_s, b2a_msg_recv_r) = async_channel::unbounded::(); let send_closed = Arc::new(AtomicBool::new(false)); self.streams.write().await.insert(sid, StreamInfo { prio, @@ -847,7 +864,6 @@ impl BParticipant { prio, promises, send_closed, - Arc::clone(&self.runtime), a2p_msg_s, b2a_msg_recv_r, a2b_close_stream_s.clone(), @@ -860,7 +876,9 @@ impl BParticipant { if let Some(r) = reason { lock.error = Some(r); } - lock.b2a_stream_opened_s.close_channel(); + lock.b2b_close_stream_opened_sender_s + .take() + .map(|s| s.send(())); debug!("Closing all streams for write"); for (sid, si) in self.streams.read().await.iter() { @@ -876,7 +894,7 @@ impl BParticipant { debug!("Closing all streams"); for (sid, si) in self.streams.read().await.iter() { trace!(?sid, "Shutting down Stream"); - si.b2a_msg_recv_s.lock().await.close_channel(); + si.b2a_msg_recv_s.lock().await.close(); } } } diff --git a/network/src/prios.rs b/network/src/prios.rs index 46d31024b5..a544a31241 100644 --- a/network/src/prios.rs +++ b/network/src/prios.rs @@ -11,9 +11,9 @@ use crate::{ types::{Frame, Prio, Sid}, }; use crossbeam_channel::{unbounded, Receiver, Sender}; -use futures::channel::oneshot; use std::collections::{HashMap, HashSet, VecDeque}; #[cfg(feature = "metrics")] use std::sync::Arc; +use tokio::sync::oneshot; use tracing::trace; const PRIO_MAX: usize = 64; @@ -289,8 +289,8 @@ mod tests { types::{Frame, Pid, Prio, Sid}, }; use crossbeam_channel::Sender; - use futures::{channel::oneshot, executor::block_on}; use std::{collections::VecDeque, sync::Arc}; + use tokio::{runtime::Runtime, sync::oneshot}; const SIZE: u64 = OutgoingMessage::FRAME_DATA_SIZE; const USIZE: usize = OutgoingMessage::FRAME_DATA_SIZE as usize; @@ -366,7 +366,9 @@ mod tests { let (mut mgr, msg_tx, _flush_tx) = mock_new(); msg_tx.send(mock_out(16, 1337)).unwrap(); let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(100, &mut frames)); + Runtime::new() + .unwrap() + .block_on(mgr.fill_frames(100, &mut frames)); assert_header(&mut frames, 1337, 3); assert_data(&mut frames, 0, vec![48, 49, 50]); @@ -380,7 +382,9 @@ mod tests { msg_tx.send(mock_out(20, 42)).unwrap(); let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(100, &mut frames)); + Runtime::new() + .unwrap() + .block_on(mgr.fill_frames(100, &mut frames)); assert_header(&mut frames, 1337, 3); assert_data(&mut frames, 0, vec![48, 49, 50]); assert_header(&mut frames, 42, 3); @@ -394,7 +398,9 @@ mod tests { msg_tx.send(mock_out(20, 42)).unwrap(); msg_tx.send(mock_out(16, 1337)).unwrap(); let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(100, &mut frames)); + Runtime::new() + .unwrap() + .block_on(mgr.fill_frames(100, &mut frames)); assert_header(&mut frames, 1337, 3); assert_data(&mut frames, 0, vec![48, 49, 50]); @@ -420,7 +426,9 @@ mod tests { msg_tx.send(mock_out(16, 11)).unwrap(); msg_tx.send(mock_out(20, 13)).unwrap(); let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(100, &mut frames)); + Runtime::new() + .unwrap() + .block_on(mgr.fill_frames(100, &mut frames)); for i in 1..14 { assert_header(&mut frames, i, 3); @@ -447,13 +455,17 @@ mod tests { msg_tx.send(mock_out(20, 13)).unwrap(); let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(3, &mut frames)); + Runtime::new() + .unwrap() + .block_on(mgr.fill_frames(3, &mut frames)); for i in 1..4 { assert_header(&mut frames, i, 3); assert_data(&mut frames, 0, vec![48, 49, 50]); } assert!(frames.is_empty()); - block_on(mgr.fill_frames(11, &mut frames)); + Runtime::new() + .unwrap() + .block_on(mgr.fill_frames(11, &mut frames)); for i in 4..14 { assert_header(&mut frames, i, 3); assert_data(&mut frames, 0, vec![48, 49, 50]); @@ -466,7 +478,9 @@ mod tests { let (mut mgr, msg_tx, _flush_tx) = mock_new(); msg_tx.send(mock_out_large(16, 1)).unwrap(); let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(100, &mut frames)); + Runtime::new() + .unwrap() + .block_on(mgr.fill_frames(100, &mut frames)); assert_header(&mut frames, 1, SIZE * 2 + 20); assert_data(&mut frames, 0, vec![48; USIZE]); @@ -481,7 +495,9 @@ mod tests { msg_tx.send(mock_out_large(16, 1)).unwrap(); msg_tx.send(mock_out_large(16, 2)).unwrap(); let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(100, &mut frames)); + Runtime::new() + .unwrap() + .block_on(mgr.fill_frames(100, &mut frames)); assert_header(&mut frames, 1, SIZE * 2 + 20); assert_data(&mut frames, 0, vec![48; USIZE]); @@ -500,14 +516,18 @@ mod tests { msg_tx.send(mock_out_large(16, 1)).unwrap(); msg_tx.send(mock_out_large(16, 2)).unwrap(); let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(2, &mut frames)); + Runtime::new() + .unwrap() + .block_on(mgr.fill_frames(2, &mut frames)); assert_header(&mut frames, 1, SIZE * 2 + 20); assert_data(&mut frames, 0, vec![48; USIZE]); assert_data(&mut frames, SIZE, vec![49; USIZE]); msg_tx.send(mock_out(0, 3)).unwrap(); - block_on(mgr.fill_frames(100, &mut frames)); + Runtime::new() + .unwrap() + .block_on(mgr.fill_frames(100, &mut frames)); assert_header(&mut frames, 3, 3); assert_data(&mut frames, 0, vec![48, 49, 50]); @@ -530,7 +550,9 @@ mod tests { msg_tx.send(mock_out(16, 2)).unwrap(); msg_tx.send(mock_out(16, 2)).unwrap(); let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(2000, &mut frames)); + Runtime::new() + .unwrap() + .block_on(mgr.fill_frames(2000, &mut frames)); assert_header(&mut frames, 2, 3); assert_data(&mut frames, 0, vec![48, 49, 50]); @@ -549,13 +571,17 @@ mod tests { msg_tx.send(mock_out(16, 2)).unwrap(); } let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(2000, &mut frames)); + Runtime::new() + .unwrap() + .block_on(mgr.fill_frames(2000, &mut frames)); //^unimportant frames, gonna be dropped msg_tx.send(mock_out(20, 1)).unwrap(); msg_tx.send(mock_out(16, 2)).unwrap(); msg_tx.send(mock_out(16, 2)).unwrap(); let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(2000, &mut frames)); + Runtime::new() + .unwrap() + .block_on(mgr.fill_frames(2000, &mut frames)); //important in that test is, that after the first frames got cleared i reset // the Points even though 998 prio 16 messages have been send at this @@ -589,7 +615,9 @@ mod tests { .unwrap(); let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(2000, &mut frames)); + Runtime::new() + .unwrap() + .block_on(mgr.fill_frames(2000, &mut frames)); assert_header(&mut frames, 2, 7000); assert_data(&mut frames, 0, vec![1; USIZE]); @@ -619,7 +647,9 @@ mod tests { msg_tx.send(mock_out(16, 8)).unwrap(); let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(2000, &mut frames)); + Runtime::new() + .unwrap() + .block_on(mgr.fill_frames(2000, &mut frames)); assert_header(&mut frames, 2, 7000); assert_data(&mut frames, 0, vec![1; USIZE]); @@ -651,7 +681,9 @@ mod tests { msg_tx.send(mock_out(20, 8)).unwrap(); let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(2000, &mut frames)); + Runtime::new() + .unwrap() + .block_on(mgr.fill_frames(2000, &mut frames)); assert_header(&mut frames, 2, 7000); assert_data(&mut frames, 0, vec![1; USIZE]); diff --git a/network/src/protocols.rs b/network/src/protocols.rs index 771ea649e5..b92ef27eee 100644 --- a/network/src/protocols.rs +++ b/network/src/protocols.rs @@ -4,19 +4,14 @@ use crate::{ participant::C2pFrame, types::{Cid, Frame}, }; +use futures_util::{future::Fuse, FutureExt}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{TcpStream, UdpSocket}, + select, + sync::{mpsc, oneshot, Mutex}, }; -use futures::{ - channel::{mpsc, oneshot}, - future::{Fuse, FutureExt}, - lock::Mutex, - select, - sink::SinkExt, - stream::StreamExt, -}; use std::{convert::TryFrom, net::SocketAddr, sync::Arc}; use tracing::*; @@ -75,7 +70,7 @@ impl TcpProtocol { async fn read_frame( r: &mut R, - mut end_receiver: &mut Fuse>, + end_receiver: &mut Fuse>, ) -> Result> { let handle = |read_result| match read_result { Ok(_) => Ok(()), @@ -190,7 +185,6 @@ impl TcpProtocol { } w2c_cid_frame_s .send((cid, Ok(frame))) - .await .expect("Channel or Participant seems no longer to exist"); }, Err(e_option) => { @@ -201,7 +195,6 @@ impl TcpProtocol { // need a explicit STOP here w2c_cid_frame_s .send((cid, Err(()))) - .await .expect("Channel or Participant seems no longer to exist"); } //None is clean shutdown @@ -284,7 +277,7 @@ impl TcpProtocol { #[cfg(not(feature = "metrics"))] let _cid = cid; - while let Some(frame) = c2w_frame_r.next().await { + while let Some(frame) = c2w_frame_r.recv().await { #[cfg(feature = "metrics")] { metrics_cache.with_label_values(&frame).inc(); @@ -343,15 +336,15 @@ impl UdpProtocol { let mut data_in = self.data_in.lock().await; let mut end_r = end_r.fuse(); while let Some(bytes) = select! { - r = data_in.next().fuse() => match r { + r = data_in.recv().fuse() => match r { Some(r) => Some(r), None => { info!("Udp read ended"); - w2c_cid_frame_s.send((cid, Err(()))).await.expect("Channel or Participant seems no longer to exist"); + w2c_cid_frame_s.send((cid, Err(()))).expect("Channel or Participant seems no longer to exist"); None } }, - _ = end_r => None, + _ = &mut end_r => None, } { trace!("Got raw UDP message with len: {}", bytes.len()); let frame_no = bytes[0]; @@ -389,7 +382,7 @@ impl UdpProtocol { }; #[cfg(feature = "metrics")] metrics_cache.with_label_values(&frame).inc(); - w2c_cid_frame_s.send((cid, Ok(frame))).await.unwrap(); + w2c_cid_frame_s.send((cid, Ok(frame))).unwrap(); } trace!("Shutting down udp read()"); } @@ -406,7 +399,7 @@ impl UdpProtocol { .with_label_values(&[&cid.to_string()]); #[cfg(not(feature = "metrics"))] let _cid = cid; - while let Some(frame) = c2w_frame_r.next().await { + while let Some(frame) = c2w_frame_r.recv().await { #[cfg(feature = "metrics")] metrics_cache.with_label_values(&frame).inc(); let len = match frame { @@ -501,9 +494,8 @@ impl UdpProtocol { mod tests { use super::*; use crate::{metrics::NetworkMetrics, types::Pid}; - use tokio::net; - use futures::{executor::block_on, stream::StreamExt}; use std::sync::Arc; + use tokio::{net, runtime::Runtime, sync::mpsc}; #[test] fn tcp_read_handshake() { @@ -511,11 +503,11 @@ mod tests { let cid = 80085; let metrics = Arc::new(NetworkMetrics::new(&pid).unwrap()); let addr = std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(127, 0, 0, 1), 50500); - block_on(async { + Runtime::new().unwrap().block_on(async { let server = net::TcpListener::bind(addr).await.unwrap(); let mut client = net::TcpStream::connect(addr).await.unwrap(); - let s_stream = server.incoming().next().await.unwrap().unwrap(); + let (s_stream, _) = server.accept().await.unwrap(); let prot = TcpProtocol::new(s_stream, metrics); //Send Handshake @@ -524,21 +516,21 @@ mod tests { client.write_all(&1337u32.to_le_bytes()).await.unwrap(); client.write_all(&0u32.to_le_bytes()).await.unwrap(); client.write_all(&42u32.to_le_bytes()).await.unwrap(); - client.flush(); + client.flush().await.unwrap(); //handle data - let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded::(); + let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded_channel::(); let (read_stop_sender, read_stop_receiver) = oneshot::channel(); let cid2 = cid; let t = std::thread::spawn(move || { - block_on(async { + Runtime::new().unwrap().block_on(async { prot.read_from_wire(cid2, &mut w2c_cid_frame_s, read_stop_receiver) .await; }) }); // Assert than we get some value back! Its a Handshake! //tokio::task::sleep(std::time::Duration::from_millis(1000)); - let (cid_r, frame) = w2c_cid_frame_r.next().await.unwrap(); + let (cid_r, frame) = w2c_cid_frame_r.recv().await.unwrap(); assert_eq!(cid, cid_r); if let Ok(Frame::Handshake { magic_number, @@ -561,11 +553,11 @@ mod tests { let cid = 80085; let metrics = Arc::new(NetworkMetrics::new(&pid).unwrap()); let addr = std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(127, 0, 0, 1), 50501); - block_on(async { + Runtime::new().unwrap().block_on(async { let server = net::TcpListener::bind(addr).await.unwrap(); let mut client = net::TcpStream::connect(addr).await.unwrap(); - let s_stream = server.incoming().next().await.unwrap().unwrap(); + let (s_stream, _) = server.accept().await.unwrap(); let prot = TcpProtocol::new(s_stream, metrics); //Send Handshake @@ -573,19 +565,19 @@ mod tests { .write_all("x4hrtzsektfhxugzdtz5r78gzrtzfhxfdthfthuzhfzzufasgasdfg".as_bytes()) .await .unwrap(); - client.flush(); + client.flush().await.unwrap(); //handle data - let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded::(); + let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded_channel::(); let (read_stop_sender, read_stop_receiver) = oneshot::channel(); let cid2 = cid; let t = std::thread::spawn(move || { - block_on(async { + Runtime::new().unwrap().block_on(async { prot.read_from_wire(cid2, &mut w2c_cid_frame_s, read_stop_receiver) .await; }) }); // Assert than we get some value back! Its a Raw! - let (cid_r, frame) = w2c_cid_frame_r.next().await.unwrap(); + let (cid_r, frame) = w2c_cid_frame_r.recv().await.unwrap(); assert_eq!(cid, cid_r); if let Ok(Frame::Raw(data)) = frame { assert_eq!(&data.as_slice(), b"x4hrtzsektfhxugzdtz5r78gzrtzfhxf"); diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index e0c3b0ef84..f648d48a15 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -7,15 +7,7 @@ use crate::{ protocols::{Protocols, TcpProtocol, UdpProtocol}, types::Pid, }; -use tokio::{io, net, sync::Mutex}; -use tokio::runtime::Runtime; -use futures::{ - channel::{mpsc, oneshot}, - future::FutureExt, - select, - sink::SinkExt, - stream::StreamExt, -}; +use futures_util::{FutureExt, StreamExt}; #[cfg(feature = "metrics")] use prometheus::Registry; use rand::Rng; @@ -26,6 +18,13 @@ use std::{ Arc, }, }; +use tokio::{ + io, net, + runtime::Runtime, + select, + sync::{mpsc, oneshot, Mutex}, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::*; use tracing_futures::Instrument; @@ -92,12 +91,13 @@ impl Scheduler { mpsc::UnboundedReceiver, oneshot::Sender<()>, ) { - let (a2s_listen_s, a2s_listen_r) = mpsc::unbounded::(); - let (a2s_connect_s, a2s_connect_r) = mpsc::unbounded::(); - let (s2a_connected_s, s2a_connected_r) = mpsc::unbounded::(); + let (a2s_listen_s, a2s_listen_r) = mpsc::unbounded_channel::(); + let (a2s_connect_s, a2s_connect_r) = mpsc::unbounded_channel::(); + let (s2a_connected_s, s2a_connected_r) = mpsc::unbounded_channel::(); let (a2s_scheduler_shutdown_s, a2s_scheduler_shutdown_r) = oneshot::channel::<()>(); - let (a2s_disconnect_s, a2s_disconnect_r) = mpsc::unbounded::(); - let (b2s_prio_statistic_s, b2s_prio_statistic_r) = mpsc::unbounded::(); + let (a2s_disconnect_s, a2s_disconnect_r) = mpsc::unbounded_channel::(); + let (b2s_prio_statistic_s, b2s_prio_statistic_r) = + mpsc::unbounded_channel::(); let run_channels = Some(ControlChannels { a2s_listen_r, @@ -150,7 +150,7 @@ impl Scheduler { pub async fn run(mut self) { let run_channels = self.run_channels.take().unwrap(); - futures::join!( + tokio::join!( self.listen_mgr(run_channels.a2s_listen_r), self.connect_mgr(run_channels.a2s_connect_r), self.disconnect_mgr(run_channels.a2s_disconnect_r), @@ -161,6 +161,7 @@ impl Scheduler { async fn listen_mgr(&self, a2s_listen_r: mpsc::UnboundedReceiver) { trace!("Start listen_mgr"); + let a2s_listen_r = UnboundedReceiverStream::new(a2s_listen_r); a2s_listen_r .for_each_concurrent(None, |(address, s2a_listen_result_s)| { let address = address; @@ -197,7 +198,7 @@ impl Scheduler { )>, ) { trace!("Start connect_mgr"); - while let Some((addr, pid_sender)) = a2s_connect_r.next().await { + while let Some((addr, pid_sender)) = a2s_connect_r.recv().await { let (protocol, handshake) = match addr { ProtocolAddr::Tcp(addr) => { #[cfg(feature = "metrics")] @@ -240,7 +241,7 @@ impl Scheduler { continue; }; info!("Connecting Udp to: {}", addr); - let (udp_data_sender, udp_data_receiver) = mpsc::unbounded::>(); + let (udp_data_sender, udp_data_receiver) = mpsc::unbounded_channel::>(); let protocol = UdpProtocol::new( Arc::clone(&socket), addr, @@ -264,7 +265,7 @@ impl Scheduler { async fn disconnect_mgr(&self, mut a2s_disconnect_r: mpsc::UnboundedReceiver) { trace!("Start disconnect_mgr"); - while let Some((pid, return_once_successful_shutdown)) = a2s_disconnect_r.next().await { + while let Some((pid, return_once_successful_shutdown)) = a2s_disconnect_r.recv().await { //Closing Participants is done the following way: // 1. We drop our senders and receivers // 2. we need to close BParticipant, this will drop its senderns and receivers @@ -299,7 +300,7 @@ impl Scheduler { mut b2s_prio_statistic_r: mpsc::UnboundedReceiver, ) { trace!("Start prio_adj_mgr"); - while let Some((_pid, _frame_cnt, _unused)) = b2s_prio_statistic_r.next().await { + while let Some((_pid, _frame_cnt, _unused)) = b2s_prio_statistic_r.recv().await { //TODO adjust prios in participants here! } @@ -381,7 +382,7 @@ impl Scheduler { let mut end_receiver = s2s_stop_listening_r.fuse(); while let Some(data) = select! { next = listener.accept().fuse() => Some(next), - _ = end_receiver => None, + _ = &mut end_receiver => None, } { let (stream, remote_addr) = match data { Ok((s, p)) => (s, p), @@ -425,7 +426,7 @@ impl Scheduler { let mut data = [0u8; UDP_MAXIMUM_SINGLE_PACKET_SIZE_EVER]; while let Ok((size, remote_addr)) = select! { next = socket.recv_from(&mut data).fuse() => next, - _ = end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")), + _ = &mut end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")), } { let mut datavec = Vec::with_capacity(size); datavec.extend_from_slice(&data[0..size]); @@ -434,7 +435,8 @@ impl Scheduler { #[allow(clippy::map_entry)] if !listeners.contains_key(&remote_addr) { info!("Accepting Udp from: {}", &remote_addr); - let (udp_data_sender, udp_data_receiver) = mpsc::unbounded::>(); + let (udp_data_sender, udp_data_receiver) = + mpsc::unbounded_channel::>(); listeners.insert(remote_addr, udp_data_sender); let protocol = UdpProtocol::new( Arc::clone(&socket), @@ -447,7 +449,7 @@ impl Scheduler { .await; } let udp_data_sender = listeners.get_mut(&remote_addr).unwrap(); - udp_data_sender.send(datavec).await.unwrap(); + udp_data_sender.send(datavec).unwrap(); } }, _ => unimplemented!(), @@ -457,7 +459,7 @@ impl Scheduler { async fn udp_single_channel_connect( socket: Arc, - mut w2p_udp_package_s: mpsc::UnboundedSender>, + w2p_udp_package_s: mpsc::UnboundedSender>, ) { let addr = socket.local_addr(); trace!(?addr, "Start udp_single_channel_connect"); @@ -470,11 +472,11 @@ impl Scheduler { let mut data = [0u8; 9216]; while let Ok(size) = select! { next = socket.recv(&mut data).fuse() => next, - _ = end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")), + _ = &mut end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")), } { let mut datavec = Vec::with_capacity(size); datavec.extend_from_slice(&data[0..size]); - w2p_udp_package_s.send(datavec).await.unwrap(); + w2p_udp_package_s.send(datavec).unwrap(); } trace!(?addr, "Stop udp_single_channel_connect"); } @@ -491,7 +493,7 @@ impl Scheduler { Contra: - DOS possibility because we answer first - Speed, because otherwise the message can be send with the creation */ - let mut participant_channels = self.participant_channels.lock().await.clone().unwrap(); + let participant_channels = self.participant_channels.lock().await.clone().unwrap(); // spawn is needed here, e.g. for TCP connect it would mean that only 1 // participant can be in handshake phase ever! Someone could deadlock // the whole server easily for new clients UDP doesnt work at all, as @@ -533,7 +535,7 @@ impl Scheduler { bparticipant, a2b_stream_open_s, b2a_stream_opened_r, - mut s2b_create_channel_s, + s2b_create_channel_s, s2b_shutdown_bparticipant_s, ) = BParticipant::new( pid, @@ -578,7 +580,6 @@ impl Scheduler { leftover_cid_frame, b2s_create_channel_done_s, )) - .await .unwrap(); b2s_create_channel_done_r.await.unwrap(); if let Some(pid_oneshot) = s2a_return_pid_s { @@ -589,7 +590,6 @@ impl Scheduler { participant_channels .s2a_connected_s .send(participant) - .await .unwrap(); } } else { diff --git a/network/tests/closing.rs b/network/tests/closing.rs index fac118ff5a..e3abb25533 100644 --- a/network/tests/closing.rs +++ b/network/tests/closing.rs @@ -18,8 +18,8 @@ //! - You sometimes see sleep(1000ms) this is used when we rely on the //! underlying TCP functionality, as this simulates client and server -use async_std::task; -use task::block_on; +use std::sync::Arc; +use tokio::runtime::Runtime; use veloren_network::{Network, ParticipantError, Pid, Promises, StreamError}; mod helper; use helper::{network_participant_stream, tcp}; @@ -27,26 +27,26 @@ use helper::{network_participant_stream, tcp}; #[test] fn close_network() { let (_, _) = helper::setup(false, 0); - let (_, _p1_a, mut s1_a, _, _p1_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _, _p1_a, mut s1_a, _, _p1_b, mut s1_b) = network_participant_stream(tcp()); std::thread::sleep(std::time::Duration::from_millis(1000)); assert_eq!(s1_a.send("Hello World"), Err(StreamError::StreamClosed)); - let msg1: Result = block_on(s1_b.recv()); + let msg1: Result = r.block_on(s1_b.recv()); assert_eq!(msg1, Err(StreamError::StreamClosed)); } #[test] fn close_participant() { let (_, _) = helper::setup(false, 0); - let (_n_a, p1_a, mut s1_a, _n_b, p1_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _n_a, p1_a, mut s1_a, _n_b, p1_b, mut s1_b) = network_participant_stream(tcp()); - block_on(p1_a.disconnect()).unwrap(); - block_on(p1_b.disconnect()).unwrap(); + r.block_on(p1_a.disconnect()).unwrap(); + r.block_on(p1_b.disconnect()).unwrap(); assert_eq!(s1_a.send("Hello World"), Err(StreamError::StreamClosed)); assert_eq!( - block_on(s1_b.recv::()), + r.block_on(s1_b.recv::()), Err(StreamError::StreamClosed) ); } @@ -54,14 +54,14 @@ fn close_participant() { #[test] fn close_stream() { let (_, _) = helper::setup(false, 0); - let (_n_a, _, mut s1_a, _n_b, _, _) = block_on(network_participant_stream(tcp())); + let (r, _n_a, _, mut s1_a, _n_b, _, _) = network_participant_stream(tcp()); // s1_b is dropped directly while s1_a isn't std::thread::sleep(std::time::Duration::from_millis(1000)); assert_eq!(s1_a.send("Hello World"), Err(StreamError::StreamClosed)); assert_eq!( - block_on(s1_a.recv::()), + r.block_on(s1_a.recv::()), Err(StreamError::StreamClosed) ); } @@ -72,8 +72,8 @@ fn close_stream() { #[test] fn close_streams_in_block_on() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, s1_a, _n_b, _p_b, s1_b) = block_on(network_participant_stream(tcp())); - block_on(async { + let (r, _n_a, _p_a, s1_a, _n_b, _p_b, s1_b) = network_participant_stream(tcp()); + r.block_on(async { //make it locally so that they are dropped later let mut s1_a = s1_a; let mut s1_b = s1_b; @@ -86,14 +86,14 @@ fn close_streams_in_block_on() { #[test] fn stream_simple_3msg_then_close() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send(1u8).unwrap(); s1_a.send(42).unwrap(); s1_a.send("3rdMessage").unwrap(); - assert_eq!(block_on(s1_b.recv()), Ok(1u8)); - assert_eq!(block_on(s1_b.recv()), Ok(42)); - assert_eq!(block_on(s1_b.recv()), Ok("3rdMessage".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok(1u8)); + assert_eq!(r.block_on(s1_b.recv()), Ok(42)); + assert_eq!(r.block_on(s1_b.recv()), Ok("3rdMessage".to_string())); drop(s1_a); std::thread::sleep(std::time::Duration::from_millis(1000)); assert_eq!(s1_b.send("Hello World"), Err(StreamError::StreamClosed)); @@ -103,43 +103,43 @@ fn stream_simple_3msg_then_close() { fn stream_send_first_then_receive() { // recv should still be possible even if stream got closed if they are in queue let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send(1u8).unwrap(); s1_a.send(42).unwrap(); s1_a.send("3rdMessage").unwrap(); drop(s1_a); std::thread::sleep(std::time::Duration::from_millis(1000)); - assert_eq!(block_on(s1_b.recv()), Ok(1u8)); - assert_eq!(block_on(s1_b.recv()), Ok(42)); - assert_eq!(block_on(s1_b.recv()), Ok("3rdMessage".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok(1u8)); + assert_eq!(r.block_on(s1_b.recv()), Ok(42)); + assert_eq!(r.block_on(s1_b.recv()), Ok("3rdMessage".to_string())); assert_eq!(s1_b.send("Hello World"), Err(StreamError::StreamClosed)); } #[test] fn stream_send_1_then_close_stream() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send("this message must be received, even if stream is closed already!") .unwrap(); drop(s1_a); std::thread::sleep(std::time::Duration::from_millis(1000)); let exp = Ok("this message must be received, even if stream is closed already!".to_string()); - assert_eq!(block_on(s1_b.recv()), exp); + assert_eq!(r.block_on(s1_b.recv()), exp); println!("all received and done"); } #[test] fn stream_send_100000_then_close_stream() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); for _ in 0..100000 { s1_a.send("woop_PARTY_HARD_woop").unwrap(); } drop(s1_a); let exp = Ok("woop_PARTY_HARD_woop".to_string()); println!("start receiving"); - block_on(async { + r.block_on(async { for _ in 0..100000 { assert_eq!(s1_b.recv().await, exp); } @@ -150,7 +150,7 @@ fn stream_send_100000_then_close_stream() { #[test] fn stream_send_100000_then_close_stream_remote() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, _s1_b) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _p_a, mut s1_a, _n_b, _p_b, _s1_b) = network_participant_stream(tcp()); for _ in 0..100000 { s1_a.send("woop_PARTY_HARD_woop").unwrap(); } @@ -162,7 +162,7 @@ fn stream_send_100000_then_close_stream_remote() { #[test] fn stream_send_100000_then_close_stream_remote2() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, _s1_b) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _p_a, mut s1_a, _n_b, _p_b, _s1_b) = network_participant_stream(tcp()); for _ in 0..100000 { s1_a.send("woop_PARTY_HARD_woop").unwrap(); } @@ -175,7 +175,7 @@ fn stream_send_100000_then_close_stream_remote2() { #[test] fn stream_send_100000_then_close_stream_remote3() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, _s1_b) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _p_a, mut s1_a, _n_b, _p_b, _s1_b) = network_participant_stream(tcp()); for _ in 0..100000 { s1_a.send("woop_PARTY_HARD_woop").unwrap(); } @@ -188,7 +188,7 @@ fn stream_send_100000_then_close_stream_remote3() { #[test] fn close_part_then_network() { let (_, _) = helper::setup(false, 0); - let (n_a, p_a, mut s1_a, _n_b, _p_b, _s1_b) = block_on(network_participant_stream(tcp())); + let (_, n_a, p_a, mut s1_a, _n_b, _p_b, _s1_b) = network_participant_stream(tcp()); for _ in 0..1000 { s1_a.send("woop_PARTY_HARD_woop").unwrap(); } @@ -201,7 +201,7 @@ fn close_part_then_network() { #[test] fn close_network_then_part() { let (_, _) = helper::setup(false, 0); - let (n_a, p_a, mut s1_a, _n_b, _p_b, _s1_b) = block_on(network_participant_stream(tcp())); + let (_, n_a, p_a, mut s1_a, _n_b, _p_b, _s1_b) = network_participant_stream(tcp()); for _ in 0..1000 { s1_a.send("woop_PARTY_HARD_woop").unwrap(); } @@ -214,39 +214,39 @@ fn close_network_then_part() { #[test] fn close_network_then_disconnect_part() { let (_, _) = helper::setup(false, 0); - let (n_a, p_a, mut s1_a, _n_b, _p_b, _s1_b) = block_on(network_participant_stream(tcp())); + let (r, n_a, p_a, mut s1_a, _n_b, _p_b, _s1_b) = network_participant_stream(tcp()); for _ in 0..1000 { s1_a.send("woop_PARTY_HARD_woop").unwrap(); } drop(n_a); - assert!(block_on(p_a.disconnect()).is_err()); + assert!(r.block_on(p_a.disconnect()).is_err()); std::thread::sleep(std::time::Duration::from_millis(1000)); } #[test] fn opened_stream_before_remote_part_is_closed() { let (_, _) = helper::setup(false, 0); - let (_n_a, p_a, _, _n_b, p_b, _) = block_on(network_participant_stream(tcp())); - let mut s2_a = block_on(p_a.open(10, Promises::empty())).unwrap(); + let (r, _n_a, p_a, _, _n_b, p_b, _) = network_participant_stream(tcp()); + let mut s2_a = r.block_on(p_a.open(10, Promises::empty())).unwrap(); s2_a.send("HelloWorld").unwrap(); - let mut s2_b = block_on(p_b.opened()).unwrap(); + let mut s2_b = r.block_on(p_b.opened()).unwrap(); drop(p_a); std::thread::sleep(std::time::Duration::from_millis(1000)); - assert_eq!(block_on(s2_b.recv()), Ok("HelloWorld".to_string())); + assert_eq!(r.block_on(s2_b.recv()), Ok("HelloWorld".to_string())); } #[test] fn opened_stream_after_remote_part_is_closed() { let (_, _) = helper::setup(false, 0); - let (_n_a, p_a, _, _n_b, p_b, _) = block_on(network_participant_stream(tcp())); - let mut s2_a = block_on(p_a.open(10, Promises::empty())).unwrap(); + let (r, _n_a, p_a, _, _n_b, p_b, _) = network_participant_stream(tcp()); + let mut s2_a = r.block_on(p_a.open(10, Promises::empty())).unwrap(); s2_a.send("HelloWorld").unwrap(); drop(p_a); std::thread::sleep(std::time::Duration::from_millis(1000)); - let mut s2_b = block_on(p_b.opened()).unwrap(); - assert_eq!(block_on(s2_b.recv()), Ok("HelloWorld".to_string())); + let mut s2_b = r.block_on(p_b.opened()).unwrap(); + assert_eq!(r.block_on(s2_b.recv()), Ok("HelloWorld".to_string())); assert_eq!( - block_on(p_b.opened()).unwrap_err(), + r.block_on(p_b.opened()).unwrap_err(), ParticipantError::ParticipantDisconnected ); } @@ -254,15 +254,15 @@ fn opened_stream_after_remote_part_is_closed() { #[test] fn open_stream_after_remote_part_is_closed() { let (_, _) = helper::setup(false, 0); - let (_n_a, p_a, _, _n_b, p_b, _) = block_on(network_participant_stream(tcp())); - let mut s2_a = block_on(p_a.open(10, Promises::empty())).unwrap(); + let (r, _n_a, p_a, _, _n_b, p_b, _) = network_participant_stream(tcp()); + let mut s2_a = r.block_on(p_a.open(10, Promises::empty())).unwrap(); s2_a.send("HelloWorld").unwrap(); drop(p_a); std::thread::sleep(std::time::Duration::from_millis(1000)); - let mut s2_b = block_on(p_b.opened()).unwrap(); - assert_eq!(block_on(s2_b.recv()), Ok("HelloWorld".to_string())); + let mut s2_b = r.block_on(p_b.opened()).unwrap(); + assert_eq!(r.block_on(s2_b.recv()), Ok("HelloWorld".to_string())); assert_eq!( - block_on(p_b.open(20, Promises::empty())).unwrap_err(), + r.block_on(p_b.open(20, Promises::empty())).unwrap_err(), ParticipantError::ParticipantDisconnected ); } @@ -270,11 +270,11 @@ fn open_stream_after_remote_part_is_closed() { #[test] fn failed_stream_open_after_remote_part_is_closed() { let (_, _) = helper::setup(false, 0); - let (_n_a, p_a, _, _n_b, p_b, _) = block_on(network_participant_stream(tcp())); + let (r, _n_a, p_a, _, _n_b, p_b, _) = network_participant_stream(tcp()); drop(p_a); std::thread::sleep(std::time::Duration::from_millis(1000)); assert_eq!( - block_on(p_b.opened()).unwrap_err(), + r.block_on(p_b.opened()).unwrap_err(), ParticipantError::ParticipantDisconnected ); } @@ -282,72 +282,69 @@ fn failed_stream_open_after_remote_part_is_closed() { #[test] fn open_participant_before_remote_part_is_closed() { let (_, _) = helper::setup(false, 0); - let (n_a, f) = Network::new(Pid::fake(0)); - std::thread::spawn(f); - let (n_b, f) = Network::new(Pid::fake(1)); - std::thread::spawn(f); + let r = Arc::new(Runtime::new().unwrap()); + let n_a = Network::new(Pid::fake(0), Arc::clone(&r)); + let n_b = Network::new(Pid::fake(1), Arc::clone(&r)); let addr = tcp(); - block_on(n_a.listen(addr.clone())).unwrap(); - let p_b = block_on(n_b.connect(addr)).unwrap(); - let mut s1_b = block_on(p_b.open(10, Promises::empty())).unwrap(); + r.block_on(n_a.listen(addr.clone())).unwrap(); + let p_b = r.block_on(n_b.connect(addr)).unwrap(); + let mut s1_b = r.block_on(p_b.open(10, Promises::empty())).unwrap(); s1_b.send("HelloWorld").unwrap(); - let p_a = block_on(n_a.connected()).unwrap(); + let p_a = r.block_on(n_a.connected()).unwrap(); drop(s1_b); drop(p_b); drop(n_b); std::thread::sleep(std::time::Duration::from_millis(1000)); - let mut s1_a = block_on(p_a.opened()).unwrap(); - assert_eq!(block_on(s1_a.recv()), Ok("HelloWorld".to_string())); + let mut s1_a = r.block_on(p_a.opened()).unwrap(); + assert_eq!(r.block_on(s1_a.recv()), Ok("HelloWorld".to_string())); } #[test] fn open_participant_after_remote_part_is_closed() { let (_, _) = helper::setup(false, 0); - let (n_a, f) = Network::new(Pid::fake(0)); - std::thread::spawn(f); - let (n_b, f) = Network::new(Pid::fake(1)); - std::thread::spawn(f); + let r = Arc::new(Runtime::new().unwrap()); + let n_a = Network::new(Pid::fake(0), Arc::clone(&r)); + let n_b = Network::new(Pid::fake(1), Arc::clone(&r)); let addr = tcp(); - block_on(n_a.listen(addr.clone())).unwrap(); - let p_b = block_on(n_b.connect(addr)).unwrap(); - let mut s1_b = block_on(p_b.open(10, Promises::empty())).unwrap(); + r.block_on(n_a.listen(addr.clone())).unwrap(); + let p_b = r.block_on(n_b.connect(addr)).unwrap(); + let mut s1_b = r.block_on(p_b.open(10, Promises::empty())).unwrap(); s1_b.send("HelloWorld").unwrap(); drop(s1_b); drop(p_b); drop(n_b); std::thread::sleep(std::time::Duration::from_millis(1000)); - let p_a = block_on(n_a.connected()).unwrap(); - let mut s1_a = block_on(p_a.opened()).unwrap(); - assert_eq!(block_on(s1_a.recv()), Ok("HelloWorld".to_string())); + let p_a = r.block_on(n_a.connected()).unwrap(); + let mut s1_a = r.block_on(p_a.opened()).unwrap(); + assert_eq!(r.block_on(s1_a.recv()), Ok("HelloWorld".to_string())); } #[test] fn close_network_scheduler_completely() { let (_, _) = helper::setup(false, 0); - let (n_a, f) = Network::new(Pid::fake(0)); - let ha = std::thread::spawn(f); - let (n_b, f) = Network::new(Pid::fake(1)); - let hb = std::thread::spawn(f); + let r = Arc::new(Runtime::new().unwrap()); + let n_a = Network::new(Pid::fake(0), Arc::clone(&r)); + let n_b = Network::new(Pid::fake(1), Arc::clone(&r)); let addr = tcp(); - block_on(n_a.listen(addr.clone())).unwrap(); - let p_b = block_on(n_b.connect(addr)).unwrap(); - let mut s1_b = block_on(p_b.open(10, Promises::empty())).unwrap(); + r.block_on(n_a.listen(addr.clone())).unwrap(); + let p_b = r.block_on(n_b.connect(addr)).unwrap(); + let mut s1_b = r.block_on(p_b.open(10, Promises::empty())).unwrap(); s1_b.send("HelloWorld").unwrap(); - let p_a = block_on(n_a.connected()).unwrap(); - let mut s1_a = block_on(p_a.opened()).unwrap(); - assert_eq!(block_on(s1_a.recv()), Ok("HelloWorld".to_string())); + let p_a = r.block_on(n_a.connected()).unwrap(); + let mut s1_a = r.block_on(p_a.opened()).unwrap(); + assert_eq!(r.block_on(s1_a.recv()), Ok("HelloWorld".to_string())); drop(n_a); drop(n_b); std::thread::sleep(std::time::Duration::from_millis(1000)); - ha.join().unwrap(); - hb.join().unwrap(); + let runtime = Arc::try_unwrap(r).expect("runtime is not alone, there still exist a reference"); + runtime.shutdown_timeout(std::time::Duration::from_secs(300)); } #[test] fn dont_panic_on_multiply_recv_after_close() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send(11u32).unwrap(); drop(s1_a); @@ -362,7 +359,7 @@ fn dont_panic_on_multiply_recv_after_close() { #[test] fn dont_panic_on_recv_send_after_close() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send(11u32).unwrap(); drop(s1_a); @@ -375,7 +372,7 @@ fn dont_panic_on_recv_send_after_close() { #[test] fn dont_panic_on_multiple_send_after_close() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send(11u32).unwrap(); drop(s1_a); diff --git a/network/tests/helper.rs b/network/tests/helper.rs index 93ee64a1e6..64c65b0e91 100644 --- a/network/tests/helper.rs +++ b/network/tests/helper.rs @@ -1,10 +1,14 @@ use lazy_static::*; use std::{ net::SocketAddr, - sync::atomic::{AtomicU16, Ordering}, + sync::{ + atomic::{AtomicU16, Ordering}, + Arc, + }, thread, time::Duration, }; +use tokio::runtime::Runtime; use tracing::*; use tracing_subscriber::EnvFilter; use veloren_network::{Network, Participant, Pid, Promises, ProtocolAddr, Stream}; @@ -43,22 +47,32 @@ pub fn setup(tracing: bool, sleep: u64) -> (u64, u64) { } #[allow(dead_code)] -pub async fn network_participant_stream( +pub fn network_participant_stream( addr: ProtocolAddr, -) -> (Network, Participant, Stream, Network, Participant, Stream) { - let (n_a, f_a) = Network::new(Pid::fake(0)); - std::thread::spawn(f_a); - let (n_b, f_b) = Network::new(Pid::fake(1)); - std::thread::spawn(f_b); +) -> ( + Arc, + Network, + Participant, + Stream, + Network, + Participant, + Stream, +) { + let runtime = Arc::new(Runtime::new().unwrap()); + let (n_a, p1_a, s1_a, n_b, p1_b, s1_b) = runtime.block_on(async { + let n_a = Network::new(Pid::fake(0), Arc::clone(&runtime)); + let n_b = Network::new(Pid::fake(1), Arc::clone(&runtime)); - n_a.listen(addr.clone()).await.unwrap(); - let p1_b = n_b.connect(addr).await.unwrap(); - let p1_a = n_a.connected().await.unwrap(); + n_a.listen(addr.clone()).await.unwrap(); + let p1_b = n_b.connect(addr).await.unwrap(); + let p1_a = n_a.connected().await.unwrap(); - let s1_a = p1_a.open(10, Promises::empty()).await.unwrap(); - let s1_b = p1_b.opened().await.unwrap(); + let s1_a = p1_a.open(10, Promises::empty()).await.unwrap(); + let s1_b = p1_b.opened().await.unwrap(); - (n_a, p1_a, s1_a, n_b, p1_b, s1_b) + (n_a, p1_a, s1_a, n_b, p1_b, s1_b) + }); + (runtime, n_a, p1_a, s1_a, n_b, p1_b, s1_b) } #[allow(dead_code)] diff --git a/network/tests/integration.rs b/network/tests/integration.rs index f4c8367841..b83f50b570 100644 --- a/network/tests/integration.rs +++ b/network/tests/integration.rs @@ -1,5 +1,5 @@ -use async_std::task; -use task::block_on; +use std::sync::Arc; +use tokio::runtime::Runtime; use veloren_network::{NetworkError, StreamError}; mod helper; use helper::{network_participant_stream, tcp, udp}; @@ -10,23 +10,23 @@ use veloren_network::{Network, Pid, Promises, ProtocolAddr}; #[ignore] fn network_20s() { let (_, _) = helper::setup(false, 0); - let (_n_a, _, _, _n_b, _, _) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _, _, _n_b, _, _) = network_participant_stream(tcp()); std::thread::sleep(std::time::Duration::from_secs(30)); } #[test] fn stream_simple() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send("Hello World").unwrap(); - assert_eq!(block_on(s1_b.recv()), Ok("Hello World".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); } #[test] fn stream_try_recv() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send(4242u32).unwrap(); std::thread::sleep(std::time::Duration::from_secs(1)); @@ -36,47 +36,46 @@ fn stream_try_recv() { #[test] fn stream_simple_3msg() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send("Hello World").unwrap(); s1_a.send(1337).unwrap(); - assert_eq!(block_on(s1_b.recv()), Ok("Hello World".to_string())); - assert_eq!(block_on(s1_b.recv()), Ok(1337)); + assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok(1337)); s1_a.send("3rdMessage").unwrap(); - assert_eq!(block_on(s1_b.recv()), Ok("3rdMessage".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok("3rdMessage".to_string())); } #[test] fn stream_simple_udp() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(udp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(udp()); s1_a.send("Hello World").unwrap(); - assert_eq!(block_on(s1_b.recv()), Ok("Hello World".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); } #[test] fn stream_simple_udp_3msg() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(udp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(udp()); s1_a.send("Hello World").unwrap(); s1_a.send(1337).unwrap(); - assert_eq!(block_on(s1_b.recv()), Ok("Hello World".to_string())); - assert_eq!(block_on(s1_b.recv()), Ok(1337)); + assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok(1337)); s1_a.send("3rdMessage").unwrap(); - assert_eq!(block_on(s1_b.recv()), Ok("3rdMessage".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok("3rdMessage".to_string())); } #[test] #[ignore] fn tcp_and_udp_2_connections() -> std::result::Result<(), Box> { let (_, _) = helper::setup(false, 0); - let (network, f) = Network::new(Pid::new()); - let (remote, fr) = Network::new(Pid::new()); - std::thread::spawn(f); - std::thread::spawn(fr); - block_on(async { + let r = Arc::new(Runtime::new().unwrap()); + let network = Network::new(Pid::new(), Arc::clone(&r)); + let remote = Network::new(Pid::new(), Arc::clone(&r)); + r.block_on(async { remote .listen(ProtocolAddr::Tcp("127.0.0.1:2000".parse().unwrap())) .await?; @@ -97,18 +96,17 @@ fn tcp_and_udp_2_connections() -> std::result::Result<(), Box std::result::Result<(), Box> { let (_, _) = helper::setup(false, 0); - let (network, f) = Network::new(Pid::new()); - std::thread::spawn(f); + let r = Arc::new(Runtime::new().unwrap()); + let network = Network::new(Pid::new(), Arc::clone(&r)); let udp1 = udp(); let tcp1 = tcp(); - block_on(network.listen(udp1.clone()))?; - block_on(network.listen(tcp1.clone()))?; + r.block_on(network.listen(udp1.clone()))?; + r.block_on(network.listen(tcp1.clone()))?; std::thread::sleep(std::time::Duration::from_millis(200)); - let (network2, f2) = Network::new(Pid::new()); - std::thread::spawn(f2); - let e1 = block_on(network2.listen(udp1)); - let e2 = block_on(network2.listen(tcp1)); + let network2 = Network::new(Pid::new(), Arc::clone(&r)); + let e1 = r.block_on(network2.listen(udp1)); + let e2 = r.block_on(network2.listen(tcp1)); match e1 { Err(NetworkError::ListenFailed(e)) if e.kind() == ErrorKind::AddrInUse => (), _ => panic!(), @@ -130,11 +128,10 @@ fn api_stream_send_main() -> std::result::Result<(), Box> let (_, _) = helper::setup(false, 0); // Create a Network, listen on Port `1200` and wait for a Stream to be opened, // then answer `Hello World` - let (network, f) = Network::new(Pid::new()); - let (remote, fr) = Network::new(Pid::new()); - std::thread::spawn(f); - std::thread::spawn(fr); - block_on(async { + let r = Arc::new(Runtime::new().unwrap()); + let network = Network::new(Pid::new(), Arc::clone(&r)); + let remote = Network::new(Pid::new(), Arc::clone(&r)); + r.block_on(async { network .listen(ProtocolAddr::Tcp("127.0.0.1:1200".parse().unwrap())) .await?; @@ -158,11 +155,10 @@ fn api_stream_recv_main() -> std::result::Result<(), Box> let (_, _) = helper::setup(false, 0); // Create a Network, listen on Port `1220` and wait for a Stream to be opened, // then listen on it - let (network, f) = Network::new(Pid::new()); - let (remote, fr) = Network::new(Pid::new()); - std::thread::spawn(f); - std::thread::spawn(fr); - block_on(async { + let r = Arc::new(Runtime::new().unwrap()); + let network = Network::new(Pid::new(), Arc::clone(&r)); + let remote = Network::new(Pid::new(), Arc::clone(&r)); + r.block_on(async { network .listen(ProtocolAddr::Tcp("127.0.0.1:1220".parse().unwrap())) .await?; @@ -184,10 +180,10 @@ fn api_stream_recv_main() -> std::result::Result<(), Box> #[test] fn wrong_parse() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send(1337).unwrap(); - match block_on(s1_b.recv::()) { + match r.block_on(s1_b.recv::()) { Err(StreamError::Deserialize(_)) => (), _ => panic!("this should fail, but it doesnt!"), } @@ -196,7 +192,7 @@ fn wrong_parse() { #[test] fn multiple_try_recv() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send("asd").unwrap(); s1_a.send(11u32).unwrap(); diff --git a/server-cli/src/main.rs b/server-cli/src/main.rs index 18b3390199..b930bd9de9 100644 --- a/server-cli/src/main.rs +++ b/server-cli/src/main.rs @@ -129,9 +129,19 @@ fn main() -> io::Result<()> { let server_port = &server_settings.gameserver_address.port(); let metrics_port = &server_settings.metrics_address.port(); // Create server - let runtime = Arc::new(tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap()); - let mut server = Server::new(server_settings, editable_settings, &server_data_dir, runtime) - .expect("Failed to create server instance!"); + let runtime = Arc::new( + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(), + ); + let mut server = Server::new( + server_settings, + editable_settings, + &server_data_dir, + runtime, + ) + .expect("Failed to create server instance!"); info!( ?server_port, diff --git a/server/src/lib.rs b/server/src/lib.rs index fe389ae266..a8ac9487fe 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -91,8 +91,8 @@ use std::{ }; #[cfg(not(feature = "worldgen"))] use test_world::{IndexOwned, World}; -use tracing::{debug, error, info, trace}; use tokio::runtime::Runtime; +use tracing::{debug, error, info, trace}; use uvth::{ThreadPool, ThreadPoolBuilder}; use vek::*; @@ -121,7 +121,7 @@ pub struct Server { connection_handler: ConnectionHandler, - runtime: Arc, + _runtime: Arc, thread_pool: ThreadPool, metrics: ServerMetrics, @@ -367,7 +367,8 @@ impl Server { let thread_pool = ThreadPoolBuilder::new() .name("veloren-worker".to_string()) .build(); - let network = Network::new_with_registry(Pid::new(), Arc::clone(&runtime), &metrics.registry()); + let network = + Network::new_with_registry(Pid::new(), Arc::clone(&runtime), &metrics.registry()); metrics .run(settings.metrics_address) .expect("Failed to initialize server metrics submodule."); @@ -388,7 +389,7 @@ impl Server { connection_handler, - runtime, + _runtime: runtime, thread_pool, metrics, diff --git a/voxygen/src/menu/main/client_init.rs b/voxygen/src/menu/main/client_init.rs index 010071cd56..2297bf9232 100644 --- a/voxygen/src/menu/main/client_init.rs +++ b/voxygen/src/menu/main/client_init.rs @@ -72,7 +72,13 @@ impl ClientInit { let mut last_err = None; let cores = num_cpus::get(); - let runtime = Arc::new(tokio::runtime::Builder::new_multi_thread().enable_all().worker_threads(if cores > 4 {cores-1} else {cores}).build().unwrap()); + let runtime = Arc::new( + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .worker_threads(if cores > 4 { cores - 1 } else { cores }) + .build() + .unwrap(), + ); const FOUR_MINUTES_RETRIES: u64 = 48; 'tries: for _ in 0..FOUR_MINUTES_RETRIES { diff --git a/voxygen/src/singleplayer.rs b/voxygen/src/singleplayer.rs index 32368a19fc..bda1c1fd2b 100644 --- a/voxygen/src/singleplayer.rs +++ b/voxygen/src/singleplayer.rs @@ -83,7 +83,13 @@ impl Singleplayer { let thread_pool = client.map(|c| c.thread_pool().clone()); let cores = num_cpus::get(); - let runtime = Arc::new(tokio::runtime::Builder::new_multi_thread().enable_all().worker_threads(if cores > 4 {cores-1} else {cores}).build().unwrap()); + let runtime = Arc::new( + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .worker_threads(if cores > 4 { cores - 1 } else { cores }) + .build() + .unwrap(), + ); let settings2 = settings.clone(); let paused = Arc::new(AtomicBool::new(false)); From 3f85506761f05b536f357fc341d9302793737a35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= Date: Tue, 19 Jan 2021 09:48:33 +0100 Subject: [PATCH 3/6] fix most unittests (not all) by a) dropping network/participant BEFORE runtime and by transfering a expect into a warn! in the protocol --- network/src/api.rs | 27 +++++++++++++++------------ network/src/participant.rs | 3 ++- network/src/protocols.rs | 28 ++++++++++++++-------------- network/tests/closing.rs | 17 +++++++++++++---- network/tests/integration.rs | 14 ++++++++++++++ 5 files changed, 58 insertions(+), 31 deletions(-) diff --git a/network/src/api.rs b/network/src/api.rs index 1b349f3248..ef6fb113db 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -167,7 +167,10 @@ impl Network { /// * `participant_id` - provide it by calling [`Pid::new()`], usually you /// don't want to reuse a Pid for 2 `Networks` /// * `runtime` - provide a tokio::Runtime, it's used to internally spawn - /// tasks. It is necessary to clean up in the non-async `Drop`. + /// tasks. It is necessary to clean up in the non-async `Drop`. **All** + /// network related components **must** be dropped before the runtime is + /// stopped. dropping the runtime while a shutdown is still in progress + /// leaves the network in a bad state which might cause a panic! /// /// # Result /// * `Self` - returns a `Network` which can be `Send` to multiple areas of @@ -245,11 +248,10 @@ impl Network { ); runtime.spawn(async move { trace!(?p, "Starting scheduler in own thread"); - let _handle = tokio::spawn( - scheduler - .run() - .instrument(tracing::info_span!("scheduler", ?p)), - ); + scheduler + .run() + .instrument(tracing::info_span!("scheduler", ?p)) + .await; trace!(?p, "Stopping scheduler and his own thread"); }); Self { @@ -985,11 +987,7 @@ impl Drop for Network { }); trace!(?pid, "Participants have shut down!"); trace!(?pid, "Shutting down Scheduler"); - self.shutdown_sender - .take() - .unwrap() - .send(()) - .expect("Scheduler is closed, but nobody other should be able to close it"); + self.shutdown_sender.take().unwrap().send(()).expect("Scheduler is closed, but nobody other should be able to close it"); debug!(?pid, "Network has shut down"); } } @@ -1001,7 +999,12 @@ impl Drop for Participant { let pid = self.remote_pid; debug!(?pid, "Shutting down Participant"); - match self.runtime.block_on(self.a2s_disconnect_s.lock()).take() { + match self + .a2s_disconnect_s + .try_lock() + .expect("Participant in use while beeing dropped") + .take() + { None => trace!( ?pid, "Participant has been shutdown cleanly, no further waiting is required!" diff --git a/network/src/participant.rs b/network/src/participant.rs index 764f407cdf..6986a70e8f 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -204,6 +204,7 @@ impl BParticipant { // wait for more messages self.running_mgr.fetch_add(1, Ordering::Relaxed); let mut b2b_prios_flushed_s = None; //closing up + let mut interval = tokio::time::interval(Self::TICK_TIME); trace!("Start send_mgr"); #[cfg(feature = "metrics")] let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); @@ -225,7 +226,7 @@ impl BParticipant { b2s_prio_statistic_s .send((self.remote_pid, len as u64, /* */ 0)) .unwrap(); - tokio::time::sleep(Self::TICK_TIME).await; + interval.tick().await; i += 1; if i.rem_euclid(1000) == 0 { trace!("Did 1000 ticks"); diff --git a/network/src/protocols.rs b/network/src/protocols.rs index b92ef27eee..a18c1e1cbd 100644 --- a/network/src/protocols.rs +++ b/network/src/protocols.rs @@ -172,20 +172,20 @@ impl TcpProtocol { match Self::read_frame(&mut *read_stream, &mut end_r).await { Ok(frame) => { #[cfg(feature = "metrics")] - { - metrics_cache.with_label_values(&frame).inc(); - if let Frame::Data { - mid: _, - start: _, - ref data, - } = frame { - throughput_cache.inc_by(data.len() as u64); + metrics_cache.with_label_values(&frame).inc(); + if let Frame::Data { + mid: _, + start: _, + ref data, + } = frame + { + throughput_cache.inc_by(data.len() as u64); + } } + if let Err(e) = w2c_cid_frame_s.send((cid, Ok(frame))) { + warn!(?e, "Channel or Participant seems no longer to exist"); } - w2c_cid_frame_s - .send((cid, Ok(frame))) - .expect("Channel or Participant seems no longer to exist"); }, Err(e_option) => { if let Some(e) = e_option { @@ -193,9 +193,9 @@ impl TcpProtocol { //w2c_cid_frame_s is shared, dropping it wouldn't notify the receiver as // every channel is holding a sender! thats why Ne // need a explicit STOP here - w2c_cid_frame_s - .send((cid, Err(()))) - .expect("Channel or Participant seems no longer to exist"); + if let Err(e) = w2c_cid_frame_s.send((cid, Err(()))) { + warn!(?e, "Channel or Participant seems no longer to exist"); + } } //None is clean shutdown break; diff --git a/network/tests/closing.rs b/network/tests/closing.rs index e3abb25533..b0e8e180a9 100644 --- a/network/tests/closing.rs +++ b/network/tests/closing.rs @@ -66,9 +66,8 @@ fn close_stream() { ); } -///THIS is actually a bug which currently luckily doesn't trigger, but with new -/// async-std WE must make sure, if a stream is `drop`ed inside a `block_on`, -/// that no panic is thrown. +///WE must NOT create runtimes inside a Runtime, this check needs to verify +/// that we dont panic there #[test] fn close_streams_in_block_on() { let (_, _) = helper::setup(false, 0); @@ -81,6 +80,7 @@ fn close_streams_in_block_on() { assert_eq!(s1_b.recv().await, Ok("ping".to_string())); drop(s1_a); }); + drop((_n_a, _p_a, _n_b, _p_b)); //clean teardown } #[test] @@ -157,6 +157,7 @@ fn stream_send_100000_then_close_stream_remote() { drop(s1_a); drop(_s1_b); //no receiving + drop((_n_a, _p_a, _n_b, _p_b)); //clean teardown } #[test] @@ -170,6 +171,7 @@ fn stream_send_100000_then_close_stream_remote2() { std::thread::sleep(std::time::Duration::from_millis(1000)); drop(s1_a); //no receiving + drop((_n_a, _p_a, _n_b, _p_b)); //clean teardown } #[test] @@ -183,6 +185,7 @@ fn stream_send_100000_then_close_stream_remote3() { std::thread::sleep(std::time::Duration::from_millis(1000)); drop(s1_a); //no receiving + drop((_n_a, _p_a, _n_b, _p_b)); //clean teardown } #[test] @@ -233,6 +236,7 @@ fn opened_stream_before_remote_part_is_closed() { drop(p_a); std::thread::sleep(std::time::Duration::from_millis(1000)); assert_eq!(r.block_on(s2_b.recv()), Ok("HelloWorld".to_string())); + drop((_n_a, _n_b, p_b)); //clean teardown } #[test] @@ -249,6 +253,7 @@ fn opened_stream_after_remote_part_is_closed() { r.block_on(p_b.opened()).unwrap_err(), ParticipantError::ParticipantDisconnected ); + drop((_n_a, _n_b, p_b)); //clean teardown } #[test] @@ -265,6 +270,7 @@ fn open_stream_after_remote_part_is_closed() { r.block_on(p_b.open(20, Promises::empty())).unwrap_err(), ParticipantError::ParticipantDisconnected ); + drop((_n_a, _n_b, p_b)); //clean teardown } #[test] @@ -272,11 +278,11 @@ fn failed_stream_open_after_remote_part_is_closed() { let (_, _) = helper::setup(false, 0); let (r, _n_a, p_a, _, _n_b, p_b, _) = network_participant_stream(tcp()); drop(p_a); - std::thread::sleep(std::time::Duration::from_millis(1000)); assert_eq!( r.block_on(p_b.opened()).unwrap_err(), ParticipantError::ParticipantDisconnected ); + drop((_n_a, _n_b, p_b)); //clean teardown } #[test] @@ -337,6 +343,9 @@ fn close_network_scheduler_completely() { drop(n_a); drop(n_b); std::thread::sleep(std::time::Duration::from_millis(1000)); + + drop(p_b); + drop(p_a); let runtime = Arc::try_unwrap(r).expect("runtime is not alone, there still exist a reference"); runtime.shutdown_timeout(std::time::Duration::from_secs(300)); } diff --git a/network/tests/integration.rs b/network/tests/integration.rs index b83f50b570..b78619d65d 100644 --- a/network/tests/integration.rs +++ b/network/tests/integration.rs @@ -21,6 +21,7 @@ fn stream_simple() { s1_a.send("Hello World").unwrap(); assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown } #[test] @@ -31,6 +32,7 @@ fn stream_try_recv() { s1_a.send(4242u32).unwrap(); std::thread::sleep(std::time::Duration::from_secs(1)); assert_eq!(s1_b.try_recv(), Ok(Some(4242u32))); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown } #[test] @@ -44,6 +46,7 @@ fn stream_simple_3msg() { assert_eq!(r.block_on(s1_b.recv()), Ok(1337)); s1_a.send("3rdMessage").unwrap(); assert_eq!(r.block_on(s1_b.recv()), Ok("3rdMessage".to_string())); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown } #[test] @@ -53,6 +56,7 @@ fn stream_simple_udp() { s1_a.send("Hello World").unwrap(); assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown } #[test] @@ -66,6 +70,7 @@ fn stream_simple_udp_3msg() { assert_eq!(r.block_on(s1_b.recv()), Ok(1337)); s1_a.send("3rdMessage").unwrap(); assert_eq!(r.block_on(s1_b.recv()), Ok("3rdMessage".to_string())); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown } #[test] @@ -76,6 +81,8 @@ fn tcp_and_udp_2_connections() -> std::result::Result<(), Box std::result::Result<(), Box (), _ => panic!(), }; + drop((network, network2)); //clean teardown Ok(()) } @@ -132,6 +140,8 @@ fn api_stream_send_main() -> std::result::Result<(), Box> let network = Network::new(Pid::new(), Arc::clone(&r)); let remote = Network::new(Pid::new(), Arc::clone(&r)); r.block_on(async { + let network = network; + let remote = remote; network .listen(ProtocolAddr::Tcp("127.0.0.1:1200".parse().unwrap())) .await?; @@ -159,6 +169,8 @@ fn api_stream_recv_main() -> std::result::Result<(), Box> let network = Network::new(Pid::new(), Arc::clone(&r)); let remote = Network::new(Pid::new(), Arc::clone(&r)); r.block_on(async { + let network = network; + let remote = remote; network .listen(ProtocolAddr::Tcp("127.0.0.1:1220".parse().unwrap())) .await?; @@ -187,6 +199,7 @@ fn wrong_parse() { Err(StreamError::Deserialize(_)) => (), _ => panic!("this should fail, but it doesnt!"), } + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown } #[test] @@ -204,4 +217,5 @@ fn multiple_try_recv() { drop(s1_a); std::thread::sleep(std::time::Duration::from_secs(1)); assert_eq!(s1_b.try_recv::(), Err(StreamError::StreamClosed)); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown } From 9884019963241823326d564fe7103dd9e9198344 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= Date: Fri, 22 Jan 2021 17:09:20 +0100 Subject: [PATCH 4/6] COMPLETE REDESIGN of network crate - Implementing a async non-io protocol crate a) no tokio / no channels b) I/O is based on abstraction Sink/Drain c) different Protocols can have a different Drain Type This allow MPSC to send its content without splitting up messages at all! It allows UDP to have internal extra frames to care for security It allows better abstraction for tests Allows benchmarks on the mpsc variant Custom Handshakes to allow sth like Quic protocol easily - reduce the participant managers to 4: channel creations, send, recv and shutdown. keeping the `mut data` in one manager removes the need for all RwLocks. reducing complexity and parallel access problems - more strategic participant shutdown. first send. then wait for remote side to notice recv stop, then remote side will stop send, then local side can stop recv. - metrics are internally abstracted to fit protocol and network layer - in this commit network/protocol tests work and network tests work someway, veloren compiles but does not work - handshake compatible to async_std --- Cargo.lock | 83 +- Cargo.toml | 5 +- client/Cargo.toml | 2 +- network/Cargo.toml | 10 +- network/protocol/Cargo.toml | 33 + network/protocol/benches/protocols.rs | 243 +++++ network/protocol/src/event.rs | 74 ++ network/protocol/src/frame.rs | 634 ++++++++++++ network/protocol/src/handshake.rs | 227 +++++ network/protocol/src/io.rs | 62 ++ network/protocol/src/lib.rs | 75 ++ network/protocol/src/message.rs | 127 +++ network/protocol/src/metrics.rs | 414 ++++++++ network/protocol/src/mpsc.rs | 217 ++++ network/protocol/src/prio.rs | 139 +++ network/protocol/src/tcp.rs | 584 +++++++++++ network/{ => protocol}/src/types.rs | 152 +-- network/protocol/src/udp.rs | 37 + network/src/api.rs | 170 ++-- network/src/channel.rs | 560 ++++------ network/src/lib.rs | 6 +- network/src/message.rs | 89 +- network/src/metrics.rs | 284 +----- network/src/participant.rs | 1350 ++++++++++++------------- network/src/prios.rs | 697 ------------- network/src/protocols.rs | 591 ----------- network/src/scheduler.rs | 140 ++- server/Cargo.toml | 2 +- voxygen/src/hud/chat.rs | 2 +- 29 files changed, 3987 insertions(+), 3022 deletions(-) create mode 100644 network/protocol/Cargo.toml create mode 100644 network/protocol/benches/protocols.rs create mode 100644 network/protocol/src/event.rs create mode 100644 network/protocol/src/frame.rs create mode 100644 network/protocol/src/handshake.rs create mode 100644 network/protocol/src/io.rs create mode 100644 network/protocol/src/lib.rs create mode 100644 network/protocol/src/message.rs create mode 100644 network/protocol/src/metrics.rs create mode 100644 network/protocol/src/mpsc.rs create mode 100644 network/protocol/src/prio.rs create mode 100644 network/protocol/src/tcp.rs rename network/{ => protocol}/src/types.rs (52%) create mode 100644 network/protocol/src/udp.rs delete mode 100644 network/src/prios.rs delete mode 100644 network/src/protocols.rs diff --git a/Cargo.lock b/Cargo.lock index 69b549cd2b..bd0dcb62ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -259,6 +259,17 @@ dependencies = [ "futures-core", ] +[[package]] +name = "async-trait" +version = "0.1.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d3a45e77e34375a7923b1e8febb049bb011f064714a8e17a1a616fef01da13d" +dependencies = [ + "proc-macro2 1.0.24", + "quote 1.0.9", + "syn 1.0.60", +] + [[package]] name = "atom" version = "0.3.6" @@ -1057,6 +1068,7 @@ dependencies = [ "clap", "criterion-plot", "csv", + "futures", "itertools 0.10.0", "lazy_static", "num-traits", @@ -1069,6 +1081,7 @@ dependencies = [ "serde_derive", "serde_json", "tinytemplate", + "tokio 1.2.0", "walkdir 2.3.1", ] @@ -5580,7 +5593,7 @@ dependencies = [ "veloren-common", "veloren-common-net", "veloren-common-sys", - "veloren_network", + "veloren-network", ] [[package]] @@ -5661,6 +5674,47 @@ dependencies = [ "wasmer", ] +[[package]] +name = "veloren-network" +version = "0.3.0" +dependencies = [ + "async-channel", + "async-trait", + "bincode", + "bitflags", + "clap", + "crossbeam-channel 0.5.0", + "futures-core", + "futures-util", + "lazy_static", + "lz-fear", + "prometheus", + "rand 0.8.3", + "serde", + "shellexpand", + "tiny_http", + "tokio 1.2.0", + "tokio-stream", + "tracing", + "tracing-futures", + "tracing-subscriber", + "veloren-network-protocol", +] + +[[package]] +name = "veloren-network-protocol" +version = "0.5.0" +dependencies = [ + "async-channel", + "async-trait", + "bitflags", + "criterion", + "prometheus", + "rand 0.8.3", + "tokio 1.2.0", + "tracing", +] + [[package]] name = "veloren-plugin-api" version = "0.1.0" @@ -5725,9 +5779,9 @@ dependencies = [ "veloren-common", "veloren-common-net", "veloren-common-sys", + "veloren-network", "veloren-plugin-api", "veloren-world", - "veloren_network", ] [[package]] @@ -5864,31 +5918,6 @@ dependencies = [ "veloren-common-net", ] -[[package]] -name = "veloren_network" -version = "0.3.0" -dependencies = [ - "async-channel", - "bincode", - "bitflags", - "clap", - "crossbeam-channel 0.5.0", - "futures-core", - "futures-util", - "lazy_static", - "lz-fear", - "prometheus", - "rand 0.8.3", - "serde", - "shellexpand", - "tiny_http", - "tokio 1.2.0", - "tokio-stream", - "tracing", - "tracing-futures", - "tracing-subscriber", -] - [[package]] name = "version-compare" version = "0.0.10" diff --git a/Cargo.toml b/Cargo.toml index 6bd656493d..e8104a7195 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ members = [ "voxygen/anim", "world", "network", + "network/protocol" ] # default profile for devs, fast to compile, okay enough to run, no debug information @@ -30,8 +31,10 @@ incremental = true # All dependencies (but not this crate itself) [profile.dev.package."*"] opt-level = 3 -[profile.dev.package."veloren_network"] +[profile.dev.package."veloren-network"] opt-level = 2 +[profile.dev.package."veloren-network-protocol"] +opt-level = 3 [profile.dev.package."veloren-common"] opt-level = 2 [profile.dev.package."veloren-client"] diff --git a/client/Cargo.toml b/client/Cargo.toml index b2ebcfead1..fcda01cb65 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -14,7 +14,7 @@ default = ["simd"] common = { package = "veloren-common", path = "../common", features = ["no-assets"] } common-sys = { package = "veloren-common-sys", path = "../common/sys", default-features = false } common-net = { package = "veloren-common-net", path = "../common/net" } -network = { package = "veloren_network", path = "../network", features = ["compression"], default-features = false } +network = { package = "veloren-network", path = "../network", features = ["compression"], default-features = false } byteorder = "1.3.2" uvth = "3.1.1" diff --git a/network/Cargo.toml b/network/Cargo.toml index 0a540ca6dc..f548278896 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "veloren_network" +name = "veloren-network" version = "0.3.0" authors = ["Marcel Märtens "] edition = "2018" @@ -7,13 +7,15 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -metrics = ["prometheus"] +metrics = ["prometheus", "network-protocol/metrics"] compression = ["lz-fear"] default = ["metrics","compression"] [dependencies] +network-protocol = { package = "veloren-network-protocol", path = "protocol", default-features = false } + #serialisation bincode = "1.3.1" serde = { version = "1.0" } @@ -35,10 +37,12 @@ rand = { version = "0.8" } #stream flags bitflags = "1.2.1" lz-fear = { version = "0.1.1", optional = true } +# async traits +async-trait = "0.1.42" [dev-dependencies] tracing-subscriber = { version = "0.2.3", default-features = false, features = ["env-filter", "fmt", "chrono", "ansi", "smallvec"] } -tokio = { version = "1.0.1", default-features = false, features = ["io-std", "fs", "rt-multi-thread"] } +tokio = { version = "1.1.0", default-features = false, features = ["io-std", "fs", "rt-multi-thread"] } futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] } clap = { version = "2.33", default-features = false } shellexpand = "2.0.0" diff --git a/network/protocol/Cargo.toml b/network/protocol/Cargo.toml new file mode 100644 index 0000000000..e097314b6b --- /dev/null +++ b/network/protocol/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "veloren-network-protocol" +description = "pure Protocol without any I/O itself" +version = "0.5.0" +authors = ["Marcel Märtens "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +metrics = ["prometheus"] + +default = ["metrics"] + +[dependencies] + +#tracing and metrics +tracing = { version = "0.1", default-features = false } +prometheus = { version = "0.11", default-features = false, optional = true } +#stream flags +bitflags = "1.2.1" +rand = { version = "0.8" } +# async traits +async-trait = "0.1.42" + +[dev-dependencies] +async-channel = "1.5.1" +tokio = { version = "1.2", default-features = false, features = ["rt", "macros"] } +criterion = { version = "0.3.4", features = ["default", "async_tokio"] } + +[[bench]] +name = "protocols" +harness = false \ No newline at end of file diff --git a/network/protocol/benches/protocols.rs b/network/protocol/benches/protocols.rs new file mode 100644 index 0000000000..5151083b98 --- /dev/null +++ b/network/protocol/benches/protocols.rs @@ -0,0 +1,243 @@ +use async_channel::*; +use async_trait::async_trait; +use criterion::{criterion_group, criterion_main, Criterion}; +use std::{sync::Arc, time::Duration}; +use veloren_network_protocol::{ + InitProtocol, MessageBuffer, MpscMsg, MpscRecvProtcol, MpscSendProtcol, Pid, Promises, + ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, RecvProtocol, SendProtocol, + Sid, TcpRecvProtcol, TcpSendProtcol, UnreliableDrain, UnreliableSink, _internal::Frame, +}; + +fn frame_serialize(frame: Frame, buffer: &mut [u8]) -> usize { frame.to_bytes(buffer).0 } + +async fn mpsc_msg(buffer: Arc) { + // Arrrg, need to include constructor here + let [p1, p2] = utils::ac_bound(10, None); + let (mut s, mut r) = (p1.0, p2.1); + s.send(ProtocolEvent::Message { + sid: Sid::new(12), + mid: 0, + buffer, + }) + .await + .unwrap(); + r.recv().await.unwrap(); +} + +async fn mpsc_handshake() { + let [mut p1, mut p2] = utils::ac_bound(10, None); + let r1 = tokio::spawn(async move { + p1.initialize(true, Pid::fake(2), 1337).await.unwrap(); + p1 + }); + let r2 = tokio::spawn(async move { + p2.initialize(false, Pid::fake(3), 42).await.unwrap(); + p2 + }); + let (r1, r2) = tokio::join!(r1, r2); + r1.unwrap(); + r2.unwrap(); +} + +async fn tcp_msg(buffer: Arc, cnt: usize) { + let [p1, p2] = utils::tcp_bound(10000, None); /*10kbit*/ + let (mut s, mut r) = (p1.0, p2.1); + + let buffer = Arc::clone(&buffer); + let bandwidth = buffer.data.len() as u64 + 1000; + + let r1 = tokio::spawn(async move { + s.send(ProtocolEvent::OpenStream { + sid: Sid::new(12), + prio: 0, + promises: Promises::ORDERED, + guaranteed_bandwidth: 100_000, + }) + .await + .unwrap(); + + for i in 0..cnt { + s.send(ProtocolEvent::Message { + sid: Sid::new(12), + mid: i as u64, + buffer: Arc::clone(&buffer), + }) + .await + .unwrap(); + s.flush(bandwidth, Duration::from_secs(1)).await.unwrap(); + } + }); + let r2 = tokio::spawn(async move { + r.recv().await.unwrap(); + + for _ in 0..cnt { + r.recv().await.unwrap(); + } + }); + let (r1, r2) = tokio::join!(r1, r2); + r1.unwrap(); + r2.unwrap(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let rt = || { + tokio::runtime::Builder::new_current_thread() + .build() + .unwrap() + }; + + c.bench_function("mpsc_short_msg", |b| { + let buffer = Arc::new(MessageBuffer { + data: b"hello_world".to_vec(), + }); + b.to_async(rt()).iter(|| mpsc_msg(Arc::clone(&buffer))) + }); + c.bench_function("mpsc_long_msg", |b| { + let buffer = Arc::new(MessageBuffer { + data: vec![150u8; 500_000], + }); + b.to_async(rt()).iter(|| mpsc_msg(Arc::clone(&buffer))) + }); + c.bench_function("mpsc_handshake", |b| { + b.to_async(rt()).iter(|| mpsc_handshake()) + }); + + let mut buffer = [0u8; 1500]; + + c.bench_function("frame_serialize_short", |b| { + let frame = Frame::Data { + mid: 65, + start: 89u64, + data: b"hello_world".to_vec(), + }; + b.iter(move || frame_serialize(frame.clone(), &mut buffer)) + }); + + c.bench_function("tcp_short_msg", |b| { + let buffer = Arc::new(MessageBuffer { + data: b"hello_world".to_vec(), + }); + b.to_async(rt()).iter(|| tcp_msg(Arc::clone(&buffer), 1)) + }); + c.bench_function("tcp_1GB_in_10000_msg", |b| { + let buffer = Arc::new(MessageBuffer { + data: vec![155u8; 100_000], + }); + b.to_async(rt()) + .iter(|| tcp_msg(Arc::clone(&buffer), 10_000)) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); + +mod utils { + use super::*; + + pub struct ACDrain { + sender: Sender, + } + + pub struct ACSink { + receiver: Receiver, + } + + pub fn ac_bound( + cap: usize, + metrics: Option, + ) -> [(MpscSendProtcol, MpscRecvProtcol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("mpsc", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + MpscSendProtcol::new(ACDrain { sender: s1 }, m.clone()), + MpscRecvProtcol::new(ACSink { receiver: r2 }, m.clone()), + ), + ( + MpscSendProtcol::new(ACDrain { sender: s2 }, m.clone()), + MpscRecvProtcol::new(ACSink { receiver: r1 }, m.clone()), + ), + ] + } + + pub struct TcpDrain { + sender: Sender>, + } + + pub struct TcpSink { + receiver: Receiver>, + } + + /// emulate Tcp protocol on Channels + pub fn tcp_bound( + cap: usize, + metrics: Option, + ) -> [(TcpSendProtcol, TcpRecvProtcol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("tcp", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + TcpSendProtcol::new(TcpDrain { sender: s1 }, m.clone()), + TcpRecvProtcol::new(TcpSink { receiver: r2 }, m.clone()), + ), + ( + TcpSendProtcol::new(TcpDrain { sender: s2 }, m.clone()), + TcpRecvProtcol::new(TcpSink { receiver: r1 }, m.clone()), + ), + ] + } + + #[async_trait] + impl UnreliableDrain for ACDrain { + type DataFormat = MpscMsg; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for ACSink { + type DataFormat = MpscMsg; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableDrain for TcpDrain { + type DataFormat = Vec; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for TcpSink { + type DataFormat = Vec; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } +} diff --git a/network/protocol/src/event.rs b/network/protocol/src/event.rs new file mode 100644 index 0000000000..14b74de558 --- /dev/null +++ b/network/protocol/src/event.rs @@ -0,0 +1,74 @@ +use crate::{ + frame::Frame, + message::MessageBuffer, + types::{Bandwidth, Mid, Prio, Promises, Sid}, +}; +use std::sync::Arc; + +/* used for communication with Protocols */ +#[derive(Debug, Clone)] +#[cfg_attr(test, derive(PartialEq))] +pub enum ProtocolEvent { + Shutdown, + OpenStream { + sid: Sid, + prio: Prio, + promises: Promises, + guaranteed_bandwidth: Bandwidth, + }, + CloseStream { + sid: Sid, + }, + Message { + buffer: Arc, + mid: Mid, + sid: Sid, + }, +} + +impl ProtocolEvent { + pub(crate) fn to_frame(&self) -> Frame { + match self { + ProtocolEvent::Shutdown => Frame::Shutdown, + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth: _, + } => Frame::OpenStream { + sid: *sid, + prio: *prio, + promises: *promises, + }, + ProtocolEvent::CloseStream { sid } => Frame::CloseStream { sid: *sid }, + ProtocolEvent::Message { .. } => { + unimplemented!("Event::Message to Frame IS NOT supported") + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_to_frame() { + assert_eq!(ProtocolEvent::Shutdown.to_frame(), Frame::Shutdown); + assert_eq!( + ProtocolEvent::CloseStream { sid: Sid::new(42) }.to_frame(), + Frame::CloseStream { sid: Sid::new(42) } + ); + } + + #[test] + #[should_panic] + fn test_sixlet_to_str() { + let _ = ProtocolEvent::Message { + buffer: Arc::new(MessageBuffer { data: vec![] }), + mid: 0, + sid: Sid::new(23), + } + .to_frame(); + } +} diff --git a/network/protocol/src/frame.rs b/network/protocol/src/frame.rs new file mode 100644 index 0000000000..4824498940 --- /dev/null +++ b/network/protocol/src/frame.rs @@ -0,0 +1,634 @@ +use crate::types::{Mid, Pid, Prio, Promises, Sid}; +use std::{collections::VecDeque, convert::TryFrom}; + +// const FRAME_RESERVED_1: u8 = 0; +const FRAME_HANDSHAKE: u8 = 1; +const FRAME_INIT: u8 = 2; +const FRAME_SHUTDOWN: u8 = 3; +const FRAME_OPEN_STREAM: u8 = 4; +const FRAME_CLOSE_STREAM: u8 = 5; +const FRAME_DATA_HEADER: u8 = 6; +const FRAME_DATA: u8 = 7; +const FRAME_RAW: u8 = 8; +//const FRAME_RESERVED_2: u8 = 10; +//const FRAME_RESERVED_3: u8 = 13; + +/// Used for Communication between Channel <----(TCP/UDP)----> Channel +#[derive(Debug, PartialEq, Clone)] +pub /* should be crate only */ enum InitFrame { + Handshake { + magic_number: [u8; 7], + version: [u32; 3], + }, + Init { + pid: Pid, + secret: u128, + }, + /* WARNING: Sending RAW is only used for debug purposes in case someone write a new API + * against veloren Server! */ + Raw(Vec), +} + +/// Used for Communication between Channel <----(TCP/UDP)----> Channel +#[derive(Debug, PartialEq, Clone)] +pub enum Frame { + Shutdown, /* Shutdown this channel gracefully, if all channels are shutdown (gracefully), + * Participant is deleted */ + OpenStream { + sid: Sid, + prio: Prio, + promises: Promises, + }, + CloseStream { + sid: Sid, + }, + DataHeader { + mid: Mid, + sid: Sid, + length: u64, + }, + Data { + mid: Mid, + start: u64, + data: Vec, + }, +} + +impl InitFrame { + // Size WITHOUT the 1rst indicating byte + pub(crate) const HANDSHAKE_CNS: usize = 19; + pub(crate) const INIT_CNS: usize = 32; + /// const part of the RAW frame, actual size is variable + pub(crate) const RAW_CNS: usize = 2; + + //provide an appropriate buffer size. > 1500 + pub(crate) fn to_bytes(self, bytes: &mut [u8]) -> usize { + match self { + InitFrame::Handshake { + magic_number, + version, + } => { + let x = FRAME_HANDSHAKE.to_be_bytes(); + bytes[0] = x[0]; + bytes[1..8].copy_from_slice(&magic_number); + bytes[8..12].copy_from_slice(&version[0].to_le_bytes()); + bytes[12..16].copy_from_slice(&version[1].to_le_bytes()); + bytes[16..Self::HANDSHAKE_CNS + 1].copy_from_slice(&version[2].to_le_bytes()); + Self::HANDSHAKE_CNS + 1 + }, + InitFrame::Init { pid, secret } => { + bytes[0] = FRAME_INIT.to_be_bytes()[0]; + bytes[1..17].copy_from_slice(&pid.to_le_bytes()); + bytes[17..Self::INIT_CNS + 1].copy_from_slice(&secret.to_le_bytes()); + Self::INIT_CNS + 1 + }, + InitFrame::Raw(data) => { + bytes[0] = FRAME_RAW.to_be_bytes()[0]; + bytes[1..3].copy_from_slice(&(data.len() as u16).to_le_bytes()); + bytes[Self::RAW_CNS + 1..(data.len() + Self::RAW_CNS + 1)] + .clone_from_slice(&data[..]); + Self::RAW_CNS + 1 + data.len() + }, + } + } + + pub(crate) fn to_frame(bytes: Vec) -> Option { + let frame_no = match bytes.get(0) { + Some(&f) => f, + None => return None, + }; + let frame = match frame_no { + FRAME_HANDSHAKE => { + if bytes.len() < Self::HANDSHAKE_CNS + 1 { + return None; + } + InitFrame::gen_handshake( + *<&[u8; Self::HANDSHAKE_CNS]>::try_from(&bytes[1..Self::HANDSHAKE_CNS + 1]) + .unwrap(), + ) + }, + FRAME_INIT => { + if bytes.len() < Self::INIT_CNS + 1 { + return None; + } + InitFrame::gen_init( + *<&[u8; Self::INIT_CNS]>::try_from(&bytes[1..Self::INIT_CNS + 1]).unwrap(), + ) + }, + FRAME_RAW => { + if bytes.len() < Self::RAW_CNS + 1 { + return None; + } + let length = InitFrame::gen_raw( + *<&[u8; Self::RAW_CNS]>::try_from(&bytes[1..Self::RAW_CNS + 1]).unwrap(), + ); + let mut data = vec![0; length as usize]; + let slice = &bytes[Self::RAW_CNS + 1..]; + if slice.len() != length as usize { + return None; + } + data.copy_from_slice(&bytes[Self::RAW_CNS + 1..]); + InitFrame::Raw(data) + }, + _ => InitFrame::Raw(bytes), + }; + Some(frame) + } + + fn gen_handshake(buf: [u8; Self::HANDSHAKE_CNS]) -> Self { + let magic_number = *<&[u8; 7]>::try_from(&buf[0..7]).unwrap(); + InitFrame::Handshake { + magic_number, + version: [ + u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[7..11]).unwrap()), + u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[11..15]).unwrap()), + u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[15..Self::HANDSHAKE_CNS]).unwrap()), + ], + } + } + + fn gen_init(buf: [u8; Self::INIT_CNS]) -> Self { + InitFrame::Init { + pid: Pid::from_le_bytes(*<&[u8; 16]>::try_from(&buf[0..16]).unwrap()), + secret: u128::from_le_bytes(*<&[u8; 16]>::try_from(&buf[16..Self::INIT_CNS]).unwrap()), + } + } + + fn gen_raw(buf: [u8; Self::RAW_CNS]) -> u16 { + u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[0..Self::RAW_CNS]).unwrap()) + } +} + +impl Frame { + pub(crate) const CLOSE_STREAM_CNS: usize = 8; + /// const part of the DATA frame, actual size is variable + pub(crate) const DATA_CNS: usize = 18; + pub(crate) const DATA_HEADER_CNS: usize = 24; + #[cfg(feature = "metrics")] + pub const FRAMES_LEN: u8 = 5; + pub(crate) const OPEN_STREAM_CNS: usize = 10; + // Size WITHOUT the 1rst indicating byte + pub(crate) const SHUTDOWN_CNS: usize = 0; + + #[cfg(feature = "metrics")] + pub const fn int_to_string(i: u8) -> &'static str { + match i { + 0 => "Shutdown", + 1 => "OpenStream", + 2 => "CloseStream", + 3 => "DataHeader", + 4 => "Data", + _ => "", + } + } + + #[cfg(feature = "metrics")] + pub fn get_int(&self) -> u8 { + match self { + Frame::Shutdown => 0, + Frame::OpenStream { .. } => 1, + Frame::CloseStream { .. } => 2, + Frame::DataHeader { .. } => 3, + Frame::Data { .. } => 4, + } + } + + #[cfg(feature = "metrics")] + pub fn get_string(&self) -> &str { Self::int_to_string(self.get_int()) } + + //provide an appropriate buffer size. > 1500 + pub fn to_bytes(self, bytes: &mut [u8]) -> (/* buf */ usize, /* actual data */ u64) { + match self { + Frame::Shutdown => { + bytes[Self::SHUTDOWN_CNS] = FRAME_SHUTDOWN.to_be_bytes()[0]; + (Self::SHUTDOWN_CNS + 1, 0) + }, + Frame::OpenStream { + sid, + prio, + promises, + } => { + bytes[0] = FRAME_OPEN_STREAM.to_be_bytes()[0]; + bytes[1..9].copy_from_slice(&sid.to_le_bytes()); + bytes[9] = prio.to_le_bytes()[0]; + bytes[Self::OPEN_STREAM_CNS] = promises.to_le_bytes()[0]; + (Self::OPEN_STREAM_CNS + 1, 0) + }, + Frame::CloseStream { sid } => { + bytes[0] = FRAME_CLOSE_STREAM.to_be_bytes()[0]; + bytes[1..Self::CLOSE_STREAM_CNS + 1].copy_from_slice(&sid.to_le_bytes()); + (Self::CLOSE_STREAM_CNS + 1, 0) + }, + Frame::DataHeader { mid, sid, length } => { + bytes[0] = FRAME_DATA_HEADER.to_be_bytes()[0]; + bytes[1..9].copy_from_slice(&mid.to_le_bytes()); + bytes[9..17].copy_from_slice(&sid.to_le_bytes()); + bytes[17..Self::DATA_HEADER_CNS + 1].copy_from_slice(&length.to_le_bytes()); + (Self::DATA_HEADER_CNS + 1, 0) + }, + Frame::Data { mid, start, data } => { + bytes[0] = FRAME_DATA.to_be_bytes()[0]; + bytes[1..9].copy_from_slice(&mid.to_le_bytes()); + bytes[9..17].copy_from_slice(&start.to_le_bytes()); + bytes[17..Self::DATA_CNS + 1].copy_from_slice(&(data.len() as u16).to_le_bytes()); + bytes[Self::DATA_CNS + 1..(data.len() + Self::DATA_CNS + 1)] + .clone_from_slice(&data[..]); + (Self::DATA_CNS + 1 + data.len(), data.len() as u64) + }, + } + } + + pub(crate) fn to_frame(bytes: &mut VecDeque) -> Option { + let frame_no = match bytes.get(0) { + Some(&f) => f, + None => return None, + }; + let size = match frame_no { + FRAME_SHUTDOWN => Self::SHUTDOWN_CNS, + FRAME_OPEN_STREAM => Self::OPEN_STREAM_CNS, + FRAME_CLOSE_STREAM => Self::CLOSE_STREAM_CNS, + FRAME_DATA_HEADER => Self::DATA_HEADER_CNS, + FRAME_DATA => { + u16::from_le_bytes([bytes[16 + 1], bytes[17 + 1]]) as usize + Self::DATA_CNS + }, + _ => return None, + }; + + if bytes.len() < size + 1 { + return None; + } + + let frame = match frame_no { + FRAME_SHUTDOWN => { + let _ = bytes.drain(..size + 1); + Frame::Shutdown + }, + FRAME_OPEN_STREAM => { + let bytes = bytes.drain(..size + 1).skip(1).collect::>(); + Frame::gen_open_stream(<[u8; 10]>::try_from(bytes).unwrap()) + }, + FRAME_CLOSE_STREAM => { + let bytes = bytes.drain(..size + 1).skip(1).collect::>(); + Frame::gen_close_stream(<[u8; 8]>::try_from(bytes).unwrap()) + }, + FRAME_DATA_HEADER => { + let bytes = bytes.drain(..size + 1).skip(1).collect::>(); + Frame::gen_data_header(<[u8; 24]>::try_from(bytes).unwrap()) + }, + FRAME_DATA => { + let info = bytes + .drain(..Self::DATA_CNS + 1) + .skip(1) + .collect::>(); + let (mid, start, length) = Frame::gen_data(<[u8; 18]>::try_from(info).unwrap()); + debug_assert_eq!(length as usize, size - Self::DATA_CNS); + let data = bytes.drain(..length as usize).collect::>(); + Frame::Data { mid, start, data } + }, + _ => unreachable!("Frame::to_frame should be handled before!"), + }; + Some(frame) + } + + fn gen_open_stream(buf: [u8; Self::OPEN_STREAM_CNS]) -> Self { + Frame::OpenStream { + sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), + prio: buf[8], + promises: Promises::from_bits_truncate(buf[Self::OPEN_STREAM_CNS - 1]), + } + } + + fn gen_close_stream(buf: [u8; Self::CLOSE_STREAM_CNS]) -> Self { + Frame::CloseStream { + sid: Sid::from_le_bytes( + *<&[u8; 8]>::try_from(&buf[0..Self::CLOSE_STREAM_CNS]).unwrap(), + ), + } + } + + fn gen_data_header(buf: [u8; Self::DATA_HEADER_CNS]) -> Self { + Frame::DataHeader { + mid: Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), + sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()), + length: u64::from_le_bytes( + *<&[u8; 8]>::try_from(&buf[16..Self::DATA_HEADER_CNS]).unwrap(), + ), + } + } + + fn gen_data(buf: [u8; Self::DATA_CNS]) -> (Mid, u64, u16) { + let mid = Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()); + let start = u64::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()); + let length = u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[16..Self::DATA_CNS]).unwrap()); + (mid, start, length) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::{VELOREN_MAGIC_NUMBER, VELOREN_NETWORK_VERSION}; + + fn get_initframes() -> Vec { + vec![ + InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }, + InitFrame::Init { + pid: Pid::fake(0), + secret: 0u128, + }, + InitFrame::Raw(vec![1, 2, 3]), + ] + } + + fn get_frames() -> Vec { + vec![ + Frame::OpenStream { + sid: Sid::new(1337), + prio: 14, + promises: Promises::GUARANTEED_DELIVERY, + }, + Frame::DataHeader { + sid: Sid::new(1337), + mid: 0, + length: 36, + }, + Frame::Data { + mid: 0, + start: 0, + data: vec![77u8; 20], + }, + Frame::Data { + mid: 0, + start: 20, + data: vec![42u8; 16], + }, + Frame::CloseStream { + sid: Sid::new(1337), + }, + Frame::Shutdown, + ] + } + + #[test] + fn initframe_individual() { + let dupl = |frame: InitFrame| { + let mut buffer = vec![0u8; 1500]; + let size = InitFrame::to_bytes(frame.clone(), &mut buffer); + buffer.truncate(size); + InitFrame::to_frame(buffer) + }; + + for frame in get_initframes() { + println!("initframe: {:?}", &frame); + assert_eq!(Some(frame.clone()), dupl(frame)); + } + } + + #[test] + fn initframe_multiple() { + let mut buffer = vec![0u8; 3000]; + + let mut frames = get_initframes(); + let mut last = 0; + // to string + let sizes = frames + .iter() + .map(|f| { + let s = InitFrame::to_bytes(f.clone(), &mut buffer[last..]); + last += s; + s + }) + .collect::>(); + + // from string + let mut last = 0; + let mut framesd = sizes + .iter() + .map(|&s| { + let f = InitFrame::to_frame(buffer[last..last + s].to_vec()); + last += s; + f + }) + .collect::>(); + + // compare + for (f, fd) in frames.drain(..).zip(framesd.drain(..)) { + println!("initframe: {:?}", &f); + assert_eq!(Some(f), fd); + } + } + + #[test] + fn frame_individual() { + let dupl = |frame: Frame| { + let mut buffer = vec![0u8; 1500]; + let (size, _) = Frame::to_bytes(frame.clone(), &mut buffer); + let mut deque = buffer[..size].iter().map(|b| *b).collect(); + Frame::to_frame(&mut deque) + }; + + for frame in get_frames() { + println!("frame: {:?}", &frame); + assert_eq!(Some(frame.clone()), dupl(frame)); + } + } + + #[test] + fn frame_multiple() { + let mut buffer = vec![0u8; 3000]; + + let mut frames = get_frames(); + let mut last = 0; + // to string + let sizes = frames + .iter() + .map(|f| { + let s = Frame::to_bytes(f.clone(), &mut buffer[last..]).0; + last += s; + s + }) + .collect::>(); + + assert_eq!(sizes[0], 1 + Frame::OPEN_STREAM_CNS); + assert_eq!(sizes[1], 1 + Frame::DATA_HEADER_CNS); + assert_eq!(sizes[2], 1 + Frame::DATA_CNS + 20); + assert_eq!(sizes[3], 1 + Frame::DATA_CNS + 16); + assert_eq!(sizes[4], 1 + Frame::CLOSE_STREAM_CNS); + assert_eq!(sizes[5], 1 + Frame::SHUTDOWN_CNS); + + let mut buffer = buffer.drain(..).collect::>(); + + // from string + let mut framesd = sizes + .iter() + .map(|&_| Frame::to_frame(&mut buffer)) + .collect::>(); + + // compare + for (f, fd) in frames.drain(..).zip(framesd.drain(..)) { + println!("frame: {:?}", &f); + assert_eq!(Some(f), fd); + } + } + + #[test] + fn frame_exact_size() { + let mut buffer = vec![0u8; Frame::CLOSE_STREAM_CNS+1/*first byte*/]; + + let frame1 = Frame::CloseStream { + sid: Sid::new(1337), + }; + let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + let mut deque = buffer.iter().map(|b| *b).collect(); + let frame2 = Frame::to_frame(&mut deque); + assert_eq!(Some(frame1), frame2); + } + + #[test] + #[should_panic] + fn initframe_too_short_buffer() { + let mut buffer = vec![0u8; 10]; + + let frame1 = InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }; + let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); + } + + #[test] + fn initframe_too_less_data() { + let mut buffer = vec![0u8; 20]; + + let frame1 = InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }; + let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); + buffer.truncate(6); // simulate partial retrieve + let frame1d = InitFrame::to_frame(buffer[..6].to_vec()); + assert_eq!(frame1d, None); + } + + #[test] + fn initframe_rubish() { + let buffer = b"dtrgwcser".to_vec(); + assert_eq!( + InitFrame::to_frame(buffer), + Some(InitFrame::Raw(b"dtrgwcser".to_vec())) + ); + } + + #[test] + fn initframe_attack_too_much_length() { + let mut buffer = vec![0u8; 50]; + + let frame1 = InitFrame::Raw(b"foobar".to_vec()); + let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); + buffer[2] = 255; + let framed = InitFrame::to_frame(buffer); + assert_eq!(framed, None); + } + + #[test] + fn initframe_attack_too_low_length() { + let mut buffer = vec![0u8; 50]; + + let frame1 = InitFrame::Raw(b"foobar".to_vec()); + let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); + buffer[2] = 3; + let framed = InitFrame::to_frame(buffer); + assert_eq!(framed, None); + } + + #[test] + #[should_panic] + fn frame_too_short_buffer() { + let mut buffer = vec![0u8; 10]; + + let frame1 = Frame::OpenStream { + sid: Sid::new(88), + promises: Promises::ENCRYPTED, + prio: 88, + }; + let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + } + + #[test] + fn frame_too_less_data() { + let mut buffer = vec![0u8; 20]; + + let frame1 = Frame::OpenStream { + sid: Sid::new(88), + promises: Promises::ENCRYPTED, + prio: 88, + }; + let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + buffer.truncate(6); // simulate partial retrieve + let mut buffer = buffer.drain(..6).collect::>(); + let frame1d = Frame::to_frame(&mut buffer); + assert_eq!(frame1d, None); + } + + #[test] + fn frame_rubish() { + let mut buffer = b"dtrgwcser".iter().map(|u| *u).collect::>(); + assert_eq!(Frame::to_frame(&mut buffer), None); + } + + #[test] + fn frame_attack_too_much_length() { + let mut buffer = vec![0u8; 50]; + + let frame1 = Frame::Data { + mid: 7u64, + start: 1u64, + data: b"foobar".to_vec(), + }; + + let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + buffer[17] = 255; + let mut buffer = buffer.drain(..).collect::>(); + let framed = Frame::to_frame(&mut buffer); + assert_eq!(framed, None); + } + + #[test] + fn frame_attack_too_low_length() { + let mut buffer = vec![0u8; 50]; + + let frame1 = Frame::Data { + mid: 7u64, + start: 1u64, + data: b"foobar".to_vec(), + }; + + let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + buffer[17] = 3; + let mut buffer = buffer.drain(..).collect::>(); + let framed = Frame::to_frame(&mut buffer); + assert_eq!( + framed, + Some(Frame::Data { + mid: 7u64, + start: 1u64, + data: b"foo".to_vec(), + }) + ); + //next = Invalid => Empty + let framed = Frame::to_frame(&mut buffer); + assert_eq!(framed, None); + } + + #[test] + fn frame_int2str() { + assert_eq!(Frame::int_to_string(0), "Shutdown"); + assert_eq!(Frame::int_to_string(1), "OpenStream"); + assert_eq!(Frame::int_to_string(2), "CloseStream"); + assert_eq!(Frame::int_to_string(3), "DataHeader"); + assert_eq!(Frame::int_to_string(4), "Data"); + } +} diff --git a/network/protocol/src/handshake.rs b/network/protocol/src/handshake.rs new file mode 100644 index 0000000000..cc46791fc6 --- /dev/null +++ b/network/protocol/src/handshake.rs @@ -0,0 +1,227 @@ +use crate::{ + frame::InitFrame, + types::{ + Pid, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2, VELOREN_MAGIC_NUMBER, + VELOREN_NETWORK_VERSION, + }, + InitProtocol, InitProtocolError, ProtocolError, +}; +use async_trait::async_trait; +use tracing::{debug, error, info, trace}; + +// Protocols might define a Reliable Variant for auto Handshake discovery +// this doesn't need to be effective +#[async_trait] +pub trait ReliableDrain { + async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError>; +} + +#[async_trait] +pub trait ReliableSink { + async fn recv(&mut self) -> Result; +} + +#[async_trait] +impl InitProtocol for (D, S) +where + D: ReliableDrain + Send, + S: ReliableSink + Send, +{ + async fn initialize( + &mut self, + initializer: bool, + local_pid: Pid, + local_secret: u128, + ) -> Result<(Pid, Sid, u128), InitProtocolError> { + #[cfg(debug_assertions)] + const WRONG_NUMBER: &'static [u8] = "Handshake does not contain the magic number required \ + by veloren server.\nWe are not sure if you are a \ + valid veloren client.\nClosing the connection" + .as_bytes(); + #[cfg(debug_assertions)] + const WRONG_VERSION: &'static str = "Handshake does contain a correct magic number, but \ + invalid version.\nWe don't know how to communicate \ + with you.\nClosing the connection"; + const ERR_S: &str = "Got A Raw Message, these are usually Debug Messages indicating that \ + something went wrong on network layer and connection will be closed"; + + let drain = &mut self.0; + let sink = &mut self.1; + + if initializer { + drain + .send(InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }) + .await?; + } + + match sink.recv().await? { + InitFrame::Handshake { + magic_number, + version, + } => { + trace!(?magic_number, ?version, "Recv handshake"); + if magic_number != VELOREN_MAGIC_NUMBER { + error!(?magic_number, "Connection with invalid magic_number"); + #[cfg(debug_assertions)] + drain.send(InitFrame::Raw(WRONG_NUMBER.to_vec())).await?; + Err(InitProtocolError::WrongMagicNumber(magic_number)) + } else if version != VELOREN_NETWORK_VERSION { + error!(?version, "Connection with wrong network version"); + #[cfg(debug_assertions)] + drain + .send(InitFrame::Raw( + format!( + "{} Our Version: {:?}\nYour Version: {:?}\nClosing the connection", + WRONG_VERSION, VELOREN_NETWORK_VERSION, version, + ) + .as_bytes() + .to_vec(), + )) + .await?; + Err(InitProtocolError::WrongVersion(version)) + } else { + trace!("Handshake Frame completed"); + if initializer { + drain + .send(InitFrame::Init { + pid: local_pid, + secret: local_secret, + }) + .await?; + } else { + drain + .send(InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }) + .await?; + } + Ok(()) + } + }, + InitFrame::Raw(bytes) => { + match std::str::from_utf8(bytes.as_slice()) { + Ok(string) => error!(?string, ERR_S), + _ => error!(?bytes, ERR_S), + } + Err(InitProtocolError::Closed) + }, + _ => { + info!("Handshake failed"); + Err(InitProtocolError::Closed) + }, + }?; + + match sink.recv().await? { + InitFrame::Init { pid, secret } => { + debug!(?pid, "Participant send their ID"); + let stream_id_offset = if initializer { + STREAM_ID_OFFSET1 + } else { + drain + .send(InitFrame::Init { + pid: local_pid, + secret: local_secret, + }) + .await?; + STREAM_ID_OFFSET2 + }; + info!(?pid, "This Handshake is now configured!"); + Ok((pid, stream_id_offset, secret)) + }, + InitFrame::Raw(bytes) => { + match std::str::from_utf8(bytes.as_slice()) { + Ok(string) => error!(?string, ERR_S), + _ => error!(?bytes, ERR_S), + } + Err(InitProtocolError::Closed) + }, + _ => { + info!("Handshake failed"); + Err(InitProtocolError::Closed) + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{mpsc::test_utils::*, InitProtocolError}; + + #[tokio::test] + async fn handshake_drop_start() { + let [mut p1, p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { + let _ = p2; + }); + let (r1, _) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Err(InitProtocolError::Closed)); + } + + #[tokio::test] + async fn handshake_wrong_magic_number() { + let [mut p1, mut p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { + let _ = p2.1.recv().await?; + p2.0.send(InitFrame::Handshake { + magic_number: *b"woopsie", + version: VELOREN_NETWORK_VERSION, + }) + .await?; + let _ = p2.1.recv().await?; + Result::<(), InitProtocolError>::Ok(()) + }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!( + r1.unwrap(), + Err(InitProtocolError::WrongMagicNumber(*b"woopsie")) + ); + assert_eq!(r2.unwrap(), Ok(())); + } + + #[tokio::test] + async fn handshake_wrong_version() { + let [mut p1, mut p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { + let _ = p2.1.recv().await?; + p2.0.send(InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: [0, 1, 2], + }) + .await?; + let _ = p2.1.recv().await?; + let _ = p2.1.recv().await?; //this should be closed now + Ok(()) + }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Err(InitProtocolError::WrongVersion([0, 1, 2]))); + assert_eq!(r2.unwrap(), Err(InitProtocolError::Closed)); + } + + #[tokio::test] + async fn handshake_unexpected_raw() { + let [mut p1, mut p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { + let _ = p2.1.recv().await?; + p2.0.send(InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }) + .await?; + let _ = p2.1.recv().await?; + p2.0.send(InitFrame::Raw(b"Hello World".to_vec())).await?; + Result::<(), InitProtocolError>::Ok(()) + }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Err(InitProtocolError::Closed)); + assert_eq!(r2.unwrap(), Ok(())); + } +} diff --git a/network/protocol/src/io.rs b/network/protocol/src/io.rs new file mode 100644 index 0000000000..c4e3eba43e --- /dev/null +++ b/network/protocol/src/io.rs @@ -0,0 +1,62 @@ +use crate::ProtocolError; +use async_trait::async_trait; +use std::collections::VecDeque; +///! I/O-Free (Sans-I/O) protocol https://sans-io.readthedocs.io/how-to-sans-io.html + +// Protocols should base on the Unrealiable variants to get something effective! +#[async_trait] +pub trait UnreliableDrain: Send { + type DataFormat; + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError>; +} + +#[async_trait] +pub trait UnreliableSink: Send { + type DataFormat; + async fn recv(&mut self) -> Result; +} + +pub struct BaseDrain { + data: VecDeque>, +} + +pub struct BaseSink { + data: VecDeque>, +} + +impl BaseDrain { + pub fn new() -> Self { + Self { + data: VecDeque::new(), + } + } +} + +impl BaseSink { + pub fn new() -> Self { + Self { + data: VecDeque::new(), + } + } +} + +//TODO: Test Sinks that drop 20% by random and log that + +#[async_trait] +impl UnreliableDrain for BaseDrain { + type DataFormat = Vec; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.data.push_back(data); + Ok(()) + } +} + +#[async_trait] +impl UnreliableSink for BaseSink { + type DataFormat = Vec; + + async fn recv(&mut self) -> Result { + self.data.pop_front().ok_or(ProtocolError::Closed) + } +} diff --git a/network/protocol/src/lib.rs b/network/protocol/src/lib.rs new file mode 100644 index 0000000000..8d49ed58c9 --- /dev/null +++ b/network/protocol/src/lib.rs @@ -0,0 +1,75 @@ +mod event; +mod frame; +mod handshake; +mod io; +mod message; +mod metrics; +mod mpsc; +mod prio; +mod tcp; +mod types; + +pub use event::ProtocolEvent; +pub use io::{BaseDrain, BaseSink, UnreliableDrain, UnreliableSink}; +pub use message::MessageBuffer; +pub use metrics::ProtocolMetricCache; +#[cfg(feature = "metrics")] +pub use metrics::ProtocolMetrics; +pub use mpsc::{MpscMsg, MpscRecvProtcol, MpscSendProtcol}; +pub use tcp::{TcpRecvProtcol, TcpSendProtcol}; +pub use types::{Bandwidth, Cid, Mid, Pid, Prio, Promises, Sid, VELOREN_NETWORK_VERSION}; + +///use at own risk, might change any time, for internal benchmarks +pub mod _internal { + pub use crate::frame::Frame; +} + +use async_trait::async_trait; + +#[async_trait] +pub trait InitProtocol { + async fn initialize( + &mut self, + initializer: bool, + local_pid: Pid, + secret: u128, + ) -> Result<(Pid, Sid, u128), InitProtocolError>; +} + +#[async_trait] +pub trait SendProtocol { + //a stream MUST be bound to a specific Protocol, there will be a failover + // feature comming for the case where a Protocol fails completly + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError>; + async fn flush( + &mut self, + bandwidth: Bandwidth, + dt: std::time::Duration, + ) -> Result<(), ProtocolError>; +} + +#[async_trait] +pub trait RecvProtocol { + async fn recv(&mut self) -> Result; +} + +#[derive(Debug, PartialEq)] +pub enum InitProtocolError { + Closed, + WrongMagicNumber([u8; 7]), + WrongVersion([u32; 3]), +} + +#[derive(Debug, PartialEq)] +/// When you return closed you must stay closed! +pub enum ProtocolError { + Closed, +} + +impl From for InitProtocolError { + fn from(err: ProtocolError) -> Self { + match err { + ProtocolError::Closed => InitProtocolError::Closed, + } + } +} diff --git a/network/protocol/src/message.rs b/network/protocol/src/message.rs new file mode 100644 index 0000000000..1bda1325ad --- /dev/null +++ b/network/protocol/src/message.rs @@ -0,0 +1,127 @@ +use crate::{ + frame::Frame, + types::{Mid, Sid}, +}; +use std::{collections::VecDeque, sync::Arc}; + +//Todo: Evaluate switching to VecDeque for quickly adding and removing data +// from front, back. +// - It would prob require custom bincode code but thats possible. +#[cfg_attr(test, derive(PartialEq))] +pub struct MessageBuffer { + pub data: Vec, +} + +impl std::fmt::Debug for MessageBuffer { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + //TODO: small messages! + let len = self.data.len(); + if len > 20 { + write!( + f, + "MessageBuffer(len: {}, {}, {}, {}, {:X?}..{:X?})", + len, + u32::from_le_bytes([self.data[0], self.data[1], self.data[2], self.data[3]]), + u32::from_le_bytes([self.data[4], self.data[5], self.data[6], self.data[7]]), + u32::from_le_bytes([self.data[8], self.data[9], self.data[10], self.data[11]]), + &self.data[13..16], + &self.data[len - 8..len] + ) + } else { + write!(f, "MessageBuffer(len: {}, {:?})", len, &self.data[..]) + } + } +} + +/// Contains a outgoing message and store what was *send* and *confirmed* +/// All Chunks have the same size, except for the last chunk which can end +/// earlier. E.g. +/// ```ignore +/// msg = OutgoingMessage::new(); +/// msg.next(); +/// msg.next(); +/// msg.confirm(1); +/// msg.confirm(2); +/// ``` +#[derive(Debug)] +pub(crate) struct OutgoingMessage { + buffer: Arc, + send_index: u64, // 3 => 4200 (3*FRAME_DATA_SIZE) + send_header: bool, + mid: Mid, + sid: Sid, + max_index: u64, //speedup + missing_header: bool, + missing_indices: VecDeque, +} + +impl OutgoingMessage { + pub(crate) const FRAME_DATA_SIZE: u64 = 1400; + + pub(crate) fn new(buffer: Arc, mid: Mid, sid: Sid) -> Self { + let max_index = + (buffer.data.len() as u64 + Self::FRAME_DATA_SIZE - 1) / Self::FRAME_DATA_SIZE; + Self { + buffer, + send_index: 0, + send_header: false, + mid, + sid, + max_index, + missing_header: false, + missing_indices: VecDeque::new(), + } + } + + /// all has been send once, but might been resend due to failures. + #[allow(dead_code)] + pub(crate) fn initial_sent(&self) -> bool { self.send_index == self.max_index } + + pub fn get_header(&self) -> Frame { + Frame::DataHeader { + mid: self.mid, + sid: self.sid, + length: self.buffer.data.len() as u64, + } + } + + pub fn get_data(&self, index: u64) -> Frame { + let start = index * Self::FRAME_DATA_SIZE; + let to_send = std::cmp::min( + self.buffer.data[start as usize..].len() as u64, + Self::FRAME_DATA_SIZE, + ); + Frame::Data { + mid: self.mid, + start, + data: self.buffer.data[start as usize..][..to_send as usize].to_vec(), + } + } + + #[allow(dead_code)] + pub(crate) fn set_missing(&mut self, missing_header: bool, missing_indicies: VecDeque) { + self.missing_header = missing_header; + self.missing_indices = missing_indicies; + } + + /// returns if something was added + pub(crate) fn next(&mut self) -> Option { + if !self.send_header { + self.send_header = true; + Some(self.get_header()) + } else if self.send_index < self.max_index { + self.send_index += 1; + Some(self.get_data(self.send_index - 1)) + } else if self.missing_header { + self.missing_header = false; + Some(self.get_header()) + } else if let Some(index) = self.missing_indices.pop_front() { + Some(self.get_data(index)) + } else { + None + } + } + + pub(crate) fn get_sid_len(&self) -> (Sid, u64) { (self.sid, self.buffer.data.len() as u64) } +} diff --git a/network/protocol/src/metrics.rs b/network/protocol/src/metrics.rs new file mode 100644 index 0000000000..715a06fc9d --- /dev/null +++ b/network/protocol/src/metrics.rs @@ -0,0 +1,414 @@ +use crate::types::Sid; +#[cfg(feature = "metrics")] +use prometheus::{IntCounterVec, IntGaugeVec, Opts, Registry}; +#[cfg(feature = "metrics")] +use std::{error::Error, sync::Arc}; + +#[allow(dead_code)] +pub enum RemoveReason { + Finished, + Dropped, +} + +#[cfg(feature = "metrics")] +pub struct ProtocolMetrics { + // smsg=send_msg rdata=receive_data + // i=in o=out + // t=total b=byte throughput + //e.g smsg_it = sending messages, in (responsibility of protocol) total + + // based on CHANNEL/STREAM + /// messages added to be send total, by STREAM, + smsg_it: IntCounterVec, + /// messages bytes added to be send throughput, by STREAM, + smsg_ib: IntCounterVec, + /// messages removed from to be send, because they where finished total, by + /// STREAM AND REASON(finished/canceled), + smsg_ot: IntCounterVec, + /// messages bytes removed from to be send throughput, because they where + /// finished total, by STREAM AND REASON(finished/dropped), + smsg_ob: IntCounterVec, + /// data frames send by prio by CHANNEL, + sdata_frames_t: IntCounterVec, + /// data frames bytes send by prio by CHANNEL, + sdata_frames_b: IntCounterVec, + + // based on CHANNEL/STREAM + /// messages added to be received total, by STREAM, + rmsg_it: IntCounterVec, + /// messages bytes added to be received throughput, by STREAM, + rmsg_ib: IntCounterVec, + /// messages removed from to be received, because they where finished total, + /// by STREAM AND REASON(finished/canceled), + rmsg_ot: IntCounterVec, + /// messages bytes removed from to be received throughput, because they + /// where finished total, by STREAM AND REASON(finished/dropped), + rmsg_ob: IntCounterVec, + /// data frames send by prio by CHANNEL, + rdata_frames_t: IntCounterVec, + /// data frames bytes send by prio by CHANNEL, + rdata_frames_b: IntCounterVec, + /// ping per CHANNEL //TODO: implement + ping: IntGaugeVec, +} + +#[cfg(feature = "metrics")] +#[derive(Debug, Clone)] +pub struct ProtocolMetricCache { + cid: String, + m: Arc, +} + +#[cfg(not(feature = "metrics"))] +#[derive(Debug, Clone)] +pub struct ProtocolMetricCache {} + +#[cfg(feature = "metrics")] +impl ProtocolMetrics { + pub fn new() -> Result> { + let smsg_it = IntCounterVec::new( + Opts::new( + "send_messages_in_total", + "All Messages that are added to this Protocol to be send at stream level", + ), + &["channel", "stream"], + )?; + let smsg_ib = IntCounterVec::new( + Opts::new( + "send_messages_in_throughput", + "All Message bytes that are added to this Protocol to be send at stream level", + ), + &["channel", "stream"], + )?; + let smsg_ot = IntCounterVec::new( + Opts::new( + "send_messages_out_total", + "All Messages that are removed from this Protocol to be send at stream and \ + reason(finished/canceled) level", + ), + &["channel", "stream", "reason"], + )?; + let smsg_ob = IntCounterVec::new( + Opts::new( + "send_messages_out_throughput", + "All Message bytes that are removed from this Protocol to be send at stream and \ + reason(finished/canceled) level", + ), + &["channel", "stream", "reason"], + )?; + let sdata_frames_t = IntCounterVec::new( + Opts::new( + "send_data_frames_total", + "Number of data frames send per channel", + ), + &["channel"], + )?; + let sdata_frames_b = IntCounterVec::new( + Opts::new( + "send_data_frames_throughput", + "Number of data frames bytes send per channel", + ), + &["channel"], + )?; + + let rmsg_it = IntCounterVec::new( + Opts::new( + "recv_messages_in_total", + "All Messages that are added to this Protocol to be received at stream level", + ), + &["channel", "stream"], + )?; + let rmsg_ib = IntCounterVec::new( + Opts::new( + "recv_messages_in_throughput", + "All Message bytes that are added to this Protocol to be received at stream level", + ), + &["channel", "stream"], + )?; + let rmsg_ot = IntCounterVec::new( + Opts::new( + "recv_messages_out_total", + "All Messages that are removed from this Protocol to be received at stream and \ + reason(finished/canceled) level", + ), + &["channel", "stream", "reason"], + )?; + let rmsg_ob = IntCounterVec::new( + Opts::new( + "recv_messages_out_throughput", + "All Message bytes that are removed from this Protocol to be received at stream \ + and reason(finished/canceled) level", + ), + &["channel", "stream", "reason"], + )?; + let rdata_frames_t = IntCounterVec::new( + Opts::new( + "recv_data_frames_total", + "Number of data frames received per channel", + ), + &["channel"], + )?; + let rdata_frames_b = IntCounterVec::new( + Opts::new( + "recv_data_frames_throughput", + "Number of data frames bytes received per channel", + ), + &["channel"], + )?; + let ping = IntGaugeVec::new(Opts::new("ping", "Ping per channel"), &["channel"])?; + + Ok(Self { + smsg_it, + smsg_ib, + smsg_ot, + smsg_ob, + sdata_frames_t, + sdata_frames_b, + rmsg_it, + rmsg_ib, + rmsg_ot, + rmsg_ob, + rdata_frames_t, + rdata_frames_b, + ping, + }) + } + + pub fn register(&self, registry: &Registry) -> Result<(), Box> { + registry.register(Box::new(self.smsg_it.clone()))?; + registry.register(Box::new(self.smsg_ib.clone()))?; + registry.register(Box::new(self.smsg_ot.clone()))?; + registry.register(Box::new(self.smsg_ob.clone()))?; + registry.register(Box::new(self.sdata_frames_t.clone()))?; + registry.register(Box::new(self.sdata_frames_b.clone()))?; + registry.register(Box::new(self.rmsg_it.clone()))?; + registry.register(Box::new(self.rmsg_ib.clone()))?; + registry.register(Box::new(self.rmsg_ot.clone()))?; + registry.register(Box::new(self.rmsg_ob.clone()))?; + registry.register(Box::new(self.rdata_frames_t.clone()))?; + registry.register(Box::new(self.rdata_frames_b.clone()))?; + registry.register(Box::new(self.ping.clone()))?; + Ok(()) + } +} + +#[cfg(feature = "metrics")] +impl ProtocolMetricCache { + pub fn new(channel_key: &str, metrics: Arc) -> Self { + Self { + cid: channel_key.to_string(), + m: metrics, + } + } + + pub(crate) fn smsg_it(&self, sid: Sid) { + self.m + .smsg_it + .with_label_values(&[&self.cid, &sid.to_string()]) + .inc(); + } + + pub(crate) fn smsg_ib(&self, sid: Sid, bytes: u64) { + self.m + .smsg_ib + .with_label_values(&[&self.cid, &sid.to_string()]) + .inc_by(bytes); + } + + pub(crate) fn smsg_ot(&self, sid: Sid, reason: RemoveReason) { + self.m + .smsg_ot + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .inc(); + } + + pub(crate) fn smsg_ob(&self, sid: Sid, reason: RemoveReason, bytes: u64) { + self.m + .smsg_ob + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .inc_by(bytes); + } + + pub(crate) fn sdata_frames_t(&self) { + self.m.sdata_frames_t.with_label_values(&[&self.cid]).inc(); + } + + pub(crate) fn sdata_frames_b(&self, bytes: u64) { + self.m + .sdata_frames_b + .with_label_values(&[&self.cid]) + .inc_by(bytes); + } + + pub(crate) fn rmsg_it(&self, sid: Sid) { + self.m + .rmsg_it + .with_label_values(&[&self.cid, &sid.to_string()]) + .inc(); + } + + pub(crate) fn rmsg_ib(&self, sid: Sid, bytes: u64) { + self.m + .rmsg_ib + .with_label_values(&[&self.cid, &sid.to_string()]) + .inc_by(bytes); + } + + pub(crate) fn rmsg_ot(&self, sid: Sid, reason: RemoveReason) { + self.m + .rmsg_ot + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .inc(); + } + + pub(crate) fn rmsg_ob(&self, sid: Sid, reason: RemoveReason, bytes: u64) { + self.m + .rmsg_ob + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .inc_by(bytes); + } + + pub(crate) fn rdata_frames_t(&self) { + self.m.rdata_frames_t.with_label_values(&[&self.cid]).inc(); + } + + pub(crate) fn rdata_frames_b(&self, bytes: u64) { + self.m + .rdata_frames_b + .with_label_values(&[&self.cid]) + .inc_by(bytes); + } + + #[cfg(test)] + pub(crate) fn assert_msg(&self, sid: Sid, cnt: u64, reason: RemoveReason) { + assert_eq!( + self.m + .smsg_it + .with_label_values(&[&self.cid, &sid.to_string()]) + .get(), + cnt + ); + assert_eq!( + self.m + .smsg_ot + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .get(), + cnt + ); + assert_eq!( + self.m + .rmsg_it + .with_label_values(&[&self.cid, &sid.to_string()]) + .get(), + cnt + ); + assert_eq!( + self.m + .rmsg_ot + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .get(), + cnt + ); + } + + #[cfg(test)] + pub(crate) fn assert_msg_bytes(&self, sid: Sid, bytes: u64, reason: RemoveReason) { + assert_eq!( + self.m + .smsg_ib + .with_label_values(&[&self.cid, &sid.to_string()]) + .get(), + bytes + ); + assert_eq!( + self.m + .smsg_ob + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .get(), + bytes + ); + assert_eq!( + self.m + .rmsg_ib + .with_label_values(&[&self.cid, &sid.to_string()]) + .get(), + bytes + ); + assert_eq!( + self.m + .rmsg_ob + .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) + .get(), + bytes + ); + } + + #[cfg(test)] + pub(crate) fn assert_data_frames(&self, cnt: u64) { + assert_eq!( + self.m.sdata_frames_t.with_label_values(&[&self.cid]).get(), + cnt + ); + assert_eq!( + self.m.rdata_frames_t.with_label_values(&[&self.cid]).get(), + cnt + ); + } + + #[cfg(test)] + pub(crate) fn assert_data_frames_bytes(&self, bytes: u64) { + assert_eq!( + self.m.sdata_frames_b.with_label_values(&[&self.cid]).get(), + bytes + ); + assert_eq!( + self.m.rdata_frames_b.with_label_values(&[&self.cid]).get(), + bytes + ); + } +} + +#[cfg(feature = "metrics")] +impl std::fmt::Debug for ProtocolMetrics { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ProtocolMetrics()") + } +} + +#[cfg(not(feature = "metrics"))] +impl ProtocolMetricCache { + pub(crate) fn smsg_it(&self, _sid: Sid) {} + + pub(crate) fn smsg_ib(&self, _sid: Sid, _b: u64) {} + + pub(crate) fn smsg_ot(&self, _sid: Sid, _reason: RemoveReason) {} + + pub(crate) fn smsg_ob(&self, _sid: Sid, _reason: RemoveReason, _b: u64) {} + + pub(crate) fn sdata_frames_t(&self) {} + + pub(crate) fn sdata_frames_b(&self, _b: u64) {} + + pub(crate) fn rmsg_it(&self, _sid: Sid) {} + + pub(crate) fn rmsg_ib(&self, _sid: Sid, _b: u64) {} + + pub(crate) fn rmsg_ot(&self, _sid: Sid, _reason: RemoveReason) {} + + pub(crate) fn rmsg_ob(&self, _sid: Sid, _reason: RemoveReason, _b: u64) {} + + pub(crate) fn rdata_frames_t(&self) {} + + pub(crate) fn rdata_frames_b(&self, _b: u64) {} +} + +impl RemoveReason { + #[cfg(feature = "metrics")] + fn to_str(&self) -> &str { + match self { + RemoveReason::Dropped => "Dropped", + RemoveReason::Finished => "Finished", + } + } +} diff --git a/network/protocol/src/mpsc.rs b/network/protocol/src/mpsc.rs new file mode 100644 index 0000000000..3e9e5d55fe --- /dev/null +++ b/network/protocol/src/mpsc.rs @@ -0,0 +1,217 @@ +use crate::{ + event::ProtocolEvent, + frame::InitFrame, + handshake::{ReliableDrain, ReliableSink}, + io::{UnreliableDrain, UnreliableSink}, + metrics::{ProtocolMetricCache, RemoveReason}, + types::Bandwidth, + ProtocolError, RecvProtocol, SendProtocol, +}; +use async_trait::async_trait; +use std::time::{Duration, Instant}; + +pub /* should be private */ enum MpscMsg { + Event(ProtocolEvent), + InitFrame(InitFrame), +} + +#[derive(Debug)] +pub struct MpscSendProtcol +where + D: UnreliableDrain, +{ + drain: D, + last: Instant, + metrics: ProtocolMetricCache, +} + +#[derive(Debug)] +pub struct MpscRecvProtcol +where + S: UnreliableSink, +{ + sink: S, + metrics: ProtocolMetricCache, +} + +impl MpscSendProtcol +where + D: UnreliableDrain, +{ + pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self { + Self { + drain, + last: Instant::now(), + metrics, + } + } +} + +impl MpscRecvProtcol +where + S: UnreliableSink, +{ + pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self { Self { sink, metrics } } +} + +#[async_trait] +impl SendProtocol for MpscSendProtcol +where + D: UnreliableDrain, +{ + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + match &event { + ProtocolEvent::Message { + buffer, + mid: _, + sid, + } => { + let sid = *sid; + let bytes = buffer.data.len() as u64; + self.metrics.smsg_it(sid); + self.metrics.smsg_ib(sid, bytes); + let r = self.drain.send(MpscMsg::Event(event)).await; + self.metrics.smsg_ot(sid, RemoveReason::Finished); + self.metrics.smsg_ob(sid, RemoveReason::Finished, bytes); + r + }, + _ => self.drain.send(MpscMsg::Event(event)).await, + } + } + + async fn flush(&mut self, _: Bandwidth, _: Duration) -> Result<(), ProtocolError> { Ok(()) } +} + +#[async_trait] +impl RecvProtocol for MpscRecvProtcol +where + S: UnreliableSink, +{ + async fn recv(&mut self) -> Result { + match self.sink.recv().await? { + MpscMsg::Event(e) => { + if let ProtocolEvent::Message { + buffer, + mid: _, + sid, + } = &e + { + let sid = *sid; + let bytes = buffer.data.len() as u64; + self.metrics.rmsg_it(sid); + self.metrics.rmsg_ib(sid, bytes); + self.metrics.rmsg_ot(sid, RemoveReason::Finished); + self.metrics.rmsg_ob(sid, RemoveReason::Finished, bytes); + } + Ok(e) + }, + MpscMsg::InitFrame(_) => Err(ProtocolError::Closed), + } + } +} + +#[async_trait] +impl ReliableDrain for MpscSendProtcol +where + D: UnreliableDrain, +{ + async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> { + self.drain.send(MpscMsg::InitFrame(frame)).await + } +} + +#[async_trait] +impl ReliableSink for MpscRecvProtcol +where + S: UnreliableSink, +{ + async fn recv(&mut self) -> Result { + match self.sink.recv().await? { + MpscMsg::Event(_) => Err(ProtocolError::Closed), + MpscMsg::InitFrame(f) => Ok(f), + } + } +} + +#[cfg(test)] +pub mod test_utils { + use super::*; + use crate::{ + io::*, + metrics::{ProtocolMetricCache, ProtocolMetrics}, + }; + use async_channel::*; + use std::sync::Arc; + + pub struct ACDrain { + sender: Sender, + } + + pub struct ACSink { + receiver: Receiver, + } + + pub fn ac_bound( + cap: usize, + metrics: Option, + ) -> [(MpscSendProtcol, MpscRecvProtcol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("mpsc", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + MpscSendProtcol::new(ACDrain { sender: s1 }, m.clone()), + MpscRecvProtcol::new(ACSink { receiver: r2 }, m.clone()), + ), + ( + MpscSendProtcol::new(ACDrain { sender: s2 }, m.clone()), + MpscRecvProtcol::new(ACSink { receiver: r1 }, m.clone()), + ), + ] + } + + #[async_trait] + impl UnreliableDrain for ACDrain { + type DataFormat = MpscMsg; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for ACSink { + type DataFormat = MpscMsg; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + mpsc::test_utils::*, + types::{Pid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2}, + InitProtocol, + }; + + #[tokio::test] + async fn handshake_all_good() { + let [mut p1, mut p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42))); + assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337))); + } +} diff --git a/network/protocol/src/prio.rs b/network/protocol/src/prio.rs new file mode 100644 index 0000000000..35b7067352 --- /dev/null +++ b/network/protocol/src/prio.rs @@ -0,0 +1,139 @@ +use crate::{ + frame::Frame, + message::{MessageBuffer, OutgoingMessage}, + metrics::{ProtocolMetricCache, RemoveReason}, + types::{Bandwidth, Mid, Prio, Promises, Sid}, +}; +use std::{collections::HashMap, sync::Arc, time::Duration}; + +#[derive(Debug)] +struct StreamInfo { + pub(crate) guaranteed_bandwidth: Bandwidth, + pub(crate) prio: Prio, + pub(crate) promises: Promises, + pub(crate) messages: Vec, +} + +/// Responsible for queueing messages. +/// every stream has a guaranteed bandwidth and a prio 0-7. +/// when `n` Bytes are available in the buffer, first the guaranteed bandwidth +/// is used. Then remaining bandwidth is used to fill up the prios. +#[derive(Debug)] +pub(crate) struct PrioManager { + streams: HashMap, + metrics: ProtocolMetricCache, +} + +// Send everything ONCE, then keep it till it's confirmed + +impl PrioManager { + const HIGHEST_PRIO: u8 = 7; + + pub fn new(metrics: ProtocolMetricCache) -> Self { + Self { + streams: HashMap::new(), + metrics, + } + } + + pub fn open_stream( + &mut self, + sid: Sid, + prio: Prio, + promises: Promises, + guaranteed_bandwidth: Bandwidth, + ) { + self.streams.insert(sid, StreamInfo { + guaranteed_bandwidth, + prio, + promises, + messages: vec![], + }); + } + + pub fn try_close_stream(&mut self, sid: Sid) -> bool { + if let Some(si) = self.streams.get(&sid) { + if si.messages.is_empty() { + self.streams.remove(&sid); + return true; + } + } + false + } + + pub fn is_empty(&self) -> bool { self.streams.is_empty() } + + pub fn add(&mut self, buffer: Arc, mid: Mid, sid: Sid) { + self.streams + .get_mut(&sid) + .unwrap() + .messages + .push(OutgoingMessage::new(buffer, mid, sid)); + } + + /// bandwidth might be extended, as for technical reasons + /// guaranteed_bandwidth is used and frames are always 1400 bytes. + pub fn grab(&mut self, bandwidth: Bandwidth, dt: Duration) -> Vec { + let total_bytes = (bandwidth as f64 * dt.as_secs_f64()) as u64; + let mut cur_bytes = 0u64; + let mut frames = vec![]; + + let mut prios = [0u64; (Self::HIGHEST_PRIO + 1) as usize]; + let metrics = &self.metrics; + + let mut process_stream = + |stream: &mut StreamInfo, mut bandwidth: i64, cur_bytes: &mut u64| { + let mut finished = vec![]; + 'outer: for (i, msg) in stream.messages.iter_mut().enumerate() { + while let Some(frame) = msg.next() { + let b = if matches!(frame, Frame::DataHeader { .. }) { + 25 + } else { + 19 + OutgoingMessage::FRAME_DATA_SIZE + }; + bandwidth -= b as i64; + *cur_bytes += b; + frames.push(frame); + if bandwidth <= 0 { + break 'outer; + } + } + finished.push(i); + } + + //cleanup + for i in finished.iter().rev() { + let msg = stream.messages.remove(*i); + let (sid, bytes) = msg.get_sid_len(); + metrics.smsg_ot(sid, RemoveReason::Finished); + metrics.smsg_ob(sid, RemoveReason::Finished, bytes); + } + }; + + // Add guaranteed bandwidth + for (_, stream) in &mut self.streams { + prios[stream.prio.min(Self::HIGHEST_PRIO) as usize] += 1; + let stream_byte_cnt = (stream.guaranteed_bandwidth as f64 * dt.as_secs_f64()) as u64; + process_stream(stream, stream_byte_cnt as i64, &mut cur_bytes); + } + + if cur_bytes < total_bytes { + // Add optional bandwidth + for prio in 0..=Self::HIGHEST_PRIO { + if prios[prio as usize] == 0 { + continue; + } + let per_stream_bytes = (total_bytes - cur_bytes) / prios[prio as usize]; + + for (_, stream) in &mut self.streams { + if stream.prio != prio { + continue; + } + process_stream(stream, per_stream_bytes as i64, &mut cur_bytes); + } + } + } + + frames + } +} diff --git a/network/protocol/src/tcp.rs b/network/protocol/src/tcp.rs new file mode 100644 index 0000000000..e1c8e10e84 --- /dev/null +++ b/network/protocol/src/tcp.rs @@ -0,0 +1,584 @@ +use crate::{ + event::ProtocolEvent, + frame::{Frame, InitFrame}, + handshake::{ReliableDrain, ReliableSink}, + io::{UnreliableDrain, UnreliableSink}, + metrics::{ProtocolMetricCache, RemoveReason}, + prio::PrioManager, + types::Bandwidth, + ProtocolError, RecvProtocol, SendProtocol, +}; +use async_trait::async_trait; +use std::{ + collections::{HashMap, VecDeque}, + sync::Arc, + time::{Duration, Instant}, +}; +use tracing::info; + +#[derive(Debug)] +pub struct TcpSendProtcol +where + D: UnreliableDrain>, +{ + buffer: Vec, + store: PrioManager, + closing_streams: Vec, + pending_shutdown: bool, + drain: D, + last: Instant, + metrics: ProtocolMetricCache, +} + +#[derive(Debug)] +pub struct TcpRecvProtcol +where + S: UnreliableSink>, +{ + buffer: VecDeque, + incoming: HashMap, + sink: S, + metrics: ProtocolMetricCache, +} + +impl TcpSendProtcol +where + D: UnreliableDrain>, +{ + pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self { + Self { + buffer: vec![0u8; 1500], + store: PrioManager::new(metrics.clone()), + closing_streams: vec![], + pending_shutdown: false, + drain, + last: Instant::now(), + metrics, + } + } +} + +impl TcpRecvProtcol +where + S: UnreliableSink>, +{ + pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self { + Self { + buffer: VecDeque::new(), + incoming: HashMap::new(), + sink, + metrics, + } + } +} + +#[async_trait] +impl SendProtocol for TcpSendProtcol +where + D: UnreliableDrain>, +{ + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + match event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + self.store + .open_stream(sid, prio, promises, guaranteed_bandwidth); + let frame = event.to_frame(); + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + }, + ProtocolEvent::CloseStream { sid } => { + if self.store.try_close_stream(sid) { + let frame = event.to_frame(); + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + } else { + self.closing_streams.push(sid); + } + }, + ProtocolEvent::Shutdown => { + if self.store.is_empty() { + tracing::error!(?event, "send frame"); + let frame = event.to_frame(); + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + } else { + self.pending_shutdown = true; + } + }, + ProtocolEvent::Message { buffer, mid, sid } => { + self.metrics.smsg_it(sid); + self.metrics.smsg_ib(sid, buffer.data.len() as u64); + self.store.add(buffer, mid, sid); + }, + } + Ok(()) + } + + async fn flush(&mut self, bandwidth: Bandwidth, dt: Duration) -> Result<(), ProtocolError> { + let frames = self.store.grab(bandwidth, dt); + for frame in frames { + if let Frame::Data { + mid: _, + start: _, + data, + } = &frame + { + self.metrics.sdata_frames_t(); + self.metrics.sdata_frames_b(data.len() as u64); + } + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + tracing::warn!("send data frame, woop"); + } + let mut finished_streams = vec![]; + for (i, sid) in self.closing_streams.iter().enumerate() { + if self.store.try_close_stream(*sid) { + let frame = ProtocolEvent::CloseStream { sid: *sid }.to_frame(); + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + finished_streams.push(i); + } + } + for i in finished_streams.iter().rev() { + self.closing_streams.remove(*i); + } + if self.pending_shutdown && self.store.is_empty() { + tracing::error!("send shutdown frame"); + let frame = ProtocolEvent::Shutdown {}.to_frame(); + let (s, _) = frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer[..s].to_vec()).await?; + self.pending_shutdown = false; + } + Ok(()) + } +} + +use crate::{ + message::MessageBuffer, + types::{Mid, Sid}, +}; + +#[derive(Debug)] +struct IncomingMsg { + sid: Sid, + length: u64, + data: MessageBuffer, +} + +#[async_trait] +impl RecvProtocol for TcpRecvProtcol +where + S: UnreliableSink>, +{ + async fn recv(&mut self) -> Result { + tracing::error!(?self.buffer, "enter loop"); + 'outer: loop { + tracing::error!(?self.buffer, "continue loop"); + while let Some(frame) = Frame::to_frame(&mut self.buffer) { + tracing::error!(?frame, "recv frame"); + match frame { + Frame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), + Frame::OpenStream { + sid, + prio, + promises, + } => { + break 'outer Ok(ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth: 1_000_000, + }); + }, + Frame::CloseStream { sid } => { + break 'outer Ok(ProtocolEvent::CloseStream { sid }); + }, + Frame::DataHeader { sid, mid, length } => { + let m = IncomingMsg { + sid, + length, + data: MessageBuffer { data: vec![] }, + }; + self.metrics.rmsg_it(sid); + self.metrics.rmsg_ib(sid, length); + self.incoming.insert(mid, m); + }, + Frame::Data { + mid, + start: _, + mut data, + } => { + self.metrics.rdata_frames_t(); + self.metrics.rdata_frames_b(data.len() as u64); + let m = match self.incoming.get_mut(&mid) { + Some(m) => m, + None => { + info!("protocol violation by remote side: send Data before Header"); + break 'outer Err(ProtocolError::Closed); + }, + }; + m.data.data.append(&mut data); + if m.data.data.len() == m.length as usize { + // finished, yay + drop(m); + let m = self.incoming.remove(&mid).unwrap(); + self.metrics.rmsg_ot(m.sid, RemoveReason::Finished); + self.metrics.rmsg_ob( + m.sid, + RemoveReason::Finished, + m.data.data.len() as u64, + ); + break 'outer Ok(ProtocolEvent::Message { + sid: m.sid, + mid, + buffer: Arc::new(m.data), + }); + } + }, + }; + } + tracing::error!(?self.buffer, "receiving on tcp sink"); + let chunk = self.sink.recv().await?; + self.buffer.reserve(chunk.len()); + for b in chunk { + self.buffer.push_back(b); + } + tracing::error!(?self.buffer,"receiving on tcp sink done"); + } + } +} + +#[async_trait] +impl ReliableDrain for TcpSendProtcol +where + D: UnreliableDrain>, +{ + async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> { + let mut buffer = vec![0u8; 1500]; + let s = frame.to_bytes(&mut buffer); + buffer.truncate(s); + self.drain.send(buffer).await + } +} + +#[async_trait] +impl ReliableSink for TcpRecvProtcol +where + S: UnreliableSink>, +{ + async fn recv(&mut self) -> Result { + while self.buffer.len() < 100 { + let chunk = self.sink.recv().await?; + self.buffer.reserve(chunk.len()); + for b in chunk { + self.buffer.push_back(b); + } + let todo_use_bytes_instead = self.buffer.iter().map(|b| *b).collect(); + if let Some(frame) = InitFrame::to_frame(todo_use_bytes_instead) { + match frame { + InitFrame::Handshake { .. } => self.buffer.drain(.. InitFrame::HANDSHAKE_CNS + 1), + InitFrame::Init { .. } => self.buffer.drain(.. InitFrame::INIT_CNS + 1), + InitFrame::Raw { .. } => self.buffer.drain(.. InitFrame::RAW_CNS + 1), + }; + return Ok(frame); + } + } + Err(ProtocolError::Closed) + } +} + +#[cfg(test)] +mod test_utils { + //TCP protocol based on Channel + use super::*; + use crate::{ + io::*, + metrics::{ProtocolMetricCache, ProtocolMetrics}, + }; + use async_channel::*; + + pub struct TcpDrain { + pub sender: Sender>, + } + + pub struct TcpSink { + pub receiver: Receiver>, + } + + /// emulate Tcp protocol on Channels + pub fn tcp_bound( + cap: usize, + metrics: Option, + ) -> [(TcpSendProtcol, TcpRecvProtcol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("tcp", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + TcpSendProtcol::new(TcpDrain { sender: s1 }, m.clone()), + TcpRecvProtcol::new(TcpSink { receiver: r2 }, m.clone()), + ), + ( + TcpSendProtcol::new(TcpDrain { sender: s2 }, m.clone()), + TcpRecvProtcol::new(TcpSink { receiver: r1 }, m.clone()), + ), + ] + } + + #[async_trait] + impl UnreliableDrain for TcpDrain { + type DataFormat = Vec; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for TcpSink { + type DataFormat = Vec; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + metrics::{ProtocolMetricCache, ProtocolMetrics, RemoveReason}, + tcp::test_utils::*, + types::{Pid, Promises, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2}, + InitProtocol, MessageBuffer, ProtocolEvent, RecvProtocol, SendProtocol, + }; + use std::{sync::Arc, time::Duration}; + + #[tokio::test] + async fn handshake_all_good() { + let [mut p1, mut p2] = tcp_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42))); + assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337))); + } + + #[tokio::test] + async fn open_stream() { + let [p1, p2] = tcp_bound(10, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 9u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event.clone()).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn send_short_msg() { + let [p1, p2] = tcp_bound(10, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + mid: 0, + buffer: Arc::new(MessageBuffer { + data: vec![188u8; 600], + }), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + // 2nd short message + let event = ProtocolEvent::Message { + sid: Sid::new(10), + mid: 1, + buffer: Arc::new(MessageBuffer { + data: vec![7u8; 30], + }), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e) + } + + #[tokio::test] + async fn send_long_msg() { + let metrics = + ProtocolMetricCache::new("long_tcp", Arc::new(ProtocolMetrics::new().unwrap())); + let sid = Sid::new(1); + let [p1, p2] = tcp_bound(10000, Some(metrics.clone())); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + mid: 77, + buffer: Arc::new(MessageBuffer { + data: vec![99u8; 500_000], + }), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + metrics.assert_msg(sid, 1, RemoveReason::Finished); + metrics.assert_msg_bytes(sid, 500_000, RemoveReason::Finished); + metrics.assert_data_frames(358); + metrics.assert_data_frames_bytes(500_000); + } + + #[tokio::test] + async fn msg_finishes_after_close() { + let sid = Sid::new(1); + let [p1, p2] = tcp_bound(10000, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + mid: 77, + buffer: Arc::new(MessageBuffer { + data: vec![99u8; 500_000], + }), + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::CloseStream { sid }; + s.send(event).await.unwrap(); + //send + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + } + + #[tokio::test] + async fn msg_finishes_after_shutdown() { + let sid = Sid::new(1); + let [p1, p2] = tcp_bound(10000, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + mid: 77, + buffer: Arc::new(MessageBuffer { + data: vec![99u8; 500_000], + }), + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::Shutdown {}; + s.send(event).await.unwrap(); + let event = ProtocolEvent::CloseStream { sid }; + s.send(event).await.unwrap(); + //send + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Shutdown { .. })); + } + + #[tokio::test] + async fn header_and_data_in_seperate_msg() { + let sid = Sid::new(1); + let (s, r) = async_channel::bounded(10); + let m = ProtocolMetricCache::new("tcp", Arc::new(ProtocolMetrics::new().unwrap())); + let mut r = + super::TcpRecvProtcol::new(super::test_utils::TcpSink { receiver: r }, m.clone()); + + const DATA1: &[u8; 69] = + b"We need to make sure that its okay to send OPEN_STREAM and DATA_HEAD "; + const DATA2: &[u8; 95] = b"in one chunk and (DATA and CLOSE_STREAM) in the second chunk. and then keep the connection open"; + let mut buf = vec![0u8; 1500]; + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 0, + }; + let (i, _) = event.to_frame().to_bytes(&mut buf); + let (i2, _) = crate::frame::Frame::DataHeader { + mid: 99, + sid, + length: (DATA1.len() + DATA2.len()) as u64, + } + .to_bytes(&mut buf[i..]); + buf.truncate(i + i2); + s.send(buf).await.unwrap(); + + let mut buf = vec![0u8; 1500]; + let (i, _) = crate::frame::Frame::Data { + mid: 99, + start: 0, + data: DATA1.to_vec(), + } + .to_bytes(&mut buf); + let (i2, _) = crate::frame::Frame::Data { + mid: 99, + start: DATA1.len() as u64, + data: DATA2.to_vec(), + } + .to_bytes(&mut buf[i..]); + let (i3, _) = crate::frame::Frame::CloseStream { sid }.to_bytes(&mut buf[i + i2..]); + buf.truncate(i + i2 + i3); + s.send(buf).await.unwrap(); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + } +} diff --git a/network/src/types.rs b/network/protocol/src/types.rs similarity index 52% rename from network/src/types.rs rename to network/protocol/src/types.rs index d257ed808f..b6f63ca208 100644 --- a/network/src/types.rs +++ b/network/protocol/src/types.rs @@ -1,10 +1,10 @@ use bitflags::bitflags; use rand::Rng; -use std::convert::TryFrom; pub type Mid = u64; pub type Cid = u64; pub type Prio = u8; +pub type Bandwidth = u64; bitflags! { /// use promises to modify the behavior of [`Streams`]. @@ -21,9 +21,8 @@ bitflags! { /// this will guarantee that the other side will receive every message exactly /// once no messages are dropped const GUARANTEED_DELIVERY = 0b00000100; - /// this will enable the internal compression on this + /// this will enable the internal compression on this, only useable with #[cfg(feature = "compression")] /// [`Stream`](crate::api::Stream) - #[cfg(feature = "compression")] const COMPRESSED = 0b00001000; /// this will enable the internal encryption on this /// [`Stream`](crate::api::Stream) @@ -35,7 +34,7 @@ impl Promises { pub const fn to_le_bytes(self) -> [u8; 1] { self.bits.to_le_bytes() } } -pub(crate) const VELOREN_MAGIC_NUMBER: [u8; 7] = [86, 69, 76, 79, 82, 69, 78]; //VELOREN +pub(crate) const VELOREN_MAGIC_NUMBER: [u8; 7] = *b"VELOREN"; pub const VELOREN_NETWORK_VERSION: [u32; 3] = [0, 5, 0]; pub(crate) const STREAM_ID_OFFSET1: Sid = Sid::new(0); pub(crate) const STREAM_ID_OFFSET2: Sid = Sid::new(u64::MAX / 2); @@ -51,144 +50,18 @@ pub struct Pid { } #[derive(PartialEq, Eq, Hash, Clone, Copy)] -pub(crate) struct Sid { +pub struct Sid { internal: u64, } -// Used for Communication between Channel <----(TCP/UDP)----> Channel -#[derive(Debug)] -pub(crate) enum Frame { - Handshake { - magic_number: [u8; 7], - version: [u32; 3], - }, - Init { - pid: Pid, - secret: u128, - }, - Shutdown, /* Shutdown this channel gracefully, if all channels are shutdown, Participant - * is deleted */ - OpenStream { - sid: Sid, - prio: Prio, - promises: Promises, - }, - CloseStream { - sid: Sid, - }, - DataHeader { - mid: Mid, - sid: Sid, - length: u64, - }, - Data { - mid: Mid, - start: u64, - data: Vec, - }, - /* WARNING: Sending RAW is only used for debug purposes in case someone write a new API - * against veloren Server! */ - Raw(Vec), -} - -impl Frame { - #[cfg(feature = "metrics")] - pub const FRAMES_LEN: u8 = 8; - - #[cfg(feature = "metrics")] - pub const fn int_to_string(i: u8) -> &'static str { - match i { - 0 => "Handshake", - 1 => "Init", - 2 => "Shutdown", - 3 => "OpenStream", - 4 => "CloseStream", - 5 => "DataHeader", - 6 => "Data", - 7 => "Raw", - _ => "", - } - } - - #[cfg(feature = "metrics")] - pub fn get_int(&self) -> u8 { - match self { - Frame::Handshake { .. } => 0, - Frame::Init { .. } => 1, - Frame::Shutdown => 2, - Frame::OpenStream { .. } => 3, - Frame::CloseStream { .. } => 4, - Frame::DataHeader { .. } => 5, - Frame::Data { .. } => 6, - Frame::Raw(_) => 7, - } - } - - #[cfg(feature = "metrics")] - pub fn get_string(&self) -> &str { Self::int_to_string(self.get_int()) } - - pub fn gen_handshake(buf: [u8; 19]) -> Self { - let magic_number = *<&[u8; 7]>::try_from(&buf[0..7]).unwrap(); - Frame::Handshake { - magic_number, - version: [ - u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[7..11]).unwrap()), - u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[11..15]).unwrap()), - u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[15..19]).unwrap()), - ], - } - } - - pub fn gen_init(buf: [u8; 32]) -> Self { - Frame::Init { - pid: Pid::from_le_bytes(*<&[u8; 16]>::try_from(&buf[0..16]).unwrap()), - secret: u128::from_le_bytes(*<&[u8; 16]>::try_from(&buf[16..32]).unwrap()), - } - } - - pub fn gen_open_stream(buf: [u8; 10]) -> Self { - Frame::OpenStream { - sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), - prio: buf[8], - promises: Promises::from_bits_truncate(buf[9]), - } - } - - pub fn gen_close_stream(buf: [u8; 8]) -> Self { - Frame::CloseStream { - sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), - } - } - - pub fn gen_data_header(buf: [u8; 24]) -> Self { - Frame::DataHeader { - mid: Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), - sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()), - length: u64::from_le_bytes(*<&[u8; 8]>::try_from(&buf[16..24]).unwrap()), - } - } - - pub fn gen_data(buf: [u8; 18]) -> (Mid, u64, u16) { - let mid = Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()); - let start = u64::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()); - let length = u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[16..18]).unwrap()); - (mid, start, length) - } - - pub fn gen_raw(buf: [u8; 2]) -> u16 { - u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[0..2]).unwrap()) - } -} - impl Pid { /// create a new Pid with a random interior value /// /// # Example /// ```rust - /// use veloren_network::{Network, Pid}; + /// use veloren_network_protocol::Pid; /// /// let pid = Pid::new(); - /// let _ = Network::new(pid); /// ``` pub fn new() -> Self { Self { @@ -295,20 +168,7 @@ fn sixlet_to_str(sixlet: u128) -> char { #[cfg(test)] mod tests { - use crate::types::*; - - #[test] - fn frame_int2str() { - assert_eq!(Frame::int_to_string(3), "OpenStream"); - assert_eq!(Frame::int_to_string(7), "Raw"); - assert_eq!(Frame::int_to_string(8), ""); - } - - #[test] - fn frame_get_int() { - assert_eq!(Frame::get_int(&Frame::Raw(b"Foo".to_vec())), 7); - assert_eq!(Frame::get_int(&Frame::Shutdown), 2); - } + use super::*; #[test] fn frame_creation() { diff --git a/network/protocol/src/udp.rs b/network/protocol/src/udp.rs new file mode 100644 index 0000000000..ad5c31a126 --- /dev/null +++ b/network/protocol/src/udp.rs @@ -0,0 +1,37 @@ +// TODO: quick and dirty which activly waits for an ack! +/* +UDP protocol + +All Good Case: +S --HEADER--> R +S --DATA--> R +S --DATA--> R +S <--FINISHED-- R + + +Delayed HEADER: +S --HEADER--> +S --DATA--> R // STORE IT + --HEADER--> R // apply left data and continue +S --DATA--> R +S <--FINISHED-- R + + +NO HEADER: +S --HEADER--> ! +S --DATA--> R // STORE IT +S --DATA--> R // STORE IT +S <--MISSING_HEADER-- R // SEND AFTER 10 ms after DATA1 +S --HEADER--> R +S <--FINISHED-- R + + +NO DATA: +S --HEADER--> R +S --DATA--> R +S --DATA--> ! +S --STATUS--> R +S <--MISSING_DATA -- R +S --DATA--> R +S <--FINISHED-- R +*/ diff --git a/network/src/api.rs b/network/src/api.rs index ef6fb113db..08274c90be 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -3,13 +3,13 @@ //! //! (cd network/examples/async_recv && RUST_BACKTRACE=1 cargo run) use crate::{ - message::{partial_eq_bincode, IncomingMessage, Message, OutgoingMessage}, + message::{partial_eq_bincode, Message}, participant::{A2bStreamOpen, S2bShutdownBparticipant}, scheduler::Scheduler, - types::{Mid, Pid, Prio, Promises, Sid}, }; #[cfg(feature = "compression")] use lz_fear::raw::DecodeError; +use network_protocol::{Bandwidth, MessageBuffer, Mid, Pid, Prio, Promises, Sid}; #[cfg(feature = "metrics")] use prometheus::Registry; use serde::{de::DeserializeOwned, Serialize}; @@ -20,6 +20,7 @@ use std::{ atomic::{AtomicBool, Ordering}, Arc, }, + time::Duration, }; use tokio::{ io, @@ -49,8 +50,7 @@ pub enum ProtocolAddr { pub struct Participant { local_pid: Pid, remote_pid: Pid, - runtime: Arc, - a2b_stream_open_s: Mutex>, + a2b_open_stream_s: Mutex>, b2a_stream_opened_r: Mutex>, a2s_disconnect_s: A2sDisconnect, } @@ -75,9 +75,10 @@ pub struct Stream { mid: Mid, prio: Prio, promises: Promises, + guaranteed_bandwidth: Bandwidth, send_closed: Arc, - a2b_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, - b2a_msg_recv_r: Option>, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, + b2a_msg_recv_r: Option>, a2b_close_stream_s: Option>, } @@ -419,16 +420,14 @@ impl Participant { pub(crate) fn new( local_pid: Pid, remote_pid: Pid, - runtime: Arc, - a2b_stream_open_s: mpsc::UnboundedSender, + a2b_open_stream_s: mpsc::UnboundedSender, b2a_stream_opened_r: mpsc::UnboundedReceiver, a2s_disconnect_s: mpsc::UnboundedSender<(Pid, S2bShutdownBparticipant)>, ) -> Self { Self { local_pid, remote_pid, - runtime, - a2b_stream_open_s: Mutex::new(a2b_stream_open_s), + a2b_open_stream_s: Mutex::new(a2b_open_stream_s), b2a_stream_opened_r: Mutex::new(b2a_stream_opened_r), a2s_disconnect_s: Arc::new(Mutex::new(Some(a2s_disconnect_s))), } @@ -477,13 +476,13 @@ impl Participant { /// /// [`Streams`]: crate::api::Stream pub async fn open(&self, prio: u8, promises: Promises) -> Result { - let (p2a_return_stream_s, p2a_return_stream_r) = oneshot::channel(); - if let Err(e) = - self.a2b_stream_open_s - .lock() - .await - .send((prio, promises, p2a_return_stream_s)) - { + let (p2a_return_stream_s, p2a_return_stream_r) = oneshot::channel::(); + if let Err(e) = self.a2b_open_stream_s.lock().await.send(( + prio, + promises, + 100000u64, + p2a_return_stream_s, + )) { debug!(?e, "bParticipant is already closed, notifying"); return Err(ParticipantError::ParticipantDisconnected); } @@ -602,7 +601,7 @@ impl Participant { // Participant is connecting to Scheduler here, not as usual // Participant<->BParticipant a2s_disconnect_s - .send((pid, finished_sender)) + .send((pid, (Duration::from_secs(120), finished_sender))) .expect("Something is wrong in internal scheduler coding"); match finished_receiver.await { Ok(res) => { @@ -647,9 +646,10 @@ impl Stream { sid: Sid, prio: Prio, promises: Promises, + guaranteed_bandwidth: Bandwidth, send_closed: Arc, - a2b_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, - b2a_msg_recv_r: async_channel::Receiver, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, + b2a_msg_recv_r: async_channel::Receiver, a2b_close_stream_s: mpsc::UnboundedSender, ) -> Self { Self { @@ -658,6 +658,7 @@ impl Stream { mid: 0, prio, promises, + guaranteed_bandwidth, send_closed, a2b_msg_s, b2a_msg_recv_r: Some(b2a_msg_recv_r), @@ -776,12 +777,8 @@ impl Stream { } #[cfg(debug_assertions)] message.verify(&self); - self.a2b_msg_s.send((self.prio, self.sid, OutgoingMessage { - buffer: Arc::clone(&message.buffer), - cursor: 0, - mid: self.mid, - sid: self.sid, - }))?; + self.a2b_msg_s + .send((self.sid, Arc::clone(&message.buffer)))?; self.mid += 1; Ok(()) } @@ -864,7 +861,7 @@ impl Stream { Some(b2a_msg_recv_r) => { match b2a_msg_recv_r.recv().await { Ok(msg) => Ok(Message { - buffer: Arc::new(msg.buffer), + buffer: Arc::new(msg), #[cfg(feature = "compression")] compressed: self.promises.contains(Promises::COMPRESSED), }), @@ -917,7 +914,7 @@ impl Stream { Some(b2a_msg_recv_r) => match b2a_msg_recv_r.try_recv() { Ok(msg) => Ok(Some( Message { - buffer: Arc::new(msg.buffer), + buffer: Arc::new(msg), #[cfg(feature = "compression")] compressed: self.promises().contains(Promises::COMPRESSED), } @@ -953,47 +950,62 @@ impl Drop for Network { "Shutting down Participants of Network, while we still have metrics" ); let mut finished_receiver_list = vec![]; - self.runtime.block_on(async { - // we MUST avoid nested block_on, good that Network::Drop no longer triggers - // Participant::Drop directly but just the BParticipant - for (remote_pid, a2s_disconnect_s) in - self.participant_disconnect_sender.lock().await.drain() - { - match a2s_disconnect_s.lock().await.take() { - Some(a2s_disconnect_s) => { - trace!(?remote_pid, "Participants will be closed"); - let (finished_sender, finished_receiver) = oneshot::channel(); - finished_receiver_list.push((remote_pid, finished_receiver)); - a2s_disconnect_s.send((remote_pid, finished_sender)).expect( - "Scheduler is closed, but nobody other should be able to close it", - ); - }, - None => trace!(?remote_pid, "Participant already disconnected gracefully"), + + if tokio::runtime::Handle::try_current().is_ok() { + error!("we have a runtime but we mustn't, DROP NETWORK from async runtime is illegal") + } + + tokio::task::block_in_place(|| { + /* This context prevents panic if Dropped in a async fn */ + self.runtime.block_on(async { + for (remote_pid, a2s_disconnect_s) in + self.participant_disconnect_sender.lock().await.drain() + { + match a2s_disconnect_s.lock().await.take() { + Some(a2s_disconnect_s) => { + trace!(?remote_pid, "Participants will be closed"); + let (finished_sender, finished_receiver) = oneshot::channel(); + finished_receiver_list.push((remote_pid, finished_receiver)); + a2s_disconnect_s + .send((remote_pid, (Duration::from_secs(120), finished_sender))) + .expect( + "Scheduler is closed, but nobody other should be able to \ + close it", + ); + }, + None => trace!(?remote_pid, "Participant already disconnected gracefully"), + } } - } - //wait after close is requested for all - for (remote_pid, finished_receiver) in finished_receiver_list.drain(..) { - match finished_receiver.await { - Ok(Ok(())) => trace!(?remote_pid, "disconnect successful"), - Ok(Err(e)) => info!(?remote_pid, ?e, "unclean disconnect"), - Err(e) => warn!( - ?remote_pid, - ?e, - "Failed to get a message back from the scheduler, seems like the network \ - is already closed" - ), + //wait after close is requested for all + for (remote_pid, finished_receiver) in finished_receiver_list.drain(..) { + match finished_receiver.await { + Ok(Ok(())) => trace!(?remote_pid, "disconnect successful"), + Ok(Err(e)) => info!(?remote_pid, ?e, "unclean disconnect"), + Err(e) => warn!( + ?remote_pid, + ?e, + "Failed to get a message back from the scheduler, seems like the \ + network is already closed" + ), + } } - } + }); }); trace!(?pid, "Participants have shut down!"); trace!(?pid, "Shutting down Scheduler"); - self.shutdown_sender.take().unwrap().send(()).expect("Scheduler is closed, but nobody other should be able to close it"); + self.shutdown_sender + .take() + .unwrap() + .send(()) + .expect("Scheduler is closed, but nobody other should be able to close it"); debug!(?pid, "Network has shut down"); } } impl Drop for Participant { fn drop(&mut self) { + use tokio::sync::oneshot::error::TryRecvError; + // ignore closed, as we need to send it even though we disconnected the // participant from network let pid = self.remote_pid; @@ -1011,23 +1023,28 @@ impl Drop for Participant { ), Some(a2s_disconnect_s) => { debug!(?pid, "Disconnect from Scheduler"); - self.runtime.block_on(async { - let (finished_sender, finished_receiver) = oneshot::channel(); - a2s_disconnect_s - .send((self.remote_pid, finished_sender)) - .expect("Something is wrong in internal scheduler coding"); - if let Err(e) = finished_receiver - .await - .expect("Something is wrong in internal scheduler/participant coding") - { - error!( + let (finished_sender, mut finished_receiver) = oneshot::channel(); + a2s_disconnect_s + .send((self.remote_pid, (Duration::from_secs(120), finished_sender))) + .expect("Something is wrong in internal scheduler coding"); + loop { + match finished_receiver.try_recv() { + Ok(Ok(())) => break, + Ok(Err(e)) => error!( ?pid, ?e, "Error while dropping the participant, couldn't send all outgoing \ messages, dropping remaining" - ); - }; - }); + ), + Err(TryRecvError::Closed) => { + panic!("Something is wrong in internal scheduler/participant coding") + }, + Err(TryRecvError::Empty) => { + trace!("activly sleeping"); + std::thread::sleep(Duration::from_millis(20)); + }, + } + } }, } debug!(?pid, "Participant dropped"); @@ -1041,11 +1058,12 @@ impl Drop for Stream { let sid = self.sid; let pid = self.pid; debug!(?pid, ?sid, "Shutting down Stream"); - self.a2b_close_stream_s - .take() - .unwrap() - .send(self.sid) - .expect("bparticipant part of a gracefully shutdown must have crashed"); + if let Err(e) = self.a2b_close_stream_s.take().unwrap().send(self.sid) { + debug!( + ?e, + "bparticipant part of a gracefully shutdown was already closed" + ); + } } else { let sid = self.sid; let pid = self.pid; diff --git a/network/src/channel.rs b/network/src/channel.rs index 7928337bd1..654175fb1d 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -1,361 +1,231 @@ -#[cfg(feature = "metrics")] -use crate::metrics::NetworkMetrics; -use crate::{ - participant::C2pFrame, - protocols::Protocols, - types::{ - Cid, Frame, Pid, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2, VELOREN_MAGIC_NUMBER, - VELOREN_NETWORK_VERSION, - }, -}; -use futures_core::task::Poll; -use futures_util::{ - task::{noop_waker, Context}, - FutureExt, +use async_trait::async_trait; +use network_protocol::{ + InitProtocolError, MpscMsg, MpscRecvProtcol, MpscSendProtcol, Pid, ProtocolError, + ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtcol, TcpSendProtcol, + UnreliableDrain, UnreliableSink, }; #[cfg(feature = "metrics")] use std::sync::Arc; +use std::time::Duration; use tokio::{ - join, - sync::{mpsc, oneshot}, + io::{AsyncReadExt, AsyncWriteExt}, + net::tcp::{OwnedReadHalf, OwnedWriteHalf}, + sync::mpsc, }; -use tracing::*; -pub(crate) struct Channel { - cid: Cid, - c2w_frame_r: Option>, - read_stop_receiver: Option>, -} - -impl Channel { - pub fn new(cid: u64) -> (Self, mpsc::UnboundedSender, oneshot::Sender<()>) { - let (c2w_frame_s, c2w_frame_r) = mpsc::unbounded_channel::(); - let (read_stop_sender, read_stop_receiver) = oneshot::channel(); - ( - Self { - cid, - c2w_frame_r: Some(c2w_frame_r), - read_stop_receiver: Some(read_stop_receiver), - }, - c2w_frame_s, - read_stop_sender, - ) - } - - pub async fn run( - mut self, - protocol: Protocols, - mut w2c_cid_frame_s: mpsc::UnboundedSender, - mut leftover_cid_frame: Vec, - ) { - let c2w_frame_r = self.c2w_frame_r.take().unwrap(); - let read_stop_receiver = self.read_stop_receiver.take().unwrap(); - - //reapply leftovers from handshake - let cnt = leftover_cid_frame.len(); - trace!(?cnt, "Reapplying leftovers"); - for cid_frame in leftover_cid_frame.drain(..) { - w2c_cid_frame_s.send(cid_frame).unwrap(); - } - trace!(?cnt, "All leftovers reapplied"); - - trace!("Start up channel"); - match protocol { - Protocols::Tcp(tcp) => { - join!( - tcp.read_from_wire(self.cid, &mut w2c_cid_frame_s, read_stop_receiver), - tcp.write_to_wire(self.cid, c2w_frame_r), - ); - }, - Protocols::Udp(udp) => { - join!( - udp.read_from_wire(self.cid, &mut w2c_cid_frame_s, read_stop_receiver), - udp.write_to_wire(self.cid, c2w_frame_r), - ); - }, - } - - trace!("Shut down channel"); - } +#[derive(Debug)] +pub(crate) enum Protocols { + Tcp((TcpSendProtcol, TcpRecvProtcol)), + Mpsc((MpscSendProtcol, MpscRecvProtcol)), } #[derive(Debug)] -pub(crate) struct Handshake { - cid: Cid, - local_pid: Pid, - secret: u128, - init_handshake: bool, - #[cfg(feature = "metrics")] - metrics: Arc, +pub(crate) enum SendProtocols { + Tcp(TcpSendProtcol), + Mpsc(MpscSendProtcol), } -impl Handshake { - #[cfg(debug_assertions)] - const WRONG_NUMBER: &'static [u8] = "Handshake does not contain the magic number required by \ - veloren server.\nWe are not sure if you are a valid \ - veloren client.\nClosing the connection" - .as_bytes(); - #[cfg(debug_assertions)] - const WRONG_VERSION: &'static str = "Handshake does contain a correct magic number, but \ - invalid version.\nWe don't know how to communicate with \ - you.\nClosing the connection"; +#[derive(Debug)] +pub(crate) enum RecvProtocols { + Tcp(TcpRecvProtcol), + Mpsc(MpscRecvProtcol), +} - pub fn new( - cid: u64, +impl Protocols { + pub(crate) fn new_tcp(stream: tokio::net::TcpStream) -> Self { + let (r, w) = stream.into_split(); + #[cfg(feature = "metrics")] + let metrics = ProtocolMetricCache::new( + "foooobaaaarrrrrrrr", + Arc::new(ProtocolMetrics::new().unwrap()), + ); + #[cfg(not(feature = "metrics"))] + let metrics = ProtocolMetricCache {}; + + let sp = TcpSendProtcol::new(TcpDrain { half: w }, metrics.clone()); + let rp = TcpRecvProtcol::new(TcpSink { half: r }, metrics.clone()); + Protocols::Tcp((sp, rp)) + } + + pub(crate) fn new_mpsc( + sender: mpsc::Sender, + receiver: mpsc::Receiver, + ) -> Self { + #[cfg(feature = "metrics")] + let metrics = + ProtocolMetricCache::new("mppppsssscccc", Arc::new(ProtocolMetrics::new().unwrap())); + #[cfg(not(feature = "metrics"))] + let metrics = ProtocolMetricCache {}; + + let sp = MpscSendProtcol::new(MpscDrain { sender }, metrics.clone()); + let rp = MpscRecvProtcol::new(MpscSink { receiver }, metrics.clone()); + Protocols::Mpsc((sp, rp)) + } + + pub(crate) fn split(self) -> (SendProtocols, RecvProtocols) { + match self { + Protocols::Tcp((s, r)) => (SendProtocols::Tcp(s), RecvProtocols::Tcp(r)), + Protocols::Mpsc((s, r)) => (SendProtocols::Mpsc(s), RecvProtocols::Mpsc(r)), + } + } +} + +#[async_trait] +impl network_protocol::InitProtocol for Protocols { + async fn initialize( + &mut self, + initializer: bool, local_pid: Pid, secret: u128, - #[cfg(feature = "metrics")] metrics: Arc, - init_handshake: bool, - ) -> Self { - Self { - cid, - local_pid, - secret, - #[cfg(feature = "metrics")] - metrics, - init_handshake, + ) -> Result<(Pid, Sid, u128), InitProtocolError> { + match self { + Protocols::Tcp(p) => p.initialize(initializer, local_pid, secret).await, + Protocols::Mpsc(p) => p.initialize(initializer, local_pid, secret).await, + } + } +} + +#[async_trait] +impl network_protocol::SendProtocol for SendProtocols { + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + match self { + SendProtocols::Tcp(s) => s.send(event).await, + SendProtocols::Mpsc(s) => s.send(event).await, + } + } + + async fn flush(&mut self, bandwidth: u64, dt: Duration) -> Result<(), ProtocolError> { + match self { + SendProtocols::Tcp(s) => s.flush(bandwidth, dt).await, + SendProtocols::Mpsc(s) => s.flush(bandwidth, dt).await, + } + } +} + +#[async_trait] +impl network_protocol::RecvProtocol for RecvProtocols { + async fn recv(&mut self) -> Result { + match self { + RecvProtocols::Tcp(r) => r.recv().await, + RecvProtocols::Mpsc(r) => r.recv().await, + } + } +} + +/////////////////////////////////////// +//// TCP +#[derive(Debug)] +pub struct TcpDrain { + half: OwnedWriteHalf, +} + +#[derive(Debug)] +pub struct TcpSink { + half: OwnedReadHalf, +} + +#[async_trait] +impl UnreliableDrain for TcpDrain { + type DataFormat = Vec; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + //self.half.recv + match self.half.write_all(&data).await { + Ok(()) => Ok(()), + Err(_) => Err(ProtocolError::Closed), + } + } +} + +#[async_trait] +impl UnreliableSink for TcpSink { + type DataFormat = Vec; + + async fn recv(&mut self) -> Result { + let mut data = vec![0u8; 1500]; + match self.half.read(&mut data).await { + Ok(n) => { + data.truncate(n); + Ok(data) + }, + Err(_) => Err(ProtocolError::Closed), + } + } +} + +/////////////////////////////////////// +//// MPSC +#[derive(Debug)] +pub struct MpscDrain { + sender: tokio::sync::mpsc::Sender, +} + +#[derive(Debug)] +pub struct MpscSink { + receiver: tokio::sync::mpsc::Receiver, +} + +#[async_trait] +impl UnreliableDrain for MpscDrain { + type DataFormat = MpscMsg; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } +} + +#[async_trait] +impl UnreliableSink for MpscSink { + type DataFormat = MpscMsg; + + async fn recv(&mut self) -> Result { + self.receiver.recv().await.ok_or(ProtocolError::Closed) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use network_protocol::{Promises, RecvProtocol, SendProtocol}; + use tokio::net::{TcpListener, TcpStream}; + + #[tokio::test] + async fn tokio_sinks() { + let listener = TcpListener::bind("127.0.0.1:5000").await.unwrap(); + let r1 = tokio::spawn(async move { + let (server, _) = listener.accept().await.unwrap(); + (listener, server) + }); + let client = TcpStream::connect("127.0.0.1:5000").await.unwrap(); + let (_listener, server) = r1.await.unwrap(); + let client = Protocols::new_tcp(client); + let server = Protocols::new_tcp(server); + let (mut s, _) = client.split(); + let (_, mut r) = server.split(); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(1), + prio: 4u8, + promises: Promises::GUARANTEED_DELIVERY, + guaranteed_bandwidth: 1_000, + }; + s.send(event.clone()).await.unwrap(); + let r = r.recv().await; + match r { + Ok(ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth: _, + }) => { + assert_eq!(sid, Sid::new(1)); + assert_eq!(prio, 4u8); + assert_eq!(promises, Promises::GUARANTEED_DELIVERY); + }, + _ => { + panic!("wrong type {:?}", r); + }, } } - - pub async fn setup(self, protocol: &Protocols) -> Result<(Pid, Sid, u128, Vec), ()> { - let (c2w_frame_s, c2w_frame_r) = mpsc::unbounded_channel::(); - let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded_channel::(); - - let (read_stop_sender, read_stop_receiver) = oneshot::channel(); - let handler_future = - self.frame_handler(&mut w2c_cid_frame_r, c2w_frame_s, read_stop_sender); - let res = match protocol { - Protocols::Tcp(tcp) => { - (join! { - tcp.read_from_wire(self.cid, &mut w2c_cid_frame_s, read_stop_receiver), - tcp.write_to_wire(self.cid, c2w_frame_r).fuse(), - handler_future, - }) - .2 - }, - Protocols::Udp(udp) => { - (join! { - udp.read_from_wire(self.cid, &mut w2c_cid_frame_s, read_stop_receiver), - udp.write_to_wire(self.cid, c2w_frame_r), - handler_future, - }) - .2 - }, - }; - - match res { - Ok(res) => { - let fake_waker = noop_waker(); - let mut ctx = Context::from_waker(&fake_waker); - let mut leftover_frames = vec![]; - while let Poll::Ready(Some(cid_frame)) = w2c_cid_frame_r.poll_recv(&mut ctx) { - leftover_frames.push(cid_frame); - } - let cnt = leftover_frames.len(); - if cnt > 0 { - debug!( - ?cnt, - "Some additional frames got already transferred, piping them to the \ - bparticipant as leftover_frames" - ); - } - Ok((res.0, res.1, res.2, leftover_frames)) - }, - Err(()) => Err(()), - } - } - - async fn frame_handler( - &self, - w2c_cid_frame_r: &mut mpsc::UnboundedReceiver, - mut c2w_frame_s: mpsc::UnboundedSender, - read_stop_sender: oneshot::Sender<()>, - ) -> Result<(Pid, Sid, u128), ()> { - const ERR_S: &str = "Got A Raw Message, these are usually Debug Messages indicating that \ - something went wrong on network layer and connection will be closed"; - #[cfg(feature = "metrics")] - let cid_string = self.cid.to_string(); - - if self.init_handshake { - self.send_handshake(&mut c2w_frame_s).await; - } - - let frame = w2c_cid_frame_r.recv().await.map(|(_cid, frame)| frame); - #[cfg(feature = "metrics")] - { - if let Some(Ok(ref frame)) = frame { - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, &frame.get_string()]) - .inc(); - } - } - let r = match frame { - Some(Ok(Frame::Handshake { - magic_number, - version, - })) => { - trace!(?magic_number, ?version, "Recv handshake"); - if magic_number != VELOREN_MAGIC_NUMBER { - error!(?magic_number, "Connection with invalid magic_number"); - #[cfg(debug_assertions)] - self.send_raw_and_shutdown(&mut c2w_frame_s, Self::WRONG_NUMBER.to_vec()) - .await; - Err(()) - } else if version != VELOREN_NETWORK_VERSION { - error!(?version, "Connection with wrong network version"); - #[cfg(debug_assertions)] - self.send_raw_and_shutdown( - &mut c2w_frame_s, - format!( - "{} Our Version: {:?}\nYour Version: {:?}\nClosing the connection", - Self::WRONG_VERSION, - VELOREN_NETWORK_VERSION, - version, - ) - .as_bytes() - .to_vec(), - ) - .await; - Err(()) - } else { - debug!("Handshake completed"); - if self.init_handshake { - self.send_init(&mut c2w_frame_s).await; - } else { - self.send_handshake(&mut c2w_frame_s).await; - } - Ok(()) - } - }, - Some(Ok(frame)) => { - #[cfg(feature = "metrics")] - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, frame.get_string()]) - .inc(); - if let Frame::Raw(bytes) = frame { - match std::str::from_utf8(bytes.as_slice()) { - Ok(string) => error!(?string, ERR_S), - _ => error!(?bytes, ERR_S), - } - } - Err(()) - }, - Some(Err(())) => { - info!("Protocol got interrupted"); - Err(()) - }, - None => Err(()), - }; - if let Err(()) = r { - if let Err(e) = read_stop_sender.send(()) { - trace!( - ?e, - "couldn't stop protocol, probably it encountered a Protocol Stop and closed \ - itself already, which is fine" - ); - } - return Err(()); - } - - let frame = w2c_cid_frame_r.recv().await.map(|(_cid, frame)| frame); - let r = match frame { - Some(Ok(Frame::Init { pid, secret })) => { - debug!(?pid, "Participant send their ID"); - #[cfg(feature = "metrics")] - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, "ParticipantId"]) - .inc(); - let stream_id_offset = if self.init_handshake { - STREAM_ID_OFFSET1 - } else { - self.send_init(&mut c2w_frame_s).await; - STREAM_ID_OFFSET2 - }; - info!(?pid, "This Handshake is now configured!"); - Ok((pid, stream_id_offset, secret)) - }, - Some(Ok(frame)) => { - #[cfg(feature = "metrics")] - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, frame.get_string()]) - .inc(); - if let Frame::Raw(bytes) = frame { - match std::str::from_utf8(bytes.as_slice()) { - Ok(string) => error!(?string, ERR_S), - _ => error!(?bytes, ERR_S), - } - } - Err(()) - }, - Some(Err(())) => { - info!("Protocol got interrupted"); - Err(()) - }, - None => Err(()), - }; - if r.is_err() { - if let Err(e) = read_stop_sender.send(()) { - trace!( - ?e, - "couldn't stop protocol, probably it encountered a Protocol Stop and closed \ - itself already, which is fine" - ); - } - } - r - } - - async fn send_handshake(&self, c2w_frame_s: &mut mpsc::UnboundedSender) { - #[cfg(feature = "metrics")] - self.metrics - .frames_out_total - .with_label_values(&[&self.cid.to_string(), "Handshake"]) - .inc(); - c2w_frame_s - .send(Frame::Handshake { - magic_number: VELOREN_MAGIC_NUMBER, - version: VELOREN_NETWORK_VERSION, - }) - .unwrap(); - } - - async fn send_init(&self, c2w_frame_s: &mut mpsc::UnboundedSender) { - #[cfg(feature = "metrics")] - self.metrics - .frames_out_total - .with_label_values(&[&self.cid.to_string(), "ParticipantId"]) - .inc(); - c2w_frame_s - .send(Frame::Init { - pid: self.local_pid, - secret: self.secret, - }) - .unwrap(); - } - - #[cfg(debug_assertions)] - async fn send_raw_and_shutdown( - &self, - c2w_frame_s: &mut mpsc::UnboundedSender, - data: Vec, - ) { - debug!("Sending client instructions before killing"); - #[cfg(feature = "metrics")] - { - let cid_string = self.cid.to_string(); - self.metrics - .frames_out_total - .with_label_values(&[&cid_string, "Raw"]) - .inc(); - self.metrics - .frames_out_total - .with_label_values(&[&cid_string, "Shutdown"]) - .inc(); - } - c2w_frame_s.send(Frame::Raw(data)).unwrap(); - c2w_frame_s.send(Frame::Shutdown).unwrap(); - } } diff --git a/network/src/lib.rs b/network/src/lib.rs index ffba192643..7593f1edd8 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -104,14 +104,10 @@ mod channel; mod message; #[cfg(feature = "metrics")] mod metrics; mod participant; -mod prios; -mod protocols; mod scheduler; -#[macro_use] -mod types; pub use api::{ Network, NetworkError, Participant, ParticipantError, ProtocolAddr, Stream, StreamError, }; pub use message::Message; -pub use types::{Pid, Promises}; +pub use network_protocol::{Pid, Promises}; diff --git a/network/src/message.rs b/network/src/message.rs index 9ab9941599..0ad24c63ad 100644 --- a/network/src/message.rs +++ b/network/src/message.rs @@ -1,11 +1,9 @@ use serde::{de::DeserializeOwned, Serialize}; //use std::collections::VecDeque; +use crate::api::{Stream, StreamError}; +use network_protocol::MessageBuffer; #[cfg(feature = "compression")] -use crate::types::Promises; -use crate::{ - api::{Stream, StreamError}, - types::{Frame, Mid, Sid}, -}; +use network_protocol::Promises; use std::{io, sync::Arc}; #[cfg(all(feature = "compression", debug_assertions))] use tracing::warn; @@ -23,29 +21,6 @@ pub struct Message { pub(crate) compressed: bool, } -//Todo: Evaluate switching to VecDeque for quickly adding and removing data -// from front, back. -// - It would prob require custom bincode code but thats possible. -pub(crate) struct MessageBuffer { - pub data: Vec, -} - -#[derive(Debug)] -pub(crate) struct OutgoingMessage { - pub buffer: Arc, - pub cursor: u64, - pub mid: Mid, - pub sid: Sid, -} - -#[derive(Debug)] -pub(crate) struct IncomingMessage { - pub buffer: MessageBuffer, - pub length: u64, - pub mid: Mid, - pub sid: Sid, -} - impl Message { /// This serializes any message, according to the [`Streams`] [`Promises`]. /// You can reuse this `Message` and send it via other [`Streams`], if the @@ -170,38 +145,6 @@ impl Message { } } -impl OutgoingMessage { - pub(crate) const FRAME_DATA_SIZE: u64 = 1400; - - /// returns if msg is empty - pub(crate) fn fill_next>( - &mut self, - msg_sid: Sid, - frames: &mut E, - ) -> bool { - let to_send = std::cmp::min( - self.buffer.data[self.cursor as usize..].len() as u64, - Self::FRAME_DATA_SIZE, - ); - if to_send > 0 { - if self.cursor == 0 { - frames.extend(std::iter::once((msg_sid, Frame::DataHeader { - mid: self.mid, - sid: self.sid, - length: self.buffer.data.len() as u64, - }))); - } - frames.extend(std::iter::once((msg_sid, Frame::Data { - mid: self.mid, - start: self.cursor, - data: self.buffer.data[self.cursor as usize..][..to_send as usize].to_vec(), - }))); - }; - self.cursor += to_send; - self.cursor >= self.buffer.data.len() as u64 - } -} - ///wouldn't trust this aaaassss much, fine for tests pub(crate) fn partial_eq_io_error(first: &io::Error, second: &io::Error) -> bool { if let Some(f) = first.raw_os_error() { @@ -231,28 +174,6 @@ pub(crate) fn partial_eq_bincode(first: &bincode::ErrorKind, second: &bincode::E } } -impl std::fmt::Debug for MessageBuffer { - #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - //TODO: small messages! - let len = self.data.len(); - if len > 20 { - write!( - f, - "MessageBuffer(len: {}, {}, {}, {}, {:X?}..{:X?})", - len, - u32::from_le_bytes([self.data[0], self.data[1], self.data[2], self.data[3]]), - u32::from_le_bytes([self.data[4], self.data[5], self.data[6], self.data[7]]), - u32::from_le_bytes([self.data[8], self.data[9], self.data[10], self.data[11]]), - &self.data[13..16], - &self.data[len - 8..len] - ) - } else { - write!(f, "MessageBuffer(len: {}, {:?})", len, &self.data[..]) - } - } -} - #[cfg(test)] mod tests { use crate::{api::Stream, message::*}; @@ -260,7 +181,8 @@ mod tests { use tokio::sync::mpsc; fn stub_stream(compressed: bool) -> Stream { - use crate::{api::*, types::*}; + use crate::api::*; + use network_protocol::*; #[cfg(feature = "compression")] let promises = if compressed { @@ -281,6 +203,7 @@ mod tests { Sid::new(0), 0u8, promises, + 1_000_000, Arc::new(AtomicBool::new(true)), a2b_msg_s, b2a_msg_recv_r, diff --git a/network/src/metrics.rs b/network/src/metrics.rs index d43aeaae3a..60650f68fe 100644 --- a/network/src/metrics.rs +++ b/network/src/metrics.rs @@ -1,10 +1,6 @@ -use crate::types::{Cid, Frame, Pid}; -use prometheus::{ - core::{AtomicU64, GenericCounter}, - IntCounter, IntCounterVec, IntGauge, IntGaugeVec, Opts, Registry, -}; +use network_protocol::Pid; +use prometheus::{IntCounter, IntCounterVec, IntGauge, IntGaugeVec, Opts, Registry}; use std::error::Error; -use tracing::*; /// 1:1 relation between NetworkMetrics and Network /// use 2NF here and avoid redundant data like CHANNEL AND PARTICIPANT encoding. @@ -25,29 +21,6 @@ pub struct NetworkMetrics { pub streams_opened_total: IntCounterVec, pub streams_closed_total: IntCounterVec, pub network_info: IntGauge, - // Frames counted a channel level, seperated by CHANNEL (and PARTICIPANT) AND FRAME TYPE, - pub frames_out_total: IntCounterVec, - pub frames_in_total: IntCounterVec, - // Frames counted at protocol level, seperated by CHANNEL (and PARTICIPANT) AND FRAME TYPE, - pub frames_wire_out_total: IntCounterVec, - pub frames_wire_in_total: IntCounterVec, - // throughput at protocol level, seperated by CHANNEL (and PARTICIPANT), - pub wire_out_throughput: IntCounterVec, - pub wire_in_throughput: IntCounterVec, - // send(prio) Messages count, seperated by STREAM AND PARTICIPANT, - pub message_out_total: IntCounterVec, - // send(prio) Messages throughput, seperated by STREAM AND PARTICIPANT, - pub message_out_throughput: IntCounterVec, - // flushed(prio) stream count, seperated by PARTICIPANT, - pub streams_flushed: IntCounterVec, - // TODO: queued Messages, seperated by STREAM (add PART, CHANNEL), - // queued Messages, seperated by PARTICIPANT - pub queued_count: IntGaugeVec, - // TODO: queued Messages bytes, seperated by STREAM (add PART, CHANNEL), - // queued Messages bytes, seperated by PARTICIPANT - pub queued_bytes: IntGaugeVec, - // ping calculated based on last msg seperated by PARTICIPANT - pub participants_ping: IntGaugeVec, } impl NetworkMetrics { @@ -115,99 +88,13 @@ impl NetworkMetrics { "version", &format!( "{}.{}.{}", - &crate::types::VELOREN_NETWORK_VERSION[0], - &crate::types::VELOREN_NETWORK_VERSION[1], - &crate::types::VELOREN_NETWORK_VERSION[2] + &network_protocol::VELOREN_NETWORK_VERSION[0], + &network_protocol::VELOREN_NETWORK_VERSION[1], + &network_protocol::VELOREN_NETWORK_VERSION[2] ), ) .const_label("local_pid", &format!("{}", &local_pid)); let network_info = IntGauge::with_opts(opts)?; - let frames_out_total = IntCounterVec::new( - Opts::new( - "frames_out_total", - "Number of all frames send per channel, at the channel level", - ), - &["channel", "frametype"], - )?; - let frames_in_total = IntCounterVec::new( - Opts::new( - "frames_in_total", - "Number of all frames received per channel, at the channel level", - ), - &["channel", "frametype"], - )?; - let frames_wire_out_total = IntCounterVec::new( - Opts::new( - "frames_wire_out_total", - "Number of all frames send per channel, at the protocol level", - ), - &["channel", "frametype"], - )?; - let frames_wire_in_total = IntCounterVec::new( - Opts::new( - "frames_wire_in_total", - "Number of all frames received per channel, at the protocol level", - ), - &["channel", "frametype"], - )?; - let wire_out_throughput = IntCounterVec::new( - Opts::new( - "wire_out_throughput", - "Throupgput of all data frames send per channel, at the protocol level", - ), - &["channel"], - )?; - let wire_in_throughput = IntCounterVec::new( - Opts::new( - "wire_in_throughput", - "Throupgput of all data frames send per channel, at the protocol level", - ), - &["channel"], - )?; - //TODO IN - let message_out_total = IntCounterVec::new( - Opts::new( - "message_out_total", - "Number of messages send by streams on the network", - ), - &["participant", "stream"], - )?; - //TODO IN - let message_out_throughput = IntCounterVec::new( - Opts::new( - "message_out_throughput", - "Throughput of messages send by streams on the network", - ), - &["participant", "stream"], - )?; - let streams_flushed = IntCounterVec::new( - Opts::new( - "stream_flushed", - "Number of flushed streams requested to PrioManager at participant level", - ), - &["participant"], - )?; - let queued_count = IntGaugeVec::new( - Opts::new( - "queued_count", - "Queued number of messages by participant on the network", - ), - &["channel"], - )?; - let queued_bytes = IntGaugeVec::new( - Opts::new( - "queued_bytes", - "Queued bytes of messages by participant on the network", - ), - &["channel"], - )?; - let participants_ping = IntGaugeVec::new( - Opts::new( - "participants_ping", - "Ping time to participants on the network", - ), - &["channel"], - )?; Ok(Self { listen_requests_total, @@ -220,18 +107,6 @@ impl NetworkMetrics { streams_opened_total, streams_closed_total, network_info, - frames_out_total, - frames_in_total, - frames_wire_out_total, - frames_wire_in_total, - wire_out_throughput, - wire_in_throughput, - message_out_total, - message_out_throughput, - streams_flushed, - queued_count, - queued_bytes, - participants_ping, }) } @@ -246,22 +121,8 @@ impl NetworkMetrics { registry.register(Box::new(self.streams_opened_total.clone()))?; registry.register(Box::new(self.streams_closed_total.clone()))?; registry.register(Box::new(self.network_info.clone()))?; - registry.register(Box::new(self.frames_out_total.clone()))?; - registry.register(Box::new(self.frames_in_total.clone()))?; - registry.register(Box::new(self.frames_wire_out_total.clone()))?; - registry.register(Box::new(self.frames_wire_in_total.clone()))?; - registry.register(Box::new(self.wire_out_throughput.clone()))?; - registry.register(Box::new(self.wire_in_throughput.clone()))?; - registry.register(Box::new(self.message_out_total.clone()))?; - registry.register(Box::new(self.message_out_throughput.clone()))?; - registry.register(Box::new(self.queued_count.clone()))?; - registry.register(Box::new(self.queued_bytes.clone()))?; - registry.register(Box::new(self.participants_ping.clone()))?; Ok(()) } - - //pub fn _is_100th_tick(&self) -> bool { - // self.tick.load(Ordering::Relaxed).rem_euclid(100) == 0 } } impl std::fmt::Debug for NetworkMetrics { @@ -270,138 +131,3 @@ impl std::fmt::Debug for NetworkMetrics { write!(f, "NetworkMetrics()") } } - -/* -pub(crate) struct PidCidFrameCache { - metric: MetricVec, - pid: String, - cache: Vec<[T::M; 8]>, -} -*/ - -pub(crate) struct MultiCidFrameCache { - metric: IntCounterVec, - cache: Vec<[Option>; Frame::FRAMES_LEN as usize]>, -} - -impl MultiCidFrameCache { - const CACHE_SIZE: usize = 2048; - - pub fn new(metric: IntCounterVec) -> Self { - Self { - metric, - cache: vec![], - } - } - - fn populate(&mut self, cid: Cid) { - let start_cid = self.cache.len(); - if cid >= start_cid as u64 && cid > (Self::CACHE_SIZE as Cid) { - warn!( - ?cid, - "cid, getting quite high, is this a attack on the cache?" - ); - } - self.cache.resize((cid + 1) as usize, [ - None, None, None, None, None, None, None, None, - ]); - } - - pub fn with_label_values(&mut self, cid: Cid, frame: &Frame) -> &GenericCounter { - self.populate(cid); - let frame_int = frame.get_int() as usize; - let r = &mut self.cache[cid as usize][frame_int]; - if r.is_none() { - *r = Some( - self.metric - .with_label_values(&[&cid.to_string(), &frame_int.to_string()]), - ); - } - r.as_ref().unwrap() - } -} - -pub(crate) struct CidFrameCache { - cache: [GenericCounter; Frame::FRAMES_LEN as usize], -} - -impl CidFrameCache { - pub fn new(metric: IntCounterVec, cid: Cid) -> Self { - let cid = cid.to_string(); - let cache = [ - metric.with_label_values(&[&cid, Frame::int_to_string(0)]), - metric.with_label_values(&[&cid, Frame::int_to_string(1)]), - metric.with_label_values(&[&cid, Frame::int_to_string(2)]), - metric.with_label_values(&[&cid, Frame::int_to_string(3)]), - metric.with_label_values(&[&cid, Frame::int_to_string(4)]), - metric.with_label_values(&[&cid, Frame::int_to_string(5)]), - metric.with_label_values(&[&cid, Frame::int_to_string(6)]), - metric.with_label_values(&[&cid, Frame::int_to_string(7)]), - ]; - Self { cache } - } - - pub fn with_label_values(&mut self, frame: &Frame) -> &GenericCounter { - &self.cache[frame.get_int() as usize] - } -} - -#[cfg(test)] -mod tests { - use crate::{ - metrics::*, - types::{Frame, Pid}, - }; - - #[test] - fn register_metrics() { - let registry = Registry::new(); - let metrics = NetworkMetrics::new(&Pid::fake(1)).unwrap(); - metrics.register(®istry).unwrap(); - } - - #[test] - fn multi_cid_frame_cache() { - let pid = Pid::fake(1); - let frame1 = Frame::Raw(b"Foo".to_vec()); - let frame2 = Frame::Raw(b"Bar".to_vec()); - let metrics = NetworkMetrics::new(&pid).unwrap(); - let mut cache = MultiCidFrameCache::new(metrics.frames_in_total); - let v1 = cache.with_label_values(1, &frame1); - v1.inc(); - assert_eq!(v1.get(), 1); - let v2 = cache.with_label_values(1, &frame1); - v2.inc(); - assert_eq!(v2.get(), 2); - let v3 = cache.with_label_values(1, &frame2); - v3.inc(); - assert_eq!(v3.get(), 3); - let v4 = cache.with_label_values(3, &frame1); - v4.inc(); - assert_eq!(v4.get(), 1); - let v5 = cache.with_label_values(3, &Frame::Shutdown); - v5.inc(); - assert_eq!(v5.get(), 1); - } - - #[test] - fn cid_frame_cache() { - let pid = Pid::fake(1); - let frame1 = Frame::Raw(b"Foo".to_vec()); - let frame2 = Frame::Raw(b"Bar".to_vec()); - let metrics = NetworkMetrics::new(&pid).unwrap(); - let mut cache = CidFrameCache::new(metrics.frames_wire_out_total, 1); - let v1 = cache.with_label_values(&frame1); - v1.inc(); - assert_eq!(v1.get(), 1); - let v2 = cache.with_label_values(&frame1); - v2.inc(); - assert_eq!(v2.get(), 2); - let v3 = cache.with_label_values(&frame2); - v3.inc(); - assert_eq!(v3.get(), 3); - let v4 = cache.with_label_values(&Frame::Shutdown); - v4.inc(); - assert_eq!(v4.get(), 1); - } -} diff --git a/network/src/participant.rs b/network/src/participant.rs index 6986a70e8f..a942632f7b 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -1,43 +1,39 @@ #[cfg(feature = "metrics")] -use crate::metrics::{MultiCidFrameCache, NetworkMetrics}; +use crate::metrics::NetworkMetrics; use crate::{ api::{ParticipantError, Stream}, - channel::Channel, - message::{IncomingMessage, MessageBuffer, OutgoingMessage}, - prios::PrioManager, - protocols::Protocols, - types::{Cid, Frame, Pid, Prio, Promises, Sid}, + channel::{Protocols, RecvProtocols, SendProtocols}, }; use futures_util::{FutureExt, StreamExt}; +use network_protocol::{ + Bandwidth, Cid, MessageBuffer, Pid, Prio, Promises, ProtocolEvent, RecvProtocol, SendProtocol, + Sid, +}; use std::{ - collections::{HashMap, VecDeque}, + collections::HashMap, sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, + atomic::{AtomicBool, AtomicI32, Ordering}, Arc, }, time::{Duration, Instant}, }; use tokio::{ - runtime::Runtime, select, sync::{mpsc, oneshot, Mutex, RwLock}, + task::JoinHandle, }; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::*; -use tracing_futures::Instrument; -pub(crate) type A2bStreamOpen = (Prio, Promises, oneshot::Sender); -pub(crate) type C2pFrame = (Cid, Result); -pub(crate) type S2bCreateChannel = (Cid, Sid, Protocols, Vec, oneshot::Sender<()>); -pub(crate) type S2bShutdownBparticipant = oneshot::Sender>; +pub(crate) type A2bStreamOpen = (Prio, Promises, Bandwidth, oneshot::Sender); +pub(crate) type S2bCreateChannel = (Cid, Sid, Protocols, oneshot::Sender<()>); +pub(crate) type S2bShutdownBparticipant = (Duration, oneshot::Sender>); pub(crate) type B2sPrioStatistic = (Pid, u64, u64); #[derive(Debug)] struct ChannelInfo { cid: Cid, cid_string: String, //optimisationmetrics - b2w_frame_s: mpsc::UnboundedSender, - b2r_read_shutdown: oneshot::Sender<()>, } #[derive(Debug)] @@ -45,23 +41,19 @@ struct StreamInfo { prio: Prio, promises: Promises, send_closed: Arc, - b2a_msg_recv_s: Mutex>, + b2a_msg_recv_s: Mutex>, } #[derive(Debug)] struct ControlChannels { - a2b_stream_open_r: mpsc::UnboundedReceiver, + a2b_open_stream_r: mpsc::UnboundedReceiver, b2a_stream_opened_s: mpsc::UnboundedSender, - b2b_close_stream_opened_sender_r: oneshot::Receiver<()>, s2b_create_channel_r: mpsc::UnboundedReceiver, - a2b_close_stream_r: mpsc::UnboundedReceiver, - a2b_close_stream_s: mpsc::UnboundedSender, s2b_shutdown_bparticipant_r: oneshot::Receiver, /* own */ } #[derive(Debug)] struct ShutdownInfo { - //a2b_stream_open_r: mpsc::UnboundedReceiver, b2b_close_stream_opened_sender_s: Option>, error: Option, } @@ -71,29 +63,27 @@ pub struct BParticipant { remote_pid: Pid, remote_pid_string: String, //optimisation offset_sid: Sid, - runtime: Arc, channels: Arc>>>, streams: RwLock>, - running_mgr: AtomicUsize, run_channels: Option, + shutdown_barrier: AtomicI32, #[cfg(feature = "metrics")] metrics: Arc, no_channel_error_info: RwLock<(Instant, u64)>, - shutdown_info: RwLock, } impl BParticipant { - const BANDWIDTH: u64 = 25_000_000; - const FRAMES_PER_TICK: u64 = Self::BANDWIDTH * Self::TICK_TIME_MS / 1000 / 1400 /*TCP FRAME*/; + // We use integer instead of Barrier to not block mgr from freeing at the end + const BARR_CHANNEL: i32 = 1; + const BARR_RECV: i32 = 4; + const BARR_SEND: i32 = 2; const TICK_TIME: Duration = Duration::from_millis(Self::TICK_TIME_MS); - //in bit/s const TICK_TIME_MS: u64 = 10; #[allow(clippy::type_complexity)] pub(crate) fn new( remote_pid: Pid, offset_sid: Sid, - runtime: Arc, #[cfg(feature = "metrics")] metrics: Arc, ) -> ( Self, @@ -102,27 +92,15 @@ impl BParticipant { mpsc::UnboundedSender, oneshot::Sender, ) { - let (a2b_steam_open_s, a2b_stream_open_r) = mpsc::unbounded_channel::(); + let (a2b_open_stream_s, a2b_open_stream_r) = mpsc::unbounded_channel::(); let (b2a_stream_opened_s, b2a_stream_opened_r) = mpsc::unbounded_channel::(); - let (b2b_close_stream_opened_sender_s, b2b_close_stream_opened_sender_r) = - oneshot::channel(); - let (a2b_close_stream_s, a2b_close_stream_r) = mpsc::unbounded_channel(); let (s2b_shutdown_bparticipant_s, s2b_shutdown_bparticipant_r) = oneshot::channel(); let (s2b_create_channel_s, s2b_create_channel_r) = mpsc::unbounded_channel(); - let shutdown_info = RwLock::new(ShutdownInfo { - //a2b_stream_open_r: a2b_stream_open_r.clone(), - b2b_close_stream_opened_sender_s: Some(b2b_close_stream_opened_sender_s), - error: None, - }); - let run_channels = Some(ControlChannels { - a2b_stream_open_r, + a2b_open_stream_r, b2a_stream_opened_s, - b2b_close_stream_opened_sender_r, s2b_create_channel_r, - a2b_close_stream_r, - a2b_close_stream_s, s2b_shutdown_bparticipant_r, }); @@ -131,17 +109,17 @@ impl BParticipant { remote_pid, remote_pid_string: remote_pid.to_string(), offset_sid, - runtime, channels: Arc::new(RwLock::new(HashMap::new())), streams: RwLock::new(HashMap::new()), - running_mgr: AtomicUsize::new(0), + shutdown_barrier: AtomicI32::new( + Self::BARR_CHANNEL + Self::BARR_SEND + Self::BARR_RECV, + ), run_channels, #[cfg(feature = "metrics")] metrics, no_channel_error_info: RwLock::new((Instant::now(), 0)), - shutdown_info, }, - a2b_steam_open_s, + a2b_open_stream_s, b2a_stream_opened_r, s2b_create_channel_s, s2b_shutdown_bparticipant_s, @@ -149,693 +127,486 @@ impl BParticipant { } pub async fn run(mut self, b2s_prio_statistic_s: mpsc::UnboundedSender) { - //those managers that listen on api::Participant need an additional oneshot for - // shutdown scenario, those handled by scheduler will be closed by it. - let (shutdown_send_mgr_sender, shutdown_send_mgr_receiver) = oneshot::channel(); - let (shutdown_stream_close_mgr_sender, shutdown_stream_close_mgr_receiver) = - oneshot::channel(); - let (shutdown_open_mgr_sender, shutdown_open_mgr_receiver) = oneshot::channel(); - let (w2b_frames_s, w2b_frames_r) = mpsc::unbounded_channel::(); - let (prios, a2p_msg_s, b2p_notify_empty_stream_s) = PrioManager::new( - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - self.remote_pid_string.clone(), - ); + let (b2b_add_send_protocol_s, b2b_add_send_protocol_r) = + mpsc::unbounded_channel::<(Cid, SendProtocols)>(); + let (b2b_add_recv_protocol_s, b2b_add_recv_protocol_r) = + mpsc::unbounded_channel::<(Cid, RecvProtocols)>(); + let (b2b_close_send_protocol_s, b2b_close_send_protocol_r) = + async_channel::unbounded::(); + let (b2b_force_close_recv_protocol_s, b2b_force_close_recv_protocol_r) = + async_channel::unbounded::(); + + let (a2b_close_stream_s, a2b_close_stream_r) = mpsc::unbounded_channel::(); + const STREAM_BOUND: usize = 10_000; + let (a2b_msg_s, a2b_msg_r) = + crossbeam_channel::bounded::<(Sid, Arc)>(STREAM_BOUND); let run_channels = self.run_channels.take().unwrap(); tokio::join!( - self.open_mgr( - run_channels.a2b_stream_open_r, - run_channels.a2b_close_stream_s.clone(), - a2p_msg_s.clone(), - shutdown_open_mgr_receiver, + self.send_mgr( + run_channels.a2b_open_stream_r, + a2b_close_stream_r, + a2b_msg_r, + b2b_add_send_protocol_r, + b2b_close_send_protocol_r, + b2s_prio_statistic_s, + a2b_msg_s.clone(), //self + a2b_close_stream_s.clone(), //self ), - self.handle_frames_mgr( - w2b_frames_r, + self.recv_mgr( run_channels.b2a_stream_opened_s, - run_channels.b2b_close_stream_opened_sender_r, - run_channels.a2b_close_stream_s, - a2p_msg_s.clone(), + b2b_add_recv_protocol_r, + b2b_force_close_recv_protocol_r, + b2b_close_send_protocol_s.clone(), + a2b_msg_s.clone(), //self + a2b_close_stream_s.clone(), //self ), - self.create_channel_mgr(run_channels.s2b_create_channel_r, w2b_frames_s), - self.send_mgr(prios, shutdown_send_mgr_receiver, b2s_prio_statistic_s), - self.stream_close_mgr( - run_channels.a2b_close_stream_r, - shutdown_stream_close_mgr_receiver, - b2p_notify_empty_stream_s, + self.create_channel_mgr( + run_channels.s2b_create_channel_r, + b2b_add_send_protocol_s, + b2b_add_recv_protocol_s, ), self.participant_shutdown_mgr( run_channels.s2b_shutdown_bparticipant_r, - shutdown_open_mgr_sender, - shutdown_stream_close_mgr_sender, - shutdown_send_mgr_sender, + b2b_close_send_protocol_s.clone(), + b2b_force_close_recv_protocol_s, ), ); } + //TODO: local stream_cid: HashMap to know the respective protocol async fn send_mgr( &self, - mut prios: PrioManager, - mut shutdown_send_mgr_receiver: oneshot::Receiver>, - b2s_prio_statistic_s: mpsc::UnboundedSender, - ) { - //This time equals the MINIMUM Latency in average, so keep it down and //Todo: - // make it configurable or switch to await E.g. Prio 0 = await, prio 50 - // wait for more messages - self.running_mgr.fetch_add(1, Ordering::Relaxed); - let mut b2b_prios_flushed_s = None; //closing up - let mut interval = tokio::time::interval(Self::TICK_TIME); - trace!("Start send_mgr"); - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - let mut i: u64 = 0; - loop { - let mut frames = VecDeque::new(); - prios - .fill_frames(Self::FRAMES_PER_TICK as usize, &mut frames) - .await; - let len = frames.len(); - for (_, frame) in frames { - self.send_frame( - frame, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await; - } - b2s_prio_statistic_s - .send((self.remote_pid, len as u64, /* */ 0)) - .unwrap(); - interval.tick().await; - i += 1; - if i.rem_euclid(1000) == 0 { - trace!("Did 1000 ticks"); - } - //shutdown after all msg are send! - // Make sure this is called after the API is closed, and all streams are known - // to be droped to the priomgr - if b2b_prios_flushed_s.is_some() && (len == 0) { - break; - } - if b2b_prios_flushed_s.is_none() { - if let Ok(prios_flushed_s) = shutdown_send_mgr_receiver.try_recv() { - b2b_prios_flushed_s = Some(prios_flushed_s); - } - } - } - trace!("Stop send_mgr"); - b2b_prios_flushed_s - .expect("b2b_prios_flushed_s not set") - .send(()) - .unwrap(); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); - } - - //returns false if sending isn't possible. In that case we have to render the - // Participant `closed` - #[must_use = "You need to check if the send was successful and report to client!"] - async fn send_frame( - &self, - frame: Frame, - #[cfg(feature = "metrics")] frames_out_total_cache: &mut MultiCidFrameCache, - ) -> bool { - let mut drop_cid = None; - // TODO: find out ideal channel here - - let res = if let Some(ci) = self.channels.read().await.values().next() { - let ci = ci.lock().await; - //we are increasing metrics without checking the result to please - // borrow_checker. otherwise we would need to close `frame` what we - // dont want! - #[cfg(feature = "metrics")] - frames_out_total_cache - .with_label_values(ci.cid, &frame) - .inc(); - if let Err(e) = ci.b2w_frame_s.send(frame) { - let cid = ci.cid; - info!(?e, ?cid, "channel no longer available"); - drop_cid = Some(cid); - false - } else { - true - } - } else { - let mut guard = self.no_channel_error_info.write().await; - let now = Instant::now(); - if now.duration_since(guard.0) > Duration::from_secs(1) { - guard.0 = now; - let occurrences = guard.1 + 1; - guard.1 = 0; - let lastframe = frame; - error!( - ?occurrences, - ?lastframe, - "Participant has no channel to communicate on" - ); - } else { - guard.1 += 1; - } - false - }; - if let Some(cid) = drop_cid { - if let Some(ci) = self.channels.write().await.remove(&cid) { - let ci = ci.into_inner(); - trace!(?cid, "stopping read protocol"); - if let Err(e) = ci.b2r_read_shutdown.send(()) { - trace!(?cid, ?e, "seems like was already shut down"); - } - } - //TODO FIXME tags: takeover channel multiple - info!( - "FIXME: the frame is actually drop. which is fine for now as the participant will \ - be closed, but not if we do channel-takeover" - ); - //TEMP FIX: as we dont have channel takeover yet drop the whole bParticipant - self.close_write_api(Some(ParticipantError::ProtocolFailedUnrecoverable)) - .await; - }; - res - } - - async fn handle_frames_mgr( - &self, - mut w2b_frames_r: mpsc::UnboundedReceiver, - b2a_stream_opened_s: mpsc::UnboundedSender, - b2b_close_stream_opened_sender_r: oneshot::Receiver<()>, + mut a2b_open_stream_r: mpsc::UnboundedReceiver, + mut a2b_close_stream_r: mpsc::UnboundedReceiver, + a2b_msg_r: crossbeam_channel::Receiver<(Sid, Arc)>, + mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, SendProtocols)>, + b2b_close_send_protocol_r: async_channel::Receiver, + _b2s_prio_statistic_s: mpsc::UnboundedSender, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, a2b_close_stream_s: mpsc::UnboundedSender, - a2p_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); - trace!("Start handle_frames_mgr"); - let mut messages = HashMap::new(); - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - let mut dropped_instant = Instant::now(); - let mut dropped_cnt = 0u64; - let mut dropped_sid = Sid::new(0); - let mut b2a_stream_opened_s = Some(b2a_stream_opened_s); - let mut b2b_close_stream_opened_sender_r = b2b_close_stream_opened_sender_r.fuse(); + let mut send_protocols: HashMap = HashMap::new(); + let mut interval = tokio::time::interval(Self::TICK_TIME); + let mut stream_ids = self.offset_sid; + trace!("workaround, activly wait for first protocol"); + b2b_add_protocol_r + .recv() + .await + .map(|(c, p)| send_protocols.insert(c, p)); + trace!("Start send_mgr"); + loop { + let (open, close, _, addp, remp) = select!( + next = a2b_open_stream_r.recv().fuse() => (Some(next), None, None, None, None), + next = a2b_close_stream_r.recv().fuse() => (None, Some(next), None, None, None), + _ = interval.tick() => (None, None, Some(()), None, None), + next = b2b_add_protocol_r.recv().fuse() => (None, None, None, Some(next), None), + next = b2b_close_send_protocol_r.recv().fuse() => (None, None, None, None, Some(next)), + ); - while let Some((cid, result_frame)) = select!( - next = w2b_frames_r.recv().fuse() => next, - _ = &mut b2b_close_stream_opened_sender_r => { - b2a_stream_opened_s = None; - None - }, - ) { - //trace!(?result_frame, "handling frame"); - let frame = match result_frame { - Ok(frame) => frame, - Err(()) => { - // The read protocol stopped, i need to make sure that write gets stopped, can - // drop channel as it's dead anyway - debug!("read protocol was closed. Stopping channel"); - self.channels.write().await.remove(&cid); + trace!(?open, ?close, ?addp, ?remp, "foobar"); + + addp.flatten().map(|(c, p)| send_protocols.insert(c, p)); + match remp { + Some(Ok(cid)) => { + trace!(?cid, "remove send protocol"); + match send_protocols.remove(&cid) { + Some(mut prot) => { + trace!("blocking flush"); + let _ = prot.flush(u64::MAX, Duration::from_secs(1)).await; + trace!("shutdown prot"); + let _ = prot.send(ProtocolEvent::Shutdown).await; + }, + None => trace!("tried to remove protocol twice"), + }; + if send_protocols.is_empty() { + break; + } + }, + _ => (), + }; + + let cid = 0; + let active = match send_protocols.get_mut(&cid) { + Some(a) => a, + None => { + warn!("no channel arrg"); continue; }, }; - #[cfg(feature = "metrics")] - { - let cid_string = cid.to_string(); - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, frame.get_string()]) - .inc(); - } - match frame { - Frame::OpenStream { - sid, - prio, - promises, - } => { - trace!(?sid, ?prio, ?promises, "Opened frame from remote"); - let a2p_msg_s = a2p_msg_s.clone(); + + let active_err = async { + if let Some(Some((prio, promises, guaranteed_bandwidth, return_s))) = open { + trace!(?stream_ids, "openuing some new stream"); + let sid = stream_ids; + stream_ids += Sid::from(1); let stream = self - .create_stream(sid, prio, promises, a2p_msg_s, &a2b_close_stream_s) + .create_stream( + sid, + prio, + promises, + guaranteed_bandwidth, + &a2b_msg_s, + &a2b_close_stream_s, + ) .await; - match &b2a_stream_opened_s { - None => debug!("dropping openStream as Channel is already closing"), - Some(s) => { - if let Err(e) = s.send(stream) { - warn!( - ?e, - ?sid, - "couldn't notify api::Participant that a stream got opened. \ - Is the participant already dropped?" - ); - } - }, - } - }, - Frame::CloseStream { sid } => { - // no need to keep flushing as the remote no longer knows about this stream - // anyway - self.delete_stream( - sid, - None, - true, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await; - }, - Frame::DataHeader { mid, sid, length } => { - let imsg = IncomingMessage { - buffer: MessageBuffer { data: Vec::new() }, - length, - mid, + + let event = ProtocolEvent::OpenStream { sid, + prio, + promises, + guaranteed_bandwidth, }; - messages.insert(mid, imsg); - }, - Frame::Data { - mid, - start: _, - mut data, - } => { - let finished = if let Some(imsg) = messages.get_mut(&mid) { - imsg.buffer.data.append(&mut data); - imsg.buffer.data.len() as u64 == imsg.length - } else { - false - }; - if finished { - //trace!(?mid, "finished receiving message"); - let imsg = messages.remove(&mid).unwrap(); - if let Some(si) = self.streams.read().await.get(&imsg.sid) { - if let Err(e) = si.b2a_msg_recv_s.lock().await.send(imsg).await { - warn!( - ?e, - ?mid, - "Dropping message, as streams seem to be in act of being \ - dropped right now" - ); - } - } else { - //aggregate errors - let n = Instant::now(); - if dropped_cnt > 0 - && (dropped_sid != imsg.sid - || n.duration_since(dropped_instant) > Duration::from_secs(1)) - { - warn!( - ?dropped_cnt, - "Dropping multiple messages as stream no longer seems to \ - exist because it was dropped probably." - ); - dropped_cnt = 0; - dropped_instant = n; - dropped_sid = imsg.sid; - } else { - dropped_cnt += 1; - } - } - } - }, - Frame::Shutdown => { - debug!("Shutdown received from remote side"); - self.close_api(Some(ParticipantError::ParticipantDisconnected)) - .await; - }, - f => { - unreachable!( - "Frame should never reach participant!: {:?}, cid: {}", - f, cid - ); - }, + + return_s.send(stream).unwrap(); + active.send(event).await?; + } + + // get all messages and assign it to a channel + for (sid, buffer) in a2b_msg_r.try_iter() { + warn!(?sid, "sending!"); + active + .send(ProtocolEvent::Message { + buffer, + mid: 0u64, + sid, + }) + .await? + } + + if let Some(Some(sid)) = close { + warn!(?sid, "delete_stream!"); + self.delete_stream(sid).await; + // Fire&Forget the protocol will take care to verify that this Frame is delayed + // till the last msg was received! + active.send(ProtocolEvent::CloseStream { sid }).await?; + } + + warn!("flush!"); + active + .flush(1_000_000, Duration::from_secs(1) /* TODO */) + .await?; //this actually blocks, so we cant set streams whilte it. + let r: Result<(), network_protocol::ProtocolError> = Ok(()); + r + } + .await; + if let Err(e) = active_err { + info!(?cid, ?e, "send protocol failed, shutting down channel"); + // remote recv will now fail, which will trigger remote send which will trigger + // recv + send_protocols.remove(&cid).unwrap(); } } - if dropped_cnt > 0 { - warn!( - ?dropped_cnt, - "Dropping multiple messages as stream no longer seems to exist because it was \ - dropped probably." + trace!("Stop send_mgr"); + self.shutdown_barrier + .fetch_sub(Self::BARR_SEND, Ordering::Relaxed); + } + + async fn recv_mgr( + &self, + b2a_stream_opened_s: mpsc::UnboundedSender, + mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, RecvProtocols)>, + b2b_force_close_recv_protocol_r: async_channel::Receiver, + b2b_close_send_protocol_s: async_channel::Sender, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, + a2b_close_stream_s: mpsc::UnboundedSender, + ) { + let mut recv_protocols: HashMap> = HashMap::new(); + // we should be able to directly await futures imo + let (hacky_recv_s, mut hacky_recv_r) = mpsc::unbounded_channel(); + + let retrigger = |cid: Cid, mut p: RecvProtocols, map: &mut HashMap<_, _>| { + let hacky_recv_s = hacky_recv_s.clone(); + let handle = tokio::spawn(async move { + let cid = cid; + let r = p.recv().await; + let _ = hacky_recv_s.send((cid, r, p)); // ignoring failed + }); + map.insert(cid, handle); + }; + + let remove_c = |recv_protocols: &mut HashMap>, cid: &Cid| { + match recv_protocols.remove(&cid) { + Some(h) => h.abort(), + None => trace!("tried to remove protocol twice"), + }; + recv_protocols.is_empty() + }; + + trace!("Start recv_mgr"); + loop { + let (event, addp, remp) = select!( + next = hacky_recv_r.recv().fuse() => (Some(next), None, None), + Some(next) = b2b_add_protocol_r.recv().fuse() => (None, Some(next), None), + next = b2b_force_close_recv_protocol_r.recv().fuse() => (None, None, Some(next)), ); + + addp.map(|(cid, p)| { + retrigger(cid, p, &mut recv_protocols); + }); + if let Some(Ok(cid)) = remp { + // no need to stop the send_mgr here as it has been canceled before + if remove_c(&mut recv_protocols, &cid) { + break; + } + }; + + warn!(?event, "recv event!"); + if let Some(Some((cid, r, p))) = event { + match r { + Ok(ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + }) => { + trace!(?sid, "open stream"); + let stream = self + .create_stream( + sid, + prio, + promises, + guaranteed_bandwidth, + &a2b_msg_s, + &a2b_close_stream_s, + ) + .await; + b2a_stream_opened_s.send(stream).unwrap(); + retrigger(cid, p, &mut recv_protocols); + }, + Ok(ProtocolEvent::CloseStream { sid }) => { + trace!(?sid, "close stream"); + self.delete_stream(sid).await; + retrigger(cid, p, &mut recv_protocols); + }, + Ok(ProtocolEvent::Message { + buffer, + mid: _, + sid, + }) => { + let buffer = Arc::try_unwrap(buffer).unwrap(); + let lock = self.streams.read().await; + match lock.get(&sid) { + Some(stream) => { + stream + .b2a_msg_recv_s + .lock() + .await + .send(buffer) + .await + .unwrap(); + }, + None => warn!("recv a msg with orphan stream"), + }; + retrigger(cid, p, &mut recv_protocols); + }, + Ok(ProtocolEvent::Shutdown) => { + info!(?cid, "shutdown protocol"); + if let Err(e) = b2b_close_send_protocol_s.send(cid).await { + debug!(?e, ?cid, "send_mgr was already closed simultaneously"); + } + if remove_c(&mut recv_protocols, &cid) { + break; + } + }, + Err(e) => { + info!(?cid, ?e, "recv protocol failed, shutting down channel"); + if let Err(e) = b2b_close_send_protocol_s.send(cid).await { + debug!(?e, ?cid, "send_mgr was already closed simultaneously"); + } + if remove_c(&mut recv_protocols, &cid) { + break; + } + }, + } + } } - trace!("Stop handle_frames_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); + + trace!("Stop recv_mgr"); + self.shutdown_barrier + .fetch_sub(Self::BARR_RECV, Ordering::Relaxed); } async fn create_channel_mgr( &self, s2b_create_channel_r: mpsc::UnboundedReceiver, - w2b_frames_s: mpsc::UnboundedSender, + b2b_add_send_protocol_s: mpsc::UnboundedSender<(Cid, SendProtocols)>, + b2b_add_recv_protocol_s: mpsc::UnboundedSender<(Cid, RecvProtocols)>, ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); trace!("Start create_channel_mgr"); let s2b_create_channel_r = UnboundedReceiverStream::new(s2b_create_channel_r); s2b_create_channel_r - .for_each_concurrent( - None, - |(cid, _, protocol, leftover_cid_frame, b2s_create_channel_done_s)| { - // This channel is now configured, and we are running it in scope of the - // participant. - let w2b_frames_s = w2b_frames_s.clone(); - let channels = Arc::clone(&self.channels); - async move { - let (channel, b2w_frame_s, b2r_read_shutdown) = Channel::new(cid); - let mut lock = channels.write().await; - #[cfg(feature = "metrics")] - let mut channel_no = lock.len(); - #[cfg(not(feature = "metrics"))] - let channel_no = lock.len(); - lock.insert( + .for_each_concurrent(None, |(cid, _, protocol, b2s_create_channel_done_s)| { + // This channel is now configured, and we are running it in scope of the + // participant. + //let w2b_frames_s = w2b_frames_s.clone(); + let channels = Arc::clone(&self.channels); + let b2b_add_send_protocol_s = b2b_add_send_protocol_s.clone(); + let b2b_add_recv_protocol_s = b2b_add_recv_protocol_s.clone(); + async move { + let mut lock = channels.write().await; + #[cfg(feature = "metrics")] + let mut channel_no = lock.len(); + lock.insert( + cid, + Mutex::new(ChannelInfo { cid, - Mutex::new(ChannelInfo { - cid, - cid_string: cid.to_string(), - b2w_frame_s, - b2r_read_shutdown, - }), - ); - drop(lock); - b2s_create_channel_done_s.send(()).unwrap(); - #[cfg(feature = "metrics")] - { - self.metrics - .channels_connected_total - .with_label_values(&[&self.remote_pid_string]) - .inc(); - if channel_no > 5 { - debug!(?channel_no, "metrics will overwrite channel #5"); - channel_no = 5; - } - self.metrics - .participants_channel_ids - .with_label_values(&[ - &self.remote_pid_string, - &channel_no.to_string(), - ]) - .set(cid as i64); - } - trace!(?cid, ?channel_no, "Running channel in participant"); - channel - .run(protocol, w2b_frames_s, leftover_cid_frame) - .instrument(tracing::info_span!("", ?cid)) - .await; - #[cfg(feature = "metrics")] + cid_string: cid.to_string(), + }), + ); + drop(lock); + let (send, recv) = protocol.split(); + b2b_add_send_protocol_s.send((cid, send)).unwrap(); + b2b_add_recv_protocol_s.send((cid, recv)).unwrap(); + b2s_create_channel_done_s.send(()).unwrap(); + #[cfg(feature = "metrics")] + { self.metrics - .channels_disconnected_total + .channels_connected_total .with_label_values(&[&self.remote_pid_string]) .inc(); - info!(?cid, "Channel got closed"); - //maybe channel got already dropped, we don't know. - channels.write().await.remove(&cid); - trace!(?cid, "Channel cleanup completed"); - //TEMP FIX: as we dont have channel takeover yet drop the whole - // bParticipant - self.close_write_api(None).await; + if channel_no > 5 { + debug!(?channel_no, "metrics will overwrite channel #5"); + channel_no = 5; + } + self.metrics + .participants_channel_ids + .with_label_values(&[&self.remote_pid_string, &channel_no.to_string()]) + .set(cid as i64); } - }, - ) + } + }) .await; trace!("Stop create_channel_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); + self.shutdown_barrier + .fetch_sub(Self::BARR_CHANNEL, Ordering::Relaxed); } - async fn open_mgr( - &self, - mut a2b_stream_open_r: mpsc::UnboundedReceiver, - a2b_close_stream_s: mpsc::UnboundedSender, - a2p_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, - shutdown_open_mgr_receiver: oneshot::Receiver<()>, - ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); - trace!("Start open_mgr"); - let mut stream_ids = self.offset_sid; - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - let mut shutdown_open_mgr_receiver = shutdown_open_mgr_receiver.fuse(); - //from api or shutdown signal - while let Some((prio, promises, p2a_return_stream)) = select! { - next = a2b_stream_open_r.recv().fuse() => next, - _ = &mut shutdown_open_mgr_receiver => None, - } { - debug!(?prio, ?promises, "Got request to open a new steam"); - //TODO: a2b_stream_open_r isn't closed on api_close yet. This needs to change. - //till then just check here if we are closed and in that case do nothing (not - // even answer) - if self.shutdown_info.read().await.error.is_some() { - continue; - } - - let a2p_msg_s = a2p_msg_s.clone(); - let sid = stream_ids; - let stream = self - .create_stream(sid, prio, promises, a2p_msg_s, &a2b_close_stream_s) - .await; - if self - .send_frame( - Frame::OpenStream { - sid, - prio, - promises, - }, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await - { - //On error, we drop this, so it gets closed and client will handle this as an - // Err any way (: - p2a_return_stream.send(stream).unwrap(); - stream_ids += Sid::from(1); - } - } - trace!("Stop open_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); - } - - /// when activated this function will drop the participant completely and - /// wait for everything to go right! Then return 1. Shutting down - /// Streams for API and End user! 2. Wait for all "prio queued" Messages - /// to be send. 3. Send Stream + /// sink shutdown: + /// Situation AS, AR, BS, BR. A wants to close. + /// AS shutdown. + /// BR notices shutdown and tries to stops BS. (success) + /// BS shutdown + /// AR notices shutdown and tries to stop AS. (fails) + /// For the case where BS didn't get shutdowned, e.g. by a handing situation + /// on the remote, we have a timeout to also force close AR. + /// + /// This fn will: + /// - 1. stop api to interact with bparticipant by closing sendmsg and + /// openstream + /// - 2. stop the send_mgr (it will take care of clearing the + /// queue and finish with a Shutdown) + /// - (3). force stop recv after 60 + /// seconds + /// - (4). this fn finishes last and afterwards BParticipant + /// drops + /// + /// before calling this fn, make sure `s2b_create_channel` is closed! /// If BParticipant kills itself managers stay active till this function is /// called by api to get the result status async fn participant_shutdown_mgr( &self, s2b_shutdown_bparticipant_r: oneshot::Receiver, - shutdown_open_mgr_sender: oneshot::Sender<()>, - shutdown_stream_close_mgr_sender: oneshot::Sender>, - shutdown_send_mgr_sender: oneshot::Sender>, + b2b_close_send_protocol_s: async_channel::Sender, + b2b_force_close_recv_protocol_s: async_channel::Sender, ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); + let wait_for_manager = || async { + let mut sleep = 0.01f64; + loop { + let bytes = self.shutdown_barrier.load(Ordering::Relaxed); + if bytes == 0 { + break; + } + sleep *= 1.4; + tokio::time::sleep(Duration::from_secs_f64(sleep)).await; + if sleep > 0.2 { + trace!(?bytes, "wait for mgr to close"); + } + } + }; + trace!("Start participant_shutdown_mgr"); - let sender = s2b_shutdown_bparticipant_r.await.unwrap(); + let (timeout_time, sender) = s2b_shutdown_bparticipant_r.await.unwrap(); + debug!("participant_shutdown_mgr triggered"); - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - - self.close_api(None).await; - - debug!("Closing all managers"); - shutdown_open_mgr_sender - .send(()) - .expect("open_mgr must have crashed before"); - let (b2b_stream_close_shutdown_confirmed_s, b2b_stream_close_shutdown_confirmed_r) = - oneshot::channel(); - shutdown_stream_close_mgr_sender - .send(b2b_stream_close_shutdown_confirmed_s) - .expect("stream_close_mgr must have crashed before"); - // We need to wait for the stream_close_mgr BEFORE send_mgr, as the - // stream_close_mgr needs to wait on the API to drop `Stream` and be triggered - // It will then sleep for streams to be flushed in PRIO, and send_mgr is - // responsible for ticking PRIO WHILE this happens, so we cant close it before! - b2b_stream_close_shutdown_confirmed_r.await.unwrap(); - - //closing send_mgr now: - let (b2b_prios_flushed_s, b2b_prios_flushed_r) = oneshot::channel(); - shutdown_send_mgr_sender - .send(b2b_prios_flushed_s) - .expect("stream_close_mgr must have crashed before"); - b2b_prios_flushed_r.await.unwrap(); - - if Some(ParticipantError::ParticipantDisconnected) != self.shutdown_info.read().await.error + debug!("Closing all streams for send"); { - debug!("Sending shutdown frame after flushed all prios"); - if !self - .send_frame( - Frame::Shutdown, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await - { - warn!("couldn't send shutdown frame, are channels already closed?"); + let lock = self.streams.read().await; + for si in lock.values() { + si.send_closed.store(true, Ordering::Relaxed); } } - debug!("Closing all channels, after flushed prios"); - for (cid, ci) in self.channels.write().await.drain() { - let ci = ci.into_inner(); - if let Err(e) = ci.b2r_read_shutdown.send(()) { + let lock = self.channels.read().await; + assert!( + !lock.is_empty(), + "no channel existed remote_pid={}", + self.remote_pid + ); + for cid in lock.keys() { + if let Err(e) = b2b_close_send_protocol_s.send(*cid).await { debug!( ?e, ?cid, - "Seems like this read protocol got already dropped by closing the Stream \ - itself, ignoring" - ); - }; - } - - //Wait for other bparticipants mgr to close via AtomicUsize - const SLEEP_TIME: Duration = Duration::from_millis(5); - const ALLOWED_MANAGER: usize = 1; - tokio::time::sleep(SLEEP_TIME).await; - let mut i: u32 = 1; - while self.running_mgr.load(Ordering::Relaxed) > ALLOWED_MANAGER { - i += 1; - if i.rem_euclid(10) == 1 { - trace!( - ?ALLOWED_MANAGER, - "Waiting for bparticipant mgr to shut down, remaining {}", - self.running_mgr.load(Ordering::Relaxed) - ALLOWED_MANAGER + "closing send_mgr may fail if we got a recv error simultaneously" ); } - tokio::time::sleep(SLEEP_TIME * i).await; } - trace!("All BParticipant mgr (except me) are shut down now"); + drop(lock); + + trace!("wait for other managers"); + let timeout = tokio::time::sleep(timeout_time); + let timeout = tokio::select! { + _ = wait_for_manager() => false, + _ = timeout => true, + }; + if timeout { + warn!("timeout triggered: for killing recv"); + let lock = self.channels.read().await; + for cid in lock.keys() { + if let Err(e) = b2b_force_close_recv_protocol_s.send(*cid).await { + debug!( + ?e, + ?cid, + "closing recv_mgr may fail if we got a recv error simultaneously" + ); + } + } + } + + trace!("wait again"); + wait_for_manager().await; + + sender.send(Ok(())).unwrap(); #[cfg(feature = "metrics")] self.metrics.participants_disconnected_total.inc(); - debug!("BParticipant close done"); - - let mut lock = self.shutdown_info.write().await; - sender - .send(match lock.error.take() { - None => Ok(()), - Some(ParticipantError::ProtocolFailedUnrecoverable) => { - Err(ParticipantError::ProtocolFailedUnrecoverable) - }, - Some(ParticipantError::ParticipantDisconnected) => Ok(()), - }) - .unwrap(); - trace!("Stop participant_shutdown_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); - } - - async fn stream_close_mgr( - &self, - mut a2b_close_stream_r: mpsc::UnboundedReceiver, - shutdown_stream_close_mgr_receiver: oneshot::Receiver>, - b2p_notify_empty_stream_s: crossbeam_channel::Sender<(Sid, oneshot::Sender<()>)>, - ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); - trace!("Start stream_close_mgr"); - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - let mut shutdown_stream_close_mgr_receiver = shutdown_stream_close_mgr_receiver.fuse(); - let mut b2b_stream_close_shutdown_confirmed_s = None; - - //from api or shutdown signal - while let Some(sid) = select! { - next = a2b_close_stream_r.recv().fuse() => next, - sender = &mut shutdown_stream_close_mgr_receiver => { - b2b_stream_close_shutdown_confirmed_s = Some(sender.unwrap()); - None - } - } { - //TODO: make this concurrent! - //TODO: Performance, closing is slow! - self.delete_stream( - sid, - Some(b2p_notify_empty_stream_s.clone()), - false, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await; - } - trace!("deleting all leftover streams"); - let sids = self - .streams - .read() - .await - .keys() - .cloned() - .collect::>(); - for sid in sids { - //flushing is still important, e.g. when Participant::drop is called (but - // Stream:drop isn't)! - self.delete_stream( - sid, - Some(b2p_notify_empty_stream_s.clone()), - false, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await; - } - if b2b_stream_close_shutdown_confirmed_s.is_none() { - b2b_stream_close_shutdown_confirmed_s = - Some(shutdown_stream_close_mgr_receiver.await.unwrap()); - } - b2b_stream_close_shutdown_confirmed_s - .unwrap() - .send(()) - .unwrap(); - trace!("Stop stream_close_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); } + /// Stopping API and participant usage + /// Protocol will take care of the order of the frame async fn delete_stream( &self, sid: Sid, - b2p_notify_empty_stream_s: Option)>>, - from_remote: bool, - #[cfg(feature = "metrics")] frames_out_total_cache: &mut MultiCidFrameCache, + /* #[cfg(feature = "metrics")] frames_out_total_cache: &mut MultiCidFrameCache, */ ) { - //This needs to first stop clients from sending any more. - //Then it will wait for all pending messages (in prio) to be send to the - // protocol After this happened the stream is closed - //Only after all messages are send to the protocol, we can send the CloseStream - // frame! If we would send it before, all followup messages couldn't - // be handled at the remote side. - async { - trace!("Stopping api to use this stream"); - match self.streams.read().await.get(&sid) { - Some(si) => { - si.send_closed.store(true, Ordering::Relaxed); - si.b2a_msg_recv_s.lock().await.close(); - }, - None => trace!( - "Couldn't find the stream, might be simultaneous close from local/remote" - ), - } - - if !from_remote { - trace!("Wait for stream to be flushed"); - let (s2b_stream_finished_closed_s, s2b_stream_finished_closed_r) = - oneshot::channel(); - b2p_notify_empty_stream_s - .expect("needs to be set when from_remote is false") - .send((sid, s2b_stream_finished_closed_s)) - .unwrap(); - s2b_stream_finished_closed_r.await.unwrap(); - - trace!("Stream was successfully flushed"); - } - - #[cfg(feature = "metrics")] - self.metrics - .streams_closed_total - .with_label_values(&[&self.remote_pid_string]) - .inc(); - //only now remove the Stream, that means we can still recv on it. - self.streams.write().await.remove(&sid); - - if !from_remote { - self.send_frame( - Frame::CloseStream { sid }, - #[cfg(feature = "metrics")] - frames_out_total_cache, - ) - .await; - } + let stream = { self.streams.write().await.remove(&sid) }; + match stream { + Some(si) => { + si.send_closed.store(true, Ordering::Relaxed); + si.b2a_msg_recv_s.lock().await.close(); + }, + None => { + trace!("Couldn't find the stream, might be simultaneous close from local/remote") + }, } - .instrument(tracing::info_span!("close", ?sid, ?from_remote)) - .await; + /* + #[cfg(feature = "metrics")] + self.metrics + .streams_closed_total + .with_label_values(&[&self.remote_pid_string]) + .inc();*/ } async fn create_stream( @@ -843,10 +614,11 @@ impl BParticipant { sid: Sid, prio: Prio, promises: Promises, - a2p_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, + guaranteed_bandwidth: Bandwidth, + a2b_msg_s: &crossbeam_channel::Sender<(Sid, Arc)>, a2b_close_stream_s: &mpsc::UnboundedSender, ) -> Stream { - let (b2a_msg_recv_s, b2a_msg_recv_r) = async_channel::unbounded::(); + let (b2a_msg_recv_s, b2a_msg_recv_r) = async_channel::unbounded::(); let send_closed = Arc::new(AtomicBool::new(false)); self.streams.write().await.insert(sid, StreamInfo { prio, @@ -864,38 +636,248 @@ impl BParticipant { sid, prio, promises, + guaranteed_bandwidth, send_closed, - a2p_msg_s, + a2b_msg_s.clone(), b2a_msg_recv_r, a2b_close_stream_s.clone(), ) } +} - async fn close_write_api(&self, reason: Option) { - trace!(?reason, "close_api"); - let mut lock = self.shutdown_info.write().await; - if let Some(r) = reason { - lock.error = Some(r); - } - lock.b2b_close_stream_opened_sender_s - .take() - .map(|s| s.send(())); +#[cfg(test)] +mod tests { + use super::*; + use tokio::{ + runtime::Runtime, + sync::{mpsc, oneshot}, + task::JoinHandle, + }; - debug!("Closing all streams for write"); - for (sid, si) in self.streams.read().await.iter() { - trace!(?sid, "Shutting down Stream for write"); - si.send_closed.store(true, Ordering::Relaxed); - } + fn mock_bparticipant() -> ( + Arc, + mpsc::UnboundedSender, + mpsc::UnboundedReceiver, + mpsc::UnboundedSender, + oneshot::Sender, + mpsc::UnboundedReceiver, + JoinHandle<()>, + ) { + let runtime = Arc::new(tokio::runtime::Runtime::new().unwrap()); + let runtime_clone = Arc::clone(&runtime); + + let (b2s_prio_statistic_s, b2s_prio_statistic_r) = + mpsc::unbounded_channel::(); + + let ( + bparticipant, + a2b_open_stream_s, + b2a_stream_opened_r, + s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + ) = runtime_clone.block_on(async move { + let pid = Pid::fake(1); + let sid = Sid::new(1000); + let metrics = Arc::new(NetworkMetrics::new(&pid).unwrap()); + + BParticipant::new(pid, sid, Arc::clone(&metrics)) + }); + + let handle = runtime_clone.spawn(bparticipant.run(b2s_prio_statistic_s)); + ( + runtime_clone, + a2b_open_stream_s, + b2a_stream_opened_r, + s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) } - ///closing api::Participant is done by closing all channels, expect for the - /// shutdown channel at this point! - async fn close_api(&self, reason: Option) { - self.close_write_api(reason).await; - debug!("Closing all streams"); - for (sid, si) in self.streams.read().await.iter() { - trace!(?sid, "Shutting down Stream"); - si.b2a_msg_recv_s.lock().await.close(); - } + async fn mock_mpsc( + cid: Cid, + _runtime: &Arc, + create_channel: &mut mpsc::UnboundedSender, + ) -> Protocols { + let (s1, r1) = mpsc::channel(100); + let (s2, r2) = mpsc::channel(100); + let p1 = Protocols::new_mpsc(s1, r2); + let (complete_s, complete_r) = oneshot::channel(); + create_channel + .send((cid, Sid::new(0), p1, complete_s)) + .unwrap(); + complete_r.await.unwrap(); + Protocols::new_mpsc(s2, r1) + } + + #[test] + fn close_bparticipant_by_timeout_during_close() { + let ( + runtime, + a2b_open_stream_s, + b2a_stream_opened_r, + mut s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) = mock_bparticipant(); + + let _remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s)); + std::thread::sleep(Duration::from_millis(50)); + + let (s, r) = oneshot::channel(); + let before = Instant::now(); + runtime.block_on(async { + drop(s2b_create_channel_s); + s2b_shutdown_bparticipant_s + .send((Duration::from_secs(1), s)) + .unwrap(); + r.await.unwrap().unwrap(); + }); + assert!( + before.elapsed() > Duration::from_millis(900), + "timeout wasn't triggered" + ); + + runtime.block_on(handle).unwrap(); + + drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r)); + drop(runtime); + } + + #[test] + fn close_bparticipant_cleanly() { + let ( + runtime, + a2b_open_stream_s, + b2a_stream_opened_r, + mut s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) = mock_bparticipant(); + + let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s)); + std::thread::sleep(Duration::from_millis(50)); + + let (s, r) = oneshot::channel(); + let before = Instant::now(); + runtime.block_on(async { + drop(s2b_create_channel_s); + s2b_shutdown_bparticipant_s + .send((Duration::from_secs(2), s)) + .unwrap(); + drop(remote); // remote needs to be dropped as soon as local.sender is closed + r.await.unwrap().unwrap(); + }); + assert!( + before.elapsed() < Duration::from_millis(1900), + "timeout was triggered" + ); + + runtime.block_on(handle).unwrap(); + + drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r)); + drop(runtime); + } + + #[test] + fn create_stream() { + let ( + runtime, + a2b_open_stream_s, + b2a_stream_opened_r, + mut s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) = mock_bparticipant(); + + let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s)); + std::thread::sleep(Duration::from_millis(50)); + + // created stream + let (rs, mut rr) = remote.split(); + let (stream_sender, _stream_receiver) = oneshot::channel(); + a2b_open_stream_s + .send((7u8, Promises::ENCRYPTED, 1_000_000, stream_sender)) + .unwrap(); + + let stream_event = runtime.block_on(rr.recv()).unwrap(); + match stream_event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + assert_eq!(sid, Sid::new(1000)); + assert_eq!(prio, 7u8); + assert_eq!(promises, Promises::ENCRYPTED); + assert_eq!(guaranteed_bandwidth, 1_000_000); + }, + _ => panic!("wrong event"), + }; + + let (s, r) = oneshot::channel(); + runtime.block_on(async { + drop(s2b_create_channel_s); + s2b_shutdown_bparticipant_s + .send((Duration::from_secs(1), s)) + .unwrap(); + drop((rs, rr)); + r.await.unwrap().unwrap(); + }); + + runtime.block_on(handle).unwrap(); + + drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r)); + drop(runtime); + } + + #[test] + fn created_stream() { + let ( + runtime, + a2b_open_stream_s, + mut b2a_stream_opened_r, + mut s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) = mock_bparticipant(); + + let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s)); + std::thread::sleep(Duration::from_millis(50)); + + // create stream + let (mut rs, rr) = remote.split(); + runtime + .block_on(rs.send(ProtocolEvent::OpenStream { + sid: Sid::new(1000), + prio: 9u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + })) + .unwrap(); + + let stream = runtime.block_on(b2a_stream_opened_r.recv()).unwrap(); + assert_eq!(stream.promises(), Promises::ORDERED); + + let (s, r) = oneshot::channel(); + runtime.block_on(async { + drop(s2b_create_channel_s); + s2b_shutdown_bparticipant_s + .send((Duration::from_secs(1), s)) + .unwrap(); + drop((rs, rr)); + r.await.unwrap().unwrap(); + }); + + runtime.block_on(handle).unwrap(); + + drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r)); + drop(runtime); } } diff --git a/network/src/prios.rs b/network/src/prios.rs deleted file mode 100644 index a544a31241..0000000000 --- a/network/src/prios.rs +++ /dev/null @@ -1,697 +0,0 @@ -//!Priorities are handled the following way. -//!Prios from 0-63 are allowed. -//!all 5 numbers the throughput is halved. -//!E.g. in the same time 100 prio0 messages are send, only 50 prio5, 25 prio10, -//! 12 prio15 or 6 prio20 messages are send. Note: TODO: prio0 will be send -//! immediately when found! -#[cfg(feature = "metrics")] -use crate::metrics::NetworkMetrics; -use crate::{ - message::OutgoingMessage, - types::{Frame, Prio, Sid}, -}; -use crossbeam_channel::{unbounded, Receiver, Sender}; -use std::collections::{HashMap, HashSet, VecDeque}; -#[cfg(feature = "metrics")] use std::sync::Arc; -use tokio::sync::oneshot; -use tracing::trace; - -const PRIO_MAX: usize = 64; - -#[derive(Default)] -struct PidSidInfo { - len: u64, - empty_notify: Option>, -} - -pub(crate) struct PrioManager { - points: [u32; PRIO_MAX], - messages: [VecDeque<(Sid, OutgoingMessage)>; PRIO_MAX], - messages_rx: Receiver<(Prio, Sid, OutgoingMessage)>, - sid_owned: HashMap, - //you can register to be notified if a pid_sid combination is flushed completely here - sid_flushed_rx: Receiver<(Sid, oneshot::Sender<()>)>, - queued: HashSet, - #[cfg(feature = "metrics")] - metrics: Arc, - #[cfg(feature = "metrics")] - pid: String, -} - -impl PrioManager { - const PRIOS: [u32; PRIO_MAX] = [ - 100, 115, 132, 152, 174, 200, 230, 264, 303, 348, 400, 459, 528, 606, 696, 800, 919, 1056, - 1213, 1393, 1600, 1838, 2111, 2425, 2786, 3200, 3676, 4222, 4850, 5572, 6400, 7352, 8445, - 9701, 11143, 12800, 14703, 16890, 19401, 22286, 25600, 29407, 33779, 38802, 44572, 51200, - 58813, 67559, 77605, 89144, 102400, 117627, 135118, 155209, 178289, 204800, 235253, 270235, - 310419, 356578, 409600, 470507, 540470, 620838, - ]; - - #[allow(clippy::type_complexity)] - pub fn new( - #[cfg(feature = "metrics")] metrics: Arc, - pid: String, - ) -> ( - Self, - Sender<(Prio, Sid, OutgoingMessage)>, - Sender<(Sid, oneshot::Sender<()>)>, - ) { - #[cfg(not(feature = "metrics"))] - let _pid = pid; - // (a2p_msg_s, a2p_msg_r) - let (messages_tx, messages_rx) = unbounded(); - let (sid_flushed_tx, sid_flushed_rx) = unbounded(); - ( - Self { - points: [0; PRIO_MAX], - messages: [ - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - ], - messages_rx, - queued: HashSet::new(), //TODO: optimize with u64 and 64 bits - sid_flushed_rx, - sid_owned: HashMap::new(), - #[cfg(feature = "metrics")] - metrics, - #[cfg(feature = "metrics")] - pid, - }, - messages_tx, - sid_flushed_tx, - ) - } - - async fn tick(&mut self) { - // Check Range - for (prio, sid, msg) in self.messages_rx.try_iter() { - debug_assert!(prio as usize <= PRIO_MAX); - #[cfg(feature = "metrics")] - { - let sid_string = sid.to_string(); - self.metrics - .message_out_total - .with_label_values(&[&self.pid, &sid_string]) - .inc(); - self.metrics - .message_out_throughput - .with_label_values(&[&self.pid, &sid_string]) - .inc_by(msg.buffer.data.len() as u64); - } - - //trace!(?prio, ?sid_string, "tick"); - self.queued.insert(prio); - self.messages[prio as usize].push_back((sid, msg)); - self.sid_owned.entry(sid).or_default().len += 1; - } - //this must be AFTER messages - for (sid, return_sender) in self.sid_flushed_rx.try_iter() { - #[cfg(feature = "metrics")] - self.metrics - .streams_flushed - .with_label_values(&[&self.pid]) - .inc(); - if let Some(cnt) = self.sid_owned.get_mut(&sid) { - // register sender - cnt.empty_notify = Some(return_sender); - trace!(?sid, "register empty notify"); - } else { - // return immediately - return_sender.send(()).unwrap(); - trace!(?sid, "return immediately that stream is empty"); - } - } - } - - //if None returned, we are empty! - fn calc_next_prio(&self) -> Option { - // compare all queued prios, max 64 operations - let mut lowest = std::u32::MAX; - let mut lowest_id = None; - for &n in &self.queued { - let n_points = self.points[n as usize]; - if n_points < lowest { - lowest = n_points; - lowest_id = Some(n) - } else if n_points == lowest && lowest_id.is_some() && n < lowest_id.unwrap() { - //on equal points lowest first! - lowest_id = Some(n) - } - } - lowest_id - /* - self.queued - .iter() - .min_by_key(|&n| self.points[*n as usize]).cloned()*/ - } - - /// no_of_frames = frames.len() - /// Your goal is to try to find a realistic no_of_frames! - /// no_of_frames should be choosen so, that all Frames can be send out till - /// the next tick! - /// - if no_of_frames is too high you will fill either the Socket buffer, - /// or your internal buffer. In that case you will increase latency for - /// high prio messages! - /// - if no_of_frames is too low you wont saturate your Socket fully, thus - /// have a lower bandwidth as possible - pub async fn fill_frames>( - &mut self, - no_of_frames: usize, - frames: &mut E, - ) { - for v in self.messages.iter_mut() { - v.reserve_exact(no_of_frames) - } - self.tick().await; - for _ in 0..no_of_frames { - match self.calc_next_prio() { - Some(prio) => { - //let prio2 = self.calc_next_prio().unwrap(); - //trace!(?prio, "handle next prio"); - self.points[prio as usize] += Self::PRIOS[prio as usize]; - //pop message from front of VecDeque, handle it and push it back, so that all - // => messages with same prio get a fair chance :) - //TODO: evaluate not popping every time - let (sid, mut msg) = self.messages[prio as usize].pop_front().unwrap(); - if msg.fill_next(sid, frames) { - //trace!(?m.mid, "finish message"); - //check if prio is empty - if self.messages[prio as usize].is_empty() { - self.queued.remove(&prio); - } - //decrease pid_sid counter by 1 again - let cnt = self.sid_owned.get_mut(&sid).expect( - "The pid_sid_owned counter works wrong, more pid,sid removed than \ - inserted", - ); - cnt.len -= 1; - if cnt.len == 0 { - let cnt = self.sid_owned.remove(&sid).unwrap(); - if let Some(empty_notify) = cnt.empty_notify { - empty_notify.send(()).unwrap(); - trace!(?sid, "returned that stream is empty"); - } - } - } else { - self.messages[prio as usize].push_front((sid, msg)); - } - }, - None => { - //QUEUE is empty, we are clearing the POINTS to not build up huge pipes of - // POINTS on a prio from the past - self.points = [0; PRIO_MAX]; - break; - }, - } - } - } -} - -impl std::fmt::Debug for PrioManager { - #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut cnt = 0; - for m in self.messages.iter() { - cnt += m.len(); - } - write!(f, "PrioManager(len: {}, queued: {:?})", cnt, &self.queued,) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - message::{MessageBuffer, OutgoingMessage}, - metrics::NetworkMetrics, - prios::*, - types::{Frame, Pid, Prio, Sid}, - }; - use crossbeam_channel::Sender; - use std::{collections::VecDeque, sync::Arc}; - use tokio::{runtime::Runtime, sync::oneshot}; - - const SIZE: u64 = OutgoingMessage::FRAME_DATA_SIZE; - const USIZE: usize = OutgoingMessage::FRAME_DATA_SIZE as usize; - - #[allow(clippy::type_complexity)] - fn mock_new() -> ( - PrioManager, - Sender<(Prio, Sid, OutgoingMessage)>, - Sender<(Sid, oneshot::Sender<()>)>, - ) { - let pid = Pid::fake(1); - PrioManager::new( - Arc::new(NetworkMetrics::new(&pid).unwrap()), - pid.to_string(), - ) - } - - fn mock_out(prio: Prio, sid: u64) -> (Prio, Sid, OutgoingMessage) { - let sid = Sid::new(sid); - (prio, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { - data: vec![48, 49, 50], - }), - cursor: 0, - mid: 1, - sid, - }) - } - - fn mock_out_large(prio: Prio, sid: u64) -> (Prio, Sid, OutgoingMessage) { - let sid = Sid::new(sid); - let mut data = vec![48; USIZE]; - data.append(&mut vec![49; USIZE]); - data.append(&mut vec![50; 20]); - (prio, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { data }), - cursor: 0, - mid: 1, - sid, - }) - } - - fn assert_header(frames: &mut VecDeque<(Sid, Frame)>, f_sid: u64, f_length: u64) { - let frame = frames - .pop_front() - .expect("Frames vecdeque doesn't contain enough frames!") - .1; - if let Frame::DataHeader { mid, sid, length } = frame { - assert_eq!(mid, 1); - assert_eq!(sid, Sid::new(f_sid)); - assert_eq!(length, f_length); - } else { - panic!("Wrong frame type!, expected DataHeader"); - } - } - - fn assert_data(frames: &mut VecDeque<(Sid, Frame)>, f_start: u64, f_data: Vec) { - let frame = frames - .pop_front() - .expect("Frames vecdeque doesn't contain enough frames!") - .1; - if let Frame::Data { mid, start, data } = frame { - assert_eq!(mid, 1); - assert_eq!(start, f_start); - assert_eq!(data, f_data); - } else { - panic!("Wrong frame type!, expected Data"); - } - } - - #[test] - fn single_p16() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(16, 1337)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 1337, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert!(frames.is_empty()); - } - - #[test] - fn single_p16_p20() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(16, 1337)).unwrap(); - msg_tx.send(mock_out(20, 42)).unwrap(); - let mut frames = VecDeque::new(); - - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - assert_header(&mut frames, 1337, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 42, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert!(frames.is_empty()); - } - - #[test] - fn single_p20_p16() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(20, 42)).unwrap(); - msg_tx.send(mock_out(16, 1337)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 1337, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 42, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert!(frames.is_empty()); - } - - #[test] - fn multiple_p16_p20() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(20, 2)).unwrap(); - msg_tx.send(mock_out(16, 1)).unwrap(); - msg_tx.send(mock_out(16, 3)).unwrap(); - msg_tx.send(mock_out(16, 5)).unwrap(); - msg_tx.send(mock_out(20, 4)).unwrap(); - msg_tx.send(mock_out(20, 7)).unwrap(); - msg_tx.send(mock_out(16, 6)).unwrap(); - msg_tx.send(mock_out(20, 10)).unwrap(); - msg_tx.send(mock_out(16, 8)).unwrap(); - msg_tx.send(mock_out(20, 12)).unwrap(); - msg_tx.send(mock_out(16, 9)).unwrap(); - msg_tx.send(mock_out(16, 11)).unwrap(); - msg_tx.send(mock_out(20, 13)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - for i in 1..14 { - assert_header(&mut frames, i, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - } - assert!(frames.is_empty()); - } - - #[test] - fn multiple_fill_frames_p16_p20() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(20, 2)).unwrap(); - msg_tx.send(mock_out(16, 1)).unwrap(); - msg_tx.send(mock_out(16, 3)).unwrap(); - msg_tx.send(mock_out(16, 5)).unwrap(); - msg_tx.send(mock_out(20, 4)).unwrap(); - msg_tx.send(mock_out(20, 7)).unwrap(); - msg_tx.send(mock_out(16, 6)).unwrap(); - msg_tx.send(mock_out(20, 10)).unwrap(); - msg_tx.send(mock_out(16, 8)).unwrap(); - msg_tx.send(mock_out(20, 12)).unwrap(); - msg_tx.send(mock_out(16, 9)).unwrap(); - msg_tx.send(mock_out(16, 11)).unwrap(); - msg_tx.send(mock_out(20, 13)).unwrap(); - - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(3, &mut frames)); - for i in 1..4 { - assert_header(&mut frames, i, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - } - assert!(frames.is_empty()); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(11, &mut frames)); - for i in 4..14 { - assert_header(&mut frames, i, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - } - assert!(frames.is_empty()); - } - - #[test] - fn single_large_p16() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out_large(16, 1)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 1, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert!(frames.is_empty()); - } - - #[test] - fn multiple_large_p16() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out_large(16, 1)).unwrap(); - msg_tx.send(mock_out_large(16, 2)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 1, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert_header(&mut frames, 2, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert!(frames.is_empty()); - } - - #[test] - fn multiple_large_p16_sudden_p0() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out_large(16, 1)).unwrap(); - msg_tx.send(mock_out_large(16, 2)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2, &mut frames)); - - assert_header(&mut frames, 1, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - - msg_tx.send(mock_out(0, 3)).unwrap(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 3, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert_header(&mut frames, 2, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert!(frames.is_empty()); - } - - #[test] - fn single_p20_thousand_p16_at_once() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - for _ in 0..998 { - msg_tx.send(mock_out(16, 2)).unwrap(); - } - msg_tx.send(mock_out(20, 1)).unwrap(); - msg_tx.send(mock_out(16, 2)).unwrap(); - msg_tx.send(mock_out(16, 2)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - - assert_header(&mut frames, 2, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 1, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 2, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 2, 3); - //unimportant - } - - #[test] - fn single_p20_thousand_p16_later() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - for _ in 0..998 { - msg_tx.send(mock_out(16, 2)).unwrap(); - } - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - //^unimportant frames, gonna be dropped - msg_tx.send(mock_out(20, 1)).unwrap(); - msg_tx.send(mock_out(16, 2)).unwrap(); - msg_tx.send(mock_out(16, 2)).unwrap(); - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - - //important in that test is, that after the first frames got cleared i reset - // the Points even though 998 prio 16 messages have been send at this - // point and 0 prio20 messages the next message is a prio16 message - // again, and only then prio20! we dont want to build dept over a idling - // connection - assert_header(&mut frames, 2, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 1, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 2, 3); - //unimportant - } - - #[test] - fn gigantic_message() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - let mut data = vec![1; USIZE]; - data.extend_from_slice(&[2; USIZE]); - data.extend_from_slice(&[3; USIZE]); - data.extend_from_slice(&[4; USIZE]); - data.extend_from_slice(&[5; USIZE]); - let sid = Sid::new(2); - msg_tx - .send((16, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { data }), - cursor: 0, - mid: 1, - sid, - })) - .unwrap(); - - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - - assert_header(&mut frames, 2, 7000); - assert_data(&mut frames, 0, vec![1; USIZE]); - assert_data(&mut frames, 1400, vec![2; USIZE]); - assert_data(&mut frames, 2800, vec![3; USIZE]); - assert_data(&mut frames, 4200, vec![4; USIZE]); - assert_data(&mut frames, 5600, vec![5; USIZE]); - } - - #[test] - fn gigantic_message_order() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - let mut data = vec![1; USIZE]; - data.extend_from_slice(&[2; USIZE]); - data.extend_from_slice(&[3; USIZE]); - data.extend_from_slice(&[4; USIZE]); - data.extend_from_slice(&[5; USIZE]); - let sid = Sid::new(2); - msg_tx - .send((16, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { data }), - cursor: 0, - mid: 1, - sid, - })) - .unwrap(); - msg_tx.send(mock_out(16, 8)).unwrap(); - - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - - assert_header(&mut frames, 2, 7000); - assert_data(&mut frames, 0, vec![1; USIZE]); - assert_data(&mut frames, 1400, vec![2; USIZE]); - assert_data(&mut frames, 2800, vec![3; USIZE]); - assert_data(&mut frames, 4200, vec![4; USIZE]); - assert_data(&mut frames, 5600, vec![5; USIZE]); - assert_header(&mut frames, 8, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - } - - #[test] - fn gigantic_message_order_other_prio() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - let mut data = vec![1; USIZE]; - data.extend_from_slice(&[2; USIZE]); - data.extend_from_slice(&[3; USIZE]); - data.extend_from_slice(&[4; USIZE]); - data.extend_from_slice(&[5; USIZE]); - let sid = Sid::new(2); - msg_tx - .send((16, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { data }), - cursor: 0, - mid: 1, - sid, - })) - .unwrap(); - msg_tx.send(mock_out(20, 8)).unwrap(); - - let mut frames = VecDeque::new(); - Runtime::new() - .unwrap() - .block_on(mgr.fill_frames(2000, &mut frames)); - - assert_header(&mut frames, 2, 7000); - assert_data(&mut frames, 0, vec![1; USIZE]); - assert_header(&mut frames, 8, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_data(&mut frames, 1400, vec![2; USIZE]); - assert_data(&mut frames, 2800, vec![3; USIZE]); - assert_data(&mut frames, 4200, vec![4; USIZE]); - assert_data(&mut frames, 5600, vec![5; USIZE]); - } -} diff --git a/network/src/protocols.rs b/network/src/protocols.rs deleted file mode 100644 index a18c1e1cbd..0000000000 --- a/network/src/protocols.rs +++ /dev/null @@ -1,591 +0,0 @@ -#[cfg(feature = "metrics")] -use crate::metrics::{CidFrameCache, NetworkMetrics}; -use crate::{ - participant::C2pFrame, - types::{Cid, Frame}, -}; -use futures_util::{future::Fuse, FutureExt}; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::{TcpStream, UdpSocket}, - select, - sync::{mpsc, oneshot, Mutex}, -}; - -use std::{convert::TryFrom, net::SocketAddr, sync::Arc}; -use tracing::*; - -// Reserving bytes 0, 10, 13 as i have enough space and want to make it easy to -// detect a invalid client, e.g. sending an empty line would make 10 first char -// const FRAME_RESERVED_1: u8 = 0; -const FRAME_HANDSHAKE: u8 = 1; -const FRAME_INIT: u8 = 2; -const FRAME_SHUTDOWN: u8 = 3; -const FRAME_OPEN_STREAM: u8 = 4; -const FRAME_CLOSE_STREAM: u8 = 5; -const FRAME_DATA_HEADER: u8 = 6; -const FRAME_DATA: u8 = 7; -const FRAME_RAW: u8 = 8; -//const FRAME_RESERVED_2: u8 = 10; -//const FRAME_RESERVED_3: u8 = 13; - -#[derive(Debug)] -pub(crate) enum Protocols { - Tcp(TcpProtocol), - Udp(UdpProtocol), - //Mpsc(MpscChannel), -} - -#[derive(Debug)] -pub(crate) struct TcpProtocol { - read_stream: tokio::sync::Mutex, - write_stream: tokio::sync::Mutex, - #[cfg(feature = "metrics")] - metrics: Arc, -} - -#[derive(Debug)] -pub(crate) struct UdpProtocol { - socket: Arc, - remote_addr: SocketAddr, - #[cfg(feature = "metrics")] - metrics: Arc, - data_in: Mutex>>, -} - -//TODO: PERFORMACE: Use BufWriter and BufReader from std::io! -impl TcpProtocol { - pub(crate) fn new( - stream: TcpStream, - #[cfg(feature = "metrics")] metrics: Arc, - ) -> Self { - let (read_stream, write_stream) = stream.into_split(); - Self { - read_stream: tokio::sync::Mutex::new(read_stream), - write_stream: tokio::sync::Mutex::new(write_stream), - #[cfg(feature = "metrics")] - metrics, - } - } - - async fn read_frame( - r: &mut R, - end_receiver: &mut Fuse>, - ) -> Result> { - let handle = |read_result| match read_result { - Ok(_) => Ok(()), - Err(e) => Err(Some(e)), - }; - - let mut frame_no = [0u8; 1]; - match select! { - r = r.read_exact(&mut frame_no).fuse() => Some(r), - _ = end_receiver => None, - } { - Some(read_result) => handle(read_result)?, - None => { - trace!("shutdown requested"); - return Err(None); - }, - }; - - match frame_no[0] { - FRAME_HANDSHAKE => { - let mut bytes = [0u8; 19]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_handshake(bytes)) - }, - FRAME_INIT => { - let mut bytes = [0u8; 32]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_init(bytes)) - }, - FRAME_SHUTDOWN => Ok(Frame::Shutdown), - FRAME_OPEN_STREAM => { - let mut bytes = [0u8; 10]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_open_stream(bytes)) - }, - FRAME_CLOSE_STREAM => { - let mut bytes = [0u8; 8]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_close_stream(bytes)) - }, - FRAME_DATA_HEADER => { - let mut bytes = [0u8; 24]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_data_header(bytes)) - }, - FRAME_DATA => { - let mut bytes = [0u8; 18]; - handle(r.read_exact(&mut bytes).await)?; - let (mid, start, length) = Frame::gen_data(bytes); - let mut data = vec![0; length as usize]; - handle(r.read_exact(&mut data).await)?; - Ok(Frame::Data { mid, start, data }) - }, - FRAME_RAW => { - let mut bytes = [0u8; 2]; - handle(r.read_exact(&mut bytes).await)?; - let length = Frame::gen_raw(bytes); - let mut data = vec![0; length as usize]; - handle(r.read_exact(&mut data).await)?; - Ok(Frame::Raw(data)) - }, - other => { - // report a RAW frame, but cannot rely on the next 2 bytes to be a size. - // guessing 32 bytes, which might help to sort down issues - let mut data = vec![0; 32]; - //keep the first byte! - match r.read(&mut data[1..]).await { - Ok(n) => { - data.truncate(n + 1); - Ok(()) - }, - Err(e) => Err(Some(e)), - }?; - data[0] = other; - warn!(?data, "got a unexpected RAW msg"); - Ok(Frame::Raw(data)) - }, - } - } - - pub async fn read_from_wire( - &self, - cid: Cid, - w2c_cid_frame_s: &mut mpsc::UnboundedSender, - end_r: oneshot::Receiver<()>, - ) { - trace!("Starting up tcp read()"); - #[cfg(feature = "metrics")] - let mut metrics_cache = CidFrameCache::new(self.metrics.frames_wire_in_total.clone(), cid); - #[cfg(feature = "metrics")] - let throughput_cache = self - .metrics - .wire_in_throughput - .with_label_values(&[&cid.to_string()]); - let mut read_stream = self.read_stream.lock().await; - let mut end_r = end_r.fuse(); - - loop { - match Self::read_frame(&mut *read_stream, &mut end_r).await { - Ok(frame) => { - #[cfg(feature = "metrics")] - { - metrics_cache.with_label_values(&frame).inc(); - if let Frame::Data { - mid: _, - start: _, - ref data, - } = frame - { - throughput_cache.inc_by(data.len() as u64); - } - } - if let Err(e) = w2c_cid_frame_s.send((cid, Ok(frame))) { - warn!(?e, "Channel or Participant seems no longer to exist"); - } - }, - Err(e_option) => { - if let Some(e) = e_option { - info!(?e, "Closing tcp protocol due to read error"); - //w2c_cid_frame_s is shared, dropping it wouldn't notify the receiver as - // every channel is holding a sender! thats why Ne - // need a explicit STOP here - if let Err(e) = w2c_cid_frame_s.send((cid, Err(()))) { - warn!(?e, "Channel or Participant seems no longer to exist"); - } - } - //None is clean shutdown - break; - }, - } - } - trace!("Shutting down tcp read()"); - } - - pub async fn write_frame( - w: &mut W, - frame: Frame, - ) -> Result<(), std::io::Error> { - match frame { - Frame::Handshake { - magic_number, - version, - } => { - w.write_all(&FRAME_HANDSHAKE.to_be_bytes()).await?; - w.write_all(&magic_number).await?; - w.write_all(&version[0].to_le_bytes()).await?; - w.write_all(&version[1].to_le_bytes()).await?; - w.write_all(&version[2].to_le_bytes()).await?; - }, - Frame::Init { pid, secret } => { - w.write_all(&FRAME_INIT.to_be_bytes()).await?; - w.write_all(&pid.to_le_bytes()).await?; - w.write_all(&secret.to_le_bytes()).await?; - }, - Frame::Shutdown => { - w.write_all(&FRAME_SHUTDOWN.to_be_bytes()).await?; - }, - Frame::OpenStream { - sid, - prio, - promises, - } => { - w.write_all(&FRAME_OPEN_STREAM.to_be_bytes()).await?; - w.write_all(&sid.to_le_bytes()).await?; - w.write_all(&prio.to_le_bytes()).await?; - w.write_all(&promises.to_le_bytes()).await?; - }, - Frame::CloseStream { sid } => { - w.write_all(&FRAME_CLOSE_STREAM.to_be_bytes()).await?; - w.write_all(&sid.to_le_bytes()).await?; - }, - Frame::DataHeader { mid, sid, length } => { - w.write_all(&FRAME_DATA_HEADER.to_be_bytes()).await?; - w.write_all(&mid.to_le_bytes()).await?; - w.write_all(&sid.to_le_bytes()).await?; - w.write_all(&length.to_le_bytes()).await?; - }, - Frame::Data { mid, start, data } => { - w.write_all(&FRAME_DATA.to_be_bytes()).await?; - w.write_all(&mid.to_le_bytes()).await?; - w.write_all(&start.to_le_bytes()).await?; - w.write_all(&(data.len() as u16).to_le_bytes()).await?; - w.write_all(&data).await?; - }, - Frame::Raw(data) => { - w.write_all(&FRAME_RAW.to_be_bytes()).await?; - w.write_all(&(data.len() as u16).to_le_bytes()).await?; - w.write_all(&data).await?; - }, - }; - Ok(()) - } - - pub async fn write_to_wire(&self, cid: Cid, mut c2w_frame_r: mpsc::UnboundedReceiver) { - trace!("Starting up tcp write()"); - let mut write_stream = self.write_stream.lock().await; - #[cfg(feature = "metrics")] - let mut metrics_cache = CidFrameCache::new(self.metrics.frames_wire_out_total.clone(), cid); - #[cfg(feature = "metrics")] - let throughput_cache = self - .metrics - .wire_out_throughput - .with_label_values(&[&cid.to_string()]); - #[cfg(not(feature = "metrics"))] - let _cid = cid; - - while let Some(frame) = c2w_frame_r.recv().await { - #[cfg(feature = "metrics")] - { - metrics_cache.with_label_values(&frame).inc(); - if let Frame::Data { - mid: _, - start: _, - ref data, - } = frame - { - throughput_cache.inc_by(data.len() as u64); - } - } - if let Err(e) = Self::write_frame(&mut *write_stream, frame).await { - info!( - ?e, - "Got an error writing to tcp, going to close this channel" - ); - c2w_frame_r.close(); - break; - }; - } - trace!("shutting down tcp write()"); - } -} - -impl UdpProtocol { - pub(crate) fn new( - socket: Arc, - remote_addr: SocketAddr, - #[cfg(feature = "metrics")] metrics: Arc, - data_in: mpsc::UnboundedReceiver>, - ) -> Self { - Self { - socket, - remote_addr, - #[cfg(feature = "metrics")] - metrics, - data_in: Mutex::new(data_in), - } - } - - pub async fn read_from_wire( - &self, - cid: Cid, - w2c_cid_frame_s: &mut mpsc::UnboundedSender, - end_r: oneshot::Receiver<()>, - ) { - trace!("Starting up udp read()"); - #[cfg(feature = "metrics")] - let mut metrics_cache = CidFrameCache::new(self.metrics.frames_wire_in_total.clone(), cid); - #[cfg(feature = "metrics")] - let throughput_cache = self - .metrics - .wire_in_throughput - .with_label_values(&[&cid.to_string()]); - let mut data_in = self.data_in.lock().await; - let mut end_r = end_r.fuse(); - while let Some(bytes) = select! { - r = data_in.recv().fuse() => match r { - Some(r) => Some(r), - None => { - info!("Udp read ended"); - w2c_cid_frame_s.send((cid, Err(()))).expect("Channel or Participant seems no longer to exist"); - None - } - }, - _ = &mut end_r => None, - } { - trace!("Got raw UDP message with len: {}", bytes.len()); - let frame_no = bytes[0]; - let frame = match frame_no { - FRAME_HANDSHAKE => { - Frame::gen_handshake(*<&[u8; 19]>::try_from(&bytes[1..20]).unwrap()) - }, - FRAME_INIT => Frame::gen_init(*<&[u8; 32]>::try_from(&bytes[1..33]).unwrap()), - FRAME_SHUTDOWN => Frame::Shutdown, - FRAME_OPEN_STREAM => { - Frame::gen_open_stream(*<&[u8; 10]>::try_from(&bytes[1..11]).unwrap()) - }, - FRAME_CLOSE_STREAM => { - Frame::gen_close_stream(*<&[u8; 8]>::try_from(&bytes[1..9]).unwrap()) - }, - FRAME_DATA_HEADER => { - Frame::gen_data_header(*<&[u8; 24]>::try_from(&bytes[1..25]).unwrap()) - }, - FRAME_DATA => { - let (mid, start, length) = - Frame::gen_data(*<&[u8; 18]>::try_from(&bytes[1..19]).unwrap()); - let mut data = vec![0; length as usize]; - #[cfg(feature = "metrics")] - throughput_cache.inc_by(length as u64); - data.copy_from_slice(&bytes[19..]); - Frame::Data { mid, start, data } - }, - FRAME_RAW => { - let length = Frame::gen_raw(*<&[u8; 2]>::try_from(&bytes[1..3]).unwrap()); - let mut data = vec![0; length as usize]; - data.copy_from_slice(&bytes[3..]); - Frame::Raw(data) - }, - _ => Frame::Raw(bytes), - }; - #[cfg(feature = "metrics")] - metrics_cache.with_label_values(&frame).inc(); - w2c_cid_frame_s.send((cid, Ok(frame))).unwrap(); - } - trace!("Shutting down udp read()"); - } - - pub async fn write_to_wire(&self, cid: Cid, mut c2w_frame_r: mpsc::UnboundedReceiver) { - trace!("Starting up udp write()"); - let mut buffer = [0u8; 2000]; - #[cfg(feature = "metrics")] - let mut metrics_cache = CidFrameCache::new(self.metrics.frames_wire_out_total.clone(), cid); - #[cfg(feature = "metrics")] - let throughput_cache = self - .metrics - .wire_out_throughput - .with_label_values(&[&cid.to_string()]); - #[cfg(not(feature = "metrics"))] - let _cid = cid; - while let Some(frame) = c2w_frame_r.recv().await { - #[cfg(feature = "metrics")] - metrics_cache.with_label_values(&frame).inc(); - let len = match frame { - Frame::Handshake { - magic_number, - version, - } => { - let x = FRAME_HANDSHAKE.to_be_bytes(); - buffer[0] = x[0]; - buffer[1..8].copy_from_slice(&magic_number); - buffer[8..12].copy_from_slice(&version[0].to_le_bytes()); - buffer[12..16].copy_from_slice(&version[1].to_le_bytes()); - buffer[16..20].copy_from_slice(&version[2].to_le_bytes()); - 20 - }, - Frame::Init { pid, secret } => { - buffer[0] = FRAME_INIT.to_be_bytes()[0]; - buffer[1..17].copy_from_slice(&pid.to_le_bytes()); - buffer[17..33].copy_from_slice(&secret.to_le_bytes()); - 33 - }, - Frame::Shutdown => { - buffer[0] = FRAME_SHUTDOWN.to_be_bytes()[0]; - 1 - }, - Frame::OpenStream { - sid, - prio, - promises, - } => { - buffer[0] = FRAME_OPEN_STREAM.to_be_bytes()[0]; - buffer[1..9].copy_from_slice(&sid.to_le_bytes()); - buffer[9] = prio.to_le_bytes()[0]; - buffer[10] = promises.to_le_bytes()[0]; - 11 - }, - Frame::CloseStream { sid } => { - buffer[0] = FRAME_CLOSE_STREAM.to_be_bytes()[0]; - buffer[1..9].copy_from_slice(&sid.to_le_bytes()); - 9 - }, - Frame::DataHeader { mid, sid, length } => { - buffer[0] = FRAME_DATA_HEADER.to_be_bytes()[0]; - buffer[1..9].copy_from_slice(&mid.to_le_bytes()); - buffer[9..17].copy_from_slice(&sid.to_le_bytes()); - buffer[17..25].copy_from_slice(&length.to_le_bytes()); - 25 - }, - Frame::Data { mid, start, data } => { - buffer[0] = FRAME_DATA.to_be_bytes()[0]; - buffer[1..9].copy_from_slice(&mid.to_le_bytes()); - buffer[9..17].copy_from_slice(&start.to_le_bytes()); - buffer[17..19].copy_from_slice(&(data.len() as u16).to_le_bytes()); - buffer[19..(data.len() + 19)].clone_from_slice(&data[..]); - #[cfg(feature = "metrics")] - throughput_cache.inc_by(data.len() as u64); - 19 + data.len() - }, - Frame::Raw(data) => { - buffer[0] = FRAME_RAW.to_be_bytes()[0]; - buffer[1..3].copy_from_slice(&(data.len() as u16).to_le_bytes()); - buffer[3..(data.len() + 3)].clone_from_slice(&data[..]); - 3 + data.len() - }, - }; - let mut start = 0; - while start < len { - trace!(?start, ?len, "Splitting up udp frame in multiple packages"); - match self - .socket - .send_to(&buffer[start..len], self.remote_addr) - .await - { - Ok(n) => { - start += n; - if n != len { - error!( - "THIS DOESN'T WORK, as RECEIVER CURRENTLY ONLY HANDLES 1 FRAME \ - per UDP message. splitting up will fail!" - ); - } - }, - Err(e) => error!(?e, "Need to handle that error!"), - } - } - } - trace!("Shutting down udp write()"); - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{metrics::NetworkMetrics, types::Pid}; - use std::sync::Arc; - use tokio::{net, runtime::Runtime, sync::mpsc}; - - #[test] - fn tcp_read_handshake() { - let pid = Pid::new(); - let cid = 80085; - let metrics = Arc::new(NetworkMetrics::new(&pid).unwrap()); - let addr = std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(127, 0, 0, 1), 50500); - Runtime::new().unwrap().block_on(async { - let server = net::TcpListener::bind(addr).await.unwrap(); - let mut client = net::TcpStream::connect(addr).await.unwrap(); - - let (s_stream, _) = server.accept().await.unwrap(); - let prot = TcpProtocol::new(s_stream, metrics); - - //Send Handshake - client.write_all(&[FRAME_HANDSHAKE]).await.unwrap(); - client.write_all(b"HELLOWO").await.unwrap(); - client.write_all(&1337u32.to_le_bytes()).await.unwrap(); - client.write_all(&0u32.to_le_bytes()).await.unwrap(); - client.write_all(&42u32.to_le_bytes()).await.unwrap(); - client.flush().await.unwrap(); - - //handle data - let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded_channel::(); - let (read_stop_sender, read_stop_receiver) = oneshot::channel(); - let cid2 = cid; - let t = std::thread::spawn(move || { - Runtime::new().unwrap().block_on(async { - prot.read_from_wire(cid2, &mut w2c_cid_frame_s, read_stop_receiver) - .await; - }) - }); - // Assert than we get some value back! Its a Handshake! - //tokio::task::sleep(std::time::Duration::from_millis(1000)); - let (cid_r, frame) = w2c_cid_frame_r.recv().await.unwrap(); - assert_eq!(cid, cid_r); - if let Ok(Frame::Handshake { - magic_number, - version, - }) = frame - { - assert_eq!(&magic_number, b"HELLOWO"); - assert_eq!(version, [1337, 0, 42]); - } else { - panic!("wrong handshake"); - } - read_stop_sender.send(()).unwrap(); - t.join().unwrap(); - }); - } - - #[test] - fn tcp_read_garbage() { - let pid = Pid::new(); - let cid = 80085; - let metrics = Arc::new(NetworkMetrics::new(&pid).unwrap()); - let addr = std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(127, 0, 0, 1), 50501); - Runtime::new().unwrap().block_on(async { - let server = net::TcpListener::bind(addr).await.unwrap(); - let mut client = net::TcpStream::connect(addr).await.unwrap(); - - let (s_stream, _) = server.accept().await.unwrap(); - let prot = TcpProtocol::new(s_stream, metrics); - - //Send Handshake - client - .write_all("x4hrtzsektfhxugzdtz5r78gzrtzfhxfdthfthuzhfzzufasgasdfg".as_bytes()) - .await - .unwrap(); - client.flush().await.unwrap(); - //handle data - let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded_channel::(); - let (read_stop_sender, read_stop_receiver) = oneshot::channel(); - let cid2 = cid; - let t = std::thread::spawn(move || { - Runtime::new().unwrap().block_on(async { - prot.read_from_wire(cid2, &mut w2c_cid_frame_s, read_stop_receiver) - .await; - }) - }); - // Assert than we get some value back! Its a Raw! - let (cid_r, frame) = w2c_cid_frame_r.recv().await.unwrap(); - assert_eq!(cid, cid_r); - if let Ok(Frame::Raw(data)) = frame { - assert_eq!(&data.as_slice(), b"x4hrtzsektfhxugzdtz5r78gzrtzfhxf"); - } else { - panic!("wrong frame type"); - } - read_stop_sender.send(()).unwrap(); - t.join().unwrap(); - }); - } -} diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index f648d48a15..eb6d21bd7e 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -2,12 +2,11 @@ use crate::metrics::NetworkMetrics; use crate::{ api::{Participant, ProtocolAddr}, - channel::Handshake, + channel::Protocols, participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel, S2bShutdownBparticipant}, - protocols::{Protocols, TcpProtocol, UdpProtocol}, - types::Pid, }; use futures_util::{FutureExt, StreamExt}; +use network_protocol::Pid; #[cfg(feature = "metrics")] use prometheus::Registry; use rand::Rng; @@ -17,6 +16,7 @@ use std::{ atomic::{AtomicBool, AtomicU64, Ordering}, Arc, }, + time::Duration, }; use tokio::{ io, net, @@ -214,47 +214,40 @@ impl Scheduler { }, }; info!("Connecting Tcp to: {}", stream.peer_addr().unwrap()); - ( - Protocols::Tcp(TcpProtocol::new( - stream, - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - )), - false, - ) - }, - ProtocolAddr::Udp(addr) => { - #[cfg(feature = "metrics")] - self.metrics - .connect_requests_total - .with_label_values(&["udp"]) - .inc(); - let socket = match net::UdpSocket::bind("0.0.0.0:0").await { - Ok(socket) => Arc::new(socket), - Err(e) => { - pid_sender.send(Err(e)).unwrap(); - continue; - }, - }; - if let Err(e) = socket.connect(addr).await { - pid_sender.send(Err(e)).unwrap(); - continue; - }; - info!("Connecting Udp to: {}", addr); - let (udp_data_sender, udp_data_receiver) = mpsc::unbounded_channel::>(); - let protocol = UdpProtocol::new( - Arc::clone(&socket), - addr, - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - udp_data_receiver, - ); - self.runtime.spawn( - Self::udp_single_channel_connect(Arc::clone(&socket), udp_data_sender) - .instrument(tracing::info_span!("udp", ?addr)), - ); - (Protocols::Udp(protocol), true) - }, + (Protocols::new_tcp(stream), false) + }, /* */ + //ProtocolAddr::Udp(addr) => { + //#[cfg(feature = "metrics")] + //self.metrics + //.connect_requests_total + //.with_label_values(&["udp"]) + //.inc(); + //let socket = match net::UdpSocket::bind("0.0.0.0:0").await { + //Ok(socket) => Arc::new(socket), + //Err(e) => { + //pid_sender.send(Err(e)).unwrap(); + //continue; + //}, + //}; + //if let Err(e) = socket.connect(addr).await { + //pid_sender.send(Err(e)).unwrap(); + //continue; + //}; + //info!("Connecting Udp to: {}", addr); + //let (udp_data_sender, udp_data_receiver) = mpsc::unbounded_channel::>(); + //let protocol = UdpProtocol::new( + //Arc::clone(&socket), + //addr, + //#[cfg(feature = "metrics")] + //Arc::clone(&self.metrics), + //udp_data_receiver, + //); + //self.runtime.spawn( + //Self::udp_single_channel_connect(Arc::clone(&socket), udp_data_sender) + //.instrument(tracing::info_span!("udp", ?addr)), + //); + //(Protocols::Udp(protocol), true) + //}, _ => unimplemented!(), }; self.init_protocol(protocol, Some(pid_sender), handshake) @@ -265,7 +258,9 @@ impl Scheduler { async fn disconnect_mgr(&self, mut a2s_disconnect_r: mpsc::UnboundedReceiver) { trace!("Start disconnect_mgr"); - while let Some((pid, return_once_successful_shutdown)) = a2s_disconnect_r.recv().await { + while let Some((pid, (timeout_time, return_once_successful_shutdown))) = + a2s_disconnect_r.recv().await + { //Closing Participants is done the following way: // 1. We drop our senders and receivers // 2. we need to close BParticipant, this will drop its senderns and receivers @@ -279,7 +274,7 @@ impl Scheduler { pi.s2b_shutdown_bparticipant_s .take() .unwrap() - .send(finished_sender) + .send((timeout_time, finished_sender)) .unwrap(); drop(pi); trace!(?pid, "dropped bparticipant, waiting for finish"); @@ -322,7 +317,7 @@ impl Scheduler { pi.s2b_shutdown_bparticipant_s .take() .unwrap() - .send(finished_sender) + .send((Duration::from_secs(120), finished_sender)) .unwrap(); (pid, finished_receiver) }) @@ -392,15 +387,10 @@ impl Scheduler { }, }; info!("Accepting Tcp from: {}", remote_addr); - let protocol = TcpProtocol::new( - stream, - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - ); - self.init_protocol(Protocols::Tcp(protocol), None, true) + self.init_protocol(Protocols::new_tcp(stream), None, true) .await; } - }, + },/* ProtocolAddr::Udp(addr) => { let socket = match net::UdpSocket::bind(addr).await { Ok(socket) => { @@ -451,12 +441,13 @@ impl Scheduler { let udp_data_sender = listeners.get_mut(&remote_addr).unwrap(); udp_data_sender.send(datavec).unwrap(); } - }, + },*/ _ => unimplemented!(), } trace!(?addr, "Ending channel creator"); } + #[allow(dead_code)] async fn udp_single_channel_connect( socket: Arc, w2p_udp_package_s: mpsc::UnboundedSender>, @@ -483,7 +474,7 @@ impl Scheduler { async fn init_protocol( &self, - protocol: Protocols, + mut protocol: Protocols, s2a_return_pid_s: Option>>, send_handshake: bool, ) { @@ -509,20 +500,13 @@ impl Scheduler { self.runtime.spawn( async move { trace!(?cid, "Open channel and be ready for Handshake"); - let handshake = Handshake::new( - cid, - local_pid, - local_secret, - #[cfg(feature = "metrics")] - Arc::clone(&metrics), - send_handshake, - ); - match handshake - .setup(&protocol) + use network_protocol::InitProtocol; + let init_result = protocol + .initialize(send_handshake, local_pid, local_secret) .instrument(tracing::info_span!("handshake", ?cid)) - .await - { - Ok((pid, sid, secret, leftover_cid_frame)) => { + .await; + match init_result { + Ok((pid, sid, secret)) => { trace!( ?cid, ?pid, @@ -533,14 +517,13 @@ impl Scheduler { debug!(?cid, "New participant connected via a channel"); let ( bparticipant, - a2b_stream_open_s, + a2b_open_stream_s, b2a_stream_opened_r, s2b_create_channel_s, s2b_shutdown_bparticipant_s, ) = BParticipant::new( pid, sid, - Arc::clone(&runtime), #[cfg(feature = "metrics")] Arc::clone(&metrics), ); @@ -548,8 +531,7 @@ impl Scheduler { let participant = Participant::new( local_pid, pid, - Arc::clone(&runtime), - a2b_stream_open_s, + a2b_open_stream_s, b2a_stream_opened_r, participant_channels.a2s_disconnect_s, ); @@ -573,13 +555,7 @@ impl Scheduler { oneshot::channel(); //From now on wire connects directly with bparticipant! s2b_create_channel_s - .send(( - cid, - sid, - protocol, - leftover_cid_frame, - b2s_create_channel_done_s, - )) + .send((cid, sid, protocol, b2s_create_channel_done_s)) .unwrap(); b2s_create_channel_done_r.await.unwrap(); if let Some(pid_oneshot) = s2a_return_pid_s { @@ -627,8 +603,8 @@ impl Scheduler { //From now on this CHANNEL can receiver other frames! // move directly to participant! }, - Err(()) => { - debug!(?cid, "Handshake from a new connection failed"); + Err(e) => { + debug!(?cid, ?e, "Handshake from a new connection failed"); if let Some(pid_oneshot) = s2a_return_pid_s { // someone is waiting with `connect`, so give them their Error trace!(?cid, "returning the Err to api who requested the connect"); diff --git a/server/Cargo.toml b/server/Cargo.toml index f997749565..8726a037d8 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -16,7 +16,7 @@ common = { package = "veloren-common", path = "../common" } common-sys = { package = "veloren-common-sys", path = "../common/sys" } common-net = { package = "veloren-common-net", path = "../common/net" } world = { package = "veloren-world", path = "../world" } -network = { package = "veloren_network", path = "../network", features = ["metrics", "compression"], default-features = false } +network = { package = "veloren-network", path = "../network", features = ["metrics", "compression"], default-features = false } specs = { git = "https://github.com/amethyst/specs.git", features = ["shred-derive"], rev = "d4435bdf496cf322c74886ca09dd8795984919b4" } specs-idvs = { git = "https://gitlab.com/veloren/specs-idvs.git", rev = "9fab7b396acd6454585486e50ae4bfe2069858a9" } diff --git a/voxygen/src/hud/chat.rs b/voxygen/src/hud/chat.rs index b2dc69c780..1ecc5fe621 100644 --- a/voxygen/src/hud/chat.rs +++ b/voxygen/src/hud/chat.rs @@ -373,7 +373,7 @@ impl<'a> Widget for Chat<'a> { let ChatMsg { chat_type, .. } = &message; // For each ChatType needing localization get/set matching pre-formatted // localized string. This string will be formatted with the data - // provided in ChatType in the client/src/lib.rs + // provided in ChatType in the client/src/mod.rs // fn format_message called below message.message = match chat_type { ChatType::Online(_) => self From ea8ab1ce7abe5859fe4c402d79532dc784b539ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= Date: Wed, 10 Feb 2021 11:37:42 +0100 Subject: [PATCH 5/6] Great improvements to the codebase: - better logging in network - we now notify the send of what happened in recv in participant. - works with veloren master servers - works in singleplayer, using a actual mid. - add `mpsc` in whole stack incl tests - speed up internal read/write with `Bytes` crate - use `prometheus-hyper` for metrics - use a metrics cache --- Cargo.lock | 84 +++-- client/Cargo.toml | 2 +- network/Cargo.toml | 22 +- network/examples/network-speed/main.rs | 25 +- network/examples/network-speed/metrics.rs | 92 ------ network/protocol/Cargo.toml | 2 + network/protocol/benches/protocols.rs | 20 +- network/protocol/src/frame.rs | 381 ++++++++-------------- network/protocol/src/io.rs | 9 +- network/protocol/src/lib.rs | 3 + network/protocol/src/metrics.rs | 282 ++++++++-------- network/protocol/src/mpsc.rs | 16 +- network/protocol/src/prio.rs | 17 +- network/protocol/src/tcp.rs | 251 +++++++++----- network/protocol/src/types.rs | 15 +- network/src/api.rs | 123 ++++--- network/src/channel.rs | 31 +- network/src/message.rs | 1 + network/src/participant.rs | 86 +++-- network/src/scheduler.rs | 82 ++++- network/tests/helper.rs | 11 +- network/tests/integration.rs | 29 +- server-cli/Cargo.toml | 2 +- server-cli/src/logging.rs | 1 + server/Cargo.toml | 4 +- server/src/lib.rs | 48 +-- server/src/metrics.rs | 103 +----- voxygen/Cargo.toml | 2 +- voxygen/src/logging.rs | 1 + 29 files changed, 852 insertions(+), 893 deletions(-) delete mode 100644 network/examples/network-speed/metrics.rs diff --git a/Cargo.lock b/Cargo.lock index bd0dcb62ec..0642e1e5a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -225,12 +225,6 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eab1c04a571841102f5345a8fc0f6bb3d31c315dec879b5c6e42e40ce7ffa34e" -[[package]] -name = "ascii" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbf56136a5198c7b01a49e3afcbef6cf84597273d298f54432926024107b0109" - [[package]] name = "assets_manager" version = "0.4.3" @@ -723,7 +717,7 @@ version = "3.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da3da6baa321ec19e1cc41d31bf599f00c783d0517095cdaf0332e3fe8d20680" dependencies = [ - "ascii 0.9.3", + "ascii", "byteorder", "either", "memchr", @@ -2341,6 +2335,16 @@ dependencies = [ "http", ] +[[package]] +name = "http-body" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2861bd27ee074e5ee891e8b539837a9430012e249d7f0ca2d795650f579c1994" +dependencies = [ + "bytes 1.0.1", + "http", +] + [[package]] name = "httparse" version = "1.3.5" @@ -2371,7 +2375,7 @@ dependencies = [ "futures-util", "h2", "http", - "http-body", + "http-body 0.3.1", "httparse", "httpdate", "itoa", @@ -2383,6 +2387,29 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8e946c2b1349055e0b72ae281b238baf1a3ea7307c7e9f9d64673bdd9c26ac7" +dependencies = [ + "bytes 1.0.1", + "futures-channel", + "futures-core", + "futures-util", + "http", + "http-body 0.4.0", + "httparse", + "httpdate", + "itoa", + "pin-project 1.0.5", + "socket2", + "tokio 1.2.0", + "tower-service", + "tracing", + "want", +] + [[package]] name = "hyper-rustls" version = "0.21.0" @@ -2391,7 +2418,7 @@ checksum = "37743cc83e8ee85eacfce90f2f4102030d9ff0a95244098d781e9bee4a90abb6" dependencies = [ "bytes 0.5.6", "futures-util", - "hyper", + "hyper 0.13.10", "log", "rustls 0.18.1", "tokio 0.2.25", @@ -3952,6 +3979,18 @@ dependencies = [ "thiserror", ] +[[package]] +name = "prometheus-hyper" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc47fa532a12d544229015dd3fae32394949af098b8fe9a327b8c1e4c911d1c8" +dependencies = [ + "hyper 0.14.4", + "prometheus", + "tokio 1.2.0", + "tracing", +] + [[package]] name = "publicsuffix" version = "1.5.4" @@ -4230,8 +4269,8 @@ dependencies = [ "futures-core", "futures-util", "http", - "http-body", - "hyper", + "http-body 0.3.1", + "hyper 0.13.10", "hyper-rustls", "ipnet", "js-sys", @@ -5069,19 +5108,6 @@ dependencies = [ "syn 1.0.60", ] -[[package]] -name = "tiny_http" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eded47106b8e52d8ed8119f0ea6e8c0f5881e69783e0297b5a8462958f334bc1" -dependencies = [ - "ascii 1.0.0", - "chrono", - "chunked_transfer", - "log", - "url", -] - [[package]] name = "tinytemplate" version = "1.2.0" @@ -5137,8 +5163,11 @@ dependencies = [ "memchr", "mio 0.7.7", "num_cpus", + "once_cell", "pin-project-lite 0.2.4", + "signal-hook-registry", "tokio-macros", + "winapi 0.3.9", ] [[package]] @@ -5682,6 +5711,7 @@ dependencies = [ "async-trait", "bincode", "bitflags", + "bytes 1.0.1", "clap", "crossbeam-channel 0.5.0", "futures-core", @@ -5689,14 +5719,13 @@ dependencies = [ "lazy_static", "lz-fear", "prometheus", + "prometheus-hyper", "rand 0.8.3", "serde", "shellexpand", - "tiny_http", "tokio 1.2.0", "tokio-stream", "tracing", - "tracing-futures", "tracing-subscriber", "veloren-network-protocol", ] @@ -5708,6 +5737,7 @@ dependencies = [ "async-channel", "async-trait", "bitflags", + "bytes 1.0.1", "criterion", "prometheus", "rand 0.8.3", @@ -5762,6 +5792,7 @@ dependencies = [ "libsqlite3-sys", "portpicker", "prometheus", + "prometheus-hyper", "rand 0.8.3", "rayon", "ron", @@ -5771,7 +5802,6 @@ dependencies = [ "slab", "specs", "specs-idvs", - "tiny_http", "tokio 1.2.0", "tracing", "uvth", diff --git a/client/Cargo.toml b/client/Cargo.toml index fcda01cb65..d8f6cb969a 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -21,7 +21,7 @@ uvth = "3.1.1" futures-util = "0.3.7" futures-executor = "0.3" futures-timer = "3.0" -tokio = { version = "1.0.1", default-features = false, features = ["rt"] } +tokio = { version = "1", default-features = false, features = ["rt"] } image = { version = "0.23.12", default-features = false, features = ["png"] } num = "0.3.1" num_cpus = "1.10.1" diff --git a/network/Cargo.toml b/network/Cargo.toml index f548278896..de6e028176 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -14,7 +14,7 @@ default = ["metrics","compression"] [dependencies] -network-protocol = { package = "veloren-network-protocol", path = "protocol", default-features = false } +network-protocol = { package = "veloren-network-protocol", path = "protocol" } #serialisation bincode = "1.3.1" @@ -24,8 +24,7 @@ crossbeam-channel = "0.5" tokio = { version = "1.2", default-features = false, features = ["io-util", "macros", "rt", "net", "time"] } tokio-stream = { version = "0.1.2", default-features = false } #tracing and metrics -tracing = { version = "0.1", default-features = false } -tracing-futures = "0.2" +tracing = { version = "0.1", default-features = false, features = ["attributes"]} prometheus = { version = "0.11", default-features = false, optional = true } #async futures-core = { version = "0.3", default-features = false } @@ -39,12 +38,25 @@ bitflags = "1.2.1" lz-fear = { version = "0.1.1", optional = true } # async traits async-trait = "0.1.42" +bytes = "^1" [dev-dependencies] tracing-subscriber = { version = "0.2.3", default-features = false, features = ["env-filter", "fmt", "chrono", "ansi", "smallvec"] } -tokio = { version = "1.1.0", default-features = false, features = ["io-std", "fs", "rt-multi-thread"] } +tokio = { version = "1.2", default-features = false, features = ["io-std", "fs", "rt-multi-thread"] } futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] } clap = { version = "2.33", default-features = false } shellexpand = "2.0.0" -tiny_http = "0.8.0" serde = { version = "1.0", features = ["derive"] } +prometheus-hyper = "0.1.1" + +[[example]] +name = "fileshare" + +[[example]] +name = "network-speed" + +[[example]] +name = "chat" + +[[example]] +name = "tcp_loadtest" diff --git a/network/examples/network-speed/main.rs b/network/examples/network-speed/main.rs index 9814cec998..37d076b5bd 100644 --- a/network/examples/network-speed/main.rs +++ b/network/examples/network-speed/main.rs @@ -3,11 +3,12 @@ /// (cd network/examples/network-speed && RUST_BACKTRACE=1 cargo run --profile=debuginfo -Z unstable-options -- --trace=error --protocol=tcp --mode=server) /// (cd network/examples/network-speed && RUST_BACKTRACE=1 cargo run --profile=debuginfo -Z unstable-options -- --trace=error --protocol=tcp --mode=client) /// ``` -mod metrics; - use clap::{App, Arg}; +use prometheus::Registry; +use prometheus_hyper::Server; use serde::{Deserialize, Serialize}; use std::{ + net::SocketAddr, sync::Arc, thread, time::{Duration, Instant}, @@ -121,9 +122,13 @@ fn main() { } fn server(address: ProtocolAddr, runtime: Arc) { - let mut metrics = metrics::SimpleMetrics::new(); - let server = Network::new_with_registry(Pid::new(), Arc::clone(&runtime), metrics.registry()); - metrics.run("0.0.0.0:59112".parse().unwrap()).unwrap(); + let registry = Arc::new(Registry::new()); + let server = Network::new_with_registry(Pid::new(), Arc::clone(&runtime), ®istry); + runtime.spawn(Server::run( + Arc::clone(®istry), + SocketAddr::from(([0; 4], 59112)), + futures_util::future::pending(), + )); runtime.block_on(server.listen(address)).unwrap(); loop { @@ -148,9 +153,13 @@ fn server(address: ProtocolAddr, runtime: Arc) { } fn client(address: ProtocolAddr, runtime: Arc) { - let mut metrics = metrics::SimpleMetrics::new(); - let client = Network::new_with_registry(Pid::new(), Arc::clone(&runtime), metrics.registry()); - metrics.run("0.0.0.0:59111".parse().unwrap()).unwrap(); + let registry = Arc::new(Registry::new()); + let client = Network::new_with_registry(Pid::new(), Arc::clone(&runtime), ®istry); + runtime.spawn(Server::run( + Arc::clone(®istry), + SocketAddr::from(([0; 4], 59111)), + futures_util::future::pending(), + )); let p1 = runtime.block_on(client.connect(address)).unwrap(); //remote representation of p1 let mut s1 = runtime diff --git a/network/examples/network-speed/metrics.rs b/network/examples/network-speed/metrics.rs deleted file mode 100644 index 657a00abf4..0000000000 --- a/network/examples/network-speed/metrics.rs +++ /dev/null @@ -1,92 +0,0 @@ -use prometheus::{Encoder, Registry, TextEncoder}; -use std::{ - error::Error, - net::SocketAddr, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - thread, -}; -use tracing::*; - -pub struct SimpleMetrics { - running: Arc, - handle: Option>, - registry: Option, -} - -impl SimpleMetrics { - pub fn new() -> Self { - let running = Arc::new(AtomicBool::new(false)); - let registry = Some(Registry::new()); - - Self { - running, - handle: None, - registry, - } - } - - pub fn registry(&self) -> &Registry { - match self.registry { - Some(ref r) => r, - None => panic!("You cannot longer register new metrics after the server has started!"), - } - } - - pub fn run(&mut self, addr: SocketAddr) -> Result<(), Box> { - self.running.store(true, Ordering::Relaxed); - let running2 = self.running.clone(); - - let registry = self - .registry - .take() - .expect("ServerMetrics must be already started"); - - //TODO: make this a job - self.handle = Some(thread::spawn(move || { - let server = tiny_http::Server::http(addr).unwrap(); - const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(1); - debug!("starting tiny_http server to serve metrics"); - while running2.load(Ordering::Relaxed) { - let request = match server.recv_timeout(TIMEOUT) { - Ok(Some(rq)) => rq, - Ok(None) => continue, - Err(e) => { - println!("Error: {}", e); - break; - }, - }; - let mf = registry.gather(); - let encoder = TextEncoder::new(); - let mut buffer = vec![]; - encoder - .encode(&mf, &mut buffer) - .expect("Failed to encoder metrics text."); - let response = tiny_http::Response::from_string( - String::from_utf8(buffer).expect("Failed to parse bytes as a string."), - ); - if let Err(e) = request.respond(response) { - error!( - ?e, - "The metrics HTTP server had encountered and error with answering" - ) - } - } - debug!("Stopping tiny_http server to serve metrics"); - })); - Ok(()) - } -} - -impl Drop for SimpleMetrics { - fn drop(&mut self) { - self.running.store(false, Ordering::Relaxed); - let handle = self.handle.take(); - handle - .expect("ServerMetrics worker handle does not exist.") - .join() - .expect("Error shutting down prometheus metric exporter"); - } -} diff --git a/network/protocol/Cargo.toml b/network/protocol/Cargo.toml index e097314b6b..a9bd701940 100644 --- a/network/protocol/Cargo.toml +++ b/network/protocol/Cargo.toml @@ -9,6 +9,7 @@ edition = "2018" [features] metrics = ["prometheus"] +trace_pedantic = [] # use for debug only default = ["metrics"] @@ -22,6 +23,7 @@ bitflags = "1.2.1" rand = { version = "0.8" } # async traits async-trait = "0.1.42" +bytes = "^1" [dev-dependencies] async-channel = "1.5.1" diff --git a/network/protocol/benches/protocols.rs b/network/protocol/benches/protocols.rs index 5151083b98..f9ad557682 100644 --- a/network/protocol/benches/protocols.rs +++ b/network/protocol/benches/protocols.rs @@ -1,5 +1,6 @@ use async_channel::*; use async_trait::async_trait; +use bytes::BytesMut; use criterion::{criterion_group, criterion_main, Criterion}; use std::{sync::Arc, time::Duration}; use veloren_network_protocol::{ @@ -8,7 +9,7 @@ use veloren_network_protocol::{ Sid, TcpRecvProtcol, TcpSendProtcol, UnreliableDrain, UnreliableSink, _internal::Frame, }; -fn frame_serialize(frame: Frame, buffer: &mut [u8]) -> usize { frame.to_bytes(buffer).0 } +fn frame_serialize(frame: Frame, buffer: &mut BytesMut) { frame.to_bytes(buffer); } async fn mpsc_msg(buffer: Arc) { // Arrrg, need to include constructor here @@ -102,7 +103,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.to_async(rt()).iter(|| mpsc_handshake()) }); - let mut buffer = [0u8; 1500]; + let mut buffer = BytesMut::with_capacity(1500); c.bench_function("frame_serialize_short", |b| { let frame = Frame::Data { @@ -110,7 +111,7 @@ fn criterion_benchmark(c: &mut Criterion) { start: 89u64, data: b"hello_world".to_vec(), }; - b.iter(move || frame_serialize(frame.clone(), &mut buffer)) + b.iter(|| frame_serialize(frame.clone(), &mut buffer)) }); c.bench_function("tcp_short_msg", |b| { @@ -126,6 +127,11 @@ fn criterion_benchmark(c: &mut Criterion) { b.to_async(rt()) .iter(|| tcp_msg(Arc::clone(&buffer), 10_000)) }); + c.bench_function("tcp_1000000_tiny_msg", |b| { + let buffer = Arc::new(MessageBuffer { data: vec![3u8; 5] }); + b.to_async(rt()) + .iter(|| tcp_msg(Arc::clone(&buffer), 1_000_000)) + }); } criterion_group!(benches, criterion_benchmark); @@ -164,11 +170,11 @@ mod utils { } pub struct TcpDrain { - sender: Sender>, + sender: Sender, } pub struct TcpSink { - receiver: Receiver>, + receiver: Receiver, } /// emulate Tcp protocol on Channels @@ -219,7 +225,7 @@ mod utils { #[async_trait] impl UnreliableDrain for TcpDrain { - type DataFormat = Vec; + type DataFormat = BytesMut; async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { self.sender @@ -231,7 +237,7 @@ mod utils { #[async_trait] impl UnreliableSink for TcpSink { - type DataFormat = Vec; + type DataFormat = BytesMut; async fn recv(&mut self) -> Result { self.receiver diff --git a/network/protocol/src/frame.rs b/network/protocol/src/frame.rs index 4824498940..2f1bb3eb4c 100644 --- a/network/protocol/src/frame.rs +++ b/network/protocol/src/frame.rs @@ -1,5 +1,5 @@ use crate::types::{Mid, Pid, Prio, Promises, Sid}; -use std::{collections::VecDeque, convert::TryFrom}; +use bytes::{Buf, BufMut, BytesMut}; // const FRAME_RESERVED_1: u8 = 0; const FRAME_HANDSHAKE: u8 = 1; @@ -62,37 +62,32 @@ impl InitFrame { pub(crate) const RAW_CNS: usize = 2; //provide an appropriate buffer size. > 1500 - pub(crate) fn to_bytes(self, bytes: &mut [u8]) -> usize { + pub(crate) fn to_bytes(self, bytes: &mut BytesMut) { match self { InitFrame::Handshake { magic_number, version, } => { - let x = FRAME_HANDSHAKE.to_be_bytes(); - bytes[0] = x[0]; - bytes[1..8].copy_from_slice(&magic_number); - bytes[8..12].copy_from_slice(&version[0].to_le_bytes()); - bytes[12..16].copy_from_slice(&version[1].to_le_bytes()); - bytes[16..Self::HANDSHAKE_CNS + 1].copy_from_slice(&version[2].to_le_bytes()); - Self::HANDSHAKE_CNS + 1 + bytes.put_u8(FRAME_HANDSHAKE); + bytes.put_slice(&magic_number); + bytes.put_u32_le(version[0]); + bytes.put_u32_le(version[1]); + bytes.put_u32_le(version[2]); }, InitFrame::Init { pid, secret } => { - bytes[0] = FRAME_INIT.to_be_bytes()[0]; - bytes[1..17].copy_from_slice(&pid.to_le_bytes()); - bytes[17..Self::INIT_CNS + 1].copy_from_slice(&secret.to_le_bytes()); - Self::INIT_CNS + 1 + bytes.put_u8(FRAME_INIT); + pid.to_bytes(bytes); + bytes.put_u128_le(secret); }, InitFrame::Raw(data) => { - bytes[0] = FRAME_RAW.to_be_bytes()[0]; - bytes[1..3].copy_from_slice(&(data.len() as u16).to_le_bytes()); - bytes[Self::RAW_CNS + 1..(data.len() + Self::RAW_CNS + 1)] - .clone_from_slice(&data[..]); - Self::RAW_CNS + 1 + data.len() + bytes.put_u8(FRAME_RAW); + bytes.put_u16_le(data.len() as u16); + bytes.put_slice(&data); }, } } - pub(crate) fn to_frame(bytes: Vec) -> Option { + pub(crate) fn to_frame(bytes: &mut BytesMut) -> Option { let frame_no = match bytes.get(0) { Some(&f) => f, None => return None, @@ -102,61 +97,43 @@ impl InitFrame { if bytes.len() < Self::HANDSHAKE_CNS + 1 { return None; } - InitFrame::gen_handshake( - *<&[u8; Self::HANDSHAKE_CNS]>::try_from(&bytes[1..Self::HANDSHAKE_CNS + 1]) - .unwrap(), - ) + bytes.advance(1); + let mut magic_number_bytes = bytes.copy_to_bytes(7); + let mut magic_number = [0u8; 7]; + magic_number_bytes.copy_to_slice(&mut magic_number); + InitFrame::Handshake { + magic_number, + version: [bytes.get_u32_le(), bytes.get_u32_le(), bytes.get_u32_le()], + } }, FRAME_INIT => { if bytes.len() < Self::INIT_CNS + 1 { return None; } - InitFrame::gen_init( - *<&[u8; Self::INIT_CNS]>::try_from(&bytes[1..Self::INIT_CNS + 1]).unwrap(), - ) + bytes.advance(1); + InitFrame::Init { + pid: Pid::from_bytes(bytes), + secret: bytes.get_u128_le(), + } }, FRAME_RAW => { if bytes.len() < Self::RAW_CNS + 1 { return None; } - let length = InitFrame::gen_raw( - *<&[u8; Self::RAW_CNS]>::try_from(&bytes[1..Self::RAW_CNS + 1]).unwrap(), - ); - let mut data = vec![0; length as usize]; - let slice = &bytes[Self::RAW_CNS + 1..]; - if slice.len() != length as usize { - return None; - } - data.copy_from_slice(&bytes[Self::RAW_CNS + 1..]); + bytes.advance(1); + let length = bytes.get_u16_le() as usize; + // lower length is allowed + let max_length = length.min(bytes.len()); + println!("dasdasd {:?}", length); + println!("aaaaa {:?}", max_length); + let mut data = vec![0; max_length]; + data.copy_from_slice(&bytes[..max_length]); InitFrame::Raw(data) }, - _ => InitFrame::Raw(bytes), + _ => InitFrame::Raw(bytes.to_vec()), }; Some(frame) } - - fn gen_handshake(buf: [u8; Self::HANDSHAKE_CNS]) -> Self { - let magic_number = *<&[u8; 7]>::try_from(&buf[0..7]).unwrap(); - InitFrame::Handshake { - magic_number, - version: [ - u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[7..11]).unwrap()), - u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[11..15]).unwrap()), - u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[15..Self::HANDSHAKE_CNS]).unwrap()), - ], - } - } - - fn gen_init(buf: [u8; Self::INIT_CNS]) -> Self { - InitFrame::Init { - pid: Pid::from_le_bytes(*<&[u8; 16]>::try_from(&buf[0..16]).unwrap()), - secret: u128::from_le_bytes(*<&[u8; 16]>::try_from(&buf[16..Self::INIT_CNS]).unwrap()), - } - } - - fn gen_raw(buf: [u8; Self::RAW_CNS]) -> u16 { - u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[0..Self::RAW_CNS]).unwrap()) - } } impl Frame { @@ -164,82 +141,53 @@ impl Frame { /// const part of the DATA frame, actual size is variable pub(crate) const DATA_CNS: usize = 18; pub(crate) const DATA_HEADER_CNS: usize = 24; - #[cfg(feature = "metrics")] - pub const FRAMES_LEN: u8 = 5; pub(crate) const OPEN_STREAM_CNS: usize = 10; // Size WITHOUT the 1rst indicating byte pub(crate) const SHUTDOWN_CNS: usize = 0; - #[cfg(feature = "metrics")] - pub const fn int_to_string(i: u8) -> &'static str { - match i { - 0 => "Shutdown", - 1 => "OpenStream", - 2 => "CloseStream", - 3 => "DataHeader", - 4 => "Data", - _ => "", - } - } - - #[cfg(feature = "metrics")] - pub fn get_int(&self) -> u8 { - match self { - Frame::Shutdown => 0, - Frame::OpenStream { .. } => 1, - Frame::CloseStream { .. } => 2, - Frame::DataHeader { .. } => 3, - Frame::Data { .. } => 4, - } - } - - #[cfg(feature = "metrics")] - pub fn get_string(&self) -> &str { Self::int_to_string(self.get_int()) } - //provide an appropriate buffer size. > 1500 - pub fn to_bytes(self, bytes: &mut [u8]) -> (/* buf */ usize, /* actual data */ u64) { + pub fn to_bytes(self, bytes: &mut BytesMut) -> u64 { match self { Frame::Shutdown => { - bytes[Self::SHUTDOWN_CNS] = FRAME_SHUTDOWN.to_be_bytes()[0]; - (Self::SHUTDOWN_CNS + 1, 0) + bytes.put_u8(FRAME_SHUTDOWN); + 0 }, Frame::OpenStream { sid, prio, promises, } => { - bytes[0] = FRAME_OPEN_STREAM.to_be_bytes()[0]; - bytes[1..9].copy_from_slice(&sid.to_le_bytes()); - bytes[9] = prio.to_le_bytes()[0]; - bytes[Self::OPEN_STREAM_CNS] = promises.to_le_bytes()[0]; - (Self::OPEN_STREAM_CNS + 1, 0) + bytes.put_u8(FRAME_OPEN_STREAM); + bytes.put_slice(&sid.to_le_bytes()); + bytes.put_u8(prio); + bytes.put_u8(promises.to_le_bytes()[0]); + 0 }, Frame::CloseStream { sid } => { - bytes[0] = FRAME_CLOSE_STREAM.to_be_bytes()[0]; - bytes[1..Self::CLOSE_STREAM_CNS + 1].copy_from_slice(&sid.to_le_bytes()); - (Self::CLOSE_STREAM_CNS + 1, 0) + bytes.put_u8(FRAME_CLOSE_STREAM); + bytes.put_slice(&sid.to_le_bytes()); + 0 }, Frame::DataHeader { mid, sid, length } => { - bytes[0] = FRAME_DATA_HEADER.to_be_bytes()[0]; - bytes[1..9].copy_from_slice(&mid.to_le_bytes()); - bytes[9..17].copy_from_slice(&sid.to_le_bytes()); - bytes[17..Self::DATA_HEADER_CNS + 1].copy_from_slice(&length.to_le_bytes()); - (Self::DATA_HEADER_CNS + 1, 0) + bytes.put_u8(FRAME_DATA_HEADER); + bytes.put_u64_le(mid); + bytes.put_slice(&sid.to_le_bytes()); + bytes.put_u64_le(length); + 0 }, Frame::Data { mid, start, data } => { - bytes[0] = FRAME_DATA.to_be_bytes()[0]; - bytes[1..9].copy_from_slice(&mid.to_le_bytes()); - bytes[9..17].copy_from_slice(&start.to_le_bytes()); - bytes[17..Self::DATA_CNS + 1].copy_from_slice(&(data.len() as u16).to_le_bytes()); - bytes[Self::DATA_CNS + 1..(data.len() + Self::DATA_CNS + 1)] - .clone_from_slice(&data[..]); - (Self::DATA_CNS + 1 + data.len(), data.len() as u64) + bytes.put_u8(FRAME_DATA); + bytes.put_u64_le(mid); + bytes.put_u64_le(start); + bytes.put_u16_le(data.len() as u16); + bytes.put_slice(&data); + data.len() as u64 }, } } - pub(crate) fn to_frame(bytes: &mut VecDeque) -> Option { - let frame_no = match bytes.get(0) { + pub(crate) fn to_frame(bytes: &mut BytesMut) -> Option { + let frame_no = match bytes.first() { Some(&f) => f, None => return None, }; @@ -249,6 +197,9 @@ impl Frame { FRAME_CLOSE_STREAM => Self::CLOSE_STREAM_CNS, FRAME_DATA_HEADER => Self::DATA_HEADER_CNS, FRAME_DATA => { + if bytes.len() < 17 + 1 + 1 { + return None; + } u16::from_le_bytes([bytes[16 + 1], bytes[17 + 1]]) as usize + Self::DATA_CNS }, _ => return None, @@ -260,68 +211,49 @@ impl Frame { let frame = match frame_no { FRAME_SHUTDOWN => { - let _ = bytes.drain(..size + 1); + let _ = bytes.split_to(size + 1); Frame::Shutdown }, FRAME_OPEN_STREAM => { - let bytes = bytes.drain(..size + 1).skip(1).collect::>(); - Frame::gen_open_stream(<[u8; 10]>::try_from(bytes).unwrap()) + let mut bytes = bytes.split_to(size + 1); + bytes.advance(1); + Frame::OpenStream { + sid: Sid::new(bytes.get_u64_le()), + prio: bytes.get_u8(), + promises: Promises::from_bits_truncate(bytes.get_u8()), + } }, FRAME_CLOSE_STREAM => { - let bytes = bytes.drain(..size + 1).skip(1).collect::>(); - Frame::gen_close_stream(<[u8; 8]>::try_from(bytes).unwrap()) + let mut bytes = bytes.split_to(size + 1); + bytes.advance(1); + Frame::CloseStream { + sid: Sid::new(bytes.get_u64_le()), + } }, FRAME_DATA_HEADER => { - let bytes = bytes.drain(..size + 1).skip(1).collect::>(); - Frame::gen_data_header(<[u8; 24]>::try_from(bytes).unwrap()) + let mut bytes = bytes.split_to(size + 1); + bytes.advance(1); + Frame::DataHeader { + mid: bytes.get_u64_le(), + sid: Sid::new(bytes.get_u64_le()), + length: bytes.get_u64_le(), + } }, FRAME_DATA => { - let info = bytes - .drain(..Self::DATA_CNS + 1) - .skip(1) - .collect::>(); - let (mid, start, length) = Frame::gen_data(<[u8; 18]>::try_from(info).unwrap()); + let mut info = bytes.split_to(Self::DATA_CNS + 1); + info.advance(1); + let mid = info.get_u64_le(); + let start = info.get_u64_le(); + let length = info.get_u16_le(); debug_assert_eq!(length as usize, size - Self::DATA_CNS); - let data = bytes.drain(..length as usize).collect::>(); + let data = bytes.split_to(length as usize); + let data = data.to_vec(); Frame::Data { mid, start, data } }, _ => unreachable!("Frame::to_frame should be handled before!"), }; Some(frame) } - - fn gen_open_stream(buf: [u8; Self::OPEN_STREAM_CNS]) -> Self { - Frame::OpenStream { - sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), - prio: buf[8], - promises: Promises::from_bits_truncate(buf[Self::OPEN_STREAM_CNS - 1]), - } - } - - fn gen_close_stream(buf: [u8; Self::CLOSE_STREAM_CNS]) -> Self { - Frame::CloseStream { - sid: Sid::from_le_bytes( - *<&[u8; 8]>::try_from(&buf[0..Self::CLOSE_STREAM_CNS]).unwrap(), - ), - } - } - - fn gen_data_header(buf: [u8; Self::DATA_HEADER_CNS]) -> Self { - Frame::DataHeader { - mid: Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), - sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()), - length: u64::from_le_bytes( - *<&[u8; 8]>::try_from(&buf[16..Self::DATA_HEADER_CNS]).unwrap(), - ), - } - } - - fn gen_data(buf: [u8; Self::DATA_CNS]) -> (Mid, u64, u16) { - let mid = Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()); - let start = u64::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()); - let length = u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[16..Self::DATA_CNS]).unwrap()); - (mid, start, length) - } } #[cfg(test)] @@ -375,10 +307,9 @@ mod tests { #[test] fn initframe_individual() { let dupl = |frame: InitFrame| { - let mut buffer = vec![0u8; 1500]; - let size = InitFrame::to_bytes(frame.clone(), &mut buffer); - buffer.truncate(size); - InitFrame::to_frame(buffer) + let mut buffer = BytesMut::with_capacity(1500); + InitFrame::to_bytes(frame.clone(), &mut buffer); + InitFrame::to_frame(&mut buffer) }; for frame in get_initframes() { @@ -389,29 +320,18 @@ mod tests { #[test] fn initframe_multiple() { - let mut buffer = vec![0u8; 3000]; + let mut buffer = BytesMut::with_capacity(3000); let mut frames = get_initframes(); - let mut last = 0; // to string - let sizes = frames - .iter() - .map(|f| { - let s = InitFrame::to_bytes(f.clone(), &mut buffer[last..]); - last += s; - s - }) - .collect::>(); + for f in &frames { + InitFrame::to_bytes(f.clone(), &mut buffer); + } // from string - let mut last = 0; - let mut framesd = sizes + let mut framesd = frames .iter() - .map(|&s| { - let f = InitFrame::to_frame(buffer[last..last + s].to_vec()); - last += s; - f - }) + .map(|&_| InitFrame::to_frame(&mut buffer)) .collect::>(); // compare @@ -424,10 +344,9 @@ mod tests { #[test] fn frame_individual() { let dupl = |frame: Frame| { - let mut buffer = vec![0u8; 1500]; - let (size, _) = Frame::to_bytes(frame.clone(), &mut buffer); - let mut deque = buffer[..size].iter().map(|b| *b).collect(); - Frame::to_frame(&mut deque) + let mut buffer = BytesMut::with_capacity(1500); + Frame::to_bytes(frame.clone(), &mut buffer); + Frame::to_frame(&mut buffer) }; for frame in get_frames() { @@ -438,31 +357,16 @@ mod tests { #[test] fn frame_multiple() { - let mut buffer = vec![0u8; 3000]; + let mut buffer = BytesMut::with_capacity(3000); let mut frames = get_frames(); - let mut last = 0; // to string - let sizes = frames - .iter() - .map(|f| { - let s = Frame::to_bytes(f.clone(), &mut buffer[last..]).0; - last += s; - s - }) - .collect::>(); - - assert_eq!(sizes[0], 1 + Frame::OPEN_STREAM_CNS); - assert_eq!(sizes[1], 1 + Frame::DATA_HEADER_CNS); - assert_eq!(sizes[2], 1 + Frame::DATA_CNS + 20); - assert_eq!(sizes[3], 1 + Frame::DATA_CNS + 16); - assert_eq!(sizes[4], 1 + Frame::CLOSE_STREAM_CNS); - assert_eq!(sizes[5], 1 + Frame::SHUTDOWN_CNS); - - let mut buffer = buffer.drain(..).collect::>(); + for f in &frames { + Frame::to_bytes(f.clone(), &mut buffer); + } // from string - let mut framesd = sizes + let mut framesd = frames .iter() .map(|&_| Frame::to_frame(&mut buffer)) .collect::>(); @@ -476,32 +380,31 @@ mod tests { #[test] fn frame_exact_size() { - let mut buffer = vec![0u8; Frame::CLOSE_STREAM_CNS+1/*first byte*/]; + const SIZE: usize = Frame::CLOSE_STREAM_CNS+1/*first byte*/; + let mut buffer = BytesMut::with_capacity(SIZE); - let frame1 = Frame::CloseStream { - sid: Sid::new(1337), - }; - let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + let frame1 = Frame::CloseStream { sid: Sid::new(2) }; + Frame::to_bytes(frame1.clone(), &mut buffer); + assert_eq!(buffer.len(), SIZE); let mut deque = buffer.iter().map(|b| *b).collect(); let frame2 = Frame::to_frame(&mut deque); assert_eq!(Some(frame1), frame2); } #[test] - #[should_panic] fn initframe_too_short_buffer() { - let mut buffer = vec![0u8; 10]; + let mut buffer = BytesMut::with_capacity(10); let frame1 = InitFrame::Handshake { magic_number: VELOREN_MAGIC_NUMBER, version: VELOREN_NETWORK_VERSION, }; - let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); + InitFrame::to_bytes(frame1.clone(), &mut buffer); } #[test] fn initframe_too_less_data() { - let mut buffer = vec![0u8; 20]; + let mut buffer = BytesMut::with_capacity(20); let frame1 = InitFrame::Handshake { magic_number: VELOREN_MAGIC_NUMBER, @@ -509,79 +412,78 @@ mod tests { }; let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); buffer.truncate(6); // simulate partial retrieve - let frame1d = InitFrame::to_frame(buffer[..6].to_vec()); + let frame1d = InitFrame::to_frame(&mut buffer); assert_eq!(frame1d, None); } #[test] fn initframe_rubish() { - let buffer = b"dtrgwcser".to_vec(); + let mut buffer = BytesMut::from(&b"dtrgwcser"[..]); assert_eq!( - InitFrame::to_frame(buffer), + InitFrame::to_frame(&mut buffer), Some(InitFrame::Raw(b"dtrgwcser".to_vec())) ); } #[test] fn initframe_attack_too_much_length() { - let mut buffer = vec![0u8; 50]; + let mut buffer = BytesMut::with_capacity(50); let frame1 = InitFrame::Raw(b"foobar".to_vec()); let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); - buffer[2] = 255; - let framed = InitFrame::to_frame(buffer); - assert_eq!(framed, None); + buffer[1] = 255; + let framed = InitFrame::to_frame(&mut buffer); + assert_eq!(framed, Some(frame1)); } #[test] fn initframe_attack_too_low_length() { - let mut buffer = vec![0u8; 50]; + let mut buffer = BytesMut::with_capacity(50); let frame1 = InitFrame::Raw(b"foobar".to_vec()); let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); - buffer[2] = 3; - let framed = InitFrame::to_frame(buffer); - assert_eq!(framed, None); + buffer[1] = 3; + let framed = InitFrame::to_frame(&mut buffer); + // we accept a different frame here, as it's RAW and debug only! + assert_eq!(framed, Some(InitFrame::Raw(b"foo".to_vec()))); } #[test] - #[should_panic] fn frame_too_short_buffer() { - let mut buffer = vec![0u8; 10]; + let mut buffer = BytesMut::with_capacity(10); let frame1 = Frame::OpenStream { sid: Sid::new(88), promises: Promises::ENCRYPTED, prio: 88, }; - let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + Frame::to_bytes(frame1.clone(), &mut buffer); } #[test] fn frame_too_less_data() { - let mut buffer = vec![0u8; 20]; + let mut buffer = BytesMut::with_capacity(20); let frame1 = Frame::OpenStream { sid: Sid::new(88), promises: Promises::ENCRYPTED, prio: 88, }; - let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + Frame::to_bytes(frame1.clone(), &mut buffer); buffer.truncate(6); // simulate partial retrieve - let mut buffer = buffer.drain(..6).collect::>(); let frame1d = Frame::to_frame(&mut buffer); assert_eq!(frame1d, None); } #[test] fn frame_rubish() { - let mut buffer = b"dtrgwcser".iter().map(|u| *u).collect::>(); + let mut buffer = BytesMut::from(&b"dtrgwcser"[..]); assert_eq!(Frame::to_frame(&mut buffer), None); } #[test] fn frame_attack_too_much_length() { - let mut buffer = vec![0u8; 50]; + let mut buffer = BytesMut::with_capacity(50); let frame1 = Frame::Data { mid: 7u64, @@ -589,16 +491,15 @@ mod tests { data: b"foobar".to_vec(), }; - let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + Frame::to_bytes(frame1.clone(), &mut buffer); buffer[17] = 255; - let mut buffer = buffer.drain(..).collect::>(); let framed = Frame::to_frame(&mut buffer); assert_eq!(framed, None); } #[test] fn frame_attack_too_low_length() { - let mut buffer = vec![0u8; 50]; + let mut buffer = BytesMut::with_capacity(50); let frame1 = Frame::Data { mid: 7u64, @@ -606,9 +507,8 @@ mod tests { data: b"foobar".to_vec(), }; - let _ = Frame::to_bytes(frame1.clone(), &mut buffer); + Frame::to_bytes(frame1.clone(), &mut buffer); buffer[17] = 3; - let mut buffer = buffer.drain(..).collect::>(); let framed = Frame::to_frame(&mut buffer); assert_eq!( framed, @@ -622,13 +522,4 @@ mod tests { let framed = Frame::to_frame(&mut buffer); assert_eq!(framed, None); } - - #[test] - fn frame_int2str() { - assert_eq!(Frame::int_to_string(0), "Shutdown"); - assert_eq!(Frame::int_to_string(1), "OpenStream"); - assert_eq!(Frame::int_to_string(2), "CloseStream"); - assert_eq!(Frame::int_to_string(3), "DataHeader"); - assert_eq!(Frame::int_to_string(4), "Data"); - } } diff --git a/network/protocol/src/io.rs b/network/protocol/src/io.rs index c4e3eba43e..6ccf40e7d8 100644 --- a/network/protocol/src/io.rs +++ b/network/protocol/src/io.rs @@ -1,5 +1,6 @@ use crate::ProtocolError; use async_trait::async_trait; +use bytes::BytesMut; use std::collections::VecDeque; ///! I/O-Free (Sans-I/O) protocol https://sans-io.readthedocs.io/how-to-sans-io.html @@ -17,11 +18,11 @@ pub trait UnreliableSink: Send { } pub struct BaseDrain { - data: VecDeque>, + data: VecDeque, } pub struct BaseSink { - data: VecDeque>, + data: VecDeque, } impl BaseDrain { @@ -44,7 +45,7 @@ impl BaseSink { #[async_trait] impl UnreliableDrain for BaseDrain { - type DataFormat = Vec; + type DataFormat = BytesMut; async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { self.data.push_back(data); @@ -54,7 +55,7 @@ impl UnreliableDrain for BaseDrain { #[async_trait] impl UnreliableSink for BaseSink { - type DataFormat = Vec; + type DataFormat = BytesMut; async fn recv(&mut self) -> Result { self.data.pop_front().ok_or(ProtocolError::Closed) diff --git a/network/protocol/src/lib.rs b/network/protocol/src/lib.rs index 8d49ed58c9..295d292881 100644 --- a/network/protocol/src/lib.rs +++ b/network/protocol/src/lib.rs @@ -40,6 +40,9 @@ pub trait InitProtocol { pub trait SendProtocol { //a stream MUST be bound to a specific Protocol, there will be a failover // feature comming for the case where a Protocol fails completly + /// use this to notify the sending side of streams that were created/remove + /// from remote + fn notify_from_recv(&mut self, event: ProtocolEvent); async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError>; async fn flush( &mut self, diff --git a/network/protocol/src/metrics.rs b/network/protocol/src/metrics.rs index 715a06fc9d..0b5c66872a 100644 --- a/network/protocol/src/metrics.rs +++ b/network/protocol/src/metrics.rs @@ -1,8 +1,11 @@ use crate::types::Sid; #[cfg(feature = "metrics")] -use prometheus::{IntCounterVec, IntGaugeVec, Opts, Registry}; +use prometheus::{ + core::{AtomicI64, AtomicU64, GenericCounter, GenericGauge}, + IntCounterVec, IntGaugeVec, Opts, Registry, +}; #[cfg(feature = "metrics")] -use std::{error::Error, sync::Arc}; +use std::{collections::HashMap, error::Error, sync::Arc}; #[allow(dead_code)] pub enum RemoveReason { @@ -57,6 +60,12 @@ pub struct ProtocolMetrics { pub struct ProtocolMetricCache { cid: String, m: Arc, + cache: HashMap, + sdata_frames_t: GenericCounter, + sdata_frames_b: GenericCounter, + rdata_frames_t: GenericCounter, + rdata_frames_b: GenericCounter, + ping: GenericGauge, } #[cfg(not(feature = "metrics"))] @@ -192,179 +201,134 @@ impl ProtocolMetrics { } } +#[cfg(feature = "metrics")] +#[derive(Debug, Clone)] +pub(crate) struct CacheLine { + smsg_it: GenericCounter, + smsg_ib: GenericCounter, + smsg_ot: [GenericCounter; 2], + smsg_ob: [GenericCounter; 2], + rmsg_it: GenericCounter, + rmsg_ib: GenericCounter, + rmsg_ot: [GenericCounter; 2], + rmsg_ob: [GenericCounter; 2], +} + #[cfg(feature = "metrics")] impl ProtocolMetricCache { pub fn new(channel_key: &str, metrics: Arc) -> Self { + let cid = channel_key.to_string(); + let sdata_frames_t = metrics.sdata_frames_t.with_label_values(&[&cid]); + let sdata_frames_b = metrics.sdata_frames_b.with_label_values(&[&cid]); + let rdata_frames_t = metrics.rdata_frames_t.with_label_values(&[&cid]); + let rdata_frames_b = metrics.rdata_frames_b.with_label_values(&[&cid]); + let ping = metrics.ping.with_label_values(&[&cid]); Self { - cid: channel_key.to_string(), + cid, m: metrics, + cache: HashMap::new(), + sdata_frames_t, + sdata_frames_b, + rdata_frames_t, + rdata_frames_b, + ping, } } - pub(crate) fn smsg_it(&self, sid: Sid) { - self.m - .smsg_it - .with_label_values(&[&self.cid, &sid.to_string()]) - .inc(); + pub(crate) fn init_sid(&mut self, sid: Sid) -> &CacheLine { + let cid = &self.cid; + let m = &self.m; + self.cache.entry(sid).or_insert_with_key(|sid| { + let s = sid.to_string(); + let finished = RemoveReason::Finished.to_str(); + let dropped = RemoveReason::Dropped.to_str(); + CacheLine { + smsg_it: m.smsg_it.with_label_values(&[&cid, &s]), + smsg_ib: m.smsg_ib.with_label_values(&[&cid, &s]), + smsg_ot: [ + m.smsg_ot.with_label_values(&[&cid, &s, &finished]), + m.smsg_ot.with_label_values(&[&cid, &s, &dropped]), + ], + smsg_ob: [ + m.smsg_ob.with_label_values(&[&cid, &s, &finished]), + m.smsg_ob.with_label_values(&[&cid, &s, &dropped]), + ], + rmsg_it: m.rmsg_it.with_label_values(&[&cid, &s]), + rmsg_ib: m.rmsg_ib.with_label_values(&[&cid, &s]), + rmsg_ot: [ + m.rmsg_ot.with_label_values(&[&cid, &s, &finished]), + m.rmsg_ot.with_label_values(&[&cid, &s, &dropped]), + ], + rmsg_ob: [ + m.rmsg_ob.with_label_values(&[&cid, &s, &finished]), + m.rmsg_ob.with_label_values(&[&cid, &s, &dropped]), + ], + } + }) } - pub(crate) fn smsg_ib(&self, sid: Sid, bytes: u64) { - self.m - .smsg_ib - .with_label_values(&[&self.cid, &sid.to_string()]) - .inc_by(bytes); + pub(crate) fn smsg_ib(&mut self, sid: Sid, bytes: u64) { + let line = self.init_sid(sid); + line.smsg_it.inc(); + line.smsg_ib.inc_by(bytes); } - pub(crate) fn smsg_ot(&self, sid: Sid, reason: RemoveReason) { - self.m - .smsg_ot - .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) - .inc(); + pub(crate) fn smsg_ob(&mut self, sid: Sid, reason: RemoveReason, bytes: u64) { + let line = self.init_sid(sid); + line.smsg_ot[reason.i()].inc(); + line.smsg_ob[reason.i()].inc_by(bytes); } - pub(crate) fn smsg_ob(&self, sid: Sid, reason: RemoveReason, bytes: u64) { - self.m - .smsg_ob - .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) - .inc_by(bytes); + pub(crate) fn sdata_frames_b(&mut self, bytes: u64) { + self.sdata_frames_t.inc(); + self.sdata_frames_b.inc_by(bytes); } - pub(crate) fn sdata_frames_t(&self) { - self.m.sdata_frames_t.with_label_values(&[&self.cid]).inc(); + pub(crate) fn rmsg_ib(&mut self, sid: Sid, bytes: u64) { + let line = self.init_sid(sid); + line.rmsg_it.inc(); + line.rmsg_ib.inc_by(bytes); } - pub(crate) fn sdata_frames_b(&self, bytes: u64) { - self.m - .sdata_frames_b - .with_label_values(&[&self.cid]) - .inc_by(bytes); + pub(crate) fn rmsg_ob(&mut self, sid: Sid, reason: RemoveReason, bytes: u64) { + let line = self.init_sid(sid); + line.rmsg_ot[reason.i()].inc(); + line.rmsg_ob[reason.i()].inc_by(bytes); } - pub(crate) fn rmsg_it(&self, sid: Sid) { - self.m - .rmsg_it - .with_label_values(&[&self.cid, &sid.to_string()]) - .inc(); - } - - pub(crate) fn rmsg_ib(&self, sid: Sid, bytes: u64) { - self.m - .rmsg_ib - .with_label_values(&[&self.cid, &sid.to_string()]) - .inc_by(bytes); - } - - pub(crate) fn rmsg_ot(&self, sid: Sid, reason: RemoveReason) { - self.m - .rmsg_ot - .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) - .inc(); - } - - pub(crate) fn rmsg_ob(&self, sid: Sid, reason: RemoveReason, bytes: u64) { - self.m - .rmsg_ob - .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) - .inc_by(bytes); - } - - pub(crate) fn rdata_frames_t(&self) { - self.m.rdata_frames_t.with_label_values(&[&self.cid]).inc(); - } - - pub(crate) fn rdata_frames_b(&self, bytes: u64) { - self.m - .rdata_frames_b - .with_label_values(&[&self.cid]) - .inc_by(bytes); + pub(crate) fn rdata_frames_b(&mut self, bytes: u64) { + self.rdata_frames_t.inc(); + self.rdata_frames_b.inc_by(bytes); } #[cfg(test)] - pub(crate) fn assert_msg(&self, sid: Sid, cnt: u64, reason: RemoveReason) { - assert_eq!( - self.m - .smsg_it - .with_label_values(&[&self.cid, &sid.to_string()]) - .get(), - cnt - ); - assert_eq!( - self.m - .smsg_ot - .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) - .get(), - cnt - ); - assert_eq!( - self.m - .rmsg_it - .with_label_values(&[&self.cid, &sid.to_string()]) - .get(), - cnt - ); - assert_eq!( - self.m - .rmsg_ot - .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) - .get(), - cnt - ); + pub(crate) fn assert_msg(&mut self, sid: Sid, cnt: u64, reason: RemoveReason) { + let line = self.init_sid(sid); + assert_eq!(line.smsg_it.get(), cnt); + assert_eq!(line.smsg_ot[reason.i()].get(), cnt); + assert_eq!(line.rmsg_it.get(), cnt); + assert_eq!(line.rmsg_ot[reason.i()].get(), cnt); } #[cfg(test)] - pub(crate) fn assert_msg_bytes(&self, sid: Sid, bytes: u64, reason: RemoveReason) { - assert_eq!( - self.m - .smsg_ib - .with_label_values(&[&self.cid, &sid.to_string()]) - .get(), - bytes - ); - assert_eq!( - self.m - .smsg_ob - .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) - .get(), - bytes - ); - assert_eq!( - self.m - .rmsg_ib - .with_label_values(&[&self.cid, &sid.to_string()]) - .get(), - bytes - ); - assert_eq!( - self.m - .rmsg_ob - .with_label_values(&[&self.cid, &sid.to_string(), reason.to_str()]) - .get(), - bytes - ); + pub(crate) fn assert_msg_bytes(&mut self, sid: Sid, bytes: u64, reason: RemoveReason) { + let line = self.init_sid(sid); + assert_eq!(line.smsg_ib.get(), bytes); + assert_eq!(line.smsg_ob[reason.i()].get(), bytes); + assert_eq!(line.rmsg_ib.get(), bytes); + assert_eq!(line.rmsg_ob[reason.i()].get(), bytes); } #[cfg(test)] - pub(crate) fn assert_data_frames(&self, cnt: u64) { - assert_eq!( - self.m.sdata_frames_t.with_label_values(&[&self.cid]).get(), - cnt - ); - assert_eq!( - self.m.rdata_frames_t.with_label_values(&[&self.cid]).get(), - cnt - ); + pub(crate) fn assert_data_frames(&mut self, cnt: u64) { + assert_eq!(self.sdata_frames_t.get(), cnt); + assert_eq!(self.rdata_frames_t.get(), cnt); } #[cfg(test)] - pub(crate) fn assert_data_frames_bytes(&self, bytes: u64) { - assert_eq!( - self.m.sdata_frames_b.with_label_values(&[&self.cid]).get(), - bytes - ); - assert_eq!( - self.m.rdata_frames_b.with_label_values(&[&self.cid]).get(), - bytes - ); + pub(crate) fn assert_data_frames_bytes(&mut self, bytes: u64) { + assert_eq!(self.sdata_frames_b.get(), bytes); + assert_eq!(self.rdata_frames_b.get(), bytes); } } @@ -378,29 +342,29 @@ impl std::fmt::Debug for ProtocolMetrics { #[cfg(not(feature = "metrics"))] impl ProtocolMetricCache { - pub(crate) fn smsg_it(&self, _sid: Sid) {} + pub(crate) fn smsg_it(&mut self, _sid: Sid) {} - pub(crate) fn smsg_ib(&self, _sid: Sid, _b: u64) {} + pub(crate) fn smsg_ib(&mut self, _sid: Sid, _b: u64) {} - pub(crate) fn smsg_ot(&self, _sid: Sid, _reason: RemoveReason) {} + pub(crate) fn smsg_ot(&mut self, _sid: Sid, _reason: RemoveReason) {} - pub(crate) fn smsg_ob(&self, _sid: Sid, _reason: RemoveReason, _b: u64) {} + pub(crate) fn smsg_ob(&mut self, _sid: Sid, _reason: RemoveReason, _b: u64) {} - pub(crate) fn sdata_frames_t(&self) {} + pub(crate) fn sdata_frames_t(&mut self) {} - pub(crate) fn sdata_frames_b(&self, _b: u64) {} + pub(crate) fn sdata_frames_b(&mut self, _b: u64) {} - pub(crate) fn rmsg_it(&self, _sid: Sid) {} + pub(crate) fn rmsg_it(&mut self, _sid: Sid) {} - pub(crate) fn rmsg_ib(&self, _sid: Sid, _b: u64) {} + pub(crate) fn rmsg_ib(&mut self, _sid: Sid, _b: u64) {} - pub(crate) fn rmsg_ot(&self, _sid: Sid, _reason: RemoveReason) {} + pub(crate) fn rmsg_ot(&mut self, _sid: Sid, _reason: RemoveReason) {} - pub(crate) fn rmsg_ob(&self, _sid: Sid, _reason: RemoveReason, _b: u64) {} + pub(crate) fn rmsg_ob(&mut self, _sid: Sid, _reason: RemoveReason, _b: u64) {} - pub(crate) fn rdata_frames_t(&self) {} + pub(crate) fn rdata_frames_t(&mut self) {} - pub(crate) fn rdata_frames_b(&self, _b: u64) {} + pub(crate) fn rdata_frames_b(&mut self, _b: u64) {} } impl RemoveReason { @@ -411,4 +375,12 @@ impl RemoveReason { RemoveReason::Finished => "Finished", } } + + #[cfg(feature = "metrics")] + fn i(&self) -> usize { + match self { + RemoveReason::Dropped => 0, + RemoveReason::Finished => 1, + } + } } diff --git a/network/protocol/src/mpsc.rs b/network/protocol/src/mpsc.rs index 3e9e5d55fe..0fbbee6300 100644 --- a/network/protocol/src/mpsc.rs +++ b/network/protocol/src/mpsc.rs @@ -9,7 +9,10 @@ use crate::{ }; use async_trait::async_trait; use std::time::{Duration, Instant}; +#[cfg(feature = "trace_pedantic")] +use tracing::trace; +#[derive(Debug)] pub /* should be private */ enum MpscMsg { Event(ProtocolEvent), InitFrame(InitFrame), @@ -59,7 +62,11 @@ impl SendProtocol for MpscSendProtcol where D: UnreliableDrain, { + fn notify_from_recv(&mut self, _event: ProtocolEvent) {} + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + #[cfg(feature = "trace_pedantic")] + trace!(?event, "send"); match &event { ProtocolEvent::Message { buffer, @@ -68,10 +75,8 @@ where } => { let sid = *sid; let bytes = buffer.data.len() as u64; - self.metrics.smsg_it(sid); self.metrics.smsg_ib(sid, bytes); let r = self.drain.send(MpscMsg::Event(event)).await; - self.metrics.smsg_ot(sid, RemoveReason::Finished); self.metrics.smsg_ob(sid, RemoveReason::Finished, bytes); r }, @@ -88,7 +93,10 @@ where S: UnreliableSink, { async fn recv(&mut self) -> Result { - match self.sink.recv().await? { + let event = self.sink.recv().await?; + #[cfg(feature = "trace_pedantic")] + trace!(?event, "recv"); + match event { MpscMsg::Event(e) => { if let ProtocolEvent::Message { buffer, @@ -98,9 +106,7 @@ where { let sid = *sid; let bytes = buffer.data.len() as u64; - self.metrics.rmsg_it(sid); self.metrics.rmsg_ib(sid, bytes); - self.metrics.rmsg_ot(sid, RemoveReason::Finished); self.metrics.rmsg_ob(sid, RemoveReason::Finished, bytes); } Ok(e) diff --git a/network/protocol/src/prio.rs b/network/protocol/src/prio.rs index 35b7067352..2348028d15 100644 --- a/network/protocol/src/prio.rs +++ b/network/protocol/src/prio.rs @@ -4,14 +4,18 @@ use crate::{ metrics::{ProtocolMetricCache, RemoveReason}, types::{Bandwidth, Mid, Prio, Promises, Sid}, }; -use std::{collections::HashMap, sync::Arc, time::Duration}; +use std::{ + collections::{HashMap, VecDeque}, + sync::Arc, + time::Duration, +}; #[derive(Debug)] struct StreamInfo { pub(crate) guaranteed_bandwidth: Bandwidth, pub(crate) prio: Prio, pub(crate) promises: Promises, - pub(crate) messages: Vec, + pub(crate) messages: VecDeque, } /// Responsible for queueing messages. @@ -47,7 +51,7 @@ impl PrioManager { guaranteed_bandwidth, prio, promises, - messages: vec![], + messages: VecDeque::new(), }); } @@ -68,7 +72,7 @@ impl PrioManager { .get_mut(&sid) .unwrap() .messages - .push(OutgoingMessage::new(buffer, mid, sid)); + .push_back(OutgoingMessage::new(buffer, mid, sid)); } /// bandwidth might be extended, as for technical reasons @@ -79,7 +83,7 @@ impl PrioManager { let mut frames = vec![]; let mut prios = [0u64; (Self::HIGHEST_PRIO + 1) as usize]; - let metrics = &self.metrics; + let metrics = &mut self.metrics; let mut process_stream = |stream: &mut StreamInfo, mut bandwidth: i64, cur_bytes: &mut u64| { @@ -103,9 +107,8 @@ impl PrioManager { //cleanup for i in finished.iter().rev() { - let msg = stream.messages.remove(*i); + let msg = stream.messages.remove(*i).unwrap(); let (sid, bytes) = msg.get_sid_len(); - metrics.smsg_ot(sid, RemoveReason::Finished); metrics.smsg_ob(sid, RemoveReason::Finished, bytes); } }; diff --git a/network/protocol/src/tcp.rs b/network/protocol/src/tcp.rs index e1c8e10e84..dd84d5b013 100644 --- a/network/protocol/src/tcp.rs +++ b/network/protocol/src/tcp.rs @@ -9,21 +9,25 @@ use crate::{ ProtocolError, RecvProtocol, SendProtocol, }; use async_trait::async_trait; +use bytes::BytesMut; use std::{ - collections::{HashMap, VecDeque}, + collections::HashMap, sync::Arc, time::{Duration, Instant}, }; use tracing::info; +#[cfg(feature = "trace_pedantic")] +use tracing::trace; #[derive(Debug)] pub struct TcpSendProtcol where - D: UnreliableDrain>, + D: UnreliableDrain, { - buffer: Vec, + buffer: BytesMut, store: PrioManager, closing_streams: Vec, + notify_closing_streams: Vec, pending_shutdown: bool, drain: D, last: Instant, @@ -33,9 +37,9 @@ where #[derive(Debug)] pub struct TcpRecvProtcol where - S: UnreliableSink>, + S: UnreliableSink, { - buffer: VecDeque, + buffer: BytesMut, incoming: HashMap, sink: S, metrics: ProtocolMetricCache, @@ -43,13 +47,14 @@ where impl TcpSendProtcol where - D: UnreliableDrain>, + D: UnreliableDrain, { pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self { Self { - buffer: vec![0u8; 1500], + buffer: BytesMut::new(), store: PrioManager::new(metrics.clone()), closing_streams: vec![], + notify_closing_streams: vec![], pending_shutdown: false, drain, last: Instant::now(), @@ -60,11 +65,11 @@ where impl TcpRecvProtcol where - S: UnreliableSink>, + S: UnreliableSink, { pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self { Self { - buffer: VecDeque::new(), + buffer: BytesMut::new(), incoming: HashMap::new(), sink, metrics, @@ -75,9 +80,9 @@ where #[async_trait] impl SendProtocol for TcpSendProtcol where - D: UnreliableDrain>, + D: UnreliableDrain, { - async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + fn notify_from_recv(&mut self, event: ProtocolEvent) { match event { ProtocolEvent::OpenStream { sid, @@ -87,31 +92,54 @@ where } => { self.store .open_stream(sid, prio, promises, guaranteed_bandwidth); - let frame = event.to_frame(); - let (s, _) = frame.to_bytes(&mut self.buffer); - self.drain.send(self.buffer[..s].to_vec()).await?; + }, + ProtocolEvent::CloseStream { sid } => { + if !self.store.try_close_stream(sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "hold back notify close stream"); + self.notify_closing_streams.push(sid); + } + }, + _ => {}, + } + } + + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + #[cfg(feature = "trace_pedantic")] + trace!(?event, "send"); + match event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + self.store + .open_stream(sid, prio, promises, guaranteed_bandwidth); + event.to_frame().to_bytes(&mut self.buffer); + self.drain.send(self.buffer.split()).await?; }, ProtocolEvent::CloseStream { sid } => { if self.store.try_close_stream(sid) { - let frame = event.to_frame(); - let (s, _) = frame.to_bytes(&mut self.buffer); - self.drain.send(self.buffer[..s].to_vec()).await?; + event.to_frame().to_bytes(&mut self.buffer); + self.drain.send(self.buffer.split()).await?; } else { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "hold back close stream"); self.closing_streams.push(sid); } }, ProtocolEvent::Shutdown => { if self.store.is_empty() { - tracing::error!(?event, "send frame"); - let frame = event.to_frame(); - let (s, _) = frame.to_bytes(&mut self.buffer); - self.drain.send(self.buffer[..s].to_vec()).await?; + event.to_frame().to_bytes(&mut self.buffer); + self.drain.send(self.buffer.split()).await?; } else { + #[cfg(feature = "trace_pedantic")] + trace!("hold back shutdown"); self.pending_shutdown = true; } }, ProtocolEvent::Message { buffer, mid, sid } => { - self.metrics.smsg_it(sid); self.metrics.smsg_ib(sid, buffer.data.len() as u64); self.store.add(buffer, mid, sid); }, @@ -128,30 +156,43 @@ where data, } = &frame { - self.metrics.sdata_frames_t(); self.metrics.sdata_frames_b(data.len() as u64); } - let (s, _) = frame.to_bytes(&mut self.buffer); - self.drain.send(self.buffer[..s].to_vec()).await?; - tracing::warn!("send data frame, woop"); + frame.to_bytes(&mut self.buffer); + self.drain.send(self.buffer.split()).await?; } + let mut finished_streams = vec![]; - for (i, sid) in self.closing_streams.iter().enumerate() { - if self.store.try_close_stream(*sid) { - let frame = ProtocolEvent::CloseStream { sid: *sid }.to_frame(); - let (s, _) = frame.to_bytes(&mut self.buffer); - self.drain.send(self.buffer[..s].to_vec()).await?; + for (i, &sid) in self.closing_streams.iter().enumerate() { + if self.store.try_close_stream(sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "close stream, as it's now empty"); + Frame::CloseStream { sid }.to_bytes(&mut self.buffer); + self.drain.send(self.buffer.split()).await?; finished_streams.push(i); } } for i in finished_streams.iter().rev() { self.closing_streams.remove(*i); } + + let mut finished_streams = vec![]; + for (i, sid) in self.notify_closing_streams.iter().enumerate() { + if self.store.try_close_stream(*sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "close stream, as it's now empty"); + finished_streams.push(i); + } + } + for i in finished_streams.iter().rev() { + self.notify_closing_streams.remove(*i); + } + if self.pending_shutdown && self.store.is_empty() { - tracing::error!("send shutdown frame"); - let frame = ProtocolEvent::Shutdown {}.to_frame(); - let (s, _) = frame.to_bytes(&mut self.buffer); - self.drain.send(self.buffer[..s].to_vec()).await?; + #[cfg(feature = "trace_pedantic")] + trace!("shutdown, as it's now empty"); + Frame::Shutdown {}.to_bytes(&mut self.buffer); + self.drain.send(self.buffer.split()).await?; self.pending_shutdown = false; } Ok(()) @@ -173,14 +214,13 @@ struct IncomingMsg { #[async_trait] impl RecvProtocol for TcpRecvProtcol where - S: UnreliableSink>, + S: UnreliableSink, { async fn recv(&mut self) -> Result { - tracing::error!(?self.buffer, "enter loop"); 'outer: loop { - tracing::error!(?self.buffer, "continue loop"); while let Some(frame) = Frame::to_frame(&mut self.buffer) { - tracing::error!(?frame, "recv frame"); + #[cfg(feature = "trace_pedantic")] + trace!(?frame, "recv"); match frame { Frame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), Frame::OpenStream { @@ -204,7 +244,6 @@ where length, data: MessageBuffer { data: vec![] }, }; - self.metrics.rmsg_it(sid); self.metrics.rmsg_ib(sid, length); self.incoming.insert(mid, m); }, @@ -213,12 +252,14 @@ where start: _, mut data, } => { - self.metrics.rdata_frames_t(); self.metrics.rdata_frames_b(data.len() as u64); let m = match self.incoming.get_mut(&mid) { Some(m) => m, None => { - info!("protocol violation by remote side: send Data before Header"); + info!( + ?mid, + "protocol violation by remote side: send Data before Header" + ); break 'outer Err(ProtocolError::Closed); }, }; @@ -227,7 +268,6 @@ where // finished, yay drop(m); let m = self.incoming.remove(&mid).unwrap(); - self.metrics.rmsg_ot(m.sid, RemoveReason::Finished); self.metrics.rmsg_ob( m.sid, RemoveReason::Finished, @@ -242,13 +282,8 @@ where }, }; } - tracing::error!(?self.buffer, "receiving on tcp sink"); let chunk = self.sink.recv().await?; - self.buffer.reserve(chunk.len()); - for b in chunk { - self.buffer.push_back(b); - } - tracing::error!(?self.buffer,"receiving on tcp sink done"); + self.buffer.extend_from_slice(&chunk); } } } @@ -256,12 +291,11 @@ where #[async_trait] impl ReliableDrain for TcpSendProtcol where - D: UnreliableDrain>, + D: UnreliableDrain, { async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> { - let mut buffer = vec![0u8; 1500]; - let s = frame.to_bytes(&mut buffer); - buffer.truncate(s); + let mut buffer = BytesMut::with_capacity(500); + frame.to_bytes(&mut buffer); self.drain.send(buffer).await } } @@ -269,22 +303,13 @@ where #[async_trait] impl ReliableSink for TcpRecvProtcol where - S: UnreliableSink>, + S: UnreliableSink, { async fn recv(&mut self) -> Result { while self.buffer.len() < 100 { let chunk = self.sink.recv().await?; - self.buffer.reserve(chunk.len()); - for b in chunk { - self.buffer.push_back(b); - } - let todo_use_bytes_instead = self.buffer.iter().map(|b| *b).collect(); - if let Some(frame) = InitFrame::to_frame(todo_use_bytes_instead) { - match frame { - InitFrame::Handshake { .. } => self.buffer.drain(.. InitFrame::HANDSHAKE_CNS + 1), - InitFrame::Init { .. } => self.buffer.drain(.. InitFrame::INIT_CNS + 1), - InitFrame::Raw { .. } => self.buffer.drain(.. InitFrame::RAW_CNS + 1), - }; + self.buffer.extend_from_slice(&chunk); + if let Some(frame) = InitFrame::to_frame(&mut self.buffer) { return Ok(frame); } } @@ -303,11 +328,11 @@ mod test_utils { use async_channel::*; pub struct TcpDrain { - pub sender: Sender>, + pub sender: Sender, } pub struct TcpSink { - pub receiver: Receiver>, + pub receiver: Receiver, } /// emulate Tcp protocol on Channels @@ -334,7 +359,7 @@ mod test_utils { #[async_trait] impl UnreliableDrain for TcpDrain { - type DataFormat = Vec; + type DataFormat = BytesMut; async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { self.sender @@ -346,7 +371,7 @@ mod test_utils { #[async_trait] impl UnreliableSink for TcpSink { - type DataFormat = Vec; + type DataFormat = BytesMut; async fn recv(&mut self) -> Result { self.receiver @@ -365,6 +390,7 @@ mod tests { types::{Pid, Promises, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2}, InitProtocol, MessageBuffer, ProtocolEvent, RecvProtocol, SendProtocol, }; + use bytes::BytesMut; use std::{sync::Arc, time::Duration}; #[tokio::test] @@ -431,7 +457,7 @@ mod tests { #[tokio::test] async fn send_long_msg() { - let metrics = + let mut metrics = ProtocolMetricCache::new("long_tcp", Arc::new(ProtocolMetrics::new().unwrap())); let sid = Sid::new(1); let [p1, p2] = tcp_bound(10000, Some(metrics.clone())); @@ -538,39 +564,36 @@ mod tests { const DATA1: &[u8; 69] = b"We need to make sure that its okay to send OPEN_STREAM and DATA_HEAD "; const DATA2: &[u8; 95] = b"in one chunk and (DATA and CLOSE_STREAM) in the second chunk. and then keep the connection open"; - let mut buf = vec![0u8; 1500]; - let event = ProtocolEvent::OpenStream { + let mut bytes = BytesMut::with_capacity(1500); + use crate::frame::Frame; + Frame::OpenStream { sid, prio: 5u8, promises: Promises::COMPRESSED, - guaranteed_bandwidth: 0, - }; - let (i, _) = event.to_frame().to_bytes(&mut buf); - let (i2, _) = crate::frame::Frame::DataHeader { + } + .to_bytes(&mut bytes); + Frame::DataHeader { mid: 99, sid, length: (DATA1.len() + DATA2.len()) as u64, } - .to_bytes(&mut buf[i..]); - buf.truncate(i + i2); - s.send(buf).await.unwrap(); + .to_bytes(&mut bytes); + s.send(bytes.split()).await.unwrap(); - let mut buf = vec![0u8; 1500]; - let (i, _) = crate::frame::Frame::Data { + Frame::Data { mid: 99, start: 0, data: DATA1.to_vec(), } - .to_bytes(&mut buf); - let (i2, _) = crate::frame::Frame::Data { + .to_bytes(&mut bytes); + Frame::Data { mid: 99, start: DATA1.len() as u64, data: DATA2.to_vec(), } - .to_bytes(&mut buf[i..]); - let (i3, _) = crate::frame::Frame::CloseStream { sid }.to_bytes(&mut buf[i + i2..]); - buf.truncate(i + i2 + i3); - s.send(buf).await.unwrap(); + .to_bytes(&mut bytes); + Frame::CloseStream { sid }.to_bytes(&mut bytes); + s.send(bytes.split()).await.unwrap(); let e = r.recv().await.unwrap(); assert!(matches!(e, ProtocolEvent::OpenStream { .. })); @@ -581,4 +604,58 @@ mod tests { let e = r.recv().await.unwrap(); assert!(matches!(e, ProtocolEvent::CloseStream { .. })); } + + #[tokio::test] + #[should_panic] + async fn send_on_stream_from_remote_without_notify() { + //remote opens stream + //we send on it + let [mut p1, mut p2] = tcp_bound(10, None); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + p1.0.send(event).await.unwrap(); + let _ = p2.1.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + mid: 0, + buffer: Arc::new(MessageBuffer { + data: vec![188u8; 600], + }), + }; + p2.0.send(event.clone()).await.unwrap(); + p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = p1.1.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn send_on_stream_from_remote() { + //remote opens stream + //we send on it + let [mut p1, mut p2] = tcp_bound(10, None); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + p1.0.send(event).await.unwrap(); + let e = p2.1.recv().await.unwrap(); + p2.0.notify_from_recv(e); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + mid: 0, + buffer: Arc::new(MessageBuffer { + data: vec![188u8; 600], + }), + }; + p2.0.send(event.clone()).await.unwrap(); + p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = p1.1.recv().await.unwrap(); + assert_eq!(event, e); + } } diff --git a/network/protocol/src/types.rs b/network/protocol/src/types.rs index b6f63ca208..ba9348a16d 100644 --- a/network/protocol/src/types.rs +++ b/network/protocol/src/types.rs @@ -1,4 +1,5 @@ use bitflags::bitflags; +use bytes::{Buf, BufMut, BytesMut}; use rand::Rng; pub type Mid = u64; @@ -88,25 +89,19 @@ impl Pid { } } - pub(crate) fn to_le_bytes(&self) -> [u8; 16] { self.internal.to_le_bytes() } - - pub(crate) fn from_le_bytes(bytes: [u8; 16]) -> Self { + pub(crate) fn from_bytes(bytes: &mut BytesMut) -> Self { Self { - internal: u128::from_le_bytes(bytes), + internal: bytes.get_u128_le(), } } + + pub(crate) fn to_bytes(&self, bytes: &mut BytesMut) { bytes.put_u128_le(self.internal) } } impl Sid { pub const fn new(internal: u64) -> Self { Self { internal } } pub(crate) fn to_le_bytes(&self) -> [u8; 8] { self.internal.to_le_bytes() } - - pub(crate) fn from_le_bytes(bytes: [u8; 8]) -> Self { - Self { - internal: u64::from_le_bytes(bytes), - } - } } impl std::fmt::Debug for Pid { diff --git a/network/src/api.rs b/network/src/api.rs index 08274c90be..bd06f1e472 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -9,7 +9,7 @@ use crate::{ }; #[cfg(feature = "compression")] use lz_fear::raw::DecodeError; -use network_protocol::{Bandwidth, MessageBuffer, Mid, Pid, Prio, Promises, Sid}; +use network_protocol::{Bandwidth, MessageBuffer, Pid, Prio, Promises, Sid}; #[cfg(feature = "metrics")] use prometheus::Registry; use serde::{de::DeserializeOwned, Serialize}; @@ -28,7 +28,6 @@ use tokio::{ sync::{mpsc, oneshot, Mutex}, }; use tracing::*; -use tracing_futures::Instrument; type A2sDisconnect = Arc>>>; @@ -70,9 +69,9 @@ pub struct Participant { /// [`opened`]: Participant::opened #[derive(Debug)] pub struct Stream { - pid: Pid, + local_pid: Pid, + remote_pid: Pid, sid: Sid, - mid: Mid, prio: Prio, promises: Promises, guaranteed_bandwidth: Bandwidth, @@ -239,7 +238,8 @@ impl Network { #[cfg(feature = "metrics")] registry: Option<&Registry>, ) -> Self { let p = participant_id; - debug!(?p, "Starting Network"); + let span = tracing::info_span!("network", ?p); + span.in_scope(|| trace!("Starting Network")); let (scheduler, listen_sender, connect_sender, connected_receiver, shutdown_sender) = Scheduler::new( participant_id, @@ -247,14 +247,14 @@ impl Network { #[cfg(feature = "metrics")] registry, ); - runtime.spawn(async move { - trace!(?p, "Starting scheduler in own thread"); - scheduler - .run() - .instrument(tracing::info_span!("scheduler", ?p)) - .await; - trace!(?p, "Stopping scheduler and his own thread"); - }); + runtime.spawn( + async move { + trace!("Starting scheduler in own thread"); + scheduler.run().await; + trace!("Stopping scheduler and his own thread"); + } + .instrument(tracing::info_span!("network", ?p)), + ); Self { local_pid: participant_id, runtime, @@ -295,6 +295,7 @@ impl Network { /// ``` /// /// [`connected`]: Network::connected + #[instrument(name="network", skip(self, address), fields(p = %self.local_pid))] pub async fn listen(&self, address: ProtocolAddr) -> Result<(), NetworkError> { let (s2a_result_s, s2a_result_r) = oneshot::channel::>(); debug!(?address, "listening on address"); @@ -350,6 +351,7 @@ impl Network { /// /// [`Streams`]: crate::api::Stream /// [`ProtocolAddres`]: crate::api::ProtocolAddr + #[instrument(name="network", skip(self, address), fields(p = %self.local_pid))] pub async fn connect(&self, address: ProtocolAddr) -> Result { let (pid_sender, pid_receiver) = oneshot::channel::>(); debug!(?address, "Connect to address"); @@ -361,15 +363,12 @@ impl Network { Ok(p) => p, Err(e) => return Err(NetworkError::ConnectFailed(e)), }; - let pid = participant.remote_pid; - debug!( - ?pid, - "Received Participant id from remote and return to user" - ); + let remote_pid = participant.remote_pid; + trace!(?remote_pid, "connected"); self.participant_disconnect_sender .lock() .await - .insert(pid, Arc::clone(&participant.a2s_disconnect_s)); + .insert(remote_pid, Arc::clone(&participant.a2s_disconnect_s)); Ok(participant) } @@ -406,6 +405,7 @@ impl Network { /// /// [`Streams`]: crate::api::Stream /// [`listen`]: crate::api::Network::listen + #[instrument(name="network", skip(self), fields(p = %self.local_pid))] pub async fn connected(&self) -> Result { let participant = self.connected_receiver.lock().await.recv().await?; self.participant_disconnect_sender.lock().await.insert( @@ -475,6 +475,7 @@ impl Participant { /// ``` /// /// [`Streams`]: crate::api::Stream + #[instrument(name="network", skip(self, prio, promises), fields(p = %self.local_pid))] pub async fn open(&self, prio: u8, promises: Promises) -> Result { let (p2a_return_stream_s, p2a_return_stream_r) = oneshot::channel::(); if let Err(e) = self.a2b_open_stream_s.lock().await.send(( @@ -489,11 +490,11 @@ impl Participant { match p2a_return_stream_r.await { Ok(stream) => { let sid = stream.sid; - debug!(?sid, ?self.remote_pid, "opened stream"); + trace!(?sid, "opened stream"); Ok(stream) }, Err(_) => { - debug!(?self.remote_pid, "p2a_return_stream_r failed, closing participant"); + debug!("p2a_return_stream_r failed, closing participant"); Err(ParticipantError::ParticipantDisconnected) }, } @@ -532,15 +533,16 @@ impl Participant { /// [`Streams`]: crate::api::Stream /// [`connected`]: Network::connected /// [`open`]: Participant::open + #[instrument(name="network", skip(self), fields(p = %self.local_pid))] pub async fn opened(&self) -> Result { match self.b2a_stream_opened_r.lock().await.recv().await { Some(stream) => { let sid = stream.sid; - debug!(?sid, ?self.remote_pid, "Receive opened stream"); + debug!(?sid, "Receive opened stream"); Ok(stream) }, None => { - debug!(?self.remote_pid, "stream_opened_receiver failed, closing participant"); + debug!("stream_opened_receiver failed, closing participant"); Err(ParticipantError::ParticipantDisconnected) }, } @@ -589,10 +591,10 @@ impl Participant { /// ``` /// /// [`Streams`]: crate::api::Stream + #[instrument(name="network", skip(self), fields(p = %self.local_pid))] pub async fn disconnect(self) -> Result<(), ParticipantError> { // Remove, Close and try_unwrap error when unwrap fails! - let pid = self.remote_pid; - debug!(?pid, "Closing participant from network"); + debug!("Closing participant from network"); //Streams will be closed by BParticipant match self.a2s_disconnect_s.lock().await.take() { @@ -601,14 +603,14 @@ impl Participant { // Participant is connecting to Scheduler here, not as usual // Participant<->BParticipant a2s_disconnect_s - .send((pid, (Duration::from_secs(120), finished_sender))) + .send((self.remote_pid, (Duration::from_secs(120), finished_sender))) .expect("Something is wrong in internal scheduler coding"); match finished_receiver.await { Ok(res) => { match res { - Ok(()) => trace!(?pid, "Participant is now closed"), + Ok(()) => trace!("Participant is now closed"), Err(ref e) => { - trace!(?pid, ?e, "Error occurred during shutdown of participant") + trace!(?e, "Error occurred during shutdown of participant") }, }; res @@ -616,7 +618,6 @@ impl Participant { Err(e) => { //this is a bug. but as i am Participant i can't destroy the network error!( - ?pid, ?e, "Failed to get a message back from the scheduler, seems like the \ network is already closed" @@ -642,7 +643,8 @@ impl Participant { impl Stream { #[allow(clippy::too_many_arguments)] pub(crate) fn new( - pid: Pid, + local_pid: Pid, + remote_pid: Pid, sid: Sid, prio: Prio, promises: Promises, @@ -653,9 +655,9 @@ impl Stream { a2b_close_stream_s: mpsc::UnboundedSender, ) -> Self { Self { - pid, + local_pid, + remote_pid, sid, - mid: 0, prio, promises, guaranteed_bandwidth, @@ -779,7 +781,6 @@ impl Stream { message.verify(&self); self.a2b_msg_s .send((self.sid, Arc::clone(&message.buffer)))?; - self.mid += 1; Ok(()) } @@ -942,13 +943,10 @@ impl core::cmp::PartialEq for Participant { } impl Drop for Network { + #[instrument(name="network", skip(self), fields(p = %self.local_pid))] fn drop(&mut self) { - let pid = self.local_pid; - debug!(?pid, "Shutting down Network"); - trace!( - ?pid, - "Shutting down Participants of Network, while we still have metrics" - ); + debug!("Shutting down Network"); + trace!("Shutting down Participants of Network, while we still have metrics"); let mut finished_receiver_list = vec![]; if tokio::runtime::Handle::try_current().is_ok() { @@ -991,25 +989,25 @@ impl Drop for Network { } }); }); - trace!(?pid, "Participants have shut down!"); - trace!(?pid, "Shutting down Scheduler"); + trace!("Participants have shut down!"); + trace!("Shutting down Scheduler"); self.shutdown_sender .take() .unwrap() .send(()) .expect("Scheduler is closed, but nobody other should be able to close it"); - debug!(?pid, "Network has shut down"); + debug!("Network has shut down"); } } impl Drop for Participant { + #[instrument(name="remote", skip(self), fields(p = %self.remote_pid))] + #[instrument(name="network", skip(self), fields(p = %self.local_pid))] fn drop(&mut self) { use tokio::sync::oneshot::error::TryRecvError; - // ignore closed, as we need to send it even though we disconnected the // participant from network - let pid = self.remote_pid; - debug!(?pid, "Shutting down Participant"); + debug!("Shutting down Participant"); match self .a2s_disconnect_s @@ -1017,25 +1015,27 @@ impl Drop for Participant { .expect("Participant in use while beeing dropped") .take() { - None => trace!( - ?pid, - "Participant has been shutdown cleanly, no further waiting is required!" - ), + None => info!("Participant already has been shutdown gracefully"), Some(a2s_disconnect_s) => { - debug!(?pid, "Disconnect from Scheduler"); + debug!("Disconnect from Scheduler"); let (finished_sender, mut finished_receiver) = oneshot::channel(); a2s_disconnect_s .send((self.remote_pid, (Duration::from_secs(120), finished_sender))) .expect("Something is wrong in internal scheduler coding"); loop { match finished_receiver.try_recv() { - Ok(Ok(())) => break, - Ok(Err(e)) => error!( - ?pid, - ?e, - "Error while dropping the participant, couldn't send all outgoing \ - messages, dropping remaining" - ), + Ok(Ok(())) => { + info!("Participant dropped gracefully"); + break; + }, + Ok(Err(e)) => { + error!( + ?e, + "Error while dropping the participant, couldn't send all outgoing \ + messages, dropping remaining" + ); + break; + }, Err(TryRecvError::Closed) => { panic!("Something is wrong in internal scheduler/participant coding") }, @@ -1047,17 +1047,17 @@ impl Drop for Participant { } }, } - debug!(?pid, "Participant dropped"); } } impl Drop for Stream { + #[instrument(name="remote", skip(self), fields(p = %self.remote_pid))] + #[instrument(name="network", skip(self), fields(p = %self.local_pid))] fn drop(&mut self) { // send if closed is unnecessary but doesn't hurt, we must not crash if !self.send_closed.load(Ordering::Relaxed) { let sid = self.sid; - let pid = self.pid; - debug!(?pid, ?sid, "Shutting down Stream"); + debug!(?sid, "Shutting down Stream"); if let Err(e) = self.a2b_close_stream_s.take().unwrap().send(self.sid) { debug!( ?e, @@ -1066,8 +1066,7 @@ impl Drop for Stream { } } else { let sid = self.sid; - let pid = self.pid; - trace!(?pid, ?sid, "Stream Drop not needed"); + trace!(?sid, "Stream Drop not needed"); } } } diff --git a/network/src/channel.rs b/network/src/channel.rs index 654175fb1d..9b6472268e 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use bytes::BytesMut; use network_protocol::{ InitProtocolError, MpscMsg, MpscRecvProtcol, MpscSendProtcol, Pid, ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtcol, TcpSendProtcol, @@ -42,7 +43,13 @@ impl Protocols { let metrics = ProtocolMetricCache {}; let sp = TcpSendProtcol::new(TcpDrain { half: w }, metrics.clone()); - let rp = TcpRecvProtcol::new(TcpSink { half: r }, metrics.clone()); + let rp = TcpRecvProtcol::new( + TcpSink { + half: r, + buffer: BytesMut::new(), + }, + metrics.clone(), + ); Protocols::Tcp((sp, rp)) } @@ -86,6 +93,13 @@ impl network_protocol::InitProtocol for Protocols { #[async_trait] impl network_protocol::SendProtocol for SendProtocols { + fn notify_from_recv(&mut self, event: ProtocolEvent) { + match self { + SendProtocols::Tcp(s) => s.notify_from_recv(event), + SendProtocols::Mpsc(s) => s.notify_from_recv(event), + } + } + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { match self { SendProtocols::Tcp(s) => s.send(event).await, @@ -121,14 +135,14 @@ pub struct TcpDrain { #[derive(Debug)] pub struct TcpSink { half: OwnedReadHalf, + buffer: BytesMut, } #[async_trait] impl UnreliableDrain for TcpDrain { - type DataFormat = Vec; + type DataFormat = BytesMut; async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { - //self.half.recv match self.half.write_all(&data).await { Ok(()) => Ok(()), Err(_) => Err(ProtocolError::Closed), @@ -138,15 +152,12 @@ impl UnreliableDrain for TcpDrain { #[async_trait] impl UnreliableSink for TcpSink { - type DataFormat = Vec; + type DataFormat = BytesMut; async fn recv(&mut self) -> Result { - let mut data = vec![0u8; 1500]; - match self.half.read(&mut data).await { - Ok(n) => { - data.truncate(n); - Ok(data) - }, + self.buffer.resize(1500, 0u8); + match self.half.read(&mut self.buffer).await { + Ok(n) => Ok(self.buffer.split_to(n)), Err(_) => Err(ProtocolError::Closed), } } diff --git a/network/src/message.rs b/network/src/message.rs index 0ad24c63ad..1969854a7a 100644 --- a/network/src/message.rs +++ b/network/src/message.rs @@ -200,6 +200,7 @@ mod tests { Stream::new( Pid::fake(0), + Pid::fake(1), Sid::new(0), 0u8, promises, diff --git a/network/src/participant.rs b/network/src/participant.rs index a942632f7b..7d07a5c5ac 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -60,6 +60,7 @@ struct ShutdownInfo { #[derive(Debug)] pub struct BParticipant { + local_pid: Pid, //tracing remote_pid: Pid, remote_pid_string: String, //optimisation offset_sid: Sid, @@ -82,6 +83,7 @@ impl BParticipant { #[allow(clippy::type_complexity)] pub(crate) fn new( + local_pid: Pid, remote_pid: Pid, offset_sid: Sid, #[cfg(feature = "metrics")] metrics: Arc, @@ -106,6 +108,7 @@ impl BParticipant { ( Self { + local_pid, remote_pid, remote_pid_string: remote_pid.to_string(), offset_sid, @@ -135,6 +138,8 @@ impl BParticipant { async_channel::unbounded::(); let (b2b_force_close_recv_protocol_s, b2b_force_close_recv_protocol_r) = async_channel::unbounded::(); + let (b2b_notify_send_of_recv_s, b2b_notify_send_of_recv_r) = + mpsc::unbounded_channel::(); let (a2b_close_stream_s, a2b_close_stream_r) = mpsc::unbounded_channel::(); const STREAM_BOUND: usize = 10_000; @@ -142,6 +147,7 @@ impl BParticipant { crossbeam_channel::bounded::<(Sid, Arc)>(STREAM_BOUND); let run_channels = self.run_channels.take().unwrap(); + trace!("start all managers"); tokio::join!( self.send_mgr( run_channels.a2b_open_stream_r, @@ -149,18 +155,22 @@ impl BParticipant { a2b_msg_r, b2b_add_send_protocol_r, b2b_close_send_protocol_r, + b2b_notify_send_of_recv_r, b2s_prio_statistic_s, a2b_msg_s.clone(), //self a2b_close_stream_s.clone(), //self - ), + ) + .instrument(tracing::info_span!("send")), self.recv_mgr( run_channels.b2a_stream_opened_s, b2b_add_recv_protocol_r, b2b_force_close_recv_protocol_r, b2b_close_send_protocol_s.clone(), + b2b_notify_send_of_recv_s, a2b_msg_s.clone(), //self a2b_close_stream_s.clone(), //self - ), + ) + .instrument(tracing::info_span!("recv")), self.create_channel_mgr( run_channels.s2b_create_channel_r, b2b_add_send_protocol_s, @@ -182,6 +192,7 @@ impl BParticipant { a2b_msg_r: crossbeam_channel::Receiver<(Sid, Arc)>, mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, SendProtocols)>, b2b_close_send_protocol_r: async_channel::Receiver, + mut b2b_notify_send_of_recv_r: mpsc::UnboundedReceiver, _b2s_prio_statistic_s: mpsc::UnboundedSender, a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, a2b_close_stream_s: mpsc::UnboundedSender, @@ -189,27 +200,29 @@ impl BParticipant { let mut send_protocols: HashMap = HashMap::new(); let mut interval = tokio::time::interval(Self::TICK_TIME); let mut stream_ids = self.offset_sid; - trace!("workaround, activly wait for first protocol"); + let mut fake_mid = 0; //TODO: move MID to protocol, should be inc per stream ? or ? + trace!("workaround, actively wait for first protocol"); b2b_add_protocol_r .recv() .await .map(|(c, p)| send_protocols.insert(c, p)); - trace!("Start send_mgr"); loop { - let (open, close, _, addp, remp) = select!( - next = a2b_open_stream_r.recv().fuse() => (Some(next), None, None, None, None), - next = a2b_close_stream_r.recv().fuse() => (None, Some(next), None, None, None), - _ = interval.tick() => (None, None, Some(()), None, None), - next = b2b_add_protocol_r.recv().fuse() => (None, None, None, Some(next), None), - next = b2b_close_send_protocol_r.recv().fuse() => (None, None, None, None, Some(next)), + let (open, close, r_event, _, addp, remp) = select!( + n = a2b_open_stream_r.recv().fuse() => (Some(n), None, None, None, None, None), + n = a2b_close_stream_r.recv().fuse() => (None, Some(n), None, None, None, None), + n = b2b_notify_send_of_recv_r.recv().fuse() => (None, None, Some(n), None, None, None), + _ = interval.tick() => (None, None, None, Some(()), None, None), + n = b2b_add_protocol_r.recv().fuse() => (None, None, None, None, Some(n), None), + n = b2b_close_send_protocol_r.recv().fuse() => (None, None, None, None, None, Some(n)), ); - trace!(?open, ?close, ?addp, ?remp, "foobar"); - - addp.flatten().map(|(c, p)| send_protocols.insert(c, p)); + addp.flatten().map(|(cid, p)| { + debug!(?cid, "add protocol"); + send_protocols.insert(cid, p) + }); match remp { Some(Ok(cid)) => { - trace!(?cid, "remove send protocol"); + debug!(?cid, "remove protocol"); match send_protocols.remove(&cid) { Some(mut prot) => { trace!("blocking flush"); @@ -230,15 +243,19 @@ impl BParticipant { let active = match send_protocols.get_mut(&cid) { Some(a) => a, None => { - warn!("no channel arrg"); + warn!("no channel"); continue; }, }; let active_err = async { + if let Some(Some(event)) = r_event { + active.notify_from_recv(event); + } + if let Some(Some((prio, promises, guaranteed_bandwidth, return_s))) = open { - trace!(?stream_ids, "openuing some new stream"); let sid = stream_ids; + trace!(?sid, "open stream"); stream_ids += Sid::from(1); let stream = self .create_stream( @@ -264,25 +281,24 @@ impl BParticipant { // get all messages and assign it to a channel for (sid, buffer) in a2b_msg_r.try_iter() { - warn!(?sid, "sending!"); + fake_mid += 1; active .send(ProtocolEvent::Message { buffer, - mid: 0u64, + mid: fake_mid, sid, }) .await? } if let Some(Some(sid)) = close { - warn!(?sid, "delete_stream!"); + trace!(?stream_ids, "delete stream"); self.delete_stream(sid).await; // Fire&Forget the protocol will take care to verify that this Frame is delayed // till the last msg was received! active.send(ProtocolEvent::CloseStream { sid }).await?; } - warn!("flush!"); active .flush(1_000_000, Duration::from_secs(1) /* TODO */) .await?; //this actually blocks, so we cant set streams whilte it. @@ -291,7 +307,7 @@ impl BParticipant { } .await; if let Err(e) = active_err { - info!(?cid, ?e, "send protocol failed, shutting down channel"); + info!(?cid, ?e, "protocol failed, shutting down channel"); // remote recv will now fail, which will trigger remote send which will trigger // recv send_protocols.remove(&cid).unwrap(); @@ -308,6 +324,7 @@ impl BParticipant { mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, RecvProtocols)>, b2b_force_close_recv_protocol_r: async_channel::Receiver, b2b_close_send_protocol_s: async_channel::Sender, + b2b_notify_send_of_recv_s: mpsc::UnboundedSender, a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, a2b_close_stream_s: mpsc::UnboundedSender, ) { @@ -327,13 +344,15 @@ impl BParticipant { let remove_c = |recv_protocols: &mut HashMap>, cid: &Cid| { match recv_protocols.remove(&cid) { - Some(h) => h.abort(), + Some(h) => { + h.abort(); + debug!(?cid, "remove protocol"); + }, None => trace!("tried to remove protocol twice"), }; recv_protocols.is_empty() }; - trace!("Start recv_mgr"); loop { let (event, addp, remp) = select!( next = hacky_recv_r.recv().fuse() => (Some(next), None, None), @@ -342,6 +361,7 @@ impl BParticipant { ); addp.map(|(cid, p)| { + debug!(?cid, "add protocol"); retrigger(cid, p, &mut recv_protocols); }); if let Some(Ok(cid)) = remp { @@ -351,7 +371,6 @@ impl BParticipant { } }; - warn!(?event, "recv event!"); if let Some(Some((cid, r, p))) = event { match r { Ok(ProtocolEvent::OpenStream { @@ -361,6 +380,7 @@ impl BParticipant { guaranteed_bandwidth, }) => { trace!(?sid, "open stream"); + let _ = b2b_notify_send_of_recv_s.send(r.unwrap()); let stream = self .create_stream( sid, @@ -376,6 +396,7 @@ impl BParticipant { }, Ok(ProtocolEvent::CloseStream { sid }) => { trace!(?sid, "close stream"); + let _ = b2b_notify_send_of_recv_s.send(r.unwrap()); self.delete_stream(sid).await; retrigger(cid, p, &mut recv_protocols); }, @@ -410,7 +431,7 @@ impl BParticipant { } }, Err(e) => { - info!(?cid, ?e, "recv protocol failed, shutting down channel"); + info!(?e, ?cid, "protocol failed, shutting down channel"); if let Err(e) = b2b_close_send_protocol_s.send(cid).await { debug!(?e, ?cid, "send_mgr was already closed simultaneously"); } @@ -433,7 +454,6 @@ impl BParticipant { b2b_add_send_protocol_s: mpsc::UnboundedSender<(Cid, SendProtocols)>, b2b_add_recv_protocol_s: mpsc::UnboundedSender<(Cid, RecvProtocols)>, ) { - trace!("Start create_channel_mgr"); let s2b_create_channel_r = UnboundedReceiverStream::new(s2b_create_channel_r); s2b_create_channel_r .for_each_concurrent(None, |(cid, _, protocol, b2s_create_channel_done_s)| { @@ -524,12 +544,8 @@ impl BParticipant { } } }; - - trace!("Start participant_shutdown_mgr"); let (timeout_time, sender) = s2b_shutdown_bparticipant_r.await.unwrap(); - debug!("participant_shutdown_mgr triggered"); - - debug!("Closing all streams for send"); + debug!("participant_shutdown_mgr triggered. Closing all streams for send"); { let lock = self.streams.read().await; for si in lock.values() { @@ -632,6 +648,7 @@ impl BParticipant { .with_label_values(&[&self.remote_pid_string]) .inc(); Stream::new( + self.local_pid, self.remote_pid, sid, prio, @@ -676,11 +693,12 @@ mod tests { s2b_create_channel_s, s2b_shutdown_bparticipant_s, ) = runtime_clone.block_on(async move { - let pid = Pid::fake(1); + let local_pid = Pid::fake(0); + let remote_pid = Pid::fake(1); let sid = Sid::new(1000); - let metrics = Arc::new(NetworkMetrics::new(&pid).unwrap()); + let metrics = Arc::new(NetworkMetrics::new(&local_pid).unwrap()); - BParticipant::new(pid, sid, Arc::clone(&metrics)) + BParticipant::new(local_pid, remote_pid, sid, Arc::clone(&metrics)) }); let handle = runtime_clone.spawn(bparticipant.run(b2s_prio_statistic_s)); diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index eb6d21bd7e..a1f31fc384 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -6,7 +6,7 @@ use crate::{ participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel, S2bShutdownBparticipant}, }; use futures_util::{FutureExt, StreamExt}; -use network_protocol::Pid; +use network_protocol::{MpscMsg, Pid}; #[cfg(feature = "metrics")] use prometheus::Registry; use rand::Rng; @@ -26,16 +26,21 @@ use tokio::{ }; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::*; -use tracing_futures::Instrument; -/// Naming of Channels `x2x` -/// - a: api -/// - s: scheduler -/// - b: bparticipant -/// - p: prios -/// - r: protocol -/// - w: wire -/// - c: channel/handshake +// Naming of Channels `x2x` +// - a: api +// - s: scheduler +// - b: bparticipant +// - p: prios +// - r: protocol +// - w: wire +// - c: channel/handshake + +lazy_static::lazy_static! { + static ref MPSC_POOL: Mutex, oneshot::Sender>)>>> = { + Mutex::new(HashMap::new()) + }; +} #[derive(Debug)] struct ParticipantInfo { @@ -80,6 +85,8 @@ pub struct Scheduler { } impl Scheduler { + const MPSC_CHANNEL_BOUND: usize = 1000; + pub fn new( local_pid: Pid, runtime: Arc, @@ -215,7 +222,35 @@ impl Scheduler { }; info!("Connecting Tcp to: {}", stream.peer_addr().unwrap()); (Protocols::new_tcp(stream), false) - }, /* */ + }, + ProtocolAddr::Mpsc(addr) => { + let mpsc_s = match MPSC_POOL.lock().await.get(&addr) { + Some(s) => s.clone(), + None => { + pid_sender + .send(Err(std::io::Error::new( + std::io::ErrorKind::NotConnected, + "no mpsc listen on this addr", + ))) + .unwrap(); + continue; + }, + }; + let (remote_to_local_s, remote_to_local_r) = + mpsc::channel(Self::MPSC_CHANNEL_BOUND); + let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel(); + mpsc_s + .send((remote_to_local_s, local_to_remote_oneshot_s)) + .unwrap(); + let local_to_remote_s = local_to_remote_oneshot_r.await.unwrap(); + + info!(?addr, "Connecting Mpsc"); + ( + Protocols::new_mpsc(local_to_remote_s, remote_to_local_r), + false, + ) + }, + /* */ //ProtocolAddr::Udp(addr) => { //#[cfg(feature = "metrics")] //self.metrics @@ -367,7 +402,7 @@ impl Scheduler { info!( ?addr, ?e, - "Listener couldn't be started due to error on tcp bind" + "Tcp bind error durin listener startup" ); s2a_listen_result_s.send(Err(e)).unwrap(); return; @@ -390,6 +425,25 @@ impl Scheduler { self.init_protocol(Protocols::new_tcp(stream), None, true) .await; } + }, + ProtocolAddr::Mpsc(addr) => { + let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel(); + MPSC_POOL.lock().await.insert(addr, mpsc_s); + s2a_listen_result_s.send(Ok(())).unwrap(); + trace!(?addr, "Listener bound"); + + let mut end_receiver = s2s_stop_listening_r.fuse(); + while let Some((local_to_remote_s, local_remote_to_local_s)) = select! { + next = mpsc_r.recv().fuse() => next, + _ = &mut end_receiver => None, + } { + let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND); + local_remote_to_local_s.send(remote_to_local_s).unwrap(); + info!(?addr, "Accepting Mpsc from"); + self.init_protocol(Protocols::new_mpsc(local_to_remote_s, remote_to_local_r), None, true) + .await; + } + warn!("MpscStream Failed, stopping"); },/* ProtocolAddr::Udp(addr) => { let socket = match net::UdpSocket::bind(addr).await { @@ -522,6 +576,7 @@ impl Scheduler { s2b_create_channel_s, s2b_shutdown_bparticipant_s, ) = BParticipant::new( + local_pid, pid, sid, #[cfg(feature = "metrics")] @@ -545,10 +600,11 @@ impl Scheduler { }); drop(participants); trace!("dropped participants lock"); + let p = pid; runtime.spawn( bparticipant .run(participant_channels.b2s_prio_statistic_s) - .instrument(tracing::info_span!("participant", ?pid)), + .instrument(tracing::info_span!("remote", ?p)), ); //create a new channel within BParticipant and wait for it to run let (b2s_create_channel_done_s, b2s_create_channel_done_r) = diff --git a/network/tests/helper.rs b/network/tests/helper.rs index 64c65b0e91..a06b59578c 100644 --- a/network/tests/helper.rs +++ b/network/tests/helper.rs @@ -2,7 +2,7 @@ use lazy_static::*; use std::{ net::SocketAddr, sync::{ - atomic::{AtomicU16, Ordering}, + atomic::{AtomicU16, AtomicU64, Ordering}, Arc, }, thread, @@ -92,3 +92,12 @@ pub fn udp() -> veloren_network::ProtocolAddr { let port = PORTS.fetch_add(1, Ordering::Relaxed); veloren_network::ProtocolAddr::Udp(SocketAddr::from(([127, 0, 0, 1], port))) } + +#[allow(dead_code)] +pub fn mpsc() -> veloren_network::ProtocolAddr { + lazy_static! { + static ref PORTS: AtomicU64 = AtomicU64::new(5000); + } + let port = PORTS.fetch_add(1, Ordering::Relaxed); + veloren_network::ProtocolAddr::Mpsc(port) +} diff --git a/network/tests/integration.rs b/network/tests/integration.rs index b78619d65d..fd33dab8e3 100644 --- a/network/tests/integration.rs +++ b/network/tests/integration.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use tokio::runtime::Runtime; use veloren_network::{NetworkError, StreamError}; mod helper; -use helper::{network_participant_stream, tcp, udp}; +use helper::{mpsc, network_participant_stream, tcp, udp}; use std::io::ErrorKind; use veloren_network::{Network, Pid, Promises, ProtocolAddr}; @@ -50,6 +50,31 @@ fn stream_simple_3msg() { } #[test] +fn stream_simple_mpsc() { + let (_, _) = helper::setup(false, 0); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(mpsc()); + + s1_a.send("Hello World").unwrap(); + assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown +} + +#[test] +fn stream_simple_mpsc_3msg() { + let (_, _) = helper::setup(false, 0); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(mpsc()); + + s1_a.send("Hello World").unwrap(); + s1_a.send(1337).unwrap(); + assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok(1337)); + s1_a.send("3rdMessage").unwrap(); + assert_eq!(r.block_on(s1_b.recv()), Ok("3rdMessage".to_string())); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown +} + +#[test] +#[ignore] fn stream_simple_udp() { let (_, _) = helper::setup(false, 0); let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(udp()); @@ -60,6 +85,7 @@ fn stream_simple_udp() { } #[test] +#[ignore] fn stream_simple_udp_3msg() { let (_, _) = helper::setup(false, 0); let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(udp()); @@ -101,6 +127,7 @@ fn tcp_and_udp_2_connections() -> std::result::Result<(), Box std::result::Result<(), Box> { let (_, _) = helper::setup(false, 0); let r = Arc::new(Runtime::new().unwrap()); diff --git a/server-cli/Cargo.toml b/server-cli/Cargo.toml index 4d8a3cf866..d75e1987ec 100644 --- a/server-cli/Cargo.toml +++ b/server-cli/Cargo.toml @@ -15,7 +15,7 @@ server = { package = "veloren-server", path = "../server", default-features = fa common = { package = "veloren-common", path = "../common" } common-net = { package = "veloren-common-net", path = "../common/net" } -tokio = { version = "1.0.1", default-features = false, features = ["rt-multi-thread"] } +tokio = { version = "1", default-features = false, features = ["rt-multi-thread"] } ansi-parser = "0.7" clap = "2.33" crossterm = "0.18" diff --git a/server-cli/src/logging.rs b/server-cli/src/logging.rs index 6a738ed2c5..18c0952780 100644 --- a/server-cli/src/logging.rs +++ b/server-cli/src/logging.rs @@ -19,6 +19,7 @@ pub fn init(basic: bool) { .add_directive("uvth=warn".parse().unwrap()) .add_directive("tiny_http=warn".parse().unwrap()) .add_directive("mio::sys::windows=debug".parse().unwrap()) + .add_directive("veloren_network_protocol=info".parse().unwrap()) .add_directive( "veloren_server::persistence::character=info" .parse() diff --git a/server/Cargo.toml b/server/Cargo.toml index 8726a037d8..f73ed1f02f 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -28,7 +28,8 @@ futures-util = "0.3.7" futures-executor = "0.3" futures-timer = "3.0" futures-channel = "0.3" -tokio = { version = "1.0.1", default-features = false, features = ["rt"] } +tokio = { version = "1", default-features = false, features = ["rt"] } +prometheus-hyper = "0.1.1" itertools = "0.9" lazy_static = "1.4.0" scan_fmt = { git = "https://github.com/Imberflur/scan_fmt" } @@ -41,7 +42,6 @@ hashbrown = { version = "0.9", features = ["rayon", "serde", "nightly"] } rayon = "1.5" crossbeam-channel = "0.5" prometheus = { version = "0.11", default-features = false} -tiny_http = "0.8.0" portpicker = { git = "https://github.com/xMAC94x/portpicker-rs", rev = "df6b37872f3586ac3b21d08b56c8ec7cd92fb172" } authc = { git = "https://gitlab.com/veloren/auth.git", rev = "bffb5181a35c19ddfd33ee0b4aedba741aafb68d" } libsqlite3-sys = { version = "0.18", features = ["bundled"] } diff --git a/server/src/lib.rs b/server/src/lib.rs index a8ac9487fe..40aa2b0804 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -76,12 +76,14 @@ use common_net::{ use common_sys::plugin::PluginMgr; use common_sys::state::State; use futures_executor::block_on; -use metrics::{PhysicsMetrics, ServerMetrics, StateTickMetrics, TickMetrics}; +use metrics::{PhysicsMetrics, StateTickMetrics, TickMetrics}; use network::{Network, Pid, ProtocolAddr}; use persistence::{ character_loader::{CharacterLoader, CharacterLoaderResponseKind}, character_updater::CharacterUpdater, }; +use prometheus::Registry; +use prometheus_hyper::Server as PrometheusServer; use specs::{join::Join, Builder, Entity as EcsEntity, RunNow, SystemData, WorldExt}; use std::{ i32, @@ -91,7 +93,7 @@ use std::{ }; #[cfg(not(feature = "worldgen"))] use test_world::{IndexOwned, World}; -use tokio::runtime::Runtime; +use tokio::{runtime::Runtime, sync::Notify}; use tracing::{debug, error, info, trace}; use uvth::{ThreadPool, ThreadPoolBuilder}; use vek::*; @@ -124,7 +126,7 @@ pub struct Server { _runtime: Arc, thread_pool: ThreadPool, - metrics: ServerMetrics, + metrics_shutdown: Arc, tick_metrics: TickMetrics, state_tick_metrics: StateTickMetrics, physics_metrics: PhysicsMetrics, @@ -350,28 +352,35 @@ impl Server { state.ecs_mut().insert(DeletedEntities::default()); - let mut metrics = ServerMetrics::new(); // register all metrics submodules here - let (tick_metrics, registry_tick) = TickMetrics::new(metrics.tick_clone()) - .expect("Failed to initialize server tick metrics submodule."); + let (tick_metrics, registry_tick) = + TickMetrics::new().expect("Failed to initialize server tick metrics submodule."); let (state_tick_metrics, registry_state) = StateTickMetrics::new().unwrap(); let (physics_metrics, registry_physics) = PhysicsMetrics::new().unwrap(); - registry_chunk(&metrics.registry()).expect("failed to register chunk gen metrics"); - registry_network(&metrics.registry()).expect("failed to register network request metrics"); - registry_player(&metrics.registry()).expect("failed to register player metrics"); - registry_tick(&metrics.registry()).expect("failed to register tick metrics"); - registry_state(&metrics.registry()).expect("failed to register state metrics"); - registry_physics(&metrics.registry()).expect("failed to register state metrics"); + let registry = Arc::new(Registry::new()); + registry_chunk(®istry).expect("failed to register chunk gen metrics"); + registry_network(®istry).expect("failed to register network request metrics"); + registry_player(®istry).expect("failed to register player metrics"); + registry_tick(®istry).expect("failed to register tick metrics"); + registry_state(®istry).expect("failed to register state metrics"); + registry_physics(®istry).expect("failed to register state metrics"); let thread_pool = ThreadPoolBuilder::new() .name("veloren-worker".to_string()) .build(); - let network = - Network::new_with_registry(Pid::new(), Arc::clone(&runtime), &metrics.registry()); - metrics - .run(settings.metrics_address) - .expect("Failed to initialize server metrics submodule."); + let network = Network::new_with_registry(Pid::new(), Arc::clone(&runtime), ®istry); + let metrics_shutdown = Arc::new(Notify::new()); + let metrics_shutdown_clone = Arc::clone(&metrics_shutdown); + let addr = settings.metrics_address; + runtime.spawn(async move { + PrometheusServer::run( + Arc::clone(®istry), + addr, + metrics_shutdown_clone.notified(), + ) + .await + }); block_on(network.listen(ProtocolAddr::Tcp(settings.gameserver_address)))?; let connection_handler = ConnectionHandler::new(network); @@ -392,7 +401,7 @@ impl Server { _runtime: runtime, thread_pool, - metrics, + metrics_shutdown, tick_metrics, state_tick_metrics, physics_metrics, @@ -904,7 +913,7 @@ impl Server { .tick_time .with_label_values(&["metrics"]) .set(end_of_server_tick.elapsed().as_nanos() as i64); - self.metrics.tick(); + self.tick_metrics.tick(); // 9) Finish the tick, pass control back to the frontend. @@ -1150,6 +1159,7 @@ impl Server { impl Drop for Server { fn drop(&mut self) { + self.metrics_shutdown.notify_one(); self.state .notify_players(ServerGeneral::Disconnect(DisconnectReason::Shutdown)); } diff --git a/server/src/metrics.rs b/server/src/metrics.rs index d52ae33538..ac57121437 100644 --- a/server/src/metrics.rs +++ b/server/src/metrics.rs @@ -1,19 +1,16 @@ use prometheus::{ - Encoder, Gauge, HistogramOpts, HistogramVec, IntCounter, IntCounterVec, IntGauge, IntGaugeVec, - Opts, Registry, TextEncoder, + Gauge, HistogramOpts, HistogramVec, IntCounter, IntCounterVec, IntGauge, IntGaugeVec, Opts, + Registry, }; use std::{ convert::TryInto, error::Error, - net::SocketAddr, sync::{ - atomic::{AtomicBool, AtomicU64, Ordering}, + atomic::{AtomicU64, Ordering}, Arc, }, - thread, time::{Duration, SystemTime, UNIX_EPOCH}, }; -use tracing::{debug, error}; type RegistryFn = Box Result<(), prometheus::Error>>; @@ -60,13 +57,6 @@ pub struct TickMetrics { tick: Arc, } -pub struct ServerMetrics { - running: Arc, - handle: Option>, - registry: Option, - tick: Arc, -} - impl PhysicsMetrics { pub fn new() -> Result<(Self, RegistryFn), prometheus::Error> { let entity_entity_collision_checks_count = IntCounter::with_opts(Opts::new( @@ -265,7 +255,7 @@ impl ChunkGenMetrics { } impl TickMetrics { - pub fn new(tick: Arc) -> Result<(Self, RegistryFn), Box> { + pub fn new() -> Result<(Self, RegistryFn), Box> { let chonks_count = IntGauge::with_opts(Opts::new( "chonks_count", "number of all chonks currently active on the server", @@ -315,6 +305,7 @@ impl TickMetrics { let time_of_day_clone = time_of_day.clone(); let light_count_clone = light_count.clone(); let tick_time_clone = tick_time.clone(); + let tick = Arc::new(AtomicU64::new(0)); let f = |registry: &Registry| { registry.register(Box::new(chonks_count_clone))?; @@ -346,87 +337,7 @@ impl TickMetrics { )) } + pub fn tick(&self) { self.tick.fetch_add(1, Ordering::Relaxed); } + pub fn is_100th_tick(&self) -> bool { self.tick.load(Ordering::Relaxed).rem_euclid(100) == 0 } } - -impl ServerMetrics { - #[allow(clippy::new_without_default)] // TODO: Pending review in #587 - pub fn new() -> Self { - let running = Arc::new(AtomicBool::new(false)); - let tick = Arc::new(AtomicU64::new(0)); - let registry = Some(Registry::new()); - - Self { - running, - handle: None, - registry, - tick, - } - } - - pub fn registry(&self) -> &Registry { - match self.registry { - Some(ref r) => r, - None => panic!("You cannot longer register new metrics after the server has started!"), - } - } - - pub fn run(&mut self, addr: SocketAddr) -> Result<(), Box> { - self.running.store(true, Ordering::Relaxed); - let running2 = Arc::clone(&self.running); - - let registry = self - .registry - .take() - .expect("ServerMetrics must be already started"); - - //TODO: make this a job - self.handle = Some(thread::spawn(move || { - let server = tiny_http::Server::http(addr).unwrap(); - const TIMEOUT: Duration = Duration::from_secs(1); - debug!("starting tiny_http server to serve metrics"); - while running2.load(Ordering::Relaxed) { - let request = match server.recv_timeout(TIMEOUT) { - Ok(Some(rq)) => rq, - Ok(None) => continue, - Err(e) => { - error!(?e, "metrics http server error"); - break; - }, - }; - let mf = registry.gather(); - let encoder = TextEncoder::new(); - let mut buffer = vec![]; - encoder - .encode(&mf, &mut buffer) - .expect("Failed to encoder metrics text."); - let response = tiny_http::Response::from_string( - String::from_utf8(buffer).expect("Failed to parse bytes as a string."), - ); - if let Err(e) = request.respond(response) { - error!( - ?e, - "The metrics HTTP server had encountered and error with answering", - ); - } - } - debug!("stopping tiny_http server to serve metrics"); - })); - Ok(()) - } - - pub fn tick(&self) -> u64 { self.tick.fetch_add(1, Ordering::Relaxed) + 1 } - - pub fn tick_clone(&self) -> Arc { Arc::clone(&self.tick) } -} - -impl Drop for ServerMetrics { - fn drop(&mut self) { - self.running.store(false, Ordering::Relaxed); - let handle = self.handle.take(); - handle - .expect("ServerMetrics worker handle does not exist.") - .join() - .expect("Error shutting down prometheus metric exporter"); - } -} diff --git a/voxygen/Cargo.toml b/voxygen/Cargo.toml index aeebeb0a30..08a7a72c46 100644 --- a/voxygen/Cargo.toml +++ b/voxygen/Cargo.toml @@ -82,7 +82,7 @@ ron = {version = "0.6", default-features = false} serde = {version = "1.0", features = [ "rc", "derive" ]} treeculler = "0.1.0" uvth = "3.1.1" -tokio = { version = "1.0.1", default-features = false, features = ["rt-multi-thread"] } +tokio = { version = "1", default-features = false, features = ["rt-multi-thread"] } num_cpus = "1.0" # vec_map = { version = "0.8.2" } inline_tweak = "1.0.2" diff --git a/voxygen/src/logging.rs b/voxygen/src/logging.rs index e0b4a962df..d0c3bd98d5 100644 --- a/voxygen/src/logging.rs +++ b/voxygen/src/logging.rs @@ -45,6 +45,7 @@ pub fn init(settings: &Settings) -> Vec { .add_directive("uvth=warn".parse().unwrap()) .add_directive("tiny_http=warn".parse().unwrap()) .add_directive("mio::sys::windows=debug".parse().unwrap()) + .add_directive("veloren_network_protocol=info".parse().unwrap()) .add_directive( "veloren_server::persistence::character=info" .parse() From 03af9937cf91214e8737805ebf0d037ed5d50b83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= Date: Sun, 14 Feb 2021 18:45:12 +0100 Subject: [PATCH 6/6] Stabelize Network again: - completly switch to Bytes, even in api. speed up TCP by fak 2 - improve benchmarks - speed up mpsc metrics - gracefully handle shutdown by interpreting Ok(0) as tokio::tcpstream closed now. - fix hotloop in participants by adding `Some(n)` to fix endless handing. - fix closing bug by closing streams after `recv_mgr` is shutdown even if now shutdown is triggered locally. - fix prometheus - no longer throw when a `Stream` is dropped while participant still receives a msg for it. - fix the bandwith handling, TCP network send speed is up to 1.5GiB/s while recv is 150MiB/s - add documentation - tmp require rt-multi-threaded in client for tokio, to not fail cargo check this is prob stable, i tested over 1 hour. after that some optimisations in priomgr. and impl. propper bandwith. Speed is up to 2GB/s write and 150MB/s recv on a single core sync add documentation --- CHANGELOG.md | 1 + Cargo.lock | 1 + client/Cargo.toml | 2 +- client/src/lib.rs | 3 +- network/Cargo.toml | 5 + network/benches/speed.rs | 143 ++++++++++++++ network/examples/chat.rs | 4 +- network/examples/fileshare/server.rs | 6 +- network/examples/network-speed/main.rs | 11 +- network/protocol/benches/protocols.rs | 249 +++++++++++++----------- network/protocol/src/event.rs | 30 +-- network/protocol/src/frame.rs | 258 ++++++++++++++----------- network/protocol/src/handshake.rs | 32 ++- network/protocol/src/io.rs | 63 ------ network/protocol/src/lib.rs | 124 ++++++++++-- network/protocol/src/message.rs | 118 ++++++++--- network/protocol/src/metrics.rs | 86 ++++++--- network/protocol/src/mpsc.rs | 88 +++++---- network/protocol/src/prio.rs | 57 +++--- network/protocol/src/tcp.rs | 257 ++++++++++++++---------- network/protocol/src/types.rs | 37 +++- network/src/api.rs | 230 +++++++++++----------- network/src/channel.rs | 96 +++++---- network/src/lib.rs | 28 +-- network/src/message.rs | 87 ++++----- network/src/metrics.rs | 53 ++++- network/src/participant.rs | 217 ++++++++++----------- network/src/scheduler.rs | 67 ++++--- network/tests/closing.rs | 14 +- network/tests/helper.rs | 14 +- network/tests/integration.rs | 4 +- server-cli/src/logging.rs | 4 +- server/src/connection_handler.rs | 10 +- 33 files changed, 1444 insertions(+), 955 deletions(-) create mode 100644 network/benches/speed.rs delete mode 100644 network/protocol/src/io.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 46089cca7e..6110a91441 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Cave scatter now includes all 6 gems. - Adjusted Stonework Defender loot table to remove mindflayer drops (bag, staff, glider). - Changed default controller key bindings +- Improved network efficiency by ≈ factor 10 by using tokio. ### Removed diff --git a/Cargo.lock b/Cargo.lock index 0642e1e5a6..cfdd5ebe19 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5713,6 +5713,7 @@ dependencies = [ "bitflags", "bytes 1.0.1", "clap", + "criterion", "crossbeam-channel 0.5.0", "futures-core", "futures-util", diff --git a/client/Cargo.toml b/client/Cargo.toml index d8f6cb969a..1578350d29 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -21,7 +21,7 @@ uvth = "3.1.1" futures-util = "0.3.7" futures-executor = "0.3" futures-timer = "3.0" -tokio = { version = "1", default-features = false, features = ["rt"] } +tokio = { version = "1", default-features = false, features = ["rt-multi-thread"] } image = { version = "0.23.12", default-features = false, features = ["png"] } num = "0.3.1" num_cpus = "1.10.1" diff --git a/client/src/lib.rs b/client/src/lib.rs index b48fa7335b..2c26d00f86 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -2066,7 +2066,8 @@ mod tests { let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 9000); let view_distance: Option = None; - let veloren_client: Result = Client::new(socket, view_distance); + let runtime = Arc::new(Runtime::new().unwrap()); + let veloren_client: Result = Client::new(socket, view_distance, runtime); let _ = veloren_client.map(|mut client| { //register diff --git a/network/Cargo.toml b/network/Cargo.toml index de6e028176..b5aeb4be3b 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -48,6 +48,11 @@ clap = { version = "2.33", default-features = false } shellexpand = "2.0.0" serde = { version = "1.0", features = ["derive"] } prometheus-hyper = "0.1.1" +criterion = { version = "0.3.4", features = ["default", "async_tokio"] } + +[[bench]] +name = "speed" +harness = false [[example]] name = "fileshare" diff --git a/network/benches/speed.rs b/network/benches/speed.rs new file mode 100644 index 0000000000..8de5c78335 --- /dev/null +++ b/network/benches/speed.rs @@ -0,0 +1,143 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use std::{net::SocketAddr, sync::Arc}; +use tokio::{runtime::Runtime, sync::Mutex}; +use veloren_network::{Message, Network, Participant, Pid, Promises, ProtocolAddr, Stream}; + +fn serialize(data: &[u8], stream: &Stream) { let _ = Message::serialize(data, &stream); } + +async fn stream_msg(s1_a: Arc>, s1_b: Arc>, data: &[u8], cnt: usize) { + let mut s1_b = s1_b.lock().await; + let m = Message::serialize(&data, &s1_b); + std::thread::spawn(move || { + let mut s1_a = s1_a.try_lock().unwrap(); + for _ in 0..cnt { + s1_a.send_raw(&m).unwrap(); + } + }); + for _ in 0..cnt { + s1_b.recv_raw().await.unwrap(); + } +} + +fn rt() -> Runtime { + tokio::runtime::Builder::new_current_thread() + .build() + .unwrap() +} + +fn criterion_util(c: &mut Criterion) { + let mut c = c.benchmark_group("net_util"); + c.significance_level(0.1).sample_size(100); + + let (r, _n_a, p_a, s1_a, _n_b, _p_b, _s1_b) = + network_participant_stream(ProtocolAddr::Mpsc(5000)); + let s2_a = r.block_on(p_a.open(4, Promises::COMPRESSED)).unwrap(); + + c.throughput(Throughput::Bytes(1000)) + .bench_function("message_serialize", |b| { + let data = vec![0u8; 1000]; + b.iter(|| serialize(&data, &s1_a)) + }); + c.throughput(Throughput::Bytes(1000)) + .bench_function("message_serialize_compress", |b| { + let data = vec![0u8; 1000]; + b.iter(|| serialize(&data, &s2_a)) + }); +} + +fn criterion_mpsc(c: &mut Criterion) { + let mut c = c.benchmark_group("net_mpsc"); + c.significance_level(0.1).sample_size(10); + + let (_r, _n_a, _p_a, s1_a, _n_b, _p_b, s1_b) = + network_participant_stream(ProtocolAddr::Mpsc(5000)); + let s1_a = Arc::new(Mutex::new(s1_a)); + let s1_b = Arc::new(Mutex::new(s1_b)); + + c.throughput(Throughput::Bytes(100000000)).bench_function( + BenchmarkId::new("100MB_in_10000_msg", ""), + |b| { + let data = vec![155u8; 100_000]; + b.to_async(rt()).iter_with_setup( + || (Arc::clone(&s1_a), Arc::clone(&s1_b)), + |(s1_a, s1_b)| stream_msg(s1_a, s1_b, &data, 1_000), + ) + }, + ); + c.throughput(Throughput::Elements(100000)).bench_function( + BenchmarkId::new("100000_tiny_msg", ""), + |b| { + let data = vec![3u8; 5]; + b.to_async(rt()).iter_with_setup( + || (Arc::clone(&s1_a), Arc::clone(&s1_b)), + |(s1_a, s1_b)| stream_msg(s1_a, s1_b, &data, 100_000), + ) + }, + ); + c.finish(); + drop((_n_a, _p_a, _n_b, _p_b)); +} + +fn criterion_tcp(c: &mut Criterion) { + let mut c = c.benchmark_group("net_tcp"); + c.significance_level(0.1).sample_size(10); + + let (_r, _n_a, _p_a, s1_a, _n_b, _p_b, s1_b) = + network_participant_stream(ProtocolAddr::Tcp(SocketAddr::from(([127, 0, 0, 1], 5000)))); + let s1_a = Arc::new(Mutex::new(s1_a)); + let s1_b = Arc::new(Mutex::new(s1_b)); + + c.throughput(Throughput::Bytes(100000000)).bench_function( + BenchmarkId::new("100MB_in_1000_msg", ""), + |b| { + let data = vec![155u8; 100_000]; + b.to_async(rt()).iter_with_setup( + || (Arc::clone(&s1_a), Arc::clone(&s1_b)), + |(s1_a, s1_b)| stream_msg(s1_a, s1_b, &data, 1_000), + ) + }, + ); + c.throughput(Throughput::Elements(100000)).bench_function( + BenchmarkId::new("100000_tiny_msg", ""), + |b| { + let data = vec![3u8; 5]; + b.to_async(rt()).iter_with_setup( + || (Arc::clone(&s1_a), Arc::clone(&s1_b)), + |(s1_a, s1_b)| stream_msg(s1_a, s1_b, &data, 100_000), + ) + }, + ); + c.finish(); + drop((_n_a, _p_a, _n_b, _p_b)); +} + +criterion_group!(benches, criterion_util, criterion_mpsc, criterion_tcp); +criterion_main!(benches); + +pub fn network_participant_stream( + addr: ProtocolAddr, +) -> ( + Arc, + Network, + Participant, + Stream, + Network, + Participant, + Stream, +) { + let runtime = Arc::new(Runtime::new().unwrap()); + let (n_a, p1_a, s1_a, n_b, p1_b, s1_b) = runtime.block_on(async { + let n_a = Network::new(Pid::fake(0), Arc::clone(&runtime)); + let n_b = Network::new(Pid::fake(1), Arc::clone(&runtime)); + + n_a.listen(addr.clone()).await.unwrap(); + let p1_b = n_b.connect(addr).await.unwrap(); + let p1_a = n_a.connected().await.unwrap(); + + let s1_a = p1_a.open(4, Promises::empty()).await.unwrap(); + let s1_b = p1_b.opened().await.unwrap(); + + (n_a, p1_a, s1_a, n_b, p1_b, s1_b) + }); + (runtime, n_a, p1_a, s1_a, n_b, p1_b, s1_b) +} diff --git a/network/examples/chat.rs b/network/examples/chat.rs index a1a3f09cf0..e5c7737531 100644 --- a/network/examples/chat.rs +++ b/network/examples/chat.rs @@ -130,7 +130,7 @@ async fn client_connection( Ok(msg) => { println!("[{}]: {}", username, msg); for p in participants.read().await.iter() { - match p.open(32, Promises::ORDERED | Promises::CONSISTENCY).await { + match p.open(4, Promises::ORDERED | Promises::CONSISTENCY).await { Err(_) => info!("error talking to client, //TODO drop it"), Ok(mut s) => s.send((username.clone(), msg.clone())).unwrap(), }; @@ -148,7 +148,7 @@ fn client(address: ProtocolAddr) { r.block_on(async { let p1 = client.connect(address.clone()).await.unwrap(); //remote representation of p1 let mut s1 = p1 - .open(16, Promises::ORDERED | Promises::CONSISTENCY) + .open(4, Promises::ORDERED | Promises::CONSISTENCY) .await .unwrap(); //remote representation of s1 let mut input_lines = io::BufReader::new(io::stdin()); diff --git a/network/examples/fileshare/server.rs b/network/examples/fileshare/server.rs index 5db8345d46..b6cf6c38dd 100644 --- a/network/examples/fileshare/server.rs +++ b/network/examples/fileshare/server.rs @@ -121,8 +121,8 @@ impl Server { #[allow(clippy::eval_order_dependence)] async fn loop_participant(&self, p: Participant) { if let (Ok(cmd_out), Ok(file_out), Ok(cmd_in), Ok(file_in)) = ( - p.open(15, Promises::ORDERED | Promises::CONSISTENCY).await, - p.open(40, Promises::CONSISTENCY).await, + p.open(3, Promises::ORDERED | Promises::CONSISTENCY).await, + p.open(6, Promises::CONSISTENCY).await, p.opened().await, p.opened().await, ) { @@ -175,7 +175,7 @@ impl Server { let mut path = std::env::current_dir().unwrap(); path.push(fi.path().file_name().unwrap()); trace!("No path provided, saving down to {:?}", path); - PathBuf::from(path) + path }, }; debug!("Received file, going to save it under {:?}", path); diff --git a/network/examples/network-speed/main.rs b/network/examples/network-speed/main.rs index 37d076b5bd..bb6684658a 100644 --- a/network/examples/network-speed/main.rs +++ b/network/examples/network-speed/main.rs @@ -132,6 +132,7 @@ fn server(address: ProtocolAddr, runtime: Arc) { runtime.block_on(server.listen(address)).unwrap(); loop { + info!("----"); info!("Waiting for participant to connect"); let p1 = runtime.block_on(server.connected()).unwrap(); //remote representation of p1 let mut s1 = runtime.block_on(p1.opened()).unwrap(); //remote representation of s1 @@ -163,7 +164,7 @@ fn client(address: ProtocolAddr, runtime: Arc) { let p1 = runtime.block_on(client.connect(address)).unwrap(); //remote representation of p1 let mut s1 = runtime - .block_on(p1.open(16, Promises::ORDERED | Promises::CONSISTENCY)) + .block_on(p1.open(4, Promises::ORDERED | Promises::CONSISTENCY)) .unwrap(); //remote representation of s1 let mut last = Instant::now(); let mut id = 0u64; @@ -185,16 +186,16 @@ fn client(address: ProtocolAddr, runtime: Arc) { } if id > 2000000 { println!("Stop"); - std::thread::sleep(std::time::Duration::from_millis(5000)); + std::thread::sleep(std::time::Duration::from_millis(2000)); break; } } drop(s1); - std::thread::sleep(std::time::Duration::from_millis(5000)); + std::thread::sleep(std::time::Duration::from_millis(2000)); info!("Closing participant"); runtime.block_on(p1.disconnect()).unwrap(); - std::thread::sleep(std::time::Duration::from_millis(25000)); + std::thread::sleep(std::time::Duration::from_millis(2000)); info!("DROPPING! client"); drop(client); - std::thread::sleep(std::time::Duration::from_millis(25000)); + std::thread::sleep(std::time::Duration::from_millis(2000)); } diff --git a/network/protocol/benches/protocols.rs b/network/protocol/benches/protocols.rs index f9ad557682..dfe6a57084 100644 --- a/network/protocol/benches/protocols.rs +++ b/network/protocol/benches/protocols.rs @@ -1,140 +1,153 @@ use async_channel::*; use async_trait::async_trait; -use bytes::BytesMut; -use criterion::{criterion_group, criterion_main, Criterion}; +use bytes::{Bytes, BytesMut}; +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use std::{sync::Arc, time::Duration}; +use tokio::runtime::Runtime; use veloren_network_protocol::{ - InitProtocol, MessageBuffer, MpscMsg, MpscRecvProtcol, MpscSendProtcol, Pid, Promises, - ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, RecvProtocol, SendProtocol, - Sid, TcpRecvProtcol, TcpSendProtcol, UnreliableDrain, UnreliableSink, _internal::Frame, + InitProtocol, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid, Promises, ProtocolError, + ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, RecvProtocol, SendProtocol, Sid, + TcpRecvProtocol, TcpSendProtocol, UnreliableDrain, UnreliableSink, _internal::OTFrame, }; -fn frame_serialize(frame: Frame, buffer: &mut BytesMut) { frame.to_bytes(buffer); } +fn frame_serialize(frame: OTFrame, buffer: &mut BytesMut) { frame.write_bytes(buffer); } -async fn mpsc_msg(buffer: Arc) { - // Arrrg, need to include constructor here - let [p1, p2] = utils::ac_bound(10, None); - let (mut s, mut r) = (p1.0, p2.1); - s.send(ProtocolEvent::Message { +async fn handshake(p: [(S, R); 2]) +where + S: SendProtocol, + R: RecvProtocol, + (S, R): InitProtocol, +{ + let [mut p1, mut p2] = p; + tokio::join!( + async { + p1.initialize(true, Pid::fake(2), 1337).await.unwrap(); + p1 + }, + async { + p2.initialize(false, Pid::fake(3), 42).await.unwrap(); + p2 + } + ); +} + +async fn send_msg(mut s: T, data: Bytes, cnt: usize) { + let bandwidth = data.len() as u64 + 100; + const SEC1: Duration = Duration::from_secs(1); + + s.send(ProtocolEvent::OpenStream { sid: Sid::new(12), - mid: 0, - buffer, + prio: 0, + promises: Promises::ORDERED, + guaranteed_bandwidth: 100_000, }) .await .unwrap(); - r.recv().await.unwrap(); -} -async fn mpsc_handshake() { - let [mut p1, mut p2] = utils::ac_bound(10, None); - let r1 = tokio::spawn(async move { - p1.initialize(true, Pid::fake(2), 1337).await.unwrap(); - p1 - }); - let r2 = tokio::spawn(async move { - p2.initialize(false, Pid::fake(3), 42).await.unwrap(); - p2 - }); - let (r1, r2) = tokio::join!(r1, r2); - r1.unwrap(); - r2.unwrap(); -} - -async fn tcp_msg(buffer: Arc, cnt: usize) { - let [p1, p2] = utils::tcp_bound(10000, None); /*10kbit*/ - let (mut s, mut r) = (p1.0, p2.1); - - let buffer = Arc::clone(&buffer); - let bandwidth = buffer.data.len() as u64 + 1000; - - let r1 = tokio::spawn(async move { - s.send(ProtocolEvent::OpenStream { + for i in 0..cnt { + s.send(ProtocolEvent::Message { sid: Sid::new(12), - prio: 0, - promises: Promises::ORDERED, - guaranteed_bandwidth: 100_000, + mid: i as u64, + data: data.clone(), }) .await .unwrap(); - - for i in 0..cnt { - s.send(ProtocolEvent::Message { - sid: Sid::new(12), - mid: i as u64, - buffer: Arc::clone(&buffer), - }) - .await - .unwrap(); - s.flush(bandwidth, Duration::from_secs(1)).await.unwrap(); + if i.rem_euclid(50) == 0 { + s.flush(bandwidth * 50_u64, SEC1).await.unwrap(); } - }); - let r2 = tokio::spawn(async move { - r.recv().await.unwrap(); - - for _ in 0..cnt { - r.recv().await.unwrap(); - } - }); - let (r1, r2) = tokio::join!(r1, r2); - r1.unwrap(); - r2.unwrap(); + } + s.flush(bandwidth * 1000_u64, SEC1).await.unwrap(); } -fn criterion_benchmark(c: &mut Criterion) { - let rt = || { - tokio::runtime::Builder::new_current_thread() - .build() - .unwrap() - }; +async fn recv_msg(mut r: T, cnt: usize) { + r.recv().await.unwrap(); - c.bench_function("mpsc_short_msg", |b| { - let buffer = Arc::new(MessageBuffer { - data: b"hello_world".to_vec(), - }); - b.to_async(rt()).iter(|| mpsc_msg(Arc::clone(&buffer))) - }); - c.bench_function("mpsc_long_msg", |b| { - let buffer = Arc::new(MessageBuffer { - data: vec![150u8; 500_000], - }); - b.to_async(rt()).iter(|| mpsc_msg(Arc::clone(&buffer))) - }); + for _ in 0..cnt { + r.recv().await.unwrap(); + } +} + +async fn send_and_recv_msg( + p: [(S, R); 2], + data: Bytes, + cnt: usize, +) { + let [p1, p2] = p; + let (s, r) = (p1.0, p2.1); + + tokio::join!(send_msg(s, data, cnt), recv_msg(r, cnt)); +} + +fn rt() -> Runtime { + tokio::runtime::Builder::new_current_thread() + .build() + .unwrap() +} + +fn criterion_util(c: &mut Criterion) { c.bench_function("mpsc_handshake", |b| { - b.to_async(rt()).iter(|| mpsc_handshake()) + b.to_async(rt()) + .iter_with_setup(|| utils::ac_bound(10, None), handshake) }); - - let mut buffer = BytesMut::with_capacity(1500); - c.bench_function("frame_serialize_short", |b| { - let frame = Frame::Data { + let mut buffer = BytesMut::with_capacity(1500); + let frame = OTFrame::Data { mid: 65, start: 89u64, - data: b"hello_world".to_vec(), + data: Bytes::from(&b"hello_world"[..]), }; - b.iter(|| frame_serialize(frame.clone(), &mut buffer)) - }); - - c.bench_function("tcp_short_msg", |b| { - let buffer = Arc::new(MessageBuffer { - data: b"hello_world".to_vec(), - }); - b.to_async(rt()).iter(|| tcp_msg(Arc::clone(&buffer), 1)) - }); - c.bench_function("tcp_1GB_in_10000_msg", |b| { - let buffer = Arc::new(MessageBuffer { - data: vec![155u8; 100_000], - }); - b.to_async(rt()) - .iter(|| tcp_msg(Arc::clone(&buffer), 10_000)) - }); - c.bench_function("tcp_1000000_tiny_msg", |b| { - let buffer = Arc::new(MessageBuffer { data: vec![3u8; 5] }); - b.to_async(rt()) - .iter(|| tcp_msg(Arc::clone(&buffer), 1_000_000)) + b.iter_with_setup( + || frame.clone(), + |frame| frame_serialize(frame, &mut buffer), + ) }); } -criterion_group!(benches, criterion_benchmark); +fn criterion_mpsc(c: &mut Criterion) { + let mut c = c.benchmark_group("mpsc"); + c.significance_level(0.1).sample_size(10); + c.throughput(Throughput::Bytes(1000000000)) + .bench_function("1GB_in_10000_msg", |b| { + let buffer = Bytes::from(&[155u8; 100_000][..]); + b.to_async(rt()).iter_with_setup( + || (buffer.clone(), utils::ac_bound(10, None)), + |(b, p)| send_and_recv_msg(p, b, 10_000), + ) + }); + c.throughput(Throughput::Elements(1000000)) + .bench_function("1000000_tiny_msg", |b| { + let buffer = Bytes::from(&[3u8; 5][..]); + b.to_async(rt()).iter_with_setup( + || (buffer.clone(), utils::ac_bound(10, None)), + |(b, p)| send_and_recv_msg(p, b, 1_000_000), + ) + }); + c.finish(); +} + +fn criterion_tcp(c: &mut Criterion) { + let mut c = c.benchmark_group("tcp"); + c.significance_level(0.1).sample_size(10); + c.throughput(Throughput::Bytes(1000000000)) + .bench_function("1GB_in_10000_msg", |b| { + let buf = Bytes::from(&[155u8; 100_000][..]); + b.to_async(rt()).iter_with_setup( + || (buf.clone(), utils::tcp_bound(10000, None)), + |(b, p)| send_and_recv_msg(p, b, 10_000), + ) + }); + c.throughput(Throughput::Elements(1000000)) + .bench_function("1000000_tiny_msg", |b| { + let buf = Bytes::from(&[3u8; 5][..]); + b.to_async(rt()).iter_with_setup( + || (buf.clone(), utils::tcp_bound(10000, None)), + |(b, p)| send_and_recv_msg(p, b, 1_000_000), + ) + }); + c.finish(); +} + +criterion_group!(benches, criterion_util, criterion_mpsc, criterion_tcp); criterion_main!(benches); mod utils { @@ -151,7 +164,7 @@ mod utils { pub fn ac_bound( cap: usize, metrics: Option, - ) -> [(MpscSendProtcol, MpscRecvProtcol); 2] { + ) -> [(MpscSendProtocol, MpscRecvProtocol); 2] { let (s1, r1) = async_channel::bounded(cap); let (s2, r2) = async_channel::bounded(cap); let m = metrics.unwrap_or_else(|| { @@ -159,12 +172,12 @@ mod utils { }); [ ( - MpscSendProtcol::new(ACDrain { sender: s1 }, m.clone()), - MpscRecvProtcol::new(ACSink { receiver: r2 }, m.clone()), + MpscSendProtocol::new(ACDrain { sender: s1 }, m.clone()), + MpscRecvProtocol::new(ACSink { receiver: r2 }, m.clone()), ), ( - MpscSendProtcol::new(ACDrain { sender: s2 }, m.clone()), - MpscRecvProtcol::new(ACSink { receiver: r1 }, m.clone()), + MpscSendProtocol::new(ACDrain { sender: s2 }, m.clone()), + MpscRecvProtocol::new(ACSink { receiver: r1 }, m), ), ] } @@ -181,7 +194,7 @@ mod utils { pub fn tcp_bound( cap: usize, metrics: Option, - ) -> [(TcpSendProtcol, TcpRecvProtcol); 2] { + ) -> [(TcpSendProtocol, TcpRecvProtocol); 2] { let (s1, r1) = async_channel::bounded(cap); let (s2, r2) = async_channel::bounded(cap); let m = metrics.unwrap_or_else(|| { @@ -189,12 +202,12 @@ mod utils { }); [ ( - TcpSendProtcol::new(TcpDrain { sender: s1 }, m.clone()), - TcpRecvProtcol::new(TcpSink { receiver: r2 }, m.clone()), + TcpSendProtocol::new(TcpDrain { sender: s1 }, m.clone()), + TcpRecvProtocol::new(TcpSink { receiver: r2 }, m.clone()), ), ( - TcpSendProtcol::new(TcpDrain { sender: s2 }, m.clone()), - TcpRecvProtcol::new(TcpSink { receiver: r1 }, m.clone()), + TcpSendProtocol::new(TcpDrain { sender: s2 }, m.clone()), + TcpRecvProtocol::new(TcpSink { receiver: r1 }, m), ), ] } diff --git a/network/protocol/src/event.rs b/network/protocol/src/event.rs index 14b74de558..cc332e5d3c 100644 --- a/network/protocol/src/event.rs +++ b/network/protocol/src/event.rs @@ -1,11 +1,13 @@ use crate::{ - frame::Frame, - message::MessageBuffer, + frame::OTFrame, types::{Bandwidth, Mid, Prio, Promises, Sid}, }; -use std::sync::Arc; +use bytes::Bytes; -/* used for communication with Protocols */ +/// used for communication with [`SendProtocol`] and [`RecvProtocol`] +/// +/// [`SendProtocol`]: crate::SendProtocol +/// [`RecvProtocol`]: crate::RecvProtocol #[derive(Debug, Clone)] #[cfg_attr(test, derive(PartialEq))] pub enum ProtocolEvent { @@ -20,29 +22,29 @@ pub enum ProtocolEvent { sid: Sid, }, Message { - buffer: Arc, + data: Bytes, mid: Mid, sid: Sid, }, } impl ProtocolEvent { - pub(crate) fn to_frame(&self) -> Frame { + pub(crate) fn to_frame(&self) -> OTFrame { match self { - ProtocolEvent::Shutdown => Frame::Shutdown, + ProtocolEvent::Shutdown => OTFrame::Shutdown, ProtocolEvent::OpenStream { sid, prio, promises, guaranteed_bandwidth: _, - } => Frame::OpenStream { + } => OTFrame::OpenStream { sid: *sid, prio: *prio, promises: *promises, }, - ProtocolEvent::CloseStream { sid } => Frame::CloseStream { sid: *sid }, + ProtocolEvent::CloseStream { sid } => OTFrame::CloseStream { sid: *sid }, ProtocolEvent::Message { .. } => { - unimplemented!("Event::Message to Frame IS NOT supported") + unimplemented!("Event::Message to OTFrame IS NOT supported") }, } } @@ -54,18 +56,18 @@ mod tests { #[test] fn test_to_frame() { - assert_eq!(ProtocolEvent::Shutdown.to_frame(), Frame::Shutdown); + assert_eq!(ProtocolEvent::Shutdown.to_frame(), OTFrame::Shutdown); assert_eq!( ProtocolEvent::CloseStream { sid: Sid::new(42) }.to_frame(), - Frame::CloseStream { sid: Sid::new(42) } + OTFrame::CloseStream { sid: Sid::new(42) } ); } #[test] #[should_panic] - fn test_sixlet_to_str() { + fn test_msg_buffer_panic() { let _ = ProtocolEvent::Message { - buffer: Arc::new(MessageBuffer { data: vec![] }), + data: Bytes::new(), mid: 0, sid: Sid::new(23), } diff --git a/network/protocol/src/frame.rs b/network/protocol/src/frame.rs index 2f1bb3eb4c..a490d67b4d 100644 --- a/network/protocol/src/frame.rs +++ b/network/protocol/src/frame.rs @@ -1,5 +1,5 @@ use crate::types::{Mid, Pid, Prio, Promises, Sid}; -use bytes::{Buf, BufMut, BytesMut}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; // const FRAME_RESERVED_1: u8 = 0; const FRAME_HANDSHAKE: u8 = 1; @@ -15,7 +15,7 @@ const FRAME_RAW: u8 = 8; /// Used for Communication between Channel <----(TCP/UDP)----> Channel #[derive(Debug, PartialEq, Clone)] -pub /* should be crate only */ enum InitFrame { +pub enum InitFrame { Handshake { magic_number: [u8; 7], version: [u32; 3], @@ -24,14 +24,14 @@ pub /* should be crate only */ enum InitFrame { pid: Pid, secret: u128, }, - /* WARNING: Sending RAW is only used for debug purposes in case someone write a new API - * against veloren Server! */ + /// WARNING: sending RAW is only for debug purposes and will drop the + /// connection Raw(Vec), } -/// Used for Communication between Channel <----(TCP/UDP)----> Channel +/// Used for OUT TCP Communication between Channel --(TCP)--> Channel #[derive(Debug, PartialEq, Clone)] -pub enum Frame { +pub enum OTFrame { Shutdown, /* Shutdown this channel gracefully, if all channels are shutdown (gracefully), * Participant is deleted */ OpenStream { @@ -49,8 +49,33 @@ pub enum Frame { }, Data { mid: Mid, - start: u64, - data: Vec, + start: u64, /* remove */ + data: Bytes, + }, +} + +/// Used for IN TCP Communication between Channel <--(TCP)-- Channel +#[derive(Debug, PartialEq, Clone)] +pub enum ITFrame { + Shutdown, /* Shutdown this channel gracefully, if all channels are shutdown (gracefully), + * Participant is deleted */ + OpenStream { + sid: Sid, + prio: Prio, + promises: Promises, + }, + CloseStream { + sid: Sid, + }, + DataHeader { + mid: Mid, + sid: Sid, + length: u64, + }, + Data { + mid: Mid, + start: u64, /* remove */ + data: BytesMut, }, } @@ -62,7 +87,7 @@ impl InitFrame { pub(crate) const RAW_CNS: usize = 2; //provide an appropriate buffer size. > 1500 - pub(crate) fn to_bytes(self, bytes: &mut BytesMut) { + pub(crate) fn write_bytes(self, bytes: &mut BytesMut) { match self { InitFrame::Handshake { magic_number, @@ -87,7 +112,7 @@ impl InitFrame { } } - pub(crate) fn to_frame(bytes: &mut BytesMut) -> Option { + pub(crate) fn read_frame(bytes: &mut BytesMut) -> Option { let frame_no = match bytes.get(0) { Some(&f) => f, None => return None, @@ -124,8 +149,6 @@ impl InitFrame { let length = bytes.get_u16_le() as usize; // lower length is allowed let max_length = length.min(bytes.len()); - println!("dasdasd {:?}", length); - println!("aaaaa {:?}", max_length); let mut data = vec![0; max_length]; data.copy_from_slice(&bytes[..max_length]); InitFrame::Raw(data) @@ -136,71 +159,67 @@ impl InitFrame { } } -impl Frame { - pub(crate) const CLOSE_STREAM_CNS: usize = 8; - /// const part of the DATA frame, actual size is variable - pub(crate) const DATA_CNS: usize = 18; - pub(crate) const DATA_HEADER_CNS: usize = 24; - pub(crate) const OPEN_STREAM_CNS: usize = 10; - // Size WITHOUT the 1rst indicating byte - pub(crate) const SHUTDOWN_CNS: usize = 0; +pub(crate) const TCP_CLOSE_STREAM_CNS: usize = 8; +/// const part of the DATA frame, actual size is variable +pub(crate) const TCP_DATA_CNS: usize = 18; +pub(crate) const TCP_DATA_HEADER_CNS: usize = 24; +pub(crate) const TCP_OPEN_STREAM_CNS: usize = 10; +// Size WITHOUT the 1rst indicating byte +pub(crate) const TCP_SHUTDOWN_CNS: usize = 0; - //provide an appropriate buffer size. > 1500 - pub fn to_bytes(self, bytes: &mut BytesMut) -> u64 { +impl OTFrame { + pub fn write_bytes(self, bytes: &mut BytesMut) { match self { - Frame::Shutdown => { + Self::Shutdown => { bytes.put_u8(FRAME_SHUTDOWN); - 0 }, - Frame::OpenStream { + Self::OpenStream { sid, prio, promises, } => { bytes.put_u8(FRAME_OPEN_STREAM); - bytes.put_slice(&sid.to_le_bytes()); + sid.to_bytes(bytes); bytes.put_u8(prio); bytes.put_u8(promises.to_le_bytes()[0]); - 0 }, - Frame::CloseStream { sid } => { + Self::CloseStream { sid } => { bytes.put_u8(FRAME_CLOSE_STREAM); - bytes.put_slice(&sid.to_le_bytes()); - 0 + sid.to_bytes(bytes); }, - Frame::DataHeader { mid, sid, length } => { + Self::DataHeader { mid, sid, length } => { bytes.put_u8(FRAME_DATA_HEADER); bytes.put_u64_le(mid); - bytes.put_slice(&sid.to_le_bytes()); + sid.to_bytes(bytes); bytes.put_u64_le(length); - 0 }, - Frame::Data { mid, start, data } => { + Self::Data { mid, start, data } => { bytes.put_u8(FRAME_DATA); bytes.put_u64_le(mid); bytes.put_u64_le(start); bytes.put_u16_le(data.len() as u16); bytes.put_slice(&data); - data.len() as u64 }, } } +} - pub(crate) fn to_frame(bytes: &mut BytesMut) -> Option { +impl ITFrame { + pub(crate) fn read_frame(bytes: &mut BytesMut) -> Option { let frame_no = match bytes.first() { Some(&f) => f, None => return None, }; let size = match frame_no { - FRAME_SHUTDOWN => Self::SHUTDOWN_CNS, - FRAME_OPEN_STREAM => Self::OPEN_STREAM_CNS, - FRAME_CLOSE_STREAM => Self::CLOSE_STREAM_CNS, - FRAME_DATA_HEADER => Self::DATA_HEADER_CNS, + FRAME_SHUTDOWN => TCP_SHUTDOWN_CNS, + FRAME_OPEN_STREAM => TCP_OPEN_STREAM_CNS, + FRAME_CLOSE_STREAM => TCP_CLOSE_STREAM_CNS, + FRAME_DATA_HEADER => TCP_DATA_HEADER_CNS, FRAME_DATA => { if bytes.len() < 17 + 1 + 1 { return None; } - u16::from_le_bytes([bytes[16 + 1], bytes[17 + 1]]) as usize + Self::DATA_CNS + u16::from_le_bytes([bytes[16 + 1], bytes[17 + 1]]) as usize + TCP_DATA_CNS }, _ => return None, }; @@ -212,13 +231,13 @@ impl Frame { let frame = match frame_no { FRAME_SHUTDOWN => { let _ = bytes.split_to(size + 1); - Frame::Shutdown + Self::Shutdown }, FRAME_OPEN_STREAM => { let mut bytes = bytes.split_to(size + 1); bytes.advance(1); - Frame::OpenStream { - sid: Sid::new(bytes.get_u64_le()), + Self::OpenStream { + sid: Sid::from_bytes(&mut bytes), prio: bytes.get_u8(), promises: Promises::from_bits_truncate(bytes.get_u8()), } @@ -226,29 +245,27 @@ impl Frame { FRAME_CLOSE_STREAM => { let mut bytes = bytes.split_to(size + 1); bytes.advance(1); - Frame::CloseStream { - sid: Sid::new(bytes.get_u64_le()), + Self::CloseStream { + sid: Sid::from_bytes(&mut bytes), } }, FRAME_DATA_HEADER => { let mut bytes = bytes.split_to(size + 1); bytes.advance(1); - Frame::DataHeader { + Self::DataHeader { mid: bytes.get_u64_le(), - sid: Sid::new(bytes.get_u64_le()), + sid: Sid::from_bytes(&mut bytes), length: bytes.get_u64_le(), } }, FRAME_DATA => { - let mut info = bytes.split_to(Self::DATA_CNS + 1); - info.advance(1); - let mid = info.get_u64_le(); - let start = info.get_u64_le(); - let length = info.get_u16_le(); - debug_assert_eq!(length as usize, size - Self::DATA_CNS); + bytes.advance(1); + let mid = bytes.get_u64_le(); + let start = bytes.get_u64_le(); + let length = bytes.get_u16_le(); + debug_assert_eq!(length as usize, size - TCP_DATA_CNS); let data = bytes.split_to(length as usize); - let data = data.to_vec(); - Frame::Data { mid, start, data } + Self::Data { mid, start, data } }, _ => unreachable!("Frame::to_frame should be handled before!"), }; @@ -256,6 +273,29 @@ impl Frame { } } +#[allow(unused_variables)] +impl PartialEq for OTFrame { + fn eq(&self, other: &ITFrame) -> bool { + match self { + Self::Shutdown => matches!(other, ITFrame::Shutdown), + Self::OpenStream { + sid, + prio, + promises, + } => matches!(other, ITFrame::OpenStream { + sid, + prio, + promises + }), + Self::CloseStream { sid } => matches!(other, ITFrame::CloseStream { sid }), + Self::DataHeader { mid, sid, length } => { + matches!(other, ITFrame::DataHeader { mid, sid, length }) + }, + Self::Data { mid, start, data } => matches!(other, ITFrame::Data { mid, start, data }), + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -275,32 +315,32 @@ mod tests { ] } - fn get_frames() -> Vec { + fn get_otframes() -> Vec { vec![ - Frame::OpenStream { + OTFrame::OpenStream { sid: Sid::new(1337), prio: 14, promises: Promises::GUARANTEED_DELIVERY, }, - Frame::DataHeader { + OTFrame::DataHeader { sid: Sid::new(1337), mid: 0, length: 36, }, - Frame::Data { + OTFrame::Data { mid: 0, start: 0, - data: vec![77u8; 20], + data: Bytes::from(&[77u8; 20][..]), }, - Frame::Data { + OTFrame::Data { mid: 0, start: 20, - data: vec![42u8; 16], + data: Bytes::from(&[42u8; 16][..]), }, - Frame::CloseStream { + OTFrame::CloseStream { sid: Sid::new(1337), }, - Frame::Shutdown, + OTFrame::Shutdown, ] } @@ -308,8 +348,8 @@ mod tests { fn initframe_individual() { let dupl = |frame: InitFrame| { let mut buffer = BytesMut::with_capacity(1500); - InitFrame::to_bytes(frame.clone(), &mut buffer); - InitFrame::to_frame(&mut buffer) + InitFrame::write_bytes(frame, &mut buffer); + InitFrame::read_frame(&mut buffer) }; for frame in get_initframes() { @@ -325,13 +365,13 @@ mod tests { let mut frames = get_initframes(); // to string for f in &frames { - InitFrame::to_bytes(f.clone(), &mut buffer); + InitFrame::write_bytes(f.clone(), &mut buffer); } // from string let mut framesd = frames .iter() - .map(|&_| InitFrame::to_frame(&mut buffer)) + .map(|&_| InitFrame::read_frame(&mut buffer)) .collect::>(); // compare @@ -343,15 +383,15 @@ mod tests { #[test] fn frame_individual() { - let dupl = |frame: Frame| { + let dupl = |frame: OTFrame| { let mut buffer = BytesMut::with_capacity(1500); - Frame::to_bytes(frame.clone(), &mut buffer); - Frame::to_frame(&mut buffer) + OTFrame::write_bytes(frame, &mut buffer); + ITFrame::read_frame(&mut buffer) }; - for frame in get_frames() { + for frame in get_otframes() { println!("frame: {:?}", &frame); - assert_eq!(Some(frame.clone()), dupl(frame)); + assert_eq!(frame.clone(), dupl(frame).expect("NONE")); } } @@ -359,36 +399,36 @@ mod tests { fn frame_multiple() { let mut buffer = BytesMut::with_capacity(3000); - let mut frames = get_frames(); + let mut frames = get_otframes(); // to string for f in &frames { - Frame::to_bytes(f.clone(), &mut buffer); + OTFrame::write_bytes(f.clone(), &mut buffer); } // from string let mut framesd = frames .iter() - .map(|&_| Frame::to_frame(&mut buffer)) + .map(|&_| ITFrame::read_frame(&mut buffer)) .collect::>(); // compare for (f, fd) in frames.drain(..).zip(framesd.drain(..)) { println!("frame: {:?}", &f); - assert_eq!(Some(f), fd); + assert_eq!(f, fd.expect("NONE")); } } #[test] fn frame_exact_size() { - const SIZE: usize = Frame::CLOSE_STREAM_CNS+1/*first byte*/; + const SIZE: usize = TCP_CLOSE_STREAM_CNS+1/*first byte*/; let mut buffer = BytesMut::with_capacity(SIZE); - let frame1 = Frame::CloseStream { sid: Sid::new(2) }; - Frame::to_bytes(frame1.clone(), &mut buffer); + let frame1 = OTFrame::CloseStream { sid: Sid::new(2) }; + OTFrame::write_bytes(frame1.clone(), &mut buffer); assert_eq!(buffer.len(), SIZE); - let mut deque = buffer.iter().map(|b| *b).collect(); - let frame2 = Frame::to_frame(&mut deque); - assert_eq!(Some(frame1), frame2); + let mut deque = buffer.iter().copied().collect(); + let frame2 = ITFrame::read_frame(&mut deque); + assert_eq!(frame1, frame2.expect("NONE")); } #[test] @@ -399,7 +439,7 @@ mod tests { magic_number: VELOREN_MAGIC_NUMBER, version: VELOREN_NETWORK_VERSION, }; - InitFrame::to_bytes(frame1.clone(), &mut buffer); + InitFrame::write_bytes(frame1, &mut buffer); } #[test] @@ -410,9 +450,9 @@ mod tests { magic_number: VELOREN_MAGIC_NUMBER, version: VELOREN_NETWORK_VERSION, }; - let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); + let _ = InitFrame::write_bytes(frame1, &mut buffer); buffer.truncate(6); // simulate partial retrieve - let frame1d = InitFrame::to_frame(&mut buffer); + let frame1d = InitFrame::read_frame(&mut buffer); assert_eq!(frame1d, None); } @@ -420,7 +460,7 @@ mod tests { fn initframe_rubish() { let mut buffer = BytesMut::from(&b"dtrgwcser"[..]); assert_eq!( - InitFrame::to_frame(&mut buffer), + InitFrame::read_frame(&mut buffer), Some(InitFrame::Raw(b"dtrgwcser".to_vec())) ); } @@ -430,9 +470,9 @@ mod tests { let mut buffer = BytesMut::with_capacity(50); let frame1 = InitFrame::Raw(b"foobar".to_vec()); - let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); + let _ = InitFrame::write_bytes(frame1.clone(), &mut buffer); buffer[1] = 255; - let framed = InitFrame::to_frame(&mut buffer); + let framed = InitFrame::read_frame(&mut buffer); assert_eq!(framed, Some(frame1)); } @@ -441,9 +481,9 @@ mod tests { let mut buffer = BytesMut::with_capacity(50); let frame1 = InitFrame::Raw(b"foobar".to_vec()); - let _ = InitFrame::to_bytes(frame1.clone(), &mut buffer); + let _ = InitFrame::write_bytes(frame1, &mut buffer); buffer[1] = 3; - let framed = InitFrame::to_frame(&mut buffer); + let framed = InitFrame::read_frame(&mut buffer); // we accept a different frame here, as it's RAW and debug only! assert_eq!(framed, Some(InitFrame::Raw(b"foo".to_vec()))); } @@ -452,48 +492,48 @@ mod tests { fn frame_too_short_buffer() { let mut buffer = BytesMut::with_capacity(10); - let frame1 = Frame::OpenStream { + let frame1 = OTFrame::OpenStream { sid: Sid::new(88), promises: Promises::ENCRYPTED, prio: 88, }; - Frame::to_bytes(frame1.clone(), &mut buffer); + OTFrame::write_bytes(frame1, &mut buffer); } #[test] fn frame_too_less_data() { let mut buffer = BytesMut::with_capacity(20); - let frame1 = Frame::OpenStream { + let frame1 = OTFrame::OpenStream { sid: Sid::new(88), promises: Promises::ENCRYPTED, prio: 88, }; - Frame::to_bytes(frame1.clone(), &mut buffer); + OTFrame::write_bytes(frame1, &mut buffer); buffer.truncate(6); // simulate partial retrieve - let frame1d = Frame::to_frame(&mut buffer); + let frame1d = ITFrame::read_frame(&mut buffer); assert_eq!(frame1d, None); } #[test] fn frame_rubish() { let mut buffer = BytesMut::from(&b"dtrgwcser"[..]); - assert_eq!(Frame::to_frame(&mut buffer), None); + assert_eq!(ITFrame::read_frame(&mut buffer), None); } #[test] fn frame_attack_too_much_length() { let mut buffer = BytesMut::with_capacity(50); - let frame1 = Frame::Data { + let frame1 = OTFrame::Data { mid: 7u64, start: 1u64, - data: b"foobar".to_vec(), + data: Bytes::from(&b"foobar"[..]), }; - Frame::to_bytes(frame1.clone(), &mut buffer); + OTFrame::write_bytes(frame1, &mut buffer); buffer[17] = 255; - let framed = Frame::to_frame(&mut buffer); + let framed = ITFrame::read_frame(&mut buffer); assert_eq!(framed, None); } @@ -501,25 +541,25 @@ mod tests { fn frame_attack_too_low_length() { let mut buffer = BytesMut::with_capacity(50); - let frame1 = Frame::Data { + let frame1 = OTFrame::Data { mid: 7u64, start: 1u64, - data: b"foobar".to_vec(), + data: Bytes::from(&b"foobar"[..]), }; - Frame::to_bytes(frame1.clone(), &mut buffer); + OTFrame::write_bytes(frame1, &mut buffer); buffer[17] = 3; - let framed = Frame::to_frame(&mut buffer); + let framed = ITFrame::read_frame(&mut buffer); assert_eq!( framed, - Some(Frame::Data { + Some(ITFrame::Data { mid: 7u64, start: 1u64, - data: b"foo".to_vec(), + data: BytesMut::from(&b"foo"[..]), }) ); //next = Invalid => Empty - let framed = Frame::to_frame(&mut buffer); + let framed = ITFrame::read_frame(&mut buffer); assert_eq!(framed, None); } } diff --git a/network/protocol/src/handshake.rs b/network/protocol/src/handshake.rs index cc46791fc6..fda3893d72 100644 --- a/network/protocol/src/handshake.rs +++ b/network/protocol/src/handshake.rs @@ -9,13 +9,24 @@ use crate::{ use async_trait::async_trait; use tracing::{debug, error, info, trace}; -// Protocols might define a Reliable Variant for auto Handshake discovery -// this doesn't need to be effective +/// Implement this for auto Handshake with [`ReliableSink`]. +/// You must make sure that EVERY message send this way actually is received on +/// the receiving site: +/// - exactly once +/// - in the correct order +/// - correctly +/// +/// [`ReliableSink`]: crate::ReliableSink +/// [`RecvProtocol`]: crate::RecvProtocol #[async_trait] pub trait ReliableDrain { async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError>; } +/// Implement this for auto Handshake with [`ReliableDrain`]. See +/// [`ReliableDrain`]. +/// +/// [`ReliableDrain`]: crate::ReliableDrain #[async_trait] pub trait ReliableSink { async fn recv(&mut self) -> Result; @@ -34,14 +45,13 @@ where local_secret: u128, ) -> Result<(Pid, Sid, u128), InitProtocolError> { #[cfg(debug_assertions)] - const WRONG_NUMBER: &'static [u8] = "Handshake does not contain the magic number required \ - by veloren server.\nWe are not sure if you are a \ - valid veloren client.\nClosing the connection" - .as_bytes(); + const WRONG_NUMBER: &str = "Handshake does not contain the magic number required by \ + veloren server.\nWe are not sure if you are a valid veloren \ + client.\nClosing the connection"; #[cfg(debug_assertions)] - const WRONG_VERSION: &'static str = "Handshake does contain a correct magic number, but \ - invalid version.\nWe don't know how to communicate \ - with you.\nClosing the connection"; + const WRONG_VERSION: &str = "Handshake does contain a correct magic number, but invalid \ + version.\nWe don't know how to communicate with \ + you.\nClosing the connection"; const ERR_S: &str = "Got A Raw Message, these are usually Debug Messages indicating that \ something went wrong on network layer and connection will be closed"; @@ -66,7 +76,9 @@ where if magic_number != VELOREN_MAGIC_NUMBER { error!(?magic_number, "Connection with invalid magic_number"); #[cfg(debug_assertions)] - drain.send(InitFrame::Raw(WRONG_NUMBER.to_vec())).await?; + drain + .send(InitFrame::Raw(WRONG_NUMBER.as_bytes().to_vec())) + .await?; Err(InitProtocolError::WrongMagicNumber(magic_number)) } else if version != VELOREN_NETWORK_VERSION { error!(?version, "Connection with wrong network version"); diff --git a/network/protocol/src/io.rs b/network/protocol/src/io.rs deleted file mode 100644 index 6ccf40e7d8..0000000000 --- a/network/protocol/src/io.rs +++ /dev/null @@ -1,63 +0,0 @@ -use crate::ProtocolError; -use async_trait::async_trait; -use bytes::BytesMut; -use std::collections::VecDeque; -///! I/O-Free (Sans-I/O) protocol https://sans-io.readthedocs.io/how-to-sans-io.html - -// Protocols should base on the Unrealiable variants to get something effective! -#[async_trait] -pub trait UnreliableDrain: Send { - type DataFormat; - async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError>; -} - -#[async_trait] -pub trait UnreliableSink: Send { - type DataFormat; - async fn recv(&mut self) -> Result; -} - -pub struct BaseDrain { - data: VecDeque, -} - -pub struct BaseSink { - data: VecDeque, -} - -impl BaseDrain { - pub fn new() -> Self { - Self { - data: VecDeque::new(), - } - } -} - -impl BaseSink { - pub fn new() -> Self { - Self { - data: VecDeque::new(), - } - } -} - -//TODO: Test Sinks that drop 20% by random and log that - -#[async_trait] -impl UnreliableDrain for BaseDrain { - type DataFormat = BytesMut; - - async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { - self.data.push_back(data); - Ok(()) - } -} - -#[async_trait] -impl UnreliableSink for BaseSink { - type DataFormat = BytesMut; - - async fn recv(&mut self) -> Result { - self.data.pop_front().ok_or(ProtocolError::Closed) - } -} diff --git a/network/protocol/src/lib.rs b/network/protocol/src/lib.rs index 295d292881..fc22b9b711 100644 --- a/network/protocol/src/lib.rs +++ b/network/protocol/src/lib.rs @@ -1,7 +1,57 @@ +//! Network Protocol +//! +//! a I/O-Free protocol for the veloren network crate. +//! This crate defines multiple different protocols over [`UnreliableDrain`] and +//! [`UnreliableSink`] traits, which allows it to define the behavior of a +//! protocol separated from the actual io. +//! +//! For example we define the TCP protocol on top of Drains and Sinks that can +//! send chunks of bytes. You can now implement your own Drain And Sink that +//! sends the data via tokio's or std's implementation. Or you just use a +//! std::mpsc::channel for unit tests without needing a actual tcp socket. +//! +//! This crate currently defines: +//! - TCP +//! - MPSC +//! +//! a UDP implementation will quickly follow, and it's also possible to abstract +//! over QUIC. +//! +//! warning: don't mix protocol, using the TCP variant for actual UDP socket +//! will result in dropped data using UDP with a TCP socket will be a waste of +//! resources. +//! +//! A *channel* in this crate is defined as a combination of *read* and *write* +//! protocol. +//! +//! # adding a protocol +//! +//! We start by defining our DataFormat. For most this is prob [`Vec`] or +//! [`Bytes`]. MPSC can directly send a msg without serialisation. +//! +//! Create 2 structs, one for the receiving and sending end. Based on a generic +//! Drain/Sink with your required DataFormat. +//! Implement the [`SendProtocol`] and [`RecvProtocol`] traits respectively. +//! +//! Implement the Handshake: [`InitProtocol`], alternatively you can also +//! implement `ReliableDrain` and `ReliableSink`, by this, you use the default +//! Handshake. +//! +//! This crate also contains consts and definitions for the network protocol. +//! +//! For an *example* see `TcpDrain` and `TcpSink` in the [tcp.rs](tcp.rs) +//! +//! [`UnreliableDrain`]: crate::UnreliableDrain +//! [`UnreliableSink`]: crate::UnreliableSink +//! [`Vec`]: std::vec::Vec +//! [`Bytes`]: bytes::Bytes +//! [`SendProtocol`]: crate::SendProtocol +//! [`RecvProtocol`]: crate::RecvProtocol +//! [`InitProtocol`]: crate::InitProtocol + mod event; mod frame; mod handshake; -mod io; mod message; mod metrics; mod mpsc; @@ -10,22 +60,23 @@ mod tcp; mod types; pub use event::ProtocolEvent; -pub use io::{BaseDrain, BaseSink, UnreliableDrain, UnreliableSink}; -pub use message::MessageBuffer; pub use metrics::ProtocolMetricCache; #[cfg(feature = "metrics")] pub use metrics::ProtocolMetrics; -pub use mpsc::{MpscMsg, MpscRecvProtcol, MpscSendProtcol}; -pub use tcp::{TcpRecvProtcol, TcpSendProtcol}; -pub use types::{Bandwidth, Cid, Mid, Pid, Prio, Promises, Sid, VELOREN_NETWORK_VERSION}; +pub use mpsc::{MpscMsg, MpscRecvProtocol, MpscSendProtocol}; +pub use tcp::{TcpRecvProtocol, TcpSendProtocol}; +pub use types::{ + Bandwidth, Cid, Mid, Pid, Prio, Promises, Sid, HIGHEST_PRIO, VELOREN_NETWORK_VERSION, +}; ///use at own risk, might change any time, for internal benchmarks pub mod _internal { - pub use crate::frame::Frame; + pub use crate::frame::{ITFrame, OTFrame}; } use async_trait::async_trait; +/// Handshake: Used to connect 2 Channels. #[async_trait] pub trait InitProtocol { async fn initialize( @@ -36,14 +87,32 @@ pub trait InitProtocol { ) -> Result<(Pid, Sid, u128), InitProtocolError>; } +/// Generic Network Send Protocol. +/// Implement this for your Protocol of choice ( tcp, udp, mpsc, quic) +/// Allows the creation/deletions of `Streams` and sending messages via +/// [`ProtocolEvent`]. +/// +/// A `Stream` MUST be bound to a specific Channel. You MUST NOT switch the +/// channel to send a stream mid air. We will provide takeover options for +/// Channel closure in the future to allow keeping a `Stream` over a broker +/// Channel. +/// +/// [`ProtocolEvent`]: crate::ProtocolEvent #[async_trait] pub trait SendProtocol { - //a stream MUST be bound to a specific Protocol, there will be a failover - // feature comming for the case where a Protocol fails completly - /// use this to notify the sending side of streams that were created/remove - /// from remote + /// YOU MUST inform the `SendProtocol` by any Stream Open BEFORE using it in + /// `send` and Stream Close AFTER using it in `send` via this fn. fn notify_from_recv(&mut self, event: ProtocolEvent); + /// Send a Event via this Protocol. The `SendProtocol` MAY require `flush` + /// to be called before actual data is send to the respective `Sink`. async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError>; + /// Flush all buffered messages according to their [`Prio`] and + /// [`Bandwidth`]. provide the current bandwidth budget (per second) as + /// well as the `dt` since last call. According to the budget the + /// respective messages will be flushed. + /// + /// [`Prio`]: crate::Prio + /// [`Bandwidth`]: crate::Bandwidth async fn flush( &mut self, bandwidth: Bandwidth, @@ -51,11 +120,42 @@ pub trait SendProtocol { ) -> Result<(), ProtocolError>; } +/// Generic Network Recv Protocol. See: [`SendProtocol`] +/// +/// [`SendProtocol`]: crate::SendProtocol #[async_trait] pub trait RecvProtocol { + /// Either recv an event or fail the Protocol, once the Recv side is closed + /// it cannot recover from the error. async fn recv(&mut self) -> Result; } +/// This crate makes use of UnreliableDrains, they are expected to provide the +/// same guarantees like their IO-counterpart. E.g. ordered messages for TCP and +/// nothing for UDP. The respective Protocol needs then to handle this. +/// This trait is an abstraction above multiple Drains, e.g. [`tokio`](https://tokio.rs) [`async-std`] [`std`] or even [`async-channel`] +/// +/// [`async-std`]: async-std +/// [`std`]: std +/// [`async-channel`]: async-channel +#[async_trait] +pub trait UnreliableDrain: Send { + type DataFormat; + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError>; +} + +/// Sink counterpart of [`UnreliableDrain`] +/// +/// [`UnreliableDrain`]: crate::UnreliableDrain +#[async_trait] +pub trait UnreliableSink: Send { + type DataFormat; + async fn recv(&mut self) -> Result; +} + +/// All possible Errors that can happen during Handshake [`InitProtocol`] +/// +/// [`InitProtocol`]: crate::InitProtocol #[derive(Debug, PartialEq)] pub enum InitProtocolError { Closed, @@ -63,8 +163,8 @@ pub enum InitProtocolError { WrongVersion([u32; 3]), } -#[derive(Debug, PartialEq)] /// When you return closed you must stay closed! +#[derive(Debug, PartialEq)] pub enum ProtocolError { Closed, } diff --git a/network/protocol/src/message.rs b/network/protocol/src/message.rs index 1bda1325ad..c71c0b5515 100644 --- a/network/protocol/src/message.rs +++ b/network/protocol/src/message.rs @@ -1,39 +1,100 @@ use crate::{ - frame::Frame, + frame::OTFrame, types::{Mid, Sid}, }; -use std::{collections::VecDeque, sync::Arc}; +use bytes::{Bytes, BytesMut}; -//Todo: Evaluate switching to VecDeque for quickly adding and removing data -// from front, back. -// - It would prob require custom bincode code but thats possible. -#[cfg_attr(test, derive(PartialEq))] -pub struct MessageBuffer { - pub data: Vec, +pub(crate) const ALLOC_BLOCK: usize = 16_777_216; + +/// Contains a outgoing message for TCP protocol +/// All Chunks have the same size, except for the last chunk which can end +/// earlier. E.g. +/// ```ignore +/// msg = OTMessage::new(); +/// msg.next(); +/// msg.next(); +/// ``` +#[derive(Debug)] +pub(crate) struct OTMessage { + data: Bytes, + original_length: u64, + send_header: bool, + mid: Mid, + sid: Sid, + start: u64, /* remove */ } -impl std::fmt::Debug for MessageBuffer { - #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - //TODO: small messages! - let len = self.data.len(); - if len > 20 { - write!( - f, - "MessageBuffer(len: {}, {}, {}, {}, {:X?}..{:X?})", - len, - u32::from_le_bytes([self.data[0], self.data[1], self.data[2], self.data[3]]), - u32::from_le_bytes([self.data[4], self.data[5], self.data[6], self.data[7]]), - u32::from_le_bytes([self.data[8], self.data[9], self.data[10], self.data[11]]), - &self.data[13..16], - &self.data[len - 8..len] - ) +#[derive(Debug)] +pub(crate) struct ITMessage { + pub data: BytesMut, + pub sid: Sid, + pub length: u64, +} + +impl OTMessage { + pub(crate) const FRAME_DATA_SIZE: u64 = 1400; + + pub(crate) fn new(data: Bytes, mid: Mid, sid: Sid) -> Self { + let original_length = data.len() as u64; + Self { + data, + original_length, + send_header: false, + mid, + sid, + start: 0, + } + } + + fn get_header(&self) -> OTFrame { + OTFrame::DataHeader { + mid: self.mid, + sid: self.sid, + length: self.data.len() as u64, + } + } + + fn get_next_data(&mut self) -> OTFrame { + let to_send = std::cmp::min(self.data.len(), Self::FRAME_DATA_SIZE as usize); + let data = self.data.split_to(to_send); + let start = self.start; + self.start += Self::FRAME_DATA_SIZE; + + OTFrame::Data { + mid: self.mid, + start, + data, + } + } + + /// returns if something was added + pub(crate) fn next(&mut self) -> Option { + if !self.send_header { + self.send_header = true; + Some(self.get_header()) + } else if !self.data.is_empty() { + Some(self.get_next_data()) } else { - write!(f, "MessageBuffer(len: {}, {:?})", len, &self.data[..]) + None + } + } + + pub(crate) fn get_sid_len(&self) -> (Sid, u64) { (self.sid, self.original_length) } +} + +impl ITMessage { + pub(crate) fn new(sid: Sid, length: u64, _allocator: &mut BytesMut) -> Self { + //allocator.reserve(ALLOC_BLOCK); + //TODO: grab mem from the allocatior, but this is only possible with unsafe + Self { + sid, + length, + data: BytesMut::with_capacity((length as usize).min(ALLOC_BLOCK /* anti-ddos */)), } } } +/* /// Contains a outgoing message and store what was *send* and *confirmed* /// All Chunks have the same size, except for the last chunk which can end /// earlier. E.g. @@ -45,7 +106,8 @@ impl std::fmt::Debug for MessageBuffer { /// msg.confirm(2); /// ``` #[derive(Debug)] -pub(crate) struct OutgoingMessage { +#[allow(dead_code)] +pub(crate) struct OUMessage { buffer: Arc, send_index: u64, // 3 => 4200 (3*FRAME_DATA_SIZE) send_header: bool, @@ -56,7 +118,8 @@ pub(crate) struct OutgoingMessage { missing_indices: VecDeque, } -impl OutgoingMessage { +#[allow(dead_code)] +impl OUMessage { pub(crate) const FRAME_DATA_SIZE: u64 = 1400; pub(crate) fn new(buffer: Arc, mid: Mid, sid: Sid) -> Self { @@ -125,3 +188,4 @@ impl OutgoingMessage { pub(crate) fn get_sid_len(&self) -> (Sid, u64) { (self.sid, self.buffer.data.len() as u64) } } +*/ diff --git a/network/protocol/src/metrics.rs b/network/protocol/src/metrics.rs index 0b5c66872a..3ecacfa1a2 100644 --- a/network/protocol/src/metrics.rs +++ b/network/protocol/src/metrics.rs @@ -5,7 +5,8 @@ use prometheus::{ IntCounterVec, IntGaugeVec, Opts, Registry, }; #[cfg(feature = "metrics")] -use std::{collections::HashMap, error::Error, sync::Arc}; +use std::collections::HashMap; +use std::{error::Error, sync::Arc}; #[allow(dead_code)] pub enum RemoveReason { @@ -13,6 +14,10 @@ pub enum RemoveReason { Dropped, } +/// Use 1 `ProtocolMetrics` per `Network`. +/// I will contain all protocol related [`prometheus`] information +/// +/// [`prometheus`]: prometheus #[cfg(feature = "metrics")] pub struct ProtocolMetrics { // smsg=send_msg rdata=receive_data @@ -55,6 +60,10 @@ pub struct ProtocolMetrics { ping: IntGaugeVec, } +/// Cache for [`ProtocolMetrics`], more optimized and cleared up after channel +/// disconnect. +/// +/// [`ProtocolMetrics`]: crate::ProtocolMetrics #[cfg(feature = "metrics")] #[derive(Debug, Clone)] pub struct ProtocolMetricCache { @@ -201,17 +210,20 @@ impl ProtocolMetrics { } } +#[cfg(not(feature = "metrics"))] +pub struct ProtocolMetrics {} + #[cfg(feature = "metrics")] #[derive(Debug, Clone)] pub(crate) struct CacheLine { - smsg_it: GenericCounter, - smsg_ib: GenericCounter, - smsg_ot: [GenericCounter; 2], - smsg_ob: [GenericCounter; 2], - rmsg_it: GenericCounter, - rmsg_ib: GenericCounter, - rmsg_ot: [GenericCounter; 2], - rmsg_ob: [GenericCounter; 2], + pub smsg_it: GenericCounter, + pub smsg_ib: GenericCounter, + pub smsg_ot: [GenericCounter; 2], + pub smsg_ob: [GenericCounter; 2], + pub rmsg_it: GenericCounter, + pub rmsg_ib: GenericCounter, + pub rmsg_ot: [GenericCounter; 2], + pub rmsg_ob: [GenericCounter; 2], } #[cfg(feature = "metrics")] @@ -279,8 +291,8 @@ impl ProtocolMetricCache { line.smsg_ob[reason.i()].inc_by(bytes); } - pub(crate) fn sdata_frames_b(&mut self, bytes: u64) { - self.sdata_frames_t.inc(); + pub(crate) fn sdata_frames_b(&mut self, cnt: u64, bytes: u64) { + self.sdata_frames_t.inc_by(cnt); self.sdata_frames_b.inc_by(bytes); } @@ -332,6 +344,31 @@ impl ProtocolMetricCache { } } +#[cfg(feature = "metrics")] +impl Drop for ProtocolMetricCache { + fn drop(&mut self) { + let cid = &self.cid; + let m = &self.m; + let finished = RemoveReason::Finished.to_str(); + let dropped = RemoveReason::Dropped.to_str(); + for (sid, _) in self.cache.drain() { + let s = sid.to_string(); + let _ = m.smsg_it.remove_label_values(&[&cid, &s]); + let _ = m.smsg_ib.remove_label_values(&[&cid, &s]); + let _ = m.smsg_ot.remove_label_values(&[&cid, &s, &finished]); + let _ = m.smsg_ot.remove_label_values(&[&cid, &s, &dropped]); + let _ = m.smsg_ob.remove_label_values(&[&cid, &s, &finished]); + let _ = m.smsg_ob.remove_label_values(&[&cid, &s, &dropped]); + let _ = m.rmsg_it.remove_label_values(&[&cid, &s]); + let _ = m.rmsg_ib.remove_label_values(&[&cid, &s]); + let _ = m.rmsg_ot.remove_label_values(&[&cid, &s, &finished]); + let _ = m.rmsg_ot.remove_label_values(&[&cid, &s, &dropped]); + let _ = m.rmsg_ob.remove_label_values(&[&cid, &s, &finished]); + let _ = m.rmsg_ob.remove_label_values(&[&cid, &s, &dropped]); + } + } +} + #[cfg(feature = "metrics")] impl std::fmt::Debug for ProtocolMetrics { #[inline] @@ -342,45 +379,40 @@ impl std::fmt::Debug for ProtocolMetrics { #[cfg(not(feature = "metrics"))] impl ProtocolMetricCache { - pub(crate) fn smsg_it(&mut self, _sid: Sid) {} + pub fn new(_channel_key: &str, _metrics: Arc) -> Self { Self {} } pub(crate) fn smsg_ib(&mut self, _sid: Sid, _b: u64) {} - pub(crate) fn smsg_ot(&mut self, _sid: Sid, _reason: RemoveReason) {} - pub(crate) fn smsg_ob(&mut self, _sid: Sid, _reason: RemoveReason, _b: u64) {} - pub(crate) fn sdata_frames_t(&mut self) {} - - pub(crate) fn sdata_frames_b(&mut self, _b: u64) {} - - pub(crate) fn rmsg_it(&mut self, _sid: Sid) {} + pub(crate) fn sdata_frames_b(&mut self, _cnt: u64, _b: u64) {} pub(crate) fn rmsg_ib(&mut self, _sid: Sid, _b: u64) {} - pub(crate) fn rmsg_ot(&mut self, _sid: Sid, _reason: RemoveReason) {} - pub(crate) fn rmsg_ob(&mut self, _sid: Sid, _reason: RemoveReason, _b: u64) {} - pub(crate) fn rdata_frames_t(&mut self) {} - pub(crate) fn rdata_frames_b(&mut self, _b: u64) {} } +#[cfg(not(feature = "metrics"))] +impl ProtocolMetrics { + pub fn new() -> Result> { Ok(Self {}) } +} + impl RemoveReason { #[cfg(feature = "metrics")] fn to_str(&self) -> &str { match self { - RemoveReason::Dropped => "Dropped", RemoveReason::Finished => "Finished", + RemoveReason::Dropped => "Dropped", } } #[cfg(feature = "metrics")] - fn i(&self) -> usize { + pub(crate) fn i(&self) -> usize { match self { - RemoveReason::Dropped => 0, - RemoveReason::Finished => 1, + RemoveReason::Finished => 0, + RemoveReason::Dropped => 1, } } } diff --git a/network/protocol/src/mpsc.rs b/network/protocol/src/mpsc.rs index 0fbbee6300..1eb987c75c 100644 --- a/network/protocol/src/mpsc.rs +++ b/network/protocol/src/mpsc.rs @@ -1,25 +1,30 @@ +#[cfg(feature = "metrics")] +use crate::metrics::RemoveReason; use crate::{ event::ProtocolEvent, frame::InitFrame, handshake::{ReliableDrain, ReliableSink}, - io::{UnreliableDrain, UnreliableSink}, - metrics::{ProtocolMetricCache, RemoveReason}, + metrics::ProtocolMetricCache, types::Bandwidth, - ProtocolError, RecvProtocol, SendProtocol, + ProtocolError, RecvProtocol, SendProtocol, UnreliableDrain, UnreliableSink, }; use async_trait::async_trait; use std::time::{Duration, Instant}; #[cfg(feature = "trace_pedantic")] use tracing::trace; +/// used for implementing your own MPSC `Sink` and `Drain` #[derive(Debug)] -pub /* should be private */ enum MpscMsg { +pub enum MpscMsg { Event(ProtocolEvent), InitFrame(InitFrame), } +/// MPSC implementation of [`SendProtocol`] +/// +/// [`SendProtocol`]: crate::SendProtocol #[derive(Debug)] -pub struct MpscSendProtcol +pub struct MpscSendProtocol where D: UnreliableDrain, { @@ -28,8 +33,11 @@ where metrics: ProtocolMetricCache, } +/// MPSC implementation of [`RecvProtocol`] +/// +/// [`RecvProtocol`]: crate::RecvProtocol #[derive(Debug)] -pub struct MpscRecvProtcol +pub struct MpscRecvProtocol where S: UnreliableSink, { @@ -37,7 +45,7 @@ where metrics: ProtocolMetricCache, } -impl MpscSendProtcol +impl MpscSendProtocol where D: UnreliableDrain, { @@ -50,7 +58,7 @@ where } } -impl MpscRecvProtcol +impl MpscRecvProtocol where S: UnreliableSink, { @@ -58,7 +66,7 @@ where } #[async_trait] -impl SendProtocol for MpscSendProtcol +impl SendProtocol for MpscSendProtocol where D: UnreliableDrain, { @@ -69,15 +77,25 @@ where trace!(?event, "send"); match &event { ProtocolEvent::Message { - buffer, + data: _data, mid: _, - sid, + sid: _sid, } => { - let sid = *sid; - let bytes = buffer.data.len() as u64; - self.metrics.smsg_ib(sid, bytes); + #[cfg(feature = "metrics")] + let (bytes, line) = { + let sid = *_sid; + let bytes = _data.len() as u64; + let line = self.metrics.init_sid(sid); + line.smsg_it.inc(); + line.smsg_ib.inc_by(bytes); + (bytes, line) + }; let r = self.drain.send(MpscMsg::Event(event)).await; - self.metrics.smsg_ob(sid, RemoveReason::Finished, bytes); + #[cfg(feature = "metrics")] + { + line.smsg_ot[RemoveReason::Finished.i()].inc(); + line.smsg_ob[RemoveReason::Finished.i()].inc_by(bytes); + } r }, _ => self.drain.send(MpscMsg::Event(event)).await, @@ -88,7 +106,7 @@ where } #[async_trait] -impl RecvProtocol for MpscRecvProtcol +impl RecvProtocol for MpscRecvProtocol where S: UnreliableSink, { @@ -98,16 +116,17 @@ where trace!(?event, "recv"); match event { MpscMsg::Event(e) => { - if let ProtocolEvent::Message { - buffer, - mid: _, - sid, - } = &e + #[cfg(feature = "metrics")] { - let sid = *sid; - let bytes = buffer.data.len() as u64; - self.metrics.rmsg_ib(sid, bytes); - self.metrics.rmsg_ob(sid, RemoveReason::Finished, bytes); + if let ProtocolEvent::Message { data, mid: _, sid } = &e { + let sid = *sid; + let bytes = data.len() as u64; + let line = self.metrics.init_sid(sid); + line.rmsg_it.inc(); + line.rmsg_ib.inc_by(bytes); + line.rmsg_ot[RemoveReason::Finished.i()].inc(); + line.rmsg_ob[RemoveReason::Finished.i()].inc_by(bytes); + } } Ok(e) }, @@ -117,7 +136,7 @@ where } #[async_trait] -impl ReliableDrain for MpscSendProtcol +impl ReliableDrain for MpscSendProtocol where D: UnreliableDrain, { @@ -127,7 +146,7 @@ where } #[async_trait] -impl ReliableSink for MpscRecvProtcol +impl ReliableSink for MpscRecvProtocol where S: UnreliableSink, { @@ -142,10 +161,7 @@ where #[cfg(test)] pub mod test_utils { use super::*; - use crate::{ - io::*, - metrics::{ProtocolMetricCache, ProtocolMetrics}, - }; + use crate::metrics::{ProtocolMetricCache, ProtocolMetrics}; use async_channel::*; use std::sync::Arc; @@ -160,7 +176,7 @@ pub mod test_utils { pub fn ac_bound( cap: usize, metrics: Option, - ) -> [(MpscSendProtcol, MpscRecvProtcol); 2] { + ) -> [(MpscSendProtocol, MpscRecvProtocol); 2] { let (s1, r1) = async_channel::bounded(cap); let (s2, r2) = async_channel::bounded(cap); let m = metrics.unwrap_or_else(|| { @@ -168,12 +184,12 @@ pub mod test_utils { }); [ ( - MpscSendProtcol::new(ACDrain { sender: s1 }, m.clone()), - MpscRecvProtcol::new(ACSink { receiver: r2 }, m.clone()), + MpscSendProtocol::new(ACDrain { sender: s1 }, m.clone()), + MpscRecvProtocol::new(ACSink { receiver: r2 }, m.clone()), ), ( - MpscSendProtcol::new(ACDrain { sender: s2 }, m.clone()), - MpscRecvProtcol::new(ACSink { receiver: r1 }, m.clone()), + MpscSendProtocol::new(ACDrain { sender: s2 }, m.clone()), + MpscRecvProtocol::new(ACSink { receiver: r1 }, m), ), ] } diff --git a/network/protocol/src/prio.rs b/network/protocol/src/prio.rs index 2348028d15..374a1ac216 100644 --- a/network/protocol/src/prio.rs +++ b/network/protocol/src/prio.rs @@ -1,12 +1,12 @@ use crate::{ - frame::Frame, - message::{MessageBuffer, OutgoingMessage}, + frame::OTFrame, + message::OTMessage, metrics::{ProtocolMetricCache, RemoveReason}, - types::{Bandwidth, Mid, Prio, Promises, Sid}, + types::{Bandwidth, Mid, Prio, Promises, Sid, HIGHEST_PRIO}, }; +use bytes::Bytes; use std::{ collections::{HashMap, VecDeque}, - sync::Arc, time::Duration, }; @@ -15,7 +15,7 @@ struct StreamInfo { pub(crate) guaranteed_bandwidth: Bandwidth, pub(crate) prio: Prio, pub(crate) promises: Promises, - pub(crate) messages: VecDeque, + pub(crate) messages: VecDeque, } /// Responsible for queueing messages. @@ -31,8 +31,6 @@ pub(crate) struct PrioManager { // Send everything ONCE, then keep it till it's confirmed impl PrioManager { - const HIGHEST_PRIO: u8 = 7; - pub fn new(metrics: ProtocolMetricCache) -> Self { Self { streams: HashMap::new(), @@ -67,34 +65,34 @@ impl PrioManager { pub fn is_empty(&self) -> bool { self.streams.is_empty() } - pub fn add(&mut self, buffer: Arc, mid: Mid, sid: Sid) { + pub fn add(&mut self, buffer: Bytes, mid: Mid, sid: Sid) { self.streams .get_mut(&sid) .unwrap() .messages - .push_back(OutgoingMessage::new(buffer, mid, sid)); + .push_back(OTMessage::new(buffer, mid, sid)); } /// bandwidth might be extended, as for technical reasons /// guaranteed_bandwidth is used and frames are always 1400 bytes. - pub fn grab(&mut self, bandwidth: Bandwidth, dt: Duration) -> Vec { + pub fn grab(&mut self, bandwidth: Bandwidth, dt: Duration) -> (Vec, Bandwidth) { let total_bytes = (bandwidth as f64 * dt.as_secs_f64()) as u64; let mut cur_bytes = 0u64; let mut frames = vec![]; - let mut prios = [0u64; (Self::HIGHEST_PRIO + 1) as usize]; + let mut prios = [0u64; (HIGHEST_PRIO + 1) as usize]; let metrics = &mut self.metrics; let mut process_stream = |stream: &mut StreamInfo, mut bandwidth: i64, cur_bytes: &mut u64| { - let mut finished = vec![]; + let mut finished = None; 'outer: for (i, msg) in stream.messages.iter_mut().enumerate() { while let Some(frame) = msg.next() { - let b = if matches!(frame, Frame::DataHeader { .. }) { - 25 + let b = if let OTFrame::Data { data, .. } = &frame { + crate::frame::TCP_DATA_CNS + 1 + data.len() } else { - 19 + OutgoingMessage::FRAME_DATA_SIZE - }; + crate::frame::TCP_DATA_HEADER_CNS + 1 + } as u64; bandwidth -= b as i64; *cur_bytes += b; frames.push(frame); @@ -102,41 +100,38 @@ impl PrioManager { break 'outer; } } - finished.push(i); - } - - //cleanup - for i in finished.iter().rev() { - let msg = stream.messages.remove(*i).unwrap(); let (sid, bytes) = msg.get_sid_len(); metrics.smsg_ob(sid, RemoveReason::Finished, bytes); + finished = Some(i); + } + if let Some(i) = finished { + //cleanup + stream.messages.drain(..=i); } }; // Add guaranteed bandwidth - for (_, stream) in &mut self.streams { - prios[stream.prio.min(Self::HIGHEST_PRIO) as usize] += 1; + for stream in self.streams.values_mut() { + prios[stream.prio as usize] += 1; let stream_byte_cnt = (stream.guaranteed_bandwidth as f64 * dt.as_secs_f64()) as u64; process_stream(stream, stream_byte_cnt as i64, &mut cur_bytes); } if cur_bytes < total_bytes { // Add optional bandwidth - for prio in 0..=Self::HIGHEST_PRIO { + for prio in 0..=HIGHEST_PRIO { if prios[prio as usize] == 0 { continue; } - let per_stream_bytes = (total_bytes - cur_bytes) / prios[prio as usize]; - - for (_, stream) in &mut self.streams { + let per_stream_bytes = ((total_bytes - cur_bytes) / prios[prio as usize]) as i64; + for stream in self.streams.values_mut() { if stream.prio != prio { continue; } - process_stream(stream, per_stream_bytes as i64, &mut cur_bytes); + process_stream(stream, per_stream_bytes, &mut cur_bytes); } } } - - frames + (frames, cur_bytes) } } diff --git a/network/protocol/src/tcp.rs b/network/protocol/src/tcp.rs index dd84d5b013..07a44ab0de 100644 --- a/network/protocol/src/tcp.rs +++ b/network/protocol/src/tcp.rs @@ -1,26 +1,28 @@ use crate::{ event::ProtocolEvent, - frame::{Frame, InitFrame}, + frame::{ITFrame, InitFrame, OTFrame}, handshake::{ReliableDrain, ReliableSink}, - io::{UnreliableDrain, UnreliableSink}, + message::{ITMessage, ALLOC_BLOCK}, metrics::{ProtocolMetricCache, RemoveReason}, prio::PrioManager, - types::Bandwidth, - ProtocolError, RecvProtocol, SendProtocol, + types::{Bandwidth, Mid, Sid}, + ProtocolError, RecvProtocol, SendProtocol, UnreliableDrain, UnreliableSink, }; use async_trait::async_trait; use bytes::BytesMut; use std::{ collections::HashMap, - sync::Arc, time::{Duration, Instant}, }; use tracing::info; #[cfg(feature = "trace_pedantic")] use tracing::trace; +/// TCP implementation of [`SendProtocol`] +/// +/// [`SendProtocol`]: crate::SendProtocol #[derive(Debug)] -pub struct TcpSendProtcol +pub struct TcpSendProtocol where D: UnreliableDrain, { @@ -34,18 +36,22 @@ where metrics: ProtocolMetricCache, } +/// TCP implementation of [`RecvProtocol`] +/// +/// [`RecvProtocol`]: crate::RecvProtocol #[derive(Debug)] -pub struct TcpRecvProtcol +pub struct TcpRecvProtocol where S: UnreliableSink, { buffer: BytesMut, - incoming: HashMap, + itmsg_allocator: BytesMut, + incoming: HashMap, sink: S, metrics: ProtocolMetricCache, } -impl TcpSendProtcol +impl TcpSendProtocol where D: UnreliableDrain, { @@ -63,13 +69,14 @@ where } } -impl TcpRecvProtcol +impl TcpRecvProtocol where S: UnreliableSink, { pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self { Self { buffer: BytesMut::new(), + itmsg_allocator: BytesMut::with_capacity(ALLOC_BLOCK), incoming: HashMap::new(), sink, metrics, @@ -78,7 +85,7 @@ where } #[async_trait] -impl SendProtocol for TcpSendProtcol +impl SendProtocol for TcpSendProtocol where D: UnreliableDrain, { @@ -116,12 +123,12 @@ where } => { self.store .open_stream(sid, prio, promises, guaranteed_bandwidth); - event.to_frame().to_bytes(&mut self.buffer); + event.to_frame().write_bytes(&mut self.buffer); self.drain.send(self.buffer.split()).await?; }, ProtocolEvent::CloseStream { sid } => { if self.store.try_close_stream(sid) { - event.to_frame().to_bytes(&mut self.buffer); + event.to_frame().write_bytes(&mut self.buffer); self.drain.send(self.buffer.split()).await?; } else { #[cfg(feature = "trace_pedantic")] @@ -131,7 +138,7 @@ where }, ProtocolEvent::Shutdown => { if self.store.is_empty() { - event.to_frame().to_bytes(&mut self.buffer); + event.to_frame().write_bytes(&mut self.buffer); self.drain.send(self.buffer.split()).await?; } else { #[cfg(feature = "trace_pedantic")] @@ -139,35 +146,41 @@ where self.pending_shutdown = true; } }, - ProtocolEvent::Message { buffer, mid, sid } => { - self.metrics.smsg_ib(sid, buffer.data.len() as u64); - self.store.add(buffer, mid, sid); + ProtocolEvent::Message { data, mid, sid } => { + self.metrics.smsg_ib(sid, data.len() as u64); + self.store.add(data, mid, sid); }, } Ok(()) } async fn flush(&mut self, bandwidth: Bandwidth, dt: Duration) -> Result<(), ProtocolError> { - let frames = self.store.grab(bandwidth, dt); + let (frames, total_bytes) = self.store.grab(bandwidth, dt); + self.buffer.reserve(total_bytes as usize); + let mut data_frames = 0; + let mut data_bandwidth = 0; for frame in frames { - if let Frame::Data { + if let OTFrame::Data { mid: _, start: _, data, } = &frame { - self.metrics.sdata_frames_b(data.len() as u64); + data_bandwidth += data.len(); + data_frames += 1; } - frame.to_bytes(&mut self.buffer); - self.drain.send(self.buffer.split()).await?; + frame.write_bytes(&mut self.buffer); } + self.drain.send(self.buffer.split()).await?; + self.metrics + .sdata_frames_b(data_frames, data_bandwidth as u64); let mut finished_streams = vec![]; for (i, &sid) in self.closing_streams.iter().enumerate() { if self.store.try_close_stream(sid) { #[cfg(feature = "trace_pedantic")] trace!(?sid, "close stream, as it's now empty"); - Frame::CloseStream { sid }.to_bytes(&mut self.buffer); + OTFrame::CloseStream { sid }.write_bytes(&mut self.buffer); self.drain.send(self.buffer.split()).await?; finished_streams.push(i); } @@ -191,7 +204,7 @@ where if self.pending_shutdown && self.store.is_empty() { #[cfg(feature = "trace_pedantic")] trace!("shutdown, as it's now empty"); - Frame::Shutdown {}.to_bytes(&mut self.buffer); + OTFrame::Shutdown {}.write_bytes(&mut self.buffer); self.drain.send(self.buffer.split()).await?; self.pending_shutdown = false; } @@ -199,58 +212,42 @@ where } } -use crate::{ - message::MessageBuffer, - types::{Mid, Sid}, -}; - -#[derive(Debug)] -struct IncomingMsg { - sid: Sid, - length: u64, - data: MessageBuffer, -} - #[async_trait] -impl RecvProtocol for TcpRecvProtcol +impl RecvProtocol for TcpRecvProtocol where S: UnreliableSink, { async fn recv(&mut self) -> Result { 'outer: loop { - while let Some(frame) = Frame::to_frame(&mut self.buffer) { + while let Some(frame) = ITFrame::read_frame(&mut self.buffer) { #[cfg(feature = "trace_pedantic")] trace!(?frame, "recv"); match frame { - Frame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), - Frame::OpenStream { + ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), + ITFrame::OpenStream { sid, prio, promises, } => { break 'outer Ok(ProtocolEvent::OpenStream { sid, - prio, + prio: prio.min(crate::types::HIGHEST_PRIO), promises, guaranteed_bandwidth: 1_000_000, }); }, - Frame::CloseStream { sid } => { + ITFrame::CloseStream { sid } => { break 'outer Ok(ProtocolEvent::CloseStream { sid }); }, - Frame::DataHeader { sid, mid, length } => { - let m = IncomingMsg { - sid, - length, - data: MessageBuffer { data: vec![] }, - }; + ITFrame::DataHeader { sid, mid, length } => { + let m = ITMessage::new(sid, length, &mut self.itmsg_allocator); self.metrics.rmsg_ib(sid, length); self.incoming.insert(mid, m); }, - Frame::Data { + ITFrame::Data { mid, start: _, - mut data, + data, } => { self.metrics.rdata_frames_b(data.len() as u64); let m = match self.incoming.get_mut(&mid) { @@ -263,45 +260,48 @@ where break 'outer Err(ProtocolError::Closed); }, }; - m.data.data.append(&mut data); - if m.data.data.len() == m.length as usize { + m.data.extend_from_slice(&data); + if m.data.len() == m.length as usize { // finished, yay - drop(m); let m = self.incoming.remove(&mid).unwrap(); self.metrics.rmsg_ob( m.sid, RemoveReason::Finished, - m.data.data.len() as u64, + m.data.len() as u64, ); break 'outer Ok(ProtocolEvent::Message { sid: m.sid, mid, - buffer: Arc::new(m.data), + data: m.data.freeze(), }); } }, }; } let chunk = self.sink.recv().await?; - self.buffer.extend_from_slice(&chunk); + if self.buffer.is_empty() { + self.buffer = chunk; + } else { + self.buffer.extend_from_slice(&chunk); + } } } } #[async_trait] -impl ReliableDrain for TcpSendProtcol +impl ReliableDrain for TcpSendProtocol where D: UnreliableDrain, { async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> { let mut buffer = BytesMut::with_capacity(500); - frame.to_bytes(&mut buffer); + frame.write_bytes(&mut buffer); self.drain.send(buffer).await } } #[async_trait] -impl ReliableSink for TcpRecvProtcol +impl ReliableSink for TcpRecvProtocol where S: UnreliableSink, { @@ -309,7 +309,7 @@ where while self.buffer.len() < 100 { let chunk = self.sink.recv().await?; self.buffer.extend_from_slice(&chunk); - if let Some(frame) = InitFrame::to_frame(&mut self.buffer) { + if let Some(frame) = InitFrame::read_frame(&mut self.buffer) { return Ok(frame); } } @@ -321,11 +321,9 @@ where mod test_utils { //TCP protocol based on Channel use super::*; - use crate::{ - io::*, - metrics::{ProtocolMetricCache, ProtocolMetrics}, - }; + use crate::metrics::{ProtocolMetricCache, ProtocolMetrics}; use async_channel::*; + use std::sync::Arc; pub struct TcpDrain { pub sender: Sender, @@ -339,7 +337,7 @@ mod test_utils { pub fn tcp_bound( cap: usize, metrics: Option, - ) -> [(TcpSendProtcol, TcpRecvProtcol); 2] { + ) -> [(TcpSendProtocol, TcpRecvProtocol); 2] { let (s1, r1) = async_channel::bounded(cap); let (s2, r2) = async_channel::bounded(cap); let m = metrics.unwrap_or_else(|| { @@ -347,12 +345,12 @@ mod test_utils { }); [ ( - TcpSendProtcol::new(TcpDrain { sender: s1 }, m.clone()), - TcpRecvProtcol::new(TcpSink { receiver: r2 }, m.clone()), + TcpSendProtocol::new(TcpDrain { sender: s1 }, m.clone()), + TcpRecvProtocol::new(TcpSink { receiver: r2 }, m.clone()), ), ( - TcpSendProtcol::new(TcpDrain { sender: s2 }, m.clone()), - TcpRecvProtcol::new(TcpSink { receiver: r1 }, m.clone()), + TcpSendProtocol::new(TcpDrain { sender: s2 }, m.clone()), + TcpRecvProtocol::new(TcpSink { receiver: r1 }, m), ), ] } @@ -385,12 +383,13 @@ mod test_utils { #[cfg(test)] mod tests { use crate::{ + frame::OTFrame, metrics::{ProtocolMetricCache, ProtocolMetrics, RemoveReason}, tcp::test_utils::*, types::{Pid, Promises, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2}, - InitProtocol, MessageBuffer, ProtocolEvent, RecvProtocol, SendProtocol, + InitProtocol, ProtocolError, ProtocolEvent, RecvProtocol, SendProtocol, }; - use bytes::BytesMut; + use bytes::{Bytes, BytesMut}; use std::{sync::Arc, time::Duration}; #[tokio::test] @@ -409,7 +408,7 @@ mod tests { let (mut s, mut r) = (p1.0, p2.1); let event = ProtocolEvent::OpenStream { sid: Sid::new(10), - prio: 9u8, + prio: 0u8, promises: Promises::ORDERED, guaranteed_bandwidth: 1_000_000, }; @@ -433,9 +432,7 @@ mod tests { let event = ProtocolEvent::Message { sid: Sid::new(10), mid: 0, - buffer: Arc::new(MessageBuffer { - data: vec![188u8; 600], - }), + data: Bytes::from(&[188u8; 600][..]), }; s.send(event.clone()).await.unwrap(); s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); @@ -445,9 +442,7 @@ mod tests { let event = ProtocolEvent::Message { sid: Sid::new(10), mid: 1, - buffer: Arc::new(MessageBuffer { - data: vec![7u8; 30], - }), + data: Bytes::from(&[7u8; 30][..]), }; s.send(event.clone()).await.unwrap(); s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); @@ -473,9 +468,7 @@ mod tests { let event = ProtocolEvent::Message { sid, mid: 77, - buffer: Arc::new(MessageBuffer { - data: vec![99u8; 500_000], - }), + data: Bytes::from(&[99u8; 500_000][..]), }; s.send(event.clone()).await.unwrap(); s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); @@ -503,9 +496,7 @@ mod tests { let event = ProtocolEvent::Message { sid, mid: 77, - buffer: Arc::new(MessageBuffer { - data: vec![99u8; 500_000], - }), + data: Bytes::from(&[99u8; 500_000][..]), }; s.send(event).await.unwrap(); let event = ProtocolEvent::CloseStream { sid }; @@ -534,9 +525,7 @@ mod tests { let event = ProtocolEvent::Message { sid, mid: 77, - buffer: Arc::new(MessageBuffer { - data: vec![99u8; 500_000], - }), + data: Bytes::from(&[99u8; 500_000][..]), }; s.send(event).await.unwrap(); let event = ProtocolEvent::Shutdown {}; @@ -553,46 +542,80 @@ mod tests { assert!(matches!(e, ProtocolEvent::Shutdown { .. })); } + #[tokio::test] + async fn msg_finishes_after_drop() { + let sid = Sid::new(1); + let [p1, p2] = tcp_bound(10000, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::Message { + sid, + mid: 77, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let event = ProtocolEvent::Message { + sid, + mid: 78, + data: Bytes::from(&[100u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + drop(s); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + } + #[tokio::test] async fn header_and_data_in_seperate_msg() { let sid = Sid::new(1); let (s, r) = async_channel::bounded(10); let m = ProtocolMetricCache::new("tcp", Arc::new(ProtocolMetrics::new().unwrap())); let mut r = - super::TcpRecvProtcol::new(super::test_utils::TcpSink { receiver: r }, m.clone()); + super::TcpRecvProtocol::new(super::test_utils::TcpSink { receiver: r }, m.clone()); const DATA1: &[u8; 69] = b"We need to make sure that its okay to send OPEN_STREAM and DATA_HEAD "; const DATA2: &[u8; 95] = b"in one chunk and (DATA and CLOSE_STREAM) in the second chunk. and then keep the connection open"; let mut bytes = BytesMut::with_capacity(1500); - use crate::frame::Frame; - Frame::OpenStream { + OTFrame::OpenStream { sid, prio: 5u8, promises: Promises::COMPRESSED, } - .to_bytes(&mut bytes); - Frame::DataHeader { + .write_bytes(&mut bytes); + OTFrame::DataHeader { mid: 99, sid, length: (DATA1.len() + DATA2.len()) as u64, } - .to_bytes(&mut bytes); + .write_bytes(&mut bytes); s.send(bytes.split()).await.unwrap(); - Frame::Data { + OTFrame::Data { mid: 99, start: 0, - data: DATA1.to_vec(), + data: Bytes::from(&DATA1[..]), } - .to_bytes(&mut bytes); - Frame::Data { + .write_bytes(&mut bytes); + OTFrame::Data { mid: 99, start: DATA1.len() as u64, - data: DATA2.to_vec(), + data: Bytes::from(&DATA2[..]), } - .to_bytes(&mut bytes); - Frame::CloseStream { sid }.to_bytes(&mut bytes); + .write_bytes(&mut bytes); + OTFrame::CloseStream { sid }.write_bytes(&mut bytes); s.send(bytes.split()).await.unwrap(); let e = r.recv().await.unwrap(); @@ -605,6 +628,32 @@ mod tests { assert!(matches!(e, ProtocolEvent::CloseStream { .. })); } + #[tokio::test] + async fn drop_sink_while_recv() { + let sid = Sid::new(1); + let (s, r) = async_channel::bounded(10); + let m = ProtocolMetricCache::new("tcp", Arc::new(ProtocolMetrics::new().unwrap())); + let mut r = + super::TcpRecvProtocol::new(super::test_utils::TcpSink { receiver: r }, m.clone()); + + let mut bytes = BytesMut::with_capacity(1500); + OTFrame::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + } + .write_bytes(&mut bytes); + s.send(bytes.split()).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + + let e = tokio::spawn(async move { r.recv().await }); + drop(s); + + let e = e.await.unwrap(); + assert_eq!(e, Err(ProtocolError::Closed)); + } + #[tokio::test] #[should_panic] async fn send_on_stream_from_remote_without_notify() { @@ -622,9 +671,7 @@ mod tests { let event = ProtocolEvent::Message { sid: Sid::new(10), mid: 0, - buffer: Arc::new(MessageBuffer { - data: vec![188u8; 600], - }), + data: Bytes::from(&[188u8; 600][..]), }; p2.0.send(event.clone()).await.unwrap(); p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); @@ -649,9 +696,7 @@ mod tests { let event = ProtocolEvent::Message { sid: Sid::new(10), mid: 0, - buffer: Arc::new(MessageBuffer { - data: vec![188u8; 600], - }), + data: Bytes::from(&[188u8; 600][..]), }; p2.0.send(event.clone()).await.unwrap(); p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); diff --git a/network/protocol/src/types.rs b/network/protocol/src/types.rs index ba9348a16d..afa5f0d866 100644 --- a/network/protocol/src/types.rs +++ b/network/protocol/src/types.rs @@ -2,9 +2,21 @@ use bitflags::bitflags; use bytes::{Buf, BufMut, BytesMut}; use rand::Rng; +/// MessageID, unique ID per Message. pub type Mid = u64; +/// ChannelID, unique ID per Channel (Protocol) pub type Cid = u64; +/// Every Stream has a `Prio` and guaranteed [`Bandwidth`]. +/// Every send, the guarantees part is used first. +/// If there is still bandwidth left, it will be shared by all Streams with the +/// same priority. Prio 0 will be send first, then 1, ... till the last prio 7 +/// is send. Prio must be < 8! +/// +/// [`Bandwidth`]: crate::Bandwidth pub type Prio = u8; +/// guaranteed `Bandwidth`. See [`Prio`] +/// +/// [`Prio`]: crate::Prio pub type Bandwidth = u64; bitflags! { @@ -36,20 +48,23 @@ impl Promises { } pub(crate) const VELOREN_MAGIC_NUMBER: [u8; 7] = *b"VELOREN"; +/// When this semver differs, 2 Networks can't communicate. pub const VELOREN_NETWORK_VERSION: [u32; 3] = [0, 5, 0]; pub(crate) const STREAM_ID_OFFSET1: Sid = Sid::new(0); pub(crate) const STREAM_ID_OFFSET2: Sid = Sid::new(u64::MAX / 2); +/// Maximal possible Prio to choose (for performance reasons) +pub const HIGHEST_PRIO: u8 = 7; -/// Support struct used for uniquely identifying [`Participant`] over the -/// [`Network`]. -/// -/// [`Participant`]: crate::api::Participant -/// [`Network`]: crate::api::Network +/// Support struct used for uniquely identifying `Participant` over the +/// `Network`. #[derive(PartialEq, Eq, Hash, Clone, Copy)] pub struct Pid { internal: u128, } +/// Unique ID per Stream, in one Channel. +/// one side will always start with 0, while the other start with u64::MAX / 2. +/// number increases for each created Stream. #[derive(PartialEq, Eq, Hash, Clone, Copy)] pub struct Sid { internal: u64, @@ -89,19 +104,29 @@ impl Pid { } } + #[inline] pub(crate) fn from_bytes(bytes: &mut BytesMut) -> Self { Self { internal: bytes.get_u128_le(), } } + #[inline] pub(crate) fn to_bytes(&self, bytes: &mut BytesMut) { bytes.put_u128_le(self.internal) } } impl Sid { pub const fn new(internal: u64) -> Self { Self { internal } } - pub(crate) fn to_le_bytes(&self) -> [u8; 8] { self.internal.to_le_bytes() } + #[inline] + pub(crate) fn from_bytes(bytes: &mut BytesMut) -> Self { + Self { + internal: bytes.get_u64_le(), + } + } + + #[inline] + pub(crate) fn to_bytes(&self, bytes: &mut BytesMut) { bytes.put_u64_le(self.internal) } } impl std::fmt::Debug for Pid { diff --git a/network/src/api.rs b/network/src/api.rs index bd06f1e472..f60b4c4743 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -1,15 +1,12 @@ -//! -//! -//! -//! (cd network/examples/async_recv && RUST_BACKTRACE=1 cargo run) use crate::{ message::{partial_eq_bincode, Message}, participant::{A2bStreamOpen, S2bShutdownBparticipant}, scheduler::Scheduler, }; +use bytes::Bytes; #[cfg(feature = "compression")] use lz_fear::raw::DecodeError; -use network_protocol::{Bandwidth, MessageBuffer, Pid, Prio, Promises, Sid}; +use network_protocol::{Bandwidth, Pid, Prio, Promises, Sid}; #[cfg(feature = "metrics")] use prometheus::Registry; use serde::{de::DeserializeOwned, Serialize}; @@ -76,8 +73,8 @@ pub struct Stream { promises: Promises, guaranteed_bandwidth: Bandwidth, send_closed: Arc, - a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, - b2a_msg_recv_r: Option>, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Bytes)>, + b2a_msg_recv_r: Option>, a2b_close_stream_s: Option>, } @@ -125,17 +122,17 @@ pub enum StreamError { /// /// # Examples /// ```rust +/// # use std::sync::Arc; +/// use tokio::runtime::Runtime; /// use veloren_network::{Network, ProtocolAddr, Pid}; -/// use futures::executor::block_on; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on port `2999` to accept connections and connect to port `8080` to connect to a (pseudo) database Application -/// let (network, f) = Network::new(Pid::new()); -/// std::thread::spawn(f); -/// block_on(async{ +/// let runtime = Arc::new(Runtime::new().unwrap()); +/// let network = Network::new(Pid::new(), Arc::clone(&runtime)); +/// runtime.block_on(async{ /// # //setup pseudo database! -/// # let (database, fd) = Network::new(Pid::new()); -/// # std::thread::spawn(fd); +/// # let database = Network::new(Pid::new(), Arc::clone(&runtime)); /// # database.listen(ProtocolAddr::Tcp("127.0.0.1:8080".parse().unwrap())).await?; /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2999".parse().unwrap())).await?; /// let database = network.connect(ProtocolAddr::Tcp("127.0.0.1:8080".parse().unwrap())).await?; @@ -179,24 +176,20 @@ impl Network { /// /// # Examples /// ```rust - /// //Example with tokio /// use std::sync::Arc; /// use tokio::runtime::Runtime; /// use veloren_network::{Network, Pid, ProtocolAddr}; /// - /// let runtime = Runtime::new(); + /// let runtime = Runtime::new().unwrap(); /// let network = Network::new(Pid::new(), Arc::new(runtime)); /// ``` /// - /// /// Usually you only create a single `Network` for an application, /// except when client and server are in the same application, then you /// will want 2. However there are no technical limitations from /// creating more. /// - /// [`Pid::new()`]: crate::types::Pid::new - /// [`ThreadPool`]: https://docs.rs/uvth/newest/uvth/struct.ThreadPool.html - /// [`uvth`]: https://docs.rs/uvth + /// [`Pid::new()`]: network_protocol::Pid::new pub fn new(participant_id: Pid, runtime: Arc) -> Self { Self::internal_new( participant_id, @@ -215,12 +208,14 @@ impl Network { /// /// # Examples /// ```rust + /// # use std::sync::Arc; /// use prometheus::Registry; + /// use tokio::runtime::Runtime; /// use veloren_network::{Network, Pid, ProtocolAddr}; /// + /// let runtime = Runtime::new().unwrap(); /// let registry = Registry::new(); - /// let (network, f) = Network::new_with_registry(Pid::new(), ®istry); - /// std::thread::spawn(f); + /// let network = Network::new_with_registry(Pid::new(), Arc::new(runtime), ®istry); /// ``` /// [`new`]: crate::api::Network::new #[cfg(feature = "metrics")] @@ -243,7 +238,6 @@ impl Network { let (scheduler, listen_sender, connect_sender, connected_receiver, shutdown_sender) = Scheduler::new( participant_id, - Arc::clone(&runtime), #[cfg(feature = "metrics")] registry, ); @@ -274,15 +268,16 @@ impl Network { /// support multiple Protocols or NICs. /// /// # Examples - /// ```rust - /// use futures::executor::block_on; + /// ```ignore + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; /// use veloren_network::{Network, Pid, ProtocolAddr}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on port `2000` TCP on all NICs and `2001` UDP locally - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// network /// .listen(ProtocolAddr::Tcp("127.0.0.1:2000".parse().unwrap())) /// .await?; @@ -315,17 +310,17 @@ impl Network { /// When the method returns the Network either returns a [`Participant`] /// ready to open [`Streams`] on OR has returned a [`NetworkError`] (e.g. /// can't connect, or invalid Handshake) # Examples - /// ```rust - /// use futures::executor::block_on; + /// ```ignore + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; /// use veloren_network::{Network, Pid, ProtocolAddr}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, connect on port `2010` TCP and `2011` UDP like listening above - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// # remote.listen(ProtocolAddr::Tcp("127.0.0.1:2010".parse().unwrap())).await?; /// # remote.listen(ProtocolAddr::Udp("127.0.0.1:2011".parse().unwrap())).await?; /// let p1 = network @@ -379,16 +374,16 @@ impl Network { /// /// # Examples /// ```rust - /// use futures::executor::block_on; + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; /// use veloren_network::{Network, Pid, ProtocolAddr}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on port `2020` TCP and opens returns their Pid - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// network /// .listen(ProtocolAddr::Tcp("127.0.0.1:2020".parse().unwrap())) /// .await?; @@ -437,10 +432,8 @@ impl Participant { /// [`Promises`] /// /// # Arguments - /// * `prio` - valid between 0-63. The priority rates the throughput for - /// messages of the [`Stream`] e.g. prio 5 messages will get 1/2 the speed - /// prio0 messages have. Prio10 messages only 1/4 and Prio 15 only 1/8, - /// etc... + /// * `prio` - defines which stream is processed first when limited on + /// bandwidth. See [`Prio`] for documentation. /// * `promises` - use a combination of you prefered [`Promises`], see the /// link for further documentation. You can combine them, e.g. /// `Promises::ORDERED | Promises::CONSISTENCY` The Stream will then @@ -452,36 +445,39 @@ impl Participant { /// /// # Examples /// ```rust - /// use futures::executor::block_on; + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; /// use veloren_network::{Network, Pid, Promises, ProtocolAddr}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, connect on port 2100 and open a stream - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// # remote.listen(ProtocolAddr::Tcp("127.0.0.1:2100".parse().unwrap())).await?; /// let p1 = network /// .connect(ProtocolAddr::Tcp("127.0.0.1:2100".parse().unwrap())) /// .await?; /// let _s1 = p1 - /// .open(16, Promises::ORDERED | Promises::CONSISTENCY) + /// .open(4, Promises::ORDERED | Promises::CONSISTENCY) /// .await?; /// # Ok(()) /// }) /// # } /// ``` /// + /// [`Prio`]: network_protocol::Prio + /// [`Promises`]: network_protocol::Promises /// [`Streams`]: crate::api::Stream #[instrument(name="network", skip(self, prio, promises), fields(p = %self.local_pid))] pub async fn open(&self, prio: u8, promises: Promises) -> Result { + debug_assert!(prio <= network_protocol::HIGHEST_PRIO, "invalid prio"); let (p2a_return_stream_s, p2a_return_stream_r) = oneshot::channel::(); if let Err(e) = self.a2b_open_stream_s.lock().await.send(( prio, promises, - 100000u64, + 1_000_000, p2a_return_stream_s, )) { debug!(?e, "bParticipant is already closed, notifying"); @@ -509,21 +505,21 @@ impl Participant { /// /// # Examples /// ```rust + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; /// use veloren_network::{Network, Pid, ProtocolAddr, Promises}; - /// use futures::executor::block_on; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, connect on port 2110 and wait for the other side to open a stream /// // Note: It's quite unusual to actively connect, but then wait on a stream to be connected, usually the Application taking initiative want's to also create the first Stream. - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// # remote.listen(ProtocolAddr::Tcp("127.0.0.1:2110".parse().unwrap())).await?; /// let p1 = network.connect(ProtocolAddr::Tcp("127.0.0.1:2110".parse().unwrap())).await?; /// # let p2 = remote.connected().await?; - /// # p2.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # p2.open(4, Promises::ORDERED | Promises::CONSISTENCY).await?; /// let _s1 = p1.opened().await?; /// # Ok(()) /// }) @@ -565,16 +561,16 @@ impl Participant { /// /// # Examples /// ```rust - /// use futures::executor::block_on; + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; /// use veloren_network::{Network, Pid, ProtocolAddr}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on port `2030` TCP and opens returns their Pid and close connection. - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// network /// .listen(ProtocolAddr::Tcp("127.0.0.1:2030".parse().unwrap())) /// .await?; @@ -636,7 +632,7 @@ impl Participant { } } - /// Returns the remote [`Pid`] + /// Returns the remote [`Pid`](network_protocol::Pid) pub fn remote_pid(&self) -> Pid { self.remote_pid } } @@ -650,8 +646,8 @@ impl Stream { promises: Promises, guaranteed_bandwidth: Bandwidth, send_closed: Arc, - a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, - b2a_msg_recv_r: async_channel::Receiver, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Bytes)>, + b2a_msg_recv_r: async_channel::Receiver, a2b_close_stream_s: mpsc::UnboundedSender, ) -> Self { Self { @@ -694,21 +690,21 @@ impl Stream { /// /// # Example /// ``` - /// use veloren_network::{Network, ProtocolAddr, Pid}; /// # use veloren_network::Promises; - /// use futures::executor::block_on; + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; + /// use veloren_network::{Network, ProtocolAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on Port `2200` and wait for a Stream to be opened, then answer `Hello World` - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2200".parse().unwrap())).await?; /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2200".parse().unwrap())).await?; /// # // keep it alive - /// # let _stream_p = remote_p.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # let _stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY).await?; /// let participant_a = network.connected().await?; /// let mut stream_a = participant_a.opened().await?; /// //Send Message @@ -734,26 +730,24 @@ impl Stream { /// /// # Example /// ```rust - /// use veloren_network::{Network, ProtocolAddr, Pid, Message}; /// # use veloren_network::Promises; - /// use futures::executor::block_on; + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; /// use bincode; - /// use std::sync::Arc; + /// use veloren_network::{Network, ProtocolAddr, Pid, Message}; /// /// # fn main() -> std::result::Result<(), Box> { - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote1, fr1) = Network::new(Pid::new()); - /// # std::thread::spawn(fr1); - /// # let (remote2, fr2) = Network::new(Pid::new()); - /// # std::thread::spawn(fr2); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote1 = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote2 = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; /// # let remote1_p = remote1.connect(ProtocolAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; /// # let remote2_p = remote2.connect(ProtocolAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; /// # assert_eq!(remote1_p.remote_pid(), remote2_p.remote_pid()); - /// # remote1_p.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; - /// # remote2_p.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # remote1_p.open(4, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # remote2_p.open(4, Promises::ORDERED | Promises::CONSISTENCY).await?; /// let participant_a = network.connected().await?; /// let participant_b = network.connected().await?; /// let mut stream_a = participant_a.opened().await?; @@ -779,8 +773,7 @@ impl Stream { } #[cfg(debug_assertions)] message.verify(&self); - self.a2b_msg_s - .send((self.sid, Arc::clone(&message.buffer)))?; + self.a2b_msg_s.send((self.sid, message.data.clone()))?; Ok(()) } @@ -795,20 +788,20 @@ impl Stream { /// /// # Example /// ``` - /// use veloren_network::{Network, ProtocolAddr, Pid}; /// # use veloren_network::Promises; - /// use futures::executor::block_on; + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; + /// use veloren_network::{Network, ProtocolAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on Port `2220` and wait for a Stream to be opened, then listen on it - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2220".parse().unwrap())).await?; /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2220".parse().unwrap())).await?; - /// # let mut stream_p = remote_p.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # let mut stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY).await?; /// # stream_p.send("Hello World"); /// let participant_a = network.connected().await?; /// let mut stream_a = participant_a.opened().await?; @@ -828,20 +821,20 @@ impl Stream { /// /// # Example /// ``` - /// use veloren_network::{Network, ProtocolAddr, Pid}; /// # use veloren_network::Promises; - /// use futures::executor::block_on; + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; + /// use veloren_network::{Network, ProtocolAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on Port `2230` and wait for a Stream to be opened, then listen on it - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2230".parse().unwrap())).await?; /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2230".parse().unwrap())).await?; - /// # let mut stream_p = remote_p.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # let mut stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY).await?; /// # stream_p.send("Hello World"); /// let participant_a = network.connected().await?; /// let mut stream_a = participant_a.opened().await?; @@ -861,8 +854,8 @@ impl Stream { match &mut self.b2a_msg_recv_r { Some(b2a_msg_recv_r) => { match b2a_msg_recv_r.recv().await { - Ok(msg) => Ok(Message { - buffer: Arc::new(msg), + Ok(data) => Ok(Message { + data, #[cfg(feature = "compression")] compressed: self.promises.contains(Promises::COMPRESSED), }), @@ -883,20 +876,20 @@ impl Stream { /// /// # Example /// ``` - /// use veloren_network::{Network, ProtocolAddr, Pid}; /// # use veloren_network::Promises; - /// use futures::executor::block_on; + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; + /// use veloren_network::{Network, ProtocolAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on Port `2240` and wait for a Stream to be opened, then listen on it - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2240".parse().unwrap())).await?; /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2240".parse().unwrap())).await?; - /// # let mut stream_p = remote_p.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # let mut stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY).await?; /// # stream_p.send("Hello World"); /// # std::thread::sleep(std::time::Duration::from_secs(1)); /// let participant_a = network.connected().await?; @@ -913,9 +906,9 @@ impl Stream { pub fn try_recv(&mut self) -> Result, StreamError> { match &mut self.b2a_msg_recv_r { Some(b2a_msg_recv_r) => match b2a_msg_recv_r.try_recv() { - Ok(msg) => Ok(Some( + Ok(data) => Ok(Some( Message { - buffer: Arc::new(msg), + data, #[cfg(feature = "compression")] compressed: self.promises().contains(Promises::COMPRESSED), } @@ -954,7 +947,6 @@ impl Drop for Network { } tokio::task::block_in_place(|| { - /* This context prevents panic if Dropped in a async fn */ self.runtime.block_on(async { for (remote_pid, a2s_disconnect_s) in self.participant_disconnect_sender.lock().await.drain() diff --git a/network/src/channel.rs b/network/src/channel.rs index 9b6472268e..22265a7f8e 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -1,12 +1,11 @@ use async_trait::async_trait; use bytes::BytesMut; use network_protocol::{ - InitProtocolError, MpscMsg, MpscRecvProtcol, MpscSendProtcol, Pid, ProtocolError, - ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtcol, TcpSendProtcol, + Cid, InitProtocolError, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid, ProtocolError, + ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol, TcpSendProtocol, UnreliableDrain, UnreliableSink, }; -#[cfg(feature = "metrics")] use std::sync::Arc; -use std::time::Duration; +use std::{sync::Arc, time::Duration}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::tcp::{OwnedReadHalf, OwnedWriteHalf}, @@ -15,40 +14,38 @@ use tokio::{ #[derive(Debug)] pub(crate) enum Protocols { - Tcp((TcpSendProtcol, TcpRecvProtcol)), - Mpsc((MpscSendProtcol, MpscRecvProtcol)), + Tcp((TcpSendProtocol, TcpRecvProtocol)), + Mpsc((MpscSendProtocol, MpscRecvProtocol)), } #[derive(Debug)] pub(crate) enum SendProtocols { - Tcp(TcpSendProtcol), - Mpsc(MpscSendProtcol), + Tcp(TcpSendProtocol), + Mpsc(MpscSendProtocol), } #[derive(Debug)] pub(crate) enum RecvProtocols { - Tcp(TcpRecvProtcol), - Mpsc(MpscRecvProtcol), + Tcp(TcpRecvProtocol), + Mpsc(MpscRecvProtocol), } impl Protocols { - pub(crate) fn new_tcp(stream: tokio::net::TcpStream) -> Self { + pub(crate) fn new_tcp( + stream: tokio::net::TcpStream, + cid: Cid, + metrics: Arc, + ) -> Self { let (r, w) = stream.into_split(); - #[cfg(feature = "metrics")] - let metrics = ProtocolMetricCache::new( - "foooobaaaarrrrrrrr", - Arc::new(ProtocolMetrics::new().unwrap()), - ); - #[cfg(not(feature = "metrics"))] - let metrics = ProtocolMetricCache {}; + let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); - let sp = TcpSendProtcol::new(TcpDrain { half: w }, metrics.clone()); - let rp = TcpRecvProtcol::new( + let sp = TcpSendProtocol::new(TcpDrain { half: w }, metrics.clone()); + let rp = TcpRecvProtocol::new( TcpSink { half: r, buffer: BytesMut::new(), }, - metrics.clone(), + metrics, ); Protocols::Tcp((sp, rp)) } @@ -56,15 +53,13 @@ impl Protocols { pub(crate) fn new_mpsc( sender: mpsc::Sender, receiver: mpsc::Receiver, + cid: Cid, + metrics: Arc, ) -> Self { - #[cfg(feature = "metrics")] - let metrics = - ProtocolMetricCache::new("mppppsssscccc", Arc::new(ProtocolMetrics::new().unwrap())); - #[cfg(not(feature = "metrics"))] - let metrics = ProtocolMetricCache {}; + let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); - let sp = MpscSendProtcol::new(MpscDrain { sender }, metrics.clone()); - let rp = MpscRecvProtcol::new(MpscSink { receiver }, metrics.clone()); + let sp = MpscSendProtocol::new(MpscDrain { sender }, metrics.clone()); + let rp = MpscRecvProtocol::new(MpscSink { receiver }, metrics); Protocols::Mpsc((sp, rp)) } @@ -157,6 +152,7 @@ impl UnreliableSink for TcpSink { async fn recv(&mut self) -> Result { self.buffer.resize(1500, 0u8); match self.half.read(&mut self.buffer).await { + Ok(0) => Err(ProtocolError::Closed), Ok(n) => Ok(self.buffer.split_to(n)), Err(_) => Err(ProtocolError::Closed), } @@ -199,6 +195,7 @@ impl UnreliableSink for MpscSink { #[cfg(test)] mod tests { use super::*; + use bytes::Bytes; use network_protocol::{Promises, RecvProtocol, SendProtocol}; use tokio::net::{TcpListener, TcpStream}; @@ -211,8 +208,9 @@ mod tests { }); let client = TcpStream::connect("127.0.0.1:5000").await.unwrap(); let (_listener, server) = r1.await.unwrap(); - let client = Protocols::new_tcp(client); - let server = Protocols::new_tcp(server); + let metrics = Arc::new(ProtocolMetrics::new().unwrap()); + let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics)); + let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics)); let (mut s, _) = client.split(); let (_, mut r) = server.split(); let event = ProtocolEvent::OpenStream { @@ -222,8 +220,18 @@ mod tests { guaranteed_bandwidth: 1_000, }; s.send(event.clone()).await.unwrap(); - let r = r.recv().await; - match r { + s.send(ProtocolEvent::Message { + sid: Sid::new(1), + mid: 0, + data: Bytes::from(&[8u8; 8][..]), + }) + .await + .unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + drop(s); // recv must work even after shutdown of send! + tokio::time::sleep(Duration::from_secs(1)).await; + let res = r.recv().await; + match res { Ok(ProtocolEvent::OpenStream { sid, prio, @@ -235,8 +243,30 @@ mod tests { assert_eq!(promises, Promises::GUARANTEED_DELIVERY); }, _ => { - panic!("wrong type {:?}", r); + panic!("wrong type {:?}", res); }, } + r.recv().await.unwrap(); + } + + #[tokio::test] + async fn tokio_sink_stop_after_drop() { + let listener = TcpListener::bind("127.0.0.1:5001").await.unwrap(); + let r1 = tokio::spawn(async move { + let (server, _) = listener.accept().await.unwrap(); + (listener, server) + }); + let client = TcpStream::connect("127.0.0.1:5001").await.unwrap(); + let (_listener, server) = r1.await.unwrap(); + let metrics = Arc::new(ProtocolMetrics::new().unwrap()); + let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics)); + let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics)); + let (s, _) = client.split(); + let (_, mut r) = server.split(); + let e = tokio::spawn(async move { r.recv().await }); + drop(s); + let e = e.await.unwrap(); + assert!(e.is_err()); + assert_eq!(e.unwrap_err(), ProtocolError::Closed); } } diff --git a/network/src/lib.rs b/network/src/lib.rs index 7593f1edd8..9981a9c987 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -39,29 +39,27 @@ //! //! # Examples //! ```rust -//! use futures::{executor::block_on, join}; -//! use tokio::task::sleep; +//! use std::sync::Arc; +//! use tokio::{join, runtime::Runtime, time::sleep}; //! use veloren_network::{Network, Pid, Promises, ProtocolAddr}; //! //! // Client -//! async fn client() -> std::result::Result<(), Box> { +//! async fn client(runtime: Arc) -> std::result::Result<(), Box> { //! sleep(std::time::Duration::from_secs(1)).await; // `connect` MUST be after `listen` -//! let (client_network, f) = Network::new(Pid::new()); -//! std::thread::spawn(f); +//! let client_network = Network::new(Pid::new(), runtime); //! let server = client_network //! .connect(ProtocolAddr::Tcp("127.0.0.1:12345".parse().unwrap())) //! .await?; //! let mut stream = server -//! .open(10, Promises::ORDERED | Promises::CONSISTENCY) +//! .open(4, Promises::ORDERED | Promises::CONSISTENCY) //! .await?; //! stream.send("Hello World")?; //! Ok(()) //! } //! //! // Server -//! async fn server() -> std::result::Result<(), Box> { -//! let (server_network, f) = Network::new(Pid::new()); -//! std::thread::spawn(f); +//! async fn server(runtime: Arc) -> std::result::Result<(), Box> { +//! let server_network = Network::new(Pid::new(), runtime); //! server_network //! .listen(ProtocolAddr::Tcp("127.0.0.1:12345".parse().unwrap())) //! .await?; @@ -74,8 +72,10 @@ //! } //! //! fn main() -> std::result::Result<(), Box> { -//! block_on(async { -//! let (result_c, result_s) = join!(client(), server(),); +//! let runtime = Arc::new(Runtime::new().unwrap()); +//! runtime.block_on(async { +//! let (result_c, result_s) = +//! join!(client(Arc::clone(&runtime)), server(Arc::clone(&runtime)),); //! result_c?; //! result_s?; //! Ok(()) @@ -95,14 +95,14 @@ //! [`Streams`]: crate::api::Stream //! [`send`]: crate::api::Stream::send //! [`recv`]: crate::api::Stream::recv -//! [`Pid`]: crate::types::Pid +//! [`Pid`]: network_protocol::Pid //! [`ProtocolAddr`]: crate::api::ProtocolAddr -//! [`Promises`]: crate::types::Promises +//! [`Promises`]: network_protocol::Promises mod api; mod channel; mod message; -#[cfg(feature = "metrics")] mod metrics; +mod metrics; mod participant; mod scheduler; diff --git a/network/src/message.rs b/network/src/message.rs index 1969854a7a..27c50abf8d 100644 --- a/network/src/message.rs +++ b/network/src/message.rs @@ -1,10 +1,9 @@ -use serde::{de::DeserializeOwned, Serialize}; -//use std::collections::VecDeque; use crate::api::{Stream, StreamError}; -use network_protocol::MessageBuffer; +use bytes::Bytes; #[cfg(feature = "compression")] use network_protocol::Promises; -use std::{io, sync::Arc}; +use serde::{de::DeserializeOwned, Serialize}; +use std::io; #[cfg(all(feature = "compression", debug_assertions))] use tracing::warn; @@ -16,7 +15,7 @@ use tracing::warn; /// [`Stream`]: crate::api::Stream /// [`send_raw`]: crate::api::Stream::send_raw pub struct Message { - pub(crate) buffer: Arc, + pub(crate) data: Bytes, #[cfg(feature = "compression")] pub(crate) compressed: bool, } @@ -58,7 +57,7 @@ impl Message { let _stream = stream; Self { - buffer: Arc::new(MessageBuffer { data }), + data: Bytes::from(data), #[cfg(feature = "compression")] compressed, } @@ -73,18 +72,18 @@ impl Message { /// ``` /// # use veloren_network::{Network, ProtocolAddr, Pid}; /// # use veloren_network::Promises; - /// # use futures::executor::block_on; + /// # use tokio::runtime::Runtime; + /// # use std::sync::Arc; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on Port `2300` and wait for a Stream to be opened, then listen on it - /// # let (network, f) = Network::new(Pid::new()); - /// # std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// # block_on(async { + /// # let runtime = Arc::new(Runtime::new().unwrap()); + /// # let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # runtime.block_on(async { /// # network.listen(ProtocolAddr::Tcp("127.0.0.1:2300".parse().unwrap())).await?; /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2300".parse().unwrap())).await?; - /// # let mut stream_p = remote_p.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # let mut stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY).await?; /// # stream_p.send("Hello World"); /// # let participant_a = network.connected().await?; /// let mut stream_a = participant_a.opened().await?; @@ -99,33 +98,27 @@ impl Message { /// [`recv_raw`]: crate::api::Stream::recv_raw pub fn deserialize(self) -> Result { #[cfg(not(feature = "compression"))] - let uncompressed_data = match Arc::try_unwrap(self.buffer) { - Ok(d) => d.data, - Err(b) => b.data.clone(), - }; + let uncompressed_data = self.data; #[cfg(feature = "compression")] let uncompressed_data = if self.compressed { { - let mut uncompressed_data = Vec::with_capacity(self.buffer.data.len() * 2); + let mut uncompressed_data = Vec::with_capacity(self.data.len() * 2); if let Err(e) = lz_fear::raw::decompress_raw( - &self.buffer.data, + &self.data, &[0; 0], &mut uncompressed_data, usize::MAX, ) { return Err(StreamError::Compression(e)); } - uncompressed_data + Bytes::from(uncompressed_data) } } else { - match Arc::try_unwrap(self.buffer) { - Ok(d) => d.data, - Err(b) => b.data.clone(), - } + self.data }; - match bincode::deserialize(uncompressed_data.as_slice()) { + match bincode::deserialize(&uncompressed_data) { Ok(m) => Ok(m), Err(e) => Err(StreamError::Deserialize(e)), } @@ -215,25 +208,25 @@ mod tests { #[test] fn serialize_test() { let msg = Message::serialize("abc", &stub_stream(false)); - assert_eq!(msg.buffer.data.len(), 11); - assert_eq!(msg.buffer.data[0], 3); - assert_eq!(msg.buffer.data[1..7], [0, 0, 0, 0, 0, 0]); - assert_eq!(msg.buffer.data[8], b'a'); - assert_eq!(msg.buffer.data[9], b'b'); - assert_eq!(msg.buffer.data[10], b'c'); + assert_eq!(msg.data.len(), 11); + assert_eq!(msg.data[0], 3); + assert_eq!(msg.data[1..7], [0, 0, 0, 0, 0, 0]); + assert_eq!(msg.data[8], b'a'); + assert_eq!(msg.data[9], b'b'); + assert_eq!(msg.data[10], b'c'); } #[cfg(feature = "compression")] #[test] fn serialize_compress_small() { let msg = Message::serialize("abc", &stub_stream(true)); - assert_eq!(msg.buffer.data.len(), 12); - assert_eq!(msg.buffer.data[0], 176); - assert_eq!(msg.buffer.data[1], 3); - assert_eq!(msg.buffer.data[2..8], [0, 0, 0, 0, 0, 0]); - assert_eq!(msg.buffer.data[9], b'a'); - assert_eq!(msg.buffer.data[10], b'b'); - assert_eq!(msg.buffer.data[11], b'c'); + assert_eq!(msg.data.len(), 12); + assert_eq!(msg.data[0], 176); + assert_eq!(msg.data[1], 3); + assert_eq!(msg.data[2..8], [0, 0, 0, 0, 0, 0]); + assert_eq!(msg.data[9], b'a'); + assert_eq!(msg.data[10], b'b'); + assert_eq!(msg.data[11], b'c'); } #[cfg(feature = "compression")] @@ -251,14 +244,14 @@ mod tests { "assets/data/plants/flowers/greenrose.ron", ); let msg = Message::serialize(&msg, &stub_stream(true)); - assert_eq!(msg.buffer.data.len(), 79); - assert_eq!(msg.buffer.data[0], 34); - assert_eq!(msg.buffer.data[1], 5); - assert_eq!(msg.buffer.data[2], 0); - assert_eq!(msg.buffer.data[3], 1); - assert_eq!(msg.buffer.data[20], 20); - assert_eq!(msg.buffer.data[40], 115); - assert_eq!(msg.buffer.data[60], 111); + assert_eq!(msg.data.len(), 79); + assert_eq!(msg.data[0], 34); + assert_eq!(msg.data[1], 5); + assert_eq!(msg.data[2], 0); + assert_eq!(msg.data[3], 1); + assert_eq!(msg.data[20], 20); + assert_eq!(msg.data[40], 115); + assert_eq!(msg.data[60], 111); } #[cfg(feature = "compression")] @@ -281,6 +274,6 @@ mod tests { } } let msg = Message::serialize(&msg, &stub_stream(true)); - assert_eq!(msg.buffer.data.len(), 1331); + assert_eq!(msg.data.len(), 1331); } } diff --git a/network/src/metrics.rs b/network/src/metrics.rs index 60650f68fe..d1b77d76d0 100644 --- a/network/src/metrics.rs +++ b/network/src/metrics.rs @@ -1,12 +1,10 @@ -use network_protocol::Pid; +use network_protocol::{Cid, Pid}; +#[cfg(feature = "metrics")] use prometheus::{IntCounter, IntCounterVec, IntGauge, IntGaugeVec, Opts, Registry}; use std::error::Error; /// 1:1 relation between NetworkMetrics and Network -/// use 2NF here and avoid redundant data like CHANNEL AND PARTICIPANT encoding. -/// as this will cause a matrix that is full of 0 but needs alot of bandwith and -/// storage -#[allow(dead_code)] +#[cfg(feature = "metrics")] pub struct NetworkMetrics { pub listen_requests_total: IntCounterVec, pub connect_requests_total: IntCounterVec, @@ -23,8 +21,11 @@ pub struct NetworkMetrics { pub network_info: IntGauge, } +#[cfg(not(feature = "metrics"))] +pub struct NetworkMetrics {} + +#[cfg(feature = "metrics")] impl NetworkMetrics { - #[allow(dead_code)] pub fn new(local_pid: &Pid) -> Result> { let listen_requests_total = IntCounterVec::new( Opts::new( @@ -123,6 +124,46 @@ impl NetworkMetrics { registry.register(Box::new(self.network_info.clone()))?; Ok(()) } + + pub(crate) fn channels_connected(&self, remote_p: &str, no: usize, cid: Cid) { + self.channels_connected_total + .with_label_values(&[remote_p]) + .inc(); + self.participants_channel_ids + .with_label_values(&[remote_p, &no.to_string()]) + .set(cid as i64); + } + + pub(crate) fn channels_disconnected(&self, remote_p: &str) { + self.channels_disconnected_total + .with_label_values(&[remote_p]) + .inc(); + } + + pub(crate) fn streams_opened(&self, remote_p: &str) { + self.streams_opened_total + .with_label_values(&[remote_p]) + .inc(); + } + + pub(crate) fn streams_closed(&self, remote_p: &str) { + self.streams_closed_total + .with_label_values(&[remote_p]) + .inc(); + } +} + +#[cfg(not(feature = "metrics"))] +impl NetworkMetrics { + pub fn new(_local_pid: &Pid) -> Result> { Ok(Self {}) } + + pub(crate) fn channels_connected(&self, _remote_p: &str, _no: usize, _cid: Cid) {} + + pub(crate) fn channels_disconnected(&self, _remote_p: &str) {} + + pub(crate) fn streams_opened(&self, _remote_p: &str) {} + + pub(crate) fn streams_closed(&self, _remote_p: &str) {} } impl std::fmt::Debug for NetworkMetrics { diff --git a/network/src/participant.rs b/network/src/participant.rs index 7d07a5c5ac..c7d6a9b64d 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -1,13 +1,12 @@ -#[cfg(feature = "metrics")] -use crate::metrics::NetworkMetrics; use crate::{ api::{ParticipantError, Stream}, channel::{Protocols, RecvProtocols, SendProtocols}, + metrics::NetworkMetrics, }; +use bytes::Bytes; use futures_util::{FutureExt, StreamExt}; use network_protocol::{ - Bandwidth, Cid, MessageBuffer, Pid, Prio, Promises, ProtocolEvent, RecvProtocol, SendProtocol, - Sid, + Bandwidth, Cid, Pid, Prio, Promises, ProtocolEvent, RecvProtocol, SendProtocol, Sid, }; use std::{ collections::HashMap, @@ -41,7 +40,7 @@ struct StreamInfo { prio: Prio, promises: Promises, send_closed: Arc, - b2a_msg_recv_s: Mutex>, + b2a_msg_recv_s: Mutex>, } #[derive(Debug)] @@ -68,7 +67,6 @@ pub struct BParticipant { streams: RwLock>, run_channels: Option, shutdown_barrier: AtomicI32, - #[cfg(feature = "metrics")] metrics: Arc, no_channel_error_info: RwLock<(Instant, u64)>, } @@ -86,7 +84,7 @@ impl BParticipant { local_pid: Pid, remote_pid: Pid, offset_sid: Sid, - #[cfg(feature = "metrics")] metrics: Arc, + metrics: Arc, ) -> ( Self, mpsc::UnboundedSender, @@ -118,7 +116,6 @@ impl BParticipant { Self::BARR_CHANNEL + Self::BARR_SEND + Self::BARR_RECV, ), run_channels, - #[cfg(feature = "metrics")] metrics, no_channel_error_info: RwLock::new((Instant::now(), 0)), }, @@ -139,12 +136,11 @@ impl BParticipant { let (b2b_force_close_recv_protocol_s, b2b_force_close_recv_protocol_r) = async_channel::unbounded::(); let (b2b_notify_send_of_recv_s, b2b_notify_send_of_recv_r) = - mpsc::unbounded_channel::(); + crossbeam_channel::unbounded::(); let (a2b_close_stream_s, a2b_close_stream_r) = mpsc::unbounded_channel::(); const STREAM_BOUND: usize = 10_000; - let (a2b_msg_s, a2b_msg_r) = - crossbeam_channel::bounded::<(Sid, Arc)>(STREAM_BOUND); + let (a2b_msg_s, a2b_msg_r) = crossbeam_channel::bounded::<(Sid, Bytes)>(STREAM_BOUND); let run_channels = self.run_channels.take().unwrap(); trace!("start all managers"); @@ -185,20 +181,22 @@ impl BParticipant { } //TODO: local stream_cid: HashMap to know the respective protocol + #[allow(clippy::too_many_arguments)] async fn send_mgr( &self, mut a2b_open_stream_r: mpsc::UnboundedReceiver, mut a2b_close_stream_r: mpsc::UnboundedReceiver, - a2b_msg_r: crossbeam_channel::Receiver<(Sid, Arc)>, + a2b_msg_r: crossbeam_channel::Receiver<(Sid, Bytes)>, mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, SendProtocols)>, b2b_close_send_protocol_r: async_channel::Receiver, - mut b2b_notify_send_of_recv_r: mpsc::UnboundedReceiver, + b2b_notify_send_of_recv_r: crossbeam_channel::Receiver, _b2s_prio_statistic_s: mpsc::UnboundedSender, - a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Bytes)>, a2b_close_stream_s: mpsc::UnboundedSender, ) { let mut send_protocols: HashMap = HashMap::new(); let mut interval = tokio::time::interval(Self::TICK_TIME); + let mut last_instant = Instant::now(); let mut stream_ids = self.offset_sid; let mut fake_mid = 0; //TODO: move MID to protocol, should be inc per stream ? or ? trace!("workaround, actively wait for first protocol"); @@ -207,41 +205,21 @@ impl BParticipant { .await .map(|(c, p)| send_protocols.insert(c, p)); loop { - let (open, close, r_event, _, addp, remp) = select!( - n = a2b_open_stream_r.recv().fuse() => (Some(n), None, None, None, None, None), - n = a2b_close_stream_r.recv().fuse() => (None, Some(n), None, None, None, None), - n = b2b_notify_send_of_recv_r.recv().fuse() => (None, None, Some(n), None, None, None), - _ = interval.tick() => (None, None, None, Some(()), None, None), - n = b2b_add_protocol_r.recv().fuse() => (None, None, None, None, Some(n), None), - n = b2b_close_send_protocol_r.recv().fuse() => (None, None, None, None, None, Some(n)), + let (open, close, _, addp, remp) = select!( + Some(n) = a2b_open_stream_r.recv().fuse() => (Some(n), None, None, None, None), + Some(n) = a2b_close_stream_r.recv().fuse() => (None, Some(n), None, None, None), + _ = interval.tick() => (None, None, Some(()), None, None), + Some(n) = b2b_add_protocol_r.recv().fuse() => (None, None, None, Some(n), None), + Ok(n) = b2b_close_send_protocol_r.recv().fuse() => (None, None, None, None, Some(n)), ); - addp.flatten().map(|(cid, p)| { + addp.map(|(cid, p)| { debug!(?cid, "add protocol"); send_protocols.insert(cid, p) }); - match remp { - Some(Ok(cid)) => { - debug!(?cid, "remove protocol"); - match send_protocols.remove(&cid) { - Some(mut prot) => { - trace!("blocking flush"); - let _ = prot.flush(u64::MAX, Duration::from_secs(1)).await; - trace!("shutdown prot"); - let _ = prot.send(ProtocolEvent::Shutdown).await; - }, - None => trace!("tried to remove protocol twice"), - }; - if send_protocols.is_empty() { - break; - } - }, - _ => (), - }; - let cid = 0; - let active = match send_protocols.get_mut(&cid) { - Some(a) => a, + let (cid, active) = match send_protocols.iter_mut().next() { + Some((cid, a)) => (*cid, a), None => { warn!("no channel"); continue; @@ -249,11 +227,7 @@ impl BParticipant { }; let active_err = async { - if let Some(Some(event)) = r_event { - active.notify_from_recv(event); - } - - if let Some(Some((prio, promises, guaranteed_bandwidth, return_s))) = open { + if let Some((prio, promises, guaranteed_bandwidth, return_s)) = open { let sid = stream_ids; trace!(?sid, "open stream"); stream_ids += Sid::from(1); @@ -279,19 +253,39 @@ impl BParticipant { active.send(event).await?; } + // process recv content first + let mut closeevents = b2b_notify_send_of_recv_r + .try_iter() + .map(|e| { + if matches!(e, ProtocolEvent::OpenStream { .. }) { + active.notify_from_recv(e); + None + } else { + Some(e) + } + }) + .collect::>(); + // get all messages and assign it to a channel for (sid, buffer) in a2b_msg_r.try_iter() { fake_mid += 1; active .send(ProtocolEvent::Message { - buffer, + data: buffer, mid: fake_mid, sid, }) .await? } - if let Some(Some(sid)) = close { + // process recv content afterwards + let _ = closeevents.drain(..).map(|e| { + if let Some(e) = e { + active.notify_from_recv(e); + } + }); + + if let Some(sid) = close { trace!(?stream_ids, "delete stream"); self.delete_stream(sid).await; // Fire&Forget the protocol will take care to verify that this Frame is delayed @@ -299,9 +293,10 @@ impl BParticipant { active.send(ProtocolEvent::CloseStream { sid }).await?; } - active - .flush(1_000_000, Duration::from_secs(1) /* TODO */) - .await?; //this actually blocks, so we cant set streams whilte it. + let send_time = Instant::now(); + let diff = send_time.duration_since(last_instant); + last_instant = send_time; + active.flush(1_000_000_000, diff).await?; //this actually blocks, so we cant set streams while it. let r: Result<(), network_protocol::ProtocolError> = Ok(()); r } @@ -311,6 +306,24 @@ impl BParticipant { // remote recv will now fail, which will trigger remote send which will trigger // recv send_protocols.remove(&cid).unwrap(); + self.metrics.channels_disconnected(&self.remote_pid_string); + } + + if let Some(cid) = remp { + debug!(?cid, "remove protocol"); + match send_protocols.remove(&cid) { + Some(mut prot) => { + self.metrics.channels_disconnected(&self.remote_pid_string); + trace!("blocking flush"); + let _ = prot.flush(u64::MAX, Duration::from_secs(1)).await; + trace!("shutdown prot"); + let _ = prot.send(ProtocolEvent::Shutdown).await; + }, + None => trace!("tried to remove protocol twice"), + }; + if send_protocols.is_empty() { + break; + } } } trace!("Stop send_mgr"); @@ -318,14 +331,15 @@ impl BParticipant { .fetch_sub(Self::BARR_SEND, Ordering::Relaxed); } + #[allow(clippy::too_many_arguments)] async fn recv_mgr( &self, b2a_stream_opened_s: mpsc::UnboundedSender, mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, RecvProtocols)>, b2b_force_close_recv_protocol_r: async_channel::Receiver, b2b_close_send_protocol_s: async_channel::Sender, - b2b_notify_send_of_recv_s: mpsc::UnboundedSender, - a2b_msg_s: crossbeam_channel::Sender<(Sid, Arc)>, + b2b_notify_send_of_recv_s: crossbeam_channel::Sender, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Bytes)>, a2b_close_stream_s: mpsc::UnboundedSender, ) { let mut recv_protocols: HashMap> = HashMap::new(); @@ -355,23 +369,27 @@ impl BParticipant { loop { let (event, addp, remp) = select!( - next = hacky_recv_r.recv().fuse() => (Some(next), None, None), - Some(next) = b2b_add_protocol_r.recv().fuse() => (None, Some(next), None), - next = b2b_force_close_recv_protocol_r.recv().fuse() => (None, None, Some(next)), + Some(n) = hacky_recv_r.recv().fuse() => (Some(n), None, None), + Some(n) = b2b_add_protocol_r.recv().fuse() => (None, Some(n), None), + Ok(n) = b2b_force_close_recv_protocol_r.recv().fuse() => (None, None, Some(n)), + else => { + error!("recv_mgr -> something is seriously wrong!, end recv_mgr"); + break; + } ); - addp.map(|(cid, p)| { + if let Some((cid, p)) = addp { debug!(?cid, "add protocol"); retrigger(cid, p, &mut recv_protocols); - }); - if let Some(Ok(cid)) = remp { + }; + if let Some(cid) = remp { // no need to stop the send_mgr here as it has been canceled before if remove_c(&mut recv_protocols, &cid) { break; } }; - if let Some(Some((cid, r, p))) = event { + if let Some((cid, r, p)) = event { match r { Ok(ProtocolEvent::OpenStream { sid, @@ -381,6 +399,8 @@ impl BParticipant { }) => { trace!(?sid, "open stream"); let _ = b2b_notify_send_of_recv_s.send(r.unwrap()); + // waiting for receiving is not necessary, because the send_mgr will first + // process this before process messages! let stream = self .create_stream( sid, @@ -400,22 +420,11 @@ impl BParticipant { self.delete_stream(sid).await; retrigger(cid, p, &mut recv_protocols); }, - Ok(ProtocolEvent::Message { - buffer, - mid: _, - sid, - }) => { - let buffer = Arc::try_unwrap(buffer).unwrap(); + Ok(ProtocolEvent::Message { data, mid: _, sid }) => { let lock = self.streams.read().await; match lock.get(&sid) { Some(stream) => { - stream - .b2a_msg_recv_s - .lock() - .await - .send(buffer) - .await - .unwrap(); + let _ = stream.b2a_msg_recv_s.lock().await.send(data).await; }, None => warn!("recv a msg with orphan stream"), }; @@ -442,7 +451,11 @@ impl BParticipant { } } } - + trace!("receiving no longer possible, closing all streams"); + for (_, si) in self.streams.write().await.drain() { + si.send_closed.store(true, Ordering::Relaxed); + self.metrics.streams_closed(&self.remote_pid_string); + } trace!("Stop recv_mgr"); self.shutdown_barrier .fetch_sub(Self::BARR_RECV, Ordering::Relaxed); @@ -459,13 +472,11 @@ impl BParticipant { .for_each_concurrent(None, |(cid, _, protocol, b2s_create_channel_done_s)| { // This channel is now configured, and we are running it in scope of the // participant. - //let w2b_frames_s = w2b_frames_s.clone(); let channels = Arc::clone(&self.channels); let b2b_add_send_protocol_s = b2b_add_send_protocol_s.clone(); let b2b_add_recv_protocol_s = b2b_add_recv_protocol_s.clone(); async move { let mut lock = channels.write().await; - #[cfg(feature = "metrics")] let mut channel_no = lock.len(); lock.insert( cid, @@ -479,21 +490,12 @@ impl BParticipant { b2b_add_send_protocol_s.send((cid, send)).unwrap(); b2b_add_recv_protocol_s.send((cid, recv)).unwrap(); b2s_create_channel_done_s.send(()).unwrap(); - #[cfg(feature = "metrics")] - { - self.metrics - .channels_connected_total - .with_label_values(&[&self.remote_pid_string]) - .inc(); - if channel_no > 5 { - debug!(?channel_no, "metrics will overwrite channel #5"); - channel_no = 5; - } - self.metrics - .participants_channel_ids - .with_label_values(&[&self.remote_pid_string, &channel_no.to_string()]) - .set(cid as i64); + if channel_no > 5 { + debug!(?channel_no, "metrics will overwrite channel #5"); + channel_no = 5; } + self.metrics + .channels_connected(&self.remote_pid_string, channel_no, cid); } }) .await; @@ -544,6 +546,7 @@ impl BParticipant { } } }; + let (timeout_time, sender) = s2b_shutdown_bparticipant_r.await.unwrap(); debug!("participant_shutdown_mgr triggered. Closing all streams for send"); { @@ -602,11 +605,7 @@ impl BParticipant { /// Stopping API and participant usage /// Protocol will take care of the order of the frame - async fn delete_stream( - &self, - sid: Sid, - /* #[cfg(feature = "metrics")] frames_out_total_cache: &mut MultiCidFrameCache, */ - ) { + async fn delete_stream(&self, sid: Sid) { let stream = { self.streams.write().await.remove(&sid) }; match stream { Some(si) => { @@ -617,12 +616,7 @@ impl BParticipant { trace!("Couldn't find the stream, might be simultaneous close from local/remote") }, } - /* - #[cfg(feature = "metrics")] - self.metrics - .streams_closed_total - .with_label_values(&[&self.remote_pid_string]) - .inc();*/ + self.metrics.streams_closed(&self.remote_pid_string); } async fn create_stream( @@ -631,10 +625,10 @@ impl BParticipant { prio: Prio, promises: Promises, guaranteed_bandwidth: Bandwidth, - a2b_msg_s: &crossbeam_channel::Sender<(Sid, Arc)>, + a2b_msg_s: &crossbeam_channel::Sender<(Sid, Bytes)>, a2b_close_stream_s: &mpsc::UnboundedSender, ) -> Stream { - let (b2a_msg_recv_s, b2a_msg_recv_r) = async_channel::unbounded::(); + let (b2a_msg_recv_s, b2a_msg_recv_r) = async_channel::unbounded::(); let send_closed = Arc::new(AtomicBool::new(false)); self.streams.write().await.insert(sid, StreamInfo { prio, @@ -642,11 +636,7 @@ impl BParticipant { send_closed: Arc::clone(&send_closed), b2a_msg_recv_s: Mutex::new(b2a_msg_recv_s), }); - #[cfg(feature = "metrics")] - self.metrics - .streams_opened_total - .with_label_values(&[&self.remote_pid_string]) - .inc(); + self.metrics.streams_opened(&self.remote_pid_string); Stream::new( self.local_pid, self.remote_pid, @@ -665,12 +655,14 @@ impl BParticipant { #[cfg(test)] mod tests { use super::*; + use network_protocol::ProtocolMetrics; use tokio::{ runtime::Runtime, sync::{mpsc, oneshot}, task::JoinHandle, }; + #[allow(clippy::type_complexity)] fn mock_bparticipant() -> ( Arc, mpsc::UnboundedSender, @@ -720,13 +712,14 @@ mod tests { ) -> Protocols { let (s1, r1) = mpsc::channel(100); let (s2, r2) = mpsc::channel(100); - let p1 = Protocols::new_mpsc(s1, r2); + let metrics = Arc::new(ProtocolMetrics::new().unwrap()); + let p1 = Protocols::new_mpsc(s1, r2, cid, Arc::clone(&metrics)); let (complete_s, complete_r) = oneshot::channel(); create_channel .send((cid, Sid::new(0), p1, complete_s)) .unwrap(); complete_r.await.unwrap(); - Protocols::new_mpsc(s2, r1) + Protocols::new_mpsc(s2, r1, cid, Arc::clone(&metrics)) } #[test] diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index a1f31fc384..c2ac365b0f 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -1,12 +1,11 @@ -#[cfg(feature = "metrics")] -use crate::metrics::NetworkMetrics; use crate::{ api::{Participant, ProtocolAddr}, channel::Protocols, + metrics::NetworkMetrics, participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel, S2bShutdownBparticipant}, }; use futures_util::{FutureExt, StreamExt}; -use network_protocol::{MpscMsg, Pid}; +use network_protocol::{Cid, MpscMsg, Pid, ProtocolMetrics}; #[cfg(feature = "metrics")] use prometheus::Registry; use rand::Rng; @@ -19,9 +18,7 @@ use std::{ time::Duration, }; use tokio::{ - io, net, - runtime::Runtime, - select, + io, net, select, sync::{mpsc, oneshot, Mutex}, }; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -37,7 +34,7 @@ use tracing::*; // - c: channel/handshake lazy_static::lazy_static! { - static ref MPSC_POOL: Mutex, oneshot::Sender>)>>> = { + static ref MPSC_POOL: Mutex>> = { Mutex::new(HashMap::new()) }; } @@ -52,6 +49,10 @@ struct ParticipantInfo { type A2sListen = (ProtocolAddr, oneshot::Sender>); type A2sConnect = (ProtocolAddr, oneshot::Sender>); type A2sDisconnect = (Pid, S2bShutdownBparticipant); +type S2sMpscConnect = ( + mpsc::Sender, + oneshot::Sender>, +); #[derive(Debug)] struct ControlChannels { @@ -72,7 +73,6 @@ struct ParticipantChannels { #[derive(Debug)] pub struct Scheduler { local_pid: Pid, - runtime: Arc, local_secret: u128, closed: AtomicBool, run_channels: Option, @@ -80,8 +80,8 @@ pub struct Scheduler { participants: Arc>>, channel_ids: Arc, channel_listener: Mutex>>, - #[cfg(feature = "metrics")] metrics: Arc, + protocol_metrics: Arc, } impl Scheduler { @@ -89,7 +89,6 @@ impl Scheduler { pub fn new( local_pid: Pid, - runtime: Arc, #[cfg(feature = "metrics")] registry: Option<&Registry>, ) -> ( Self, @@ -120,13 +119,14 @@ impl Scheduler { b2s_prio_statistic_s, }; - #[cfg(feature = "metrics")] let metrics = Arc::new(NetworkMetrics::new(&local_pid).unwrap()); + let protocol_metrics = Arc::new(ProtocolMetrics::new().unwrap()); #[cfg(feature = "metrics")] { if let Some(registry) = registry { metrics.register(registry).unwrap(); + protocol_metrics.register(registry).unwrap(); } } @@ -136,7 +136,6 @@ impl Scheduler { ( Self { local_pid, - runtime, local_secret, closed: AtomicBool::new(false), run_channels, @@ -144,8 +143,8 @@ impl Scheduler { participants: Arc::new(Mutex::new(HashMap::new())), channel_ids: Arc::new(AtomicU64::new(0)), channel_listener: Mutex::new(HashMap::new()), - #[cfg(feature = "metrics")] metrics, + protocol_metrics, }, a2s_listen_s, a2s_connect_s, @@ -206,7 +205,7 @@ impl Scheduler { ) { trace!("Start connect_mgr"); while let Some((addr, pid_sender)) = a2s_connect_r.recv().await { - let (protocol, handshake) = match addr { + let (protocol, cid, handshake) = match addr { ProtocolAddr::Tcp(addr) => { #[cfg(feature = "metrics")] self.metrics @@ -220,8 +219,13 @@ impl Scheduler { continue; }, }; + let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); info!("Connecting Tcp to: {}", stream.peer_addr().unwrap()); - (Protocols::new_tcp(stream), false) + ( + Protocols::new_tcp(stream, cid, Arc::clone(&self.protocol_metrics)), + cid, + false, + ) }, ProtocolAddr::Mpsc(addr) => { let mpsc_s = match MPSC_POOL.lock().await.get(&addr) { @@ -244,9 +248,16 @@ impl Scheduler { .unwrap(); let local_to_remote_s = local_to_remote_oneshot_r.await.unwrap(); + let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); info!(?addr, "Connecting Mpsc"); ( - Protocols::new_mpsc(local_to_remote_s, remote_to_local_r), + Protocols::new_mpsc( + local_to_remote_s, + remote_to_local_r, + cid, + Arc::clone(&self.protocol_metrics), + ), + cid, false, ) }, @@ -285,7 +296,7 @@ impl Scheduler { //}, _ => unimplemented!(), }; - self.init_protocol(protocol, Some(pid_sender), handshake) + self.init_protocol(protocol, cid, Some(pid_sender), handshake) .await; } trace!("Stop connect_mgr"); @@ -422,7 +433,8 @@ impl Scheduler { }, }; info!("Accepting Tcp from: {}", remote_addr); - self.init_protocol(Protocols::new_tcp(stream), None, true) + let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); + self.init_protocol(Protocols::new_tcp(stream, cid, Arc::clone(&self.protocol_metrics)), cid, None, true) .await; } }, @@ -440,7 +452,8 @@ impl Scheduler { let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND); local_remote_to_local_s.send(remote_to_local_s).unwrap(); info!(?addr, "Accepting Mpsc from"); - self.init_protocol(Protocols::new_mpsc(local_to_remote_s, remote_to_local_r), None, true) + let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); + self.init_protocol(Protocols::new_mpsc(local_to_remote_s, remote_to_local_r, cid, Arc::clone(&self.protocol_metrics)), cid, None, true) .await; } warn!("MpscStream Failed, stopping"); @@ -529,6 +542,7 @@ impl Scheduler { async fn init_protocol( &self, mut protocol: Protocols, + cid: Cid, s2a_return_pid_s: Option>>, send_handshake: bool, ) { @@ -543,15 +557,12 @@ impl Scheduler { // participant can be in handshake phase ever! Someone could deadlock // the whole server easily for new clients UDP doesnt work at all, as // the UDP listening is done in another place. - let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); let participants = Arc::clone(&self.participants); - let runtime = Arc::clone(&self.runtime); - #[cfg(feature = "metrics")] let metrics = Arc::clone(&self.metrics); let local_pid = self.local_pid; let local_secret = self.local_secret; // this is necessary for UDP to work at all and to remove code duplication - self.runtime.spawn( + tokio::spawn( async move { trace!(?cid, "Open channel and be ready for Handshake"); use network_protocol::InitProtocol; @@ -575,13 +586,7 @@ impl Scheduler { b2a_stream_opened_r, s2b_create_channel_s, s2b_shutdown_bparticipant_s, - ) = BParticipant::new( - local_pid, - pid, - sid, - #[cfg(feature = "metrics")] - Arc::clone(&metrics), - ); + ) = BParticipant::new(local_pid, pid, sid, Arc::clone(&metrics)); let participant = Participant::new( local_pid, @@ -601,7 +606,7 @@ impl Scheduler { drop(participants); trace!("dropped participants lock"); let p = pid; - runtime.spawn( + tokio::spawn( bparticipant .run(participant_channels.b2s_prio_statistic_s) .instrument(tracing::info_span!("remote", ?p)), diff --git a/network/tests/closing.rs b/network/tests/closing.rs index b0e8e180a9..8606879174 100644 --- a/network/tests/closing.rs +++ b/network/tests/closing.rs @@ -230,7 +230,7 @@ fn close_network_then_disconnect_part() { fn opened_stream_before_remote_part_is_closed() { let (_, _) = helper::setup(false, 0); let (r, _n_a, p_a, _, _n_b, p_b, _) = network_participant_stream(tcp()); - let mut s2_a = r.block_on(p_a.open(10, Promises::empty())).unwrap(); + let mut s2_a = r.block_on(p_a.open(4, Promises::empty())).unwrap(); s2_a.send("HelloWorld").unwrap(); let mut s2_b = r.block_on(p_b.opened()).unwrap(); drop(p_a); @@ -243,7 +243,7 @@ fn opened_stream_before_remote_part_is_closed() { fn opened_stream_after_remote_part_is_closed() { let (_, _) = helper::setup(false, 0); let (r, _n_a, p_a, _, _n_b, p_b, _) = network_participant_stream(tcp()); - let mut s2_a = r.block_on(p_a.open(10, Promises::empty())).unwrap(); + let mut s2_a = r.block_on(p_a.open(3, Promises::empty())).unwrap(); s2_a.send("HelloWorld").unwrap(); drop(p_a); std::thread::sleep(std::time::Duration::from_millis(1000)); @@ -260,14 +260,14 @@ fn opened_stream_after_remote_part_is_closed() { fn open_stream_after_remote_part_is_closed() { let (_, _) = helper::setup(false, 0); let (r, _n_a, p_a, _, _n_b, p_b, _) = network_participant_stream(tcp()); - let mut s2_a = r.block_on(p_a.open(10, Promises::empty())).unwrap(); + let mut s2_a = r.block_on(p_a.open(4, Promises::empty())).unwrap(); s2_a.send("HelloWorld").unwrap(); drop(p_a); std::thread::sleep(std::time::Duration::from_millis(1000)); let mut s2_b = r.block_on(p_b.opened()).unwrap(); assert_eq!(r.block_on(s2_b.recv()), Ok("HelloWorld".to_string())); assert_eq!( - r.block_on(p_b.open(20, Promises::empty())).unwrap_err(), + r.block_on(p_b.open(5, Promises::empty())).unwrap_err(), ParticipantError::ParticipantDisconnected ); drop((_n_a, _n_b, p_b)); //clean teardown @@ -294,7 +294,7 @@ fn open_participant_before_remote_part_is_closed() { let addr = tcp(); r.block_on(n_a.listen(addr.clone())).unwrap(); let p_b = r.block_on(n_b.connect(addr)).unwrap(); - let mut s1_b = r.block_on(p_b.open(10, Promises::empty())).unwrap(); + let mut s1_b = r.block_on(p_b.open(4, Promises::empty())).unwrap(); s1_b.send("HelloWorld").unwrap(); let p_a = r.block_on(n_a.connected()).unwrap(); drop(s1_b); @@ -314,7 +314,7 @@ fn open_participant_after_remote_part_is_closed() { let addr = tcp(); r.block_on(n_a.listen(addr.clone())).unwrap(); let p_b = r.block_on(n_b.connect(addr)).unwrap(); - let mut s1_b = r.block_on(p_b.open(10, Promises::empty())).unwrap(); + let mut s1_b = r.block_on(p_b.open(4, Promises::empty())).unwrap(); s1_b.send("HelloWorld").unwrap(); drop(s1_b); drop(p_b); @@ -334,7 +334,7 @@ fn close_network_scheduler_completely() { let addr = tcp(); r.block_on(n_a.listen(addr.clone())).unwrap(); let p_b = r.block_on(n_b.connect(addr)).unwrap(); - let mut s1_b = r.block_on(p_b.open(10, Promises::empty())).unwrap(); + let mut s1_b = r.block_on(p_b.open(4, Promises::empty())).unwrap(); s1_b.send("HelloWorld").unwrap(); let p_a = r.block_on(n_a.connected()).unwrap(); diff --git a/network/tests/helper.rs b/network/tests/helper.rs index a06b59578c..eb5806190b 100644 --- a/network/tests/helper.rs +++ b/network/tests/helper.rs @@ -67,7 +67,7 @@ pub fn network_participant_stream( let p1_b = n_b.connect(addr).await.unwrap(); let p1_a = n_a.connected().await.unwrap(); - let s1_a = p1_a.open(10, Promises::empty()).await.unwrap(); + let s1_a = p1_a.open(4, Promises::empty()).await.unwrap(); let s1_b = p1_b.opened().await.unwrap(); (n_a, p1_a, s1_a, n_b, p1_b, s1_b) @@ -76,28 +76,28 @@ pub fn network_participant_stream( } #[allow(dead_code)] -pub fn tcp() -> veloren_network::ProtocolAddr { +pub fn tcp() -> ProtocolAddr { lazy_static! { static ref PORTS: AtomicU16 = AtomicU16::new(5000); } let port = PORTS.fetch_add(1, Ordering::Relaxed); - veloren_network::ProtocolAddr::Tcp(SocketAddr::from(([127, 0, 0, 1], port))) + ProtocolAddr::Tcp(SocketAddr::from(([127, 0, 0, 1], port))) } #[allow(dead_code)] -pub fn udp() -> veloren_network::ProtocolAddr { +pub fn udp() -> ProtocolAddr { lazy_static! { static ref PORTS: AtomicU16 = AtomicU16::new(5000); } let port = PORTS.fetch_add(1, Ordering::Relaxed); - veloren_network::ProtocolAddr::Udp(SocketAddr::from(([127, 0, 0, 1], port))) + ProtocolAddr::Udp(SocketAddr::from(([127, 0, 0, 1], port))) } #[allow(dead_code)] -pub fn mpsc() -> veloren_network::ProtocolAddr { +pub fn mpsc() -> ProtocolAddr { lazy_static! { static ref PORTS: AtomicU64 = AtomicU64::new(5000); } let port = PORTS.fetch_add(1, Ordering::Relaxed); - veloren_network::ProtocolAddr::Mpsc(port) + ProtocolAddr::Mpsc(port) } diff --git a/network/tests/integration.rs b/network/tests/integration.rs index fd33dab8e3..af30b1c89f 100644 --- a/network/tests/integration.rs +++ b/network/tests/integration.rs @@ -177,7 +177,7 @@ fn api_stream_send_main() -> std::result::Result<(), Box> .await?; // keep it alive let _stream_p = remote_p - .open(16, Promises::ORDERED | Promises::CONSISTENCY) + .open(4, Promises::ORDERED | Promises::CONSISTENCY) .await?; let participant_a = network.connected().await?; let mut stream_a = participant_a.opened().await?; @@ -205,7 +205,7 @@ fn api_stream_recv_main() -> std::result::Result<(), Box> .connect(ProtocolAddr::Tcp("127.0.0.1:1220".parse().unwrap())) .await?; let mut stream_p = remote_p - .open(16, Promises::ORDERED | Promises::CONSISTENCY) + .open(4, Promises::ORDERED | Promises::CONSISTENCY) .await?; stream_p.send("Hello World")?; let participant_a = network.connected().await?; diff --git a/server-cli/src/logging.rs b/server-cli/src/logging.rs index 18c0952780..74eb2dc304 100644 --- a/server-cli/src/logging.rs +++ b/server-cli/src/logging.rs @@ -17,7 +17,9 @@ pub fn init(basic: bool) { env.add_directive("veloren_world::sim=info".parse().unwrap()) .add_directive("veloren_world::civ=info".parse().unwrap()) .add_directive("uvth=warn".parse().unwrap()) - .add_directive("tiny_http=warn".parse().unwrap()) + .add_directive("hyper=info".parse().unwrap()) + .add_directive("prometheus_hyper=info".parse().unwrap()) + .add_directive("mio::pool=info".parse().unwrap()) .add_directive("mio::sys::windows=debug".parse().unwrap()) .add_directive("veloren_network_protocol=info".parse().unwrap()) .add_directive( diff --git a/server/src/connection_handler.rs b/server/src/connection_handler.rs index 4e4dc02e0a..4128d3115c 100644 --- a/server/src/connection_handler.rs +++ b/server/src/connection_handler.rs @@ -104,11 +104,11 @@ impl ConnectionHandler { let reliable = Promises::ORDERED | Promises::CONSISTENCY; let reliablec = reliable | Promises::COMPRESSED; - let general_stream = participant.open(10, reliablec).await?; - let ping_stream = participant.open(5, reliable).await?; - let mut register_stream = participant.open(10, reliablec).await?; - let character_screen_stream = participant.open(10, reliablec).await?; - let in_game_stream = participant.open(10, reliablec).await?; + let general_stream = participant.open(3, reliablec).await?; + let ping_stream = participant.open(2, reliable).await?; + let mut register_stream = participant.open(3, reliablec).await?; + let character_screen_stream = participant.open(3, reliablec).await?; + let in_game_stream = participant.open(3, reliablec).await?; let server_data = receiver.recv()?;