diff --git a/CHANGELOG.md b/CHANGELOG.md index 1430a88635..42bff7edcc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Cave scatter now includes all 6 gems. - Adjusted Stonework Defender loot table to remove mindflayer drops (bag, staff, glider). - Changed default controller key bindings +- Improved network efficiency by ≈ factor 10 by using tokio. ### Removed diff --git a/Cargo.lock b/Cargo.lock index d2595f8d09..cfdd5ebe19 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -225,12 +225,6 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eab1c04a571841102f5345a8fc0f6bb3d31c315dec879b5c6e42e40ce7ffa34e" -[[package]] -name = "ascii" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbf56136a5198c7b01a49e3afcbef6cf84597273d298f54432926024107b0109" - [[package]] name = "assets_manager" version = "0.4.3" @@ -249,38 +243,25 @@ dependencies = [ ] [[package]] -name = "async-std" -version = "1.5.0" +name = "async-channel" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "538ecb01eb64eecd772087e5b6f7540cbc917f047727339a472dafed2185b267" +checksum = "59740d83946db6a5af71ae25ddf9562c2b176b2ca42cf99a455f09f4a220d6b9" dependencies = [ - "async-task", - "crossbeam-channel 0.4.4", - "crossbeam-deque 0.7.3", - "crossbeam-utils 0.7.2", + "concurrent-queue", + "event-listener", "futures-core", - "futures-io", - "futures-timer 2.0.2", - "kv-log-macro", - "log", - "memchr", - "mio 0.6.23", - "mio-uds", - "num_cpus", - "once_cell", - "pin-project-lite 0.1.11", - "pin-utils", - "slab", ] [[package]] -name = "async-task" -version = "1.3.1" +name = "async-trait" +version = "0.1.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ac2c016b079e771204030951c366db398864f5026f84a44dafb0ff20f02085d" +checksum = "8d3a45e77e34375a7923b1e8febb049bb011f064714a8e17a1a616fef01da13d" dependencies = [ - "libc", - "winapi 0.3.9", + "proc-macro2 1.0.24", + "quote 1.0.9", + "syn 1.0.60", ] [[package]] @@ -487,6 +468,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" +[[package]] +name = "cache-padded" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "631ae5198c9be5e753e5cc215e1bd73c2b466a3565173db433f52bb9d3e66dba" + [[package]] name = "calloop" version = "0.6.5" @@ -730,7 +717,7 @@ version = "3.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da3da6baa321ec19e1cc41d31bf599f00c783d0517095cdaf0332e3fe8d20680" dependencies = [ - "ascii 0.9.3", + "ascii", "byteorder", "either", "memchr", @@ -747,6 +734,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "concurrent-queue" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30ed07550be01594c6026cff2a1d7fe9c8f683caa798e12b68694ac9e88286a3" +dependencies = [ + "cache-padded", +] + [[package]] name = "conrod_core" version = "0.63.0" @@ -1066,6 +1062,7 @@ dependencies = [ "clap", "criterion-plot", "csv", + "futures", "itertools 0.10.0", "lazy_static", "num-traits", @@ -1078,6 +1075,7 @@ dependencies = [ "serde_derive", "serde_json", "tinytemplate", + "tokio 1.2.0", "walkdir 2.3.1", ] @@ -1114,16 +1112,6 @@ dependencies = [ "crossbeam-utils 0.6.6", ] -[[package]] -name = "crossbeam-channel" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b153fe7cbef478c567df0f972e02e6d736db11affe43dfc9c56a9374d1adfb87" -dependencies = [ - "crossbeam-utils 0.7.2", - "maybe-uninit", -] - [[package]] name = "crossbeam-channel" version = "0.5.0" @@ -1290,16 +1278,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "ctor" -version = "0.1.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8f45d9ad417bcef4817d614a501ab55cdd96a6fdb24f49aab89a54acfd66b19" -dependencies = [ - "quote 1.0.9", - "syn 1.0.60", -] - [[package]] name = "daggy" version = "0.5.0" @@ -1621,6 +1599,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "event-listener" +version = "2.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7531096570974c3a9dcf9e4b8e1cede1ec26cf5046219fb3b9d897503b9be59" + [[package]] name = "fallible-iterator" version = "0.2.0" @@ -1861,12 +1845,6 @@ dependencies = [ "once_cell", ] -[[package]] -name = "futures-timer" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1de7508b218029b0f01662ed8f61b1c964b3ae99d6f25462d0f55a595109df6" - [[package]] name = "futures-timer" version = "3.0.2" @@ -2242,7 +2220,7 @@ dependencies = [ "http", "indexmap", "slab", - "tokio", + "tokio 0.2.25", "tokio-util", "tracing", "tracing-futures", @@ -2357,6 +2335,16 @@ dependencies = [ "http", ] +[[package]] +name = "http-body" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2861bd27ee074e5ee891e8b539837a9430012e249d7f0ca2d795650f579c1994" +dependencies = [ + "bytes 1.0.1", + "http", +] + [[package]] name = "httparse" version = "1.3.5" @@ -2387,13 +2375,36 @@ dependencies = [ "futures-util", "h2", "http", - "http-body", + "http-body 0.3.1", "httparse", "httpdate", "itoa", "pin-project 1.0.5", "socket2", - "tokio", + "tokio 0.2.25", + "tower-service", + "tracing", + "want", +] + +[[package]] +name = "hyper" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8e946c2b1349055e0b72ae281b238baf1a3ea7307c7e9f9d64673bdd9c26ac7" +dependencies = [ + "bytes 1.0.1", + "futures-channel", + "futures-core", + "futures-util", + "http", + "http-body 0.4.0", + "httparse", + "httpdate", + "itoa", + "pin-project 1.0.5", + "socket2", + "tokio 1.2.0", "tower-service", "tracing", "want", @@ -2407,10 +2418,10 @@ checksum = "37743cc83e8ee85eacfce90f2f4102030d9ff0a95244098d781e9bee4a90abb6" dependencies = [ "bytes 0.5.6", "futures-util", - "hyper", + "hyper 0.13.10", "log", "rustls 0.18.1", - "tokio", + "tokio 0.2.25", "tokio-rustls", "webpki", ] @@ -2687,15 +2698,6 @@ version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2db585e1d738fc771bf08a151420d3ed193d9d895a36df7f6f8a9456b911ddc" -[[package]] -name = "kv-log-macro" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0de8b303297635ad57c9f5059fd9cee7a47f8e8daa09df0fcd07dd39fb22977f" -dependencies = [ - "log", -] - [[package]] name = "lazy-bytes-cast" version = "5.0.1" @@ -2863,7 +2865,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710" dependencies = [ "cfg-if 1.0.0", - "value-bag", ] [[package]] @@ -3088,17 +3089,6 @@ dependencies = [ "slab", ] -[[package]] -name = "mio-uds" -version = "0.6.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afcb699eb26d4332647cc848492bbc15eafb26f08d0304550d5aa1f612e066f0" -dependencies = [ - "iovec", - "libc", - "mio 0.6.23", -] - [[package]] name = "miow" version = "0.2.2" @@ -3989,6 +3979,18 @@ dependencies = [ "thiserror", ] +[[package]] +name = "prometheus-hyper" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc47fa532a12d544229015dd3fae32394949af098b8fe9a327b8c1e4c911d1c8" +dependencies = [ + "hyper 0.14.4", + "prometheus", + "tokio 1.2.0", + "tracing", +] + [[package]] name = "publicsuffix" version = "1.5.4" @@ -4267,8 +4269,8 @@ dependencies = [ "futures-core", "futures-util", "http", - "http-body", - "hyper", + "http-body 0.3.1", + "hyper 0.13.10", "hyper-rustls", "ipnet", "js-sys", @@ -4282,7 +4284,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "tokio", + "tokio 0.2.25", "tokio-rustls", "url", "wasm-bindgen", @@ -5106,19 +5108,6 @@ dependencies = [ "syn 1.0.60", ] -[[package]] -name = "tiny_http" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eded47106b8e52d8ed8119f0ea6e8c0f5881e69783e0297b5a8462958f334bc1" -dependencies = [ - "ascii 1.0.0", - "chrono", - "chunked_transfer", - "log", - "url", -] - [[package]] name = "tinytemplate" version = "1.2.0" @@ -5162,6 +5151,36 @@ dependencies = [ "slab", ] +[[package]] +name = "tokio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8190d04c665ea9e6b6a0dc45523ade572c088d2e6566244c1122671dbf4ae3a" +dependencies = [ + "autocfg", + "bytes 1.0.1", + "libc", + "memchr", + "mio 0.7.7", + "num_cpus", + "once_cell", + "pin-project-lite 0.2.4", + "signal-hook-registry", + "tokio-macros", + "winapi 0.3.9", +] + +[[package]] +name = "tokio-macros" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "caf7b11a536f46a809a8a9f0bb4237020f70ecbf115b842360afb127ea2fda57" +dependencies = [ + "proc-macro2 1.0.24", + "quote 1.0.9", + "syn 1.0.60", +] + [[package]] name = "tokio-rustls" version = "0.14.1" @@ -5170,10 +5189,21 @@ checksum = "e12831b255bcfa39dc0436b01e19fea231a37db570686c06ee72c423479f889a" dependencies = [ "futures-core", "rustls 0.18.1", - "tokio", + "tokio 0.2.25", "webpki", ] +[[package]] +name = "tokio-stream" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1981ad97df782ab506a1f43bf82c967326960d278acf3bf8279809648c3ff3ea" +dependencies = [ + "futures-core", + "pin-project-lite 0.2.4", + "tokio 1.2.0", +] + [[package]] name = "tokio-util" version = "0.3.1" @@ -5185,7 +5215,7 @@ dependencies = [ "futures-sink", "log", "pin-project-lite 0.1.11", - "tokio", + "tokio 0.2.25", ] [[package]] @@ -5517,26 +5547,6 @@ dependencies = [ "num_cpus", ] -[[package]] -name = "uvth" -version = "4.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e5910f9106b96334c6cae1f1d77a764bda66ac4ca9f507f73259f184fe1bb6b" -dependencies = [ - "crossbeam-channel 0.3.9", - "log", - "num_cpus", -] - -[[package]] -name = "value-bag" -version = "1.0.0-alpha.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b676010e055c99033117c2343b33a40a30b91fecd6c49055ac9cd2d6c305ab1" -dependencies = [ - "ctor", -] - [[package]] name = "vcpkg" version = "0.2.11" @@ -5596,7 +5606,7 @@ dependencies = [ "authc", "byteorder", "futures-executor", - "futures-timer 3.0.2", + "futures-timer", "futures-util", "hashbrown 0.9.1", "image", @@ -5604,14 +5614,15 @@ dependencies = [ "num_cpus", "rayon", "specs", + "tokio 1.2.0", "tracing", "tracing-subscriber", - "uvth 3.1.1", + "uvth", "vek 0.12.0", "veloren-common", "veloren-common-net", "veloren-common-sys", - "veloren_network", + "veloren-network", ] [[package]] @@ -5692,6 +5703,49 @@ dependencies = [ "wasmer", ] +[[package]] +name = "veloren-network" +version = "0.3.0" +dependencies = [ + "async-channel", + "async-trait", + "bincode", + "bitflags", + "bytes 1.0.1", + "clap", + "criterion", + "crossbeam-channel 0.5.0", + "futures-core", + "futures-util", + "lazy_static", + "lz-fear", + "prometheus", + "prometheus-hyper", + "rand 0.8.3", + "serde", + "shellexpand", + "tokio 1.2.0", + "tokio-stream", + "tracing", + "tracing-subscriber", + "veloren-network-protocol", +] + +[[package]] +name = "veloren-network-protocol" +version = "0.5.0" +dependencies = [ + "async-channel", + "async-trait", + "bitflags", + "bytes 1.0.1", + "criterion", + "prometheus", + "rand 0.8.3", + "tokio 1.2.0", + "tracing", +] + [[package]] name = "veloren-plugin-api" version = "0.1.0" @@ -5731,7 +5785,7 @@ dependencies = [ "dotenv", "futures-channel", "futures-executor", - "futures-timer 3.0.2", + "futures-timer", "futures-util", "hashbrown 0.9.1", "itertools 0.9.0", @@ -5739,6 +5793,7 @@ dependencies = [ "libsqlite3-sys", "portpicker", "prometheus", + "prometheus-hyper", "rand 0.8.3", "rayon", "ron", @@ -5748,16 +5803,16 @@ dependencies = [ "slab", "specs", "specs-idvs", - "tiny_http", + "tokio 1.2.0", "tracing", - "uvth 3.1.1", + "uvth", "vek 0.12.0", "veloren-common", "veloren-common-net", "veloren-common-sys", + "veloren-network", "veloren-plugin-api", "veloren-world", - "veloren_network", ] [[package]] @@ -5772,6 +5827,7 @@ dependencies = [ "serde", "signal-hook 0.2.3", "termcolor", + "tokio 1.2.0", "tracing", "tracing-subscriber", "tracing-tracy", @@ -5818,6 +5874,7 @@ dependencies = [ "lazy_static", "native-dialog", "num 0.3.1", + "num_cpus", "old_school_gfx_glutin_ext", "ordered-float 2.1.1", "rand 0.8.3", @@ -5827,13 +5884,14 @@ dependencies = [ "specs", "specs-idvs", "termcolor", + "tokio 1.2.0", "tracing", "tracing-appender", "tracing-log", "tracing-subscriber", "tracing-tracy", "treeculler", - "uvth 3.1.1", + "uvth", "vek 0.12.0", "veloren-client", "veloren-common", @@ -5891,29 +5949,6 @@ dependencies = [ "veloren-common-net", ] -[[package]] -name = "veloren_network" -version = "0.2.0" -dependencies = [ - "async-std", - "bincode", - "bitflags", - "clap", - "crossbeam-channel 0.5.0", - "futures", - "lazy_static", - "lz-fear", - "prometheus", - "rand 0.8.3", - "serde", - "shellexpand", - "tiny_http", - "tracing", - "tracing-futures", - "tracing-subscriber", - "uvth 4.0.1", -] - [[package]] name = "version-compare" version = "0.0.10" diff --git a/Cargo.toml b/Cargo.toml index 6bd656493d..e8104a7195 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ members = [ "voxygen/anim", "world", "network", + "network/protocol" ] # default profile for devs, fast to compile, okay enough to run, no debug information @@ -30,8 +31,10 @@ incremental = true # All dependencies (but not this crate itself) [profile.dev.package."*"] opt-level = 3 -[profile.dev.package."veloren_network"] +[profile.dev.package."veloren-network"] opt-level = 2 +[profile.dev.package."veloren-network-protocol"] +opt-level = 3 [profile.dev.package."veloren-common"] opt-level = 2 [profile.dev.package."veloren-client"] diff --git a/client/Cargo.toml b/client/Cargo.toml index c775dcfc79..1578350d29 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -14,13 +14,14 @@ default = ["simd"] common = { package = "veloren-common", path = "../common", features = ["no-assets"] } common-sys = { package = "veloren-common-sys", path = "../common/sys", default-features = false } common-net = { package = "veloren-common-net", path = "../common/net" } -network = { package = "veloren_network", path = "../network", features = ["compression"], default-features = false } +network = { package = "veloren-network", path = "../network", features = ["compression"], default-features = false } byteorder = "1.3.2" uvth = "3.1.1" futures-util = "0.3.7" futures-executor = "0.3" futures-timer = "3.0" +tokio = { version = "1", default-features = false, features = ["rt-multi-thread"] } image = { version = "0.23.12", default-features = false, features = ["png"] } num = "0.3.1" num_cpus = "1.10.1" diff --git a/client/examples/chat-cli/main.rs b/client/examples/chat-cli/main.rs index 3ffaa7d5b9..115d9ae50c 100644 --- a/client/examples/chat-cli/main.rs +++ b/client/examples/chat-cli/main.rs @@ -3,7 +3,14 @@ #![deny(clippy::clone_on_ref_ptr)] use common::{clock::Clock, comp}; -use std::{io, net::ToSocketAddrs, sync::mpsc, thread, time::Duration}; +use std::{ + io, + net::ToSocketAddrs, + sync::{mpsc, Arc}, + thread, + time::Duration, +}; +use tokio::runtime::Runtime; use tracing::{error, info}; use veloren_client::{Client, Event}; @@ -37,6 +44,8 @@ fn main() { println!("Enter your password"); let password = read_input(); + let runtime = Arc::new(Runtime::new().unwrap()); + // Create a client. let mut client = Client::new( server_addr @@ -45,6 +54,7 @@ fn main() { .next() .unwrap(), None, + runtime, ) .expect("Failed to create client instance"); diff --git a/client/src/lib.rs b/client/src/lib.rs index 1f15a66bfe..2c26d00f86 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -63,6 +63,7 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; +use tokio::runtime::Runtime; use tracing::{debug, error, trace, warn}; use uvth::{ThreadPool, ThreadPoolBuilder}; use vek::*; @@ -129,6 +130,7 @@ impl WorldData { pub struct Client { registered: bool, presence: Option, + runtime: Arc, thread_pool: ThreadPool, server_info: ServerInfo, world_data: WorldData, @@ -185,15 +187,18 @@ pub struct CharacterList { impl Client { /// Create a new `Client`. - pub fn new>(addr: A, view_distance: Option) -> Result { + pub fn new>( + addr: A, + view_distance: Option, + runtime: Arc, + ) -> Result { let mut thread_pool = ThreadPoolBuilder::new() .name("veloren-worker".into()) .build(); // We reduce the thread count by 1 to keep rendering smooth thread_pool.set_num_threads((num_cpus::get() - 1).max(1)); - let (network, scheduler) = Network::new(Pid::new()); - thread_pool.execute(scheduler); + let network = Network::new(Pid::new(), Arc::clone(&runtime)); let participant = block_on(network.connect(ProtocolAddr::Tcp(addr.into())))?; let stream = block_on(participant.opened())?; @@ -417,6 +422,7 @@ impl Client { Ok(Self { registered: false, presence: None, + runtime, thread_pool, server_info, world_data: WorldData { @@ -1733,6 +1739,8 @@ impl Client { /// exempt). pub fn thread_pool(&self) -> &ThreadPool { &self.thread_pool } + pub fn runtime(&self) -> &Arc { &self.runtime } + /// Get a reference to the client's game state. pub fn state(&self) -> &State { &self.state } @@ -2058,7 +2066,8 @@ mod tests { let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 9000); let view_distance: Option = None; - let veloren_client: Result = Client::new(socket, view_distance); + let runtime = Arc::new(Runtime::new().unwrap()); + let veloren_client: Result = Client::new(socket, view_distance, runtime); let _ = veloren_client.map(|mut client| { //register diff --git a/network/Cargo.toml b/network/Cargo.toml index 49caa4d62d..b5aeb4be3b 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -1,44 +1,67 @@ [package] -name = "veloren_network" -version = "0.2.0" +name = "veloren-network" +version = "0.3.0" authors = ["Marcel Märtens "] edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -metrics = ["prometheus"] +metrics = ["prometheus", "network-protocol/metrics"] compression = ["lz-fear"] default = ["metrics","compression"] [dependencies] +network-protocol = { package = "veloren-network-protocol", path = "protocol" } + #serialisation bincode = "1.3.1" serde = { version = "1.0" } #sending crossbeam-channel = "0.5" -# NOTE: Upgrading async-std can trigger spontanious crashes for `network`ing. Consider elaborate tests before upgrading -async-std = { version = "~1.5", default-features = false, features = ["std", "async-task", "default"] } +tokio = { version = "1.2", default-features = false, features = ["io-util", "macros", "rt", "net", "time"] } +tokio-stream = { version = "0.1.2", default-features = false } #tracing and metrics -tracing = { version = "0.1", default-features = false } -tracing-futures = "0.2" +tracing = { version = "0.1", default-features = false, features = ["attributes"]} prometheus = { version = "0.11", default-features = false, optional = true } #async -futures = { version = "0.3", features = ["thread-pool"] } +futures-core = { version = "0.3", default-features = false } +futures-util = { version = "0.3", default-features = false, features = ["std"] } +async-channel = "1.5.1" #use for .close() channels #mpsc channel registry lazy_static = { version = "1.4", default-features = false } rand = { version = "0.8" } #stream flags bitflags = "1.2.1" lz-fear = { version = "0.1.1", optional = true } +# async traits +async-trait = "0.1.42" +bytes = "^1" [dev-dependencies] tracing-subscriber = { version = "0.2.3", default-features = false, features = ["env-filter", "fmt", "chrono", "ansi", "smallvec"] } -# `uvth` needed for doc tests -uvth = { version = ">= 3.0, <= 4.0", default-features = false } +tokio = { version = "1.2", default-features = false, features = ["io-std", "fs", "rt-multi-thread"] } +futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] } clap = { version = "2.33", default-features = false } shellexpand = "2.0.0" -tiny_http = "0.8.0" serde = { version = "1.0", features = ["derive"] } +prometheus-hyper = "0.1.1" +criterion = { version = "0.3.4", features = ["default", "async_tokio"] } + +[[bench]] +name = "speed" +harness = false + +[[example]] +name = "fileshare" + +[[example]] +name = "network-speed" + +[[example]] +name = "chat" + +[[example]] +name = "tcp_loadtest" diff --git a/network/benches/speed.rs b/network/benches/speed.rs new file mode 100644 index 0000000000..8de5c78335 --- /dev/null +++ b/network/benches/speed.rs @@ -0,0 +1,143 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use std::{net::SocketAddr, sync::Arc}; +use tokio::{runtime::Runtime, sync::Mutex}; +use veloren_network::{Message, Network, Participant, Pid, Promises, ProtocolAddr, Stream}; + +fn serialize(data: &[u8], stream: &Stream) { let _ = Message::serialize(data, &stream); } + +async fn stream_msg(s1_a: Arc>, s1_b: Arc>, data: &[u8], cnt: usize) { + let mut s1_b = s1_b.lock().await; + let m = Message::serialize(&data, &s1_b); + std::thread::spawn(move || { + let mut s1_a = s1_a.try_lock().unwrap(); + for _ in 0..cnt { + s1_a.send_raw(&m).unwrap(); + } + }); + for _ in 0..cnt { + s1_b.recv_raw().await.unwrap(); + } +} + +fn rt() -> Runtime { + tokio::runtime::Builder::new_current_thread() + .build() + .unwrap() +} + +fn criterion_util(c: &mut Criterion) { + let mut c = c.benchmark_group("net_util"); + c.significance_level(0.1).sample_size(100); + + let (r, _n_a, p_a, s1_a, _n_b, _p_b, _s1_b) = + network_participant_stream(ProtocolAddr::Mpsc(5000)); + let s2_a = r.block_on(p_a.open(4, Promises::COMPRESSED)).unwrap(); + + c.throughput(Throughput::Bytes(1000)) + .bench_function("message_serialize", |b| { + let data = vec![0u8; 1000]; + b.iter(|| serialize(&data, &s1_a)) + }); + c.throughput(Throughput::Bytes(1000)) + .bench_function("message_serialize_compress", |b| { + let data = vec![0u8; 1000]; + b.iter(|| serialize(&data, &s2_a)) + }); +} + +fn criterion_mpsc(c: &mut Criterion) { + let mut c = c.benchmark_group("net_mpsc"); + c.significance_level(0.1).sample_size(10); + + let (_r, _n_a, _p_a, s1_a, _n_b, _p_b, s1_b) = + network_participant_stream(ProtocolAddr::Mpsc(5000)); + let s1_a = Arc::new(Mutex::new(s1_a)); + let s1_b = Arc::new(Mutex::new(s1_b)); + + c.throughput(Throughput::Bytes(100000000)).bench_function( + BenchmarkId::new("100MB_in_10000_msg", ""), + |b| { + let data = vec![155u8; 100_000]; + b.to_async(rt()).iter_with_setup( + || (Arc::clone(&s1_a), Arc::clone(&s1_b)), + |(s1_a, s1_b)| stream_msg(s1_a, s1_b, &data, 1_000), + ) + }, + ); + c.throughput(Throughput::Elements(100000)).bench_function( + BenchmarkId::new("100000_tiny_msg", ""), + |b| { + let data = vec![3u8; 5]; + b.to_async(rt()).iter_with_setup( + || (Arc::clone(&s1_a), Arc::clone(&s1_b)), + |(s1_a, s1_b)| stream_msg(s1_a, s1_b, &data, 100_000), + ) + }, + ); + c.finish(); + drop((_n_a, _p_a, _n_b, _p_b)); +} + +fn criterion_tcp(c: &mut Criterion) { + let mut c = c.benchmark_group("net_tcp"); + c.significance_level(0.1).sample_size(10); + + let (_r, _n_a, _p_a, s1_a, _n_b, _p_b, s1_b) = + network_participant_stream(ProtocolAddr::Tcp(SocketAddr::from(([127, 0, 0, 1], 5000)))); + let s1_a = Arc::new(Mutex::new(s1_a)); + let s1_b = Arc::new(Mutex::new(s1_b)); + + c.throughput(Throughput::Bytes(100000000)).bench_function( + BenchmarkId::new("100MB_in_1000_msg", ""), + |b| { + let data = vec![155u8; 100_000]; + b.to_async(rt()).iter_with_setup( + || (Arc::clone(&s1_a), Arc::clone(&s1_b)), + |(s1_a, s1_b)| stream_msg(s1_a, s1_b, &data, 1_000), + ) + }, + ); + c.throughput(Throughput::Elements(100000)).bench_function( + BenchmarkId::new("100000_tiny_msg", ""), + |b| { + let data = vec![3u8; 5]; + b.to_async(rt()).iter_with_setup( + || (Arc::clone(&s1_a), Arc::clone(&s1_b)), + |(s1_a, s1_b)| stream_msg(s1_a, s1_b, &data, 100_000), + ) + }, + ); + c.finish(); + drop((_n_a, _p_a, _n_b, _p_b)); +} + +criterion_group!(benches, criterion_util, criterion_mpsc, criterion_tcp); +criterion_main!(benches); + +pub fn network_participant_stream( + addr: ProtocolAddr, +) -> ( + Arc, + Network, + Participant, + Stream, + Network, + Participant, + Stream, +) { + let runtime = Arc::new(Runtime::new().unwrap()); + let (n_a, p1_a, s1_a, n_b, p1_b, s1_b) = runtime.block_on(async { + let n_a = Network::new(Pid::fake(0), Arc::clone(&runtime)); + let n_b = Network::new(Pid::fake(1), Arc::clone(&runtime)); + + n_a.listen(addr.clone()).await.unwrap(); + let p1_b = n_b.connect(addr).await.unwrap(); + let p1_a = n_a.connected().await.unwrap(); + + let s1_a = p1_a.open(4, Promises::empty()).await.unwrap(); + let s1_b = p1_b.opened().await.unwrap(); + + (n_a, p1_a, s1_a, n_b, p1_b, s1_b) + }); + (runtime, n_a, p1_a, s1_a, n_b, p1_b, s1_b) +} diff --git a/network/examples/chat.rs b/network/examples/chat.rs index 91fcdea733..e5c7737531 100644 --- a/network/examples/chat.rs +++ b/network/examples/chat.rs @@ -3,10 +3,9 @@ //! RUST_BACKTRACE=1 cargo run --example chat -- --trace=info --port 15006 //! RUST_BACKTRACE=1 cargo run --example chat -- --trace=info --port 15006 --mode=client //! ``` -use async_std::{io, sync::RwLock}; use clap::{App, Arg}; -use futures::executor::{block_on, ThreadPool}; use std::{sync::Arc, thread, time::Duration}; +use tokio::{io, io::AsyncBufReadExt, runtime::Runtime, sync::RwLock}; use tracing::*; use tracing_subscriber::EnvFilter; use veloren_network::{Network, Participant, Pid, Promises, ProtocolAddr}; @@ -100,18 +99,17 @@ fn main() { } fn server(address: ProtocolAddr) { - let (server, f) = Network::new(Pid::new()); + let r = Arc::new(Runtime::new().unwrap()); + let server = Network::new(Pid::new(), Arc::clone(&r)); let server = Arc::new(server); - std::thread::spawn(f); - let pool = ThreadPool::new().unwrap(); let participants = Arc::new(RwLock::new(Vec::new())); - block_on(async { + r.block_on(async { server.listen(address).await.unwrap(); loop { let p1 = Arc::new(server.connected().await.unwrap()); let server1 = server.clone(); participants.write().await.push(p1.clone()); - pool.spawn_ok(client_connection(server1, p1, participants.clone())); + tokio::spawn(client_connection(server1, p1, participants.clone())); } }); } @@ -132,7 +130,7 @@ async fn client_connection( Ok(msg) => { println!("[{}]: {}", username, msg); for p in participants.read().await.iter() { - match p.open(32, Promises::ORDERED | Promises::CONSISTENCY).await { + match p.open(4, Promises::ORDERED | Promises::CONSISTENCY).await { Err(_) => info!("error talking to client, //TODO drop it"), Ok(mut s) => s.send((username.clone(), msg.clone())).unwrap(), }; @@ -144,27 +142,27 @@ async fn client_connection( } fn client(address: ProtocolAddr) { - let (client, f) = Network::new(Pid::new()); - std::thread::spawn(f); - let pool = ThreadPool::new().unwrap(); + let r = Arc::new(Runtime::new().unwrap()); + let client = Network::new(Pid::new(), Arc::clone(&r)); - block_on(async { + r.block_on(async { let p1 = client.connect(address.clone()).await.unwrap(); //remote representation of p1 let mut s1 = p1 - .open(16, Promises::ORDERED | Promises::CONSISTENCY) + .open(4, Promises::ORDERED | Promises::CONSISTENCY) .await .unwrap(); //remote representation of s1 + let mut input_lines = io::BufReader::new(io::stdin()); println!("Enter your username:"); let mut username = String::new(); - io::stdin().read_line(&mut username).await.unwrap(); + input_lines.read_line(&mut username).await.unwrap(); username = username.split_whitespace().collect(); println!("Your username is: {}", username); println!("write /quit to close"); - pool.spawn_ok(read_messages(p1)); + tokio::spawn(read_messages(p1)); s1.send(username).unwrap(); loop { let mut line = String::new(); - io::stdin().read_line(&mut line).await.unwrap(); + input_lines.read_line(&mut line).await.unwrap(); line = line.split_whitespace().collect(); if line.as_str() == "/quit" { println!("goodbye"); diff --git a/network/examples/fileshare/commands.rs b/network/examples/fileshare/commands.rs index 3967631a56..a18c90b38e 100644 --- a/network/examples/fileshare/commands.rs +++ b/network/examples/fileshare/commands.rs @@ -1,9 +1,7 @@ -use async_std::{ - fs, - path::{Path, PathBuf}, -}; use rand::Rng; use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; +use tokio::fs; use veloren_network::{Participant, ProtocolAddr, Stream}; use std::collections::HashMap; diff --git a/network/examples/fileshare/main.rs b/network/examples/fileshare/main.rs index d4b7b832e7..f000f371e0 100644 --- a/network/examples/fileshare/main.rs +++ b/network/examples/fileshare/main.rs @@ -4,14 +4,9 @@ //! --profile=release -Z unstable-options -- --trace=info --port 15006) //! (cd network/examples/fileshare && RUST_BACKTRACE=1 cargo run //! --profile=release -Z unstable-options -- --trace=info --port 15007) ``` -use async_std::{io, path::PathBuf}; use clap::{App, Arg, SubCommand}; -use futures::{ - channel::mpsc, - executor::{block_on, ThreadPool}, - sink::SinkExt, -}; -use std::{thread, time::Duration}; +use std::{path::PathBuf, sync::Arc, thread, time::Duration}; +use tokio::{io, io::AsyncBufReadExt, runtime::Runtime, sync::mpsc}; use tracing::*; use tracing_subscriber::EnvFilter; use veloren_network::ProtocolAddr; @@ -56,14 +51,14 @@ fn main() { let port: u16 = matches.value_of("port").unwrap().parse().unwrap(); let address = ProtocolAddr::Tcp(format!("{}:{}", "127.0.0.1", port).parse().unwrap()); + let runtime = Arc::new(Runtime::new().unwrap()); - let (server, cmd_sender) = Server::new(); - let pool = ThreadPool::new().unwrap(); - pool.spawn_ok(server.run(address)); + let (server, cmd_sender) = Server::new(Arc::clone(&runtime)); + runtime.spawn(server.run(address)); thread::sleep(Duration::from_millis(50)); //just for trace - block_on(client(cmd_sender)); + runtime.block_on(client(cmd_sender)); } fn file_exists(file: String) -> Result<(), String> { @@ -130,14 +125,15 @@ fn get_options<'a, 'b>() -> App<'a, 'b> { ) } -async fn client(mut cmd_sender: mpsc::UnboundedSender) { +async fn client(cmd_sender: mpsc::UnboundedSender) { use std::io::Write; loop { let mut line = String::new(); + let mut input_lines = io::BufReader::new(io::stdin()); print!("==> "); std::io::stdout().flush().unwrap(); - io::stdin().read_line(&mut line).await.unwrap(); + input_lines.read_line(&mut line).await.unwrap(); let matches = match get_options().get_matches_from_safe(line.split_whitespace()) { Err(e) => { println!("{}", e.message); @@ -148,12 +144,12 @@ async fn client(mut cmd_sender: mpsc::UnboundedSender) { match matches.subcommand() { ("quit", _) => { - cmd_sender.send(LocalCommand::Shutdown).await.unwrap(); + cmd_sender.send(LocalCommand::Shutdown).unwrap(); println!("goodbye"); break; }, ("disconnect", _) => { - cmd_sender.send(LocalCommand::Disconnect).await.unwrap(); + cmd_sender.send(LocalCommand::Disconnect).unwrap(); }, ("connect", Some(connect_matches)) => { let socketaddr = connect_matches @@ -163,7 +159,6 @@ async fn client(mut cmd_sender: mpsc::UnboundedSender) { .unwrap(); cmd_sender .send(LocalCommand::Connect(ProtocolAddr::Tcp(socketaddr))) - .await .unwrap(); }, ("t", _) => { @@ -171,28 +166,23 @@ async fn client(mut cmd_sender: mpsc::UnboundedSender) { .send(LocalCommand::Connect(ProtocolAddr::Tcp( "127.0.0.1:1231".parse().unwrap(), ))) - .await .unwrap(); }, ("serve", Some(serve_matches)) => { let path = shellexpand::tilde(serve_matches.value_of("file").unwrap()); let path: PathBuf = path.parse().unwrap(); if let Some(fileinfo) = FileInfo::new(&path).await { - cmd_sender - .send(LocalCommand::Serve(fileinfo)) - .await - .unwrap(); + cmd_sender.send(LocalCommand::Serve(fileinfo)).unwrap(); } }, ("list", _) => { - cmd_sender.send(LocalCommand::List).await.unwrap(); + cmd_sender.send(LocalCommand::List).unwrap(); }, ("get", Some(get_matches)) => { let id: u32 = get_matches.value_of("id").unwrap().parse().unwrap(); let file = get_matches.value_of("file"); cmd_sender .send(LocalCommand::Get(id, file.map(|s| s.to_string()))) - .await .unwrap(); }, diff --git a/network/examples/fileshare/server.rs b/network/examples/fileshare/server.rs index 080b85fff5..b6cf6c38dd 100644 --- a/network/examples/fileshare/server.rs +++ b/network/examples/fileshare/server.rs @@ -1,11 +1,12 @@ use crate::commands::{Command, FileInfo, LocalCommand, RemoteInfo}; -use async_std::{ - fs, - path::PathBuf, - sync::{Mutex, RwLock}, +use futures_util::{FutureExt, StreamExt}; +use std::{collections::HashMap, path::PathBuf, sync::Arc}; +use tokio::{ + fs, join, + runtime::Runtime, + sync::{mpsc, Mutex, RwLock}, }; -use futures::{channel::mpsc, future::FutureExt, stream::StreamExt}; -use std::{collections::HashMap, sync::Arc}; +use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::*; use veloren_network::{Network, Participant, Pid, Promises, ProtocolAddr, Stream}; @@ -23,11 +24,10 @@ pub struct Server { } impl Server { - pub fn new() -> (Self, mpsc::UnboundedSender) { - let (command_sender, command_receiver) = mpsc::unbounded(); + pub fn new(runtime: Arc) -> (Self, mpsc::UnboundedSender) { + let (command_sender, command_receiver) = mpsc::unbounded_channel(); - let (network, f) = Network::new(Pid::new()); - std::thread::spawn(f); + let network = Network::new(Pid::new(), runtime); let run_channels = Some(ControlChannels { command_receiver }); ( @@ -47,7 +47,7 @@ impl Server { self.network.listen(address).await.unwrap(); - futures::join!( + join!( self.command_manager(run_channels.command_receiver,), self.connect_manager(), ); @@ -55,6 +55,7 @@ impl Server { async fn command_manager(&self, command_receiver: mpsc::UnboundedReceiver) { trace!("Start command_manager"); + let command_receiver = UnboundedReceiverStream::new(command_receiver); command_receiver .for_each_concurrent(None, async move |cmd| { match cmd { @@ -106,7 +107,7 @@ impl Server { async fn connect_manager(&self) { trace!("Start connect_manager"); - let iter = futures::stream::unfold((), |_| { + let iter = futures_util::stream::unfold((), |_| { self.network.connected().map(|r| r.ok().map(|v| (v, ()))) }); @@ -120,8 +121,8 @@ impl Server { #[allow(clippy::eval_order_dependence)] async fn loop_participant(&self, p: Participant) { if let (Ok(cmd_out), Ok(file_out), Ok(cmd_in), Ok(file_in)) = ( - p.open(15, Promises::ORDERED | Promises::CONSISTENCY).await, - p.open(40, Promises::CONSISTENCY).await, + p.open(3, Promises::ORDERED | Promises::CONSISTENCY).await, + p.open(6, Promises::CONSISTENCY).await, p.opened().await, p.opened().await, ) { @@ -129,7 +130,7 @@ impl Server { let id = p.remote_pid(); let ri = Arc::new(Mutex::new(RemoteInfo::new(cmd_out, file_out, p))); self.remotes.write().await.insert(id, ri.clone()); - futures::join!( + join!( self.handle_remote_cmd(cmd_in, ri.clone()), self.handle_files(file_in, ri.clone()), ); @@ -174,7 +175,7 @@ impl Server { let mut path = std::env::current_dir().unwrap(); path.push(fi.path().file_name().unwrap()); trace!("No path provided, saving down to {:?}", path); - PathBuf::from(path) + path }, }; debug!("Received file, going to save it under {:?}", path); diff --git a/network/examples/network-speed/main.rs b/network/examples/network-speed/main.rs index 5f0617ec68..bb6684658a 100644 --- a/network/examples/network-speed/main.rs +++ b/network/examples/network-speed/main.rs @@ -3,15 +3,17 @@ /// (cd network/examples/network-speed && RUST_BACKTRACE=1 cargo run --profile=debuginfo -Z unstable-options -- --trace=error --protocol=tcp --mode=server) /// (cd network/examples/network-speed && RUST_BACKTRACE=1 cargo run --profile=debuginfo -Z unstable-options -- --trace=error --protocol=tcp --mode=client) /// ``` -mod metrics; - use clap::{App, Arg}; -use futures::executor::block_on; +use prometheus::Registry; +use prometheus_hyper::Server; use serde::{Deserialize, Serialize}; use std::{ + net::SocketAddr, + sync::Arc, thread, time::{Duration, Instant}, }; +use tokio::runtime::Runtime; use tracing::*; use tracing_subscriber::EnvFilter; use veloren_network::{Message, Network, Pid, Promises, ProtocolAddr}; @@ -101,14 +103,16 @@ fn main() { }; let mut background = None; + let runtime = Arc::new(Runtime::new().unwrap()); match matches.value_of("mode") { - Some("server") => server(address), - Some("client") => client(address), + Some("server") => server(address, Arc::clone(&runtime)), + Some("client") => client(address, Arc::clone(&runtime)), Some("both") => { let address1 = address.clone(); - background = Some(thread::spawn(|| server(address1))); + let runtime2 = Arc::clone(&runtime); + background = Some(thread::spawn(|| server(address1, runtime2))); thread::sleep(Duration::from_millis(200)); //start client after server - client(address); + client(address, Arc::clone(&runtime)); }, _ => panic!("Invalid mode, run --help!"), }; @@ -117,18 +121,22 @@ fn main() { } } -fn server(address: ProtocolAddr) { - let mut metrics = metrics::SimpleMetrics::new(); - let (server, f) = Network::new_with_registry(Pid::new(), metrics.registry()); - std::thread::spawn(f); - metrics.run("0.0.0.0:59112".parse().unwrap()).unwrap(); - block_on(server.listen(address)).unwrap(); +fn server(address: ProtocolAddr, runtime: Arc) { + let registry = Arc::new(Registry::new()); + let server = Network::new_with_registry(Pid::new(), Arc::clone(&runtime), ®istry); + runtime.spawn(Server::run( + Arc::clone(®istry), + SocketAddr::from(([0; 4], 59112)), + futures_util::future::pending(), + )); + runtime.block_on(server.listen(address)).unwrap(); loop { + info!("----"); info!("Waiting for participant to connect"); - let p1 = block_on(server.connected()).unwrap(); //remote representation of p1 - let mut s1 = block_on(p1.opened()).unwrap(); //remote representation of s1 - block_on(async { + let p1 = runtime.block_on(server.connected()).unwrap(); //remote representation of p1 + let mut s1 = runtime.block_on(p1.opened()).unwrap(); //remote representation of s1 + runtime.block_on(async { let mut last = Instant::now(); let mut id = 0u64; while let Ok(_msg) = s1.recv_raw().await { @@ -145,14 +153,19 @@ fn server(address: ProtocolAddr) { } } -fn client(address: ProtocolAddr) { - let mut metrics = metrics::SimpleMetrics::new(); - let (client, f) = Network::new_with_registry(Pid::new(), metrics.registry()); - std::thread::spawn(f); - metrics.run("0.0.0.0:59111".parse().unwrap()).unwrap(); +fn client(address: ProtocolAddr, runtime: Arc) { + let registry = Arc::new(Registry::new()); + let client = Network::new_with_registry(Pid::new(), Arc::clone(&runtime), ®istry); + runtime.spawn(Server::run( + Arc::clone(®istry), + SocketAddr::from(([0; 4], 59111)), + futures_util::future::pending(), + )); - let p1 = block_on(client.connect(address)).unwrap(); //remote representation of p1 - let mut s1 = block_on(p1.open(16, Promises::ORDERED | Promises::CONSISTENCY)).unwrap(); //remote representation of s1 + let p1 = runtime.block_on(client.connect(address)).unwrap(); //remote representation of p1 + let mut s1 = runtime + .block_on(p1.open(4, Promises::ORDERED | Promises::CONSISTENCY)) + .unwrap(); //remote representation of s1 let mut last = Instant::now(); let mut id = 0u64; let raw_msg = Message::serialize( @@ -173,16 +186,16 @@ fn client(address: ProtocolAddr) { } if id > 2000000 { println!("Stop"); - std::thread::sleep(std::time::Duration::from_millis(5000)); + std::thread::sleep(std::time::Duration::from_millis(2000)); break; } } drop(s1); - std::thread::sleep(std::time::Duration::from_millis(5000)); + std::thread::sleep(std::time::Duration::from_millis(2000)); info!("Closing participant"); - block_on(p1.disconnect()).unwrap(); - std::thread::sleep(std::time::Duration::from_millis(25000)); + runtime.block_on(p1.disconnect()).unwrap(); + std::thread::sleep(std::time::Duration::from_millis(2000)); info!("DROPPING! client"); drop(client); - std::thread::sleep(std::time::Duration::from_millis(25000)); + std::thread::sleep(std::time::Duration::from_millis(2000)); } diff --git a/network/examples/network-speed/metrics.rs b/network/examples/network-speed/metrics.rs deleted file mode 100644 index 657a00abf4..0000000000 --- a/network/examples/network-speed/metrics.rs +++ /dev/null @@ -1,92 +0,0 @@ -use prometheus::{Encoder, Registry, TextEncoder}; -use std::{ - error::Error, - net::SocketAddr, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - thread, -}; -use tracing::*; - -pub struct SimpleMetrics { - running: Arc, - handle: Option>, - registry: Option, -} - -impl SimpleMetrics { - pub fn new() -> Self { - let running = Arc::new(AtomicBool::new(false)); - let registry = Some(Registry::new()); - - Self { - running, - handle: None, - registry, - } - } - - pub fn registry(&self) -> &Registry { - match self.registry { - Some(ref r) => r, - None => panic!("You cannot longer register new metrics after the server has started!"), - } - } - - pub fn run(&mut self, addr: SocketAddr) -> Result<(), Box> { - self.running.store(true, Ordering::Relaxed); - let running2 = self.running.clone(); - - let registry = self - .registry - .take() - .expect("ServerMetrics must be already started"); - - //TODO: make this a job - self.handle = Some(thread::spawn(move || { - let server = tiny_http::Server::http(addr).unwrap(); - const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(1); - debug!("starting tiny_http server to serve metrics"); - while running2.load(Ordering::Relaxed) { - let request = match server.recv_timeout(TIMEOUT) { - Ok(Some(rq)) => rq, - Ok(None) => continue, - Err(e) => { - println!("Error: {}", e); - break; - }, - }; - let mf = registry.gather(); - let encoder = TextEncoder::new(); - let mut buffer = vec![]; - encoder - .encode(&mf, &mut buffer) - .expect("Failed to encoder metrics text."); - let response = tiny_http::Response::from_string( - String::from_utf8(buffer).expect("Failed to parse bytes as a string."), - ); - if let Err(e) = request.respond(response) { - error!( - ?e, - "The metrics HTTP server had encountered and error with answering" - ) - } - } - debug!("Stopping tiny_http server to serve metrics"); - })); - Ok(()) - } -} - -impl Drop for SimpleMetrics { - fn drop(&mut self) { - self.running.store(false, Ordering::Relaxed); - let handle = self.handle.take(); - handle - .expect("ServerMetrics worker handle does not exist.") - .join() - .expect("Error shutting down prometheus metric exporter"); - } -} diff --git a/network/protocol/Cargo.toml b/network/protocol/Cargo.toml new file mode 100644 index 0000000000..a9bd701940 --- /dev/null +++ b/network/protocol/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "veloren-network-protocol" +description = "pure Protocol without any I/O itself" +version = "0.5.0" +authors = ["Marcel Märtens "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +metrics = ["prometheus"] +trace_pedantic = [] # use for debug only + +default = ["metrics"] + +[dependencies] + +#tracing and metrics +tracing = { version = "0.1", default-features = false } +prometheus = { version = "0.11", default-features = false, optional = true } +#stream flags +bitflags = "1.2.1" +rand = { version = "0.8" } +# async traits +async-trait = "0.1.42" +bytes = "^1" + +[dev-dependencies] +async-channel = "1.5.1" +tokio = { version = "1.2", default-features = false, features = ["rt", "macros"] } +criterion = { version = "0.3.4", features = ["default", "async_tokio"] } + +[[bench]] +name = "protocols" +harness = false \ No newline at end of file diff --git a/network/protocol/benches/protocols.rs b/network/protocol/benches/protocols.rs new file mode 100644 index 0000000000..dfe6a57084 --- /dev/null +++ b/network/protocol/benches/protocols.rs @@ -0,0 +1,262 @@ +use async_channel::*; +use async_trait::async_trait; +use bytes::{Bytes, BytesMut}; +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; +use std::{sync::Arc, time::Duration}; +use tokio::runtime::Runtime; +use veloren_network_protocol::{ + InitProtocol, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid, Promises, ProtocolError, + ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, RecvProtocol, SendProtocol, Sid, + TcpRecvProtocol, TcpSendProtocol, UnreliableDrain, UnreliableSink, _internal::OTFrame, +}; + +fn frame_serialize(frame: OTFrame, buffer: &mut BytesMut) { frame.write_bytes(buffer); } + +async fn handshake(p: [(S, R); 2]) +where + S: SendProtocol, + R: RecvProtocol, + (S, R): InitProtocol, +{ + let [mut p1, mut p2] = p; + tokio::join!( + async { + p1.initialize(true, Pid::fake(2), 1337).await.unwrap(); + p1 + }, + async { + p2.initialize(false, Pid::fake(3), 42).await.unwrap(); + p2 + } + ); +} + +async fn send_msg(mut s: T, data: Bytes, cnt: usize) { + let bandwidth = data.len() as u64 + 100; + const SEC1: Duration = Duration::from_secs(1); + + s.send(ProtocolEvent::OpenStream { + sid: Sid::new(12), + prio: 0, + promises: Promises::ORDERED, + guaranteed_bandwidth: 100_000, + }) + .await + .unwrap(); + + for i in 0..cnt { + s.send(ProtocolEvent::Message { + sid: Sid::new(12), + mid: i as u64, + data: data.clone(), + }) + .await + .unwrap(); + if i.rem_euclid(50) == 0 { + s.flush(bandwidth * 50_u64, SEC1).await.unwrap(); + } + } + s.flush(bandwidth * 1000_u64, SEC1).await.unwrap(); +} + +async fn recv_msg(mut r: T, cnt: usize) { + r.recv().await.unwrap(); + + for _ in 0..cnt { + r.recv().await.unwrap(); + } +} + +async fn send_and_recv_msg( + p: [(S, R); 2], + data: Bytes, + cnt: usize, +) { + let [p1, p2] = p; + let (s, r) = (p1.0, p2.1); + + tokio::join!(send_msg(s, data, cnt), recv_msg(r, cnt)); +} + +fn rt() -> Runtime { + tokio::runtime::Builder::new_current_thread() + .build() + .unwrap() +} + +fn criterion_util(c: &mut Criterion) { + c.bench_function("mpsc_handshake", |b| { + b.to_async(rt()) + .iter_with_setup(|| utils::ac_bound(10, None), handshake) + }); + c.bench_function("frame_serialize_short", |b| { + let mut buffer = BytesMut::with_capacity(1500); + let frame = OTFrame::Data { + mid: 65, + start: 89u64, + data: Bytes::from(&b"hello_world"[..]), + }; + b.iter_with_setup( + || frame.clone(), + |frame| frame_serialize(frame, &mut buffer), + ) + }); +} + +fn criterion_mpsc(c: &mut Criterion) { + let mut c = c.benchmark_group("mpsc"); + c.significance_level(0.1).sample_size(10); + c.throughput(Throughput::Bytes(1000000000)) + .bench_function("1GB_in_10000_msg", |b| { + let buffer = Bytes::from(&[155u8; 100_000][..]); + b.to_async(rt()).iter_with_setup( + || (buffer.clone(), utils::ac_bound(10, None)), + |(b, p)| send_and_recv_msg(p, b, 10_000), + ) + }); + c.throughput(Throughput::Elements(1000000)) + .bench_function("1000000_tiny_msg", |b| { + let buffer = Bytes::from(&[3u8; 5][..]); + b.to_async(rt()).iter_with_setup( + || (buffer.clone(), utils::ac_bound(10, None)), + |(b, p)| send_and_recv_msg(p, b, 1_000_000), + ) + }); + c.finish(); +} + +fn criterion_tcp(c: &mut Criterion) { + let mut c = c.benchmark_group("tcp"); + c.significance_level(0.1).sample_size(10); + c.throughput(Throughput::Bytes(1000000000)) + .bench_function("1GB_in_10000_msg", |b| { + let buf = Bytes::from(&[155u8; 100_000][..]); + b.to_async(rt()).iter_with_setup( + || (buf.clone(), utils::tcp_bound(10000, None)), + |(b, p)| send_and_recv_msg(p, b, 10_000), + ) + }); + c.throughput(Throughput::Elements(1000000)) + .bench_function("1000000_tiny_msg", |b| { + let buf = Bytes::from(&[3u8; 5][..]); + b.to_async(rt()).iter_with_setup( + || (buf.clone(), utils::tcp_bound(10000, None)), + |(b, p)| send_and_recv_msg(p, b, 1_000_000), + ) + }); + c.finish(); +} + +criterion_group!(benches, criterion_util, criterion_mpsc, criterion_tcp); +criterion_main!(benches); + +mod utils { + use super::*; + + pub struct ACDrain { + sender: Sender, + } + + pub struct ACSink { + receiver: Receiver, + } + + pub fn ac_bound( + cap: usize, + metrics: Option, + ) -> [(MpscSendProtocol, MpscRecvProtocol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("mpsc", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + MpscSendProtocol::new(ACDrain { sender: s1 }, m.clone()), + MpscRecvProtocol::new(ACSink { receiver: r2 }, m.clone()), + ), + ( + MpscSendProtocol::new(ACDrain { sender: s2 }, m.clone()), + MpscRecvProtocol::new(ACSink { receiver: r1 }, m), + ), + ] + } + + pub struct TcpDrain { + sender: Sender, + } + + pub struct TcpSink { + receiver: Receiver, + } + + /// emulate Tcp protocol on Channels + pub fn tcp_bound( + cap: usize, + metrics: Option, + ) -> [(TcpSendProtocol, TcpRecvProtocol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("tcp", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + TcpSendProtocol::new(TcpDrain { sender: s1 }, m.clone()), + TcpRecvProtocol::new(TcpSink { receiver: r2 }, m.clone()), + ), + ( + TcpSendProtocol::new(TcpDrain { sender: s2 }, m.clone()), + TcpRecvProtocol::new(TcpSink { receiver: r1 }, m), + ), + ] + } + + #[async_trait] + impl UnreliableDrain for ACDrain { + type DataFormat = MpscMsg; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for ACSink { + type DataFormat = MpscMsg; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableDrain for TcpDrain { + type DataFormat = BytesMut; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for TcpSink { + type DataFormat = BytesMut; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } +} diff --git a/network/protocol/src/event.rs b/network/protocol/src/event.rs new file mode 100644 index 0000000000..cc332e5d3c --- /dev/null +++ b/network/protocol/src/event.rs @@ -0,0 +1,76 @@ +use crate::{ + frame::OTFrame, + types::{Bandwidth, Mid, Prio, Promises, Sid}, +}; +use bytes::Bytes; + +/// used for communication with [`SendProtocol`] and [`RecvProtocol`] +/// +/// [`SendProtocol`]: crate::SendProtocol +/// [`RecvProtocol`]: crate::RecvProtocol +#[derive(Debug, Clone)] +#[cfg_attr(test, derive(PartialEq))] +pub enum ProtocolEvent { + Shutdown, + OpenStream { + sid: Sid, + prio: Prio, + promises: Promises, + guaranteed_bandwidth: Bandwidth, + }, + CloseStream { + sid: Sid, + }, + Message { + data: Bytes, + mid: Mid, + sid: Sid, + }, +} + +impl ProtocolEvent { + pub(crate) fn to_frame(&self) -> OTFrame { + match self { + ProtocolEvent::Shutdown => OTFrame::Shutdown, + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth: _, + } => OTFrame::OpenStream { + sid: *sid, + prio: *prio, + promises: *promises, + }, + ProtocolEvent::CloseStream { sid } => OTFrame::CloseStream { sid: *sid }, + ProtocolEvent::Message { .. } => { + unimplemented!("Event::Message to OTFrame IS NOT supported") + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_to_frame() { + assert_eq!(ProtocolEvent::Shutdown.to_frame(), OTFrame::Shutdown); + assert_eq!( + ProtocolEvent::CloseStream { sid: Sid::new(42) }.to_frame(), + OTFrame::CloseStream { sid: Sid::new(42) } + ); + } + + #[test] + #[should_panic] + fn test_msg_buffer_panic() { + let _ = ProtocolEvent::Message { + data: Bytes::new(), + mid: 0, + sid: Sid::new(23), + } + .to_frame(); + } +} diff --git a/network/protocol/src/frame.rs b/network/protocol/src/frame.rs new file mode 100644 index 0000000000..a490d67b4d --- /dev/null +++ b/network/protocol/src/frame.rs @@ -0,0 +1,565 @@ +use crate::types::{Mid, Pid, Prio, Promises, Sid}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +// const FRAME_RESERVED_1: u8 = 0; +const FRAME_HANDSHAKE: u8 = 1; +const FRAME_INIT: u8 = 2; +const FRAME_SHUTDOWN: u8 = 3; +const FRAME_OPEN_STREAM: u8 = 4; +const FRAME_CLOSE_STREAM: u8 = 5; +const FRAME_DATA_HEADER: u8 = 6; +const FRAME_DATA: u8 = 7; +const FRAME_RAW: u8 = 8; +//const FRAME_RESERVED_2: u8 = 10; +//const FRAME_RESERVED_3: u8 = 13; + +/// Used for Communication between Channel <----(TCP/UDP)----> Channel +#[derive(Debug, PartialEq, Clone)] +pub enum InitFrame { + Handshake { + magic_number: [u8; 7], + version: [u32; 3], + }, + Init { + pid: Pid, + secret: u128, + }, + /// WARNING: sending RAW is only for debug purposes and will drop the + /// connection + Raw(Vec), +} + +/// Used for OUT TCP Communication between Channel --(TCP)--> Channel +#[derive(Debug, PartialEq, Clone)] +pub enum OTFrame { + Shutdown, /* Shutdown this channel gracefully, if all channels are shutdown (gracefully), + * Participant is deleted */ + OpenStream { + sid: Sid, + prio: Prio, + promises: Promises, + }, + CloseStream { + sid: Sid, + }, + DataHeader { + mid: Mid, + sid: Sid, + length: u64, + }, + Data { + mid: Mid, + start: u64, /* remove */ + data: Bytes, + }, +} + +/// Used for IN TCP Communication between Channel <--(TCP)-- Channel +#[derive(Debug, PartialEq, Clone)] +pub enum ITFrame { + Shutdown, /* Shutdown this channel gracefully, if all channels are shutdown (gracefully), + * Participant is deleted */ + OpenStream { + sid: Sid, + prio: Prio, + promises: Promises, + }, + CloseStream { + sid: Sid, + }, + DataHeader { + mid: Mid, + sid: Sid, + length: u64, + }, + Data { + mid: Mid, + start: u64, /* remove */ + data: BytesMut, + }, +} + +impl InitFrame { + // Size WITHOUT the 1rst indicating byte + pub(crate) const HANDSHAKE_CNS: usize = 19; + pub(crate) const INIT_CNS: usize = 32; + /// const part of the RAW frame, actual size is variable + pub(crate) const RAW_CNS: usize = 2; + + //provide an appropriate buffer size. > 1500 + pub(crate) fn write_bytes(self, bytes: &mut BytesMut) { + match self { + InitFrame::Handshake { + magic_number, + version, + } => { + bytes.put_u8(FRAME_HANDSHAKE); + bytes.put_slice(&magic_number); + bytes.put_u32_le(version[0]); + bytes.put_u32_le(version[1]); + bytes.put_u32_le(version[2]); + }, + InitFrame::Init { pid, secret } => { + bytes.put_u8(FRAME_INIT); + pid.to_bytes(bytes); + bytes.put_u128_le(secret); + }, + InitFrame::Raw(data) => { + bytes.put_u8(FRAME_RAW); + bytes.put_u16_le(data.len() as u16); + bytes.put_slice(&data); + }, + } + } + + pub(crate) fn read_frame(bytes: &mut BytesMut) -> Option { + let frame_no = match bytes.get(0) { + Some(&f) => f, + None => return None, + }; + let frame = match frame_no { + FRAME_HANDSHAKE => { + if bytes.len() < Self::HANDSHAKE_CNS + 1 { + return None; + } + bytes.advance(1); + let mut magic_number_bytes = bytes.copy_to_bytes(7); + let mut magic_number = [0u8; 7]; + magic_number_bytes.copy_to_slice(&mut magic_number); + InitFrame::Handshake { + magic_number, + version: [bytes.get_u32_le(), bytes.get_u32_le(), bytes.get_u32_le()], + } + }, + FRAME_INIT => { + if bytes.len() < Self::INIT_CNS + 1 { + return None; + } + bytes.advance(1); + InitFrame::Init { + pid: Pid::from_bytes(bytes), + secret: bytes.get_u128_le(), + } + }, + FRAME_RAW => { + if bytes.len() < Self::RAW_CNS + 1 { + return None; + } + bytes.advance(1); + let length = bytes.get_u16_le() as usize; + // lower length is allowed + let max_length = length.min(bytes.len()); + let mut data = vec![0; max_length]; + data.copy_from_slice(&bytes[..max_length]); + InitFrame::Raw(data) + }, + _ => InitFrame::Raw(bytes.to_vec()), + }; + Some(frame) + } +} + +pub(crate) const TCP_CLOSE_STREAM_CNS: usize = 8; +/// const part of the DATA frame, actual size is variable +pub(crate) const TCP_DATA_CNS: usize = 18; +pub(crate) const TCP_DATA_HEADER_CNS: usize = 24; +pub(crate) const TCP_OPEN_STREAM_CNS: usize = 10; +// Size WITHOUT the 1rst indicating byte +pub(crate) const TCP_SHUTDOWN_CNS: usize = 0; + +impl OTFrame { + pub fn write_bytes(self, bytes: &mut BytesMut) { + match self { + Self::Shutdown => { + bytes.put_u8(FRAME_SHUTDOWN); + }, + Self::OpenStream { + sid, + prio, + promises, + } => { + bytes.put_u8(FRAME_OPEN_STREAM); + sid.to_bytes(bytes); + bytes.put_u8(prio); + bytes.put_u8(promises.to_le_bytes()[0]); + }, + Self::CloseStream { sid } => { + bytes.put_u8(FRAME_CLOSE_STREAM); + sid.to_bytes(bytes); + }, + Self::DataHeader { mid, sid, length } => { + bytes.put_u8(FRAME_DATA_HEADER); + bytes.put_u64_le(mid); + sid.to_bytes(bytes); + bytes.put_u64_le(length); + }, + Self::Data { mid, start, data } => { + bytes.put_u8(FRAME_DATA); + bytes.put_u64_le(mid); + bytes.put_u64_le(start); + bytes.put_u16_le(data.len() as u16); + bytes.put_slice(&data); + }, + } + } +} + +impl ITFrame { + pub(crate) fn read_frame(bytes: &mut BytesMut) -> Option { + let frame_no = match bytes.first() { + Some(&f) => f, + None => return None, + }; + let size = match frame_no { + FRAME_SHUTDOWN => TCP_SHUTDOWN_CNS, + FRAME_OPEN_STREAM => TCP_OPEN_STREAM_CNS, + FRAME_CLOSE_STREAM => TCP_CLOSE_STREAM_CNS, + FRAME_DATA_HEADER => TCP_DATA_HEADER_CNS, + FRAME_DATA => { + if bytes.len() < 17 + 1 + 1 { + return None; + } + u16::from_le_bytes([bytes[16 + 1], bytes[17 + 1]]) as usize + TCP_DATA_CNS + }, + _ => return None, + }; + + if bytes.len() < size + 1 { + return None; + } + + let frame = match frame_no { + FRAME_SHUTDOWN => { + let _ = bytes.split_to(size + 1); + Self::Shutdown + }, + FRAME_OPEN_STREAM => { + let mut bytes = bytes.split_to(size + 1); + bytes.advance(1); + Self::OpenStream { + sid: Sid::from_bytes(&mut bytes), + prio: bytes.get_u8(), + promises: Promises::from_bits_truncate(bytes.get_u8()), + } + }, + FRAME_CLOSE_STREAM => { + let mut bytes = bytes.split_to(size + 1); + bytes.advance(1); + Self::CloseStream { + sid: Sid::from_bytes(&mut bytes), + } + }, + FRAME_DATA_HEADER => { + let mut bytes = bytes.split_to(size + 1); + bytes.advance(1); + Self::DataHeader { + mid: bytes.get_u64_le(), + sid: Sid::from_bytes(&mut bytes), + length: bytes.get_u64_le(), + } + }, + FRAME_DATA => { + bytes.advance(1); + let mid = bytes.get_u64_le(); + let start = bytes.get_u64_le(); + let length = bytes.get_u16_le(); + debug_assert_eq!(length as usize, size - TCP_DATA_CNS); + let data = bytes.split_to(length as usize); + Self::Data { mid, start, data } + }, + _ => unreachable!("Frame::to_frame should be handled before!"), + }; + Some(frame) + } +} + +#[allow(unused_variables)] +impl PartialEq for OTFrame { + fn eq(&self, other: &ITFrame) -> bool { + match self { + Self::Shutdown => matches!(other, ITFrame::Shutdown), + Self::OpenStream { + sid, + prio, + promises, + } => matches!(other, ITFrame::OpenStream { + sid, + prio, + promises + }), + Self::CloseStream { sid } => matches!(other, ITFrame::CloseStream { sid }), + Self::DataHeader { mid, sid, length } => { + matches!(other, ITFrame::DataHeader { mid, sid, length }) + }, + Self::Data { mid, start, data } => matches!(other, ITFrame::Data { mid, start, data }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::{VELOREN_MAGIC_NUMBER, VELOREN_NETWORK_VERSION}; + + fn get_initframes() -> Vec { + vec![ + InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }, + InitFrame::Init { + pid: Pid::fake(0), + secret: 0u128, + }, + InitFrame::Raw(vec![1, 2, 3]), + ] + } + + fn get_otframes() -> Vec { + vec![ + OTFrame::OpenStream { + sid: Sid::new(1337), + prio: 14, + promises: Promises::GUARANTEED_DELIVERY, + }, + OTFrame::DataHeader { + sid: Sid::new(1337), + mid: 0, + length: 36, + }, + OTFrame::Data { + mid: 0, + start: 0, + data: Bytes::from(&[77u8; 20][..]), + }, + OTFrame::Data { + mid: 0, + start: 20, + data: Bytes::from(&[42u8; 16][..]), + }, + OTFrame::CloseStream { + sid: Sid::new(1337), + }, + OTFrame::Shutdown, + ] + } + + #[test] + fn initframe_individual() { + let dupl = |frame: InitFrame| { + let mut buffer = BytesMut::with_capacity(1500); + InitFrame::write_bytes(frame, &mut buffer); + InitFrame::read_frame(&mut buffer) + }; + + for frame in get_initframes() { + println!("initframe: {:?}", &frame); + assert_eq!(Some(frame.clone()), dupl(frame)); + } + } + + #[test] + fn initframe_multiple() { + let mut buffer = BytesMut::with_capacity(3000); + + let mut frames = get_initframes(); + // to string + for f in &frames { + InitFrame::write_bytes(f.clone(), &mut buffer); + } + + // from string + let mut framesd = frames + .iter() + .map(|&_| InitFrame::read_frame(&mut buffer)) + .collect::>(); + + // compare + for (f, fd) in frames.drain(..).zip(framesd.drain(..)) { + println!("initframe: {:?}", &f); + assert_eq!(Some(f), fd); + } + } + + #[test] + fn frame_individual() { + let dupl = |frame: OTFrame| { + let mut buffer = BytesMut::with_capacity(1500); + OTFrame::write_bytes(frame, &mut buffer); + ITFrame::read_frame(&mut buffer) + }; + + for frame in get_otframes() { + println!("frame: {:?}", &frame); + assert_eq!(frame.clone(), dupl(frame).expect("NONE")); + } + } + + #[test] + fn frame_multiple() { + let mut buffer = BytesMut::with_capacity(3000); + + let mut frames = get_otframes(); + // to string + for f in &frames { + OTFrame::write_bytes(f.clone(), &mut buffer); + } + + // from string + let mut framesd = frames + .iter() + .map(|&_| ITFrame::read_frame(&mut buffer)) + .collect::>(); + + // compare + for (f, fd) in frames.drain(..).zip(framesd.drain(..)) { + println!("frame: {:?}", &f); + assert_eq!(f, fd.expect("NONE")); + } + } + + #[test] + fn frame_exact_size() { + const SIZE: usize = TCP_CLOSE_STREAM_CNS+1/*first byte*/; + let mut buffer = BytesMut::with_capacity(SIZE); + + let frame1 = OTFrame::CloseStream { sid: Sid::new(2) }; + OTFrame::write_bytes(frame1.clone(), &mut buffer); + assert_eq!(buffer.len(), SIZE); + let mut deque = buffer.iter().copied().collect(); + let frame2 = ITFrame::read_frame(&mut deque); + assert_eq!(frame1, frame2.expect("NONE")); + } + + #[test] + fn initframe_too_short_buffer() { + let mut buffer = BytesMut::with_capacity(10); + + let frame1 = InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }; + InitFrame::write_bytes(frame1, &mut buffer); + } + + #[test] + fn initframe_too_less_data() { + let mut buffer = BytesMut::with_capacity(20); + + let frame1 = InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }; + let _ = InitFrame::write_bytes(frame1, &mut buffer); + buffer.truncate(6); // simulate partial retrieve + let frame1d = InitFrame::read_frame(&mut buffer); + assert_eq!(frame1d, None); + } + + #[test] + fn initframe_rubish() { + let mut buffer = BytesMut::from(&b"dtrgwcser"[..]); + assert_eq!( + InitFrame::read_frame(&mut buffer), + Some(InitFrame::Raw(b"dtrgwcser".to_vec())) + ); + } + + #[test] + fn initframe_attack_too_much_length() { + let mut buffer = BytesMut::with_capacity(50); + + let frame1 = InitFrame::Raw(b"foobar".to_vec()); + let _ = InitFrame::write_bytes(frame1.clone(), &mut buffer); + buffer[1] = 255; + let framed = InitFrame::read_frame(&mut buffer); + assert_eq!(framed, Some(frame1)); + } + + #[test] + fn initframe_attack_too_low_length() { + let mut buffer = BytesMut::with_capacity(50); + + let frame1 = InitFrame::Raw(b"foobar".to_vec()); + let _ = InitFrame::write_bytes(frame1, &mut buffer); + buffer[1] = 3; + let framed = InitFrame::read_frame(&mut buffer); + // we accept a different frame here, as it's RAW and debug only! + assert_eq!(framed, Some(InitFrame::Raw(b"foo".to_vec()))); + } + + #[test] + fn frame_too_short_buffer() { + let mut buffer = BytesMut::with_capacity(10); + + let frame1 = OTFrame::OpenStream { + sid: Sid::new(88), + promises: Promises::ENCRYPTED, + prio: 88, + }; + OTFrame::write_bytes(frame1, &mut buffer); + } + + #[test] + fn frame_too_less_data() { + let mut buffer = BytesMut::with_capacity(20); + + let frame1 = OTFrame::OpenStream { + sid: Sid::new(88), + promises: Promises::ENCRYPTED, + prio: 88, + }; + OTFrame::write_bytes(frame1, &mut buffer); + buffer.truncate(6); // simulate partial retrieve + let frame1d = ITFrame::read_frame(&mut buffer); + assert_eq!(frame1d, None); + } + + #[test] + fn frame_rubish() { + let mut buffer = BytesMut::from(&b"dtrgwcser"[..]); + assert_eq!(ITFrame::read_frame(&mut buffer), None); + } + + #[test] + fn frame_attack_too_much_length() { + let mut buffer = BytesMut::with_capacity(50); + + let frame1 = OTFrame::Data { + mid: 7u64, + start: 1u64, + data: Bytes::from(&b"foobar"[..]), + }; + + OTFrame::write_bytes(frame1, &mut buffer); + buffer[17] = 255; + let framed = ITFrame::read_frame(&mut buffer); + assert_eq!(framed, None); + } + + #[test] + fn frame_attack_too_low_length() { + let mut buffer = BytesMut::with_capacity(50); + + let frame1 = OTFrame::Data { + mid: 7u64, + start: 1u64, + data: Bytes::from(&b"foobar"[..]), + }; + + OTFrame::write_bytes(frame1, &mut buffer); + buffer[17] = 3; + let framed = ITFrame::read_frame(&mut buffer); + assert_eq!( + framed, + Some(ITFrame::Data { + mid: 7u64, + start: 1u64, + data: BytesMut::from(&b"foo"[..]), + }) + ); + //next = Invalid => Empty + let framed = ITFrame::read_frame(&mut buffer); + assert_eq!(framed, None); + } +} diff --git a/network/protocol/src/handshake.rs b/network/protocol/src/handshake.rs new file mode 100644 index 0000000000..fda3893d72 --- /dev/null +++ b/network/protocol/src/handshake.rs @@ -0,0 +1,239 @@ +use crate::{ + frame::InitFrame, + types::{ + Pid, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2, VELOREN_MAGIC_NUMBER, + VELOREN_NETWORK_VERSION, + }, + InitProtocol, InitProtocolError, ProtocolError, +}; +use async_trait::async_trait; +use tracing::{debug, error, info, trace}; + +/// Implement this for auto Handshake with [`ReliableSink`]. +/// You must make sure that EVERY message send this way actually is received on +/// the receiving site: +/// - exactly once +/// - in the correct order +/// - correctly +/// +/// [`ReliableSink`]: crate::ReliableSink +/// [`RecvProtocol`]: crate::RecvProtocol +#[async_trait] +pub trait ReliableDrain { + async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError>; +} + +/// Implement this for auto Handshake with [`ReliableDrain`]. See +/// [`ReliableDrain`]. +/// +/// [`ReliableDrain`]: crate::ReliableDrain +#[async_trait] +pub trait ReliableSink { + async fn recv(&mut self) -> Result; +} + +#[async_trait] +impl InitProtocol for (D, S) +where + D: ReliableDrain + Send, + S: ReliableSink + Send, +{ + async fn initialize( + &mut self, + initializer: bool, + local_pid: Pid, + local_secret: u128, + ) -> Result<(Pid, Sid, u128), InitProtocolError> { + #[cfg(debug_assertions)] + const WRONG_NUMBER: &str = "Handshake does not contain the magic number required by \ + veloren server.\nWe are not sure if you are a valid veloren \ + client.\nClosing the connection"; + #[cfg(debug_assertions)] + const WRONG_VERSION: &str = "Handshake does contain a correct magic number, but invalid \ + version.\nWe don't know how to communicate with \ + you.\nClosing the connection"; + const ERR_S: &str = "Got A Raw Message, these are usually Debug Messages indicating that \ + something went wrong on network layer and connection will be closed"; + + let drain = &mut self.0; + let sink = &mut self.1; + + if initializer { + drain + .send(InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }) + .await?; + } + + match sink.recv().await? { + InitFrame::Handshake { + magic_number, + version, + } => { + trace!(?magic_number, ?version, "Recv handshake"); + if magic_number != VELOREN_MAGIC_NUMBER { + error!(?magic_number, "Connection with invalid magic_number"); + #[cfg(debug_assertions)] + drain + .send(InitFrame::Raw(WRONG_NUMBER.as_bytes().to_vec())) + .await?; + Err(InitProtocolError::WrongMagicNumber(magic_number)) + } else if version != VELOREN_NETWORK_VERSION { + error!(?version, "Connection with wrong network version"); + #[cfg(debug_assertions)] + drain + .send(InitFrame::Raw( + format!( + "{} Our Version: {:?}\nYour Version: {:?}\nClosing the connection", + WRONG_VERSION, VELOREN_NETWORK_VERSION, version, + ) + .as_bytes() + .to_vec(), + )) + .await?; + Err(InitProtocolError::WrongVersion(version)) + } else { + trace!("Handshake Frame completed"); + if initializer { + drain + .send(InitFrame::Init { + pid: local_pid, + secret: local_secret, + }) + .await?; + } else { + drain + .send(InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }) + .await?; + } + Ok(()) + } + }, + InitFrame::Raw(bytes) => { + match std::str::from_utf8(bytes.as_slice()) { + Ok(string) => error!(?string, ERR_S), + _ => error!(?bytes, ERR_S), + } + Err(InitProtocolError::Closed) + }, + _ => { + info!("Handshake failed"); + Err(InitProtocolError::Closed) + }, + }?; + + match sink.recv().await? { + InitFrame::Init { pid, secret } => { + debug!(?pid, "Participant send their ID"); + let stream_id_offset = if initializer { + STREAM_ID_OFFSET1 + } else { + drain + .send(InitFrame::Init { + pid: local_pid, + secret: local_secret, + }) + .await?; + STREAM_ID_OFFSET2 + }; + info!(?pid, "This Handshake is now configured!"); + Ok((pid, stream_id_offset, secret)) + }, + InitFrame::Raw(bytes) => { + match std::str::from_utf8(bytes.as_slice()) { + Ok(string) => error!(?string, ERR_S), + _ => error!(?bytes, ERR_S), + } + Err(InitProtocolError::Closed) + }, + _ => { + info!("Handshake failed"); + Err(InitProtocolError::Closed) + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{mpsc::test_utils::*, InitProtocolError}; + + #[tokio::test] + async fn handshake_drop_start() { + let [mut p1, p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { + let _ = p2; + }); + let (r1, _) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Err(InitProtocolError::Closed)); + } + + #[tokio::test] + async fn handshake_wrong_magic_number() { + let [mut p1, mut p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { + let _ = p2.1.recv().await?; + p2.0.send(InitFrame::Handshake { + magic_number: *b"woopsie", + version: VELOREN_NETWORK_VERSION, + }) + .await?; + let _ = p2.1.recv().await?; + Result::<(), InitProtocolError>::Ok(()) + }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!( + r1.unwrap(), + Err(InitProtocolError::WrongMagicNumber(*b"woopsie")) + ); + assert_eq!(r2.unwrap(), Ok(())); + } + + #[tokio::test] + async fn handshake_wrong_version() { + let [mut p1, mut p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { + let _ = p2.1.recv().await?; + p2.0.send(InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: [0, 1, 2], + }) + .await?; + let _ = p2.1.recv().await?; + let _ = p2.1.recv().await?; //this should be closed now + Ok(()) + }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Err(InitProtocolError::WrongVersion([0, 1, 2]))); + assert_eq!(r2.unwrap(), Err(InitProtocolError::Closed)); + } + + #[tokio::test] + async fn handshake_unexpected_raw() { + let [mut p1, mut p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { + let _ = p2.1.recv().await?; + p2.0.send(InitFrame::Handshake { + magic_number: VELOREN_MAGIC_NUMBER, + version: VELOREN_NETWORK_VERSION, + }) + .await?; + let _ = p2.1.recv().await?; + p2.0.send(InitFrame::Raw(b"Hello World".to_vec())).await?; + Result::<(), InitProtocolError>::Ok(()) + }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Err(InitProtocolError::Closed)); + assert_eq!(r2.unwrap(), Ok(())); + } +} diff --git a/network/protocol/src/lib.rs b/network/protocol/src/lib.rs new file mode 100644 index 0000000000..fc22b9b711 --- /dev/null +++ b/network/protocol/src/lib.rs @@ -0,0 +1,178 @@ +//! Network Protocol +//! +//! a I/O-Free protocol for the veloren network crate. +//! This crate defines multiple different protocols over [`UnreliableDrain`] and +//! [`UnreliableSink`] traits, which allows it to define the behavior of a +//! protocol separated from the actual io. +//! +//! For example we define the TCP protocol on top of Drains and Sinks that can +//! send chunks of bytes. You can now implement your own Drain And Sink that +//! sends the data via tokio's or std's implementation. Or you just use a +//! std::mpsc::channel for unit tests without needing a actual tcp socket. +//! +//! This crate currently defines: +//! - TCP +//! - MPSC +//! +//! a UDP implementation will quickly follow, and it's also possible to abstract +//! over QUIC. +//! +//! warning: don't mix protocol, using the TCP variant for actual UDP socket +//! will result in dropped data using UDP with a TCP socket will be a waste of +//! resources. +//! +//! A *channel* in this crate is defined as a combination of *read* and *write* +//! protocol. +//! +//! # adding a protocol +//! +//! We start by defining our DataFormat. For most this is prob [`Vec`] or +//! [`Bytes`]. MPSC can directly send a msg without serialisation. +//! +//! Create 2 structs, one for the receiving and sending end. Based on a generic +//! Drain/Sink with your required DataFormat. +//! Implement the [`SendProtocol`] and [`RecvProtocol`] traits respectively. +//! +//! Implement the Handshake: [`InitProtocol`], alternatively you can also +//! implement `ReliableDrain` and `ReliableSink`, by this, you use the default +//! Handshake. +//! +//! This crate also contains consts and definitions for the network protocol. +//! +//! For an *example* see `TcpDrain` and `TcpSink` in the [tcp.rs](tcp.rs) +//! +//! [`UnreliableDrain`]: crate::UnreliableDrain +//! [`UnreliableSink`]: crate::UnreliableSink +//! [`Vec`]: std::vec::Vec +//! [`Bytes`]: bytes::Bytes +//! [`SendProtocol`]: crate::SendProtocol +//! [`RecvProtocol`]: crate::RecvProtocol +//! [`InitProtocol`]: crate::InitProtocol + +mod event; +mod frame; +mod handshake; +mod message; +mod metrics; +mod mpsc; +mod prio; +mod tcp; +mod types; + +pub use event::ProtocolEvent; +pub use metrics::ProtocolMetricCache; +#[cfg(feature = "metrics")] +pub use metrics::ProtocolMetrics; +pub use mpsc::{MpscMsg, MpscRecvProtocol, MpscSendProtocol}; +pub use tcp::{TcpRecvProtocol, TcpSendProtocol}; +pub use types::{ + Bandwidth, Cid, Mid, Pid, Prio, Promises, Sid, HIGHEST_PRIO, VELOREN_NETWORK_VERSION, +}; + +///use at own risk, might change any time, for internal benchmarks +pub mod _internal { + pub use crate::frame::{ITFrame, OTFrame}; +} + +use async_trait::async_trait; + +/// Handshake: Used to connect 2 Channels. +#[async_trait] +pub trait InitProtocol { + async fn initialize( + &mut self, + initializer: bool, + local_pid: Pid, + secret: u128, + ) -> Result<(Pid, Sid, u128), InitProtocolError>; +} + +/// Generic Network Send Protocol. +/// Implement this for your Protocol of choice ( tcp, udp, mpsc, quic) +/// Allows the creation/deletions of `Streams` and sending messages via +/// [`ProtocolEvent`]. +/// +/// A `Stream` MUST be bound to a specific Channel. You MUST NOT switch the +/// channel to send a stream mid air. We will provide takeover options for +/// Channel closure in the future to allow keeping a `Stream` over a broker +/// Channel. +/// +/// [`ProtocolEvent`]: crate::ProtocolEvent +#[async_trait] +pub trait SendProtocol { + /// YOU MUST inform the `SendProtocol` by any Stream Open BEFORE using it in + /// `send` and Stream Close AFTER using it in `send` via this fn. + fn notify_from_recv(&mut self, event: ProtocolEvent); + /// Send a Event via this Protocol. The `SendProtocol` MAY require `flush` + /// to be called before actual data is send to the respective `Sink`. + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError>; + /// Flush all buffered messages according to their [`Prio`] and + /// [`Bandwidth`]. provide the current bandwidth budget (per second) as + /// well as the `dt` since last call. According to the budget the + /// respective messages will be flushed. + /// + /// [`Prio`]: crate::Prio + /// [`Bandwidth`]: crate::Bandwidth + async fn flush( + &mut self, + bandwidth: Bandwidth, + dt: std::time::Duration, + ) -> Result<(), ProtocolError>; +} + +/// Generic Network Recv Protocol. See: [`SendProtocol`] +/// +/// [`SendProtocol`]: crate::SendProtocol +#[async_trait] +pub trait RecvProtocol { + /// Either recv an event or fail the Protocol, once the Recv side is closed + /// it cannot recover from the error. + async fn recv(&mut self) -> Result; +} + +/// This crate makes use of UnreliableDrains, they are expected to provide the +/// same guarantees like their IO-counterpart. E.g. ordered messages for TCP and +/// nothing for UDP. The respective Protocol needs then to handle this. +/// This trait is an abstraction above multiple Drains, e.g. [`tokio`](https://tokio.rs) [`async-std`] [`std`] or even [`async-channel`] +/// +/// [`async-std`]: async-std +/// [`std`]: std +/// [`async-channel`]: async-channel +#[async_trait] +pub trait UnreliableDrain: Send { + type DataFormat; + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError>; +} + +/// Sink counterpart of [`UnreliableDrain`] +/// +/// [`UnreliableDrain`]: crate::UnreliableDrain +#[async_trait] +pub trait UnreliableSink: Send { + type DataFormat; + async fn recv(&mut self) -> Result; +} + +/// All possible Errors that can happen during Handshake [`InitProtocol`] +/// +/// [`InitProtocol`]: crate::InitProtocol +#[derive(Debug, PartialEq)] +pub enum InitProtocolError { + Closed, + WrongMagicNumber([u8; 7]), + WrongVersion([u32; 3]), +} + +/// When you return closed you must stay closed! +#[derive(Debug, PartialEq)] +pub enum ProtocolError { + Closed, +} + +impl From for InitProtocolError { + fn from(err: ProtocolError) -> Self { + match err { + ProtocolError::Closed => InitProtocolError::Closed, + } + } +} diff --git a/network/protocol/src/message.rs b/network/protocol/src/message.rs new file mode 100644 index 0000000000..c71c0b5515 --- /dev/null +++ b/network/protocol/src/message.rs @@ -0,0 +1,191 @@ +use crate::{ + frame::OTFrame, + types::{Mid, Sid}, +}; +use bytes::{Bytes, BytesMut}; + +pub(crate) const ALLOC_BLOCK: usize = 16_777_216; + +/// Contains a outgoing message for TCP protocol +/// All Chunks have the same size, except for the last chunk which can end +/// earlier. E.g. +/// ```ignore +/// msg = OTMessage::new(); +/// msg.next(); +/// msg.next(); +/// ``` +#[derive(Debug)] +pub(crate) struct OTMessage { + data: Bytes, + original_length: u64, + send_header: bool, + mid: Mid, + sid: Sid, + start: u64, /* remove */ +} + +#[derive(Debug)] +pub(crate) struct ITMessage { + pub data: BytesMut, + pub sid: Sid, + pub length: u64, +} + +impl OTMessage { + pub(crate) const FRAME_DATA_SIZE: u64 = 1400; + + pub(crate) fn new(data: Bytes, mid: Mid, sid: Sid) -> Self { + let original_length = data.len() as u64; + Self { + data, + original_length, + send_header: false, + mid, + sid, + start: 0, + } + } + + fn get_header(&self) -> OTFrame { + OTFrame::DataHeader { + mid: self.mid, + sid: self.sid, + length: self.data.len() as u64, + } + } + + fn get_next_data(&mut self) -> OTFrame { + let to_send = std::cmp::min(self.data.len(), Self::FRAME_DATA_SIZE as usize); + let data = self.data.split_to(to_send); + let start = self.start; + self.start += Self::FRAME_DATA_SIZE; + + OTFrame::Data { + mid: self.mid, + start, + data, + } + } + + /// returns if something was added + pub(crate) fn next(&mut self) -> Option { + if !self.send_header { + self.send_header = true; + Some(self.get_header()) + } else if !self.data.is_empty() { + Some(self.get_next_data()) + } else { + None + } + } + + pub(crate) fn get_sid_len(&self) -> (Sid, u64) { (self.sid, self.original_length) } +} + +impl ITMessage { + pub(crate) fn new(sid: Sid, length: u64, _allocator: &mut BytesMut) -> Self { + //allocator.reserve(ALLOC_BLOCK); + //TODO: grab mem from the allocatior, but this is only possible with unsafe + Self { + sid, + length, + data: BytesMut::with_capacity((length as usize).min(ALLOC_BLOCK /* anti-ddos */)), + } + } +} + +/* +/// Contains a outgoing message and store what was *send* and *confirmed* +/// All Chunks have the same size, except for the last chunk which can end +/// earlier. E.g. +/// ```ignore +/// msg = OutgoingMessage::new(); +/// msg.next(); +/// msg.next(); +/// msg.confirm(1); +/// msg.confirm(2); +/// ``` +#[derive(Debug)] +#[allow(dead_code)] +pub(crate) struct OUMessage { + buffer: Arc, + send_index: u64, // 3 => 4200 (3*FRAME_DATA_SIZE) + send_header: bool, + mid: Mid, + sid: Sid, + max_index: u64, //speedup + missing_header: bool, + missing_indices: VecDeque, +} + +#[allow(dead_code)] +impl OUMessage { + pub(crate) const FRAME_DATA_SIZE: u64 = 1400; + + pub(crate) fn new(buffer: Arc, mid: Mid, sid: Sid) -> Self { + let max_index = + (buffer.data.len() as u64 + Self::FRAME_DATA_SIZE - 1) / Self::FRAME_DATA_SIZE; + Self { + buffer, + send_index: 0, + send_header: false, + mid, + sid, + max_index, + missing_header: false, + missing_indices: VecDeque::new(), + } + } + + /// all has been send once, but might been resend due to failures. + #[allow(dead_code)] + pub(crate) fn initial_sent(&self) -> bool { self.send_index == self.max_index } + + pub fn get_header(&self) -> Frame { + Frame::DataHeader { + mid: self.mid, + sid: self.sid, + length: self.buffer.data.len() as u64, + } + } + + pub fn get_data(&self, index: u64) -> Frame { + let start = index * Self::FRAME_DATA_SIZE; + let to_send = std::cmp::min( + self.buffer.data[start as usize..].len() as u64, + Self::FRAME_DATA_SIZE, + ); + Frame::Data { + mid: self.mid, + start, + data: self.buffer.data[start as usize..][..to_send as usize].to_vec(), + } + } + + #[allow(dead_code)] + pub(crate) fn set_missing(&mut self, missing_header: bool, missing_indicies: VecDeque) { + self.missing_header = missing_header; + self.missing_indices = missing_indicies; + } + + /// returns if something was added + pub(crate) fn next(&mut self) -> Option { + if !self.send_header { + self.send_header = true; + Some(self.get_header()) + } else if self.send_index < self.max_index { + self.send_index += 1; + Some(self.get_data(self.send_index - 1)) + } else if self.missing_header { + self.missing_header = false; + Some(self.get_header()) + } else if let Some(index) = self.missing_indices.pop_front() { + Some(self.get_data(index)) + } else { + None + } + } + + pub(crate) fn get_sid_len(&self) -> (Sid, u64) { (self.sid, self.buffer.data.len() as u64) } +} +*/ diff --git a/network/protocol/src/metrics.rs b/network/protocol/src/metrics.rs new file mode 100644 index 0000000000..3ecacfa1a2 --- /dev/null +++ b/network/protocol/src/metrics.rs @@ -0,0 +1,418 @@ +use crate::types::Sid; +#[cfg(feature = "metrics")] +use prometheus::{ + core::{AtomicI64, AtomicU64, GenericCounter, GenericGauge}, + IntCounterVec, IntGaugeVec, Opts, Registry, +}; +#[cfg(feature = "metrics")] +use std::collections::HashMap; +use std::{error::Error, sync::Arc}; + +#[allow(dead_code)] +pub enum RemoveReason { + Finished, + Dropped, +} + +/// Use 1 `ProtocolMetrics` per `Network`. +/// I will contain all protocol related [`prometheus`] information +/// +/// [`prometheus`]: prometheus +#[cfg(feature = "metrics")] +pub struct ProtocolMetrics { + // smsg=send_msg rdata=receive_data + // i=in o=out + // t=total b=byte throughput + //e.g smsg_it = sending messages, in (responsibility of protocol) total + + // based on CHANNEL/STREAM + /// messages added to be send total, by STREAM, + smsg_it: IntCounterVec, + /// messages bytes added to be send throughput, by STREAM, + smsg_ib: IntCounterVec, + /// messages removed from to be send, because they where finished total, by + /// STREAM AND REASON(finished/canceled), + smsg_ot: IntCounterVec, + /// messages bytes removed from to be send throughput, because they where + /// finished total, by STREAM AND REASON(finished/dropped), + smsg_ob: IntCounterVec, + /// data frames send by prio by CHANNEL, + sdata_frames_t: IntCounterVec, + /// data frames bytes send by prio by CHANNEL, + sdata_frames_b: IntCounterVec, + + // based on CHANNEL/STREAM + /// messages added to be received total, by STREAM, + rmsg_it: IntCounterVec, + /// messages bytes added to be received throughput, by STREAM, + rmsg_ib: IntCounterVec, + /// messages removed from to be received, because they where finished total, + /// by STREAM AND REASON(finished/canceled), + rmsg_ot: IntCounterVec, + /// messages bytes removed from to be received throughput, because they + /// where finished total, by STREAM AND REASON(finished/dropped), + rmsg_ob: IntCounterVec, + /// data frames send by prio by CHANNEL, + rdata_frames_t: IntCounterVec, + /// data frames bytes send by prio by CHANNEL, + rdata_frames_b: IntCounterVec, + /// ping per CHANNEL //TODO: implement + ping: IntGaugeVec, +} + +/// Cache for [`ProtocolMetrics`], more optimized and cleared up after channel +/// disconnect. +/// +/// [`ProtocolMetrics`]: crate::ProtocolMetrics +#[cfg(feature = "metrics")] +#[derive(Debug, Clone)] +pub struct ProtocolMetricCache { + cid: String, + m: Arc, + cache: HashMap, + sdata_frames_t: GenericCounter, + sdata_frames_b: GenericCounter, + rdata_frames_t: GenericCounter, + rdata_frames_b: GenericCounter, + ping: GenericGauge, +} + +#[cfg(not(feature = "metrics"))] +#[derive(Debug, Clone)] +pub struct ProtocolMetricCache {} + +#[cfg(feature = "metrics")] +impl ProtocolMetrics { + pub fn new() -> Result> { + let smsg_it = IntCounterVec::new( + Opts::new( + "send_messages_in_total", + "All Messages that are added to this Protocol to be send at stream level", + ), + &["channel", "stream"], + )?; + let smsg_ib = IntCounterVec::new( + Opts::new( + "send_messages_in_throughput", + "All Message bytes that are added to this Protocol to be send at stream level", + ), + &["channel", "stream"], + )?; + let smsg_ot = IntCounterVec::new( + Opts::new( + "send_messages_out_total", + "All Messages that are removed from this Protocol to be send at stream and \ + reason(finished/canceled) level", + ), + &["channel", "stream", "reason"], + )?; + let smsg_ob = IntCounterVec::new( + Opts::new( + "send_messages_out_throughput", + "All Message bytes that are removed from this Protocol to be send at stream and \ + reason(finished/canceled) level", + ), + &["channel", "stream", "reason"], + )?; + let sdata_frames_t = IntCounterVec::new( + Opts::new( + "send_data_frames_total", + "Number of data frames send per channel", + ), + &["channel"], + )?; + let sdata_frames_b = IntCounterVec::new( + Opts::new( + "send_data_frames_throughput", + "Number of data frames bytes send per channel", + ), + &["channel"], + )?; + + let rmsg_it = IntCounterVec::new( + Opts::new( + "recv_messages_in_total", + "All Messages that are added to this Protocol to be received at stream level", + ), + &["channel", "stream"], + )?; + let rmsg_ib = IntCounterVec::new( + Opts::new( + "recv_messages_in_throughput", + "All Message bytes that are added to this Protocol to be received at stream level", + ), + &["channel", "stream"], + )?; + let rmsg_ot = IntCounterVec::new( + Opts::new( + "recv_messages_out_total", + "All Messages that are removed from this Protocol to be received at stream and \ + reason(finished/canceled) level", + ), + &["channel", "stream", "reason"], + )?; + let rmsg_ob = IntCounterVec::new( + Opts::new( + "recv_messages_out_throughput", + "All Message bytes that are removed from this Protocol to be received at stream \ + and reason(finished/canceled) level", + ), + &["channel", "stream", "reason"], + )?; + let rdata_frames_t = IntCounterVec::new( + Opts::new( + "recv_data_frames_total", + "Number of data frames received per channel", + ), + &["channel"], + )?; + let rdata_frames_b = IntCounterVec::new( + Opts::new( + "recv_data_frames_throughput", + "Number of data frames bytes received per channel", + ), + &["channel"], + )?; + let ping = IntGaugeVec::new(Opts::new("ping", "Ping per channel"), &["channel"])?; + + Ok(Self { + smsg_it, + smsg_ib, + smsg_ot, + smsg_ob, + sdata_frames_t, + sdata_frames_b, + rmsg_it, + rmsg_ib, + rmsg_ot, + rmsg_ob, + rdata_frames_t, + rdata_frames_b, + ping, + }) + } + + pub fn register(&self, registry: &Registry) -> Result<(), Box> { + registry.register(Box::new(self.smsg_it.clone()))?; + registry.register(Box::new(self.smsg_ib.clone()))?; + registry.register(Box::new(self.smsg_ot.clone()))?; + registry.register(Box::new(self.smsg_ob.clone()))?; + registry.register(Box::new(self.sdata_frames_t.clone()))?; + registry.register(Box::new(self.sdata_frames_b.clone()))?; + registry.register(Box::new(self.rmsg_it.clone()))?; + registry.register(Box::new(self.rmsg_ib.clone()))?; + registry.register(Box::new(self.rmsg_ot.clone()))?; + registry.register(Box::new(self.rmsg_ob.clone()))?; + registry.register(Box::new(self.rdata_frames_t.clone()))?; + registry.register(Box::new(self.rdata_frames_b.clone()))?; + registry.register(Box::new(self.ping.clone()))?; + Ok(()) + } +} + +#[cfg(not(feature = "metrics"))] +pub struct ProtocolMetrics {} + +#[cfg(feature = "metrics")] +#[derive(Debug, Clone)] +pub(crate) struct CacheLine { + pub smsg_it: GenericCounter, + pub smsg_ib: GenericCounter, + pub smsg_ot: [GenericCounter; 2], + pub smsg_ob: [GenericCounter; 2], + pub rmsg_it: GenericCounter, + pub rmsg_ib: GenericCounter, + pub rmsg_ot: [GenericCounter; 2], + pub rmsg_ob: [GenericCounter; 2], +} + +#[cfg(feature = "metrics")] +impl ProtocolMetricCache { + pub fn new(channel_key: &str, metrics: Arc) -> Self { + let cid = channel_key.to_string(); + let sdata_frames_t = metrics.sdata_frames_t.with_label_values(&[&cid]); + let sdata_frames_b = metrics.sdata_frames_b.with_label_values(&[&cid]); + let rdata_frames_t = metrics.rdata_frames_t.with_label_values(&[&cid]); + let rdata_frames_b = metrics.rdata_frames_b.with_label_values(&[&cid]); + let ping = metrics.ping.with_label_values(&[&cid]); + Self { + cid, + m: metrics, + cache: HashMap::new(), + sdata_frames_t, + sdata_frames_b, + rdata_frames_t, + rdata_frames_b, + ping, + } + } + + pub(crate) fn init_sid(&mut self, sid: Sid) -> &CacheLine { + let cid = &self.cid; + let m = &self.m; + self.cache.entry(sid).or_insert_with_key(|sid| { + let s = sid.to_string(); + let finished = RemoveReason::Finished.to_str(); + let dropped = RemoveReason::Dropped.to_str(); + CacheLine { + smsg_it: m.smsg_it.with_label_values(&[&cid, &s]), + smsg_ib: m.smsg_ib.with_label_values(&[&cid, &s]), + smsg_ot: [ + m.smsg_ot.with_label_values(&[&cid, &s, &finished]), + m.smsg_ot.with_label_values(&[&cid, &s, &dropped]), + ], + smsg_ob: [ + m.smsg_ob.with_label_values(&[&cid, &s, &finished]), + m.smsg_ob.with_label_values(&[&cid, &s, &dropped]), + ], + rmsg_it: m.rmsg_it.with_label_values(&[&cid, &s]), + rmsg_ib: m.rmsg_ib.with_label_values(&[&cid, &s]), + rmsg_ot: [ + m.rmsg_ot.with_label_values(&[&cid, &s, &finished]), + m.rmsg_ot.with_label_values(&[&cid, &s, &dropped]), + ], + rmsg_ob: [ + m.rmsg_ob.with_label_values(&[&cid, &s, &finished]), + m.rmsg_ob.with_label_values(&[&cid, &s, &dropped]), + ], + } + }) + } + + pub(crate) fn smsg_ib(&mut self, sid: Sid, bytes: u64) { + let line = self.init_sid(sid); + line.smsg_it.inc(); + line.smsg_ib.inc_by(bytes); + } + + pub(crate) fn smsg_ob(&mut self, sid: Sid, reason: RemoveReason, bytes: u64) { + let line = self.init_sid(sid); + line.smsg_ot[reason.i()].inc(); + line.smsg_ob[reason.i()].inc_by(bytes); + } + + pub(crate) fn sdata_frames_b(&mut self, cnt: u64, bytes: u64) { + self.sdata_frames_t.inc_by(cnt); + self.sdata_frames_b.inc_by(bytes); + } + + pub(crate) fn rmsg_ib(&mut self, sid: Sid, bytes: u64) { + let line = self.init_sid(sid); + line.rmsg_it.inc(); + line.rmsg_ib.inc_by(bytes); + } + + pub(crate) fn rmsg_ob(&mut self, sid: Sid, reason: RemoveReason, bytes: u64) { + let line = self.init_sid(sid); + line.rmsg_ot[reason.i()].inc(); + line.rmsg_ob[reason.i()].inc_by(bytes); + } + + pub(crate) fn rdata_frames_b(&mut self, bytes: u64) { + self.rdata_frames_t.inc(); + self.rdata_frames_b.inc_by(bytes); + } + + #[cfg(test)] + pub(crate) fn assert_msg(&mut self, sid: Sid, cnt: u64, reason: RemoveReason) { + let line = self.init_sid(sid); + assert_eq!(line.smsg_it.get(), cnt); + assert_eq!(line.smsg_ot[reason.i()].get(), cnt); + assert_eq!(line.rmsg_it.get(), cnt); + assert_eq!(line.rmsg_ot[reason.i()].get(), cnt); + } + + #[cfg(test)] + pub(crate) fn assert_msg_bytes(&mut self, sid: Sid, bytes: u64, reason: RemoveReason) { + let line = self.init_sid(sid); + assert_eq!(line.smsg_ib.get(), bytes); + assert_eq!(line.smsg_ob[reason.i()].get(), bytes); + assert_eq!(line.rmsg_ib.get(), bytes); + assert_eq!(line.rmsg_ob[reason.i()].get(), bytes); + } + + #[cfg(test)] + pub(crate) fn assert_data_frames(&mut self, cnt: u64) { + assert_eq!(self.sdata_frames_t.get(), cnt); + assert_eq!(self.rdata_frames_t.get(), cnt); + } + + #[cfg(test)] + pub(crate) fn assert_data_frames_bytes(&mut self, bytes: u64) { + assert_eq!(self.sdata_frames_b.get(), bytes); + assert_eq!(self.rdata_frames_b.get(), bytes); + } +} + +#[cfg(feature = "metrics")] +impl Drop for ProtocolMetricCache { + fn drop(&mut self) { + let cid = &self.cid; + let m = &self.m; + let finished = RemoveReason::Finished.to_str(); + let dropped = RemoveReason::Dropped.to_str(); + for (sid, _) in self.cache.drain() { + let s = sid.to_string(); + let _ = m.smsg_it.remove_label_values(&[&cid, &s]); + let _ = m.smsg_ib.remove_label_values(&[&cid, &s]); + let _ = m.smsg_ot.remove_label_values(&[&cid, &s, &finished]); + let _ = m.smsg_ot.remove_label_values(&[&cid, &s, &dropped]); + let _ = m.smsg_ob.remove_label_values(&[&cid, &s, &finished]); + let _ = m.smsg_ob.remove_label_values(&[&cid, &s, &dropped]); + let _ = m.rmsg_it.remove_label_values(&[&cid, &s]); + let _ = m.rmsg_ib.remove_label_values(&[&cid, &s]); + let _ = m.rmsg_ot.remove_label_values(&[&cid, &s, &finished]); + let _ = m.rmsg_ot.remove_label_values(&[&cid, &s, &dropped]); + let _ = m.rmsg_ob.remove_label_values(&[&cid, &s, &finished]); + let _ = m.rmsg_ob.remove_label_values(&[&cid, &s, &dropped]); + } + } +} + +#[cfg(feature = "metrics")] +impl std::fmt::Debug for ProtocolMetrics { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ProtocolMetrics()") + } +} + +#[cfg(not(feature = "metrics"))] +impl ProtocolMetricCache { + pub fn new(_channel_key: &str, _metrics: Arc) -> Self { Self {} } + + pub(crate) fn smsg_ib(&mut self, _sid: Sid, _b: u64) {} + + pub(crate) fn smsg_ob(&mut self, _sid: Sid, _reason: RemoveReason, _b: u64) {} + + pub(crate) fn sdata_frames_b(&mut self, _cnt: u64, _b: u64) {} + + pub(crate) fn rmsg_ib(&mut self, _sid: Sid, _b: u64) {} + + pub(crate) fn rmsg_ob(&mut self, _sid: Sid, _reason: RemoveReason, _b: u64) {} + + pub(crate) fn rdata_frames_b(&mut self, _b: u64) {} +} + +#[cfg(not(feature = "metrics"))] +impl ProtocolMetrics { + pub fn new() -> Result> { Ok(Self {}) } +} + +impl RemoveReason { + #[cfg(feature = "metrics")] + fn to_str(&self) -> &str { + match self { + RemoveReason::Finished => "Finished", + RemoveReason::Dropped => "Dropped", + } + } + + #[cfg(feature = "metrics")] + pub(crate) fn i(&self) -> usize { + match self { + RemoveReason::Finished => 0, + RemoveReason::Dropped => 1, + } + } +} diff --git a/network/protocol/src/mpsc.rs b/network/protocol/src/mpsc.rs new file mode 100644 index 0000000000..1eb987c75c --- /dev/null +++ b/network/protocol/src/mpsc.rs @@ -0,0 +1,239 @@ +#[cfg(feature = "metrics")] +use crate::metrics::RemoveReason; +use crate::{ + event::ProtocolEvent, + frame::InitFrame, + handshake::{ReliableDrain, ReliableSink}, + metrics::ProtocolMetricCache, + types::Bandwidth, + ProtocolError, RecvProtocol, SendProtocol, UnreliableDrain, UnreliableSink, +}; +use async_trait::async_trait; +use std::time::{Duration, Instant}; +#[cfg(feature = "trace_pedantic")] +use tracing::trace; + +/// used for implementing your own MPSC `Sink` and `Drain` +#[derive(Debug)] +pub enum MpscMsg { + Event(ProtocolEvent), + InitFrame(InitFrame), +} + +/// MPSC implementation of [`SendProtocol`] +/// +/// [`SendProtocol`]: crate::SendProtocol +#[derive(Debug)] +pub struct MpscSendProtocol +where + D: UnreliableDrain, +{ + drain: D, + last: Instant, + metrics: ProtocolMetricCache, +} + +/// MPSC implementation of [`RecvProtocol`] +/// +/// [`RecvProtocol`]: crate::RecvProtocol +#[derive(Debug)] +pub struct MpscRecvProtocol +where + S: UnreliableSink, +{ + sink: S, + metrics: ProtocolMetricCache, +} + +impl MpscSendProtocol +where + D: UnreliableDrain, +{ + pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self { + Self { + drain, + last: Instant::now(), + metrics, + } + } +} + +impl MpscRecvProtocol +where + S: UnreliableSink, +{ + pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self { Self { sink, metrics } } +} + +#[async_trait] +impl SendProtocol for MpscSendProtocol +where + D: UnreliableDrain, +{ + fn notify_from_recv(&mut self, _event: ProtocolEvent) {} + + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + #[cfg(feature = "trace_pedantic")] + trace!(?event, "send"); + match &event { + ProtocolEvent::Message { + data: _data, + mid: _, + sid: _sid, + } => { + #[cfg(feature = "metrics")] + let (bytes, line) = { + let sid = *_sid; + let bytes = _data.len() as u64; + let line = self.metrics.init_sid(sid); + line.smsg_it.inc(); + line.smsg_ib.inc_by(bytes); + (bytes, line) + }; + let r = self.drain.send(MpscMsg::Event(event)).await; + #[cfg(feature = "metrics")] + { + line.smsg_ot[RemoveReason::Finished.i()].inc(); + line.smsg_ob[RemoveReason::Finished.i()].inc_by(bytes); + } + r + }, + _ => self.drain.send(MpscMsg::Event(event)).await, + } + } + + async fn flush(&mut self, _: Bandwidth, _: Duration) -> Result<(), ProtocolError> { Ok(()) } +} + +#[async_trait] +impl RecvProtocol for MpscRecvProtocol +where + S: UnreliableSink, +{ + async fn recv(&mut self) -> Result { + let event = self.sink.recv().await?; + #[cfg(feature = "trace_pedantic")] + trace!(?event, "recv"); + match event { + MpscMsg::Event(e) => { + #[cfg(feature = "metrics")] + { + if let ProtocolEvent::Message { data, mid: _, sid } = &e { + let sid = *sid; + let bytes = data.len() as u64; + let line = self.metrics.init_sid(sid); + line.rmsg_it.inc(); + line.rmsg_ib.inc_by(bytes); + line.rmsg_ot[RemoveReason::Finished.i()].inc(); + line.rmsg_ob[RemoveReason::Finished.i()].inc_by(bytes); + } + } + Ok(e) + }, + MpscMsg::InitFrame(_) => Err(ProtocolError::Closed), + } + } +} + +#[async_trait] +impl ReliableDrain for MpscSendProtocol +where + D: UnreliableDrain, +{ + async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> { + self.drain.send(MpscMsg::InitFrame(frame)).await + } +} + +#[async_trait] +impl ReliableSink for MpscRecvProtocol +where + S: UnreliableSink, +{ + async fn recv(&mut self) -> Result { + match self.sink.recv().await? { + MpscMsg::Event(_) => Err(ProtocolError::Closed), + MpscMsg::InitFrame(f) => Ok(f), + } + } +} + +#[cfg(test)] +pub mod test_utils { + use super::*; + use crate::metrics::{ProtocolMetricCache, ProtocolMetrics}; + use async_channel::*; + use std::sync::Arc; + + pub struct ACDrain { + sender: Sender, + } + + pub struct ACSink { + receiver: Receiver, + } + + pub fn ac_bound( + cap: usize, + metrics: Option, + ) -> [(MpscSendProtocol, MpscRecvProtocol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("mpsc", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + MpscSendProtocol::new(ACDrain { sender: s1 }, m.clone()), + MpscRecvProtocol::new(ACSink { receiver: r2 }, m.clone()), + ), + ( + MpscSendProtocol::new(ACDrain { sender: s2 }, m.clone()), + MpscRecvProtocol::new(ACSink { receiver: r1 }, m), + ), + ] + } + + #[async_trait] + impl UnreliableDrain for ACDrain { + type DataFormat = MpscMsg; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for ACSink { + type DataFormat = MpscMsg; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + mpsc::test_utils::*, + types::{Pid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2}, + InitProtocol, + }; + + #[tokio::test] + async fn handshake_all_good() { + let [mut p1, mut p2] = ac_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42))); + assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337))); + } +} diff --git a/network/protocol/src/prio.rs b/network/protocol/src/prio.rs new file mode 100644 index 0000000000..374a1ac216 --- /dev/null +++ b/network/protocol/src/prio.rs @@ -0,0 +1,137 @@ +use crate::{ + frame::OTFrame, + message::OTMessage, + metrics::{ProtocolMetricCache, RemoveReason}, + types::{Bandwidth, Mid, Prio, Promises, Sid, HIGHEST_PRIO}, +}; +use bytes::Bytes; +use std::{ + collections::{HashMap, VecDeque}, + time::Duration, +}; + +#[derive(Debug)] +struct StreamInfo { + pub(crate) guaranteed_bandwidth: Bandwidth, + pub(crate) prio: Prio, + pub(crate) promises: Promises, + pub(crate) messages: VecDeque, +} + +/// Responsible for queueing messages. +/// every stream has a guaranteed bandwidth and a prio 0-7. +/// when `n` Bytes are available in the buffer, first the guaranteed bandwidth +/// is used. Then remaining bandwidth is used to fill up the prios. +#[derive(Debug)] +pub(crate) struct PrioManager { + streams: HashMap, + metrics: ProtocolMetricCache, +} + +// Send everything ONCE, then keep it till it's confirmed + +impl PrioManager { + pub fn new(metrics: ProtocolMetricCache) -> Self { + Self { + streams: HashMap::new(), + metrics, + } + } + + pub fn open_stream( + &mut self, + sid: Sid, + prio: Prio, + promises: Promises, + guaranteed_bandwidth: Bandwidth, + ) { + self.streams.insert(sid, StreamInfo { + guaranteed_bandwidth, + prio, + promises, + messages: VecDeque::new(), + }); + } + + pub fn try_close_stream(&mut self, sid: Sid) -> bool { + if let Some(si) = self.streams.get(&sid) { + if si.messages.is_empty() { + self.streams.remove(&sid); + return true; + } + } + false + } + + pub fn is_empty(&self) -> bool { self.streams.is_empty() } + + pub fn add(&mut self, buffer: Bytes, mid: Mid, sid: Sid) { + self.streams + .get_mut(&sid) + .unwrap() + .messages + .push_back(OTMessage::new(buffer, mid, sid)); + } + + /// bandwidth might be extended, as for technical reasons + /// guaranteed_bandwidth is used and frames are always 1400 bytes. + pub fn grab(&mut self, bandwidth: Bandwidth, dt: Duration) -> (Vec, Bandwidth) { + let total_bytes = (bandwidth as f64 * dt.as_secs_f64()) as u64; + let mut cur_bytes = 0u64; + let mut frames = vec![]; + + let mut prios = [0u64; (HIGHEST_PRIO + 1) as usize]; + let metrics = &mut self.metrics; + + let mut process_stream = + |stream: &mut StreamInfo, mut bandwidth: i64, cur_bytes: &mut u64| { + let mut finished = None; + 'outer: for (i, msg) in stream.messages.iter_mut().enumerate() { + while let Some(frame) = msg.next() { + let b = if let OTFrame::Data { data, .. } = &frame { + crate::frame::TCP_DATA_CNS + 1 + data.len() + } else { + crate::frame::TCP_DATA_HEADER_CNS + 1 + } as u64; + bandwidth -= b as i64; + *cur_bytes += b; + frames.push(frame); + if bandwidth <= 0 { + break 'outer; + } + } + let (sid, bytes) = msg.get_sid_len(); + metrics.smsg_ob(sid, RemoveReason::Finished, bytes); + finished = Some(i); + } + if let Some(i) = finished { + //cleanup + stream.messages.drain(..=i); + } + }; + + // Add guaranteed bandwidth + for stream in self.streams.values_mut() { + prios[stream.prio as usize] += 1; + let stream_byte_cnt = (stream.guaranteed_bandwidth as f64 * dt.as_secs_f64()) as u64; + process_stream(stream, stream_byte_cnt as i64, &mut cur_bytes); + } + + if cur_bytes < total_bytes { + // Add optional bandwidth + for prio in 0..=HIGHEST_PRIO { + if prios[prio as usize] == 0 { + continue; + } + let per_stream_bytes = ((total_bytes - cur_bytes) / prios[prio as usize]) as i64; + for stream in self.streams.values_mut() { + if stream.prio != prio { + continue; + } + process_stream(stream, per_stream_bytes, &mut cur_bytes); + } + } + } + (frames, cur_bytes) + } +} diff --git a/network/protocol/src/tcp.rs b/network/protocol/src/tcp.rs new file mode 100644 index 0000000000..07a44ab0de --- /dev/null +++ b/network/protocol/src/tcp.rs @@ -0,0 +1,706 @@ +use crate::{ + event::ProtocolEvent, + frame::{ITFrame, InitFrame, OTFrame}, + handshake::{ReliableDrain, ReliableSink}, + message::{ITMessage, ALLOC_BLOCK}, + metrics::{ProtocolMetricCache, RemoveReason}, + prio::PrioManager, + types::{Bandwidth, Mid, Sid}, + ProtocolError, RecvProtocol, SendProtocol, UnreliableDrain, UnreliableSink, +}; +use async_trait::async_trait; +use bytes::BytesMut; +use std::{ + collections::HashMap, + time::{Duration, Instant}, +}; +use tracing::info; +#[cfg(feature = "trace_pedantic")] +use tracing::trace; + +/// TCP implementation of [`SendProtocol`] +/// +/// [`SendProtocol`]: crate::SendProtocol +#[derive(Debug)] +pub struct TcpSendProtocol +where + D: UnreliableDrain, +{ + buffer: BytesMut, + store: PrioManager, + closing_streams: Vec, + notify_closing_streams: Vec, + pending_shutdown: bool, + drain: D, + last: Instant, + metrics: ProtocolMetricCache, +} + +/// TCP implementation of [`RecvProtocol`] +/// +/// [`RecvProtocol`]: crate::RecvProtocol +#[derive(Debug)] +pub struct TcpRecvProtocol +where + S: UnreliableSink, +{ + buffer: BytesMut, + itmsg_allocator: BytesMut, + incoming: HashMap, + sink: S, + metrics: ProtocolMetricCache, +} + +impl TcpSendProtocol +where + D: UnreliableDrain, +{ + pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self { + Self { + buffer: BytesMut::new(), + store: PrioManager::new(metrics.clone()), + closing_streams: vec![], + notify_closing_streams: vec![], + pending_shutdown: false, + drain, + last: Instant::now(), + metrics, + } + } +} + +impl TcpRecvProtocol +where + S: UnreliableSink, +{ + pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self { + Self { + buffer: BytesMut::new(), + itmsg_allocator: BytesMut::with_capacity(ALLOC_BLOCK), + incoming: HashMap::new(), + sink, + metrics, + } + } +} + +#[async_trait] +impl SendProtocol for TcpSendProtocol +where + D: UnreliableDrain, +{ + fn notify_from_recv(&mut self, event: ProtocolEvent) { + match event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + self.store + .open_stream(sid, prio, promises, guaranteed_bandwidth); + }, + ProtocolEvent::CloseStream { sid } => { + if !self.store.try_close_stream(sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "hold back notify close stream"); + self.notify_closing_streams.push(sid); + } + }, + _ => {}, + } + } + + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + #[cfg(feature = "trace_pedantic")] + trace!(?event, "send"); + match event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + self.store + .open_stream(sid, prio, promises, guaranteed_bandwidth); + event.to_frame().write_bytes(&mut self.buffer); + self.drain.send(self.buffer.split()).await?; + }, + ProtocolEvent::CloseStream { sid } => { + if self.store.try_close_stream(sid) { + event.to_frame().write_bytes(&mut self.buffer); + self.drain.send(self.buffer.split()).await?; + } else { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "hold back close stream"); + self.closing_streams.push(sid); + } + }, + ProtocolEvent::Shutdown => { + if self.store.is_empty() { + event.to_frame().write_bytes(&mut self.buffer); + self.drain.send(self.buffer.split()).await?; + } else { + #[cfg(feature = "trace_pedantic")] + trace!("hold back shutdown"); + self.pending_shutdown = true; + } + }, + ProtocolEvent::Message { data, mid, sid } => { + self.metrics.smsg_ib(sid, data.len() as u64); + self.store.add(data, mid, sid); + }, + } + Ok(()) + } + + async fn flush(&mut self, bandwidth: Bandwidth, dt: Duration) -> Result<(), ProtocolError> { + let (frames, total_bytes) = self.store.grab(bandwidth, dt); + self.buffer.reserve(total_bytes as usize); + let mut data_frames = 0; + let mut data_bandwidth = 0; + for frame in frames { + if let OTFrame::Data { + mid: _, + start: _, + data, + } = &frame + { + data_bandwidth += data.len(); + data_frames += 1; + } + frame.write_bytes(&mut self.buffer); + } + self.drain.send(self.buffer.split()).await?; + self.metrics + .sdata_frames_b(data_frames, data_bandwidth as u64); + + let mut finished_streams = vec![]; + for (i, &sid) in self.closing_streams.iter().enumerate() { + if self.store.try_close_stream(sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "close stream, as it's now empty"); + OTFrame::CloseStream { sid }.write_bytes(&mut self.buffer); + self.drain.send(self.buffer.split()).await?; + finished_streams.push(i); + } + } + for i in finished_streams.iter().rev() { + self.closing_streams.remove(*i); + } + + let mut finished_streams = vec![]; + for (i, sid) in self.notify_closing_streams.iter().enumerate() { + if self.store.try_close_stream(*sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "close stream, as it's now empty"); + finished_streams.push(i); + } + } + for i in finished_streams.iter().rev() { + self.notify_closing_streams.remove(*i); + } + + if self.pending_shutdown && self.store.is_empty() { + #[cfg(feature = "trace_pedantic")] + trace!("shutdown, as it's now empty"); + OTFrame::Shutdown {}.write_bytes(&mut self.buffer); + self.drain.send(self.buffer.split()).await?; + self.pending_shutdown = false; + } + Ok(()) + } +} + +#[async_trait] +impl RecvProtocol for TcpRecvProtocol +where + S: UnreliableSink, +{ + async fn recv(&mut self) -> Result { + 'outer: loop { + while let Some(frame) = ITFrame::read_frame(&mut self.buffer) { + #[cfg(feature = "trace_pedantic")] + trace!(?frame, "recv"); + match frame { + ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), + ITFrame::OpenStream { + sid, + prio, + promises, + } => { + break 'outer Ok(ProtocolEvent::OpenStream { + sid, + prio: prio.min(crate::types::HIGHEST_PRIO), + promises, + guaranteed_bandwidth: 1_000_000, + }); + }, + ITFrame::CloseStream { sid } => { + break 'outer Ok(ProtocolEvent::CloseStream { sid }); + }, + ITFrame::DataHeader { sid, mid, length } => { + let m = ITMessage::new(sid, length, &mut self.itmsg_allocator); + self.metrics.rmsg_ib(sid, length); + self.incoming.insert(mid, m); + }, + ITFrame::Data { + mid, + start: _, + data, + } => { + self.metrics.rdata_frames_b(data.len() as u64); + let m = match self.incoming.get_mut(&mid) { + Some(m) => m, + None => { + info!( + ?mid, + "protocol violation by remote side: send Data before Header" + ); + break 'outer Err(ProtocolError::Closed); + }, + }; + m.data.extend_from_slice(&data); + if m.data.len() == m.length as usize { + // finished, yay + let m = self.incoming.remove(&mid).unwrap(); + self.metrics.rmsg_ob( + m.sid, + RemoveReason::Finished, + m.data.len() as u64, + ); + break 'outer Ok(ProtocolEvent::Message { + sid: m.sid, + mid, + data: m.data.freeze(), + }); + } + }, + }; + } + let chunk = self.sink.recv().await?; + if self.buffer.is_empty() { + self.buffer = chunk; + } else { + self.buffer.extend_from_slice(&chunk); + } + } + } +} + +#[async_trait] +impl ReliableDrain for TcpSendProtocol +where + D: UnreliableDrain, +{ + async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> { + let mut buffer = BytesMut::with_capacity(500); + frame.write_bytes(&mut buffer); + self.drain.send(buffer).await + } +} + +#[async_trait] +impl ReliableSink for TcpRecvProtocol +where + S: UnreliableSink, +{ + async fn recv(&mut self) -> Result { + while self.buffer.len() < 100 { + let chunk = self.sink.recv().await?; + self.buffer.extend_from_slice(&chunk); + if let Some(frame) = InitFrame::read_frame(&mut self.buffer) { + return Ok(frame); + } + } + Err(ProtocolError::Closed) + } +} + +#[cfg(test)] +mod test_utils { + //TCP protocol based on Channel + use super::*; + use crate::metrics::{ProtocolMetricCache, ProtocolMetrics}; + use async_channel::*; + use std::sync::Arc; + + pub struct TcpDrain { + pub sender: Sender, + } + + pub struct TcpSink { + pub receiver: Receiver, + } + + /// emulate Tcp protocol on Channels + pub fn tcp_bound( + cap: usize, + metrics: Option, + ) -> [(TcpSendProtocol, TcpRecvProtocol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("tcp", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + TcpSendProtocol::new(TcpDrain { sender: s1 }, m.clone()), + TcpRecvProtocol::new(TcpSink { receiver: r2 }, m.clone()), + ), + ( + TcpSendProtocol::new(TcpDrain { sender: s2 }, m.clone()), + TcpRecvProtocol::new(TcpSink { receiver: r1 }, m), + ), + ] + } + + #[async_trait] + impl UnreliableDrain for TcpDrain { + type DataFormat = BytesMut; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for TcpSink { + type DataFormat = BytesMut; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + frame::OTFrame, + metrics::{ProtocolMetricCache, ProtocolMetrics, RemoveReason}, + tcp::test_utils::*, + types::{Pid, Promises, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2}, + InitProtocol, ProtocolError, ProtocolEvent, RecvProtocol, SendProtocol, + }; + use bytes::{Bytes, BytesMut}; + use std::{sync::Arc, time::Duration}; + + #[tokio::test] + async fn handshake_all_good() { + let [mut p1, mut p2] = tcp_bound(10, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42))); + assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337))); + } + + #[tokio::test] + async fn open_stream() { + let [p1, p2] = tcp_bound(10, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 0u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event.clone()).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn send_short_msg() { + let [p1, p2] = tcp_bound(10, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + mid: 0, + data: Bytes::from(&[188u8; 600][..]), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + // 2nd short message + let event = ProtocolEvent::Message { + sid: Sid::new(10), + mid: 1, + data: Bytes::from(&[7u8; 30][..]), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e) + } + + #[tokio::test] + async fn send_long_msg() { + let mut metrics = + ProtocolMetricCache::new("long_tcp", Arc::new(ProtocolMetrics::new().unwrap())); + let sid = Sid::new(1); + let [p1, p2] = tcp_bound(10000, Some(metrics.clone())); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + mid: 77, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + metrics.assert_msg(sid, 1, RemoveReason::Finished); + metrics.assert_msg_bytes(sid, 500_000, RemoveReason::Finished); + metrics.assert_data_frames(358); + metrics.assert_data_frames_bytes(500_000); + } + + #[tokio::test] + async fn msg_finishes_after_close() { + let sid = Sid::new(1); + let [p1, p2] = tcp_bound(10000, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + mid: 77, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::CloseStream { sid }; + s.send(event).await.unwrap(); + //send + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + } + + #[tokio::test] + async fn msg_finishes_after_shutdown() { + let sid = Sid::new(1); + let [p1, p2] = tcp_bound(10000, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + mid: 77, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::Shutdown {}; + s.send(event).await.unwrap(); + let event = ProtocolEvent::CloseStream { sid }; + s.send(event).await.unwrap(); + //send + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Shutdown { .. })); + } + + #[tokio::test] + async fn msg_finishes_after_drop() { + let sid = Sid::new(1); + let [p1, p2] = tcp_bound(10000, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::Message { + sid, + mid: 77, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let event = ProtocolEvent::Message { + sid, + mid: 78, + data: Bytes::from(&[100u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + drop(s); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + } + + #[tokio::test] + async fn header_and_data_in_seperate_msg() { + let sid = Sid::new(1); + let (s, r) = async_channel::bounded(10); + let m = ProtocolMetricCache::new("tcp", Arc::new(ProtocolMetrics::new().unwrap())); + let mut r = + super::TcpRecvProtocol::new(super::test_utils::TcpSink { receiver: r }, m.clone()); + + const DATA1: &[u8; 69] = + b"We need to make sure that its okay to send OPEN_STREAM and DATA_HEAD "; + const DATA2: &[u8; 95] = b"in one chunk and (DATA and CLOSE_STREAM) in the second chunk. and then keep the connection open"; + let mut bytes = BytesMut::with_capacity(1500); + OTFrame::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + } + .write_bytes(&mut bytes); + OTFrame::DataHeader { + mid: 99, + sid, + length: (DATA1.len() + DATA2.len()) as u64, + } + .write_bytes(&mut bytes); + s.send(bytes.split()).await.unwrap(); + + OTFrame::Data { + mid: 99, + start: 0, + data: Bytes::from(&DATA1[..]), + } + .write_bytes(&mut bytes); + OTFrame::Data { + mid: 99, + start: DATA1.len() as u64, + data: Bytes::from(&DATA2[..]), + } + .write_bytes(&mut bytes); + OTFrame::CloseStream { sid }.write_bytes(&mut bytes); + s.send(bytes.split()).await.unwrap(); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + } + + #[tokio::test] + async fn drop_sink_while_recv() { + let sid = Sid::new(1); + let (s, r) = async_channel::bounded(10); + let m = ProtocolMetricCache::new("tcp", Arc::new(ProtocolMetrics::new().unwrap())); + let mut r = + super::TcpRecvProtocol::new(super::test_utils::TcpSink { receiver: r }, m.clone()); + + let mut bytes = BytesMut::with_capacity(1500); + OTFrame::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + } + .write_bytes(&mut bytes); + s.send(bytes.split()).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + + let e = tokio::spawn(async move { r.recv().await }); + drop(s); + + let e = e.await.unwrap(); + assert_eq!(e, Err(ProtocolError::Closed)); + } + + #[tokio::test] + #[should_panic] + async fn send_on_stream_from_remote_without_notify() { + //remote opens stream + //we send on it + let [mut p1, mut p2] = tcp_bound(10, None); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + p1.0.send(event).await.unwrap(); + let _ = p2.1.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + mid: 0, + data: Bytes::from(&[188u8; 600][..]), + }; + p2.0.send(event.clone()).await.unwrap(); + p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = p1.1.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn send_on_stream_from_remote() { + //remote opens stream + //we send on it + let [mut p1, mut p2] = tcp_bound(10, None); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + p1.0.send(event).await.unwrap(); + let e = p2.1.recv().await.unwrap(); + p2.0.notify_from_recv(e); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + mid: 0, + data: Bytes::from(&[188u8; 600][..]), + }; + p2.0.send(event.clone()).await.unwrap(); + p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = p1.1.recv().await.unwrap(); + assert_eq!(event, e); + } +} diff --git a/network/protocol/src/types.rs b/network/protocol/src/types.rs new file mode 100644 index 0000000000..afa5f0d866 --- /dev/null +++ b/network/protocol/src/types.rs @@ -0,0 +1,207 @@ +use bitflags::bitflags; +use bytes::{Buf, BufMut, BytesMut}; +use rand::Rng; + +/// MessageID, unique ID per Message. +pub type Mid = u64; +/// ChannelID, unique ID per Channel (Protocol) +pub type Cid = u64; +/// Every Stream has a `Prio` and guaranteed [`Bandwidth`]. +/// Every send, the guarantees part is used first. +/// If there is still bandwidth left, it will be shared by all Streams with the +/// same priority. Prio 0 will be send first, then 1, ... till the last prio 7 +/// is send. Prio must be < 8! +/// +/// [`Bandwidth`]: crate::Bandwidth +pub type Prio = u8; +/// guaranteed `Bandwidth`. See [`Prio`] +/// +/// [`Prio`]: crate::Prio +pub type Bandwidth = u64; + +bitflags! { + /// use promises to modify the behavior of [`Streams`]. + /// see the consts in this `struct` for + /// + /// [`Streams`]: crate::api::Stream + pub struct Promises: u8 { + /// this will guarantee that the order of messages which are send on one side, + /// is the same when received on the other. + const ORDERED = 0b00000001; + /// this will guarantee that messages received haven't been altered by errors, + /// like bit flips, this is done with a checksum. + const CONSISTENCY = 0b00000010; + /// this will guarantee that the other side will receive every message exactly + /// once no messages are dropped + const GUARANTEED_DELIVERY = 0b00000100; + /// this will enable the internal compression on this, only useable with #[cfg(feature = "compression")] + /// [`Stream`](crate::api::Stream) + const COMPRESSED = 0b00001000; + /// this will enable the internal encryption on this + /// [`Stream`](crate::api::Stream) + const ENCRYPTED = 0b00010000; + } +} + +impl Promises { + pub const fn to_le_bytes(self) -> [u8; 1] { self.bits.to_le_bytes() } +} + +pub(crate) const VELOREN_MAGIC_NUMBER: [u8; 7] = *b"VELOREN"; +/// When this semver differs, 2 Networks can't communicate. +pub const VELOREN_NETWORK_VERSION: [u32; 3] = [0, 5, 0]; +pub(crate) const STREAM_ID_OFFSET1: Sid = Sid::new(0); +pub(crate) const STREAM_ID_OFFSET2: Sid = Sid::new(u64::MAX / 2); +/// Maximal possible Prio to choose (for performance reasons) +pub const HIGHEST_PRIO: u8 = 7; + +/// Support struct used for uniquely identifying `Participant` over the +/// `Network`. +#[derive(PartialEq, Eq, Hash, Clone, Copy)] +pub struct Pid { + internal: u128, +} + +/// Unique ID per Stream, in one Channel. +/// one side will always start with 0, while the other start with u64::MAX / 2. +/// number increases for each created Stream. +#[derive(PartialEq, Eq, Hash, Clone, Copy)] +pub struct Sid { + internal: u64, +} + +impl Pid { + /// create a new Pid with a random interior value + /// + /// # Example + /// ```rust + /// use veloren_network_protocol::Pid; + /// + /// let pid = Pid::new(); + /// ``` + pub fn new() -> Self { + Self { + internal: rand::thread_rng().gen(), + } + } + + /// don't use fake! just for testing! + /// This will panic if pid i greater than 7, as I do not want you to use + /// this in production! + #[doc(hidden)] + pub fn fake(pid_offset: u8) -> Self { + assert!(pid_offset < 8); + let o = pid_offset as u128; + const OFF: [u128; 5] = [ + 0x40, + 0x40 * 0x40, + 0x40 * 0x40 * 0x40, + 0x40 * 0x40 * 0x40 * 0x40, + 0x40 * 0x40 * 0x40 * 0x40 * 0x40, + ]; + Self { + internal: o + o * OFF[0] + o * OFF[1] + o * OFF[2] + o * OFF[3] + o * OFF[4], + } + } + + #[inline] + pub(crate) fn from_bytes(bytes: &mut BytesMut) -> Self { + Self { + internal: bytes.get_u128_le(), + } + } + + #[inline] + pub(crate) fn to_bytes(&self, bytes: &mut BytesMut) { bytes.put_u128_le(self.internal) } +} + +impl Sid { + pub const fn new(internal: u64) -> Self { Self { internal } } + + #[inline] + pub(crate) fn from_bytes(bytes: &mut BytesMut) -> Self { + Self { + internal: bytes.get_u64_le(), + } + } + + #[inline] + pub(crate) fn to_bytes(&self, bytes: &mut BytesMut) { bytes.put_u64_le(self.internal) } +} + +impl std::fmt::Debug for Pid { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + const BITS_PER_SIXLET: usize = 6; + //only print last 6 chars of number as full u128 logs are unreadable + const CHAR_COUNT: usize = 6; + for i in 0..CHAR_COUNT { + write!( + f, + "{}", + sixlet_to_str((self.internal >> (i * BITS_PER_SIXLET)) & 0x3F) + )?; + } + Ok(()) + } +} + +impl Default for Pid { + fn default() -> Self { Pid::new() } +} + +impl std::fmt::Display for Pid { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:?}", self) } +} + +impl std::ops::AddAssign for Sid { + fn add_assign(&mut self, other: Self) { + *self = Self { + internal: self.internal + other.internal, + }; + } +} + +impl std::fmt::Debug for Sid { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + //only print last 6 chars of number as full u128 logs are unreadable + write!(f, "{}", self.internal.rem_euclid(1000000)) + } +} + +impl From for Sid { + fn from(internal: u64) -> Self { Sid { internal } } +} + +impl std::fmt::Display for Sid { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.internal) + } +} + +fn sixlet_to_str(sixlet: u128) -> char { + b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[sixlet as usize] as char +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn frame_creation() { + Pid::new(); + assert_eq!(format!("{}", Pid::fake(0)), "AAAAAA"); + assert_eq!(format!("{}", Pid::fake(1)), "BBBBBB"); + assert_eq!(format!("{}", Pid::fake(2)), "CCCCCC"); + } + + #[test] + fn test_sixlet_to_str() { + assert_eq!(sixlet_to_str(0), 'A'); + assert_eq!(sixlet_to_str(29), 'd'); + assert_eq!(sixlet_to_str(63), '/'); + } +} diff --git a/network/protocol/src/udp.rs b/network/protocol/src/udp.rs new file mode 100644 index 0000000000..ad5c31a126 --- /dev/null +++ b/network/protocol/src/udp.rs @@ -0,0 +1,37 @@ +// TODO: quick and dirty which activly waits for an ack! +/* +UDP protocol + +All Good Case: +S --HEADER--> R +S --DATA--> R +S --DATA--> R +S <--FINISHED-- R + + +Delayed HEADER: +S --HEADER--> +S --DATA--> R // STORE IT + --HEADER--> R // apply left data and continue +S --DATA--> R +S <--FINISHED-- R + + +NO HEADER: +S --HEADER--> ! +S --DATA--> R // STORE IT +S --DATA--> R // STORE IT +S <--MISSING_HEADER-- R // SEND AFTER 10 ms after DATA1 +S --HEADER--> R +S <--FINISHED-- R + + +NO DATA: +S --HEADER--> R +S --DATA--> R +S --DATA--> ! +S --STATUS--> R +S <--MISSING_DATA -- R +S --DATA--> R +S <--FINISHED-- R +*/ diff --git a/network/src/api.rs b/network/src/api.rs index 66ffa82096..f60b4c4743 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -1,21 +1,12 @@ -//! -//! -//! -//! (cd network/examples/async_recv && RUST_BACKTRACE=1 cargo run) use crate::{ - message::{partial_eq_bincode, IncomingMessage, Message, OutgoingMessage}, + message::{partial_eq_bincode, Message}, participant::{A2bStreamOpen, S2bShutdownBparticipant}, scheduler::Scheduler, - types::{Mid, Pid, Prio, Promises, Sid}, -}; -use async_std::{io, sync::Mutex, task}; -use futures::{ - channel::{mpsc, oneshot}, - sink::SinkExt, - stream::StreamExt, }; +use bytes::Bytes; #[cfg(feature = "compression")] use lz_fear::raw::DecodeError; +use network_protocol::{Bandwidth, Pid, Prio, Promises, Sid}; #[cfg(feature = "metrics")] use prometheus::Registry; use serde::{de::DeserializeOwned, Serialize}; @@ -26,9 +17,14 @@ use std::{ atomic::{AtomicBool, Ordering}, Arc, }, + time::Duration, +}; +use tokio::{ + io, + runtime::Runtime, + sync::{mpsc, oneshot, Mutex}, }; use tracing::*; -use tracing_futures::Instrument; type A2sDisconnect = Arc>>>; @@ -50,7 +46,7 @@ pub enum ProtocolAddr { pub struct Participant { local_pid: Pid, remote_pid: Pid, - a2b_stream_open_s: Mutex>, + a2b_open_stream_s: Mutex>, b2a_stream_opened_r: Mutex>, a2s_disconnect_s: A2sDisconnect, } @@ -70,14 +66,15 @@ pub struct Participant { /// [`opened`]: Participant::opened #[derive(Debug)] pub struct Stream { - pid: Pid, + local_pid: Pid, + remote_pid: Pid, sid: Sid, - mid: Mid, prio: Prio, promises: Promises, + guaranteed_bandwidth: Bandwidth, send_closed: Arc, - a2b_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, - b2a_msg_recv_r: Option>, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Bytes)>, + b2a_msg_recv_r: Option>, a2b_close_stream_s: Option>, } @@ -125,17 +122,17 @@ pub enum StreamError { /// /// # Examples /// ```rust +/// # use std::sync::Arc; +/// use tokio::runtime::Runtime; /// use veloren_network::{Network, ProtocolAddr, Pid}; -/// use futures::executor::block_on; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on port `2999` to accept connections and connect to port `8080` to connect to a (pseudo) database Application -/// let (network, f) = Network::new(Pid::new()); -/// std::thread::spawn(f); -/// block_on(async{ +/// let runtime = Arc::new(Runtime::new().unwrap()); +/// let network = Network::new(Pid::new(), Arc::clone(&runtime)); +/// runtime.block_on(async{ /// # //setup pseudo database! -/// # let (database, fd) = Network::new(Pid::new()); -/// # std::thread::spawn(fd); +/// # let database = Network::new(Pid::new(), Arc::clone(&runtime)); /// # database.listen(ProtocolAddr::Tcp("127.0.0.1:8080".parse().unwrap())).await?; /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2999".parse().unwrap())).await?; /// let database = network.connect(ProtocolAddr::Tcp("127.0.0.1:8080".parse().unwrap())).await?; @@ -150,9 +147,10 @@ pub enum StreamError { /// [`connected`]: Network::connected pub struct Network { local_pid: Pid, + runtime: Arc, participant_disconnect_sender: Mutex>, listen_sender: - Mutex>)>>, + Mutex>)>>, connect_sender: Mutex>)>>, connected_receiver: Mutex>, @@ -165,35 +163,25 @@ impl Network { /// # Arguments /// * `participant_id` - provide it by calling [`Pid::new()`], usually you /// don't want to reuse a Pid for 2 `Networks` + /// * `runtime` - provide a tokio::Runtime, it's used to internally spawn + /// tasks. It is necessary to clean up in the non-async `Drop`. **All** + /// network related components **must** be dropped before the runtime is + /// stopped. dropping the runtime while a shutdown is still in progress + /// leaves the network in a bad state which might cause a panic! /// /// # Result /// * `Self` - returns a `Network` which can be `Send` to multiple areas of /// your code, including multiple threads. This is the base strct of this /// crate. - /// * `FnOnce` - you need to run the returning FnOnce exactly once, probably - /// in it's own thread. this is NOT done internally, so that you are free - /// to choose the threadpool implementation of your choice. We recommend - /// using [`ThreadPool`] from [`uvth`] crate. This fn will run the - /// Scheduler to handle all `Network` internals. Additional threads will - /// be allocated on an internal async-aware threadpool /// /// # Examples /// ```rust - /// //Example with uvth - /// use uvth::ThreadPoolBuilder; + /// use std::sync::Arc; + /// use tokio::runtime::Runtime; /// use veloren_network::{Network, Pid, ProtocolAddr}; /// - /// let pool = ThreadPoolBuilder::new().build(); - /// let (network, f) = Network::new(Pid::new()); - /// pool.execute(f); - /// ``` - /// - /// ```rust - /// //Example with std::thread - /// use veloren_network::{Network, Pid, ProtocolAddr}; - /// - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); + /// let runtime = Runtime::new().unwrap(); + /// let network = Network::new(Pid::new(), Arc::new(runtime)); /// ``` /// /// Usually you only create a single `Network` for an application, @@ -201,12 +189,11 @@ impl Network { /// will want 2. However there are no technical limitations from /// creating more. /// - /// [`Pid::new()`]: crate::types::Pid::new - /// [`ThreadPool`]: https://docs.rs/uvth/newest/uvth/struct.ThreadPool.html - /// [`uvth`]: https://docs.rs/uvth - pub fn new(participant_id: Pid) -> (Self, impl std::ops::FnOnce()) { + /// [`Pid::new()`]: network_protocol::Pid::new + pub fn new(participant_id: Pid, runtime: Arc) -> Self { Self::internal_new( participant_id, + runtime, #[cfg(feature = "metrics")] None, ) @@ -221,53 +208,56 @@ impl Network { /// /// # Examples /// ```rust + /// # use std::sync::Arc; /// use prometheus::Registry; + /// use tokio::runtime::Runtime; /// use veloren_network::{Network, Pid, ProtocolAddr}; /// + /// let runtime = Runtime::new().unwrap(); /// let registry = Registry::new(); - /// let (network, f) = Network::new_with_registry(Pid::new(), ®istry); - /// std::thread::spawn(f); + /// let network = Network::new_with_registry(Pid::new(), Arc::new(runtime), ®istry); /// ``` /// [`new`]: crate::api::Network::new #[cfg(feature = "metrics")] pub fn new_with_registry( participant_id: Pid, + runtime: Arc, registry: &Registry, - ) -> (Self, impl std::ops::FnOnce()) { - Self::internal_new(participant_id, Some(registry)) + ) -> Self { + Self::internal_new(participant_id, runtime, Some(registry)) } fn internal_new( participant_id: Pid, + runtime: Arc, #[cfg(feature = "metrics")] registry: Option<&Registry>, - ) -> (Self, impl std::ops::FnOnce()) { + ) -> Self { let p = participant_id; - debug!(?p, "Starting Network"); + let span = tracing::info_span!("network", ?p); + span.in_scope(|| trace!("Starting Network")); let (scheduler, listen_sender, connect_sender, connected_receiver, shutdown_sender) = Scheduler::new( participant_id, #[cfg(feature = "metrics")] registry, ); - ( - Self { - local_pid: participant_id, - participant_disconnect_sender: Mutex::new(HashMap::new()), - listen_sender: Mutex::new(listen_sender), - connect_sender: Mutex::new(connect_sender), - connected_receiver: Mutex::new(connected_receiver), - shutdown_sender: Some(shutdown_sender), - }, - move || { - trace!(?p, "Starting scheduler in own thread"); - let _handle = task::block_on( - scheduler - .run() - .instrument(tracing::info_span!("scheduler", ?p)), - ); - trace!(?p, "Stopping scheduler and his own thread"); - }, - ) + runtime.spawn( + async move { + trace!("Starting scheduler in own thread"); + scheduler.run().await; + trace!("Stopping scheduler and his own thread"); + } + .instrument(tracing::info_span!("network", ?p)), + ); + Self { + local_pid: participant_id, + runtime, + participant_disconnect_sender: Mutex::new(HashMap::new()), + listen_sender: Mutex::new(listen_sender), + connect_sender: Mutex::new(connect_sender), + connected_receiver: Mutex::new(connected_receiver), + shutdown_sender: Some(shutdown_sender), + } } /// starts listening on an [`ProtocolAddr`]. @@ -278,15 +268,16 @@ impl Network { /// support multiple Protocols or NICs. /// /// # Examples - /// ```rust - /// use futures::executor::block_on; + /// ```ignore + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; /// use veloren_network::{Network, Pid, ProtocolAddr}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on port `2000` TCP on all NICs and `2001` UDP locally - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// network /// .listen(ProtocolAddr::Tcp("127.0.0.1:2000".parse().unwrap())) /// .await?; @@ -299,14 +290,14 @@ impl Network { /// ``` /// /// [`connected`]: Network::connected + #[instrument(name="network", skip(self, address), fields(p = %self.local_pid))] pub async fn listen(&self, address: ProtocolAddr) -> Result<(), NetworkError> { - let (s2a_result_s, s2a_result_r) = oneshot::channel::>(); + let (s2a_result_s, s2a_result_r) = oneshot::channel::>(); debug!(?address, "listening on address"); self.listen_sender .lock() .await - .send((address, s2a_result_s)) - .await?; + .send((address, s2a_result_s))?; match s2a_result_r.await? { //waiting guarantees that we either listened successfully or get an error like port in // use @@ -319,17 +310,17 @@ impl Network { /// When the method returns the Network either returns a [`Participant`] /// ready to open [`Streams`] on OR has returned a [`NetworkError`] (e.g. /// can't connect, or invalid Handshake) # Examples - /// ```rust - /// use futures::executor::block_on; + /// ```ignore + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; /// use veloren_network::{Network, Pid, ProtocolAddr}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, connect on port `2010` TCP and `2011` UDP like listening above - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// # remote.listen(ProtocolAddr::Tcp("127.0.0.1:2010".parse().unwrap())).await?; /// # remote.listen(ProtocolAddr::Udp("127.0.0.1:2011".parse().unwrap())).await?; /// let p1 = network @@ -355,27 +346,24 @@ impl Network { /// /// [`Streams`]: crate::api::Stream /// [`ProtocolAddres`]: crate::api::ProtocolAddr + #[instrument(name="network", skip(self, address), fields(p = %self.local_pid))] pub async fn connect(&self, address: ProtocolAddr) -> Result { let (pid_sender, pid_receiver) = oneshot::channel::>(); debug!(?address, "Connect to address"); self.connect_sender .lock() .await - .send((address, pid_sender)) - .await?; + .send((address, pid_sender))?; let participant = match pid_receiver.await? { Ok(p) => p, Err(e) => return Err(NetworkError::ConnectFailed(e)), }; - let pid = participant.remote_pid; - debug!( - ?pid, - "Received Participant id from remote and return to user" - ); + let remote_pid = participant.remote_pid; + trace!(?remote_pid, "connected"); self.participant_disconnect_sender .lock() .await - .insert(pid, Arc::clone(&participant.a2s_disconnect_s)); + .insert(remote_pid, Arc::clone(&participant.a2s_disconnect_s)); Ok(participant) } @@ -386,16 +374,16 @@ impl Network { /// /// # Examples /// ```rust - /// use futures::executor::block_on; + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; /// use veloren_network::{Network, Pid, ProtocolAddr}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on port `2020` TCP and opens returns their Pid - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// network /// .listen(ProtocolAddr::Tcp("127.0.0.1:2020".parse().unwrap())) /// .await?; @@ -412,8 +400,9 @@ impl Network { /// /// [`Streams`]: crate::api::Stream /// [`listen`]: crate::api::Network::listen + #[instrument(name="network", skip(self), fields(p = %self.local_pid))] pub async fn connected(&self) -> Result { - let participant = self.connected_receiver.lock().await.next().await?; + let participant = self.connected_receiver.lock().await.recv().await?; self.participant_disconnect_sender.lock().await.insert( participant.remote_pid, Arc::clone(&participant.a2s_disconnect_s), @@ -426,14 +415,14 @@ impl Participant { pub(crate) fn new( local_pid: Pid, remote_pid: Pid, - a2b_stream_open_s: mpsc::UnboundedSender, + a2b_open_stream_s: mpsc::UnboundedSender, b2a_stream_opened_r: mpsc::UnboundedReceiver, a2s_disconnect_s: mpsc::UnboundedSender<(Pid, S2bShutdownBparticipant)>, ) -> Self { Self { local_pid, remote_pid, - a2b_stream_open_s: Mutex::new(a2b_stream_open_s), + a2b_open_stream_s: Mutex::new(a2b_open_stream_s), b2a_stream_opened_r: Mutex::new(b2a_stream_opened_r), a2s_disconnect_s: Arc::new(Mutex::new(Some(a2s_disconnect_s))), } @@ -443,10 +432,8 @@ impl Participant { /// [`Promises`] /// /// # Arguments - /// * `prio` - valid between 0-63. The priority rates the throughput for - /// messages of the [`Stream`] e.g. prio 5 messages will get 1/2 the speed - /// prio0 messages have. Prio10 messages only 1/4 and Prio 15 only 1/8, - /// etc... + /// * `prio` - defines which stream is processed first when limited on + /// bandwidth. See [`Prio`] for documentation. /// * `promises` - use a combination of you prefered [`Promises`], see the /// link for further documentation. You can combine them, e.g. /// `Promises::ORDERED | Promises::CONSISTENCY` The Stream will then @@ -458,49 +445,52 @@ impl Participant { /// /// # Examples /// ```rust - /// use futures::executor::block_on; + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; /// use veloren_network::{Network, Pid, Promises, ProtocolAddr}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, connect on port 2100 and open a stream - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// # remote.listen(ProtocolAddr::Tcp("127.0.0.1:2100".parse().unwrap())).await?; /// let p1 = network /// .connect(ProtocolAddr::Tcp("127.0.0.1:2100".parse().unwrap())) /// .await?; /// let _s1 = p1 - /// .open(16, Promises::ORDERED | Promises::CONSISTENCY) + /// .open(4, Promises::ORDERED | Promises::CONSISTENCY) /// .await?; /// # Ok(()) /// }) /// # } /// ``` /// + /// [`Prio`]: network_protocol::Prio + /// [`Promises`]: network_protocol::Promises /// [`Streams`]: crate::api::Stream + #[instrument(name="network", skip(self, prio, promises), fields(p = %self.local_pid))] pub async fn open(&self, prio: u8, promises: Promises) -> Result { - let (p2a_return_stream_s, p2a_return_stream_r) = oneshot::channel(); - if let Err(e) = self - .a2b_stream_open_s - .lock() - .await - .send((prio, promises, p2a_return_stream_s)) - .await - { + debug_assert!(prio <= network_protocol::HIGHEST_PRIO, "invalid prio"); + let (p2a_return_stream_s, p2a_return_stream_r) = oneshot::channel::(); + if let Err(e) = self.a2b_open_stream_s.lock().await.send(( + prio, + promises, + 1_000_000, + p2a_return_stream_s, + )) { debug!(?e, "bParticipant is already closed, notifying"); return Err(ParticipantError::ParticipantDisconnected); } match p2a_return_stream_r.await { Ok(stream) => { let sid = stream.sid; - debug!(?sid, ?self.remote_pid, "opened stream"); + trace!(?sid, "opened stream"); Ok(stream) }, Err(_) => { - debug!(?self.remote_pid, "p2a_return_stream_r failed, closing participant"); + debug!("p2a_return_stream_r failed, closing participant"); Err(ParticipantError::ParticipantDisconnected) }, } @@ -515,21 +505,21 @@ impl Participant { /// /// # Examples /// ```rust + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; /// use veloren_network::{Network, Pid, ProtocolAddr, Promises}; - /// use futures::executor::block_on; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, connect on port 2110 and wait for the other side to open a stream /// // Note: It's quite unusual to actively connect, but then wait on a stream to be connected, usually the Application taking initiative want's to also create the first Stream. - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// # remote.listen(ProtocolAddr::Tcp("127.0.0.1:2110".parse().unwrap())).await?; /// let p1 = network.connect(ProtocolAddr::Tcp("127.0.0.1:2110".parse().unwrap())).await?; /// # let p2 = remote.connected().await?; - /// # p2.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # p2.open(4, Promises::ORDERED | Promises::CONSISTENCY).await?; /// let _s1 = p1.opened().await?; /// # Ok(()) /// }) @@ -539,15 +529,16 @@ impl Participant { /// [`Streams`]: crate::api::Stream /// [`connected`]: Network::connected /// [`open`]: Participant::open + #[instrument(name="network", skip(self), fields(p = %self.local_pid))] pub async fn opened(&self) -> Result { - match self.b2a_stream_opened_r.lock().await.next().await { + match self.b2a_stream_opened_r.lock().await.recv().await { Some(stream) => { let sid = stream.sid; - debug!(?sid, ?self.remote_pid, "Receive opened stream"); + debug!(?sid, "Receive opened stream"); Ok(stream) }, None => { - debug!(?self.remote_pid, "stream_opened_receiver failed, closing participant"); + debug!("stream_opened_receiver failed, closing participant"); Err(ParticipantError::ParticipantDisconnected) }, } @@ -570,16 +561,16 @@ impl Participant { /// /// # Examples /// ```rust - /// use futures::executor::block_on; + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; /// use veloren_network::{Network, Pid, ProtocolAddr}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on port `2030` TCP and opens returns their Pid and close connection. - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// network /// .listen(ProtocolAddr::Tcp("127.0.0.1:2030".parse().unwrap())) /// .await?; @@ -596,27 +587,26 @@ impl Participant { /// ``` /// /// [`Streams`]: crate::api::Stream + #[instrument(name="network", skip(self), fields(p = %self.local_pid))] pub async fn disconnect(self) -> Result<(), ParticipantError> { // Remove, Close and try_unwrap error when unwrap fails! - let pid = self.remote_pid; - debug!(?pid, "Closing participant from network"); + debug!("Closing participant from network"); //Streams will be closed by BParticipant match self.a2s_disconnect_s.lock().await.take() { - Some(mut a2s_disconnect_s) => { + Some(a2s_disconnect_s) => { let (finished_sender, finished_receiver) = oneshot::channel(); // Participant is connecting to Scheduler here, not as usual // Participant<->BParticipant a2s_disconnect_s - .send((pid, finished_sender)) - .await + .send((self.remote_pid, (Duration::from_secs(120), finished_sender))) .expect("Something is wrong in internal scheduler coding"); match finished_receiver.await { Ok(res) => { match res { - Ok(()) => trace!(?pid, "Participant is now closed"), + Ok(()) => trace!("Participant is now closed"), Err(ref e) => { - trace!(?pid, ?e, "Error occurred during shutdown of participant") + trace!(?e, "Error occurred during shutdown of participant") }, }; res @@ -624,7 +614,6 @@ impl Participant { Err(e) => { //this is a bug. but as i am Participant i can't destroy the network error!( - ?pid, ?e, "Failed to get a message back from the scheduler, seems like the \ network is already closed" @@ -643,28 +632,31 @@ impl Participant { } } - /// Returns the remote [`Pid`] + /// Returns the remote [`Pid`](network_protocol::Pid) pub fn remote_pid(&self) -> Pid { self.remote_pid } } impl Stream { #[allow(clippy::too_many_arguments)] pub(crate) fn new( - pid: Pid, + local_pid: Pid, + remote_pid: Pid, sid: Sid, prio: Prio, promises: Promises, + guaranteed_bandwidth: Bandwidth, send_closed: Arc, - a2b_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, - b2a_msg_recv_r: mpsc::UnboundedReceiver, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Bytes)>, + b2a_msg_recv_r: async_channel::Receiver, a2b_close_stream_s: mpsc::UnboundedSender, ) -> Self { Self { - pid, + local_pid, + remote_pid, sid, - mid: 0, prio, promises, + guaranteed_bandwidth, send_closed, a2b_msg_s, b2a_msg_recv_r: Some(b2a_msg_recv_r), @@ -698,21 +690,21 @@ impl Stream { /// /// # Example /// ``` - /// use veloren_network::{Network, ProtocolAddr, Pid}; /// # use veloren_network::Promises; - /// use futures::executor::block_on; + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; + /// use veloren_network::{Network, ProtocolAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on Port `2200` and wait for a Stream to be opened, then answer `Hello World` - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2200".parse().unwrap())).await?; /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2200".parse().unwrap())).await?; /// # // keep it alive - /// # let _stream_p = remote_p.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # let _stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY).await?; /// let participant_a = network.connected().await?; /// let mut stream_a = participant_a.opened().await?; /// //Send Message @@ -738,26 +730,24 @@ impl Stream { /// /// # Example /// ```rust - /// use veloren_network::{Network, ProtocolAddr, Pid, Message}; /// # use veloren_network::Promises; - /// use futures::executor::block_on; + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; /// use bincode; - /// use std::sync::Arc; + /// use veloren_network::{Network, ProtocolAddr, Pid, Message}; /// /// # fn main() -> std::result::Result<(), Box> { - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote1, fr1) = Network::new(Pid::new()); - /// # std::thread::spawn(fr1); - /// # let (remote2, fr2) = Network::new(Pid::new()); - /// # std::thread::spawn(fr2); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote1 = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote2 = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; /// # let remote1_p = remote1.connect(ProtocolAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; /// # let remote2_p = remote2.connect(ProtocolAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; /// # assert_eq!(remote1_p.remote_pid(), remote2_p.remote_pid()); - /// # remote1_p.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; - /// # remote2_p.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # remote1_p.open(4, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # remote2_p.open(4, Promises::ORDERED | Promises::CONSISTENCY).await?; /// let participant_a = network.connected().await?; /// let participant_b = network.connected().await?; /// let mut stream_a = participant_a.opened().await?; @@ -783,13 +773,7 @@ impl Stream { } #[cfg(debug_assertions)] message.verify(&self); - self.a2b_msg_s.send((self.prio, self.sid, OutgoingMessage { - buffer: Arc::clone(&message.buffer), - cursor: 0, - mid: self.mid, - sid: self.sid, - }))?; - self.mid += 1; + self.a2b_msg_s.send((self.sid, message.data.clone()))?; Ok(()) } @@ -804,20 +788,20 @@ impl Stream { /// /// # Example /// ``` - /// use veloren_network::{Network, ProtocolAddr, Pid}; /// # use veloren_network::Promises; - /// use futures::executor::block_on; + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; + /// use veloren_network::{Network, ProtocolAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on Port `2220` and wait for a Stream to be opened, then listen on it - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2220".parse().unwrap())).await?; /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2220".parse().unwrap())).await?; - /// # let mut stream_p = remote_p.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # let mut stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY).await?; /// # stream_p.send("Hello World"); /// let participant_a = network.connected().await?; /// let mut stream_a = participant_a.opened().await?; @@ -837,20 +821,20 @@ impl Stream { /// /// # Example /// ``` - /// use veloren_network::{Network, ProtocolAddr, Pid}; /// # use veloren_network::Promises; - /// use futures::executor::block_on; + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; + /// use veloren_network::{Network, ProtocolAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on Port `2230` and wait for a Stream to be opened, then listen on it - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2230".parse().unwrap())).await?; /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2230".parse().unwrap())).await?; - /// # let mut stream_p = remote_p.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # let mut stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY).await?; /// # stream_p.send("Hello World"); /// let participant_a = network.connected().await?; /// let mut stream_a = participant_a.opened().await?; @@ -869,13 +853,13 @@ impl Stream { pub async fn recv_raw(&mut self) -> Result { match &mut self.b2a_msg_recv_r { Some(b2a_msg_recv_r) => { - match b2a_msg_recv_r.next().await { - Some(msg) => Ok(Message { - buffer: Arc::new(msg.buffer), + match b2a_msg_recv_r.recv().await { + Ok(data) => Ok(Message { + data, #[cfg(feature = "compression")] compressed: self.promises.contains(Promises::COMPRESSED), }), - None => { + Err(_) => { self.b2a_msg_recv_r = None; //prevent panic Err(StreamError::StreamClosed) }, @@ -892,20 +876,20 @@ impl Stream { /// /// # Example /// ``` - /// use veloren_network::{Network, ProtocolAddr, Pid}; /// # use veloren_network::Promises; - /// use futures::executor::block_on; + /// # use std::sync::Arc; + /// use tokio::runtime::Runtime; + /// use veloren_network::{Network, ProtocolAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on Port `2240` and wait for a Stream to be opened, then listen on it - /// let (network, f) = Network::new(Pid::new()); - /// std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// block_on(async { + /// let runtime = Arc::new(Runtime::new().unwrap()); + /// let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// runtime.block_on(async { /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2240".parse().unwrap())).await?; /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2240".parse().unwrap())).await?; - /// # let mut stream_p = remote_p.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # let mut stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY).await?; /// # stream_p.send("Hello World"); /// # std::thread::sleep(std::time::Duration::from_secs(1)); /// let participant_a = network.connected().await?; @@ -921,20 +905,20 @@ impl Stream { #[inline] pub fn try_recv(&mut self) -> Result, StreamError> { match &mut self.b2a_msg_recv_r { - Some(b2a_msg_recv_r) => match b2a_msg_recv_r.try_next() { - Err(_) => Ok(None), - Ok(None) => { - self.b2a_msg_recv_r = None; //prevent panic - Err(StreamError::StreamClosed) - }, - Ok(Some(msg)) => Ok(Some( + Some(b2a_msg_recv_r) => match b2a_msg_recv_r.try_recv() { + Ok(data) => Ok(Some( Message { - buffer: Arc::new(msg.buffer), + data, #[cfg(feature = "compression")] compressed: self.promises().contains(Promises::COMPRESSED), } .deserialize()?, )), + Err(async_channel::TryRecvError::Empty) => Ok(None), + Err(async_channel::TryRecvError::Closed) => { + self.b2a_msg_recv_r = None; //prevent panic + Err(StreamError::StreamClosed) + }, }, None => Err(StreamError::StreamClosed), } @@ -952,111 +936,129 @@ impl core::cmp::PartialEq for Participant { } impl Drop for Network { + #[instrument(name="network", skip(self), fields(p = %self.local_pid))] fn drop(&mut self) { - let pid = self.local_pid; - debug!(?pid, "Shutting down Network"); - trace!( - ?pid, - "Shutting down Participants of Network, while we still have metrics" - ); + debug!("Shutting down Network"); + trace!("Shutting down Participants of Network, while we still have metrics"); let mut finished_receiver_list = vec![]; - task::block_on(async { - // we MUST avoid nested block_on, good that Network::Drop no longer triggers - // Participant::Drop directly but just the BParticipant - for (remote_pid, a2s_disconnect_s) in - self.participant_disconnect_sender.lock().await.drain() - { - match a2s_disconnect_s.lock().await.take() { - Some(mut a2s_disconnect_s) => { - trace!(?remote_pid, "Participants will be closed"); - let (finished_sender, finished_receiver) = oneshot::channel(); - finished_receiver_list.push((remote_pid, finished_receiver)); - a2s_disconnect_s - .send((remote_pid, finished_sender)) - .await - .expect( - "Scheduler is closed, but nobody other should be able to close it", - ); - }, - None => trace!(?remote_pid, "Participant already disconnected gracefully"), + + if tokio::runtime::Handle::try_current().is_ok() { + error!("we have a runtime but we mustn't, DROP NETWORK from async runtime is illegal") + } + + tokio::task::block_in_place(|| { + self.runtime.block_on(async { + for (remote_pid, a2s_disconnect_s) in + self.participant_disconnect_sender.lock().await.drain() + { + match a2s_disconnect_s.lock().await.take() { + Some(a2s_disconnect_s) => { + trace!(?remote_pid, "Participants will be closed"); + let (finished_sender, finished_receiver) = oneshot::channel(); + finished_receiver_list.push((remote_pid, finished_receiver)); + a2s_disconnect_s + .send((remote_pid, (Duration::from_secs(120), finished_sender))) + .expect( + "Scheduler is closed, but nobody other should be able to \ + close it", + ); + }, + None => trace!(?remote_pid, "Participant already disconnected gracefully"), + } } - } - //wait after close is requested for all - for (remote_pid, finished_receiver) in finished_receiver_list.drain(..) { - match finished_receiver.await { - Ok(Ok(())) => trace!(?remote_pid, "disconnect successful"), - Ok(Err(e)) => info!(?remote_pid, ?e, "unclean disconnect"), - Err(e) => warn!( - ?remote_pid, - ?e, - "Failed to get a message back from the scheduler, seems like the network \ - is already closed" - ), + //wait after close is requested for all + for (remote_pid, finished_receiver) in finished_receiver_list.drain(..) { + match finished_receiver.await { + Ok(Ok(())) => trace!(?remote_pid, "disconnect successful"), + Ok(Err(e)) => info!(?remote_pid, ?e, "unclean disconnect"), + Err(e) => warn!( + ?remote_pid, + ?e, + "Failed to get a message back from the scheduler, seems like the \ + network is already closed" + ), + } } - } + }); }); - trace!(?pid, "Participants have shut down!"); - trace!(?pid, "Shutting down Scheduler"); + trace!("Participants have shut down!"); + trace!("Shutting down Scheduler"); self.shutdown_sender .take() .unwrap() .send(()) .expect("Scheduler is closed, but nobody other should be able to close it"); - debug!(?pid, "Network has shut down"); + debug!("Network has shut down"); } } impl Drop for Participant { + #[instrument(name="remote", skip(self), fields(p = %self.remote_pid))] + #[instrument(name="network", skip(self), fields(p = %self.local_pid))] fn drop(&mut self) { + use tokio::sync::oneshot::error::TryRecvError; // ignore closed, as we need to send it even though we disconnected the // participant from network - let pid = self.remote_pid; - debug!(?pid, "Shutting down Participant"); + debug!("Shutting down Participant"); - match task::block_on(self.a2s_disconnect_s.lock()).take() { - None => trace!( - ?pid, - "Participant has been shutdown cleanly, no further waiting is required!" - ), - Some(mut a2s_disconnect_s) => { - debug!(?pid, "Disconnect from Scheduler"); - task::block_on(async { - let (finished_sender, finished_receiver) = oneshot::channel(); - a2s_disconnect_s - .send((self.remote_pid, finished_sender)) - .await - .expect("Something is wrong in internal scheduler coding"); - if let Err(e) = finished_receiver - .await - .expect("Something is wrong in internal scheduler/participant coding") - { - error!( - ?pid, - ?e, - "Error while dropping the participant, couldn't send all outgoing \ - messages, dropping remaining" - ); - }; - }); + match self + .a2s_disconnect_s + .try_lock() + .expect("Participant in use while beeing dropped") + .take() + { + None => info!("Participant already has been shutdown gracefully"), + Some(a2s_disconnect_s) => { + debug!("Disconnect from Scheduler"); + let (finished_sender, mut finished_receiver) = oneshot::channel(); + a2s_disconnect_s + .send((self.remote_pid, (Duration::from_secs(120), finished_sender))) + .expect("Something is wrong in internal scheduler coding"); + loop { + match finished_receiver.try_recv() { + Ok(Ok(())) => { + info!("Participant dropped gracefully"); + break; + }, + Ok(Err(e)) => { + error!( + ?e, + "Error while dropping the participant, couldn't send all outgoing \ + messages, dropping remaining" + ); + break; + }, + Err(TryRecvError::Closed) => { + panic!("Something is wrong in internal scheduler/participant coding") + }, + Err(TryRecvError::Empty) => { + trace!("activly sleeping"); + std::thread::sleep(Duration::from_millis(20)); + }, + } + } }, } - debug!(?pid, "Participant dropped"); } } impl Drop for Stream { + #[instrument(name="remote", skip(self), fields(p = %self.remote_pid))] + #[instrument(name="network", skip(self), fields(p = %self.local_pid))] fn drop(&mut self) { // send if closed is unnecessary but doesn't hurt, we must not crash if !self.send_closed.load(Ordering::Relaxed) { let sid = self.sid; - let pid = self.pid; - debug!(?pid, ?sid, "Shutting down Stream"); - task::block_on(self.a2b_close_stream_s.take().unwrap().send(self.sid)) - .expect("bparticipant part of a gracefully shutdown must have crashed"); + debug!(?sid, "Shutting down Stream"); + if let Err(e) = self.a2b_close_stream_s.take().unwrap().send(self.sid) { + debug!( + ?e, + "bparticipant part of a gracefully shutdown was already closed" + ); + } } else { let sid = self.sid; - let pid = self.pid; - trace!(?pid, ?sid, "Stream Drop not needed"); + trace!(?sid, "Stream Drop not needed"); } } } @@ -1088,12 +1090,16 @@ impl From for NetworkError { fn from(_err: std::option::NoneError) -> Self { NetworkError::NetworkClosed } } -impl From for NetworkError { - fn from(_err: mpsc::SendError) -> Self { NetworkError::NetworkClosed } +impl From> for NetworkError { + fn from(_err: mpsc::error::SendError) -> Self { NetworkError::NetworkClosed } } -impl From for NetworkError { - fn from(_err: oneshot::Canceled) -> Self { NetworkError::NetworkClosed } +impl From for NetworkError { + fn from(_err: oneshot::error::RecvError) -> Self { NetworkError::NetworkClosed } +} + +impl From for NetworkError { + fn from(_err: std::io::Error) -> Self { NetworkError::NetworkClosed } } impl From> for StreamError { diff --git a/network/src/channel.rs b/network/src/channel.rs index c591fb0b88..22265a7f8e 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -1,359 +1,272 @@ -#[cfg(feature = "metrics")] -use crate::metrics::NetworkMetrics; -use crate::{ - participant::C2pFrame, - protocols::Protocols, - types::{ - Cid, Frame, Pid, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2, VELOREN_MAGIC_NUMBER, - VELOREN_NETWORK_VERSION, - }, +use async_trait::async_trait; +use bytes::BytesMut; +use network_protocol::{ + Cid, InitProtocolError, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid, ProtocolError, + ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol, TcpSendProtocol, + UnreliableDrain, UnreliableSink, }; -use futures::{ - channel::{mpsc, oneshot}, - join, - sink::SinkExt, - stream::StreamExt, - FutureExt, +use std::{sync::Arc, time::Duration}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::tcp::{OwnedReadHalf, OwnedWriteHalf}, + sync::mpsc, }; -#[cfg(feature = "metrics")] use std::sync::Arc; -use tracing::*; -pub(crate) struct Channel { - cid: Cid, - c2w_frame_r: Option>, - read_stop_receiver: Option>, -} - -impl Channel { - pub fn new(cid: u64) -> (Self, mpsc::UnboundedSender, oneshot::Sender<()>) { - let (c2w_frame_s, c2w_frame_r) = mpsc::unbounded::(); - let (read_stop_sender, read_stop_receiver) = oneshot::channel(); - ( - Self { - cid, - c2w_frame_r: Some(c2w_frame_r), - read_stop_receiver: Some(read_stop_receiver), - }, - c2w_frame_s, - read_stop_sender, - ) - } - - pub async fn run( - mut self, - protocol: Protocols, - mut w2c_cid_frame_s: mpsc::UnboundedSender, - mut leftover_cid_frame: Vec, - ) { - let c2w_frame_r = self.c2w_frame_r.take().unwrap(); - let read_stop_receiver = self.read_stop_receiver.take().unwrap(); - - //reapply leftovers from handshake - let cnt = leftover_cid_frame.len(); - trace!(?cnt, "Reapplying leftovers"); - for cid_frame in leftover_cid_frame.drain(..) { - w2c_cid_frame_s.send(cid_frame).await.unwrap(); - } - trace!(?cnt, "All leftovers reapplied"); - - trace!("Start up channel"); - match protocol { - Protocols::Tcp(tcp) => { - join!( - tcp.read_from_wire(self.cid, &mut w2c_cid_frame_s, read_stop_receiver), - tcp.write_to_wire(self.cid, c2w_frame_r), - ); - }, - Protocols::Udp(udp) => { - join!( - udp.read_from_wire(self.cid, &mut w2c_cid_frame_s, read_stop_receiver), - udp.write_to_wire(self.cid, c2w_frame_r), - ); - }, - } - - trace!("Shut down channel"); - } +#[derive(Debug)] +pub(crate) enum Protocols { + Tcp((TcpSendProtocol, TcpRecvProtocol)), + Mpsc((MpscSendProtocol, MpscRecvProtocol)), } #[derive(Debug)] -pub(crate) struct Handshake { - cid: Cid, - local_pid: Pid, - secret: u128, - init_handshake: bool, - #[cfg(feature = "metrics")] - metrics: Arc, +pub(crate) enum SendProtocols { + Tcp(TcpSendProtocol), + Mpsc(MpscSendProtocol), } -impl Handshake { - #[cfg(debug_assertions)] - const WRONG_NUMBER: &'static [u8] = "Handshake does not contain the magic number required by \ - veloren server.\nWe are not sure if you are a valid \ - veloren client.\nClosing the connection" - .as_bytes(); - #[cfg(debug_assertions)] - const WRONG_VERSION: &'static str = "Handshake does contain a correct magic number, but \ - invalid version.\nWe don't know how to communicate with \ - you.\nClosing the connection"; +#[derive(Debug)] +pub(crate) enum RecvProtocols { + Tcp(TcpRecvProtocol), + Mpsc(MpscRecvProtocol), +} - pub fn new( - cid: u64, +impl Protocols { + pub(crate) fn new_tcp( + stream: tokio::net::TcpStream, + cid: Cid, + metrics: Arc, + ) -> Self { + let (r, w) = stream.into_split(); + let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); + + let sp = TcpSendProtocol::new(TcpDrain { half: w }, metrics.clone()); + let rp = TcpRecvProtocol::new( + TcpSink { + half: r, + buffer: BytesMut::new(), + }, + metrics, + ); + Protocols::Tcp((sp, rp)) + } + + pub(crate) fn new_mpsc( + sender: mpsc::Sender, + receiver: mpsc::Receiver, + cid: Cid, + metrics: Arc, + ) -> Self { + let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); + + let sp = MpscSendProtocol::new(MpscDrain { sender }, metrics.clone()); + let rp = MpscRecvProtocol::new(MpscSink { receiver }, metrics); + Protocols::Mpsc((sp, rp)) + } + + pub(crate) fn split(self) -> (SendProtocols, RecvProtocols) { + match self { + Protocols::Tcp((s, r)) => (SendProtocols::Tcp(s), RecvProtocols::Tcp(r)), + Protocols::Mpsc((s, r)) => (SendProtocols::Mpsc(s), RecvProtocols::Mpsc(r)), + } + } +} + +#[async_trait] +impl network_protocol::InitProtocol for Protocols { + async fn initialize( + &mut self, + initializer: bool, local_pid: Pid, secret: u128, - #[cfg(feature = "metrics")] metrics: Arc, - init_handshake: bool, - ) -> Self { - Self { - cid, - local_pid, - secret, - #[cfg(feature = "metrics")] - metrics, - init_handshake, + ) -> Result<(Pid, Sid, u128), InitProtocolError> { + match self { + Protocols::Tcp(p) => p.initialize(initializer, local_pid, secret).await, + Protocols::Mpsc(p) => p.initialize(initializer, local_pid, secret).await, } } - - pub async fn setup(self, protocol: &Protocols) -> Result<(Pid, Sid, u128, Vec), ()> { - let (c2w_frame_s, c2w_frame_r) = mpsc::unbounded::(); - let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded::(); - - let (read_stop_sender, read_stop_receiver) = oneshot::channel(); - let handler_future = - self.frame_handler(&mut w2c_cid_frame_r, c2w_frame_s, read_stop_sender); - let res = match protocol { - Protocols::Tcp(tcp) => { - (join! { - tcp.read_from_wire(self.cid, &mut w2c_cid_frame_s, read_stop_receiver), - tcp.write_to_wire(self.cid, c2w_frame_r).fuse(), - handler_future, - }) - .2 - }, - Protocols::Udp(udp) => { - (join! { - udp.read_from_wire(self.cid, &mut w2c_cid_frame_s, read_stop_receiver), - udp.write_to_wire(self.cid, c2w_frame_r), - handler_future, - }) - .2 - }, - }; - - match res { - Ok(res) => { - let mut leftover_frames = vec![]; - while let Ok(Some(cid_frame)) = w2c_cid_frame_r.try_next() { - leftover_frames.push(cid_frame); - } - let cnt = leftover_frames.len(); - if cnt > 0 { - debug!( - ?cnt, - "Some additional frames got already transferred, piping them to the \ - bparticipant as leftover_frames" - ); - } - Ok((res.0, res.1, res.2, leftover_frames)) - }, - Err(()) => Err(()), - } - } - - async fn frame_handler( - &self, - w2c_cid_frame_r: &mut mpsc::UnboundedReceiver, - mut c2w_frame_s: mpsc::UnboundedSender, - read_stop_sender: oneshot::Sender<()>, - ) -> Result<(Pid, Sid, u128), ()> { - const ERR_S: &str = "Got A Raw Message, these are usually Debug Messages indicating that \ - something went wrong on network layer and connection will be closed"; - #[cfg(feature = "metrics")] - let cid_string = self.cid.to_string(); - - if self.init_handshake { - self.send_handshake(&mut c2w_frame_s).await; - } - - let frame = w2c_cid_frame_r.next().await.map(|(_cid, frame)| frame); - #[cfg(feature = "metrics")] - { - if let Some(Ok(ref frame)) = frame { - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, &frame.get_string()]) - .inc(); - } - } - let r = match frame { - Some(Ok(Frame::Handshake { - magic_number, - version, - })) => { - trace!(?magic_number, ?version, "Recv handshake"); - if magic_number != VELOREN_MAGIC_NUMBER { - error!(?magic_number, "Connection with invalid magic_number"); - #[cfg(debug_assertions)] - self.send_raw_and_shutdown(&mut c2w_frame_s, Self::WRONG_NUMBER.to_vec()) - .await; - Err(()) - } else if version != VELOREN_NETWORK_VERSION { - error!(?version, "Connection with wrong network version"); - #[cfg(debug_assertions)] - self.send_raw_and_shutdown( - &mut c2w_frame_s, - format!( - "{} Our Version: {:?}\nYour Version: {:?}\nClosing the connection", - Self::WRONG_VERSION, - VELOREN_NETWORK_VERSION, - version, - ) - .as_bytes() - .to_vec(), - ) - .await; - Err(()) - } else { - debug!("Handshake completed"); - if self.init_handshake { - self.send_init(&mut c2w_frame_s).await; - } else { - self.send_handshake(&mut c2w_frame_s).await; - } - Ok(()) - } - }, - Some(Ok(frame)) => { - #[cfg(feature = "metrics")] - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, frame.get_string()]) - .inc(); - if let Frame::Raw(bytes) = frame { - match std::str::from_utf8(bytes.as_slice()) { - Ok(string) => error!(?string, ERR_S), - _ => error!(?bytes, ERR_S), - } - } - Err(()) - }, - Some(Err(())) => { - info!("Protocol got interrupted"); - Err(()) - }, - None => Err(()), - }; - if let Err(()) = r { - if let Err(e) = read_stop_sender.send(()) { - trace!( - ?e, - "couldn't stop protocol, probably it encountered a Protocol Stop and closed \ - itself already, which is fine" - ); - } - return Err(()); - } - - let frame = w2c_cid_frame_r.next().await.map(|(_cid, frame)| frame); - let r = match frame { - Some(Ok(Frame::Init { pid, secret })) => { - debug!(?pid, "Participant send their ID"); - #[cfg(feature = "metrics")] - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, "ParticipantId"]) - .inc(); - let stream_id_offset = if self.init_handshake { - STREAM_ID_OFFSET1 - } else { - self.send_init(&mut c2w_frame_s).await; - STREAM_ID_OFFSET2 - }; - info!(?pid, "This Handshake is now configured!"); - Ok((pid, stream_id_offset, secret)) - }, - Some(Ok(frame)) => { - #[cfg(feature = "metrics")] - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, frame.get_string()]) - .inc(); - if let Frame::Raw(bytes) = frame { - match std::str::from_utf8(bytes.as_slice()) { - Ok(string) => error!(?string, ERR_S), - _ => error!(?bytes, ERR_S), - } - } - Err(()) - }, - Some(Err(())) => { - info!("Protocol got interrupted"); - Err(()) - }, - None => Err(()), - }; - if r.is_err() { - if let Err(e) = read_stop_sender.send(()) { - trace!( - ?e, - "couldn't stop protocol, probably it encountered a Protocol Stop and closed \ - itself already, which is fine" - ); - } - } - r - } - - async fn send_handshake(&self, c2w_frame_s: &mut mpsc::UnboundedSender) { - #[cfg(feature = "metrics")] - self.metrics - .frames_out_total - .with_label_values(&[&self.cid.to_string(), "Handshake"]) - .inc(); - c2w_frame_s - .send(Frame::Handshake { - magic_number: VELOREN_MAGIC_NUMBER, - version: VELOREN_NETWORK_VERSION, - }) - .await - .unwrap(); - } - - async fn send_init(&self, c2w_frame_s: &mut mpsc::UnboundedSender) { - #[cfg(feature = "metrics")] - self.metrics - .frames_out_total - .with_label_values(&[&self.cid.to_string(), "ParticipantId"]) - .inc(); - c2w_frame_s - .send(Frame::Init { - pid: self.local_pid, - secret: self.secret, - }) - .await - .unwrap(); - } - - #[cfg(debug_assertions)] - async fn send_raw_and_shutdown( - &self, - c2w_frame_s: &mut mpsc::UnboundedSender, - data: Vec, - ) { - debug!("Sending client instructions before killing"); - #[cfg(feature = "metrics")] - { - let cid_string = self.cid.to_string(); - self.metrics - .frames_out_total - .with_label_values(&[&cid_string, "Raw"]) - .inc(); - self.metrics - .frames_out_total - .with_label_values(&[&cid_string, "Shutdown"]) - .inc(); - } - c2w_frame_s.send(Frame::Raw(data)).await.unwrap(); - c2w_frame_s.send(Frame::Shutdown).await.unwrap(); - } +} + +#[async_trait] +impl network_protocol::SendProtocol for SendProtocols { + fn notify_from_recv(&mut self, event: ProtocolEvent) { + match self { + SendProtocols::Tcp(s) => s.notify_from_recv(event), + SendProtocols::Mpsc(s) => s.notify_from_recv(event), + } + } + + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + match self { + SendProtocols::Tcp(s) => s.send(event).await, + SendProtocols::Mpsc(s) => s.send(event).await, + } + } + + async fn flush(&mut self, bandwidth: u64, dt: Duration) -> Result<(), ProtocolError> { + match self { + SendProtocols::Tcp(s) => s.flush(bandwidth, dt).await, + SendProtocols::Mpsc(s) => s.flush(bandwidth, dt).await, + } + } +} + +#[async_trait] +impl network_protocol::RecvProtocol for RecvProtocols { + async fn recv(&mut self) -> Result { + match self { + RecvProtocols::Tcp(r) => r.recv().await, + RecvProtocols::Mpsc(r) => r.recv().await, + } + } +} + +/////////////////////////////////////// +//// TCP +#[derive(Debug)] +pub struct TcpDrain { + half: OwnedWriteHalf, +} + +#[derive(Debug)] +pub struct TcpSink { + half: OwnedReadHalf, + buffer: BytesMut, +} + +#[async_trait] +impl UnreliableDrain for TcpDrain { + type DataFormat = BytesMut; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + match self.half.write_all(&data).await { + Ok(()) => Ok(()), + Err(_) => Err(ProtocolError::Closed), + } + } +} + +#[async_trait] +impl UnreliableSink for TcpSink { + type DataFormat = BytesMut; + + async fn recv(&mut self) -> Result { + self.buffer.resize(1500, 0u8); + match self.half.read(&mut self.buffer).await { + Ok(0) => Err(ProtocolError::Closed), + Ok(n) => Ok(self.buffer.split_to(n)), + Err(_) => Err(ProtocolError::Closed), + } + } +} + +/////////////////////////////////////// +//// MPSC +#[derive(Debug)] +pub struct MpscDrain { + sender: tokio::sync::mpsc::Sender, +} + +#[derive(Debug)] +pub struct MpscSink { + receiver: tokio::sync::mpsc::Receiver, +} + +#[async_trait] +impl UnreliableDrain for MpscDrain { + type DataFormat = MpscMsg; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } +} + +#[async_trait] +impl UnreliableSink for MpscSink { + type DataFormat = MpscMsg; + + async fn recv(&mut self) -> Result { + self.receiver.recv().await.ok_or(ProtocolError::Closed) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use network_protocol::{Promises, RecvProtocol, SendProtocol}; + use tokio::net::{TcpListener, TcpStream}; + + #[tokio::test] + async fn tokio_sinks() { + let listener = TcpListener::bind("127.0.0.1:5000").await.unwrap(); + let r1 = tokio::spawn(async move { + let (server, _) = listener.accept().await.unwrap(); + (listener, server) + }); + let client = TcpStream::connect("127.0.0.1:5000").await.unwrap(); + let (_listener, server) = r1.await.unwrap(); + let metrics = Arc::new(ProtocolMetrics::new().unwrap()); + let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics)); + let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics)); + let (mut s, _) = client.split(); + let (_, mut r) = server.split(); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(1), + prio: 4u8, + promises: Promises::GUARANTEED_DELIVERY, + guaranteed_bandwidth: 1_000, + }; + s.send(event.clone()).await.unwrap(); + s.send(ProtocolEvent::Message { + sid: Sid::new(1), + mid: 0, + data: Bytes::from(&[8u8; 8][..]), + }) + .await + .unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + drop(s); // recv must work even after shutdown of send! + tokio::time::sleep(Duration::from_secs(1)).await; + let res = r.recv().await; + match res { + Ok(ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth: _, + }) => { + assert_eq!(sid, Sid::new(1)); + assert_eq!(prio, 4u8); + assert_eq!(promises, Promises::GUARANTEED_DELIVERY); + }, + _ => { + panic!("wrong type {:?}", res); + }, + } + r.recv().await.unwrap(); + } + + #[tokio::test] + async fn tokio_sink_stop_after_drop() { + let listener = TcpListener::bind("127.0.0.1:5001").await.unwrap(); + let r1 = tokio::spawn(async move { + let (server, _) = listener.accept().await.unwrap(); + (listener, server) + }); + let client = TcpStream::connect("127.0.0.1:5001").await.unwrap(); + let (_listener, server) = r1.await.unwrap(); + let metrics = Arc::new(ProtocolMetrics::new().unwrap()); + let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics)); + let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics)); + let (s, _) = client.split(); + let (_, mut r) = server.split(); + let e = tokio::spawn(async move { r.recv().await }); + drop(s); + let e = e.await.unwrap(); + assert!(e.is_err()); + assert_eq!(e.unwrap_err(), ProtocolError::Closed); + } } diff --git a/network/src/lib.rs b/network/src/lib.rs index bb14782a69..9981a9c987 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -39,29 +39,27 @@ //! //! # Examples //! ```rust -//! use async_std::task::sleep; -//! use futures::{executor::block_on, join}; +//! use std::sync::Arc; +//! use tokio::{join, runtime::Runtime, time::sleep}; //! use veloren_network::{Network, Pid, Promises, ProtocolAddr}; //! //! // Client -//! async fn client() -> std::result::Result<(), Box> { +//! async fn client(runtime: Arc) -> std::result::Result<(), Box> { //! sleep(std::time::Duration::from_secs(1)).await; // `connect` MUST be after `listen` -//! let (client_network, f) = Network::new(Pid::new()); -//! std::thread::spawn(f); +//! let client_network = Network::new(Pid::new(), runtime); //! let server = client_network //! .connect(ProtocolAddr::Tcp("127.0.0.1:12345".parse().unwrap())) //! .await?; //! let mut stream = server -//! .open(10, Promises::ORDERED | Promises::CONSISTENCY) +//! .open(4, Promises::ORDERED | Promises::CONSISTENCY) //! .await?; //! stream.send("Hello World")?; //! Ok(()) //! } //! //! // Server -//! async fn server() -> std::result::Result<(), Box> { -//! let (server_network, f) = Network::new(Pid::new()); -//! std::thread::spawn(f); +//! async fn server(runtime: Arc) -> std::result::Result<(), Box> { +//! let server_network = Network::new(Pid::new(), runtime); //! server_network //! .listen(ProtocolAddr::Tcp("127.0.0.1:12345".parse().unwrap())) //! .await?; @@ -74,8 +72,10 @@ //! } //! //! fn main() -> std::result::Result<(), Box> { -//! block_on(async { -//! let (result_c, result_s) = join!(client(), server(),); +//! let runtime = Arc::new(Runtime::new().unwrap()); +//! runtime.block_on(async { +//! let (result_c, result_s) = +//! join!(client(Arc::clone(&runtime)), server(Arc::clone(&runtime)),); //! result_c?; //! result_s?; //! Ok(()) @@ -95,23 +95,19 @@ //! [`Streams`]: crate::api::Stream //! [`send`]: crate::api::Stream::send //! [`recv`]: crate::api::Stream::recv -//! [`Pid`]: crate::types::Pid +//! [`Pid`]: network_protocol::Pid //! [`ProtocolAddr`]: crate::api::ProtocolAddr -//! [`Promises`]: crate::types::Promises +//! [`Promises`]: network_protocol::Promises mod api; mod channel; mod message; -#[cfg(feature = "metrics")] mod metrics; +mod metrics; mod participant; -mod prios; -mod protocols; mod scheduler; -#[macro_use] -mod types; pub use api::{ Network, NetworkError, Participant, ParticipantError, ProtocolAddr, Stream, StreamError, }; pub use message::Message; -pub use types::{Pid, Promises}; +pub use network_protocol::{Pid, Promises}; diff --git a/network/src/message.rs b/network/src/message.rs index ad668908b6..27c50abf8d 100644 --- a/network/src/message.rs +++ b/network/src/message.rs @@ -1,12 +1,9 @@ -use serde::{de::DeserializeOwned, Serialize}; -//use std::collections::VecDeque; +use crate::api::{Stream, StreamError}; +use bytes::Bytes; #[cfg(feature = "compression")] -use crate::types::Promises; -use crate::{ - api::{Stream, StreamError}, - types::{Frame, Mid, Sid}, -}; -use std::{io, sync::Arc}; +use network_protocol::Promises; +use serde::{de::DeserializeOwned, Serialize}; +use std::io; #[cfg(all(feature = "compression", debug_assertions))] use tracing::warn; @@ -18,34 +15,11 @@ use tracing::warn; /// [`Stream`]: crate::api::Stream /// [`send_raw`]: crate::api::Stream::send_raw pub struct Message { - pub(crate) buffer: Arc, + pub(crate) data: Bytes, #[cfg(feature = "compression")] pub(crate) compressed: bool, } -//Todo: Evaluate switching to VecDeque for quickly adding and removing data -// from front, back. -// - It would prob require custom bincode code but thats possible. -pub(crate) struct MessageBuffer { - pub data: Vec, -} - -#[derive(Debug)] -pub(crate) struct OutgoingMessage { - pub buffer: Arc, - pub cursor: u64, - pub mid: Mid, - pub sid: Sid, -} - -#[derive(Debug)] -pub(crate) struct IncomingMessage { - pub buffer: MessageBuffer, - pub length: u64, - pub mid: Mid, - pub sid: Sid, -} - impl Message { /// This serializes any message, according to the [`Streams`] [`Promises`]. /// You can reuse this `Message` and send it via other [`Streams`], if the @@ -83,7 +57,7 @@ impl Message { let _stream = stream; Self { - buffer: Arc::new(MessageBuffer { data }), + data: Bytes::from(data), #[cfg(feature = "compression")] compressed, } @@ -98,18 +72,18 @@ impl Message { /// ``` /// # use veloren_network::{Network, ProtocolAddr, Pid}; /// # use veloren_network::Promises; - /// # use futures::executor::block_on; + /// # use tokio::runtime::Runtime; + /// # use std::sync::Arc; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on Port `2300` and wait for a Stream to be opened, then listen on it - /// # let (network, f) = Network::new(Pid::new()); - /// # std::thread::spawn(f); - /// # let (remote, fr) = Network::new(Pid::new()); - /// # std::thread::spawn(fr); - /// # block_on(async { + /// # let runtime = Arc::new(Runtime::new().unwrap()); + /// # let network = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # let remote = Network::new(Pid::new(), Arc::clone(&runtime)); + /// # runtime.block_on(async { /// # network.listen(ProtocolAddr::Tcp("127.0.0.1:2300".parse().unwrap())).await?; /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2300".parse().unwrap())).await?; - /// # let mut stream_p = remote_p.open(16, Promises::ORDERED | Promises::CONSISTENCY).await?; + /// # let mut stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY).await?; /// # stream_p.send("Hello World"); /// # let participant_a = network.connected().await?; /// let mut stream_a = participant_a.opened().await?; @@ -124,33 +98,27 @@ impl Message { /// [`recv_raw`]: crate::api::Stream::recv_raw pub fn deserialize(self) -> Result { #[cfg(not(feature = "compression"))] - let uncompressed_data = match Arc::try_unwrap(self.buffer) { - Ok(d) => d.data, - Err(b) => b.data.clone(), - }; + let uncompressed_data = self.data; #[cfg(feature = "compression")] let uncompressed_data = if self.compressed { { - let mut uncompressed_data = Vec::with_capacity(self.buffer.data.len() * 2); + let mut uncompressed_data = Vec::with_capacity(self.data.len() * 2); if let Err(e) = lz_fear::raw::decompress_raw( - &self.buffer.data, + &self.data, &[0; 0], &mut uncompressed_data, usize::MAX, ) { return Err(StreamError::Compression(e)); } - uncompressed_data + Bytes::from(uncompressed_data) } } else { - match Arc::try_unwrap(self.buffer) { - Ok(d) => d.data, - Err(b) => b.data.clone(), - } + self.data }; - match bincode::deserialize(uncompressed_data.as_slice()) { + match bincode::deserialize(&uncompressed_data) { Ok(m) => Ok(m), Err(e) => Err(StreamError::Deserialize(e)), } @@ -170,38 +138,6 @@ impl Message { } } -impl OutgoingMessage { - pub(crate) const FRAME_DATA_SIZE: u64 = 1400; - - /// returns if msg is empty - pub(crate) fn fill_next>( - &mut self, - msg_sid: Sid, - frames: &mut E, - ) -> bool { - let to_send = std::cmp::min( - self.buffer.data[self.cursor as usize..].len() as u64, - Self::FRAME_DATA_SIZE, - ); - if to_send > 0 { - if self.cursor == 0 { - frames.extend(std::iter::once((msg_sid, Frame::DataHeader { - mid: self.mid, - sid: self.sid, - length: self.buffer.data.len() as u64, - }))); - } - frames.extend(std::iter::once((msg_sid, Frame::Data { - mid: self.mid, - start: self.cursor, - data: self.buffer.data[self.cursor as usize..][..to_send as usize].to_vec(), - }))); - }; - self.cursor += to_send; - self.cursor >= self.buffer.data.len() as u64 - } -} - ///wouldn't trust this aaaassss much, fine for tests pub(crate) fn partial_eq_io_error(first: &io::Error, second: &io::Error) -> bool { if let Some(f) = first.raw_os_error() { @@ -231,36 +167,15 @@ pub(crate) fn partial_eq_bincode(first: &bincode::ErrorKind, second: &bincode::E } } -impl std::fmt::Debug for MessageBuffer { - #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - //TODO: small messages! - let len = self.data.len(); - if len > 20 { - write!( - f, - "MessageBuffer(len: {}, {}, {}, {}, {:X?}..{:X?})", - len, - u32::from_le_bytes([self.data[0], self.data[1], self.data[2], self.data[3]]), - u32::from_le_bytes([self.data[4], self.data[5], self.data[6], self.data[7]]), - u32::from_le_bytes([self.data[8], self.data[9], self.data[10], self.data[11]]), - &self.data[13..16], - &self.data[len - 8..len] - ) - } else { - write!(f, "MessageBuffer(len: {}, {:?})", len, &self.data[..]) - } - } -} - #[cfg(test)] mod tests { use crate::{api::Stream, message::*}; - use futures::channel::mpsc; use std::sync::{atomic::AtomicBool, Arc}; + use tokio::sync::mpsc; fn stub_stream(compressed: bool) -> Stream { - use crate::{api::*, types::*}; + use crate::api::*; + use network_protocol::*; #[cfg(feature = "compression")] let promises = if compressed { @@ -273,14 +188,16 @@ mod tests { let promises = Promises::empty(); let (a2b_msg_s, _a2b_msg_r) = crossbeam_channel::unbounded(); - let (_b2a_msg_recv_s, b2a_msg_recv_r) = mpsc::unbounded(); - let (a2b_close_stream_s, _a2b_close_stream_r) = mpsc::unbounded(); + let (_b2a_msg_recv_s, b2a_msg_recv_r) = async_channel::unbounded(); + let (a2b_close_stream_s, _a2b_close_stream_r) = mpsc::unbounded_channel(); Stream::new( Pid::fake(0), + Pid::fake(1), Sid::new(0), 0u8, promises, + 1_000_000, Arc::new(AtomicBool::new(true)), a2b_msg_s, b2a_msg_recv_r, @@ -291,25 +208,25 @@ mod tests { #[test] fn serialize_test() { let msg = Message::serialize("abc", &stub_stream(false)); - assert_eq!(msg.buffer.data.len(), 11); - assert_eq!(msg.buffer.data[0], 3); - assert_eq!(msg.buffer.data[1..7], [0, 0, 0, 0, 0, 0]); - assert_eq!(msg.buffer.data[8], b'a'); - assert_eq!(msg.buffer.data[9], b'b'); - assert_eq!(msg.buffer.data[10], b'c'); + assert_eq!(msg.data.len(), 11); + assert_eq!(msg.data[0], 3); + assert_eq!(msg.data[1..7], [0, 0, 0, 0, 0, 0]); + assert_eq!(msg.data[8], b'a'); + assert_eq!(msg.data[9], b'b'); + assert_eq!(msg.data[10], b'c'); } #[cfg(feature = "compression")] #[test] fn serialize_compress_small() { let msg = Message::serialize("abc", &stub_stream(true)); - assert_eq!(msg.buffer.data.len(), 12); - assert_eq!(msg.buffer.data[0], 176); - assert_eq!(msg.buffer.data[1], 3); - assert_eq!(msg.buffer.data[2..8], [0, 0, 0, 0, 0, 0]); - assert_eq!(msg.buffer.data[9], b'a'); - assert_eq!(msg.buffer.data[10], b'b'); - assert_eq!(msg.buffer.data[11], b'c'); + assert_eq!(msg.data.len(), 12); + assert_eq!(msg.data[0], 176); + assert_eq!(msg.data[1], 3); + assert_eq!(msg.data[2..8], [0, 0, 0, 0, 0, 0]); + assert_eq!(msg.data[9], b'a'); + assert_eq!(msg.data[10], b'b'); + assert_eq!(msg.data[11], b'c'); } #[cfg(feature = "compression")] @@ -327,14 +244,14 @@ mod tests { "assets/data/plants/flowers/greenrose.ron", ); let msg = Message::serialize(&msg, &stub_stream(true)); - assert_eq!(msg.buffer.data.len(), 79); - assert_eq!(msg.buffer.data[0], 34); - assert_eq!(msg.buffer.data[1], 5); - assert_eq!(msg.buffer.data[2], 0); - assert_eq!(msg.buffer.data[3], 1); - assert_eq!(msg.buffer.data[20], 20); - assert_eq!(msg.buffer.data[40], 115); - assert_eq!(msg.buffer.data[60], 111); + assert_eq!(msg.data.len(), 79); + assert_eq!(msg.data[0], 34); + assert_eq!(msg.data[1], 5); + assert_eq!(msg.data[2], 0); + assert_eq!(msg.data[3], 1); + assert_eq!(msg.data[20], 20); + assert_eq!(msg.data[40], 115); + assert_eq!(msg.data[60], 111); } #[cfg(feature = "compression")] @@ -357,6 +274,6 @@ mod tests { } } let msg = Message::serialize(&msg, &stub_stream(true)); - assert_eq!(msg.buffer.data.len(), 1331); + assert_eq!(msg.data.len(), 1331); } } diff --git a/network/src/metrics.rs b/network/src/metrics.rs index d43aeaae3a..d1b77d76d0 100644 --- a/network/src/metrics.rs +++ b/network/src/metrics.rs @@ -1,16 +1,10 @@ -use crate::types::{Cid, Frame, Pid}; -use prometheus::{ - core::{AtomicU64, GenericCounter}, - IntCounter, IntCounterVec, IntGauge, IntGaugeVec, Opts, Registry, -}; +use network_protocol::{Cid, Pid}; +#[cfg(feature = "metrics")] +use prometheus::{IntCounter, IntCounterVec, IntGauge, IntGaugeVec, Opts, Registry}; use std::error::Error; -use tracing::*; /// 1:1 relation between NetworkMetrics and Network -/// use 2NF here and avoid redundant data like CHANNEL AND PARTICIPANT encoding. -/// as this will cause a matrix that is full of 0 but needs alot of bandwith and -/// storage -#[allow(dead_code)] +#[cfg(feature = "metrics")] pub struct NetworkMetrics { pub listen_requests_total: IntCounterVec, pub connect_requests_total: IntCounterVec, @@ -25,33 +19,13 @@ pub struct NetworkMetrics { pub streams_opened_total: IntCounterVec, pub streams_closed_total: IntCounterVec, pub network_info: IntGauge, - // Frames counted a channel level, seperated by CHANNEL (and PARTICIPANT) AND FRAME TYPE, - pub frames_out_total: IntCounterVec, - pub frames_in_total: IntCounterVec, - // Frames counted at protocol level, seperated by CHANNEL (and PARTICIPANT) AND FRAME TYPE, - pub frames_wire_out_total: IntCounterVec, - pub frames_wire_in_total: IntCounterVec, - // throughput at protocol level, seperated by CHANNEL (and PARTICIPANT), - pub wire_out_throughput: IntCounterVec, - pub wire_in_throughput: IntCounterVec, - // send(prio) Messages count, seperated by STREAM AND PARTICIPANT, - pub message_out_total: IntCounterVec, - // send(prio) Messages throughput, seperated by STREAM AND PARTICIPANT, - pub message_out_throughput: IntCounterVec, - // flushed(prio) stream count, seperated by PARTICIPANT, - pub streams_flushed: IntCounterVec, - // TODO: queued Messages, seperated by STREAM (add PART, CHANNEL), - // queued Messages, seperated by PARTICIPANT - pub queued_count: IntGaugeVec, - // TODO: queued Messages bytes, seperated by STREAM (add PART, CHANNEL), - // queued Messages bytes, seperated by PARTICIPANT - pub queued_bytes: IntGaugeVec, - // ping calculated based on last msg seperated by PARTICIPANT - pub participants_ping: IntGaugeVec, } +#[cfg(not(feature = "metrics"))] +pub struct NetworkMetrics {} + +#[cfg(feature = "metrics")] impl NetworkMetrics { - #[allow(dead_code)] pub fn new(local_pid: &Pid) -> Result> { let listen_requests_total = IntCounterVec::new( Opts::new( @@ -115,99 +89,13 @@ impl NetworkMetrics { "version", &format!( "{}.{}.{}", - &crate::types::VELOREN_NETWORK_VERSION[0], - &crate::types::VELOREN_NETWORK_VERSION[1], - &crate::types::VELOREN_NETWORK_VERSION[2] + &network_protocol::VELOREN_NETWORK_VERSION[0], + &network_protocol::VELOREN_NETWORK_VERSION[1], + &network_protocol::VELOREN_NETWORK_VERSION[2] ), ) .const_label("local_pid", &format!("{}", &local_pid)); let network_info = IntGauge::with_opts(opts)?; - let frames_out_total = IntCounterVec::new( - Opts::new( - "frames_out_total", - "Number of all frames send per channel, at the channel level", - ), - &["channel", "frametype"], - )?; - let frames_in_total = IntCounterVec::new( - Opts::new( - "frames_in_total", - "Number of all frames received per channel, at the channel level", - ), - &["channel", "frametype"], - )?; - let frames_wire_out_total = IntCounterVec::new( - Opts::new( - "frames_wire_out_total", - "Number of all frames send per channel, at the protocol level", - ), - &["channel", "frametype"], - )?; - let frames_wire_in_total = IntCounterVec::new( - Opts::new( - "frames_wire_in_total", - "Number of all frames received per channel, at the protocol level", - ), - &["channel", "frametype"], - )?; - let wire_out_throughput = IntCounterVec::new( - Opts::new( - "wire_out_throughput", - "Throupgput of all data frames send per channel, at the protocol level", - ), - &["channel"], - )?; - let wire_in_throughput = IntCounterVec::new( - Opts::new( - "wire_in_throughput", - "Throupgput of all data frames send per channel, at the protocol level", - ), - &["channel"], - )?; - //TODO IN - let message_out_total = IntCounterVec::new( - Opts::new( - "message_out_total", - "Number of messages send by streams on the network", - ), - &["participant", "stream"], - )?; - //TODO IN - let message_out_throughput = IntCounterVec::new( - Opts::new( - "message_out_throughput", - "Throughput of messages send by streams on the network", - ), - &["participant", "stream"], - )?; - let streams_flushed = IntCounterVec::new( - Opts::new( - "stream_flushed", - "Number of flushed streams requested to PrioManager at participant level", - ), - &["participant"], - )?; - let queued_count = IntGaugeVec::new( - Opts::new( - "queued_count", - "Queued number of messages by participant on the network", - ), - &["channel"], - )?; - let queued_bytes = IntGaugeVec::new( - Opts::new( - "queued_bytes", - "Queued bytes of messages by participant on the network", - ), - &["channel"], - )?; - let participants_ping = IntGaugeVec::new( - Opts::new( - "participants_ping", - "Ping time to participants on the network", - ), - &["channel"], - )?; Ok(Self { listen_requests_total, @@ -220,18 +108,6 @@ impl NetworkMetrics { streams_opened_total, streams_closed_total, network_info, - frames_out_total, - frames_in_total, - frames_wire_out_total, - frames_wire_in_total, - wire_out_throughput, - wire_in_throughput, - message_out_total, - message_out_throughput, - streams_flushed, - queued_count, - queued_bytes, - participants_ping, }) } @@ -246,22 +122,48 @@ impl NetworkMetrics { registry.register(Box::new(self.streams_opened_total.clone()))?; registry.register(Box::new(self.streams_closed_total.clone()))?; registry.register(Box::new(self.network_info.clone()))?; - registry.register(Box::new(self.frames_out_total.clone()))?; - registry.register(Box::new(self.frames_in_total.clone()))?; - registry.register(Box::new(self.frames_wire_out_total.clone()))?; - registry.register(Box::new(self.frames_wire_in_total.clone()))?; - registry.register(Box::new(self.wire_out_throughput.clone()))?; - registry.register(Box::new(self.wire_in_throughput.clone()))?; - registry.register(Box::new(self.message_out_total.clone()))?; - registry.register(Box::new(self.message_out_throughput.clone()))?; - registry.register(Box::new(self.queued_count.clone()))?; - registry.register(Box::new(self.queued_bytes.clone()))?; - registry.register(Box::new(self.participants_ping.clone()))?; Ok(()) } - //pub fn _is_100th_tick(&self) -> bool { - // self.tick.load(Ordering::Relaxed).rem_euclid(100) == 0 } + pub(crate) fn channels_connected(&self, remote_p: &str, no: usize, cid: Cid) { + self.channels_connected_total + .with_label_values(&[remote_p]) + .inc(); + self.participants_channel_ids + .with_label_values(&[remote_p, &no.to_string()]) + .set(cid as i64); + } + + pub(crate) fn channels_disconnected(&self, remote_p: &str) { + self.channels_disconnected_total + .with_label_values(&[remote_p]) + .inc(); + } + + pub(crate) fn streams_opened(&self, remote_p: &str) { + self.streams_opened_total + .with_label_values(&[remote_p]) + .inc(); + } + + pub(crate) fn streams_closed(&self, remote_p: &str) { + self.streams_closed_total + .with_label_values(&[remote_p]) + .inc(); + } +} + +#[cfg(not(feature = "metrics"))] +impl NetworkMetrics { + pub fn new(_local_pid: &Pid) -> Result> { Ok(Self {}) } + + pub(crate) fn channels_connected(&self, _remote_p: &str, _no: usize, _cid: Cid) {} + + pub(crate) fn channels_disconnected(&self, _remote_p: &str) {} + + pub(crate) fn streams_opened(&self, _remote_p: &str) {} + + pub(crate) fn streams_closed(&self, _remote_p: &str) {} } impl std::fmt::Debug for NetworkMetrics { @@ -270,138 +172,3 @@ impl std::fmt::Debug for NetworkMetrics { write!(f, "NetworkMetrics()") } } - -/* -pub(crate) struct PidCidFrameCache { - metric: MetricVec, - pid: String, - cache: Vec<[T::M; 8]>, -} -*/ - -pub(crate) struct MultiCidFrameCache { - metric: IntCounterVec, - cache: Vec<[Option>; Frame::FRAMES_LEN as usize]>, -} - -impl MultiCidFrameCache { - const CACHE_SIZE: usize = 2048; - - pub fn new(metric: IntCounterVec) -> Self { - Self { - metric, - cache: vec![], - } - } - - fn populate(&mut self, cid: Cid) { - let start_cid = self.cache.len(); - if cid >= start_cid as u64 && cid > (Self::CACHE_SIZE as Cid) { - warn!( - ?cid, - "cid, getting quite high, is this a attack on the cache?" - ); - } - self.cache.resize((cid + 1) as usize, [ - None, None, None, None, None, None, None, None, - ]); - } - - pub fn with_label_values(&mut self, cid: Cid, frame: &Frame) -> &GenericCounter { - self.populate(cid); - let frame_int = frame.get_int() as usize; - let r = &mut self.cache[cid as usize][frame_int]; - if r.is_none() { - *r = Some( - self.metric - .with_label_values(&[&cid.to_string(), &frame_int.to_string()]), - ); - } - r.as_ref().unwrap() - } -} - -pub(crate) struct CidFrameCache { - cache: [GenericCounter; Frame::FRAMES_LEN as usize], -} - -impl CidFrameCache { - pub fn new(metric: IntCounterVec, cid: Cid) -> Self { - let cid = cid.to_string(); - let cache = [ - metric.with_label_values(&[&cid, Frame::int_to_string(0)]), - metric.with_label_values(&[&cid, Frame::int_to_string(1)]), - metric.with_label_values(&[&cid, Frame::int_to_string(2)]), - metric.with_label_values(&[&cid, Frame::int_to_string(3)]), - metric.with_label_values(&[&cid, Frame::int_to_string(4)]), - metric.with_label_values(&[&cid, Frame::int_to_string(5)]), - metric.with_label_values(&[&cid, Frame::int_to_string(6)]), - metric.with_label_values(&[&cid, Frame::int_to_string(7)]), - ]; - Self { cache } - } - - pub fn with_label_values(&mut self, frame: &Frame) -> &GenericCounter { - &self.cache[frame.get_int() as usize] - } -} - -#[cfg(test)] -mod tests { - use crate::{ - metrics::*, - types::{Frame, Pid}, - }; - - #[test] - fn register_metrics() { - let registry = Registry::new(); - let metrics = NetworkMetrics::new(&Pid::fake(1)).unwrap(); - metrics.register(®istry).unwrap(); - } - - #[test] - fn multi_cid_frame_cache() { - let pid = Pid::fake(1); - let frame1 = Frame::Raw(b"Foo".to_vec()); - let frame2 = Frame::Raw(b"Bar".to_vec()); - let metrics = NetworkMetrics::new(&pid).unwrap(); - let mut cache = MultiCidFrameCache::new(metrics.frames_in_total); - let v1 = cache.with_label_values(1, &frame1); - v1.inc(); - assert_eq!(v1.get(), 1); - let v2 = cache.with_label_values(1, &frame1); - v2.inc(); - assert_eq!(v2.get(), 2); - let v3 = cache.with_label_values(1, &frame2); - v3.inc(); - assert_eq!(v3.get(), 3); - let v4 = cache.with_label_values(3, &frame1); - v4.inc(); - assert_eq!(v4.get(), 1); - let v5 = cache.with_label_values(3, &Frame::Shutdown); - v5.inc(); - assert_eq!(v5.get(), 1); - } - - #[test] - fn cid_frame_cache() { - let pid = Pid::fake(1); - let frame1 = Frame::Raw(b"Foo".to_vec()); - let frame2 = Frame::Raw(b"Bar".to_vec()); - let metrics = NetworkMetrics::new(&pid).unwrap(); - let mut cache = CidFrameCache::new(metrics.frames_wire_out_total, 1); - let v1 = cache.with_label_values(&frame1); - v1.inc(); - assert_eq!(v1.get(), 1); - let v2 = cache.with_label_values(&frame1); - v2.inc(); - assert_eq!(v2.get(), 2); - let v3 = cache.with_label_values(&frame2); - v3.inc(); - assert_eq!(v3.get(), 3); - let v4 = cache.with_label_values(&Frame::Shutdown); - v4.inc(); - assert_eq!(v4.get(), 1); - } -} diff --git a/network/src/participant.rs b/network/src/participant.rs index 8e0d0a1904..c7d6a9b64d 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -1,44 +1,38 @@ -#[cfg(feature = "metrics")] -use crate::metrics::{MultiCidFrameCache, NetworkMetrics}; use crate::{ api::{ParticipantError, Stream}, - channel::Channel, - message::{IncomingMessage, MessageBuffer, OutgoingMessage}, - prios::PrioManager, - protocols::Protocols, - types::{Cid, Frame, Pid, Prio, Promises, Sid}, + channel::{Protocols, RecvProtocols, SendProtocols}, + metrics::NetworkMetrics, }; -use async_std::sync::{Mutex, RwLock}; -use futures::{ - channel::{mpsc, oneshot}, - future::FutureExt, - select, - sink::SinkExt, - stream::StreamExt, +use bytes::Bytes; +use futures_util::{FutureExt, StreamExt}; +use network_protocol::{ + Bandwidth, Cid, Pid, Prio, Promises, ProtocolEvent, RecvProtocol, SendProtocol, Sid, }; use std::{ - collections::{HashMap, VecDeque}, + collections::HashMap, sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, + atomic::{AtomicBool, AtomicI32, Ordering}, Arc, }, time::{Duration, Instant}, }; +use tokio::{ + select, + sync::{mpsc, oneshot, Mutex, RwLock}, + task::JoinHandle, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::*; -use tracing_futures::Instrument; -pub(crate) type A2bStreamOpen = (Prio, Promises, oneshot::Sender); -pub(crate) type C2pFrame = (Cid, Result); -pub(crate) type S2bCreateChannel = (Cid, Sid, Protocols, Vec, oneshot::Sender<()>); -pub(crate) type S2bShutdownBparticipant = oneshot::Sender>; +pub(crate) type A2bStreamOpen = (Prio, Promises, Bandwidth, oneshot::Sender); +pub(crate) type S2bCreateChannel = (Cid, Sid, Protocols, oneshot::Sender<()>); +pub(crate) type S2bShutdownBparticipant = (Duration, oneshot::Sender>); pub(crate) type B2sPrioStatistic = (Pid, u64, u64); #[derive(Debug)] struct ChannelInfo { cid: Cid, cid_string: String, //optimisationmetrics - b2w_frame_s: mpsc::UnboundedSender, - b2r_read_shutdown: oneshot::Sender<()>, } #[derive(Debug)] @@ -46,47 +40,51 @@ struct StreamInfo { prio: Prio, promises: Promises, send_closed: Arc, - b2a_msg_recv_s: Mutex>, + b2a_msg_recv_s: Mutex>, } #[derive(Debug)] struct ControlChannels { - a2b_stream_open_r: mpsc::UnboundedReceiver, + a2b_open_stream_r: mpsc::UnboundedReceiver, b2a_stream_opened_s: mpsc::UnboundedSender, s2b_create_channel_r: mpsc::UnboundedReceiver, - a2b_close_stream_r: mpsc::UnboundedReceiver, - a2b_close_stream_s: mpsc::UnboundedSender, s2b_shutdown_bparticipant_r: oneshot::Receiver, /* own */ } #[derive(Debug)] struct ShutdownInfo { - //a2b_stream_open_r: mpsc::UnboundedReceiver, - b2a_stream_opened_s: mpsc::UnboundedSender, + b2b_close_stream_opened_sender_s: Option>, error: Option, } #[derive(Debug)] pub struct BParticipant { + local_pid: Pid, //tracing remote_pid: Pid, remote_pid_string: String, //optimisation offset_sid: Sid, channels: Arc>>>, streams: RwLock>, - running_mgr: AtomicUsize, run_channels: Option, - #[cfg(feature = "metrics")] + shutdown_barrier: AtomicI32, metrics: Arc, no_channel_error_info: RwLock<(Instant, u64)>, - shutdown_info: RwLock, } impl BParticipant { + // We use integer instead of Barrier to not block mgr from freeing at the end + const BARR_CHANNEL: i32 = 1; + const BARR_RECV: i32 = 4; + const BARR_SEND: i32 = 2; + const TICK_TIME: Duration = Duration::from_millis(Self::TICK_TIME_MS); + const TICK_TIME_MS: u64 = 10; + #[allow(clippy::type_complexity)] pub(crate) fn new( + local_pid: Pid, remote_pid: Pid, offset_sid: Sid, - #[cfg(feature = "metrics")] metrics: Arc, + metrics: Arc, ) -> ( Self, mpsc::UnboundedSender, @@ -94,42 +92,34 @@ impl BParticipant { mpsc::UnboundedSender, oneshot::Sender, ) { - let (a2b_steam_open_s, a2b_stream_open_r) = mpsc::unbounded::(); - let (b2a_stream_opened_s, b2a_stream_opened_r) = mpsc::unbounded::(); - let (a2b_close_stream_s, a2b_close_stream_r) = mpsc::unbounded(); + let (a2b_open_stream_s, a2b_open_stream_r) = mpsc::unbounded_channel::(); + let (b2a_stream_opened_s, b2a_stream_opened_r) = mpsc::unbounded_channel::(); let (s2b_shutdown_bparticipant_s, s2b_shutdown_bparticipant_r) = oneshot::channel(); - let (s2b_create_channel_s, s2b_create_channel_r) = mpsc::unbounded(); - - let shutdown_info = RwLock::new(ShutdownInfo { - //a2b_stream_open_r: a2b_stream_open_r.clone(), - b2a_stream_opened_s: b2a_stream_opened_s.clone(), - error: None, - }); + let (s2b_create_channel_s, s2b_create_channel_r) = mpsc::unbounded_channel(); let run_channels = Some(ControlChannels { - a2b_stream_open_r, + a2b_open_stream_r, b2a_stream_opened_s, s2b_create_channel_r, - a2b_close_stream_r, - a2b_close_stream_s, s2b_shutdown_bparticipant_r, }); ( Self { + local_pid, remote_pid, remote_pid_string: remote_pid.to_string(), offset_sid, channels: Arc::new(RwLock::new(HashMap::new())), streams: RwLock::new(HashMap::new()), - running_mgr: AtomicUsize::new(0), + shutdown_barrier: AtomicI32::new( + Self::BARR_CHANNEL + Self::BARR_SEND + Self::BARR_RECV, + ), run_channels, - #[cfg(feature = "metrics")] metrics, no_channel_error_info: RwLock::new((Instant::now(), 0)), - shutdown_info, }, - a2b_steam_open_s, + a2b_open_stream_s, b2a_stream_opened_r, s2b_create_channel_s, s2b_shutdown_bparticipant_s, @@ -137,683 +127,496 @@ impl BParticipant { } pub async fn run(mut self, b2s_prio_statistic_s: mpsc::UnboundedSender) { - //those managers that listen on api::Participant need an additional oneshot for - // shutdown scenario, those handled by scheduler will be closed by it. - let (shutdown_send_mgr_sender, shutdown_send_mgr_receiver) = oneshot::channel(); - let (shutdown_stream_close_mgr_sender, shutdown_stream_close_mgr_receiver) = - oneshot::channel(); - let (shutdown_open_mgr_sender, shutdown_open_mgr_receiver) = oneshot::channel(); - let (w2b_frames_s, w2b_frames_r) = mpsc::unbounded::(); - let (prios, a2p_msg_s, b2p_notify_empty_stream_s) = PrioManager::new( - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - self.remote_pid_string.clone(), - ); + let (b2b_add_send_protocol_s, b2b_add_send_protocol_r) = + mpsc::unbounded_channel::<(Cid, SendProtocols)>(); + let (b2b_add_recv_protocol_s, b2b_add_recv_protocol_r) = + mpsc::unbounded_channel::<(Cid, RecvProtocols)>(); + let (b2b_close_send_protocol_s, b2b_close_send_protocol_r) = + async_channel::unbounded::(); + let (b2b_force_close_recv_protocol_s, b2b_force_close_recv_protocol_r) = + async_channel::unbounded::(); + let (b2b_notify_send_of_recv_s, b2b_notify_send_of_recv_r) = + crossbeam_channel::unbounded::(); + + let (a2b_close_stream_s, a2b_close_stream_r) = mpsc::unbounded_channel::(); + const STREAM_BOUND: usize = 10_000; + let (a2b_msg_s, a2b_msg_r) = crossbeam_channel::bounded::<(Sid, Bytes)>(STREAM_BOUND); let run_channels = self.run_channels.take().unwrap(); - futures::join!( - self.open_mgr( - run_channels.a2b_stream_open_r, - run_channels.a2b_close_stream_s.clone(), - a2p_msg_s.clone(), - shutdown_open_mgr_receiver, - ), - self.handle_frames_mgr( - w2b_frames_r, + trace!("start all managers"); + tokio::join!( + self.send_mgr( + run_channels.a2b_open_stream_r, + a2b_close_stream_r, + a2b_msg_r, + b2b_add_send_protocol_r, + b2b_close_send_protocol_r, + b2b_notify_send_of_recv_r, + b2s_prio_statistic_s, + a2b_msg_s.clone(), //self + a2b_close_stream_s.clone(), //self + ) + .instrument(tracing::info_span!("send")), + self.recv_mgr( run_channels.b2a_stream_opened_s, - run_channels.a2b_close_stream_s, - a2p_msg_s.clone(), - ), - self.create_channel_mgr(run_channels.s2b_create_channel_r, w2b_frames_s), - self.send_mgr(prios, shutdown_send_mgr_receiver, b2s_prio_statistic_s), - self.stream_close_mgr( - run_channels.a2b_close_stream_r, - shutdown_stream_close_mgr_receiver, - b2p_notify_empty_stream_s, + b2b_add_recv_protocol_r, + b2b_force_close_recv_protocol_r, + b2b_close_send_protocol_s.clone(), + b2b_notify_send_of_recv_s, + a2b_msg_s.clone(), //self + a2b_close_stream_s.clone(), //self + ) + .instrument(tracing::info_span!("recv")), + self.create_channel_mgr( + run_channels.s2b_create_channel_r, + b2b_add_send_protocol_s, + b2b_add_recv_protocol_s, ), self.participant_shutdown_mgr( run_channels.s2b_shutdown_bparticipant_r, - shutdown_open_mgr_sender, - shutdown_stream_close_mgr_sender, - shutdown_send_mgr_sender, + b2b_close_send_protocol_s.clone(), + b2b_force_close_recv_protocol_s, ), ); } + //TODO: local stream_cid: HashMap to know the respective protocol + #[allow(clippy::too_many_arguments)] async fn send_mgr( &self, - mut prios: PrioManager, - mut shutdown_send_mgr_receiver: oneshot::Receiver>, - mut b2s_prio_statistic_s: mpsc::UnboundedSender, + mut a2b_open_stream_r: mpsc::UnboundedReceiver, + mut a2b_close_stream_r: mpsc::UnboundedReceiver, + a2b_msg_r: crossbeam_channel::Receiver<(Sid, Bytes)>, + mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, SendProtocols)>, + b2b_close_send_protocol_r: async_channel::Receiver, + b2b_notify_send_of_recv_r: crossbeam_channel::Receiver, + _b2s_prio_statistic_s: mpsc::UnboundedSender, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Bytes)>, + a2b_close_stream_s: mpsc::UnboundedSender, ) { - //This time equals the MINIMUM Latency in average, so keep it down and //Todo: - // make it configurable or switch to await E.g. Prio 0 = await, prio 50 - // wait for more messages - const TICK_TIME: Duration = Duration::from_millis(10); - const FRAMES_PER_TICK: usize = 10005; - self.running_mgr.fetch_add(1, Ordering::Relaxed); - let mut b2b_prios_flushed_s = None; //closing up - trace!("Start send_mgr"); - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - let mut i: u64 = 0; + let mut send_protocols: HashMap = HashMap::new(); + let mut interval = tokio::time::interval(Self::TICK_TIME); + let mut last_instant = Instant::now(); + let mut stream_ids = self.offset_sid; + let mut fake_mid = 0; //TODO: move MID to protocol, should be inc per stream ? or ? + trace!("workaround, actively wait for first protocol"); + b2b_add_protocol_r + .recv() + .await + .map(|(c, p)| send_protocols.insert(c, p)); loop { - let mut frames = VecDeque::new(); - prios.fill_frames(FRAMES_PER_TICK, &mut frames).await; - let len = frames.len(); - for (_, frame) in frames { - self.send_frame( - frame, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await; + let (open, close, _, addp, remp) = select!( + Some(n) = a2b_open_stream_r.recv().fuse() => (Some(n), None, None, None, None), + Some(n) = a2b_close_stream_r.recv().fuse() => (None, Some(n), None, None, None), + _ = interval.tick() => (None, None, Some(()), None, None), + Some(n) = b2b_add_protocol_r.recv().fuse() => (None, None, None, Some(n), None), + Ok(n) = b2b_close_send_protocol_r.recv().fuse() => (None, None, None, None, Some(n)), + ); + + addp.map(|(cid, p)| { + debug!(?cid, "add protocol"); + send_protocols.insert(cid, p) + }); + + let (cid, active) = match send_protocols.iter_mut().next() { + Some((cid, a)) => (*cid, a), + None => { + warn!("no channel"); + continue; + }, + }; + + let active_err = async { + if let Some((prio, promises, guaranteed_bandwidth, return_s)) = open { + let sid = stream_ids; + trace!(?sid, "open stream"); + stream_ids += Sid::from(1); + let stream = self + .create_stream( + sid, + prio, + promises, + guaranteed_bandwidth, + &a2b_msg_s, + &a2b_close_stream_s, + ) + .await; + + let event = ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + }; + + return_s.send(stream).unwrap(); + active.send(event).await?; + } + + // process recv content first + let mut closeevents = b2b_notify_send_of_recv_r + .try_iter() + .map(|e| { + if matches!(e, ProtocolEvent::OpenStream { .. }) { + active.notify_from_recv(e); + None + } else { + Some(e) + } + }) + .collect::>(); + + // get all messages and assign it to a channel + for (sid, buffer) in a2b_msg_r.try_iter() { + fake_mid += 1; + active + .send(ProtocolEvent::Message { + data: buffer, + mid: fake_mid, + sid, + }) + .await? + } + + // process recv content afterwards + let _ = closeevents.drain(..).map(|e| { + if let Some(e) = e { + active.notify_from_recv(e); + } + }); + + if let Some(sid) = close { + trace!(?stream_ids, "delete stream"); + self.delete_stream(sid).await; + // Fire&Forget the protocol will take care to verify that this Frame is delayed + // till the last msg was received! + active.send(ProtocolEvent::CloseStream { sid }).await?; + } + + let send_time = Instant::now(); + let diff = send_time.duration_since(last_instant); + last_instant = send_time; + active.flush(1_000_000_000, diff).await?; //this actually blocks, so we cant set streams while it. + let r: Result<(), network_protocol::ProtocolError> = Ok(()); + r } - b2s_prio_statistic_s - .send((self.remote_pid, len as u64, /* */ 0)) - .await - .unwrap(); - async_std::task::sleep(TICK_TIME).await; - i += 1; - if i.rem_euclid(1000) == 0 { - trace!("Did 1000 ticks"); + .await; + if let Err(e) = active_err { + info!(?cid, ?e, "protocol failed, shutting down channel"); + // remote recv will now fail, which will trigger remote send which will trigger + // recv + send_protocols.remove(&cid).unwrap(); + self.metrics.channels_disconnected(&self.remote_pid_string); } - //shutdown after all msg are send! - // Make sure this is called after the API is closed, and all streams are known - // to be droped to the priomgr - if b2b_prios_flushed_s.is_some() && (len == 0) { - break; - } - if b2b_prios_flushed_s.is_none() { - if let Some(prios_flushed_s) = shutdown_send_mgr_receiver.try_recv().unwrap() { - b2b_prios_flushed_s = Some(prios_flushed_s); + + if let Some(cid) = remp { + debug!(?cid, "remove protocol"); + match send_protocols.remove(&cid) { + Some(mut prot) => { + self.metrics.channels_disconnected(&self.remote_pid_string); + trace!("blocking flush"); + let _ = prot.flush(u64::MAX, Duration::from_secs(1)).await; + trace!("shutdown prot"); + let _ = prot.send(ProtocolEvent::Shutdown).await; + }, + None => trace!("tried to remove protocol twice"), + }; + if send_protocols.is_empty() { + break; } } } trace!("Stop send_mgr"); - b2b_prios_flushed_s - .expect("b2b_prios_flushed_s not set") - .send(()) - .unwrap(); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); + self.shutdown_barrier + .fetch_sub(Self::BARR_SEND, Ordering::Relaxed); } - //returns false if sending isn't possible. In that case we have to render the - // Participant `closed` - #[must_use = "You need to check if the send was successful and report to client!"] - async fn send_frame( + #[allow(clippy::too_many_arguments)] + async fn recv_mgr( &self, - frame: Frame, - #[cfg(feature = "metrics")] frames_out_total_cache: &mut MultiCidFrameCache, - ) -> bool { - let mut drop_cid = None; - // TODO: find out ideal channel here - let res = if let Some(ci) = self.channels.read().await.values().next() { - let mut ci = ci.lock().await; - //we are increasing metrics without checking the result to please - // borrow_checker. otherwise we would need to close `frame` what we - // dont want! - #[cfg(feature = "metrics")] - frames_out_total_cache - .with_label_values(ci.cid, &frame) - .inc(); - if let Err(e) = ci.b2w_frame_s.send(frame).await { - let cid = ci.cid; - info!(?e, ?cid, "channel no longer available"); - drop_cid = Some(cid); - false - } else { - true - } - } else { - let mut guard = self.no_channel_error_info.write().await; - let now = Instant::now(); - if now.duration_since(guard.0) > Duration::from_secs(1) { - guard.0 = now; - let occurrences = guard.1 + 1; - guard.1 = 0; - let lastframe = frame; - error!( - ?occurrences, - ?lastframe, - "Participant has no channel to communicate on" - ); - } else { - guard.1 += 1; - } - false - }; - if let Some(cid) = drop_cid { - if let Some(ci) = self.channels.write().await.remove(&cid) { - let ci = ci.into_inner(); - trace!(?cid, "stopping read protocol"); - if let Err(e) = ci.b2r_read_shutdown.send(()) { - trace!(?cid, ?e, "seems like was already shut down"); - } - ci.b2w_frame_s.close_channel(); - } - //TODO FIXME tags: takeover channel multiple - info!( - "FIXME: the frame is actually drop. which is fine for now as the participant will \ - be closed, but not if we do channel-takeover" - ); - //TEMP FIX: as we dont have channel takeover yet drop the whole bParticipant - self.close_write_api(Some(ParticipantError::ProtocolFailedUnrecoverable)) - .await; - }; - res - } - - async fn handle_frames_mgr( - &self, - mut w2b_frames_r: mpsc::UnboundedReceiver, - mut b2a_stream_opened_s: mpsc::UnboundedSender, + b2a_stream_opened_s: mpsc::UnboundedSender, + mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, RecvProtocols)>, + b2b_force_close_recv_protocol_r: async_channel::Receiver, + b2b_close_send_protocol_s: async_channel::Sender, + b2b_notify_send_of_recv_s: crossbeam_channel::Sender, + a2b_msg_s: crossbeam_channel::Sender<(Sid, Bytes)>, a2b_close_stream_s: mpsc::UnboundedSender, - a2p_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); - trace!("Start handle_frames_mgr"); - let mut messages = HashMap::new(); - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - let mut dropped_instant = Instant::now(); - let mut dropped_cnt = 0u64; - let mut dropped_sid = Sid::new(0); + let mut recv_protocols: HashMap> = HashMap::new(); + // we should be able to directly await futures imo + let (hacky_recv_s, mut hacky_recv_r) = mpsc::unbounded_channel(); - while let Some((cid, result_frame)) = w2b_frames_r.next().await { - //trace!(?result_frame, "handling frame"); - let frame = match result_frame { - Ok(frame) => frame, - Err(()) => { - // The read protocol stopped, i need to make sure that write gets stopped - debug!("read protocol was closed. Stopping write protocol"); - if let Some(ci) = self.channels.read().await.get(&cid) { - let mut ci = ci.lock().await; - ci.b2w_frame_s - .close() - .await - .expect("couldn't stop write protocol"); - } - continue; + let retrigger = |cid: Cid, mut p: RecvProtocols, map: &mut HashMap<_, _>| { + let hacky_recv_s = hacky_recv_s.clone(); + let handle = tokio::spawn(async move { + let cid = cid; + let r = p.recv().await; + let _ = hacky_recv_s.send((cid, r, p)); // ignoring failed + }); + map.insert(cid, handle); + }; + + let remove_c = |recv_protocols: &mut HashMap>, cid: &Cid| { + match recv_protocols.remove(&cid) { + Some(h) => { + h.abort(); + debug!(?cid, "remove protocol"); }, + None => trace!("tried to remove protocol twice"), }; - #[cfg(feature = "metrics")] - { - let cid_string = cid.to_string(); - self.metrics - .frames_in_total - .with_label_values(&[&cid_string, frame.get_string()]) - .inc(); - } - match frame { - Frame::OpenStream { - sid, - prio, - promises, - } => { - trace!(?sid, ?prio, ?promises, "Opened frame from remote"); - let a2p_msg_s = a2p_msg_s.clone(); - let stream = self - .create_stream(sid, prio, promises, a2p_msg_s, &a2b_close_stream_s) - .await; - if let Err(e) = b2a_stream_opened_s.send(stream).await { - warn!( - ?e, - ?sid, - "couldn't notify api::Participant that a stream got opened. Is the \ - participant already dropped?" - ); - } - }, - Frame::CloseStream { sid } => { - // no need to keep flushing as the remote no longer knows about this stream - // anyway - self.delete_stream( - sid, - None, - true, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await; - }, - Frame::DataHeader { mid, sid, length } => { - let imsg = IncomingMessage { - buffer: MessageBuffer { data: Vec::new() }, - length, - mid, - sid, - }; - messages.insert(mid, imsg); - }, - Frame::Data { - mid, - start: _, - mut data, - } => { - let finished = if let Some(imsg) = messages.get_mut(&mid) { - imsg.buffer.data.append(&mut data); - imsg.buffer.data.len() as u64 == imsg.length - } else { - false - }; - if finished { - //trace!(?mid, "finished receiving message"); - let imsg = messages.remove(&mid).unwrap(); - if let Some(si) = self.streams.read().await.get(&imsg.sid) { - if let Err(e) = si.b2a_msg_recv_s.lock().await.send(imsg).await { - warn!( - ?e, - ?mid, - "Dropping message, as streams seem to be in act of being \ - dropped right now" - ); - } - } else { - //aggregate errors - let n = Instant::now(); - if dropped_cnt > 0 - && (dropped_sid != imsg.sid - || n.duration_since(dropped_instant) > Duration::from_secs(1)) - { - warn!( - ?dropped_cnt, - "Dropping multiple messages as stream no longer seems to \ - exist because it was dropped probably." - ); - dropped_cnt = 0; - dropped_instant = n; - dropped_sid = imsg.sid; - } else { - dropped_cnt += 1; - } - } - } - }, - Frame::Shutdown => { - debug!("Shutdown received from remote side"); - self.close_api(Some(ParticipantError::ParticipantDisconnected)) - .await; - }, - f => { - unreachable!( - "Frame should never reach participant!: {:?}, cid: {}", - f, cid - ); - }, - } - } - if dropped_cnt > 0 { - warn!( - ?dropped_cnt, - "Dropping multiple messages as stream no longer seems to exist because it was \ - dropped probably." + recv_protocols.is_empty() + }; + + loop { + let (event, addp, remp) = select!( + Some(n) = hacky_recv_r.recv().fuse() => (Some(n), None, None), + Some(n) = b2b_add_protocol_r.recv().fuse() => (None, Some(n), None), + Ok(n) = b2b_force_close_recv_protocol_r.recv().fuse() => (None, None, Some(n)), + else => { + error!("recv_mgr -> something is seriously wrong!, end recv_mgr"); + break; + } ); + + if let Some((cid, p)) = addp { + debug!(?cid, "add protocol"); + retrigger(cid, p, &mut recv_protocols); + }; + if let Some(cid) = remp { + // no need to stop the send_mgr here as it has been canceled before + if remove_c(&mut recv_protocols, &cid) { + break; + } + }; + + if let Some((cid, r, p)) = event { + match r { + Ok(ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + }) => { + trace!(?sid, "open stream"); + let _ = b2b_notify_send_of_recv_s.send(r.unwrap()); + // waiting for receiving is not necessary, because the send_mgr will first + // process this before process messages! + let stream = self + .create_stream( + sid, + prio, + promises, + guaranteed_bandwidth, + &a2b_msg_s, + &a2b_close_stream_s, + ) + .await; + b2a_stream_opened_s.send(stream).unwrap(); + retrigger(cid, p, &mut recv_protocols); + }, + Ok(ProtocolEvent::CloseStream { sid }) => { + trace!(?sid, "close stream"); + let _ = b2b_notify_send_of_recv_s.send(r.unwrap()); + self.delete_stream(sid).await; + retrigger(cid, p, &mut recv_protocols); + }, + Ok(ProtocolEvent::Message { data, mid: _, sid }) => { + let lock = self.streams.read().await; + match lock.get(&sid) { + Some(stream) => { + let _ = stream.b2a_msg_recv_s.lock().await.send(data).await; + }, + None => warn!("recv a msg with orphan stream"), + }; + retrigger(cid, p, &mut recv_protocols); + }, + Ok(ProtocolEvent::Shutdown) => { + info!(?cid, "shutdown protocol"); + if let Err(e) = b2b_close_send_protocol_s.send(cid).await { + debug!(?e, ?cid, "send_mgr was already closed simultaneously"); + } + if remove_c(&mut recv_protocols, &cid) { + break; + } + }, + Err(e) => { + info!(?e, ?cid, "protocol failed, shutting down channel"); + if let Err(e) = b2b_close_send_protocol_s.send(cid).await { + debug!(?e, ?cid, "send_mgr was already closed simultaneously"); + } + if remove_c(&mut recv_protocols, &cid) { + break; + } + }, + } + } } - trace!("Stop handle_frames_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); + trace!("receiving no longer possible, closing all streams"); + for (_, si) in self.streams.write().await.drain() { + si.send_closed.store(true, Ordering::Relaxed); + self.metrics.streams_closed(&self.remote_pid_string); + } + trace!("Stop recv_mgr"); + self.shutdown_barrier + .fetch_sub(Self::BARR_RECV, Ordering::Relaxed); } async fn create_channel_mgr( &self, s2b_create_channel_r: mpsc::UnboundedReceiver, - w2b_frames_s: mpsc::UnboundedSender, + b2b_add_send_protocol_s: mpsc::UnboundedSender<(Cid, SendProtocols)>, + b2b_add_recv_protocol_s: mpsc::UnboundedSender<(Cid, RecvProtocols)>, ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); - trace!("Start create_channel_mgr"); + let s2b_create_channel_r = UnboundedReceiverStream::new(s2b_create_channel_r); s2b_create_channel_r - .for_each_concurrent( - None, - |(cid, _, protocol, leftover_cid_frame, b2s_create_channel_done_s)| { - // This channel is now configured, and we are running it in scope of the - // participant. - let w2b_frames_s = w2b_frames_s.clone(); - let channels = Arc::clone(&self.channels); - async move { - let (channel, b2w_frame_s, b2r_read_shutdown) = Channel::new(cid); - let mut lock = channels.write().await; - #[cfg(feature = "metrics")] - let mut channel_no = lock.len(); - #[cfg(not(feature = "metrics"))] - let channel_no = lock.len(); - lock.insert( + .for_each_concurrent(None, |(cid, _, protocol, b2s_create_channel_done_s)| { + // This channel is now configured, and we are running it in scope of the + // participant. + let channels = Arc::clone(&self.channels); + let b2b_add_send_protocol_s = b2b_add_send_protocol_s.clone(); + let b2b_add_recv_protocol_s = b2b_add_recv_protocol_s.clone(); + async move { + let mut lock = channels.write().await; + let mut channel_no = lock.len(); + lock.insert( + cid, + Mutex::new(ChannelInfo { cid, - Mutex::new(ChannelInfo { - cid, - cid_string: cid.to_string(), - b2w_frame_s, - b2r_read_shutdown, - }), - ); - drop(lock); - b2s_create_channel_done_s.send(()).unwrap(); - #[cfg(feature = "metrics")] - { - self.metrics - .channels_connected_total - .with_label_values(&[&self.remote_pid_string]) - .inc(); - if channel_no > 5 { - debug!(?channel_no, "metrics will overwrite channel #5"); - channel_no = 5; - } - self.metrics - .participants_channel_ids - .with_label_values(&[ - &self.remote_pid_string, - &channel_no.to_string(), - ]) - .set(cid as i64); - } - trace!(?cid, ?channel_no, "Running channel in participant"); - channel - .run(protocol, w2b_frames_s, leftover_cid_frame) - .instrument(tracing::info_span!("", ?cid)) - .await; - #[cfg(feature = "metrics")] - self.metrics - .channels_disconnected_total - .with_label_values(&[&self.remote_pid_string]) - .inc(); - info!(?cid, "Channel got closed"); - //maybe channel got already dropped, we don't know. - channels.write().await.remove(&cid); - trace!(?cid, "Channel cleanup completed"); - //TEMP FIX: as we dont have channel takeover yet drop the whole - // bParticipant - self.close_write_api(None).await; + cid_string: cid.to_string(), + }), + ); + drop(lock); + let (send, recv) = protocol.split(); + b2b_add_send_protocol_s.send((cid, send)).unwrap(); + b2b_add_recv_protocol_s.send((cid, recv)).unwrap(); + b2s_create_channel_done_s.send(()).unwrap(); + if channel_no > 5 { + debug!(?channel_no, "metrics will overwrite channel #5"); + channel_no = 5; } - }, - ) + self.metrics + .channels_connected(&self.remote_pid_string, channel_no, cid); + } + }) .await; trace!("Stop create_channel_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); + self.shutdown_barrier + .fetch_sub(Self::BARR_CHANNEL, Ordering::Relaxed); } - async fn open_mgr( - &self, - mut a2b_stream_open_r: mpsc::UnboundedReceiver, - a2b_close_stream_s: mpsc::UnboundedSender, - a2p_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, - shutdown_open_mgr_receiver: oneshot::Receiver<()>, - ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); - trace!("Start open_mgr"); - let mut stream_ids = self.offset_sid; - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - let mut shutdown_open_mgr_receiver = shutdown_open_mgr_receiver.fuse(); - //from api or shutdown signal - while let Some((prio, promises, p2a_return_stream)) = select! { - next = a2b_stream_open_r.next().fuse() => next, - _ = shutdown_open_mgr_receiver => None, - } { - debug!(?prio, ?promises, "Got request to open a new steam"); - //TODO: a2b_stream_open_r isn't closed on api_close yet. This needs to change. - //till then just check here if we are closed and in that case do nothing (not - // even answer) - if self.shutdown_info.read().await.error.is_some() { - continue; - } - - let a2p_msg_s = a2p_msg_s.clone(); - let sid = stream_ids; - let stream = self - .create_stream(sid, prio, promises, a2p_msg_s, &a2b_close_stream_s) - .await; - if self - .send_frame( - Frame::OpenStream { - sid, - prio, - promises, - }, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await - { - //On error, we drop this, so it gets closed and client will handle this as an - // Err any way (: - p2a_return_stream.send(stream).unwrap(); - stream_ids += Sid::from(1); - } - } - trace!("Stop open_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); - } - - /// when activated this function will drop the participant completely and - /// wait for everything to go right! Then return 1. Shutting down - /// Streams for API and End user! 2. Wait for all "prio queued" Messages - /// to be send. 3. Send Stream + /// sink shutdown: + /// Situation AS, AR, BS, BR. A wants to close. + /// AS shutdown. + /// BR notices shutdown and tries to stops BS. (success) + /// BS shutdown + /// AR notices shutdown and tries to stop AS. (fails) + /// For the case where BS didn't get shutdowned, e.g. by a handing situation + /// on the remote, we have a timeout to also force close AR. + /// + /// This fn will: + /// - 1. stop api to interact with bparticipant by closing sendmsg and + /// openstream + /// - 2. stop the send_mgr (it will take care of clearing the + /// queue and finish with a Shutdown) + /// - (3). force stop recv after 60 + /// seconds + /// - (4). this fn finishes last and afterwards BParticipant + /// drops + /// + /// before calling this fn, make sure `s2b_create_channel` is closed! /// If BParticipant kills itself managers stay active till this function is /// called by api to get the result status async fn participant_shutdown_mgr( &self, s2b_shutdown_bparticipant_r: oneshot::Receiver, - shutdown_open_mgr_sender: oneshot::Sender<()>, - shutdown_stream_close_mgr_sender: oneshot::Sender>, - shutdown_send_mgr_sender: oneshot::Sender>, + b2b_close_send_protocol_s: async_channel::Sender, + b2b_force_close_recv_protocol_s: async_channel::Sender, ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); - trace!("Start participant_shutdown_mgr"); - let sender = s2b_shutdown_bparticipant_r.await.unwrap(); + let wait_for_manager = || async { + let mut sleep = 0.01f64; + loop { + let bytes = self.shutdown_barrier.load(Ordering::Relaxed); + if bytes == 0 { + break; + } + sleep *= 1.4; + tokio::time::sleep(Duration::from_secs_f64(sleep)).await; + if sleep > 0.2 { + trace!(?bytes, "wait for mgr to close"); + } + } + }; - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - - self.close_api(None).await; - - debug!("Closing all managers"); - shutdown_open_mgr_sender - .send(()) - .expect("open_mgr must have crashed before"); - let (b2b_stream_close_shutdown_confirmed_s, b2b_stream_close_shutdown_confirmed_r) = - oneshot::channel(); - shutdown_stream_close_mgr_sender - .send(b2b_stream_close_shutdown_confirmed_s) - .expect("stream_close_mgr must have crashed before"); - // We need to wait for the stream_close_mgr BEFORE send_mgr, as the - // stream_close_mgr needs to wait on the API to drop `Stream` and be triggered - // It will then sleep for streams to be flushed in PRIO, and send_mgr is - // responsible for ticking PRIO WHILE this happens, so we cant close it before! - b2b_stream_close_shutdown_confirmed_r.await.unwrap(); - - //closing send_mgr now: - let (b2b_prios_flushed_s, b2b_prios_flushed_r) = oneshot::channel(); - shutdown_send_mgr_sender - .send(b2b_prios_flushed_s) - .expect("stream_close_mgr must have crashed before"); - b2b_prios_flushed_r.await.unwrap(); - - if Some(ParticipantError::ParticipantDisconnected) != self.shutdown_info.read().await.error + let (timeout_time, sender) = s2b_shutdown_bparticipant_r.await.unwrap(); + debug!("participant_shutdown_mgr triggered. Closing all streams for send"); { - debug!("Sending shutdown frame after flushed all prios"); - if !self - .send_frame( - Frame::Shutdown, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await - { - warn!("couldn't send shutdown frame, are channels already closed?"); + let lock = self.streams.read().await; + for si in lock.values() { + si.send_closed.store(true, Ordering::Relaxed); } } - debug!("Closing all channels, after flushed prios"); - for (cid, ci) in self.channels.write().await.drain() { - let ci = ci.into_inner(); - if let Err(e) = ci.b2r_read_shutdown.send(()) { + let lock = self.channels.read().await; + assert!( + !lock.is_empty(), + "no channel existed remote_pid={}", + self.remote_pid + ); + for cid in lock.keys() { + if let Err(e) = b2b_close_send_protocol_s.send(*cid).await { debug!( ?e, ?cid, - "Seems like this read protocol got already dropped by closing the Stream \ - itself, ignoring" - ); - }; - ci.b2w_frame_s.close_channel(); - } - - //Wait for other bparticipants mgr to close via AtomicUsize - const SLEEP_TIME: Duration = Duration::from_millis(5); - const ALLOWED_MANAGER: usize = 1; - async_std::task::sleep(SLEEP_TIME).await; - let mut i: u32 = 1; - while self.running_mgr.load(Ordering::Relaxed) > ALLOWED_MANAGER { - i += 1; - if i.rem_euclid(10) == 1 { - trace!( - ?ALLOWED_MANAGER, - "Waiting for bparticipant mgr to shut down, remaining {}", - self.running_mgr.load(Ordering::Relaxed) - ALLOWED_MANAGER + "closing send_mgr may fail if we got a recv error simultaneously" ); } - async_std::task::sleep(SLEEP_TIME * i).await; } - trace!("All BParticipant mgr (except me) are shut down now"); + drop(lock); + + trace!("wait for other managers"); + let timeout = tokio::time::sleep(timeout_time); + let timeout = tokio::select! { + _ = wait_for_manager() => false, + _ = timeout => true, + }; + if timeout { + warn!("timeout triggered: for killing recv"); + let lock = self.channels.read().await; + for cid in lock.keys() { + if let Err(e) = b2b_force_close_recv_protocol_s.send(*cid).await { + debug!( + ?e, + ?cid, + "closing recv_mgr may fail if we got a recv error simultaneously" + ); + } + } + } + + trace!("wait again"); + wait_for_manager().await; + + sender.send(Ok(())).unwrap(); #[cfg(feature = "metrics")] self.metrics.participants_disconnected_total.inc(); - debug!("BParticipant close done"); - - let mut lock = self.shutdown_info.write().await; - sender - .send(match lock.error.take() { - None => Ok(()), - Some(ParticipantError::ProtocolFailedUnrecoverable) => { - Err(ParticipantError::ProtocolFailedUnrecoverable) - }, - Some(ParticipantError::ParticipantDisconnected) => Ok(()), - }) - .unwrap(); - trace!("Stop participant_shutdown_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); } - async fn stream_close_mgr( - &self, - mut a2b_close_stream_r: mpsc::UnboundedReceiver, - shutdown_stream_close_mgr_receiver: oneshot::Receiver>, - b2p_notify_empty_stream_s: crossbeam_channel::Sender<(Sid, oneshot::Sender<()>)>, - ) { - self.running_mgr.fetch_add(1, Ordering::Relaxed); - trace!("Start stream_close_mgr"); - #[cfg(feature = "metrics")] - let mut send_cache = MultiCidFrameCache::new(self.metrics.frames_out_total.clone()); - let mut shutdown_stream_close_mgr_receiver = shutdown_stream_close_mgr_receiver.fuse(); - let mut b2b_stream_close_shutdown_confirmed_s = None; - - //from api or shutdown signal - while let Some(sid) = select! { - next = a2b_close_stream_r.next().fuse() => next, - sender = shutdown_stream_close_mgr_receiver => { - b2b_stream_close_shutdown_confirmed_s = Some(sender.unwrap()); - None - } - } { - //TODO: make this concurrent! - //TODO: Performance, closing is slow! - self.delete_stream( - sid, - Some(b2p_notify_empty_stream_s.clone()), - false, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await; + /// Stopping API and participant usage + /// Protocol will take care of the order of the frame + async fn delete_stream(&self, sid: Sid) { + let stream = { self.streams.write().await.remove(&sid) }; + match stream { + Some(si) => { + si.send_closed.store(true, Ordering::Relaxed); + si.b2a_msg_recv_s.lock().await.close(); + }, + None => { + trace!("Couldn't find the stream, might be simultaneous close from local/remote") + }, } - trace!("deleting all leftover streams"); - let sids = self - .streams - .read() - .await - .keys() - .cloned() - .collect::>(); - for sid in sids { - //flushing is still important, e.g. when Participant::drop is called (but - // Stream:drop isn't)! - self.delete_stream( - sid, - Some(b2p_notify_empty_stream_s.clone()), - false, - #[cfg(feature = "metrics")] - &mut send_cache, - ) - .await; - } - if b2b_stream_close_shutdown_confirmed_s.is_none() { - b2b_stream_close_shutdown_confirmed_s = - Some(shutdown_stream_close_mgr_receiver.await.unwrap()); - } - b2b_stream_close_shutdown_confirmed_s - .unwrap() - .send(()) - .unwrap(); - trace!("Stop stream_close_mgr"); - self.running_mgr.fetch_sub(1, Ordering::Relaxed); - } - - async fn delete_stream( - &self, - sid: Sid, - b2p_notify_empty_stream_s: Option)>>, - from_remote: bool, - #[cfg(feature = "metrics")] frames_out_total_cache: &mut MultiCidFrameCache, - ) { - //This needs to first stop clients from sending any more. - //Then it will wait for all pending messages (in prio) to be send to the - // protocol After this happened the stream is closed - //Only after all messages are send to the protocol, we can send the CloseStream - // frame! If we would send it before, all followup messages couldn't - // be handled at the remote side. - async { - trace!("Stopping api to use this stream"); - match self.streams.read().await.get(&sid) { - Some(si) => { - si.send_closed.store(true, Ordering::Relaxed); - si.b2a_msg_recv_s.lock().await.close_channel(); - }, - None => trace!( - "Couldn't find the stream, might be simultaneous close from local/remote" - ), - } - - if !from_remote { - trace!("Wait for stream to be flushed"); - let (s2b_stream_finished_closed_s, s2b_stream_finished_closed_r) = - oneshot::channel(); - b2p_notify_empty_stream_s - .expect("needs to be set when from_remote is false") - .send((sid, s2b_stream_finished_closed_s)) - .unwrap(); - s2b_stream_finished_closed_r.await.unwrap(); - - trace!("Stream was successfully flushed"); - } - - #[cfg(feature = "metrics")] - self.metrics - .streams_closed_total - .with_label_values(&[&self.remote_pid_string]) - .inc(); - //only now remove the Stream, that means we can still recv on it. - self.streams.write().await.remove(&sid); - - if !from_remote { - self.send_frame( - Frame::CloseStream { sid }, - #[cfg(feature = "metrics")] - frames_out_total_cache, - ) - .await; - } - } - .instrument(tracing::info_span!("close", ?sid, ?from_remote)) - .await; + self.metrics.streams_closed(&self.remote_pid_string); } async fn create_stream( @@ -821,10 +624,11 @@ impl BParticipant { sid: Sid, prio: Prio, promises: Promises, - a2p_msg_s: crossbeam_channel::Sender<(Prio, Sid, OutgoingMessage)>, + guaranteed_bandwidth: Bandwidth, + a2b_msg_s: &crossbeam_channel::Sender<(Sid, Bytes)>, a2b_close_stream_s: &mpsc::UnboundedSender, ) -> Stream { - let (b2a_msg_recv_s, b2a_msg_recv_r) = mpsc::unbounded::(); + let (b2a_msg_recv_s, b2a_msg_recv_r) = async_channel::unbounded::(); let send_closed = Arc::new(AtomicBool::new(false)); self.streams.write().await.insert(sid, StreamInfo { prio, @@ -832,46 +636,259 @@ impl BParticipant { send_closed: Arc::clone(&send_closed), b2a_msg_recv_s: Mutex::new(b2a_msg_recv_s), }); - #[cfg(feature = "metrics")] - self.metrics - .streams_opened_total - .with_label_values(&[&self.remote_pid_string]) - .inc(); + self.metrics.streams_opened(&self.remote_pid_string); Stream::new( + self.local_pid, self.remote_pid, sid, prio, promises, + guaranteed_bandwidth, send_closed, - a2p_msg_s, + a2b_msg_s.clone(), b2a_msg_recv_r, a2b_close_stream_s.clone(), ) } +} - async fn close_write_api(&self, reason: Option) { - trace!(?reason, "close_api"); - let mut lock = self.shutdown_info.write().await; - if let Some(r) = reason { - lock.error = Some(r); - } - lock.b2a_stream_opened_s.close_channel(); +#[cfg(test)] +mod tests { + use super::*; + use network_protocol::ProtocolMetrics; + use tokio::{ + runtime::Runtime, + sync::{mpsc, oneshot}, + task::JoinHandle, + }; - debug!("Closing all streams for write"); - for (sid, si) in self.streams.read().await.iter() { - trace!(?sid, "Shutting down Stream for write"); - si.send_closed.store(true, Ordering::Relaxed); - } + #[allow(clippy::type_complexity)] + fn mock_bparticipant() -> ( + Arc, + mpsc::UnboundedSender, + mpsc::UnboundedReceiver, + mpsc::UnboundedSender, + oneshot::Sender, + mpsc::UnboundedReceiver, + JoinHandle<()>, + ) { + let runtime = Arc::new(tokio::runtime::Runtime::new().unwrap()); + let runtime_clone = Arc::clone(&runtime); + + let (b2s_prio_statistic_s, b2s_prio_statistic_r) = + mpsc::unbounded_channel::(); + + let ( + bparticipant, + a2b_open_stream_s, + b2a_stream_opened_r, + s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + ) = runtime_clone.block_on(async move { + let local_pid = Pid::fake(0); + let remote_pid = Pid::fake(1); + let sid = Sid::new(1000); + let metrics = Arc::new(NetworkMetrics::new(&local_pid).unwrap()); + + BParticipant::new(local_pid, remote_pid, sid, Arc::clone(&metrics)) + }); + + let handle = runtime_clone.spawn(bparticipant.run(b2s_prio_statistic_s)); + ( + runtime_clone, + a2b_open_stream_s, + b2a_stream_opened_r, + s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) } - ///closing api::Participant is done by closing all channels, expect for the - /// shutdown channel at this point! - async fn close_api(&self, reason: Option) { - self.close_write_api(reason).await; - debug!("Closing all streams"); - for (sid, si) in self.streams.read().await.iter() { - trace!(?sid, "Shutting down Stream"); - si.b2a_msg_recv_s.lock().await.close_channel(); - } + async fn mock_mpsc( + cid: Cid, + _runtime: &Arc, + create_channel: &mut mpsc::UnboundedSender, + ) -> Protocols { + let (s1, r1) = mpsc::channel(100); + let (s2, r2) = mpsc::channel(100); + let metrics = Arc::new(ProtocolMetrics::new().unwrap()); + let p1 = Protocols::new_mpsc(s1, r2, cid, Arc::clone(&metrics)); + let (complete_s, complete_r) = oneshot::channel(); + create_channel + .send((cid, Sid::new(0), p1, complete_s)) + .unwrap(); + complete_r.await.unwrap(); + Protocols::new_mpsc(s2, r1, cid, Arc::clone(&metrics)) + } + + #[test] + fn close_bparticipant_by_timeout_during_close() { + let ( + runtime, + a2b_open_stream_s, + b2a_stream_opened_r, + mut s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) = mock_bparticipant(); + + let _remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s)); + std::thread::sleep(Duration::from_millis(50)); + + let (s, r) = oneshot::channel(); + let before = Instant::now(); + runtime.block_on(async { + drop(s2b_create_channel_s); + s2b_shutdown_bparticipant_s + .send((Duration::from_secs(1), s)) + .unwrap(); + r.await.unwrap().unwrap(); + }); + assert!( + before.elapsed() > Duration::from_millis(900), + "timeout wasn't triggered" + ); + + runtime.block_on(handle).unwrap(); + + drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r)); + drop(runtime); + } + + #[test] + fn close_bparticipant_cleanly() { + let ( + runtime, + a2b_open_stream_s, + b2a_stream_opened_r, + mut s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) = mock_bparticipant(); + + let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s)); + std::thread::sleep(Duration::from_millis(50)); + + let (s, r) = oneshot::channel(); + let before = Instant::now(); + runtime.block_on(async { + drop(s2b_create_channel_s); + s2b_shutdown_bparticipant_s + .send((Duration::from_secs(2), s)) + .unwrap(); + drop(remote); // remote needs to be dropped as soon as local.sender is closed + r.await.unwrap().unwrap(); + }); + assert!( + before.elapsed() < Duration::from_millis(1900), + "timeout was triggered" + ); + + runtime.block_on(handle).unwrap(); + + drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r)); + drop(runtime); + } + + #[test] + fn create_stream() { + let ( + runtime, + a2b_open_stream_s, + b2a_stream_opened_r, + mut s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) = mock_bparticipant(); + + let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s)); + std::thread::sleep(Duration::from_millis(50)); + + // created stream + let (rs, mut rr) = remote.split(); + let (stream_sender, _stream_receiver) = oneshot::channel(); + a2b_open_stream_s + .send((7u8, Promises::ENCRYPTED, 1_000_000, stream_sender)) + .unwrap(); + + let stream_event = runtime.block_on(rr.recv()).unwrap(); + match stream_event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + assert_eq!(sid, Sid::new(1000)); + assert_eq!(prio, 7u8); + assert_eq!(promises, Promises::ENCRYPTED); + assert_eq!(guaranteed_bandwidth, 1_000_000); + }, + _ => panic!("wrong event"), + }; + + let (s, r) = oneshot::channel(); + runtime.block_on(async { + drop(s2b_create_channel_s); + s2b_shutdown_bparticipant_s + .send((Duration::from_secs(1), s)) + .unwrap(); + drop((rs, rr)); + r.await.unwrap().unwrap(); + }); + + runtime.block_on(handle).unwrap(); + + drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r)); + drop(runtime); + } + + #[test] + fn created_stream() { + let ( + runtime, + a2b_open_stream_s, + mut b2a_stream_opened_r, + mut s2b_create_channel_s, + s2b_shutdown_bparticipant_s, + b2s_prio_statistic_r, + handle, + ) = mock_bparticipant(); + + let remote = runtime.block_on(mock_mpsc(0, &runtime, &mut s2b_create_channel_s)); + std::thread::sleep(Duration::from_millis(50)); + + // create stream + let (mut rs, rr) = remote.split(); + runtime + .block_on(rs.send(ProtocolEvent::OpenStream { + sid: Sid::new(1000), + prio: 9u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + })) + .unwrap(); + + let stream = runtime.block_on(b2a_stream_opened_r.recv()).unwrap(); + assert_eq!(stream.promises(), Promises::ORDERED); + + let (s, r) = oneshot::channel(); + runtime.block_on(async { + drop(s2b_create_channel_s); + s2b_shutdown_bparticipant_s + .send((Duration::from_secs(1), s)) + .unwrap(); + drop((rs, rr)); + r.await.unwrap().unwrap(); + }); + + runtime.block_on(handle).unwrap(); + + drop((a2b_open_stream_s, b2a_stream_opened_r, b2s_prio_statistic_r)); + drop(runtime); } } diff --git a/network/src/prios.rs b/network/src/prios.rs deleted file mode 100644 index 46d31024b5..0000000000 --- a/network/src/prios.rs +++ /dev/null @@ -1,665 +0,0 @@ -//!Priorities are handled the following way. -//!Prios from 0-63 are allowed. -//!all 5 numbers the throughput is halved. -//!E.g. in the same time 100 prio0 messages are send, only 50 prio5, 25 prio10, -//! 12 prio15 or 6 prio20 messages are send. Note: TODO: prio0 will be send -//! immediately when found! -#[cfg(feature = "metrics")] -use crate::metrics::NetworkMetrics; -use crate::{ - message::OutgoingMessage, - types::{Frame, Prio, Sid}, -}; -use crossbeam_channel::{unbounded, Receiver, Sender}; -use futures::channel::oneshot; -use std::collections::{HashMap, HashSet, VecDeque}; -#[cfg(feature = "metrics")] use std::sync::Arc; -use tracing::trace; - -const PRIO_MAX: usize = 64; - -#[derive(Default)] -struct PidSidInfo { - len: u64, - empty_notify: Option>, -} - -pub(crate) struct PrioManager { - points: [u32; PRIO_MAX], - messages: [VecDeque<(Sid, OutgoingMessage)>; PRIO_MAX], - messages_rx: Receiver<(Prio, Sid, OutgoingMessage)>, - sid_owned: HashMap, - //you can register to be notified if a pid_sid combination is flushed completely here - sid_flushed_rx: Receiver<(Sid, oneshot::Sender<()>)>, - queued: HashSet, - #[cfg(feature = "metrics")] - metrics: Arc, - #[cfg(feature = "metrics")] - pid: String, -} - -impl PrioManager { - const PRIOS: [u32; PRIO_MAX] = [ - 100, 115, 132, 152, 174, 200, 230, 264, 303, 348, 400, 459, 528, 606, 696, 800, 919, 1056, - 1213, 1393, 1600, 1838, 2111, 2425, 2786, 3200, 3676, 4222, 4850, 5572, 6400, 7352, 8445, - 9701, 11143, 12800, 14703, 16890, 19401, 22286, 25600, 29407, 33779, 38802, 44572, 51200, - 58813, 67559, 77605, 89144, 102400, 117627, 135118, 155209, 178289, 204800, 235253, 270235, - 310419, 356578, 409600, 470507, 540470, 620838, - ]; - - #[allow(clippy::type_complexity)] - pub fn new( - #[cfg(feature = "metrics")] metrics: Arc, - pid: String, - ) -> ( - Self, - Sender<(Prio, Sid, OutgoingMessage)>, - Sender<(Sid, oneshot::Sender<()>)>, - ) { - #[cfg(not(feature = "metrics"))] - let _pid = pid; - // (a2p_msg_s, a2p_msg_r) - let (messages_tx, messages_rx) = unbounded(); - let (sid_flushed_tx, sid_flushed_rx) = unbounded(); - ( - Self { - points: [0; PRIO_MAX], - messages: [ - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - VecDeque::new(), - ], - messages_rx, - queued: HashSet::new(), //TODO: optimize with u64 and 64 bits - sid_flushed_rx, - sid_owned: HashMap::new(), - #[cfg(feature = "metrics")] - metrics, - #[cfg(feature = "metrics")] - pid, - }, - messages_tx, - sid_flushed_tx, - ) - } - - async fn tick(&mut self) { - // Check Range - for (prio, sid, msg) in self.messages_rx.try_iter() { - debug_assert!(prio as usize <= PRIO_MAX); - #[cfg(feature = "metrics")] - { - let sid_string = sid.to_string(); - self.metrics - .message_out_total - .with_label_values(&[&self.pid, &sid_string]) - .inc(); - self.metrics - .message_out_throughput - .with_label_values(&[&self.pid, &sid_string]) - .inc_by(msg.buffer.data.len() as u64); - } - - //trace!(?prio, ?sid_string, "tick"); - self.queued.insert(prio); - self.messages[prio as usize].push_back((sid, msg)); - self.sid_owned.entry(sid).or_default().len += 1; - } - //this must be AFTER messages - for (sid, return_sender) in self.sid_flushed_rx.try_iter() { - #[cfg(feature = "metrics")] - self.metrics - .streams_flushed - .with_label_values(&[&self.pid]) - .inc(); - if let Some(cnt) = self.sid_owned.get_mut(&sid) { - // register sender - cnt.empty_notify = Some(return_sender); - trace!(?sid, "register empty notify"); - } else { - // return immediately - return_sender.send(()).unwrap(); - trace!(?sid, "return immediately that stream is empty"); - } - } - } - - //if None returned, we are empty! - fn calc_next_prio(&self) -> Option { - // compare all queued prios, max 64 operations - let mut lowest = std::u32::MAX; - let mut lowest_id = None; - for &n in &self.queued { - let n_points = self.points[n as usize]; - if n_points < lowest { - lowest = n_points; - lowest_id = Some(n) - } else if n_points == lowest && lowest_id.is_some() && n < lowest_id.unwrap() { - //on equal points lowest first! - lowest_id = Some(n) - } - } - lowest_id - /* - self.queued - .iter() - .min_by_key(|&n| self.points[*n as usize]).cloned()*/ - } - - /// no_of_frames = frames.len() - /// Your goal is to try to find a realistic no_of_frames! - /// no_of_frames should be choosen so, that all Frames can be send out till - /// the next tick! - /// - if no_of_frames is too high you will fill either the Socket buffer, - /// or your internal buffer. In that case you will increase latency for - /// high prio messages! - /// - if no_of_frames is too low you wont saturate your Socket fully, thus - /// have a lower bandwidth as possible - pub async fn fill_frames>( - &mut self, - no_of_frames: usize, - frames: &mut E, - ) { - for v in self.messages.iter_mut() { - v.reserve_exact(no_of_frames) - } - self.tick().await; - for _ in 0..no_of_frames { - match self.calc_next_prio() { - Some(prio) => { - //let prio2 = self.calc_next_prio().unwrap(); - //trace!(?prio, "handle next prio"); - self.points[prio as usize] += Self::PRIOS[prio as usize]; - //pop message from front of VecDeque, handle it and push it back, so that all - // => messages with same prio get a fair chance :) - //TODO: evaluate not popping every time - let (sid, mut msg) = self.messages[prio as usize].pop_front().unwrap(); - if msg.fill_next(sid, frames) { - //trace!(?m.mid, "finish message"); - //check if prio is empty - if self.messages[prio as usize].is_empty() { - self.queued.remove(&prio); - } - //decrease pid_sid counter by 1 again - let cnt = self.sid_owned.get_mut(&sid).expect( - "The pid_sid_owned counter works wrong, more pid,sid removed than \ - inserted", - ); - cnt.len -= 1; - if cnt.len == 0 { - let cnt = self.sid_owned.remove(&sid).unwrap(); - if let Some(empty_notify) = cnt.empty_notify { - empty_notify.send(()).unwrap(); - trace!(?sid, "returned that stream is empty"); - } - } - } else { - self.messages[prio as usize].push_front((sid, msg)); - } - }, - None => { - //QUEUE is empty, we are clearing the POINTS to not build up huge pipes of - // POINTS on a prio from the past - self.points = [0; PRIO_MAX]; - break; - }, - } - } - } -} - -impl std::fmt::Debug for PrioManager { - #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut cnt = 0; - for m in self.messages.iter() { - cnt += m.len(); - } - write!(f, "PrioManager(len: {}, queued: {:?})", cnt, &self.queued,) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - message::{MessageBuffer, OutgoingMessage}, - metrics::NetworkMetrics, - prios::*, - types::{Frame, Pid, Prio, Sid}, - }; - use crossbeam_channel::Sender; - use futures::{channel::oneshot, executor::block_on}; - use std::{collections::VecDeque, sync::Arc}; - - const SIZE: u64 = OutgoingMessage::FRAME_DATA_SIZE; - const USIZE: usize = OutgoingMessage::FRAME_DATA_SIZE as usize; - - #[allow(clippy::type_complexity)] - fn mock_new() -> ( - PrioManager, - Sender<(Prio, Sid, OutgoingMessage)>, - Sender<(Sid, oneshot::Sender<()>)>, - ) { - let pid = Pid::fake(1); - PrioManager::new( - Arc::new(NetworkMetrics::new(&pid).unwrap()), - pid.to_string(), - ) - } - - fn mock_out(prio: Prio, sid: u64) -> (Prio, Sid, OutgoingMessage) { - let sid = Sid::new(sid); - (prio, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { - data: vec![48, 49, 50], - }), - cursor: 0, - mid: 1, - sid, - }) - } - - fn mock_out_large(prio: Prio, sid: u64) -> (Prio, Sid, OutgoingMessage) { - let sid = Sid::new(sid); - let mut data = vec![48; USIZE]; - data.append(&mut vec![49; USIZE]); - data.append(&mut vec![50; 20]); - (prio, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { data }), - cursor: 0, - mid: 1, - sid, - }) - } - - fn assert_header(frames: &mut VecDeque<(Sid, Frame)>, f_sid: u64, f_length: u64) { - let frame = frames - .pop_front() - .expect("Frames vecdeque doesn't contain enough frames!") - .1; - if let Frame::DataHeader { mid, sid, length } = frame { - assert_eq!(mid, 1); - assert_eq!(sid, Sid::new(f_sid)); - assert_eq!(length, f_length); - } else { - panic!("Wrong frame type!, expected DataHeader"); - } - } - - fn assert_data(frames: &mut VecDeque<(Sid, Frame)>, f_start: u64, f_data: Vec) { - let frame = frames - .pop_front() - .expect("Frames vecdeque doesn't contain enough frames!") - .1; - if let Frame::Data { mid, start, data } = frame { - assert_eq!(mid, 1); - assert_eq!(start, f_start); - assert_eq!(data, f_data); - } else { - panic!("Wrong frame type!, expected Data"); - } - } - - #[test] - fn single_p16() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(16, 1337)).unwrap(); - let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 1337, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert!(frames.is_empty()); - } - - #[test] - fn single_p16_p20() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(16, 1337)).unwrap(); - msg_tx.send(mock_out(20, 42)).unwrap(); - let mut frames = VecDeque::new(); - - block_on(mgr.fill_frames(100, &mut frames)); - assert_header(&mut frames, 1337, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 42, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert!(frames.is_empty()); - } - - #[test] - fn single_p20_p16() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(20, 42)).unwrap(); - msg_tx.send(mock_out(16, 1337)).unwrap(); - let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 1337, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 42, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert!(frames.is_empty()); - } - - #[test] - fn multiple_p16_p20() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(20, 2)).unwrap(); - msg_tx.send(mock_out(16, 1)).unwrap(); - msg_tx.send(mock_out(16, 3)).unwrap(); - msg_tx.send(mock_out(16, 5)).unwrap(); - msg_tx.send(mock_out(20, 4)).unwrap(); - msg_tx.send(mock_out(20, 7)).unwrap(); - msg_tx.send(mock_out(16, 6)).unwrap(); - msg_tx.send(mock_out(20, 10)).unwrap(); - msg_tx.send(mock_out(16, 8)).unwrap(); - msg_tx.send(mock_out(20, 12)).unwrap(); - msg_tx.send(mock_out(16, 9)).unwrap(); - msg_tx.send(mock_out(16, 11)).unwrap(); - msg_tx.send(mock_out(20, 13)).unwrap(); - let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(100, &mut frames)); - - for i in 1..14 { - assert_header(&mut frames, i, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - } - assert!(frames.is_empty()); - } - - #[test] - fn multiple_fill_frames_p16_p20() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out(20, 2)).unwrap(); - msg_tx.send(mock_out(16, 1)).unwrap(); - msg_tx.send(mock_out(16, 3)).unwrap(); - msg_tx.send(mock_out(16, 5)).unwrap(); - msg_tx.send(mock_out(20, 4)).unwrap(); - msg_tx.send(mock_out(20, 7)).unwrap(); - msg_tx.send(mock_out(16, 6)).unwrap(); - msg_tx.send(mock_out(20, 10)).unwrap(); - msg_tx.send(mock_out(16, 8)).unwrap(); - msg_tx.send(mock_out(20, 12)).unwrap(); - msg_tx.send(mock_out(16, 9)).unwrap(); - msg_tx.send(mock_out(16, 11)).unwrap(); - msg_tx.send(mock_out(20, 13)).unwrap(); - - let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(3, &mut frames)); - for i in 1..4 { - assert_header(&mut frames, i, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - } - assert!(frames.is_empty()); - block_on(mgr.fill_frames(11, &mut frames)); - for i in 4..14 { - assert_header(&mut frames, i, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - } - assert!(frames.is_empty()); - } - - #[test] - fn single_large_p16() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out_large(16, 1)).unwrap(); - let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 1, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert!(frames.is_empty()); - } - - #[test] - fn multiple_large_p16() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out_large(16, 1)).unwrap(); - msg_tx.send(mock_out_large(16, 2)).unwrap(); - let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 1, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert_header(&mut frames, 2, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert!(frames.is_empty()); - } - - #[test] - fn multiple_large_p16_sudden_p0() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - msg_tx.send(mock_out_large(16, 1)).unwrap(); - msg_tx.send(mock_out_large(16, 2)).unwrap(); - let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(2, &mut frames)); - - assert_header(&mut frames, 1, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - - msg_tx.send(mock_out(0, 3)).unwrap(); - block_on(mgr.fill_frames(100, &mut frames)); - - assert_header(&mut frames, 3, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert_header(&mut frames, 2, SIZE * 2 + 20); - assert_data(&mut frames, 0, vec![48; USIZE]); - assert_data(&mut frames, SIZE, vec![49; USIZE]); - assert_data(&mut frames, SIZE * 2, vec![50; 20]); - assert!(frames.is_empty()); - } - - #[test] - fn single_p20_thousand_p16_at_once() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - for _ in 0..998 { - msg_tx.send(mock_out(16, 2)).unwrap(); - } - msg_tx.send(mock_out(20, 1)).unwrap(); - msg_tx.send(mock_out(16, 2)).unwrap(); - msg_tx.send(mock_out(16, 2)).unwrap(); - let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(2000, &mut frames)); - - assert_header(&mut frames, 2, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 1, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 2, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 2, 3); - //unimportant - } - - #[test] - fn single_p20_thousand_p16_later() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - for _ in 0..998 { - msg_tx.send(mock_out(16, 2)).unwrap(); - } - let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(2000, &mut frames)); - //^unimportant frames, gonna be dropped - msg_tx.send(mock_out(20, 1)).unwrap(); - msg_tx.send(mock_out(16, 2)).unwrap(); - msg_tx.send(mock_out(16, 2)).unwrap(); - let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(2000, &mut frames)); - - //important in that test is, that after the first frames got cleared i reset - // the Points even though 998 prio 16 messages have been send at this - // point and 0 prio20 messages the next message is a prio16 message - // again, and only then prio20! we dont want to build dept over a idling - // connection - assert_header(&mut frames, 2, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 1, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_header(&mut frames, 2, 3); - //unimportant - } - - #[test] - fn gigantic_message() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - let mut data = vec![1; USIZE]; - data.extend_from_slice(&[2; USIZE]); - data.extend_from_slice(&[3; USIZE]); - data.extend_from_slice(&[4; USIZE]); - data.extend_from_slice(&[5; USIZE]); - let sid = Sid::new(2); - msg_tx - .send((16, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { data }), - cursor: 0, - mid: 1, - sid, - })) - .unwrap(); - - let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(2000, &mut frames)); - - assert_header(&mut frames, 2, 7000); - assert_data(&mut frames, 0, vec![1; USIZE]); - assert_data(&mut frames, 1400, vec![2; USIZE]); - assert_data(&mut frames, 2800, vec![3; USIZE]); - assert_data(&mut frames, 4200, vec![4; USIZE]); - assert_data(&mut frames, 5600, vec![5; USIZE]); - } - - #[test] - fn gigantic_message_order() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - let mut data = vec![1; USIZE]; - data.extend_from_slice(&[2; USIZE]); - data.extend_from_slice(&[3; USIZE]); - data.extend_from_slice(&[4; USIZE]); - data.extend_from_slice(&[5; USIZE]); - let sid = Sid::new(2); - msg_tx - .send((16, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { data }), - cursor: 0, - mid: 1, - sid, - })) - .unwrap(); - msg_tx.send(mock_out(16, 8)).unwrap(); - - let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(2000, &mut frames)); - - assert_header(&mut frames, 2, 7000); - assert_data(&mut frames, 0, vec![1; USIZE]); - assert_data(&mut frames, 1400, vec![2; USIZE]); - assert_data(&mut frames, 2800, vec![3; USIZE]); - assert_data(&mut frames, 4200, vec![4; USIZE]); - assert_data(&mut frames, 5600, vec![5; USIZE]); - assert_header(&mut frames, 8, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - } - - #[test] - fn gigantic_message_order_other_prio() { - let (mut mgr, msg_tx, _flush_tx) = mock_new(); - let mut data = vec![1; USIZE]; - data.extend_from_slice(&[2; USIZE]); - data.extend_from_slice(&[3; USIZE]); - data.extend_from_slice(&[4; USIZE]); - data.extend_from_slice(&[5; USIZE]); - let sid = Sid::new(2); - msg_tx - .send((16, sid, OutgoingMessage { - buffer: Arc::new(MessageBuffer { data }), - cursor: 0, - mid: 1, - sid, - })) - .unwrap(); - msg_tx.send(mock_out(20, 8)).unwrap(); - - let mut frames = VecDeque::new(); - block_on(mgr.fill_frames(2000, &mut frames)); - - assert_header(&mut frames, 2, 7000); - assert_data(&mut frames, 0, vec![1; USIZE]); - assert_header(&mut frames, 8, 3); - assert_data(&mut frames, 0, vec![48, 49, 50]); - assert_data(&mut frames, 1400, vec![2; USIZE]); - assert_data(&mut frames, 2800, vec![3; USIZE]); - assert_data(&mut frames, 4200, vec![4; USIZE]); - assert_data(&mut frames, 5600, vec![5; USIZE]); - } -} diff --git a/network/src/protocols.rs b/network/src/protocols.rs deleted file mode 100644 index 7b0b8651b6..0000000000 --- a/network/src/protocols.rs +++ /dev/null @@ -1,596 +0,0 @@ -#[cfg(feature = "metrics")] -use crate::metrics::{CidFrameCache, NetworkMetrics}; -use crate::{ - participant::C2pFrame, - types::{Cid, Frame}, -}; -use async_std::{ - io::prelude::*, - net::{TcpStream, UdpSocket}, -}; - -use futures::{ - channel::{mpsc, oneshot}, - future::{Fuse, FutureExt}, - lock::Mutex, - select, - sink::SinkExt, - stream::StreamExt, -}; -use std::{convert::TryFrom, net::SocketAddr, sync::Arc}; -use tracing::*; - -// Reserving bytes 0, 10, 13 as i have enough space and want to make it easy to -// detect a invalid client, e.g. sending an empty line would make 10 first char -// const FRAME_RESERVED_1: u8 = 0; -const FRAME_HANDSHAKE: u8 = 1; -const FRAME_INIT: u8 = 2; -const FRAME_SHUTDOWN: u8 = 3; -const FRAME_OPEN_STREAM: u8 = 4; -const FRAME_CLOSE_STREAM: u8 = 5; -const FRAME_DATA_HEADER: u8 = 6; -const FRAME_DATA: u8 = 7; -const FRAME_RAW: u8 = 8; -//const FRAME_RESERVED_2: u8 = 10; -//const FRAME_RESERVED_3: u8 = 13; - -#[derive(Debug)] -pub(crate) enum Protocols { - Tcp(TcpProtocol), - Udp(UdpProtocol), - //Mpsc(MpscChannel), -} - -#[derive(Debug)] -pub(crate) struct TcpProtocol { - stream: TcpStream, - #[cfg(feature = "metrics")] - metrics: Arc, -} - -#[derive(Debug)] -pub(crate) struct UdpProtocol { - socket: Arc, - remote_addr: SocketAddr, - #[cfg(feature = "metrics")] - metrics: Arc, - data_in: Mutex>>, -} - -//TODO: PERFORMACE: Use BufWriter and BufReader from std::io! -impl TcpProtocol { - pub(crate) fn new( - stream: TcpStream, - #[cfg(feature = "metrics")] metrics: Arc, - ) -> Self { - Self { - stream, - #[cfg(feature = "metrics")] - metrics, - } - } - - async fn read_frame( - r: &mut R, - mut end_receiver: &mut Fuse>, - ) -> Result> { - let handle = |read_result| match read_result { - Ok(_) => Ok(()), - Err(e) => Err(Some(e)), - }; - - let mut frame_no = [0u8; 1]; - match select! { - r = r.read_exact(&mut frame_no).fuse() => Some(r), - _ = end_receiver => None, - } { - Some(read_result) => handle(read_result)?, - None => { - trace!("shutdown requested"); - return Err(None); - }, - }; - - match frame_no[0] { - FRAME_HANDSHAKE => { - let mut bytes = [0u8; 19]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_handshake(bytes)) - }, - FRAME_INIT => { - let mut bytes = [0u8; 32]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_init(bytes)) - }, - FRAME_SHUTDOWN => Ok(Frame::Shutdown), - FRAME_OPEN_STREAM => { - let mut bytes = [0u8; 10]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_open_stream(bytes)) - }, - FRAME_CLOSE_STREAM => { - let mut bytes = [0u8; 8]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_close_stream(bytes)) - }, - FRAME_DATA_HEADER => { - let mut bytes = [0u8; 24]; - handle(r.read_exact(&mut bytes).await)?; - Ok(Frame::gen_data_header(bytes)) - }, - FRAME_DATA => { - let mut bytes = [0u8; 18]; - handle(r.read_exact(&mut bytes).await)?; - let (mid, start, length) = Frame::gen_data(bytes); - let mut data = vec![0; length as usize]; - handle(r.read_exact(&mut data).await)?; - Ok(Frame::Data { mid, start, data }) - }, - FRAME_RAW => { - let mut bytes = [0u8; 2]; - handle(r.read_exact(&mut bytes).await)?; - let length = Frame::gen_raw(bytes); - let mut data = vec![0; length as usize]; - handle(r.read_exact(&mut data).await)?; - Ok(Frame::Raw(data)) - }, - other => { - // report a RAW frame, but cannot rely on the next 2 bytes to be a size. - // guessing 32 bytes, which might help to sort down issues - let mut data = vec![0; 32]; - //keep the first byte! - match r.read(&mut data[1..]).await { - Ok(n) => { - data.truncate(n + 1); - Ok(()) - }, - Err(e) => Err(Some(e)), - }?; - data[0] = other; - warn!(?data, "got a unexpected RAW msg"); - Ok(Frame::Raw(data)) - }, - } - } - - pub async fn read_from_wire( - &self, - cid: Cid, - w2c_cid_frame_s: &mut mpsc::UnboundedSender, - end_r: oneshot::Receiver<()>, - ) { - trace!("Starting up tcp read()"); - #[cfg(feature = "metrics")] - let mut metrics_cache = CidFrameCache::new(self.metrics.frames_wire_in_total.clone(), cid); - #[cfg(feature = "metrics")] - let throughput_cache = self - .metrics - .wire_in_throughput - .with_label_values(&[&cid.to_string()]); - let mut stream = self.stream.clone(); - let mut end_r = end_r.fuse(); - - loop { - match Self::read_frame(&mut stream, &mut end_r).await { - Ok(frame) => { - #[cfg(feature = "metrics")] - { - metrics_cache.with_label_values(&frame).inc(); - if let Frame::Data { - mid: _, - start: _, - ref data, - } = frame - { - throughput_cache.inc_by(data.len() as u64); - } - } - w2c_cid_frame_s - .send((cid, Ok(frame))) - .await - .expect("Channel or Participant seems no longer to exist"); - }, - Err(e_option) => { - if let Some(e) = e_option { - info!(?e, "Closing tcp protocol due to read error"); - //w2c_cid_frame_s is shared, dropping it wouldn't notify the receiver as - // every channel is holding a sender! thats why Ne - // need a explicit STOP here - w2c_cid_frame_s - .send((cid, Err(()))) - .await - .expect("Channel or Participant seems no longer to exist"); - } - //None is clean shutdown - break; - }, - } - } - trace!("Shutting down tcp read()"); - } - - pub async fn write_frame( - w: &mut W, - frame: Frame, - ) -> Result<(), std::io::Error> { - match frame { - Frame::Handshake { - magic_number, - version, - } => { - w.write_all(&FRAME_HANDSHAKE.to_be_bytes()).await?; - w.write_all(&magic_number).await?; - w.write_all(&version[0].to_le_bytes()).await?; - w.write_all(&version[1].to_le_bytes()).await?; - w.write_all(&version[2].to_le_bytes()).await?; - }, - Frame::Init { pid, secret } => { - w.write_all(&FRAME_INIT.to_be_bytes()).await?; - w.write_all(&pid.to_le_bytes()).await?; - w.write_all(&secret.to_le_bytes()).await?; - }, - Frame::Shutdown => { - w.write_all(&FRAME_SHUTDOWN.to_be_bytes()).await?; - }, - Frame::OpenStream { - sid, - prio, - promises, - } => { - w.write_all(&FRAME_OPEN_STREAM.to_be_bytes()).await?; - w.write_all(&sid.to_le_bytes()).await?; - w.write_all(&prio.to_le_bytes()).await?; - w.write_all(&promises.to_le_bytes()).await?; - }, - Frame::CloseStream { sid } => { - w.write_all(&FRAME_CLOSE_STREAM.to_be_bytes()).await?; - w.write_all(&sid.to_le_bytes()).await?; - }, - Frame::DataHeader { mid, sid, length } => { - w.write_all(&FRAME_DATA_HEADER.to_be_bytes()).await?; - w.write_all(&mid.to_le_bytes()).await?; - w.write_all(&sid.to_le_bytes()).await?; - w.write_all(&length.to_le_bytes()).await?; - }, - Frame::Data { mid, start, data } => { - w.write_all(&FRAME_DATA.to_be_bytes()).await?; - w.write_all(&mid.to_le_bytes()).await?; - w.write_all(&start.to_le_bytes()).await?; - w.write_all(&(data.len() as u16).to_le_bytes()).await?; - w.write_all(&data).await?; - }, - Frame::Raw(data) => { - w.write_all(&FRAME_RAW.to_be_bytes()).await?; - w.write_all(&(data.len() as u16).to_le_bytes()).await?; - w.write_all(&data).await?; - }, - }; - Ok(()) - } - - pub async fn write_to_wire(&self, cid: Cid, mut c2w_frame_r: mpsc::UnboundedReceiver) { - trace!("Starting up tcp write()"); - let mut stream = self.stream.clone(); - #[cfg(feature = "metrics")] - let mut metrics_cache = CidFrameCache::new(self.metrics.frames_wire_out_total.clone(), cid); - #[cfg(feature = "metrics")] - let throughput_cache = self - .metrics - .wire_out_throughput - .with_label_values(&[&cid.to_string()]); - #[cfg(not(feature = "metrics"))] - let _cid = cid; - - while let Some(frame) = c2w_frame_r.next().await { - #[cfg(feature = "metrics")] - { - metrics_cache.with_label_values(&frame).inc(); - if let Frame::Data { - mid: _, - start: _, - ref data, - } = frame - { - throughput_cache.inc_by(data.len() as u64); - } - } - if let Err(e) = Self::write_frame(&mut stream, frame).await { - info!( - ?e, - "Got an error writing to tcp, going to close this channel" - ); - c2w_frame_r.close(); - break; - }; - } - trace!("shutting down tcp write()"); - } -} - -impl UdpProtocol { - pub(crate) fn new( - socket: Arc, - remote_addr: SocketAddr, - #[cfg(feature = "metrics")] metrics: Arc, - data_in: mpsc::UnboundedReceiver>, - ) -> Self { - Self { - socket, - remote_addr, - #[cfg(feature = "metrics")] - metrics, - data_in: Mutex::new(data_in), - } - } - - pub async fn read_from_wire( - &self, - cid: Cid, - w2c_cid_frame_s: &mut mpsc::UnboundedSender, - end_r: oneshot::Receiver<()>, - ) { - trace!("Starting up udp read()"); - #[cfg(feature = "metrics")] - let mut metrics_cache = CidFrameCache::new(self.metrics.frames_wire_in_total.clone(), cid); - #[cfg(feature = "metrics")] - let throughput_cache = self - .metrics - .wire_in_throughput - .with_label_values(&[&cid.to_string()]); - let mut data_in = self.data_in.lock().await; - let mut end_r = end_r.fuse(); - while let Some(bytes) = select! { - r = data_in.next().fuse() => match r { - Some(r) => Some(r), - None => { - info!("Udp read ended"); - w2c_cid_frame_s.send((cid, Err(()))).await.expect("Channel or Participant seems no longer to exist"); - None - } - }, - _ = end_r => None, - } { - trace!("Got raw UDP message with len: {}", bytes.len()); - let frame_no = bytes[0]; - let frame = match frame_no { - FRAME_HANDSHAKE => { - Frame::gen_handshake(*<&[u8; 19]>::try_from(&bytes[1..20]).unwrap()) - }, - FRAME_INIT => Frame::gen_init(*<&[u8; 32]>::try_from(&bytes[1..33]).unwrap()), - FRAME_SHUTDOWN => Frame::Shutdown, - FRAME_OPEN_STREAM => { - Frame::gen_open_stream(*<&[u8; 10]>::try_from(&bytes[1..11]).unwrap()) - }, - FRAME_CLOSE_STREAM => { - Frame::gen_close_stream(*<&[u8; 8]>::try_from(&bytes[1..9]).unwrap()) - }, - FRAME_DATA_HEADER => { - Frame::gen_data_header(*<&[u8; 24]>::try_from(&bytes[1..25]).unwrap()) - }, - FRAME_DATA => { - let (mid, start, length) = - Frame::gen_data(*<&[u8; 18]>::try_from(&bytes[1..19]).unwrap()); - let mut data = vec![0; length as usize]; - #[cfg(feature = "metrics")] - throughput_cache.inc_by(length as u64); - data.copy_from_slice(&bytes[19..]); - Frame::Data { mid, start, data } - }, - FRAME_RAW => { - let length = Frame::gen_raw(*<&[u8; 2]>::try_from(&bytes[1..3]).unwrap()); - let mut data = vec![0; length as usize]; - data.copy_from_slice(&bytes[3..]); - Frame::Raw(data) - }, - _ => Frame::Raw(bytes), - }; - #[cfg(feature = "metrics")] - metrics_cache.with_label_values(&frame).inc(); - w2c_cid_frame_s.send((cid, Ok(frame))).await.unwrap(); - } - trace!("Shutting down udp read()"); - } - - pub async fn write_to_wire(&self, cid: Cid, mut c2w_frame_r: mpsc::UnboundedReceiver) { - trace!("Starting up udp write()"); - let mut buffer = [0u8; 2000]; - #[cfg(feature = "metrics")] - let mut metrics_cache = CidFrameCache::new(self.metrics.frames_wire_out_total.clone(), cid); - #[cfg(feature = "metrics")] - let throughput_cache = self - .metrics - .wire_out_throughput - .with_label_values(&[&cid.to_string()]); - #[cfg(not(feature = "metrics"))] - let _cid = cid; - while let Some(frame) = c2w_frame_r.next().await { - #[cfg(feature = "metrics")] - metrics_cache.with_label_values(&frame).inc(); - let len = match frame { - Frame::Handshake { - magic_number, - version, - } => { - let x = FRAME_HANDSHAKE.to_be_bytes(); - buffer[0] = x[0]; - buffer[1..8].copy_from_slice(&magic_number); - buffer[8..12].copy_from_slice(&version[0].to_le_bytes()); - buffer[12..16].copy_from_slice(&version[1].to_le_bytes()); - buffer[16..20].copy_from_slice(&version[2].to_le_bytes()); - 20 - }, - Frame::Init { pid, secret } => { - buffer[0] = FRAME_INIT.to_be_bytes()[0]; - buffer[1..17].copy_from_slice(&pid.to_le_bytes()); - buffer[17..33].copy_from_slice(&secret.to_le_bytes()); - 33 - }, - Frame::Shutdown => { - buffer[0] = FRAME_SHUTDOWN.to_be_bytes()[0]; - 1 - }, - Frame::OpenStream { - sid, - prio, - promises, - } => { - buffer[0] = FRAME_OPEN_STREAM.to_be_bytes()[0]; - buffer[1..9].copy_from_slice(&sid.to_le_bytes()); - buffer[9] = prio.to_le_bytes()[0]; - buffer[10] = promises.to_le_bytes()[0]; - 11 - }, - Frame::CloseStream { sid } => { - buffer[0] = FRAME_CLOSE_STREAM.to_be_bytes()[0]; - buffer[1..9].copy_from_slice(&sid.to_le_bytes()); - 9 - }, - Frame::DataHeader { mid, sid, length } => { - buffer[0] = FRAME_DATA_HEADER.to_be_bytes()[0]; - buffer[1..9].copy_from_slice(&mid.to_le_bytes()); - buffer[9..17].copy_from_slice(&sid.to_le_bytes()); - buffer[17..25].copy_from_slice(&length.to_le_bytes()); - 25 - }, - Frame::Data { mid, start, data } => { - buffer[0] = FRAME_DATA.to_be_bytes()[0]; - buffer[1..9].copy_from_slice(&mid.to_le_bytes()); - buffer[9..17].copy_from_slice(&start.to_le_bytes()); - buffer[17..19].copy_from_slice(&(data.len() as u16).to_le_bytes()); - buffer[19..(data.len() + 19)].clone_from_slice(&data[..]); - #[cfg(feature = "metrics")] - throughput_cache.inc_by(data.len() as u64); - 19 + data.len() - }, - Frame::Raw(data) => { - buffer[0] = FRAME_RAW.to_be_bytes()[0]; - buffer[1..3].copy_from_slice(&(data.len() as u16).to_le_bytes()); - buffer[3..(data.len() + 3)].clone_from_slice(&data[..]); - 3 + data.len() - }, - }; - let mut start = 0; - while start < len { - trace!(?start, ?len, "Splitting up udp frame in multiple packages"); - match self - .socket - .send_to(&buffer[start..len], self.remote_addr) - .await - { - Ok(n) => { - start += n; - if n != len { - error!( - "THIS DOESN'T WORK, as RECEIVER CURRENTLY ONLY HANDLES 1 FRAME \ - per UDP message. splitting up will fail!" - ); - } - }, - Err(e) => error!(?e, "Need to handle that error!"), - } - } - } - trace!("Shutting down udp write()"); - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{metrics::NetworkMetrics, types::Pid}; - use async_std::net; - use futures::{executor::block_on, stream::StreamExt}; - use std::sync::Arc; - - #[test] - fn tcp_read_handshake() { - let pid = Pid::new(); - let cid = 80085; - let metrics = Arc::new(NetworkMetrics::new(&pid).unwrap()); - let addr = std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(127, 0, 0, 1), 50500); - block_on(async { - let server = net::TcpListener::bind(addr).await.unwrap(); - let mut client = net::TcpStream::connect(addr).await.unwrap(); - - let s_stream = server.incoming().next().await.unwrap().unwrap(); - let prot = TcpProtocol::new(s_stream, metrics); - - //Send Handshake - client.write_all(&[FRAME_HANDSHAKE]).await.unwrap(); - client.write_all(b"HELLOWO").await.unwrap(); - client.write_all(&1337u32.to_le_bytes()).await.unwrap(); - client.write_all(&0u32.to_le_bytes()).await.unwrap(); - client.write_all(&42u32.to_le_bytes()).await.unwrap(); - client.flush(); - - //handle data - let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded::(); - let (read_stop_sender, read_stop_receiver) = oneshot::channel(); - let cid2 = cid; - let t = std::thread::spawn(move || { - block_on(async { - prot.read_from_wire(cid2, &mut w2c_cid_frame_s, read_stop_receiver) - .await; - }) - }); - // Assert than we get some value back! Its a Handshake! - //async_std::task::sleep(std::time::Duration::from_millis(1000)); - let (cid_r, frame) = w2c_cid_frame_r.next().await.unwrap(); - assert_eq!(cid, cid_r); - if let Ok(Frame::Handshake { - magic_number, - version, - }) = frame - { - assert_eq!(&magic_number, b"HELLOWO"); - assert_eq!(version, [1337, 0, 42]); - } else { - panic!("wrong handshake"); - } - read_stop_sender.send(()).unwrap(); - t.join().unwrap(); - }); - } - - #[test] - fn tcp_read_garbage() { - let pid = Pid::new(); - let cid = 80085; - let metrics = Arc::new(NetworkMetrics::new(&pid).unwrap()); - let addr = std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(127, 0, 0, 1), 50501); - block_on(async { - let server = net::TcpListener::bind(addr).await.unwrap(); - let mut client = net::TcpStream::connect(addr).await.unwrap(); - - let s_stream = server.incoming().next().await.unwrap().unwrap(); - let prot = TcpProtocol::new(s_stream, metrics); - - //Send Handshake - client - .write_all("x4hrtzsektfhxugzdtz5r78gzrtzfhxfdthfthuzhfzzufasgasdfg".as_bytes()) - .await - .unwrap(); - client.flush(); - //handle data - let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded::(); - let (read_stop_sender, read_stop_receiver) = oneshot::channel(); - let cid2 = cid; - let t = std::thread::spawn(move || { - block_on(async { - prot.read_from_wire(cid2, &mut w2c_cid_frame_s, read_stop_receiver) - .await; - }) - }); - // Assert than we get some value back! Its a Raw! - let (cid_r, frame) = w2c_cid_frame_r.next().await.unwrap(); - assert_eq!(cid, cid_r); - if let Ok(Frame::Raw(data)) = frame { - assert_eq!(&data.as_slice(), b"x4hrtzsektfhxugzdtz5r78gzrtzfhxf"); - } else { - panic!("wrong frame type"); - } - read_stop_sender.send(()).unwrap(); - t.join().unwrap(); - }); - } -} diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index 33cb4ed054..c2ac365b0f 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -1,21 +1,11 @@ -#[cfg(feature = "metrics")] -use crate::metrics::NetworkMetrics; use crate::{ api::{Participant, ProtocolAddr}, - channel::Handshake, + channel::Protocols, + metrics::NetworkMetrics, participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel, S2bShutdownBparticipant}, - protocols::{Protocols, TcpProtocol, UdpProtocol}, - types::Pid, -}; -use async_std::{io, net, sync::Mutex}; -use futures::{ - channel::{mpsc, oneshot}, - executor::ThreadPool, - future::FutureExt, - select, - sink::SinkExt, - stream::StreamExt, }; +use futures_util::{FutureExt, StreamExt}; +use network_protocol::{Cid, MpscMsg, Pid, ProtocolMetrics}; #[cfg(feature = "metrics")] use prometheus::Registry; use rand::Rng; @@ -25,18 +15,29 @@ use std::{ atomic::{AtomicBool, AtomicU64, Ordering}, Arc, }, + time::Duration, }; +use tokio::{ + io, net, select, + sync::{mpsc, oneshot, Mutex}, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::*; -use tracing_futures::Instrument; -/// Naming of Channels `x2x` -/// - a: api -/// - s: scheduler -/// - b: bparticipant -/// - p: prios -/// - r: protocol -/// - w: wire -/// - c: channel/handshake +// Naming of Channels `x2x` +// - a: api +// - s: scheduler +// - b: bparticipant +// - p: prios +// - r: protocol +// - w: wire +// - c: channel/handshake + +lazy_static::lazy_static! { + static ref MPSC_POOL: Mutex>> = { + Mutex::new(HashMap::new()) + }; +} #[derive(Debug)] struct ParticipantInfo { @@ -48,6 +49,10 @@ struct ParticipantInfo { type A2sListen = (ProtocolAddr, oneshot::Sender>); type A2sConnect = (ProtocolAddr, oneshot::Sender>); type A2sDisconnect = (Pid, S2bShutdownBparticipant); +type S2sMpscConnect = ( + mpsc::Sender, + oneshot::Sender>, +); #[derive(Debug)] struct ControlChannels { @@ -70,17 +75,18 @@ pub struct Scheduler { local_pid: Pid, local_secret: u128, closed: AtomicBool, - pool: Arc, run_channels: Option, participant_channels: Arc>>, participants: Arc>>, channel_ids: Arc, channel_listener: Mutex>>, - #[cfg(feature = "metrics")] metrics: Arc, + protocol_metrics: Arc, } impl Scheduler { + const MPSC_CHANNEL_BOUND: usize = 1000; + pub fn new( local_pid: Pid, #[cfg(feature = "metrics")] registry: Option<&Registry>, @@ -91,12 +97,13 @@ impl Scheduler { mpsc::UnboundedReceiver, oneshot::Sender<()>, ) { - let (a2s_listen_s, a2s_listen_r) = mpsc::unbounded::(); - let (a2s_connect_s, a2s_connect_r) = mpsc::unbounded::(); - let (s2a_connected_s, s2a_connected_r) = mpsc::unbounded::(); + let (a2s_listen_s, a2s_listen_r) = mpsc::unbounded_channel::(); + let (a2s_connect_s, a2s_connect_r) = mpsc::unbounded_channel::(); + let (s2a_connected_s, s2a_connected_r) = mpsc::unbounded_channel::(); let (a2s_scheduler_shutdown_s, a2s_scheduler_shutdown_r) = oneshot::channel::<()>(); - let (a2s_disconnect_s, a2s_disconnect_r) = mpsc::unbounded::(); - let (b2s_prio_statistic_s, b2s_prio_statistic_r) = mpsc::unbounded::(); + let (a2s_disconnect_s, a2s_disconnect_r) = mpsc::unbounded_channel::(); + let (b2s_prio_statistic_s, b2s_prio_statistic_r) = + mpsc::unbounded_channel::(); let run_channels = Some(ControlChannels { a2s_listen_r, @@ -112,13 +119,14 @@ impl Scheduler { b2s_prio_statistic_s, }; - #[cfg(feature = "metrics")] let metrics = Arc::new(NetworkMetrics::new(&local_pid).unwrap()); + let protocol_metrics = Arc::new(ProtocolMetrics::new().unwrap()); #[cfg(feature = "metrics")] { if let Some(registry) = registry { metrics.register(registry).unwrap(); + protocol_metrics.register(registry).unwrap(); } } @@ -130,14 +138,13 @@ impl Scheduler { local_pid, local_secret, closed: AtomicBool::new(false), - pool: Arc::new(ThreadPool::new().unwrap()), run_channels, participant_channels: Arc::new(Mutex::new(Some(participant_channels))), participants: Arc::new(Mutex::new(HashMap::new())), channel_ids: Arc::new(AtomicU64::new(0)), channel_listener: Mutex::new(HashMap::new()), - #[cfg(feature = "metrics")] metrics, + protocol_metrics, }, a2s_listen_s, a2s_connect_s, @@ -149,7 +156,7 @@ impl Scheduler { pub async fn run(mut self) { let run_channels = self.run_channels.take().unwrap(); - futures::join!( + tokio::join!( self.listen_mgr(run_channels.a2s_listen_r), self.connect_mgr(run_channels.a2s_connect_r), self.disconnect_mgr(run_channels.a2s_disconnect_r), @@ -160,6 +167,7 @@ impl Scheduler { async fn listen_mgr(&self, a2s_listen_r: mpsc::UnboundedReceiver) { trace!("Start listen_mgr"); + let a2s_listen_r = UnboundedReceiverStream::new(a2s_listen_r); a2s_listen_r .for_each_concurrent(None, |(address, s2a_listen_result_s)| { let address = address; @@ -196,8 +204,8 @@ impl Scheduler { )>, ) { trace!("Start connect_mgr"); - while let Some((addr, pid_sender)) = a2s_connect_r.next().await { - let (protocol, handshake) = match addr { + while let Some((addr, pid_sender)) = a2s_connect_r.recv().await { + let (protocol, cid, handshake) = match addr { ProtocolAddr::Tcp(addr) => { #[cfg(feature = "metrics")] self.metrics @@ -211,51 +219,84 @@ impl Scheduler { continue; }, }; + let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); info!("Connecting Tcp to: {}", stream.peer_addr().unwrap()); ( - Protocols::Tcp(TcpProtocol::new( - stream, - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - )), + Protocols::new_tcp(stream, cid, Arc::clone(&self.protocol_metrics)), + cid, false, ) }, - ProtocolAddr::Udp(addr) => { - #[cfg(feature = "metrics")] - self.metrics - .connect_requests_total - .with_label_values(&["udp"]) - .inc(); - let socket = match net::UdpSocket::bind("0.0.0.0:0").await { - Ok(socket) => Arc::new(socket), - Err(e) => { - pid_sender.send(Err(e)).unwrap(); + ProtocolAddr::Mpsc(addr) => { + let mpsc_s = match MPSC_POOL.lock().await.get(&addr) { + Some(s) => s.clone(), + None => { + pid_sender + .send(Err(std::io::Error::new( + std::io::ErrorKind::NotConnected, + "no mpsc listen on this addr", + ))) + .unwrap(); continue; }, }; - if let Err(e) = socket.connect(addr).await { - pid_sender.send(Err(e)).unwrap(); - continue; - }; - info!("Connecting Udp to: {}", addr); - let (udp_data_sender, udp_data_receiver) = mpsc::unbounded::>(); - let protocol = UdpProtocol::new( - Arc::clone(&socket), - addr, - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - udp_data_receiver, - ); - self.pool.spawn_ok( - Self::udp_single_channel_connect(Arc::clone(&socket), udp_data_sender) - .instrument(tracing::info_span!("udp", ?addr)), - ); - (Protocols::Udp(protocol), true) + let (remote_to_local_s, remote_to_local_r) = + mpsc::channel(Self::MPSC_CHANNEL_BOUND); + let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel(); + mpsc_s + .send((remote_to_local_s, local_to_remote_oneshot_s)) + .unwrap(); + let local_to_remote_s = local_to_remote_oneshot_r.await.unwrap(); + + let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); + info!(?addr, "Connecting Mpsc"); + ( + Protocols::new_mpsc( + local_to_remote_s, + remote_to_local_r, + cid, + Arc::clone(&self.protocol_metrics), + ), + cid, + false, + ) }, + /* */ + //ProtocolAddr::Udp(addr) => { + //#[cfg(feature = "metrics")] + //self.metrics + //.connect_requests_total + //.with_label_values(&["udp"]) + //.inc(); + //let socket = match net::UdpSocket::bind("0.0.0.0:0").await { + //Ok(socket) => Arc::new(socket), + //Err(e) => { + //pid_sender.send(Err(e)).unwrap(); + //continue; + //}, + //}; + //if let Err(e) = socket.connect(addr).await { + //pid_sender.send(Err(e)).unwrap(); + //continue; + //}; + //info!("Connecting Udp to: {}", addr); + //let (udp_data_sender, udp_data_receiver) = mpsc::unbounded_channel::>(); + //let protocol = UdpProtocol::new( + //Arc::clone(&socket), + //addr, + //#[cfg(feature = "metrics")] + //Arc::clone(&self.metrics), + //udp_data_receiver, + //); + //self.runtime.spawn( + //Self::udp_single_channel_connect(Arc::clone(&socket), udp_data_sender) + //.instrument(tracing::info_span!("udp", ?addr)), + //); + //(Protocols::Udp(protocol), true) + //}, _ => unimplemented!(), }; - self.init_protocol(protocol, Some(pid_sender), handshake) + self.init_protocol(protocol, cid, Some(pid_sender), handshake) .await; } trace!("Stop connect_mgr"); @@ -263,7 +304,9 @@ impl Scheduler { async fn disconnect_mgr(&self, mut a2s_disconnect_r: mpsc::UnboundedReceiver) { trace!("Start disconnect_mgr"); - while let Some((pid, return_once_successful_shutdown)) = a2s_disconnect_r.next().await { + while let Some((pid, (timeout_time, return_once_successful_shutdown))) = + a2s_disconnect_r.recv().await + { //Closing Participants is done the following way: // 1. We drop our senders and receivers // 2. we need to close BParticipant, this will drop its senderns and receivers @@ -277,7 +320,7 @@ impl Scheduler { pi.s2b_shutdown_bparticipant_s .take() .unwrap() - .send(finished_sender) + .send((timeout_time, finished_sender)) .unwrap(); drop(pi); trace!(?pid, "dropped bparticipant, waiting for finish"); @@ -298,7 +341,7 @@ impl Scheduler { mut b2s_prio_statistic_r: mpsc::UnboundedReceiver, ) { trace!("Start prio_adj_mgr"); - while let Some((_pid, _frame_cnt, _unused)) = b2s_prio_statistic_r.next().await { + while let Some((_pid, _frame_cnt, _unused)) = b2s_prio_statistic_r.recv().await { //TODO adjust prios in participants here! } @@ -320,7 +363,7 @@ impl Scheduler { pi.s2b_shutdown_bparticipant_s .take() .unwrap() - .send(finished_sender) + .send((Duration::from_secs(120), finished_sender)) .unwrap(); (pid, finished_receiver) }) @@ -370,43 +413,51 @@ impl Scheduler { info!( ?addr, ?e, - "Listener couldn't be started due to error on tcp bind" + "Tcp bind error durin listener startup" ); s2a_listen_result_s.send(Err(e)).unwrap(); return; }, }; trace!(?addr, "Listener bound"); - let mut incoming = listener.incoming(); let mut end_receiver = s2s_stop_listening_r.fuse(); - while let Some(stream) = select! { - next = incoming.next().fuse() => next, - _ = end_receiver => None, + while let Some(data) = select! { + next = listener.accept().fuse() => Some(next), + _ = &mut end_receiver => None, } { - let stream = match stream { - Ok(s) => s, + let (stream, remote_addr) = match data { + Ok((s, p)) => (s, p), Err(e) => { warn!(?e, "TcpStream Error, ignoring connection attempt"); continue; }, }; - let peer_addr = match stream.peer_addr() { - Ok(s) => s, - Err(e) => { - warn!(?e, "TcpStream Error, ignoring connection attempt"); - continue; - }, - }; - info!("Accepting Tcp from: {}", peer_addr); - let protocol = TcpProtocol::new( - stream, - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - ); - self.init_protocol(Protocols::Tcp(protocol), None, true) + info!("Accepting Tcp from: {}", remote_addr); + let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); + self.init_protocol(Protocols::new_tcp(stream, cid, Arc::clone(&self.protocol_metrics)), cid, None, true) .await; } }, + ProtocolAddr::Mpsc(addr) => { + let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel(); + MPSC_POOL.lock().await.insert(addr, mpsc_s); + s2a_listen_result_s.send(Ok(())).unwrap(); + trace!(?addr, "Listener bound"); + + let mut end_receiver = s2s_stop_listening_r.fuse(); + while let Some((local_to_remote_s, local_remote_to_local_s)) = select! { + next = mpsc_r.recv().fuse() => next, + _ = &mut end_receiver => None, + } { + let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND); + local_remote_to_local_s.send(remote_to_local_s).unwrap(); + info!(?addr, "Accepting Mpsc from"); + let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); + self.init_protocol(Protocols::new_mpsc(local_to_remote_s, remote_to_local_r, cid, Arc::clone(&self.protocol_metrics)), cid, None, true) + .await; + } + warn!("MpscStream Failed, stopping"); + },/* ProtocolAddr::Udp(addr) => { let socket = match net::UdpSocket::bind(addr).await { Ok(socket) => { @@ -432,7 +483,7 @@ impl Scheduler { let mut data = [0u8; UDP_MAXIMUM_SINGLE_PACKET_SIZE_EVER]; while let Ok((size, remote_addr)) = select! { next = socket.recv_from(&mut data).fuse() => next, - _ = end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")), + _ = &mut end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")), } { let mut datavec = Vec::with_capacity(size); datavec.extend_from_slice(&data[0..size]); @@ -441,7 +492,8 @@ impl Scheduler { #[allow(clippy::map_entry)] if !listeners.contains_key(&remote_addr) { info!("Accepting Udp from: {}", &remote_addr); - let (udp_data_sender, udp_data_receiver) = mpsc::unbounded::>(); + let (udp_data_sender, udp_data_receiver) = + mpsc::unbounded_channel::>(); listeners.insert(remote_addr, udp_data_sender); let protocol = UdpProtocol::new( Arc::clone(&socket), @@ -454,17 +506,18 @@ impl Scheduler { .await; } let udp_data_sender = listeners.get_mut(&remote_addr).unwrap(); - udp_data_sender.send(datavec).await.unwrap(); + udp_data_sender.send(datavec).unwrap(); } - }, + },*/ _ => unimplemented!(), } trace!(?addr, "Ending channel creator"); } + #[allow(dead_code)] async fn udp_single_channel_connect( socket: Arc, - mut w2p_udp_package_s: mpsc::UnboundedSender>, + w2p_udp_package_s: mpsc::UnboundedSender>, ) { let addr = socket.local_addr(); trace!(?addr, "Start udp_single_channel_connect"); @@ -477,18 +530,19 @@ impl Scheduler { let mut data = [0u8; 9216]; while let Ok(size) = select! { next = socket.recv(&mut data).fuse() => next, - _ = end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")), + _ = &mut end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")), } { let mut datavec = Vec::with_capacity(size); datavec.extend_from_slice(&data[0..size]); - w2p_udp_package_s.send(datavec).await.unwrap(); + w2p_udp_package_s.send(datavec).unwrap(); } trace!(?addr, "Stop udp_single_channel_connect"); } async fn init_protocol( &self, - protocol: Protocols, + mut protocol: Protocols, + cid: Cid, s2a_return_pid_s: Option>>, send_handshake: bool, ) { @@ -498,36 +552,26 @@ impl Scheduler { Contra: - DOS possibility because we answer first - Speed, because otherwise the message can be send with the creation */ - let mut participant_channels = self.participant_channels.lock().await.clone().unwrap(); + let participant_channels = self.participant_channels.lock().await.clone().unwrap(); // spawn is needed here, e.g. for TCP connect it would mean that only 1 // participant can be in handshake phase ever! Someone could deadlock // the whole server easily for new clients UDP doesnt work at all, as // the UDP listening is done in another place. - let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); let participants = Arc::clone(&self.participants); - #[cfg(feature = "metrics")] let metrics = Arc::clone(&self.metrics); - let pool = Arc::clone(&self.pool); let local_pid = self.local_pid; let local_secret = self.local_secret; // this is necessary for UDP to work at all and to remove code duplication - self.pool.spawn_ok( + tokio::spawn( async move { trace!(?cid, "Open channel and be ready for Handshake"); - let handshake = Handshake::new( - cid, - local_pid, - local_secret, - #[cfg(feature = "metrics")] - Arc::clone(&metrics), - send_handshake, - ); - match handshake - .setup(&protocol) + use network_protocol::InitProtocol; + let init_result = protocol + .initialize(send_handshake, local_pid, local_secret) .instrument(tracing::info_span!("handshake", ?cid)) - .await - { - Ok((pid, sid, secret, leftover_cid_frame)) => { + .await; + match init_result { + Ok((pid, sid, secret)) => { trace!( ?cid, ?pid, @@ -538,21 +582,16 @@ impl Scheduler { debug!(?cid, "New participant connected via a channel"); let ( bparticipant, - a2b_stream_open_s, + a2b_open_stream_s, b2a_stream_opened_r, - mut s2b_create_channel_s, + s2b_create_channel_s, s2b_shutdown_bparticipant_s, - ) = BParticipant::new( - pid, - sid, - #[cfg(feature = "metrics")] - Arc::clone(&metrics), - ); + ) = BParticipant::new(local_pid, pid, sid, Arc::clone(&metrics)); let participant = Participant::new( local_pid, pid, - a2b_stream_open_s, + a2b_open_stream_s, b2a_stream_opened_r, participant_channels.a2s_disconnect_s, ); @@ -566,24 +605,18 @@ impl Scheduler { }); drop(participants); trace!("dropped participants lock"); - pool.spawn_ok( + let p = pid; + tokio::spawn( bparticipant .run(participant_channels.b2s_prio_statistic_s) - .instrument(tracing::info_span!("participant", ?pid)), + .instrument(tracing::info_span!("remote", ?p)), ); //create a new channel within BParticipant and wait for it to run let (b2s_create_channel_done_s, b2s_create_channel_done_r) = oneshot::channel(); //From now on wire connects directly with bparticipant! s2b_create_channel_s - .send(( - cid, - sid, - protocol, - leftover_cid_frame, - b2s_create_channel_done_s, - )) - .await + .send((cid, sid, protocol, b2s_create_channel_done_s)) .unwrap(); b2s_create_channel_done_r.await.unwrap(); if let Some(pid_oneshot) = s2a_return_pid_s { @@ -594,7 +627,6 @@ impl Scheduler { participant_channels .s2a_connected_s .send(participant) - .await .unwrap(); } } else { @@ -632,8 +664,8 @@ impl Scheduler { //From now on this CHANNEL can receiver other frames! // move directly to participant! }, - Err(()) => { - debug!(?cid, "Handshake from a new connection failed"); + Err(e) => { + debug!(?cid, ?e, "Handshake from a new connection failed"); if let Some(pid_oneshot) = s2a_return_pid_s { // someone is waiting with `connect`, so give them their Error trace!(?cid, "returning the Err to api who requested the connect"); diff --git a/network/src/types.rs b/network/src/types.rs deleted file mode 100644 index d257ed808f..0000000000 --- a/network/src/types.rs +++ /dev/null @@ -1,327 +0,0 @@ -use bitflags::bitflags; -use rand::Rng; -use std::convert::TryFrom; - -pub type Mid = u64; -pub type Cid = u64; -pub type Prio = u8; - -bitflags! { - /// use promises to modify the behavior of [`Streams`]. - /// see the consts in this `struct` for - /// - /// [`Streams`]: crate::api::Stream - pub struct Promises: u8 { - /// this will guarantee that the order of messages which are send on one side, - /// is the same when received on the other. - const ORDERED = 0b00000001; - /// this will guarantee that messages received haven't been altered by errors, - /// like bit flips, this is done with a checksum. - const CONSISTENCY = 0b00000010; - /// this will guarantee that the other side will receive every message exactly - /// once no messages are dropped - const GUARANTEED_DELIVERY = 0b00000100; - /// this will enable the internal compression on this - /// [`Stream`](crate::api::Stream) - #[cfg(feature = "compression")] - const COMPRESSED = 0b00001000; - /// this will enable the internal encryption on this - /// [`Stream`](crate::api::Stream) - const ENCRYPTED = 0b00010000; - } -} - -impl Promises { - pub const fn to_le_bytes(self) -> [u8; 1] { self.bits.to_le_bytes() } -} - -pub(crate) const VELOREN_MAGIC_NUMBER: [u8; 7] = [86, 69, 76, 79, 82, 69, 78]; //VELOREN -pub const VELOREN_NETWORK_VERSION: [u32; 3] = [0, 5, 0]; -pub(crate) const STREAM_ID_OFFSET1: Sid = Sid::new(0); -pub(crate) const STREAM_ID_OFFSET2: Sid = Sid::new(u64::MAX / 2); - -/// Support struct used for uniquely identifying [`Participant`] over the -/// [`Network`]. -/// -/// [`Participant`]: crate::api::Participant -/// [`Network`]: crate::api::Network -#[derive(PartialEq, Eq, Hash, Clone, Copy)] -pub struct Pid { - internal: u128, -} - -#[derive(PartialEq, Eq, Hash, Clone, Copy)] -pub(crate) struct Sid { - internal: u64, -} - -// Used for Communication between Channel <----(TCP/UDP)----> Channel -#[derive(Debug)] -pub(crate) enum Frame { - Handshake { - magic_number: [u8; 7], - version: [u32; 3], - }, - Init { - pid: Pid, - secret: u128, - }, - Shutdown, /* Shutdown this channel gracefully, if all channels are shutdown, Participant - * is deleted */ - OpenStream { - sid: Sid, - prio: Prio, - promises: Promises, - }, - CloseStream { - sid: Sid, - }, - DataHeader { - mid: Mid, - sid: Sid, - length: u64, - }, - Data { - mid: Mid, - start: u64, - data: Vec, - }, - /* WARNING: Sending RAW is only used for debug purposes in case someone write a new API - * against veloren Server! */ - Raw(Vec), -} - -impl Frame { - #[cfg(feature = "metrics")] - pub const FRAMES_LEN: u8 = 8; - - #[cfg(feature = "metrics")] - pub const fn int_to_string(i: u8) -> &'static str { - match i { - 0 => "Handshake", - 1 => "Init", - 2 => "Shutdown", - 3 => "OpenStream", - 4 => "CloseStream", - 5 => "DataHeader", - 6 => "Data", - 7 => "Raw", - _ => "", - } - } - - #[cfg(feature = "metrics")] - pub fn get_int(&self) -> u8 { - match self { - Frame::Handshake { .. } => 0, - Frame::Init { .. } => 1, - Frame::Shutdown => 2, - Frame::OpenStream { .. } => 3, - Frame::CloseStream { .. } => 4, - Frame::DataHeader { .. } => 5, - Frame::Data { .. } => 6, - Frame::Raw(_) => 7, - } - } - - #[cfg(feature = "metrics")] - pub fn get_string(&self) -> &str { Self::int_to_string(self.get_int()) } - - pub fn gen_handshake(buf: [u8; 19]) -> Self { - let magic_number = *<&[u8; 7]>::try_from(&buf[0..7]).unwrap(); - Frame::Handshake { - magic_number, - version: [ - u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[7..11]).unwrap()), - u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[11..15]).unwrap()), - u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[15..19]).unwrap()), - ], - } - } - - pub fn gen_init(buf: [u8; 32]) -> Self { - Frame::Init { - pid: Pid::from_le_bytes(*<&[u8; 16]>::try_from(&buf[0..16]).unwrap()), - secret: u128::from_le_bytes(*<&[u8; 16]>::try_from(&buf[16..32]).unwrap()), - } - } - - pub fn gen_open_stream(buf: [u8; 10]) -> Self { - Frame::OpenStream { - sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), - prio: buf[8], - promises: Promises::from_bits_truncate(buf[9]), - } - } - - pub fn gen_close_stream(buf: [u8; 8]) -> Self { - Frame::CloseStream { - sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), - } - } - - pub fn gen_data_header(buf: [u8; 24]) -> Self { - Frame::DataHeader { - mid: Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), - sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()), - length: u64::from_le_bytes(*<&[u8; 8]>::try_from(&buf[16..24]).unwrap()), - } - } - - pub fn gen_data(buf: [u8; 18]) -> (Mid, u64, u16) { - let mid = Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()); - let start = u64::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()); - let length = u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[16..18]).unwrap()); - (mid, start, length) - } - - pub fn gen_raw(buf: [u8; 2]) -> u16 { - u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[0..2]).unwrap()) - } -} - -impl Pid { - /// create a new Pid with a random interior value - /// - /// # Example - /// ```rust - /// use veloren_network::{Network, Pid}; - /// - /// let pid = Pid::new(); - /// let _ = Network::new(pid); - /// ``` - pub fn new() -> Self { - Self { - internal: rand::thread_rng().gen(), - } - } - - /// don't use fake! just for testing! - /// This will panic if pid i greater than 7, as I do not want you to use - /// this in production! - #[doc(hidden)] - pub fn fake(pid_offset: u8) -> Self { - assert!(pid_offset < 8); - let o = pid_offset as u128; - const OFF: [u128; 5] = [ - 0x40, - 0x40 * 0x40, - 0x40 * 0x40 * 0x40, - 0x40 * 0x40 * 0x40 * 0x40, - 0x40 * 0x40 * 0x40 * 0x40 * 0x40, - ]; - Self { - internal: o + o * OFF[0] + o * OFF[1] + o * OFF[2] + o * OFF[3] + o * OFF[4], - } - } - - pub(crate) fn to_le_bytes(&self) -> [u8; 16] { self.internal.to_le_bytes() } - - pub(crate) fn from_le_bytes(bytes: [u8; 16]) -> Self { - Self { - internal: u128::from_le_bytes(bytes), - } - } -} - -impl Sid { - pub const fn new(internal: u64) -> Self { Self { internal } } - - pub(crate) fn to_le_bytes(&self) -> [u8; 8] { self.internal.to_le_bytes() } - - pub(crate) fn from_le_bytes(bytes: [u8; 8]) -> Self { - Self { - internal: u64::from_le_bytes(bytes), - } - } -} - -impl std::fmt::Debug for Pid { - #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - const BITS_PER_SIXLET: usize = 6; - //only print last 6 chars of number as full u128 logs are unreadable - const CHAR_COUNT: usize = 6; - for i in 0..CHAR_COUNT { - write!( - f, - "{}", - sixlet_to_str((self.internal >> (i * BITS_PER_SIXLET)) & 0x3F) - )?; - } - Ok(()) - } -} - -impl Default for Pid { - fn default() -> Self { Pid::new() } -} - -impl std::fmt::Display for Pid { - #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:?}", self) } -} - -impl std::ops::AddAssign for Sid { - fn add_assign(&mut self, other: Self) { - *self = Self { - internal: self.internal + other.internal, - }; - } -} - -impl std::fmt::Debug for Sid { - #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - //only print last 6 chars of number as full u128 logs are unreadable - write!(f, "{}", self.internal.rem_euclid(1000000)) - } -} - -impl From for Sid { - fn from(internal: u64) -> Self { Sid { internal } } -} - -impl std::fmt::Display for Sid { - #[inline] - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.internal) - } -} - -fn sixlet_to_str(sixlet: u128) -> char { - b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[sixlet as usize] as char -} - -#[cfg(test)] -mod tests { - use crate::types::*; - - #[test] - fn frame_int2str() { - assert_eq!(Frame::int_to_string(3), "OpenStream"); - assert_eq!(Frame::int_to_string(7), "Raw"); - assert_eq!(Frame::int_to_string(8), ""); - } - - #[test] - fn frame_get_int() { - assert_eq!(Frame::get_int(&Frame::Raw(b"Foo".to_vec())), 7); - assert_eq!(Frame::get_int(&Frame::Shutdown), 2); - } - - #[test] - fn frame_creation() { - Pid::new(); - assert_eq!(format!("{}", Pid::fake(0)), "AAAAAA"); - assert_eq!(format!("{}", Pid::fake(1)), "BBBBBB"); - assert_eq!(format!("{}", Pid::fake(2)), "CCCCCC"); - } - - #[test] - fn test_sixlet_to_str() { - assert_eq!(sixlet_to_str(0), 'A'); - assert_eq!(sixlet_to_str(29), 'd'); - assert_eq!(sixlet_to_str(63), '/'); - } -} diff --git a/network/tests/closing.rs b/network/tests/closing.rs index fac118ff5a..8606879174 100644 --- a/network/tests/closing.rs +++ b/network/tests/closing.rs @@ -18,8 +18,8 @@ //! - You sometimes see sleep(1000ms) this is used when we rely on the //! underlying TCP functionality, as this simulates client and server -use async_std::task; -use task::block_on; +use std::sync::Arc; +use tokio::runtime::Runtime; use veloren_network::{Network, ParticipantError, Pid, Promises, StreamError}; mod helper; use helper::{network_participant_stream, tcp}; @@ -27,26 +27,26 @@ use helper::{network_participant_stream, tcp}; #[test] fn close_network() { let (_, _) = helper::setup(false, 0); - let (_, _p1_a, mut s1_a, _, _p1_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _, _p1_a, mut s1_a, _, _p1_b, mut s1_b) = network_participant_stream(tcp()); std::thread::sleep(std::time::Duration::from_millis(1000)); assert_eq!(s1_a.send("Hello World"), Err(StreamError::StreamClosed)); - let msg1: Result = block_on(s1_b.recv()); + let msg1: Result = r.block_on(s1_b.recv()); assert_eq!(msg1, Err(StreamError::StreamClosed)); } #[test] fn close_participant() { let (_, _) = helper::setup(false, 0); - let (_n_a, p1_a, mut s1_a, _n_b, p1_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _n_a, p1_a, mut s1_a, _n_b, p1_b, mut s1_b) = network_participant_stream(tcp()); - block_on(p1_a.disconnect()).unwrap(); - block_on(p1_b.disconnect()).unwrap(); + r.block_on(p1_a.disconnect()).unwrap(); + r.block_on(p1_b.disconnect()).unwrap(); assert_eq!(s1_a.send("Hello World"), Err(StreamError::StreamClosed)); assert_eq!( - block_on(s1_b.recv::()), + r.block_on(s1_b.recv::()), Err(StreamError::StreamClosed) ); } @@ -54,26 +54,25 @@ fn close_participant() { #[test] fn close_stream() { let (_, _) = helper::setup(false, 0); - let (_n_a, _, mut s1_a, _n_b, _, _) = block_on(network_participant_stream(tcp())); + let (r, _n_a, _, mut s1_a, _n_b, _, _) = network_participant_stream(tcp()); // s1_b is dropped directly while s1_a isn't std::thread::sleep(std::time::Duration::from_millis(1000)); assert_eq!(s1_a.send("Hello World"), Err(StreamError::StreamClosed)); assert_eq!( - block_on(s1_a.recv::()), + r.block_on(s1_a.recv::()), Err(StreamError::StreamClosed) ); } -///THIS is actually a bug which currently luckily doesn't trigger, but with new -/// async-std WE must make sure, if a stream is `drop`ed inside a `block_on`, -/// that no panic is thrown. +///WE must NOT create runtimes inside a Runtime, this check needs to verify +/// that we dont panic there #[test] fn close_streams_in_block_on() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, s1_a, _n_b, _p_b, s1_b) = block_on(network_participant_stream(tcp())); - block_on(async { + let (r, _n_a, _p_a, s1_a, _n_b, _p_b, s1_b) = network_participant_stream(tcp()); + r.block_on(async { //make it locally so that they are dropped later let mut s1_a = s1_a; let mut s1_b = s1_b; @@ -81,19 +80,20 @@ fn close_streams_in_block_on() { assert_eq!(s1_b.recv().await, Ok("ping".to_string())); drop(s1_a); }); + drop((_n_a, _p_a, _n_b, _p_b)); //clean teardown } #[test] fn stream_simple_3msg_then_close() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send(1u8).unwrap(); s1_a.send(42).unwrap(); s1_a.send("3rdMessage").unwrap(); - assert_eq!(block_on(s1_b.recv()), Ok(1u8)); - assert_eq!(block_on(s1_b.recv()), Ok(42)); - assert_eq!(block_on(s1_b.recv()), Ok("3rdMessage".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok(1u8)); + assert_eq!(r.block_on(s1_b.recv()), Ok(42)); + assert_eq!(r.block_on(s1_b.recv()), Ok("3rdMessage".to_string())); drop(s1_a); std::thread::sleep(std::time::Duration::from_millis(1000)); assert_eq!(s1_b.send("Hello World"), Err(StreamError::StreamClosed)); @@ -103,43 +103,43 @@ fn stream_simple_3msg_then_close() { fn stream_send_first_then_receive() { // recv should still be possible even if stream got closed if they are in queue let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send(1u8).unwrap(); s1_a.send(42).unwrap(); s1_a.send("3rdMessage").unwrap(); drop(s1_a); std::thread::sleep(std::time::Duration::from_millis(1000)); - assert_eq!(block_on(s1_b.recv()), Ok(1u8)); - assert_eq!(block_on(s1_b.recv()), Ok(42)); - assert_eq!(block_on(s1_b.recv()), Ok("3rdMessage".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok(1u8)); + assert_eq!(r.block_on(s1_b.recv()), Ok(42)); + assert_eq!(r.block_on(s1_b.recv()), Ok("3rdMessage".to_string())); assert_eq!(s1_b.send("Hello World"), Err(StreamError::StreamClosed)); } #[test] fn stream_send_1_then_close_stream() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send("this message must be received, even if stream is closed already!") .unwrap(); drop(s1_a); std::thread::sleep(std::time::Duration::from_millis(1000)); let exp = Ok("this message must be received, even if stream is closed already!".to_string()); - assert_eq!(block_on(s1_b.recv()), exp); + assert_eq!(r.block_on(s1_b.recv()), exp); println!("all received and done"); } #[test] fn stream_send_100000_then_close_stream() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); for _ in 0..100000 { s1_a.send("woop_PARTY_HARD_woop").unwrap(); } drop(s1_a); let exp = Ok("woop_PARTY_HARD_woop".to_string()); println!("start receiving"); - block_on(async { + r.block_on(async { for _ in 0..100000 { assert_eq!(s1_b.recv().await, exp); } @@ -150,19 +150,20 @@ fn stream_send_100000_then_close_stream() { #[test] fn stream_send_100000_then_close_stream_remote() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, _s1_b) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _p_a, mut s1_a, _n_b, _p_b, _s1_b) = network_participant_stream(tcp()); for _ in 0..100000 { s1_a.send("woop_PARTY_HARD_woop").unwrap(); } drop(s1_a); drop(_s1_b); //no receiving + drop((_n_a, _p_a, _n_b, _p_b)); //clean teardown } #[test] fn stream_send_100000_then_close_stream_remote2() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, _s1_b) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _p_a, mut s1_a, _n_b, _p_b, _s1_b) = network_participant_stream(tcp()); for _ in 0..100000 { s1_a.send("woop_PARTY_HARD_woop").unwrap(); } @@ -170,12 +171,13 @@ fn stream_send_100000_then_close_stream_remote2() { std::thread::sleep(std::time::Duration::from_millis(1000)); drop(s1_a); //no receiving + drop((_n_a, _p_a, _n_b, _p_b)); //clean teardown } #[test] fn stream_send_100000_then_close_stream_remote3() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, _s1_b) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _p_a, mut s1_a, _n_b, _p_b, _s1_b) = network_participant_stream(tcp()); for _ in 0..100000 { s1_a.send("woop_PARTY_HARD_woop").unwrap(); } @@ -183,12 +185,13 @@ fn stream_send_100000_then_close_stream_remote3() { std::thread::sleep(std::time::Duration::from_millis(1000)); drop(s1_a); //no receiving + drop((_n_a, _p_a, _n_b, _p_b)); //clean teardown } #[test] fn close_part_then_network() { let (_, _) = helper::setup(false, 0); - let (n_a, p_a, mut s1_a, _n_b, _p_b, _s1_b) = block_on(network_participant_stream(tcp())); + let (_, n_a, p_a, mut s1_a, _n_b, _p_b, _s1_b) = network_participant_stream(tcp()); for _ in 0..1000 { s1_a.send("woop_PARTY_HARD_woop").unwrap(); } @@ -201,7 +204,7 @@ fn close_part_then_network() { #[test] fn close_network_then_part() { let (_, _) = helper::setup(false, 0); - let (n_a, p_a, mut s1_a, _n_b, _p_b, _s1_b) = block_on(network_participant_stream(tcp())); + let (_, n_a, p_a, mut s1_a, _n_b, _p_b, _s1_b) = network_participant_stream(tcp()); for _ in 0..1000 { s1_a.send("woop_PARTY_HARD_woop").unwrap(); } @@ -214,140 +217,143 @@ fn close_network_then_part() { #[test] fn close_network_then_disconnect_part() { let (_, _) = helper::setup(false, 0); - let (n_a, p_a, mut s1_a, _n_b, _p_b, _s1_b) = block_on(network_participant_stream(tcp())); + let (r, n_a, p_a, mut s1_a, _n_b, _p_b, _s1_b) = network_participant_stream(tcp()); for _ in 0..1000 { s1_a.send("woop_PARTY_HARD_woop").unwrap(); } drop(n_a); - assert!(block_on(p_a.disconnect()).is_err()); + assert!(r.block_on(p_a.disconnect()).is_err()); std::thread::sleep(std::time::Duration::from_millis(1000)); } #[test] fn opened_stream_before_remote_part_is_closed() { let (_, _) = helper::setup(false, 0); - let (_n_a, p_a, _, _n_b, p_b, _) = block_on(network_participant_stream(tcp())); - let mut s2_a = block_on(p_a.open(10, Promises::empty())).unwrap(); + let (r, _n_a, p_a, _, _n_b, p_b, _) = network_participant_stream(tcp()); + let mut s2_a = r.block_on(p_a.open(4, Promises::empty())).unwrap(); s2_a.send("HelloWorld").unwrap(); - let mut s2_b = block_on(p_b.opened()).unwrap(); + let mut s2_b = r.block_on(p_b.opened()).unwrap(); drop(p_a); std::thread::sleep(std::time::Duration::from_millis(1000)); - assert_eq!(block_on(s2_b.recv()), Ok("HelloWorld".to_string())); + assert_eq!(r.block_on(s2_b.recv()), Ok("HelloWorld".to_string())); + drop((_n_a, _n_b, p_b)); //clean teardown } #[test] fn opened_stream_after_remote_part_is_closed() { let (_, _) = helper::setup(false, 0); - let (_n_a, p_a, _, _n_b, p_b, _) = block_on(network_participant_stream(tcp())); - let mut s2_a = block_on(p_a.open(10, Promises::empty())).unwrap(); + let (r, _n_a, p_a, _, _n_b, p_b, _) = network_participant_stream(tcp()); + let mut s2_a = r.block_on(p_a.open(3, Promises::empty())).unwrap(); s2_a.send("HelloWorld").unwrap(); drop(p_a); std::thread::sleep(std::time::Duration::from_millis(1000)); - let mut s2_b = block_on(p_b.opened()).unwrap(); - assert_eq!(block_on(s2_b.recv()), Ok("HelloWorld".to_string())); + let mut s2_b = r.block_on(p_b.opened()).unwrap(); + assert_eq!(r.block_on(s2_b.recv()), Ok("HelloWorld".to_string())); assert_eq!( - block_on(p_b.opened()).unwrap_err(), + r.block_on(p_b.opened()).unwrap_err(), ParticipantError::ParticipantDisconnected ); + drop((_n_a, _n_b, p_b)); //clean teardown } #[test] fn open_stream_after_remote_part_is_closed() { let (_, _) = helper::setup(false, 0); - let (_n_a, p_a, _, _n_b, p_b, _) = block_on(network_participant_stream(tcp())); - let mut s2_a = block_on(p_a.open(10, Promises::empty())).unwrap(); + let (r, _n_a, p_a, _, _n_b, p_b, _) = network_participant_stream(tcp()); + let mut s2_a = r.block_on(p_a.open(4, Promises::empty())).unwrap(); s2_a.send("HelloWorld").unwrap(); drop(p_a); std::thread::sleep(std::time::Duration::from_millis(1000)); - let mut s2_b = block_on(p_b.opened()).unwrap(); - assert_eq!(block_on(s2_b.recv()), Ok("HelloWorld".to_string())); + let mut s2_b = r.block_on(p_b.opened()).unwrap(); + assert_eq!(r.block_on(s2_b.recv()), Ok("HelloWorld".to_string())); assert_eq!( - block_on(p_b.open(20, Promises::empty())).unwrap_err(), + r.block_on(p_b.open(5, Promises::empty())).unwrap_err(), ParticipantError::ParticipantDisconnected ); + drop((_n_a, _n_b, p_b)); //clean teardown } #[test] fn failed_stream_open_after_remote_part_is_closed() { let (_, _) = helper::setup(false, 0); - let (_n_a, p_a, _, _n_b, p_b, _) = block_on(network_participant_stream(tcp())); + let (r, _n_a, p_a, _, _n_b, p_b, _) = network_participant_stream(tcp()); drop(p_a); - std::thread::sleep(std::time::Duration::from_millis(1000)); assert_eq!( - block_on(p_b.opened()).unwrap_err(), + r.block_on(p_b.opened()).unwrap_err(), ParticipantError::ParticipantDisconnected ); + drop((_n_a, _n_b, p_b)); //clean teardown } #[test] fn open_participant_before_remote_part_is_closed() { let (_, _) = helper::setup(false, 0); - let (n_a, f) = Network::new(Pid::fake(0)); - std::thread::spawn(f); - let (n_b, f) = Network::new(Pid::fake(1)); - std::thread::spawn(f); + let r = Arc::new(Runtime::new().unwrap()); + let n_a = Network::new(Pid::fake(0), Arc::clone(&r)); + let n_b = Network::new(Pid::fake(1), Arc::clone(&r)); let addr = tcp(); - block_on(n_a.listen(addr.clone())).unwrap(); - let p_b = block_on(n_b.connect(addr)).unwrap(); - let mut s1_b = block_on(p_b.open(10, Promises::empty())).unwrap(); + r.block_on(n_a.listen(addr.clone())).unwrap(); + let p_b = r.block_on(n_b.connect(addr)).unwrap(); + let mut s1_b = r.block_on(p_b.open(4, Promises::empty())).unwrap(); s1_b.send("HelloWorld").unwrap(); - let p_a = block_on(n_a.connected()).unwrap(); + let p_a = r.block_on(n_a.connected()).unwrap(); drop(s1_b); drop(p_b); drop(n_b); std::thread::sleep(std::time::Duration::from_millis(1000)); - let mut s1_a = block_on(p_a.opened()).unwrap(); - assert_eq!(block_on(s1_a.recv()), Ok("HelloWorld".to_string())); + let mut s1_a = r.block_on(p_a.opened()).unwrap(); + assert_eq!(r.block_on(s1_a.recv()), Ok("HelloWorld".to_string())); } #[test] fn open_participant_after_remote_part_is_closed() { let (_, _) = helper::setup(false, 0); - let (n_a, f) = Network::new(Pid::fake(0)); - std::thread::spawn(f); - let (n_b, f) = Network::new(Pid::fake(1)); - std::thread::spawn(f); + let r = Arc::new(Runtime::new().unwrap()); + let n_a = Network::new(Pid::fake(0), Arc::clone(&r)); + let n_b = Network::new(Pid::fake(1), Arc::clone(&r)); let addr = tcp(); - block_on(n_a.listen(addr.clone())).unwrap(); - let p_b = block_on(n_b.connect(addr)).unwrap(); - let mut s1_b = block_on(p_b.open(10, Promises::empty())).unwrap(); + r.block_on(n_a.listen(addr.clone())).unwrap(); + let p_b = r.block_on(n_b.connect(addr)).unwrap(); + let mut s1_b = r.block_on(p_b.open(4, Promises::empty())).unwrap(); s1_b.send("HelloWorld").unwrap(); drop(s1_b); drop(p_b); drop(n_b); std::thread::sleep(std::time::Duration::from_millis(1000)); - let p_a = block_on(n_a.connected()).unwrap(); - let mut s1_a = block_on(p_a.opened()).unwrap(); - assert_eq!(block_on(s1_a.recv()), Ok("HelloWorld".to_string())); + let p_a = r.block_on(n_a.connected()).unwrap(); + let mut s1_a = r.block_on(p_a.opened()).unwrap(); + assert_eq!(r.block_on(s1_a.recv()), Ok("HelloWorld".to_string())); } #[test] fn close_network_scheduler_completely() { let (_, _) = helper::setup(false, 0); - let (n_a, f) = Network::new(Pid::fake(0)); - let ha = std::thread::spawn(f); - let (n_b, f) = Network::new(Pid::fake(1)); - let hb = std::thread::spawn(f); + let r = Arc::new(Runtime::new().unwrap()); + let n_a = Network::new(Pid::fake(0), Arc::clone(&r)); + let n_b = Network::new(Pid::fake(1), Arc::clone(&r)); let addr = tcp(); - block_on(n_a.listen(addr.clone())).unwrap(); - let p_b = block_on(n_b.connect(addr)).unwrap(); - let mut s1_b = block_on(p_b.open(10, Promises::empty())).unwrap(); + r.block_on(n_a.listen(addr.clone())).unwrap(); + let p_b = r.block_on(n_b.connect(addr)).unwrap(); + let mut s1_b = r.block_on(p_b.open(4, Promises::empty())).unwrap(); s1_b.send("HelloWorld").unwrap(); - let p_a = block_on(n_a.connected()).unwrap(); - let mut s1_a = block_on(p_a.opened()).unwrap(); - assert_eq!(block_on(s1_a.recv()), Ok("HelloWorld".to_string())); + let p_a = r.block_on(n_a.connected()).unwrap(); + let mut s1_a = r.block_on(p_a.opened()).unwrap(); + assert_eq!(r.block_on(s1_a.recv()), Ok("HelloWorld".to_string())); drop(n_a); drop(n_b); std::thread::sleep(std::time::Duration::from_millis(1000)); - ha.join().unwrap(); - hb.join().unwrap(); + + drop(p_b); + drop(p_a); + let runtime = Arc::try_unwrap(r).expect("runtime is not alone, there still exist a reference"); + runtime.shutdown_timeout(std::time::Duration::from_secs(300)); } #[test] fn dont_panic_on_multiply_recv_after_close() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send(11u32).unwrap(); drop(s1_a); @@ -362,7 +368,7 @@ fn dont_panic_on_multiply_recv_after_close() { #[test] fn dont_panic_on_recv_send_after_close() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send(11u32).unwrap(); drop(s1_a); @@ -375,7 +381,7 @@ fn dont_panic_on_recv_send_after_close() { #[test] fn dont_panic_on_multiple_send_after_close() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send(11u32).unwrap(); drop(s1_a); diff --git a/network/tests/helper.rs b/network/tests/helper.rs index 93ee64a1e6..eb5806190b 100644 --- a/network/tests/helper.rs +++ b/network/tests/helper.rs @@ -1,10 +1,14 @@ use lazy_static::*; use std::{ net::SocketAddr, - sync::atomic::{AtomicU16, Ordering}, + sync::{ + atomic::{AtomicU16, AtomicU64, Ordering}, + Arc, + }, thread, time::Duration, }; +use tokio::runtime::Runtime; use tracing::*; use tracing_subscriber::EnvFilter; use veloren_network::{Network, Participant, Pid, Promises, ProtocolAddr, Stream}; @@ -43,38 +47,57 @@ pub fn setup(tracing: bool, sleep: u64) -> (u64, u64) { } #[allow(dead_code)] -pub async fn network_participant_stream( +pub fn network_participant_stream( addr: ProtocolAddr, -) -> (Network, Participant, Stream, Network, Participant, Stream) { - let (n_a, f_a) = Network::new(Pid::fake(0)); - std::thread::spawn(f_a); - let (n_b, f_b) = Network::new(Pid::fake(1)); - std::thread::spawn(f_b); +) -> ( + Arc, + Network, + Participant, + Stream, + Network, + Participant, + Stream, +) { + let runtime = Arc::new(Runtime::new().unwrap()); + let (n_a, p1_a, s1_a, n_b, p1_b, s1_b) = runtime.block_on(async { + let n_a = Network::new(Pid::fake(0), Arc::clone(&runtime)); + let n_b = Network::new(Pid::fake(1), Arc::clone(&runtime)); - n_a.listen(addr.clone()).await.unwrap(); - let p1_b = n_b.connect(addr).await.unwrap(); - let p1_a = n_a.connected().await.unwrap(); + n_a.listen(addr.clone()).await.unwrap(); + let p1_b = n_b.connect(addr).await.unwrap(); + let p1_a = n_a.connected().await.unwrap(); - let s1_a = p1_a.open(10, Promises::empty()).await.unwrap(); - let s1_b = p1_b.opened().await.unwrap(); + let s1_a = p1_a.open(4, Promises::empty()).await.unwrap(); + let s1_b = p1_b.opened().await.unwrap(); - (n_a, p1_a, s1_a, n_b, p1_b, s1_b) + (n_a, p1_a, s1_a, n_b, p1_b, s1_b) + }); + (runtime, n_a, p1_a, s1_a, n_b, p1_b, s1_b) } #[allow(dead_code)] -pub fn tcp() -> veloren_network::ProtocolAddr { +pub fn tcp() -> ProtocolAddr { lazy_static! { static ref PORTS: AtomicU16 = AtomicU16::new(5000); } let port = PORTS.fetch_add(1, Ordering::Relaxed); - veloren_network::ProtocolAddr::Tcp(SocketAddr::from(([127, 0, 0, 1], port))) + ProtocolAddr::Tcp(SocketAddr::from(([127, 0, 0, 1], port))) } #[allow(dead_code)] -pub fn udp() -> veloren_network::ProtocolAddr { +pub fn udp() -> ProtocolAddr { lazy_static! { static ref PORTS: AtomicU16 = AtomicU16::new(5000); } let port = PORTS.fetch_add(1, Ordering::Relaxed); - veloren_network::ProtocolAddr::Udp(SocketAddr::from(([127, 0, 0, 1], port))) + ProtocolAddr::Udp(SocketAddr::from(([127, 0, 0, 1], port))) +} + +#[allow(dead_code)] +pub fn mpsc() -> ProtocolAddr { + lazy_static! { + static ref PORTS: AtomicU64 = AtomicU64::new(5000); + } + let port = PORTS.fetch_add(1, Ordering::Relaxed); + ProtocolAddr::Mpsc(port) } diff --git a/network/tests/integration.rs b/network/tests/integration.rs index f4c8367841..af30b1c89f 100644 --- a/network/tests/integration.rs +++ b/network/tests/integration.rs @@ -1,8 +1,8 @@ -use async_std::task; -use task::block_on; +use std::sync::Arc; +use tokio::runtime::Runtime; use veloren_network::{NetworkError, StreamError}; mod helper; -use helper::{network_participant_stream, tcp, udp}; +use helper::{mpsc, network_participant_stream, tcp, udp}; use std::io::ErrorKind; use veloren_network::{Network, Pid, Promises, ProtocolAddr}; @@ -10,73 +10,105 @@ use veloren_network::{Network, Pid, Promises, ProtocolAddr}; #[ignore] fn network_20s() { let (_, _) = helper::setup(false, 0); - let (_n_a, _, _, _n_b, _, _) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _, _, _n_b, _, _) = network_participant_stream(tcp()); std::thread::sleep(std::time::Duration::from_secs(30)); } #[test] fn stream_simple() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send("Hello World").unwrap(); - assert_eq!(block_on(s1_b.recv()), Ok("Hello World".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown } #[test] fn stream_try_recv() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send(4242u32).unwrap(); std::thread::sleep(std::time::Duration::from_secs(1)); assert_eq!(s1_b.try_recv(), Ok(Some(4242u32))); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown } #[test] fn stream_simple_3msg() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send("Hello World").unwrap(); s1_a.send(1337).unwrap(); - assert_eq!(block_on(s1_b.recv()), Ok("Hello World".to_string())); - assert_eq!(block_on(s1_b.recv()), Ok(1337)); + assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok(1337)); s1_a.send("3rdMessage").unwrap(); - assert_eq!(block_on(s1_b.recv()), Ok("3rdMessage".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok("3rdMessage".to_string())); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown } #[test] +fn stream_simple_mpsc() { + let (_, _) = helper::setup(false, 0); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(mpsc()); + + s1_a.send("Hello World").unwrap(); + assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown +} + +#[test] +fn stream_simple_mpsc_3msg() { + let (_, _) = helper::setup(false, 0); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(mpsc()); + + s1_a.send("Hello World").unwrap(); + s1_a.send(1337).unwrap(); + assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok(1337)); + s1_a.send("3rdMessage").unwrap(); + assert_eq!(r.block_on(s1_b.recv()), Ok("3rdMessage".to_string())); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown +} + +#[test] +#[ignore] fn stream_simple_udp() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(udp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(udp()); s1_a.send("Hello World").unwrap(); - assert_eq!(block_on(s1_b.recv()), Ok("Hello World".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown } #[test] +#[ignore] fn stream_simple_udp_3msg() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(udp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(udp()); s1_a.send("Hello World").unwrap(); s1_a.send(1337).unwrap(); - assert_eq!(block_on(s1_b.recv()), Ok("Hello World".to_string())); - assert_eq!(block_on(s1_b.recv()), Ok(1337)); + assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok(1337)); s1_a.send("3rdMessage").unwrap(); - assert_eq!(block_on(s1_b.recv()), Ok("3rdMessage".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok("3rdMessage".to_string())); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown } #[test] #[ignore] fn tcp_and_udp_2_connections() -> std::result::Result<(), Box> { let (_, _) = helper::setup(false, 0); - let (network, f) = Network::new(Pid::new()); - let (remote, fr) = Network::new(Pid::new()); - std::thread::spawn(f); - std::thread::spawn(fr); - block_on(async { + let r = Arc::new(Runtime::new().unwrap()); + let network = Network::new(Pid::new(), Arc::clone(&r)); + let remote = Network::new(Pid::new(), Arc::clone(&r)); + r.block_on(async { + let network = network; + let remote = remote; remote .listen(ProtocolAddr::Tcp("127.0.0.1:2000".parse().unwrap())) .await?; @@ -95,20 +127,20 @@ fn tcp_and_udp_2_connections() -> std::result::Result<(), Box std::result::Result<(), Box> { let (_, _) = helper::setup(false, 0); - let (network, f) = Network::new(Pid::new()); - std::thread::spawn(f); + let r = Arc::new(Runtime::new().unwrap()); + let network = Network::new(Pid::new(), Arc::clone(&r)); let udp1 = udp(); let tcp1 = tcp(); - block_on(network.listen(udp1.clone()))?; - block_on(network.listen(tcp1.clone()))?; + r.block_on(network.listen(udp1.clone()))?; + r.block_on(network.listen(tcp1.clone()))?; std::thread::sleep(std::time::Duration::from_millis(200)); - let (network2, f2) = Network::new(Pid::new()); - std::thread::spawn(f2); - let e1 = block_on(network2.listen(udp1)); - let e2 = block_on(network2.listen(tcp1)); + let network2 = Network::new(Pid::new(), Arc::clone(&r)); + let e1 = r.block_on(network2.listen(udp1)); + let e2 = r.block_on(network2.listen(tcp1)); match e1 { Err(NetworkError::ListenFailed(e)) if e.kind() == ErrorKind::AddrInUse => (), _ => panic!(), @@ -117,6 +149,7 @@ fn failed_listen_on_used_ports() -> std::result::Result<(), Box (), _ => panic!(), }; + drop((network, network2)); //clean teardown Ok(()) } @@ -130,11 +163,12 @@ fn api_stream_send_main() -> std::result::Result<(), Box> let (_, _) = helper::setup(false, 0); // Create a Network, listen on Port `1200` and wait for a Stream to be opened, // then answer `Hello World` - let (network, f) = Network::new(Pid::new()); - let (remote, fr) = Network::new(Pid::new()); - std::thread::spawn(f); - std::thread::spawn(fr); - block_on(async { + let r = Arc::new(Runtime::new().unwrap()); + let network = Network::new(Pid::new(), Arc::clone(&r)); + let remote = Network::new(Pid::new(), Arc::clone(&r)); + r.block_on(async { + let network = network; + let remote = remote; network .listen(ProtocolAddr::Tcp("127.0.0.1:1200".parse().unwrap())) .await?; @@ -143,7 +177,7 @@ fn api_stream_send_main() -> std::result::Result<(), Box> .await?; // keep it alive let _stream_p = remote_p - .open(16, Promises::ORDERED | Promises::CONSISTENCY) + .open(4, Promises::ORDERED | Promises::CONSISTENCY) .await?; let participant_a = network.connected().await?; let mut stream_a = participant_a.opened().await?; @@ -158,11 +192,12 @@ fn api_stream_recv_main() -> std::result::Result<(), Box> let (_, _) = helper::setup(false, 0); // Create a Network, listen on Port `1220` and wait for a Stream to be opened, // then listen on it - let (network, f) = Network::new(Pid::new()); - let (remote, fr) = Network::new(Pid::new()); - std::thread::spawn(f); - std::thread::spawn(fr); - block_on(async { + let r = Arc::new(Runtime::new().unwrap()); + let network = Network::new(Pid::new(), Arc::clone(&r)); + let remote = Network::new(Pid::new(), Arc::clone(&r)); + r.block_on(async { + let network = network; + let remote = remote; network .listen(ProtocolAddr::Tcp("127.0.0.1:1220".parse().unwrap())) .await?; @@ -170,7 +205,7 @@ fn api_stream_recv_main() -> std::result::Result<(), Box> .connect(ProtocolAddr::Tcp("127.0.0.1:1220".parse().unwrap())) .await?; let mut stream_p = remote_p - .open(16, Promises::ORDERED | Promises::CONSISTENCY) + .open(4, Promises::ORDERED | Promises::CONSISTENCY) .await?; stream_p.send("Hello World")?; let participant_a = network.connected().await?; @@ -184,19 +219,20 @@ fn api_stream_recv_main() -> std::result::Result<(), Box> #[test] fn wrong_parse() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send(1337).unwrap(); - match block_on(s1_b.recv::()) { + match r.block_on(s1_b.recv::()) { Err(StreamError::Deserialize(_)) => (), _ => panic!("this should fail, but it doesnt!"), } + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown } #[test] fn multiple_try_recv() { let (_, _) = helper::setup(false, 0); - let (_n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = block_on(network_participant_stream(tcp())); + let (_, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(tcp()); s1_a.send("asd").unwrap(); s1_a.send(11u32).unwrap(); @@ -208,4 +244,5 @@ fn multiple_try_recv() { drop(s1_a); std::thread::sleep(std::time::Duration::from_secs(1)); assert_eq!(s1_b.try_recv::(), Err(StreamError::StreamClosed)); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown } diff --git a/server-cli/Cargo.toml b/server-cli/Cargo.toml index 6269f04932..d75e1987ec 100644 --- a/server-cli/Cargo.toml +++ b/server-cli/Cargo.toml @@ -15,6 +15,7 @@ server = { package = "veloren-server", path = "../server", default-features = fa common = { package = "veloren-common", path = "../common" } common-net = { package = "veloren-common-net", path = "../common/net" } +tokio = { version = "1", default-features = false, features = ["rt-multi-thread"] } ansi-parser = "0.7" clap = "2.33" crossterm = "0.18" diff --git a/server-cli/src/logging.rs b/server-cli/src/logging.rs index 6a738ed2c5..74eb2dc304 100644 --- a/server-cli/src/logging.rs +++ b/server-cli/src/logging.rs @@ -17,8 +17,11 @@ pub fn init(basic: bool) { env.add_directive("veloren_world::sim=info".parse().unwrap()) .add_directive("veloren_world::civ=info".parse().unwrap()) .add_directive("uvth=warn".parse().unwrap()) - .add_directive("tiny_http=warn".parse().unwrap()) + .add_directive("hyper=info".parse().unwrap()) + .add_directive("prometheus_hyper=info".parse().unwrap()) + .add_directive("mio::pool=info".parse().unwrap()) .add_directive("mio::sys::windows=debug".parse().unwrap()) + .add_directive("veloren_network_protocol=info".parse().unwrap()) .add_directive( "veloren_server::persistence::character=info" .parse() diff --git a/server-cli/src/main.rs b/server-cli/src/main.rs index ba83e1d67f..b930bd9de9 100644 --- a/server-cli/src/main.rs +++ b/server-cli/src/main.rs @@ -129,8 +129,19 @@ fn main() -> io::Result<()> { let server_port = &server_settings.gameserver_address.port(); let metrics_port = &server_settings.metrics_address.port(); // Create server - let mut server = Server::new(server_settings, editable_settings, &server_data_dir) - .expect("Failed to create server instance!"); + let runtime = Arc::new( + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(), + ); + let mut server = Server::new( + server_settings, + editable_settings, + &server_data_dir, + runtime, + ) + .expect("Failed to create server instance!"); info!( ?server_port, diff --git a/server/Cargo.toml b/server/Cargo.toml index e1f4b7c7e3..f73ed1f02f 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -16,7 +16,7 @@ common = { package = "veloren-common", path = "../common" } common-sys = { package = "veloren-common-sys", path = "../common/sys" } common-net = { package = "veloren-common-net", path = "../common/net" } world = { package = "veloren-world", path = "../world" } -network = { package = "veloren_network", path = "../network", features = ["metrics", "compression"], default-features = false } +network = { package = "veloren-network", path = "../network", features = ["metrics", "compression"], default-features = false } specs = { git = "https://github.com/amethyst/specs.git", features = ["shred-derive"], rev = "d4435bdf496cf322c74886ca09dd8795984919b4" } specs-idvs = { git = "https://gitlab.com/veloren/specs-idvs.git", rev = "9fab7b396acd6454585486e50ae4bfe2069858a9" } @@ -28,6 +28,8 @@ futures-util = "0.3.7" futures-executor = "0.3" futures-timer = "3.0" futures-channel = "0.3" +tokio = { version = "1", default-features = false, features = ["rt"] } +prometheus-hyper = "0.1.1" itertools = "0.9" lazy_static = "1.4.0" scan_fmt = { git = "https://github.com/Imberflur/scan_fmt" } @@ -40,7 +42,6 @@ hashbrown = { version = "0.9", features = ["rayon", "serde", "nightly"] } rayon = "1.5" crossbeam-channel = "0.5" prometheus = { version = "0.11", default-features = false} -tiny_http = "0.8.0" portpicker = { git = "https://github.com/xMAC94x/portpicker-rs", rev = "df6b37872f3586ac3b21d08b56c8ec7cd92fb172" } authc = { git = "https://gitlab.com/veloren/auth.git", rev = "bffb5181a35c19ddfd33ee0b4aedba741aafb68d" } libsqlite3-sys = { version = "0.18", features = ["bundled"] } diff --git a/server/src/connection_handler.rs b/server/src/connection_handler.rs index 4e4dc02e0a..4128d3115c 100644 --- a/server/src/connection_handler.rs +++ b/server/src/connection_handler.rs @@ -104,11 +104,11 @@ impl ConnectionHandler { let reliable = Promises::ORDERED | Promises::CONSISTENCY; let reliablec = reliable | Promises::COMPRESSED; - let general_stream = participant.open(10, reliablec).await?; - let ping_stream = participant.open(5, reliable).await?; - let mut register_stream = participant.open(10, reliablec).await?; - let character_screen_stream = participant.open(10, reliablec).await?; - let in_game_stream = participant.open(10, reliablec).await?; + let general_stream = participant.open(3, reliablec).await?; + let ping_stream = participant.open(2, reliable).await?; + let mut register_stream = participant.open(3, reliablec).await?; + let character_screen_stream = participant.open(3, reliablec).await?; + let in_game_stream = participant.open(3, reliablec).await?; let server_data = receiver.recv()?; diff --git a/server/src/lib.rs b/server/src/lib.rs index 93c5916d72..40aa2b0804 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -76,12 +76,14 @@ use common_net::{ use common_sys::plugin::PluginMgr; use common_sys::state::State; use futures_executor::block_on; -use metrics::{PhysicsMetrics, ServerMetrics, StateTickMetrics, TickMetrics}; +use metrics::{PhysicsMetrics, StateTickMetrics, TickMetrics}; use network::{Network, Pid, ProtocolAddr}; use persistence::{ character_loader::{CharacterLoader, CharacterLoaderResponseKind}, character_updater::CharacterUpdater, }; +use prometheus::Registry; +use prometheus_hyper::Server as PrometheusServer; use specs::{join::Join, Builder, Entity as EcsEntity, RunNow, SystemData, WorldExt}; use std::{ i32, @@ -91,6 +93,7 @@ use std::{ }; #[cfg(not(feature = "worldgen"))] use test_world::{IndexOwned, World}; +use tokio::{runtime::Runtime, sync::Notify}; use tracing::{debug, error, info, trace}; use uvth::{ThreadPool, ThreadPoolBuilder}; use vek::*; @@ -120,9 +123,10 @@ pub struct Server { connection_handler: ConnectionHandler, + _runtime: Arc, thread_pool: ThreadPool, - metrics: ServerMetrics, + metrics_shutdown: Arc, tick_metrics: TickMetrics, state_tick_metrics: StateTickMetrics, physics_metrics: PhysicsMetrics, @@ -136,6 +140,7 @@ impl Server { settings: Settings, editable_settings: EditableSettings, data_dir: &std::path::Path, + runtime: Arc, ) -> Result { info!("Server is data dir is: {}", data_dir.display()); if settings.auth_server_address.is_none() { @@ -347,28 +352,35 @@ impl Server { state.ecs_mut().insert(DeletedEntities::default()); - let mut metrics = ServerMetrics::new(); // register all metrics submodules here - let (tick_metrics, registry_tick) = TickMetrics::new(metrics.tick_clone()) - .expect("Failed to initialize server tick metrics submodule."); + let (tick_metrics, registry_tick) = + TickMetrics::new().expect("Failed to initialize server tick metrics submodule."); let (state_tick_metrics, registry_state) = StateTickMetrics::new().unwrap(); let (physics_metrics, registry_physics) = PhysicsMetrics::new().unwrap(); - registry_chunk(&metrics.registry()).expect("failed to register chunk gen metrics"); - registry_network(&metrics.registry()).expect("failed to register network request metrics"); - registry_player(&metrics.registry()).expect("failed to register player metrics"); - registry_tick(&metrics.registry()).expect("failed to register tick metrics"); - registry_state(&metrics.registry()).expect("failed to register state metrics"); - registry_physics(&metrics.registry()).expect("failed to register state metrics"); + let registry = Arc::new(Registry::new()); + registry_chunk(®istry).expect("failed to register chunk gen metrics"); + registry_network(®istry).expect("failed to register network request metrics"); + registry_player(®istry).expect("failed to register player metrics"); + registry_tick(®istry).expect("failed to register tick metrics"); + registry_state(®istry).expect("failed to register state metrics"); + registry_physics(®istry).expect("failed to register state metrics"); let thread_pool = ThreadPoolBuilder::new() .name("veloren-worker".to_string()) .build(); - let (network, f) = Network::new_with_registry(Pid::new(), &metrics.registry()); - metrics - .run(settings.metrics_address) - .expect("Failed to initialize server metrics submodule."); - thread_pool.execute(f); + let network = Network::new_with_registry(Pid::new(), Arc::clone(&runtime), ®istry); + let metrics_shutdown = Arc::new(Notify::new()); + let metrics_shutdown_clone = Arc::clone(&metrics_shutdown); + let addr = settings.metrics_address; + runtime.spawn(async move { + PrometheusServer::run( + Arc::clone(®istry), + addr, + metrics_shutdown_clone.notified(), + ) + .await + }); block_on(network.listen(ProtocolAddr::Tcp(settings.gameserver_address)))?; let connection_handler = ConnectionHandler::new(network); @@ -386,9 +398,10 @@ impl Server { connection_handler, + _runtime: runtime, thread_pool, - metrics, + metrics_shutdown, tick_metrics, state_tick_metrics, physics_metrics, @@ -900,7 +913,7 @@ impl Server { .tick_time .with_label_values(&["metrics"]) .set(end_of_server_tick.elapsed().as_nanos() as i64); - self.metrics.tick(); + self.tick_metrics.tick(); // 9) Finish the tick, pass control back to the frontend. @@ -1146,6 +1159,7 @@ impl Server { impl Drop for Server { fn drop(&mut self) { + self.metrics_shutdown.notify_one(); self.state .notify_players(ServerGeneral::Disconnect(DisconnectReason::Shutdown)); } diff --git a/server/src/metrics.rs b/server/src/metrics.rs index d52ae33538..ac57121437 100644 --- a/server/src/metrics.rs +++ b/server/src/metrics.rs @@ -1,19 +1,16 @@ use prometheus::{ - Encoder, Gauge, HistogramOpts, HistogramVec, IntCounter, IntCounterVec, IntGauge, IntGaugeVec, - Opts, Registry, TextEncoder, + Gauge, HistogramOpts, HistogramVec, IntCounter, IntCounterVec, IntGauge, IntGaugeVec, Opts, + Registry, }; use std::{ convert::TryInto, error::Error, - net::SocketAddr, sync::{ - atomic::{AtomicBool, AtomicU64, Ordering}, + atomic::{AtomicU64, Ordering}, Arc, }, - thread, time::{Duration, SystemTime, UNIX_EPOCH}, }; -use tracing::{debug, error}; type RegistryFn = Box Result<(), prometheus::Error>>; @@ -60,13 +57,6 @@ pub struct TickMetrics { tick: Arc, } -pub struct ServerMetrics { - running: Arc, - handle: Option>, - registry: Option, - tick: Arc, -} - impl PhysicsMetrics { pub fn new() -> Result<(Self, RegistryFn), prometheus::Error> { let entity_entity_collision_checks_count = IntCounter::with_opts(Opts::new( @@ -265,7 +255,7 @@ impl ChunkGenMetrics { } impl TickMetrics { - pub fn new(tick: Arc) -> Result<(Self, RegistryFn), Box> { + pub fn new() -> Result<(Self, RegistryFn), Box> { let chonks_count = IntGauge::with_opts(Opts::new( "chonks_count", "number of all chonks currently active on the server", @@ -315,6 +305,7 @@ impl TickMetrics { let time_of_day_clone = time_of_day.clone(); let light_count_clone = light_count.clone(); let tick_time_clone = tick_time.clone(); + let tick = Arc::new(AtomicU64::new(0)); let f = |registry: &Registry| { registry.register(Box::new(chonks_count_clone))?; @@ -346,87 +337,7 @@ impl TickMetrics { )) } + pub fn tick(&self) { self.tick.fetch_add(1, Ordering::Relaxed); } + pub fn is_100th_tick(&self) -> bool { self.tick.load(Ordering::Relaxed).rem_euclid(100) == 0 } } - -impl ServerMetrics { - #[allow(clippy::new_without_default)] // TODO: Pending review in #587 - pub fn new() -> Self { - let running = Arc::new(AtomicBool::new(false)); - let tick = Arc::new(AtomicU64::new(0)); - let registry = Some(Registry::new()); - - Self { - running, - handle: None, - registry, - tick, - } - } - - pub fn registry(&self) -> &Registry { - match self.registry { - Some(ref r) => r, - None => panic!("You cannot longer register new metrics after the server has started!"), - } - } - - pub fn run(&mut self, addr: SocketAddr) -> Result<(), Box> { - self.running.store(true, Ordering::Relaxed); - let running2 = Arc::clone(&self.running); - - let registry = self - .registry - .take() - .expect("ServerMetrics must be already started"); - - //TODO: make this a job - self.handle = Some(thread::spawn(move || { - let server = tiny_http::Server::http(addr).unwrap(); - const TIMEOUT: Duration = Duration::from_secs(1); - debug!("starting tiny_http server to serve metrics"); - while running2.load(Ordering::Relaxed) { - let request = match server.recv_timeout(TIMEOUT) { - Ok(Some(rq)) => rq, - Ok(None) => continue, - Err(e) => { - error!(?e, "metrics http server error"); - break; - }, - }; - let mf = registry.gather(); - let encoder = TextEncoder::new(); - let mut buffer = vec![]; - encoder - .encode(&mf, &mut buffer) - .expect("Failed to encoder metrics text."); - let response = tiny_http::Response::from_string( - String::from_utf8(buffer).expect("Failed to parse bytes as a string."), - ); - if let Err(e) = request.respond(response) { - error!( - ?e, - "The metrics HTTP server had encountered and error with answering", - ); - } - } - debug!("stopping tiny_http server to serve metrics"); - })); - Ok(()) - } - - pub fn tick(&self) -> u64 { self.tick.fetch_add(1, Ordering::Relaxed) + 1 } - - pub fn tick_clone(&self) -> Arc { Arc::clone(&self.tick) } -} - -impl Drop for ServerMetrics { - fn drop(&mut self) { - self.running.store(false, Ordering::Relaxed); - let handle = self.handle.take(); - handle - .expect("ServerMetrics worker handle does not exist.") - .join() - .expect("Error shutting down prometheus metric exporter"); - } -} diff --git a/voxygen/Cargo.toml b/voxygen/Cargo.toml index 7480a7ed2f..08a7a72c46 100644 --- a/voxygen/Cargo.toml +++ b/voxygen/Cargo.toml @@ -82,6 +82,8 @@ ron = {version = "0.6", default-features = false} serde = {version = "1.0", features = [ "rc", "derive" ]} treeculler = "0.1.0" uvth = "3.1.1" +tokio = { version = "1", default-features = false, features = ["rt-multi-thread"] } +num_cpus = "1.0" # vec_map = { version = "0.8.2" } inline_tweak = "1.0.2" itertools = "0.10.0" diff --git a/voxygen/src/hud/chat.rs b/voxygen/src/hud/chat.rs index b2dc69c780..1ecc5fe621 100644 --- a/voxygen/src/hud/chat.rs +++ b/voxygen/src/hud/chat.rs @@ -373,7 +373,7 @@ impl<'a> Widget for Chat<'a> { let ChatMsg { chat_type, .. } = &message; // For each ChatType needing localization get/set matching pre-formatted // localized string. This string will be formatted with the data - // provided in ChatType in the client/src/lib.rs + // provided in ChatType in the client/src/mod.rs // fn format_message called below message.message = match chat_type { ChatType::Online(_) => self diff --git a/voxygen/src/logging.rs b/voxygen/src/logging.rs index e0b4a962df..d0c3bd98d5 100644 --- a/voxygen/src/logging.rs +++ b/voxygen/src/logging.rs @@ -45,6 +45,7 @@ pub fn init(settings: &Settings) -> Vec { .add_directive("uvth=warn".parse().unwrap()) .add_directive("tiny_http=warn".parse().unwrap()) .add_directive("mio::sys::windows=debug".parse().unwrap()) + .add_directive("veloren_network_protocol=info".parse().unwrap()) .add_directive( "veloren_server::persistence::character=info" .parse() diff --git a/voxygen/src/menu/main/client_init.rs b/voxygen/src/menu/main/client_init.rs index a1e01ab71a..2297bf9232 100644 --- a/voxygen/src/menu/main/client_init.rs +++ b/voxygen/src/menu/main/client_init.rs @@ -71,6 +71,15 @@ impl ClientInit { let mut last_err = None; + let cores = num_cpus::get(); + let runtime = Arc::new( + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .worker_threads(if cores > 4 { cores - 1 } else { cores }) + .build() + .unwrap(), + ); + const FOUR_MINUTES_RETRIES: u64 = 48; 'tries: for _ in 0..FOUR_MINUTES_RETRIES { if cancel2.load(Ordering::Relaxed) { @@ -79,7 +88,7 @@ impl ClientInit { for socket_addr in first_addrs.clone().into_iter().chain(second_addrs.clone()) { - match Client::new(socket_addr, view_distance) { + match Client::new(socket_addr, view_distance, Arc::clone(&runtime)) { Ok(mut client) => { if let Err(e) = client.register(username, password, |auth_server| { diff --git a/voxygen/src/singleplayer.rs b/voxygen/src/singleplayer.rs index 0e1a050988..bda1c1fd2b 100644 --- a/voxygen/src/singleplayer.rs +++ b/voxygen/src/singleplayer.rs @@ -82,6 +82,14 @@ impl Singleplayer { let editable_settings = server::EditableSettings::singleplayer(&server_data_dir); let thread_pool = client.map(|c| c.thread_pool().clone()); + let cores = num_cpus::get(); + let runtime = Arc::new( + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .worker_threads(if cores > 4 { cores - 1 } else { cores }) + .build() + .unwrap(), + ); let settings2 = settings.clone(); let paused = Arc::new(AtomicBool::new(false)); @@ -92,7 +100,7 @@ impl Singleplayer { let thread = thread::spawn(move || { let mut server = None; if let Err(e) = result_sender.send( - match Server::new(settings2, editable_settings, &server_data_dir) { + match Server::new(settings2, editable_settings, &server_data_dir, runtime) { Ok(s) => { server = Some(s); Ok(())