diff --git a/Cargo.lock b/Cargo.lock index f867e99e19..5600abfcae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2268,7 +2268,7 @@ dependencies = [ "httpdate", "itoa", "pin-project", - "socket2", + "socket2 0.4.0", "tokio", "tower-service", "tracing", @@ -3638,6 +3638,17 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" +[[package]] +name = "pem" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd56cbd21fea48d0c440b41cd69c589faacade08c992d9a54e471b79d0fd13eb" +dependencies = [ + "base64", + "once_cell", + "regex", +] + [[package]] name = "percent-encoding" version = "2.1.0" @@ -3861,6 +3872,45 @@ dependencies = [ "tracing", ] +[[package]] +name = "quinn" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c82c0a393b300104f989f3db8b8637c0d11f7a32a9c214560b47849ba8f119aa" +dependencies = [ + "bytes", + "futures", + "lazy_static", + "libc", + "mio 0.7.11", + "quinn-proto", + "rustls", + "socket2 0.3.19", + "thiserror", + "tokio", + "tracing", + "webpki", +] + +[[package]] +name = "quinn-proto" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09027365a21874b71e1fbd9d31cb99bff8e11ba81cc9ef2b9425bb607e42d3b2" +dependencies = [ + "bytes", + "ct-logs", + "rand 0.8.3", + "ring", + "rustls", + "rustls-native-certs", + "slab", + "thiserror", + "tinyvec", + "tracing", + "webpki", +] + [[package]] name = "quote" version = "0.6.13" @@ -4029,6 +4079,18 @@ dependencies = [ "num_cpus", ] +[[package]] +name = "rcgen" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e80a701a04edd9cab874a3d59323bebe24c9a92dd602088c78da83732066d1b" +dependencies = [ + "chrono", + "pem", + "ring", + "yasna", +] + [[package]] name = "redox_syscall" version = "0.1.57" @@ -4680,6 +4742,17 @@ dependencies = [ "wayland-client 0.28.5", ] +[[package]] +name = "socket2" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "122e570113d28d773067fab24266b66753f6ea915758651696b6e35e49f88d6e" +dependencies = [ + "cfg-if 1.0.0", + "libc", + "winapi 0.3.9", +] + [[package]] name = "socket2" version = "0.4.0" @@ -5585,7 +5658,9 @@ dependencies = [ "lz-fear", "prometheus", "prometheus-hyper", + "quinn", "rand 0.8.3", + "rcgen", "serde", "shellexpand", "tokio", @@ -5604,6 +5679,7 @@ dependencies = [ "bitflags", "bytes", "criterion", + "hashbrown", "prometheus", "rand 0.8.3", "tokio", @@ -6597,3 +6673,12 @@ name = "xml-rs" version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b07db065a5cf61a7e4ba64f29e67db906fb1787316516c4e6e5ff0fea1efcd8a" + +[[package]] +name = "yasna" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de7bff972b4f2a06c85f6d8454b09df153af7e3a4ec2aac81db1b105b684ddb" +dependencies = [ + "chrono", +] diff --git a/client/src/lib.rs b/client/src/lib.rs index a3d6e9798e..df53d4d46e 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -61,7 +61,7 @@ use comp::BuffKind; use futures_util::FutureExt; use hashbrown::{HashMap, HashSet}; use image::DynamicImage; -use network::{Network, Participant, Pid, ProtocolAddr, Stream}; +use network::{ConnectAddr, Network, Participant, Pid, Stream}; use num::traits::FloatConst; use rayon::prelude::*; use specs::Component; @@ -218,7 +218,7 @@ impl Client { // Try to connect to all IP's and return the first that works let mut participant = None; for addr in addrs { - match network.connect(ProtocolAddr::Tcp(addr)).await { + match network.connect(ConnectAddr::Tcp(addr)).await { Ok(p) => { participant = Some(Ok(p)); break; @@ -229,7 +229,7 @@ impl Client { participant .unwrap_or_else(|| Err(Error::Other("No Ip Addr provided".to_string())))? }, - ConnectionArgs::Mpsc(id) => network.connect(ProtocolAddr::Mpsc(id)).await?, + ConnectionArgs::Mpsc(id) => network.connect(ConnectAddr::Mpsc(id)).await?, }; let stream = participant.opened().await?; diff --git a/network/Cargo.toml b/network/Cargo.toml index 7f0c45c073..bcb509aea1 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -9,8 +9,9 @@ edition = "2018" [features] metrics = ["prometheus", "network-protocol/metrics"] compression = ["lz-fear"] +quic = ["quinn"] -default = ["metrics","compression"] +default = ["metrics","compression","quic"] [dependencies] @@ -33,6 +34,8 @@ async-channel = "1.5.1" #use for .close() channels #mpsc channel registry lazy_static = { version = "1.4", default-features = false } rand = { version = "0.8" } +#quic support +quinn = { version = "0.7.2", optional = true } #stream flags bitflags = "1.2.1" lz-fear = { version = "0.1.1", optional = true } @@ -49,6 +52,8 @@ shellexpand = "2.0.0" serde = { version = "1.0", features = ["derive"] } prometheus-hyper = "0.1.2" criterion = { version = "0.3.4", features = ["default", "async_tokio"] } +#quic +rcgen = { version = "0.8.10"} [[bench]] name = "speed" diff --git a/network/benches/speed.rs b/network/benches/speed.rs index b110d308ae..d7f0f2b63c 100644 --- a/network/benches/speed.rs +++ b/network/benches/speed.rs @@ -1,7 +1,9 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use std::{net::SocketAddr, sync::Arc}; use tokio::{runtime::Runtime, sync::Mutex}; -use veloren_network::{Message, Network, Participant, Pid, Promises, ProtocolAddr, Stream}; +use veloren_network::{ + ConnectAddr, ListenAddr, Message, Network, Participant, Pid, Promises, Stream, +}; fn serialize(data: &[u8], stream: &Stream) { let _ = Message::serialize(data, stream.params()); } @@ -30,7 +32,7 @@ fn criterion_util(c: &mut Criterion) { c.significance_level(0.1).sample_size(100); let (r, _n_a, p_a, s1_a, _n_b, _p_b, _s1_b) = - network_participant_stream(ProtocolAddr::Mpsc(5000)); + network_participant_stream((ListenAddr::Mpsc(5000), ConnectAddr::Mpsc(5000))); let s2_a = r.block_on(p_a.open(4, Promises::COMPRESSED, 0)).unwrap(); c.throughput(Throughput::Bytes(1000)) @@ -50,7 +52,7 @@ fn criterion_mpsc(c: &mut Criterion) { c.significance_level(0.1).sample_size(10); let (_r, _n_a, _p_a, s1_a, _n_b, _p_b, s1_b) = - network_participant_stream(ProtocolAddr::Mpsc(5000)); + network_participant_stream((ListenAddr::Mpsc(5000), ConnectAddr::Mpsc(5000))); let s1_a = Arc::new(Mutex::new(s1_a)); let s1_b = Arc::new(Mutex::new(s1_b)); @@ -82,8 +84,9 @@ fn criterion_tcp(c: &mut Criterion) { let mut c = c.benchmark_group("net_tcp"); c.significance_level(0.1).sample_size(10); + let socket_addr = SocketAddr::from(([127, 0, 0, 1], 5000)); let (_r, _n_a, _p_a, s1_a, _n_b, _p_b, s1_b) = - network_participant_stream(ProtocolAddr::Tcp(SocketAddr::from(([127, 0, 0, 1], 5000)))); + network_participant_stream((ListenAddr::Tcp(socket_addr), ConnectAddr::Tcp(socket_addr))); let s1_a = Arc::new(Mutex::new(s1_a)); let s1_b = Arc::new(Mutex::new(s1_b)); @@ -115,7 +118,7 @@ criterion_group!(benches, criterion_util, criterion_mpsc, criterion_tcp); criterion_main!(benches); pub fn network_participant_stream( - addr: ProtocolAddr, + addr: (ListenAddr, ConnectAddr), ) -> ( Runtime, Network, @@ -130,8 +133,8 @@ pub fn network_participant_stream( let n_a = Network::new(Pid::fake(0), &runtime); let n_b = Network::new(Pid::fake(1), &runtime); - n_a.listen(addr.clone()).await.unwrap(); - let p1_b = n_b.connect(addr).await.unwrap(); + n_a.listen(addr.0).await.unwrap(); + let p1_b = n_b.connect(addr.1).await.unwrap(); let p1_a = n_a.connected().await.unwrap(); let s1_a = p1_a.open(4, Promises::empty(), 0).await.unwrap(); diff --git a/network/examples/chat.rs b/network/examples/chat.rs index 8746479f73..2dc1e56e78 100644 --- a/network/examples/chat.rs +++ b/network/examples/chat.rs @@ -8,7 +8,7 @@ use std::{sync::Arc, thread, time::Duration}; use tokio::{io, io::AsyncBufReadExt, runtime::Runtime, sync::RwLock}; use tracing::*; use tracing_subscriber::EnvFilter; -use veloren_network::{Network, Participant, Pid, Promises, ProtocolAddr}; +use veloren_network::{ConnectAddr, ListenAddr, Network, Participant, Pid, Promises}; ///This example contains a simple chatserver, that allows to send messages /// between participants, it's neither pretty nor perfect, but it should show @@ -75,21 +75,27 @@ fn main() { let port: u16 = matches.value_of("port").unwrap().parse().unwrap(); let ip: &str = matches.value_of("ip").unwrap(); - let address = match matches.value_of("protocol") { - Some("tcp") => ProtocolAddr::Tcp(format!("{}:{}", ip, port).parse().unwrap()), - Some("udp") => ProtocolAddr::Udp(format!("{}:{}", ip, port).parse().unwrap()), + let addresses = match matches.value_of("protocol") { + Some("tcp") => ( + ListenAddr::Tcp(format!("{}:{}", ip, port).parse().unwrap()), + ConnectAddr::Tcp(format!("{}:{}", ip, port).parse().unwrap()), + ), + Some("udp") => ( + ListenAddr::Udp(format!("{}:{}", ip, port).parse().unwrap()), + ConnectAddr::Udp(format!("{}:{}", ip, port).parse().unwrap()), + ), _ => panic!("invalid mode, run --help!"), }; let mut background = None; match matches.value_of("mode") { - Some("server") => server(address), - Some("client") => client(address), + Some("server") => server(addresses.0), + Some("client") => client(addresses.1), Some("both") => { - let address1 = address.clone(); - background = Some(thread::spawn(|| server(address1))); + let s = addresses.0; + background = Some(thread::spawn(|| server(s))); thread::sleep(Duration::from_millis(200)); //start client after server - client(address) + client(addresses.1) }, _ => panic!("invalid mode, run --help!"), }; @@ -98,7 +104,7 @@ fn main() { } } -fn server(address: ProtocolAddr) { +fn server(address: ListenAddr) { let r = Arc::new(Runtime::new().unwrap()); let server = Network::new(Pid::new(), &r); let server = Arc::new(server); @@ -144,7 +150,7 @@ async fn client_connection( println!("[{}] disconnected", username); } -fn client(address: ProtocolAddr) { +fn client(address: ConnectAddr) { let r = Arc::new(Runtime::new().unwrap()); let client = Network::new(Pid::new(), &r); diff --git a/network/examples/fileshare/commands.rs b/network/examples/fileshare/commands.rs index a18c90b38e..9f23ddb6aa 100644 --- a/network/examples/fileshare/commands.rs +++ b/network/examples/fileshare/commands.rs @@ -2,7 +2,7 @@ use rand::Rng; use serde::{Deserialize, Serialize}; use std::path::{Path, PathBuf}; use tokio::fs; -use veloren_network::{Participant, ProtocolAddr, Stream}; +use veloren_network::{ConnectAddr, Participant, Stream}; use std::collections::HashMap; @@ -10,7 +10,7 @@ use std::collections::HashMap; pub enum LocalCommand { Shutdown, Disconnect, - Connect(ProtocolAddr), + Connect(ConnectAddr), List, Serve(FileInfo), Get(u32, Option), diff --git a/network/examples/fileshare/main.rs b/network/examples/fileshare/main.rs index f000f371e0..158b825073 100644 --- a/network/examples/fileshare/main.rs +++ b/network/examples/fileshare/main.rs @@ -9,7 +9,7 @@ use std::{path::PathBuf, sync::Arc, thread, time::Duration}; use tokio::{io, io::AsyncBufReadExt, runtime::Runtime, sync::mpsc}; use tracing::*; use tracing_subscriber::EnvFilter; -use veloren_network::ProtocolAddr; +use veloren_network::{ConnectAddr, ListenAddr}; mod commands; mod server; use commands::{FileInfo, LocalCommand}; @@ -50,7 +50,7 @@ fn main() { .init(); let port: u16 = matches.value_of("port").unwrap().parse().unwrap(); - let address = ProtocolAddr::Tcp(format!("{}:{}", "127.0.0.1", port).parse().unwrap()); + let address = ListenAddr::Tcp(format!("{}:{}", "127.0.0.1", port).parse().unwrap()); let runtime = Arc::new(Runtime::new().unwrap()); let (server, cmd_sender) = Server::new(Arc::clone(&runtime)); @@ -158,12 +158,12 @@ async fn client(cmd_sender: mpsc::UnboundedSender) { .parse() .unwrap(); cmd_sender - .send(LocalCommand::Connect(ProtocolAddr::Tcp(socketaddr))) + .send(LocalCommand::Connect(ConnectAddr::Tcp(socketaddr))) .unwrap(); }, ("t", _) => { cmd_sender - .send(LocalCommand::Connect(ProtocolAddr::Tcp( + .send(LocalCommand::Connect(ConnectAddr::Tcp( "127.0.0.1:1231".parse().unwrap(), ))) .unwrap(); diff --git a/network/examples/fileshare/server.rs b/network/examples/fileshare/server.rs index 252ebdf32c..7a40d5be11 100644 --- a/network/examples/fileshare/server.rs +++ b/network/examples/fileshare/server.rs @@ -8,7 +8,7 @@ use tokio::{ }; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::*; -use veloren_network::{Network, Participant, Pid, Promises, ProtocolAddr, Stream}; +use veloren_network::{ListenAddr, Network, Participant, Pid, Promises, Stream}; #[derive(Debug)] struct ControlChannels { @@ -42,7 +42,7 @@ impl Server { ) } - pub async fn run(mut self, address: ProtocolAddr) { + pub async fn run(mut self, address: ListenAddr) { let run_channels = self.run_channels.take().unwrap(); self.network.listen(address).await.unwrap(); diff --git a/network/examples/network-speed/main.rs b/network/examples/network-speed/main.rs index e058aac7a8..e8ccc8f278 100644 --- a/network/examples/network-speed/main.rs +++ b/network/examples/network-speed/main.rs @@ -16,7 +16,7 @@ use std::{ use tokio::runtime::Runtime; use tracing::*; use tracing_subscriber::EnvFilter; -use veloren_network::{Message, Network, Pid, Promises, ProtocolAddr}; +use veloren_network::{ConnectAddr, ListenAddr, Message, Network, Pid, Promises}; #[derive(Serialize, Deserialize, Debug)] enum Msg { @@ -96,23 +96,29 @@ fn main() { let port: u16 = matches.value_of("port").unwrap().parse().unwrap(); let ip: &str = matches.value_of("ip").unwrap(); - let address = match matches.value_of("protocol") { - Some("tcp") => ProtocolAddr::Tcp(format!("{}:{}", ip, port).parse().unwrap()), - Some("udp") => ProtocolAddr::Udp(format!("{}:{}", ip, port).parse().unwrap()), - _ => panic!("Invalid mode, run --help!"), + let addresses = match matches.value_of("protocol") { + Some("tcp") => ( + ListenAddr::Tcp(format!("{}:{}", ip, port).parse().unwrap()), + ConnectAddr::Tcp(format!("{}:{}", ip, port).parse().unwrap()), + ), + Some("udp") => ( + ListenAddr::Udp(format!("{}:{}", ip, port).parse().unwrap()), + ConnectAddr::Udp(format!("{}:{}", ip, port).parse().unwrap()), + ), + _ => panic!("invalid mode, run --help!"), }; let mut background = None; let runtime = Arc::new(Runtime::new().unwrap()); match matches.value_of("mode") { - Some("server") => server(address, Arc::clone(&runtime)), - Some("client") => client(address, Arc::clone(&runtime)), + Some("server") => server(addresses.0, Arc::clone(&runtime)), + Some("client") => client(addresses.1, Arc::clone(&runtime)), Some("both") => { - let address1 = address.clone(); + let s = addresses.0; let runtime2 = Arc::clone(&runtime); - background = Some(thread::spawn(|| server(address1, runtime2))); + background = Some(thread::spawn(|| server(s, runtime2))); thread::sleep(Duration::from_millis(200)); //start client after server - client(address, Arc::clone(&runtime)); + client(addresses.1, Arc::clone(&runtime)); }, _ => panic!("Invalid mode, run --help!"), }; @@ -121,7 +127,7 @@ fn main() { } } -fn server(address: ProtocolAddr, runtime: Arc) { +fn server(address: ListenAddr, runtime: Arc) { let registry = Arc::new(Registry::new()); let server = Network::new_with_registry(Pid::new(), &runtime, ®istry); runtime.spawn(Server::run( @@ -153,7 +159,7 @@ fn server(address: ProtocolAddr, runtime: Arc) { } } -fn client(address: ProtocolAddr, runtime: Arc) { +fn client(address: ConnectAddr, runtime: Arc) { let registry = Arc::new(Registry::new()); let client = Network::new_with_registry(Pid::new(), &runtime, ®istry); runtime.spawn(Server::run( diff --git a/network/protocol/Cargo.toml b/network/protocol/Cargo.toml index 5b06792fc9..4043ac488f 100644 --- a/network/protocol/Cargo.toml +++ b/network/protocol/Cargo.toml @@ -24,6 +24,7 @@ rand = { version = "0.8" } # async traits async-trait = "0.1.42" bytes = "^1" +hashbrown = { version = ">=0.9, <0.12" } [dev-dependencies] async-channel = "1.5.1" diff --git a/network/protocol/benches/protocols.rs b/network/protocol/benches/protocols.rs index d8859943ed..cb1dc02038 100644 --- a/network/protocol/benches/protocols.rs +++ b/network/protocol/benches/protocols.rs @@ -6,8 +6,9 @@ use std::{sync::Arc, time::Duration}; use tokio::runtime::Runtime; use veloren_network_protocol::{ InitProtocol, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid, Promises, ProtocolError, - ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, RecvProtocol, SendProtocol, Sid, - TcpRecvProtocol, TcpSendProtocol, UnreliableDrain, UnreliableSink, _internal::OTFrame, + ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, QuicDataFormat, QuicRecvProtocol, + QuicSendProtocol, RecvProtocol, SendProtocol, Sid, TcpRecvProtocol, TcpSendProtocol, + UnreliableDrain, UnreliableSink, _internal::OTFrame, }; fn frame_serialize(frame: OTFrame, buffer: &mut BytesMut) { frame.write_bytes(buffer); } @@ -145,7 +146,35 @@ fn criterion_tcp(c: &mut Criterion) { c.finish(); } -criterion_group!(benches, criterion_util, criterion_mpsc, criterion_tcp); +fn criterion_quic(c: &mut Criterion) { + let mut c = c.benchmark_group("quic"); + c.significance_level(0.1).sample_size(10); + c.throughput(Throughput::Bytes(1000000000)) + .bench_function("1GB_in_10000_msg", |b| { + let buf = Bytes::from(&[155u8; 100_000][..]); + b.to_async(rt()).iter_with_setup( + || (buf.clone(), utils::quic_bound(10000, None)), + |(b, p)| send_and_recv_msg(p, b, 10_000), + ) + }); + c.throughput(Throughput::Elements(1000000)) + .bench_function("1000000_tiny_msg", |b| { + let buf = Bytes::from(&[3u8; 5][..]); + b.to_async(rt()).iter_with_setup( + || (buf.clone(), utils::quic_bound(10000, None)), + |(b, p)| send_and_recv_msg(p, b, 1_000_000), + ) + }); + c.finish(); +} + +criterion_group!( + benches, + criterion_util, + criterion_mpsc, + criterion_tcp, + criterion_quic +); criterion_main!(benches); mod utils { @@ -210,6 +239,36 @@ mod utils { ] } + pub struct QuicDrain { + pub sender: Sender, + } + + pub struct QuicSink { + pub receiver: Receiver, + } + + /// emulate Quic protocol on Channels + pub fn quic_bound( + cap: usize, + metrics: Option, + ) -> [(QuicSendProtocol, QuicRecvProtocol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + QuicSendProtocol::new(QuicDrain { sender: s1 }, m.clone()), + QuicRecvProtocol::new(QuicSink { receiver: r2 }, m.clone()), + ), + ( + QuicSendProtocol::new(QuicDrain { sender: s2 }, m.clone()), + QuicRecvProtocol::new(QuicSink { receiver: r1 }, m), + ), + ] + } + #[async_trait] impl UnreliableDrain for ACDrain { type DataFormat = MpscMsg; @@ -257,4 +316,28 @@ mod utils { .map_err(|_| ProtocolError::Closed) } } + + #[async_trait] + impl UnreliableDrain for QuicDrain { + type DataFormat = QuicDataFormat; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for QuicSink { + type DataFormat = QuicDataFormat; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } } diff --git a/network/protocol/src/lib.rs b/network/protocol/src/lib.rs index 79c1ae867a..3c2eb70c75 100644 --- a/network/protocol/src/lib.rs +++ b/network/protocol/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(drain_filter)] //! Network Protocol //! //! a I/O-Free protocol for the veloren network crate. @@ -13,9 +14,9 @@ //! This crate currently defines: //! - TCP //! - MPSC +//! - QUIC //! -//! a UDP implementation will quickly follow, and it's also possible to abstract -//! over QUIC. +//! eventually a pure UDP implementation will follow //! //! warning: don't mix protocol, using the TCP variant for actual UDP socket //! will result in dropped data using UDP with a TCP socket will be a waste of @@ -57,8 +58,10 @@ mod message; mod metrics; mod mpsc; mod prio; +mod quic; mod tcp; mod types; +mod util; pub use error::{InitProtocolError, ProtocolError}; pub use event::ProtocolEvent; @@ -66,12 +69,16 @@ pub use metrics::ProtocolMetricCache; #[cfg(feature = "metrics")] pub use metrics::ProtocolMetrics; pub use mpsc::{MpscMsg, MpscRecvProtocol, MpscSendProtocol}; +pub use quic::{QuicDataFormat, QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol}; pub use tcp::{TcpRecvProtocol, TcpSendProtocol}; pub use types::{Bandwidth, Cid, Pid, Prio, Promises, Sid, HIGHEST_PRIO, VELOREN_NETWORK_VERSION}; ///use at own risk, might change any time, for internal benchmarks pub mod _internal { - pub use crate::frame::{ITFrame, OTFrame}; + pub use crate::{ + frame::{ITFrame, OTFrame}, + util::SortedVec, + }; } use async_trait::async_trait; diff --git a/network/protocol/src/prio.rs b/network/protocol/src/prio.rs index 374a1ac216..7304086a1b 100644 --- a/network/protocol/src/prio.rs +++ b/network/protocol/src/prio.rs @@ -75,7 +75,7 @@ impl PrioManager { /// bandwidth might be extended, as for technical reasons /// guaranteed_bandwidth is used and frames are always 1400 bytes. - pub fn grab(&mut self, bandwidth: Bandwidth, dt: Duration) -> (Vec, Bandwidth) { + pub fn grab(&mut self, bandwidth: Bandwidth, dt: Duration) -> (Vec<(Sid, OTFrame)>, Bandwidth) { let total_bytes = (bandwidth as f64 * dt.as_secs_f64()) as u64; let mut cur_bytes = 0u64; let mut frames = vec![]; @@ -84,7 +84,7 @@ impl PrioManager { let metrics = &mut self.metrics; let mut process_stream = - |stream: &mut StreamInfo, mut bandwidth: i64, cur_bytes: &mut u64| { + |sid: &Sid, stream: &mut StreamInfo, mut bandwidth: i64, cur_bytes: &mut u64| { let mut finished = None; 'outer: for (i, msg) in stream.messages.iter_mut().enumerate() { while let Some(frame) = msg.next() { @@ -95,7 +95,7 @@ impl PrioManager { } as u64; bandwidth -= b as i64; *cur_bytes += b; - frames.push(frame); + frames.push((*sid, frame)); if bandwidth <= 0 { break 'outer; } @@ -111,10 +111,10 @@ impl PrioManager { }; // Add guaranteed bandwidth - for stream in self.streams.values_mut() { + for (sid, stream) in self.streams.iter_mut() { prios[stream.prio as usize] += 1; let stream_byte_cnt = (stream.guaranteed_bandwidth as f64 * dt.as_secs_f64()) as u64; - process_stream(stream, stream_byte_cnt as i64, &mut cur_bytes); + process_stream(sid, stream, stream_byte_cnt as i64, &mut cur_bytes); } if cur_bytes < total_bytes { @@ -124,11 +124,11 @@ impl PrioManager { continue; } let per_stream_bytes = ((total_bytes - cur_bytes) / prios[prio as usize]) as i64; - for stream in self.streams.values_mut() { + for (sid, stream) in self.streams.iter_mut() { if stream.prio != prio { continue; } - process_stream(stream, per_stream_bytes, &mut cur_bytes); + process_stream(sid, stream, per_stream_bytes, &mut cur_bytes); } } } diff --git a/network/protocol/src/quic.rs b/network/protocol/src/quic.rs new file mode 100644 index 0000000000..cd76c0b3ef --- /dev/null +++ b/network/protocol/src/quic.rs @@ -0,0 +1,958 @@ +use crate::{ + error::ProtocolError, + event::ProtocolEvent, + frame::{ITFrame, InitFrame, OTFrame}, + handshake::{ReliableDrain, ReliableSink}, + message::{ITMessage, ALLOC_BLOCK}, + metrics::{ProtocolMetricCache, RemoveReason}, + prio::PrioManager, + types::{Bandwidth, Mid, Promises, Sid}, + util::SortedVec, + RecvProtocol, SendProtocol, UnreliableDrain, UnreliableSink, +}; +use async_trait::async_trait; +use bytes::BytesMut; +use hashbrown::HashMap; +use std::time::{Duration, Instant}; +use tracing::info; +#[cfg(feature = "trace_pedantic")] +use tracing::trace; + +#[derive(PartialEq)] +pub enum QuicDataFormatStream { + Main, + Reliable(Sid), + Unreliable, +} + +pub struct QuicDataFormat { + pub stream: QuicDataFormatStream, + pub data: BytesMut, +} + +impl QuicDataFormat { + fn with_main(buffer: &mut BytesMut) -> Self { + Self { + stream: QuicDataFormatStream::Main, + data: buffer.split(), + } + } + + fn with_reliable(buffer: &mut BytesMut, sid: Sid) -> Self { + Self { + stream: QuicDataFormatStream::Reliable(sid), + data: buffer.split(), + } + } + + fn with_unreliable(frame: OTFrame) -> Self { + let mut buffer = BytesMut::new(); + frame.write_bytes(&mut buffer); + Self { + stream: QuicDataFormatStream::Unreliable, + data: buffer, + } + } +} + +/// QUIC implementation of [`SendProtocol`] +/// +/// [`SendProtocol`]: crate::SendProtocol +#[derive(Debug)] +pub struct QuicSendProtocol +where + D: UnreliableDrain, +{ + main_buffer: BytesMut, + reliable_buffers: SortedVec, + store: PrioManager, + next_mid: Mid, + closing_streams: Vec, + notify_closing_streams: Vec, + pending_shutdown: bool, + drain: D, + last: Instant, + metrics: ProtocolMetricCache, +} + +/// QUIC implementation of [`RecvProtocol`] +/// +/// [`RecvProtocol`]: crate::RecvProtocol +#[derive(Debug)] +pub struct QuicRecvProtocol +where + S: UnreliableSink, +{ + main_buffer: BytesMut, + unreliable_buffer: BytesMut, + reliable_buffers: SortedVec, + pending_reliable_buffers: Vec<(Sid, BytesMut)>, + itmsg_allocator: BytesMut, + incoming: HashMap, + sink: S, + metrics: ProtocolMetricCache, +} + +fn is_reliable(p: &Promises) -> bool { + p.contains(Promises::ORDERED) + || p.contains(Promises::CONSISTENCY) + || p.contains(Promises::GUARANTEED_DELIVERY) +} + +impl QuicSendProtocol +where + D: UnreliableDrain, +{ + pub fn new(drain: D, metrics: ProtocolMetricCache) -> Self { + Self { + main_buffer: BytesMut::new(), + reliable_buffers: SortedVec::default(), + store: PrioManager::new(metrics.clone()), + next_mid: 0u64, + closing_streams: vec![], + notify_closing_streams: vec![], + pending_shutdown: false, + drain, + last: Instant::now(), + metrics, + } + } + + /// returns all promises that this Protocol can take care of + /// If you open a Stream anyway, unsupported promises are ignored. + pub fn supported_promises() -> Promises { + Promises::ORDERED + | Promises::CONSISTENCY + | Promises::GUARANTEED_DELIVERY + | Promises::COMPRESSED + | Promises::ENCRYPTED + } +} + +impl QuicRecvProtocol +where + S: UnreliableSink, +{ + pub fn new(sink: S, metrics: ProtocolMetricCache) -> Self { + Self { + main_buffer: BytesMut::new(), + unreliable_buffer: BytesMut::new(), + reliable_buffers: SortedVec::default(), + pending_reliable_buffers: vec![], + itmsg_allocator: BytesMut::with_capacity(ALLOC_BLOCK), + incoming: HashMap::new(), + sink, + metrics, + } + } + + async fn recv_into_stream(&mut self) -> Result { + let chunk = self.sink.recv().await?; + let buffer = match chunk.stream { + QuicDataFormatStream::Main => &mut self.main_buffer, + QuicDataFormatStream::Unreliable => &mut self.unreliable_buffer, + QuicDataFormatStream::Reliable(id) => { + match self.reliable_buffers.get_mut(&id) { + Some(buffer) => buffer, + None => { + self.pending_reliable_buffers.push((id, BytesMut::new())); + //Violated but will never happen + &mut self + .pending_reliable_buffers + .last_mut() + .ok_or(ProtocolError::Violated)? + .1 + }, + } + }, + }; + if buffer.is_empty() { + *buffer = chunk.data + } else { + buffer.extend_from_slice(&chunk.data) + } + Ok(chunk.stream) + } +} + +#[async_trait] +impl SendProtocol for QuicSendProtocol +where + D: UnreliableDrain, +{ + fn notify_from_recv(&mut self, event: ProtocolEvent) { + match event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + self.store + .open_stream(sid, prio, promises, guaranteed_bandwidth); + if is_reliable(&promises) { + self.reliable_buffers.insert(sid, BytesMut::new()); + } + }, + ProtocolEvent::CloseStream { sid } => { + if !self.store.try_close_stream(sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "hold back notify close stream"); + self.notify_closing_streams.push(sid); + } + }, + _ => {}, + } + } + + async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { + #[cfg(feature = "trace_pedantic")] + trace!(?event, "send"); + match event { + ProtocolEvent::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + self.store + .open_stream(sid, prio, promises, guaranteed_bandwidth); + if is_reliable(&promises) { + self.reliable_buffers.insert(sid, BytesMut::new()); + //Send a empty message to notify local drain of stream + self.drain + .send(QuicDataFormat::with_reliable(&mut BytesMut::new(), sid)) + .await?; + } + event.to_frame().write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + }, + ProtocolEvent::CloseStream { sid } => { + if self.store.try_close_stream(sid) { + let _ = self.reliable_buffers.delete(&sid); //delete if it was reliable + event.to_frame().write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + } else { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "hold back close stream"); + self.closing_streams.push(sid); + } + }, + ProtocolEvent::Shutdown => { + if self.store.is_empty() { + event.to_frame().write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + } else { + #[cfg(feature = "trace_pedantic")] + trace!("hold back shutdown"); + self.pending_shutdown = true; + } + }, + ProtocolEvent::Message { data, sid } => { + self.metrics.smsg_ib(sid, data.len() as u64); + self.store.add(data, self.next_mid, sid); + self.next_mid += 1; + }, + } + Ok(()) + } + + async fn flush( + &mut self, + bandwidth: Bandwidth, + dt: Duration, + ) -> Result { + let (frames, _) = self.store.grab(bandwidth, dt); + //Todo: optimize reserve + let mut data_frames = 0; + let mut data_bandwidth = 0; + for (sid, frame) in frames { + if let OTFrame::Data { mid: _, data } = &frame { + data_bandwidth += data.len(); + data_frames += 1; + } + match self.reliable_buffers.get_mut(&sid) { + Some(buffer) => frame.write_bytes(buffer), + None => { + self.drain + .send(QuicDataFormat::with_unreliable(frame)) + .await? + }, + } + } + for (sid, buffer) in self.reliable_buffers.data.iter_mut() { + if !buffer.is_empty() { + self.drain + .send(QuicDataFormat::with_reliable(buffer, *sid)) + .await?; + } + } + self.metrics + .sdata_frames_b(data_frames, data_bandwidth as u64); + + let mut finished_streams = vec![]; + for (i, &sid) in self.closing_streams.iter().enumerate() { + if self.store.try_close_stream(sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "close stream, as it's now empty"); + OTFrame::CloseStream { sid }.write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + finished_streams.push(i); + } + } + for i in finished_streams.iter().rev() { + self.closing_streams.remove(*i); + } + + let mut finished_streams = vec![]; + for (i, sid) in self.notify_closing_streams.iter().enumerate() { + if self.store.try_close_stream(*sid) { + #[cfg(feature = "trace_pedantic")] + trace!(?sid, "close stream, as it's now empty"); + finished_streams.push(i); + } + } + for i in finished_streams.iter().rev() { + self.notify_closing_streams.remove(*i); + } + + if self.pending_shutdown && self.store.is_empty() { + #[cfg(feature = "trace_pedantic")] + trace!("shutdown, as it's now empty"); + OTFrame::Shutdown {}.write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await?; + self.pending_shutdown = false; + } + Ok(data_bandwidth as u64) + } +} + +#[async_trait] +impl RecvProtocol for QuicRecvProtocol +where + S: UnreliableSink, +{ + async fn recv(&mut self) -> Result { + 'outer: loop { + match ITFrame::read_frame(&mut self.main_buffer) { + Ok(Some(frame)) => { + #[cfg(feature = "trace_pedantic")] + trace!(?frame, "recv"); + match frame { + ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), + ITFrame::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + if is_reliable(&promises) { + self.reliable_buffers.insert(sid, BytesMut::new()); + } + break 'outer Ok(ProtocolEvent::OpenStream { + sid, + prio: prio.min(crate::types::HIGHEST_PRIO), + promises, + guaranteed_bandwidth, + }); + }, + ITFrame::CloseStream { sid } => { + //FIXME: defer close! + //let _ = self.reliable_buffers.delete(sid); // if it was reliable + break 'outer Ok(ProtocolEvent::CloseStream { sid }); + }, + _ => break 'outer Err(ProtocolError::Violated), + }; + }, + Ok(None) => {}, + Err(()) => return Err(ProtocolError::Violated), + } + + // try to order pending + let mut pending_violated = false; + let mut reliable = vec![]; + self.pending_reliable_buffers.drain_filter(|(_, buffer)| { + // try to get Sid without touching buffer + let mut testbuffer = buffer.clone(); + match ITFrame::read_frame(&mut testbuffer) { + Ok(Some(ITFrame::DataHeader { + sid, + mid: _, + length: _, + })) => { + reliable.push((sid, buffer.clone())); + true + }, + Ok(Some(_)) | Err(_) => { + pending_violated = true; + true + }, + Ok(None) => false, + } + }); + + if pending_violated { + break 'outer Err(ProtocolError::Violated); + } + for (sid, buffer) in reliable.into_iter() { + self.reliable_buffers.insert(sid, buffer) + } + + let mut iter = self + .reliable_buffers + .data + .iter_mut() + .map(|(_, b)| (b, true)) + .collect::>(); + iter.push((&mut self.unreliable_buffer, false)); + + for (buffer, reliable) in iter { + loop { + match ITFrame::read_frame(buffer) { + Ok(Some(frame)) => { + #[cfg(feature = "trace_pedantic")] + trace!(?frame, "recv"); + match frame { + ITFrame::DataHeader { sid, mid, length } => { + let m = ITMessage::new(sid, length, &mut self.itmsg_allocator); + self.metrics.rmsg_ib(sid, length); + self.incoming.insert(mid, m); + }, + ITFrame::Data { mid, data } => { + self.metrics.rdata_frames_b(data.len() as u64); + let m = match self.incoming.get_mut(&mid) { + Some(m) => m, + None => { + if reliable { + info!( + ?mid, + "protocol violation by remote side: send Data \ + before Header" + ); + break 'outer Err(ProtocolError::Violated); + } else { + //TODO: cleanup old messages from time to time + continue; + } + }, + }; + m.data.extend_from_slice(&data); + if m.data.len() == m.length as usize { + // finished, yay + let m = self + .incoming + .remove(&mid) + .ok_or(ProtocolError::Violated)?; + self.metrics.rmsg_ob( + m.sid, + RemoveReason::Finished, + m.data.len() as u64, + ); + break 'outer Ok(ProtocolEvent::Message { + sid: m.sid, + data: m.data.freeze(), + }); + } + }, + _ => break 'outer Err(ProtocolError::Violated), + }; + }, + Ok(None) => break, //inner => read more data + Err(()) => return Err(ProtocolError::Violated), + } + } + } + + self.recv_into_stream().await?; + } + } +} + +#[async_trait] +impl ReliableDrain for QuicSendProtocol +where + D: UnreliableDrain, +{ + async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> { + self.main_buffer.reserve(500); + frame.write_bytes(&mut self.main_buffer); + self.drain + .send(QuicDataFormat::with_main(&mut self.main_buffer)) + .await + } +} + +#[async_trait] +impl ReliableSink for QuicRecvProtocol +where + S: UnreliableSink, +{ + async fn recv(&mut self) -> Result { + while self.main_buffer.len() < 100 { + if self.recv_into_stream().await? == QuicDataFormatStream::Main { + if let Some(frame) = InitFrame::read_frame(&mut self.main_buffer) { + return Ok(frame); + } + } + } + Err(ProtocolError::Violated) + } +} + +#[cfg(test)] +mod test_utils { + //Quic protocol based on Channel + use super::*; + use crate::metrics::{ProtocolMetricCache, ProtocolMetrics}; + use async_channel::*; + use std::sync::Arc; + + pub struct QuicDrain { + pub sender: Sender, + pub drop_ratio: f32, + } + + pub struct QuicSink { + pub receiver: Receiver, + } + + /// emulate Quic protocol on Channels + pub fn quic_bound( + cap: usize, + drop_ratio: f32, + metrics: Option, + ) -> [(QuicSendProtocol, QuicRecvProtocol); 2] { + let (s1, r1) = async_channel::bounded(cap); + let (s2, r2) = async_channel::bounded(cap); + let m = metrics.unwrap_or_else(|| { + ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap())) + }); + [ + ( + QuicSendProtocol::new( + QuicDrain { + sender: s1, + drop_ratio, + }, + m.clone(), + ), + QuicRecvProtocol::new(QuicSink { receiver: r2 }, m.clone()), + ), + ( + QuicSendProtocol::new( + QuicDrain { + sender: s2, + drop_ratio, + }, + m.clone(), + ), + QuicRecvProtocol::new(QuicSink { receiver: r1 }, m), + ), + ] + } + + #[async_trait] + impl UnreliableDrain for QuicDrain { + type DataFormat = QuicDataFormat; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + use rand::Rng; + if matches!(data.stream, QuicDataFormatStream::Unreliable) + && rand::thread_rng().gen::() < self.drop_ratio + { + return Ok(()); + } + self.sender + .send(data) + .await + .map_err(|_| ProtocolError::Closed) + } + } + + #[async_trait] + impl UnreliableSink for QuicSink { + type DataFormat = QuicDataFormat; + + async fn recv(&mut self) -> Result { + self.receiver + .recv() + .await + .map_err(|_| ProtocolError::Closed) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + error::ProtocolError, + frame::OTFrame, + metrics::{ProtocolMetricCache, ProtocolMetrics, RemoveReason}, + quic::{test_utils::*, QuicDataFormat}, + types::{Pid, Promises, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2}, + InitProtocol, ProtocolEvent, RecvProtocol, SendProtocol, + }; + use bytes::{Bytes, BytesMut}; + use std::{sync::Arc, time::Duration}; + + #[tokio::test] + async fn handshake_all_good() { + let [mut p1, mut p2] = quic_bound(10, 0.5, None); + let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await }); + let r2 = tokio::spawn(async move { p2.initialize(false, Pid::fake(3), 42).await }); + let (r1, r2) = tokio::join!(r1, r2); + assert_eq!(r1.unwrap(), Ok((Pid::fake(3), STREAM_ID_OFFSET1, 42))); + assert_eq!(r2.unwrap(), Ok((Pid::fake(2), STREAM_ID_OFFSET2, 1337))); + } + + #[tokio::test] + async fn open_stream() { + let [p1, p2] = quic_bound(10, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 0u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event.clone()).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn send_short_msg() { + let [p1, p2] = quic_bound(10, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + data: Bytes::from(&[188u8; 600][..]), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + // 2nd short message + let event = ProtocolEvent::Message { + sid: Sid::new(10), + data: Bytes::from(&[7u8; 30][..]), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e) + } + + #[tokio::test] + async fn send_long_msg() { + let mut metrics = + ProtocolMetricCache::new("long_quic", Arc::new(ProtocolMetrics::new().unwrap())); + let sid = Sid::new(1); + let [p1, p2] = quic_bound(10000, 0.5, Some(metrics.clone())); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event.clone()).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert_eq!(event, e); + metrics.assert_msg(sid, 1, RemoveReason::Finished); + metrics.assert_msg_bytes(sid, 500_000, RemoveReason::Finished); + metrics.assert_data_frames(358); + metrics.assert_data_frames_bytes(500_000); + } + + #[tokio::test] + async fn msg_finishes_after_close() { + let sid = Sid::new(1); + let [p1, p2] = quic_bound(10000, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::CloseStream { sid }; + s.send(event).await.unwrap(); + //send + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + } + + #[tokio::test] + async fn msg_finishes_after_shutdown() { + let sid = Sid::new(1); + let [p1, p2] = quic_bound(10000, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let _ = r.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::Shutdown {}; + s.send(event).await.unwrap(); + let event = ProtocolEvent::CloseStream { sid }; + s.send(event).await.unwrap(); + //send + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Shutdown { .. })); + } + + #[tokio::test] + async fn msg_finishes_after_drop() { + let sid = Sid::new(1); + let [p1, p2] = quic_bound(10000, 0.5, None); + let (mut s, mut r) = (p1.0, p2.1); + let event = ProtocolEvent::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 0, + }; + s.send(event).await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[99u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let event = ProtocolEvent::Message { + sid, + data: Bytes::from(&[100u8; 500_000][..]), + }; + s.send(event).await.unwrap(); + s.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + drop(s); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + } + + #[tokio::test] + async fn header_and_data_in_seperate_msg() { + let sid = Sid::new(1); + let (s, r) = async_channel::bounded(10); + let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap())); + let mut r = + super::QuicRecvProtocol::new(super::test_utils::QuicSink { receiver: r }, m.clone()); + + const DATA1: &[u8; 69] = + b"We need to make sure that its okay to send OPEN_STREAM and DATA_HEAD "; + const DATA2: &[u8; 95] = b"in one chunk and (DATA and CLOSE_STREAM) in the second chunk. and then keep the connection open"; + let mut bytes = BytesMut::with_capacity(1500); + OTFrame::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED | Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + } + .write_bytes(&mut bytes); + s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap(); + + OTFrame::DataHeader { + mid: 99, + sid, + length: (DATA1.len() + DATA2.len()) as u64, + } + .write_bytes(&mut bytes); + s.send(QuicDataFormat::with_reliable(&mut bytes, sid)) + .await + .unwrap(); + + OTFrame::Data { + mid: 99, + data: Bytes::from(&DATA1[..]), + } + .write_bytes(&mut bytes); + OTFrame::Data { + mid: 99, + data: Bytes::from(&DATA2[..]), + } + .write_bytes(&mut bytes); + s.send(QuicDataFormat::with_reliable(&mut bytes, sid)) + .await + .unwrap(); + + OTFrame::CloseStream { sid }.write_bytes(&mut bytes); + s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap(); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::Message { .. })); + + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::CloseStream { .. })); + } + + #[tokio::test] + async fn drop_sink_while_recv() { + let sid = Sid::new(1); + let (s, r) = async_channel::bounded(10); + let m = ProtocolMetricCache::new("quic", Arc::new(ProtocolMetrics::new().unwrap())); + let mut r = + super::QuicRecvProtocol::new(super::test_utils::QuicSink { receiver: r }, m.clone()); + + let mut bytes = BytesMut::with_capacity(1500); + OTFrame::OpenStream { + sid, + prio: 5u8, + promises: Promises::COMPRESSED, + guaranteed_bandwidth: 1_000_000, + } + .write_bytes(&mut bytes); + s.send(QuicDataFormat::with_main(&mut bytes)).await.unwrap(); + let e = r.recv().await.unwrap(); + assert!(matches!(e, ProtocolEvent::OpenStream { .. })); + + let e = tokio::spawn(async move { r.recv().await }); + drop(s); + + let e = e.await.unwrap(); + assert_eq!(e, Err(ProtocolError::Closed)); + } + + #[tokio::test] + #[should_panic] + async fn send_on_stream_from_remote_without_notify() { + //remote opens stream + //we send on it + let [mut p1, mut p2] = quic_bound(10, 0.5, None); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + p1.0.send(event).await.unwrap(); + let _ = p2.1.recv().await.unwrap(); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + data: Bytes::from(&[188u8; 600][..]), + }; + p2.0.send(event.clone()).await.unwrap(); + p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = p1.1.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn send_on_stream_from_remote() { + //remote opens stream + //we send on it + let [mut p1, mut p2] = quic_bound(10, 0.5, None); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(10), + prio: 3u8, + promises: Promises::ORDERED, + guaranteed_bandwidth: 1_000_000, + }; + p1.0.send(event).await.unwrap(); + let e = p2.1.recv().await.unwrap(); + p2.0.notify_from_recv(e); + let event = ProtocolEvent::Message { + sid: Sid::new(10), + data: Bytes::from(&[188u8; 600][..]), + }; + p2.0.send(event.clone()).await.unwrap(); + p2.0.flush(1_000_000, Duration::from_secs(1)).await.unwrap(); + let e = p1.1.recv().await.unwrap(); + assert_eq!(event, e); + } + + #[tokio::test] + async fn unrealiable_test() { + const MIN_CHECK: usize = 10; + const COUNT: usize = 10_000; + //We send COUNT msg with 50% of be send each. we check that >= MIN_CHECK && != + // COUNT reach their target + + let [mut p1, mut p2] = quic_bound( + COUNT * 2 - 1, /* 2 times as it is HEADER + DATA but -1 as we want to see not all + * succeed */ + 0.5, + None, + ); + let event = ProtocolEvent::OpenStream { + sid: Sid::new(1337), + prio: 3u8, + promises: Promises::empty(), /* on purpose! */ + guaranteed_bandwidth: 1_000_000, + }; + p1.0.send(event).await.unwrap(); + let e = p2.1.recv().await.unwrap(); + p2.0.notify_from_recv(e); + let event = ProtocolEvent::Message { + sid: Sid::new(1337), + data: Bytes::from(&[188u8; 600][..]), + }; + for _ in 0..COUNT { + p2.0.send(event.clone()).await.unwrap(); + } + p2.0.flush(1_000_000_000, Duration::from_secs(1)) + .await + .unwrap(); + for _ in 0..COUNT { + p2.0.send(event.clone()).await.unwrap(); + } + for _ in 0..MIN_CHECK { + let e = p1.1.recv().await.unwrap(); + assert_eq!(event, e); + } + } +} diff --git a/network/protocol/src/tcp.rs b/network/protocol/src/tcp.rs index 43d14e2a1e..0909336a58 100644 --- a/network/protocol/src/tcp.rs +++ b/network/protocol/src/tcp.rs @@ -11,10 +11,8 @@ use crate::{ }; use async_trait::async_trait; use bytes::BytesMut; -use std::{ - collections::HashMap, - time::{Duration, Instant}, -}; +use hashbrown::HashMap; +use std::time::{Duration, Instant}; use tracing::info; #[cfg(feature = "trace_pedantic")] use tracing::trace; @@ -176,7 +174,7 @@ where self.buffer.reserve(total_bytes as usize); let mut data_frames = 0; let mut data_bandwidth = 0; - for frame in frames { + for (_, frame) in frames { if let OTFrame::Data { mid: _, data } = &frame { data_bandwidth += data.len(); data_frames += 1; diff --git a/network/protocol/src/types.rs b/network/protocol/src/types.rs index dfc9142f38..2e189b412d 100644 --- a/network/protocol/src/types.rs +++ b/network/protocol/src/types.rs @@ -118,6 +118,8 @@ impl Pid { impl Sid { pub const fn new(internal: u64) -> Self { Self { internal } } + pub fn get_u64(&self) -> u64 { self.internal } + #[inline] pub(crate) fn from_bytes(bytes: &mut BytesMut) -> Self { Self { diff --git a/network/protocol/src/util.rs b/network/protocol/src/util.rs new file mode 100644 index 0000000000..1e28d4c4ab --- /dev/null +++ b/network/protocol/src/util.rs @@ -0,0 +1,71 @@ +/// Used for storing Buffers in a QUIC +#[derive(Debug)] +pub struct SortedVec { + pub data: Vec<(K, V)>, +} + +impl Default for SortedVec { + fn default() -> Self { Self { data: vec![] } } +} + +impl SortedVec +where + K: Ord + Copy, +{ + pub fn insert(&mut self, k: K, v: V) { + self.data.push((k, v)); + self.data.sort_by_key(|&(k, _)| k); + } + + pub fn delete(&mut self, k: &K) -> Option { + if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { + Some(self.data.remove(i).1) + } else { + None + } + } + + pub fn get(&self, k: &K) -> Option<&V> { + if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { + Some(&self.data[i].1) + } else { + None + } + } + + pub fn get_mut(&mut self, k: &K) -> Option<&mut V> { + if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { + Some(&mut self.data[i].1) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sorted_vec() { + let mut vec = SortedVec::default(); + vec.insert(10, "Hello"); + println!("{:?}", vec.data); + vec.insert(30, "World"); + println!("{:?}", vec.data); + vec.insert(20, " "); + println!("{:?}", vec.data); + assert_eq!(vec.data[0].1, "Hello"); + assert_eq!(vec.data[1].1, " "); + assert_eq!(vec.data[2].1, "World"); + assert_eq!(vec.get(&30), Some(&"World")); + assert_eq!(vec.get_mut(&20), Some(&mut " ")); + assert_eq!(vec.get(&10), Some(&"Hello")); + assert_eq!(vec.delete(&40), None); + assert_eq!(vec.delete(&10), Some("Hello")); + assert_eq!(vec.delete(&10), None); + assert_eq!(vec.get(&30), Some(&"World")); + assert_eq!(vec.get_mut(&20), Some(&mut " ")); + assert_eq!(vec.get(&10), None); + } +} diff --git a/network/src/api.rs b/network/src/api.rs index e04094d6ce..0da58aa6d5 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -28,11 +28,23 @@ use tracing::*; type A2sDisconnect = Arc>>>; -/// Represents a Tcp or Udp or Mpsc address -#[derive(Clone, Debug, Hash, PartialEq, Eq)] -pub enum ProtocolAddr { +/// Represents a Tcp, Quic, Udp or Mpsc connection address +#[derive(Clone, Debug)] +pub enum ConnectAddr { Tcp(SocketAddr), Udp(SocketAddr), + #[cfg(feature = "quic")] + Quic(SocketAddr, quinn::ClientConfig, String), + Mpsc(u64), +} + +/// Represents a Tcp, Quic, Udp or Mpsc listen address +#[derive(Clone, Debug)] +pub enum ListenAddr { + Tcp(SocketAddr), + Udp(SocketAddr), + #[cfg(feature = "quic")] + Quic(SocketAddr, quinn::ServerConfig), Mpsc(u64), } @@ -133,8 +145,8 @@ pub struct StreamParams { /// [`Arc`](std::sync::Arc) as all commands have internal mutability. /// /// The `Network` has methods to [`connect`] to other [`Participants`] actively -/// via their [`ProtocolAddr`], or [`listen`] passively for [`connected`] -/// [`Participants`]. +/// via their [`ConnectAddr`], or [`listen`] passively for [`connected`] +/// [`Participants`] via [`ListenAddr`]. /// /// Too guarantee a clean shutdown, the [`Runtime`] MUST NOT be droped before /// the Network. @@ -142,7 +154,7 @@ pub struct StreamParams { /// # Examples /// ```rust /// use tokio::runtime::Runtime; -/// use veloren_network::{Network, ProtocolAddr, Pid}; +/// use veloren_network::{Network, ConnectAddr, ListenAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on port `2999` to accept connections and connect to port `8080` to connect to a (pseudo) database Application @@ -151,9 +163,9 @@ pub struct StreamParams { /// runtime.block_on(async{ /// # //setup pseudo database! /// # let database = Network::new(Pid::new(), &runtime); -/// # database.listen(ProtocolAddr::Tcp("127.0.0.1:8080".parse().unwrap())).await?; -/// network.listen(ProtocolAddr::Tcp("127.0.0.1:2999".parse().unwrap())).await?; -/// let database = network.connect(ProtocolAddr::Tcp("127.0.0.1:8080".parse().unwrap())).await?; +/// # database.listen(ListenAddr::Tcp("127.0.0.1:8080".parse().unwrap())).await?; +/// network.listen(ListenAddr::Tcp("127.0.0.1:2999".parse().unwrap())).await?; +/// let database = network.connect(ConnectAddr::Tcp("127.0.0.1:8080".parse().unwrap())).await?; /// drop(network); /// # drop(database); /// # Ok(()) @@ -166,10 +178,12 @@ pub struct StreamParams { /// [`connect`]: Network::connect /// [`listen`]: Network::listen /// [`connected`]: Network::connected +/// [`ConnectAddr`]: crate::api::ConnectAddr +/// [`ListenAddr`]: crate::api::ListenAddr pub struct Network { local_pid: Pid, participant_disconnect_sender: Arc>>, - listen_sender: Mutex>)>>, + listen_sender: Mutex>)>>, connect_sender: Mutex>, connected_receiver: Mutex>, shutdown_network_s: Option>>, @@ -195,7 +209,7 @@ impl Network { /// # Examples /// ```rust /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolAddr}; + /// use veloren_network::{Network, Pid}; /// /// let runtime = Runtime::new().unwrap(); /// let network = Network::new(Pid::new(), &runtime); @@ -228,7 +242,7 @@ impl Network { /// ```rust /// use prometheus::Registry; /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolAddr}; + /// use veloren_network::{Network, Pid}; /// /// let runtime = Runtime::new().unwrap(); /// let registry = Registry::new(); @@ -281,7 +295,7 @@ impl Network { } } - /// starts listening on an [`ProtocolAddr`]. + /// starts listening on an [`ListenAddr`]. /// When the method returns the `Network` is ready to listen for incoming /// connections OR has returned a [`NetworkError`] (e.g. port already used). /// You can call [`connected`] to asynchrony wait for a [`Participant`] to @@ -291,7 +305,7 @@ impl Network { /// # Examples /// ```ignore /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolAddr}; + /// use veloren_network::{Network, Pid, ListenAddr}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on port `2000` TCP on all NICs and `2001` UDP locally @@ -299,10 +313,10 @@ impl Network { /// let network = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { /// network - /// .listen(ProtocolAddr::Tcp("127.0.0.1:2000".parse().unwrap())) + /// .listen(ListenAddr::Tcp("127.0.0.1:2000".parse().unwrap())) /// .await?; /// network - /// .listen(ProtocolAddr::Udp("127.0.0.1:2001".parse().unwrap())) + /// .listen(ListenAddr::Udp("127.0.0.1:2001".parse().unwrap())) /// .await?; /// drop(network); /// # Ok(()) @@ -311,8 +325,9 @@ impl Network { /// ``` /// /// [`connected`]: Network::connected + /// [`ListenAddr`]: crate::api::ListenAddr #[instrument(name="network", skip(self, address), fields(p = %self.local_pid))] - pub async fn listen(&self, address: ProtocolAddr) -> Result<(), NetworkError> { + pub async fn listen(&self, address: ListenAddr) -> Result<(), NetworkError> { let (s2a_result_s, s2a_result_r) = oneshot::channel::>(); debug!(?address, "listening on address"); self.listen_sender @@ -327,13 +342,13 @@ impl Network { } } - /// starts connection to an [`ProtocolAddr`]. + /// starts connection to an [`ConnectAddr`]. /// When the method returns the Network either returns a [`Participant`] /// ready to open [`Streams`] on OR has returned a [`NetworkError`] (e.g. /// can't connect, or invalid Handshake) # Examples /// ```ignore /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolAddr}; + /// use veloren_network::{Network, Pid, ListenAddr, ConnectAddr}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, connect on port `2010` TCP and `2011` UDP like listening above @@ -341,16 +356,16 @@ impl Network { /// let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// # remote.listen(ProtocolAddr::Tcp("127.0.0.1:2010".parse().unwrap())).await?; - /// # remote.listen(ProtocolAddr::Udp("127.0.0.1:2011".parse().unwrap())).await?; + /// # remote.listen(ListenAddr::Tcp("127.0.0.1:2010".parse().unwrap())).await?; + /// # remote.listen(ListenAddr::Udp("127.0.0.1:2011".parse().unwrap())).await?; /// let p1 = network - /// .connect(ProtocolAddr::Tcp("127.0.0.1:2010".parse().unwrap())) + /// .connect(ConnectAddr::Tcp("127.0.0.1:2010".parse().unwrap())) /// .await?; /// # //this doesn't work yet, so skip the test /// # //TODO fixme! /// # return Ok(()); /// let p2 = network - /// .connect(ProtocolAddr::Udp("127.0.0.1:2011".parse().unwrap())) + /// .connect(ConnectAddr::Udp("127.0.0.1:2011".parse().unwrap())) /// .await?; /// assert_eq!(&p1, &p2); /// # Ok(()) @@ -362,15 +377,15 @@ impl Network { /// ``` /// Usually the `Network` guarantees that a operation on a [`Participant`] /// succeeds, e.g. by automatic retrying unless it fails completely e.g. by - /// disconnecting from the remote. If 2 [`ProtocolAddres`] you `connect` to - /// belongs to the same [`Participant`], you get the same [`Participant`] as - /// a result. This is useful e.g. by connecting to the same - /// [`Participant`] via multiple Protocols. + /// disconnecting from the remote. If 2 [`ConnectAddr] you + /// `connect` to belongs to the same [`Participant`], you get the same + /// [`Participant`] as a result. This is useful e.g. by connecting to + /// the same [`Participant`] via multiple Protocols. /// /// [`Streams`]: crate::api::Stream - /// [`ProtocolAddres`]: crate::api::ProtocolAddr + /// [`ConnectAddr`]: crate::api::ConnectAddr #[instrument(name="network", skip(self, address), fields(p = %self.local_pid))] - pub async fn connect(&self, address: ProtocolAddr) -> Result { + pub async fn connect(&self, address: ConnectAddr) -> Result { let (pid_sender, pid_receiver) = oneshot::channel::>(); debug!(?address, "Connect to address"); @@ -391,15 +406,15 @@ impl Network { Ok(participant) } - /// returns a [`Participant`] created from a [`ProtocolAddr`] you called - /// [`listen`] on before. This function will either return a working - /// [`Participant`] ready to open [`Streams`] on OR has returned a - /// [`NetworkError`] (e.g. Network got closed) + /// returns a [`Participant`] created from a [`ListenAddr`] you + /// called [`listen`] on before. This function will either return a + /// working [`Participant`] ready to open [`Streams`] on OR has returned + /// a [`NetworkError`] (e.g. Network got closed) /// /// # Examples /// ```rust /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolAddr}; + /// use veloren_network::{ConnectAddr, ListenAddr, Network, Pid}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on port `2020` TCP and opens returns their Pid @@ -408,9 +423,9 @@ impl Network { /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { /// network - /// .listen(ProtocolAddr::Tcp("127.0.0.1:2020".parse().unwrap())) + /// .listen(ListenAddr::Tcp("127.0.0.1:2020".parse().unwrap())) /// .await?; - /// # remote.connect(ProtocolAddr::Tcp("127.0.0.1:2020".parse().unwrap())).await?; + /// # remote.connect(ConnectAddr::Tcp("127.0.0.1:2020".parse().unwrap())).await?; /// while let Ok(participant) = network.connected().await { /// println!("Participant connected: {}", participant.remote_pid()); /// # //skip test here as it would be a endless loop @@ -425,6 +440,7 @@ impl Network { /// /// [`Streams`]: crate::api::Stream /// [`listen`]: crate::api::Network::listen + /// [`ListenAddr`]: crate::api::ListenAddr #[instrument(name="network", skip(self), fields(p = %self.local_pid))] pub async fn connected(&self) -> Result { let participant = self.connected_receiver.lock().await.recv().await?; @@ -528,7 +544,7 @@ impl Participant { /// # Examples /// ```rust /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, Promises, ProtocolAddr}; + /// use veloren_network::{ConnectAddr, ListenAddr, Network, Pid, Promises}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, connect on port 2100 and open a stream @@ -536,9 +552,9 @@ impl Participant { /// let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// # remote.listen(ProtocolAddr::Tcp("127.0.0.1:2100".parse().unwrap())).await?; + /// # remote.listen(ListenAddr::Tcp("127.0.0.1:2100".parse().unwrap())).await?; /// let p1 = network - /// .connect(ProtocolAddr::Tcp("127.0.0.1:2100".parse().unwrap())) + /// .connect(ConnectAddr::Tcp("127.0.0.1:2100".parse().unwrap())) /// .await?; /// let _s1 = p1 /// .open(4, Promises::ORDERED | Promises::CONSISTENCY, 1000) @@ -595,7 +611,7 @@ impl Participant { /// # Examples /// ```rust /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolAddr, Promises}; + /// use veloren_network::{Network, Pid, ListenAddr, ConnectAddr, Promises}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, connect on port 2110 and wait for the other side to open a stream @@ -604,8 +620,8 @@ impl Participant { /// let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// # remote.listen(ProtocolAddr::Tcp("127.0.0.1:2110".parse().unwrap())).await?; - /// let p1 = network.connect(ProtocolAddr::Tcp("127.0.0.1:2110".parse().unwrap())).await?; + /// # remote.listen(ListenAddr::Tcp("127.0.0.1:2110".parse().unwrap())).await?; + /// let p1 = network.connect(ConnectAddr::Tcp("127.0.0.1:2110".parse().unwrap())).await?; /// # let p2 = remote.connected().await?; /// # p2.open(4, Promises::ORDERED | Promises::CONSISTENCY, 0).await?; /// let _s1 = p1.opened().await?; @@ -652,7 +668,7 @@ impl Participant { /// # Examples /// ```rust /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, Pid, ProtocolAddr}; + /// use veloren_network::{Network, Pid, ListenAddr, ConnectAddr}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on port `2030` TCP and opens returns their Pid and close connection. @@ -661,9 +677,9 @@ impl Participant { /// # let remote = Network::new(Pid::new(), &runtime); /// let err = runtime.block_on(async { /// network - /// .listen(ProtocolAddr::Tcp("127.0.0.1:2030".parse().unwrap())) + /// .listen(ListenAddr::Tcp("127.0.0.1:2030".parse().unwrap())) /// .await?; - /// # let keep_alive = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2030".parse().unwrap())).await?; + /// # let keep_alive = remote.connect(ConnectAddr::Tcp("127.0.0.1:2030".parse().unwrap())).await?; /// while let Ok(participant) = network.connected().await { /// println!("Participant connected: {}", participant.remote_pid()); /// participant.disconnect().await?; @@ -788,7 +804,7 @@ impl Stream { /// ``` /// # use veloren_network::Promises; /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, ProtocolAddr, Pid}; + /// use veloren_network::{Network, ListenAddr, ConnectAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on Port `2200` and wait for a Stream to be opened, then answer `Hello World` @@ -796,8 +812,8 @@ impl Stream { /// let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2200".parse().unwrap())).await?; - /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2200".parse().unwrap())).await?; + /// network.listen(ListenAddr::Tcp("127.0.0.1:2200".parse().unwrap())).await?; + /// # let remote_p = remote.connect(ConnectAddr::Tcp("127.0.0.1:2200".parse().unwrap())).await?; /// # // keep it alive /// # let _stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY, 0).await?; /// let participant_a = network.connected().await?; @@ -830,7 +846,7 @@ impl Stream { /// # use veloren_network::Promises; /// use tokio::runtime::Runtime; /// use bincode; - /// use veloren_network::{Network, ProtocolAddr, Pid, Message}; + /// use veloren_network::{Network, ListenAddr, ConnectAddr, Pid, Message}; /// /// # fn main() -> std::result::Result<(), Box> { /// let runtime = Runtime::new().unwrap(); @@ -838,9 +854,9 @@ impl Stream { /// # let remote1 = Network::new(Pid::new(), &runtime); /// # let remote2 = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; - /// # let remote1_p = remote1.connect(ProtocolAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; - /// # let remote2_p = remote2.connect(ProtocolAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; + /// network.listen(ListenAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; + /// # let remote1_p = remote1.connect(ConnectAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; + /// # let remote2_p = remote2.connect(ConnectAddr::Tcp("127.0.0.1:2210".parse().unwrap())).await?; /// # assert_eq!(remote1_p.remote_pid(), remote2_p.remote_pid()); /// # remote1_p.open(4, Promises::ORDERED | Promises::CONSISTENCY, 0).await?; /// # remote2_p.open(4, Promises::ORDERED | Promises::CONSISTENCY, 0).await?; @@ -889,7 +905,7 @@ impl Stream { /// ``` /// # use veloren_network::Promises; /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, ProtocolAddr, Pid}; + /// use veloren_network::{Network, ListenAddr, ConnectAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on Port `2220` and wait for a Stream to be opened, then listen on it @@ -897,8 +913,8 @@ impl Stream { /// let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2220".parse().unwrap())).await?; - /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2220".parse().unwrap())).await?; + /// network.listen(ListenAddr::Tcp("127.0.0.1:2220".parse().unwrap())).await?; + /// # let remote_p = remote.connect(ConnectAddr::Tcp("127.0.0.1:2220".parse().unwrap())).await?; /// # let mut stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY, 0).await?; /// # stream_p.send("Hello World"); /// let participant_a = network.connected().await?; @@ -923,7 +939,7 @@ impl Stream { /// ``` /// # use veloren_network::Promises; /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, ProtocolAddr, Pid}; + /// use veloren_network::{Network, ListenAddr, ConnectAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on Port `2230` and wait for a Stream to be opened, then listen on it @@ -931,8 +947,8 @@ impl Stream { /// let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2230".parse().unwrap())).await?; - /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2230".parse().unwrap())).await?; + /// network.listen(ListenAddr::Tcp("127.0.0.1:2230".parse().unwrap())).await?; + /// # let remote_p = remote.connect(ConnectAddr::Tcp("127.0.0.1:2230".parse().unwrap())).await?; /// # let mut stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY, 0).await?; /// # stream_p.send("Hello World"); /// let participant_a = network.connected().await?; @@ -979,7 +995,7 @@ impl Stream { /// ``` /// # use veloren_network::Promises; /// use tokio::runtime::Runtime; - /// use veloren_network::{Network, ProtocolAddr, Pid}; + /// use veloren_network::{Network, ListenAddr, ConnectAddr, Pid}; /// /// # fn main() -> std::result::Result<(), Box> { /// // Create a Network, listen on Port `2240` and wait for a Stream to be opened, then listen on it @@ -987,8 +1003,8 @@ impl Stream { /// let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// runtime.block_on(async { - /// network.listen(ProtocolAddr::Tcp("127.0.0.1:2240".parse().unwrap())).await?; - /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2240".parse().unwrap())).await?; + /// network.listen(ListenAddr::Tcp("127.0.0.1:2240".parse().unwrap())).await?; + /// # let remote_p = remote.connect(ConnectAddr::Tcp("127.0.0.1:2240".parse().unwrap())).await?; /// # let mut stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY, 0).await?; /// # stream_p.send("Hello World"); /// # std::thread::sleep(std::time::Duration::from_secs(1)); diff --git a/network/src/channel.rs b/network/src/channel.rs index cf7a7851bd..03930c03ec 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -1,44 +1,121 @@ +use crate::api::NetworkConnectError; use async_trait::async_trait; use bytes::BytesMut; +use futures_util::FutureExt; +#[cfg(feature = "quic")] +use futures_util::StreamExt; use network_protocol::{ Bandwidth, Cid, InitProtocolError, MpscMsg, MpscRecvProtocol, MpscSendProtocol, Pid, ProtocolError, ProtocolEvent, ProtocolMetricCache, ProtocolMetrics, Sid, TcpRecvProtocol, TcpSendProtocol, UnreliableDrain, UnreliableSink, }; -use std::{sync::Arc, time::Duration}; +#[cfg(feature = "quic")] +use network_protocol::{QuicDataFormat, QuicDataFormatStream, QuicRecvProtocol, QuicSendProtocol}; +use std::{ + collections::HashMap, + io, + net::SocketAddr, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::Duration, +}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, + net, net::tcp::{OwnedReadHalf, OwnedWriteHalf}, - sync::mpsc, + select, + sync::{mpsc, oneshot, Mutex}, }; +use tracing::{error, info, trace, warn}; +#[allow(clippy::large_enum_variant)] #[derive(Debug)] pub(crate) enum Protocols { Tcp((TcpSendProtocol, TcpRecvProtocol)), Mpsc((MpscSendProtocol, MpscRecvProtocol)), + #[cfg(feature = "quic")] + Quic((QuicSendProtocol, QuicRecvProtocol)), } #[derive(Debug)] pub(crate) enum SendProtocols { Tcp(TcpSendProtocol), Mpsc(MpscSendProtocol), + #[cfg(feature = "quic")] + Quic(QuicSendProtocol), } #[derive(Debug)] pub(crate) enum RecvProtocols { Tcp(TcpRecvProtocol), Mpsc(MpscRecvProtocol), + #[cfg(feature = "quic")] + Quic(QuicRecvProtocol), } -impl Protocols { - pub(crate) fn new_tcp( - stream: tokio::net::TcpStream, - cid: Cid, - metrics: Arc, - ) -> Self { - let (r, w) = stream.into_split(); - let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); +lazy_static::lazy_static! { + pub(crate) static ref MPSC_POOL: Mutex>> = { + Mutex::new(HashMap::new()) + }; +} +pub(crate) type C2cMpscConnect = ( + mpsc::Sender, + oneshot::Sender>, +); + +impl Protocols { + const MPSC_CHANNEL_BOUND: usize = 1000; + + pub(crate) async fn with_tcp_connect( + addr: SocketAddr, + metrics: ProtocolMetricCache, + ) -> Result { + let stream = net::TcpStream::connect(addr) + .await + .map_err(NetworkConnectError::Io)?; + info!( + "Connecting Tcp to: {}", + stream.peer_addr().map_err(NetworkConnectError::Io)? + ); + Ok(Self::new_tcp(stream, metrics)) + } + + pub(crate) async fn with_tcp_listen( + addr: SocketAddr, + cids: Arc, + metrics: Arc, + s2s_stop_listening_r: oneshot::Receiver<()>, + c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>, + ) -> std::io::Result<()> { + let listener = net::TcpListener::bind(addr).await?; + trace!(?addr, "Tcp Listener bound"); + let mut end_receiver = s2s_stop_listening_r.fuse(); + tokio::spawn(async move { + while let Some(data) = select! { + next = listener.accept().fuse() => Some(next), + _ = &mut end_receiver => None, + } { + let (stream, remote_addr) = match data { + Ok((s, p)) => (s, p), + Err(e) => { + trace!(?e, "TcpStream Error, ignoring connection attempt"); + continue; + }, + }; + let cid = cids.fetch_add(1, Ordering::Relaxed); + info!(?remote_addr, ?cid, "Accepting Tcp from"); + let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics)); + let _ = c2s_protocol_s.send((Self::new_tcp(stream, metrics.clone()), cid)); + } + }); + Ok(()) + } + + pub(crate) fn new_tcp(stream: tokio::net::TcpStream, metrics: ProtocolMetricCache) -> Self { + let (r, w) = stream.into_split(); let sp = TcpSendProtocol::new(TcpDrain { half: w }, metrics.clone()); let rp = TcpRecvProtocol::new( TcpSink { @@ -50,23 +127,232 @@ impl Protocols { Protocols::Tcp((sp, rp)) } + pub(crate) async fn with_mpsc_connect( + addr: u64, + metrics: ProtocolMetricCache, + ) -> Result { + let mpsc_s = MPSC_POOL + .lock() + .await + .get(&addr) + .ok_or_else(|| { + NetworkConnectError::Io(io::Error::new( + io::ErrorKind::NotConnected, + "no mpsc listen on this addr", + )) + })? + .clone(); + let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND); + let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel(); + mpsc_s + .send((remote_to_local_s, local_to_remote_oneshot_s)) + .map_err(|_| { + NetworkConnectError::Io(io::Error::new( + io::ErrorKind::BrokenPipe, + "mpsc pipe broke during connect", + )) + })?; + let local_to_remote_s = local_to_remote_oneshot_r + .await + .map_err(|e| NetworkConnectError::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?; + info!(?addr, "Connecting Mpsc"); + Ok(Self::new_mpsc( + local_to_remote_s, + remote_to_local_r, + metrics, + )) + } + + pub(crate) async fn with_mpsc_listen( + addr: u64, + cids: Arc, + metrics: Arc, + s2s_stop_listening_r: oneshot::Receiver<()>, + c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>, + ) -> std::io::Result<()> { + let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel(); + MPSC_POOL.lock().await.insert(addr, mpsc_s); + trace!(?addr, "Mpsc Listener bound"); + let mut end_receiver = s2s_stop_listening_r.fuse(); + tokio::spawn(async move { + while let Some((local_to_remote_s, local_remote_to_local_s)) = select! { + next = mpsc_r.recv().fuse() => next, + _ = &mut end_receiver => None, + } { + let (remote_to_local_s, remote_to_local_r) = + mpsc::channel(Self::MPSC_CHANNEL_BOUND); + if let Err(e) = local_remote_to_local_s.send(remote_to_local_s) { + error!(?e, "mpsc listen aborted"); + } + + let cid = cids.fetch_add(1, Ordering::Relaxed); + info!(?addr, ?cid, "Accepting Mpsc from"); + let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics)); + let _ = c2s_protocol_s.send(( + Self::new_mpsc(local_to_remote_s, remote_to_local_r, metrics.clone()), + cid, + )); + } + warn!("MpscStream Failed, stopping"); + }); + Ok(()) + } + pub(crate) fn new_mpsc( sender: mpsc::Sender, receiver: mpsc::Receiver, - cid: Cid, - metrics: Arc, + metrics: ProtocolMetricCache, ) -> Self { - let metrics = ProtocolMetricCache::new(&cid.to_string(), metrics); - let sp = MpscSendProtocol::new(MpscDrain { sender }, metrics.clone()); let rp = MpscRecvProtocol::new(MpscSink { receiver }, metrics); Protocols::Mpsc((sp, rp)) } + #[cfg(feature = "quic")] + pub(crate) async fn with_quic_connect( + addr: SocketAddr, + config: quinn::ClientConfig, + name: String, + metrics: ProtocolMetricCache, + ) -> Result { + let config = config.clone(); + let endpoint = quinn::Endpoint::builder(); + + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + + let bindsock = match addr { + SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0), + SocketAddr::V6(_) => { + SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0) + }, + }; + let (endpoint, _) = match endpoint.bind(&bindsock) { + Ok(e) => e, + Err(quinn::EndpointError::Socket(e)) => return Err(NetworkConnectError::Io(e)), + }; + + info!("Connecting Quic to: {}", &addr); + let connecting = endpoint.connect_with(config, &addr, &name).map_err(|e| { + trace!(?e, "error setting up quic"); + NetworkConnectError::Io(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + e, + )) + })?; + let connection = connecting.await.map_err(|e| { + trace!(?e, "error with quic connection"); + NetworkConnectError::Io(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + e, + )) + })?; + Self::new_quic(connection, false, metrics) + .await + .map_err(|e| { + trace!(?e, "error with quic"); + NetworkConnectError::Io(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + e, + )) + }) + } + + #[cfg(feature = "quic")] + pub(crate) async fn with_quic_listen( + addr: SocketAddr, + server_config: quinn::ServerConfig, + cids: Arc, + metrics: Arc, + s2s_stop_listening_r: oneshot::Receiver<()>, + c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>, + ) -> std::io::Result<()> { + let mut endpoint = quinn::Endpoint::builder(); + endpoint.listen(server_config); + let (_endpoint, mut listener) = match endpoint.bind(&addr) { + Ok(v) => v, + Err(quinn::EndpointError::Socket(e)) => return Err(e), + }; + trace!(?addr, "Quic Listener bound"); + let mut end_receiver = s2s_stop_listening_r.fuse(); + tokio::spawn(async move { + while let Some(Some(connecting)) = select! { + next = listener.next().fuse() => Some(next), + _ = &mut end_receiver => None, + } { + let remote_addr = connecting.remote_address(); + let connection = match connecting.await { + Ok(c) => c, + Err(e) => { + tracing::debug!(?e, ?remote_addr, "skipping connection attempt"); + continue; + }, + }; + + let cid = cids.fetch_add(1, Ordering::Relaxed); + info!(?remote_addr, ?cid, "Accepting Quic from"); + let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics)); + match Protocols::new_quic(connection, true, metrics).await { + Ok(quic) => { + let _ = c2s_protocol_s.send((quic, cid)); + }, + Err(e) => { + trace!(?e, "failed to start quic"); + continue; + }, + } + } + }); + Ok(()) + } + + #[cfg(feature = "quic")] + pub(crate) async fn new_quic( + mut connection: quinn::NewConnection, + listen: bool, + metrics: ProtocolMetricCache, + ) -> Result { + let (sendstream, recvstream) = if listen { + connection.connection.open_bi().await? + } else { + connection + .bi_streams + .next() + .await + .ok_or(quinn::ConnectionError::LocallyClosed)?? + }; + let (recvstreams_s, recvstreams_r) = mpsc::unbounded_channel(); + let streams_s_clone = recvstreams_s.clone(); + let (sendstreams_s, sendstreams_r) = mpsc::unbounded_channel(); + let sp = QuicSendProtocol::new( + QuicDrain { + con: connection.connection.clone(), + main: sendstream, + reliables: std::collections::HashMap::new(), + recvstreams_s: streams_s_clone, + sendstreams_r, + }, + metrics.clone(), + ); + spawn_new(recvstream, None, &recvstreams_s); + let rp = QuicRecvProtocol::new( + QuicSink { + con: connection.connection, + bi: connection.bi_streams, + recvstreams_r, + recvstreams_s, + sendstreams_s, + }, + metrics, + ); + Ok(Protocols::Quic((sp, rp))) + } + pub(crate) fn split(self) -> (SendProtocols, RecvProtocols) { match self { Protocols::Tcp((s, r)) => (SendProtocols::Tcp(s), RecvProtocols::Tcp(r)), Protocols::Mpsc((s, r)) => (SendProtocols::Mpsc(s), RecvProtocols::Mpsc(r)), + #[cfg(feature = "quic")] + Protocols::Quic((s, r)) => (SendProtocols::Quic(s), RecvProtocols::Quic(r)), } } } @@ -82,6 +368,8 @@ impl network_protocol::InitProtocol for Protocols { match self { Protocols::Tcp(p) => p.initialize(initializer, local_pid, secret).await, Protocols::Mpsc(p) => p.initialize(initializer, local_pid, secret).await, + #[cfg(feature = "quic")] + Protocols::Quic(p) => p.initialize(initializer, local_pid, secret).await, } } } @@ -92,6 +380,8 @@ impl network_protocol::SendProtocol for SendProtocols { match self { SendProtocols::Tcp(s) => s.notify_from_recv(event), SendProtocols::Mpsc(s) => s.notify_from_recv(event), + #[cfg(feature = "quic")] + SendProtocols::Quic(s) => s.notify_from_recv(event), } } @@ -99,6 +389,8 @@ impl network_protocol::SendProtocol for SendProtocols { match self { SendProtocols::Tcp(s) => s.send(event).await, SendProtocols::Mpsc(s) => s.send(event).await, + #[cfg(feature = "quic")] + SendProtocols::Quic(s) => s.send(event).await, } } @@ -110,6 +402,8 @@ impl network_protocol::SendProtocol for SendProtocols { match self { SendProtocols::Tcp(s) => s.flush(bandwidth, dt).await, SendProtocols::Mpsc(s) => s.flush(bandwidth, dt).await, + #[cfg(feature = "quic")] + SendProtocols::Quic(s) => s.flush(bandwidth, dt).await, } } } @@ -120,6 +414,8 @@ impl network_protocol::RecvProtocol for RecvProtocols { match self { RecvProtocols::Tcp(r) => r.recv().await, RecvProtocols::Mpsc(r) => r.recv().await, + #[cfg(feature = "quic")] + RecvProtocols::Quic(r) => r.recv().await, } } } @@ -196,11 +492,164 @@ impl UnreliableSink for MpscSink { } } +/////////////////////////////////////// +//// QUIC +#[cfg(feature = "quic")] +type QuicStream = ( + BytesMut, + Result, quinn::ReadError>, + quinn::RecvStream, + Option, +); + +#[cfg(feature = "quic")] +#[derive(Debug)] +pub struct QuicDrain { + con: quinn::Connection, + main: quinn::SendStream, + reliables: std::collections::HashMap, + recvstreams_s: mpsc::UnboundedSender, + sendstreams_r: mpsc::UnboundedReceiver, +} + +#[cfg(feature = "quic")] +#[derive(Debug)] +pub struct QuicSink { + con: quinn::Connection, + bi: quinn::IncomingBiStreams, + recvstreams_r: mpsc::UnboundedReceiver, + recvstreams_s: mpsc::UnboundedSender, + sendstreams_s: mpsc::UnboundedSender, +} + +#[cfg(feature = "quic")] +fn spawn_new( + mut recvstream: quinn::RecvStream, + sid: Option, + streams_s: &mpsc::UnboundedSender, +) { + let streams_s_clone = streams_s.clone(); + tokio::spawn(async move { + let mut buffer = BytesMut::new(); + buffer.resize(1500, 0u8); + let r = recvstream.read(&mut buffer).await; + let _ = streams_s_clone.send((buffer, r, recvstream, sid)); + }); +} + +#[cfg(feature = "quic")] +#[async_trait] +impl UnreliableDrain for QuicDrain { + type DataFormat = QuicDataFormat; + + async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { + match match data.stream { + QuicDataFormatStream::Main => self.main.write_all(&data.data).await, + QuicDataFormatStream::Unreliable => unimplemented!(), + QuicDataFormatStream::Reliable(sid) => { + use std::collections::hash_map::Entry; + tracing::trace!(?sid, "Reliable"); + match self.reliables.entry(sid) { + Entry::Occupied(mut occupied) => occupied.get_mut().write_all(&data.data).await, + Entry::Vacant(vacant) => { + // IF the buffer is empty this was created localy and WE are allowed to + // open_bi(), if not, we NEED to block on sendstreams_r + if data.data.is_empty() { + match self.con.open_bi().await { + Ok((mut sendstream, recvstream)) => { + // send SID as first msg + if sendstream.write_u64(sid.get_u64()).await.is_err() { + return Err(ProtocolError::Closed); + } + spawn_new(recvstream, Some(sid), &self.recvstreams_s); + vacant.insert(sendstream).write_all(&data.data).await + }, + Err(_) => return Err(ProtocolError::Closed), + } + } else { + let sendstream = self + .sendstreams_r + .recv() + .await + .ok_or(ProtocolError::Closed)?; + vacant.insert(sendstream).write_all(&data.data).await + } + }, + } + }, + } { + Ok(()) => Ok(()), + Err(_) => Err(ProtocolError::Closed), + } + } +} + +#[cfg(feature = "quic")] +#[async_trait] +impl UnreliableSink for QuicSink { + type DataFormat = QuicDataFormat; + + async fn recv(&mut self) -> Result { + let (mut buffer, result, mut recvstream, id) = loop { + use futures_util::FutureExt; + // first handle all bi streams! + let (a, b) = tokio::select! { + biased; + Some(n) = self.bi.next().fuse() => (Some(n), None), + Some(n) = self.recvstreams_r.recv().fuse() => (None, Some(n)), + }; + + if let Some(remote_stream) = a { + match remote_stream { + Ok((sendstream, mut recvstream)) => { + let sid = match recvstream.read_u64().await { + Ok(u64::MAX) => None, //unreliable + Ok(sid) => Some(Sid::new(sid)), + Err(_) => return Err(ProtocolError::Violated), + }; + if self.sendstreams_s.send(sendstream).is_err() { + return Err(ProtocolError::Closed); + } + spawn_new(recvstream, sid, &self.recvstreams_s); + }, + Err(_) => return Err(ProtocolError::Closed), + } + } + + if let Some(data) = b { + break data; + } + }; + + let r = match result { + Ok(Some(0)) => Err(ProtocolError::Closed), + Ok(Some(n)) => Ok(QuicDataFormat { + stream: match id { + Some(id) => QuicDataFormatStream::Reliable(id), + None => QuicDataFormatStream::Main, + }, + data: buffer.split_to(n), + }), + Ok(None) => Err(ProtocolError::Closed), + Err(_) => Err(ProtocolError::Closed), + }?; + + let streams_s_clone = self.recvstreams_s.clone(); + tokio::spawn(async move { + buffer.resize(1500, 0u8); + let r = recvstream.read(&mut buffer).await; + let _ = streams_s_clone.send((buffer, r, recvstream, id)); + }); + Ok(r) + } +} + #[cfg(test)] mod tests { use super::*; use bytes::Bytes; - use network_protocol::{Promises, RecvProtocol, SendProtocol}; + use network_protocol::{Promises, ProtocolMetrics, RecvProtocol, SendProtocol}; + use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; #[tokio::test] @@ -212,9 +661,9 @@ mod tests { }); let client = TcpStream::connect("127.0.0.1:5000").await.unwrap(); let (_listener, server) = r1.await.unwrap(); - let metrics = Arc::new(ProtocolMetrics::new().unwrap()); - let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics)); - let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics)); + let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap())); + let client = Protocols::new_tcp(client, metrics.clone()); + let server = Protocols::new_tcp(server, metrics); let (mut s, _) = client.split(); let (_, mut r) = server.split(); let event = ProtocolEvent::OpenStream { @@ -261,9 +710,9 @@ mod tests { }); let client = TcpStream::connect("127.0.0.1:5001").await.unwrap(); let (_listener, server) = r1.await.unwrap(); - let metrics = Arc::new(ProtocolMetrics::new().unwrap()); - let client = Protocols::new_tcp(client, 0, Arc::clone(&metrics)); - let server = Protocols::new_tcp(server, 0, Arc::clone(&metrics)); + let metrics = ProtocolMetricCache::new("0", Arc::new(ProtocolMetrics::new().unwrap())); + let client = Protocols::new_tcp(client, metrics.clone()); + let server = Protocols::new_tcp(server, metrics); let (s, _) = client.split(); let (_, mut r) = server.split(); let e = tokio::spawn(async move { r.recv().await }); diff --git a/network/src/lib.rs b/network/src/lib.rs index 448b50f41c..70b68fbafb 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -13,14 +13,14 @@ //! Say you have an application that wants to communicate with other application //! over a Network or on the same computer. Now each application instances the //! struct [`Network`] once with a new [`Pid`]. The Pid is necessary to identify -//! other [`Networks`] over the network protocols (e.g. TCP, UDP) +//! other [`Networks`] over the network protocols (e.g. TCP, UDP, QUIC, MPSC) //! -//! To connect to another application, you must know it's [`ProtocolAddr`]. One +//! To connect to another application, you must know it's [`ConnectAddr`]. One //! side will call [`connect`], the other [`connected`]. If successful both //! applications will now get a [`Participant`]. //! //! This [`Participant`] represents the connection between those 2 applications. -//! over the respective [`ProtocolAddr`] and with it the chosen network +//! over the respective [`ConnectAddr`] and with it the chosen network //! protocol. However messages can't be send directly via [`Participants`], //! instead you must open a [`Stream`] on it. Like above, one side has to call //! [`open`], the other [`opened`]. [`Streams`] can have a different priority @@ -41,14 +41,14 @@ //! ```rust //! use std::sync::Arc; //! use tokio::{join, runtime::Runtime, time::sleep}; -//! use veloren_network::{Network, Pid, Promises, ProtocolAddr}; +//! use veloren_network::{ConnectAddr, ListenAddr, Network, Pid, Promises}; //! //! // Client //! async fn client(runtime: &Runtime) -> std::result::Result<(), Box> { //! sleep(std::time::Duration::from_secs(1)).await; // `connect` MUST be after `listen` //! let client_network = Network::new(Pid::new(), runtime); //! let server = client_network -//! .connect(ProtocolAddr::Tcp("127.0.0.1:12345".parse().unwrap())) +//! .connect(ConnectAddr::Tcp("127.0.0.1:12345".parse().unwrap())) //! .await?; //! let mut stream = server //! .open(4, Promises::ORDERED | Promises::CONSISTENCY, 0) @@ -61,7 +61,7 @@ //! async fn server(runtime: &Runtime) -> std::result::Result<(), Box> { //! let server_network = Network::new(Pid::new(), runtime); //! server_network -//! .listen(ProtocolAddr::Tcp("127.0.0.1:12345".parse().unwrap())) +//! .listen(ListenAddr::Tcp("127.0.0.1:12345".parse().unwrap())) //! .await?; //! let client = server_network.connected().await?; //! let mut stream = client.opened().await?; @@ -95,7 +95,8 @@ //! [`send`]: crate::api::Stream::send //! [`recv`]: crate::api::Stream::recv //! [`Pid`]: network_protocol::Pid -//! [`ProtocolAddr`]: crate::api::ProtocolAddr +//! [`ListenAddr`]: crate::api::ListenAddr +//! [`ConnectAddr`]: crate::api::ConnectAddr //! [`Promises`]: network_protocol::Promises mod api; @@ -107,8 +108,8 @@ mod scheduler; mod util; pub use api::{ - Network, NetworkConnectError, NetworkError, Participant, ParticipantError, ProtocolAddr, - Stream, StreamError, StreamParams, + ConnectAddr, ListenAddr, Network, NetworkConnectError, NetworkError, Participant, + ParticipantError, Stream, StreamError, StreamParams, }; pub use message::Message; pub use network_protocol::{InitProtocolError, Pid, Promises}; diff --git a/network/src/message.rs b/network/src/message.rs index bc81e25802..f821511450 100644 --- a/network/src/message.rs +++ b/network/src/message.rs @@ -30,7 +30,7 @@ impl Message { /// # Example /// for example coding, see [`send_raw`] /// - /// [`send_raw`]: Stream::send_raw + /// [`send_raw`]: crate::api::Stream::send_raw /// [`Participants`]: crate::api::Participant /// [`compress`]: lz_fear::raw::compress2 /// [`Message::serialize`]: crate::message::Message::serialize @@ -70,7 +70,7 @@ impl Message { /// /// # Example /// ``` - /// # use veloren_network::{Network, ProtocolAddr, Pid}; + /// # use veloren_network::{Network, ListenAddr, ConnectAddr, Pid}; /// # use veloren_network::Promises; /// # use tokio::runtime::Runtime; /// # use std::sync::Arc; @@ -81,8 +81,8 @@ impl Message { /// # let network = Network::new(Pid::new(), &runtime); /// # let remote = Network::new(Pid::new(), &runtime); /// # runtime.block_on(async { - /// # network.listen(ProtocolAddr::Tcp("127.0.0.1:2300".parse().unwrap())).await?; - /// # let remote_p = remote.connect(ProtocolAddr::Tcp("127.0.0.1:2300".parse().unwrap())).await?; + /// # network.listen(ListenAddr::Tcp("127.0.0.1:2300".parse().unwrap())).await?; + /// # let remote_p = remote.connect(ConnectAddr::Tcp("127.0.0.1:2300".parse().unwrap())).await?; /// # let mut stream_p = remote_p.open(4, Promises::ORDERED | Promises::CONSISTENCY, 0).await?; /// # stream_p.send("Hello World"); /// # let participant_a = network.connected().await?; diff --git a/network/src/metrics.rs b/network/src/metrics.rs index c46fe16bda..f3341e392b 100644 --- a/network/src/metrics.rs +++ b/network/src/metrics.rs @@ -1,8 +1,29 @@ -use crate::api::ProtocolAddr; +use crate::api::{ConnectAddr, ListenAddr}; use network_protocol::{Cid, Pid}; #[cfg(feature = "metrics")] use prometheus::{IntCounter, IntCounterVec, IntGauge, IntGaugeVec, Opts, Registry}; -use std::error::Error; +use std::{error::Error, net::SocketAddr}; + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub(crate) enum ProtocolInfo { + Tcp(SocketAddr), + Udp(SocketAddr), + #[cfg(feature = "quic")] + Quic(SocketAddr), + Mpsc(u64), +} + +impl From for ProtocolInfo { + fn from(other: ListenAddr) -> ProtocolInfo { + match other { + ListenAddr::Tcp(s) => ProtocolInfo::Tcp(s), + ListenAddr::Udp(s) => ProtocolInfo::Udp(s), + #[cfg(feature = "quic")] + ListenAddr::Quic(s, _) => ProtocolInfo::Quic(s), + ListenAddr::Mpsc(s) => ProtocolInfo::Mpsc(s), + } + } +} /// 1:1 relation between NetworkMetrics and Network #[cfg(feature = "metrics")] @@ -154,9 +175,9 @@ impl NetworkMetrics { Ok(()) } - pub(crate) fn connect_requests_cache(&self, protocol: &ProtocolAddr) -> prometheus::IntCounter { + pub(crate) fn connect_requests_cache(&self, protocol: &ListenAddr) -> prometheus::IntCounter { self.incoming_connections_total - .with_label_values(&[protocol_name(protocol)]) + .with_label_values(&[protocollisten_name(protocol)]) } pub(crate) fn channels_connected(&self, remote_p: &str, no: usize, cid: Cid) { @@ -192,15 +213,15 @@ impl NetworkMetrics { .inc(); } - pub(crate) fn listen_request(&self, protocol: &ProtocolAddr) { + pub(crate) fn listen_request(&self, protocol: &ListenAddr) { self.listen_requests_total - .with_label_values(&[protocol_name(protocol)]) + .with_label_values(&[protocollisten_name(protocol)]) .inc(); } - pub(crate) fn connect_request(&self, protocol: &ProtocolAddr) { + pub(crate) fn connect_request(&self, protocol: &ConnectAddr) { self.connect_requests_total - .with_label_values(&[protocol_name(protocol)]) + .with_label_values(&[protocolconnect_name(protocol)]) .inc(); } @@ -225,11 +246,24 @@ impl NetworkMetrics { } #[cfg(feature = "metrics")] -fn protocol_name(protocol: &ProtocolAddr) -> &str { +fn protocolconnect_name(protocol: &ConnectAddr) -> &str { match protocol { - ProtocolAddr::Tcp(_) => "tcp", - ProtocolAddr::Udp(_) => "udp", - ProtocolAddr::Mpsc(_) => "mpsc", + ConnectAddr::Tcp(_) => "tcp", + ConnectAddr::Udp(_) => "udp", + ConnectAddr::Mpsc(_) => "mpsc", + #[cfg(feature = "quic")] + ConnectAddr::Quic(_, _, _) => "quic", + } +} + +#[cfg(feature = "metrics")] +fn protocollisten_name(protocol: &ListenAddr) -> &str { + match protocol { + ListenAddr::Tcp(_) => "tcp", + ListenAddr::Udp(_) => "udp", + ListenAddr::Mpsc(_) => "mpsc", + #[cfg(feature = "quic")] + ListenAddr::Quic(_, _) => "quic", } } @@ -247,9 +281,9 @@ impl NetworkMetrics { pub(crate) fn streams_closed(&self, _remote_p: &str) {} - pub(crate) fn listen_request(&self, _protocol: &ProtocolAddr) {} + pub(crate) fn listen_request(&self, _protocol: &ListenAddr) {} - pub(crate) fn connect_request(&self, _protocol: &ProtocolAddr) {} + pub(crate) fn connect_request(&self, _protocol: &ConnectAddr) {} pub(crate) fn cleanup_participant(&self, _remote_p: &str) {} } diff --git a/network/src/participant.rs b/network/src/participant.rs index afa30266e8..2735fd5bdd 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -2,12 +2,13 @@ use crate::{ api::{ParticipantError, Stream}, channel::{Protocols, RecvProtocols, SendProtocols}, metrics::NetworkMetrics, - util::{DeferredTracer, SortedVec}, + util::DeferredTracer, }; use bytes::Bytes; use futures_util::{FutureExt, StreamExt}; use network_protocol::{ Bandwidth, Cid, Pid, Prio, Promises, ProtocolEvent, RecvProtocol, SendProtocol, Sid, + _internal::SortedVec, }; use std::{ collections::HashMap, @@ -755,7 +756,7 @@ impl BParticipant { #[cfg(test)] mod tests { use super::*; - use network_protocol::ProtocolMetrics; + use network_protocol::{ProtocolMetricCache, ProtocolMetrics}; use tokio::{ runtime::Runtime, sync::{mpsc, oneshot}, @@ -815,14 +816,16 @@ mod tests { ) -> Protocols { let (s1, r1) = mpsc::channel(100); let (s2, r2) = mpsc::channel(100); - let metrics = Arc::new(ProtocolMetrics::new().unwrap()); - let p1 = Protocols::new_mpsc(s1, r2, cid, Arc::clone(&metrics)); + let met = Arc::new(ProtocolMetrics::new().unwrap()); + let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&met)); + let p1 = Protocols::new_mpsc(s1, r2, metrics); let (complete_s, complete_r) = oneshot::channel(); create_channel .send((cid, Sid::new(0), p1, complete_s)) .unwrap(); complete_r.await.unwrap(); - Protocols::new_mpsc(s2, r1, cid, Arc::clone(&metrics)) + let metrics = ProtocolMetricCache::new(&cid.to_string(), met); + Protocols::new_mpsc(s2, r1, metrics) } #[test] diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index 527ea6f5fe..a232be440b 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -1,11 +1,11 @@ use crate::{ - api::{NetworkConnectError, Participant, ProtocolAddr}, + api::{ConnectAddr, ListenAddr, NetworkConnectError, Participant}, channel::Protocols, - metrics::NetworkMetrics, + metrics::{NetworkMetrics, ProtocolInfo}, participant::{B2sPrioStatistic, BParticipant, S2bCreateChannel, S2bShutdownBparticipant}, }; -use futures_util::{FutureExt, StreamExt}; -use network_protocol::{Cid, MpscMsg, Pid, ProtocolMetrics}; +use futures_util::StreamExt; +use network_protocol::{Cid, Pid, ProtocolMetricCache, ProtocolMetrics}; #[cfg(feature = "metrics")] use prometheus::Registry; use rand::Rng; @@ -18,7 +18,7 @@ use std::{ time::Duration, }; use tokio::{ - io, net, select, + io, sync::{mpsc, oneshot, Mutex}, }; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -33,12 +33,6 @@ use tracing::*; // - w: wire // - c: channel/handshake -lazy_static::lazy_static! { - static ref MPSC_POOL: Mutex>> = { - Mutex::new(HashMap::new()) - }; -} - #[derive(Debug)] struct ParticipantInfo { secret: u128, @@ -46,16 +40,12 @@ struct ParticipantInfo { s2b_shutdown_bparticipant_s: Option>, } -type A2sListen = (ProtocolAddr, oneshot::Sender>); +type A2sListen = (ListenAddr, oneshot::Sender>); pub(crate) type A2sConnect = ( - ProtocolAddr, + ConnectAddr, oneshot::Sender>, ); type A2sDisconnect = (Pid, S2bShutdownBparticipant); -type S2sMpscConnect = ( - mpsc::Sender, - oneshot::Sender>, -); #[derive(Debug)] struct ControlChannels { @@ -82,14 +72,12 @@ pub struct Scheduler { participant_channels: Arc>>, participants: Arc>>, channel_ids: Arc, - channel_listener: Mutex>>, + channel_listener: Mutex>>, metrics: Arc, protocol_metrics: Arc, } impl Scheduler { - const MPSC_CHANNEL_BOUND: usize = 1000; - pub fn new( local_pid: Pid, #[cfg(feature = "metrics")] registry: Option<&Registry>, @@ -157,7 +145,10 @@ impl Scheduler { } pub async fn run(mut self) { - let run_channels = self.run_channels.take().unwrap(); + let run_channels = self + .run_channels + .take() + .expect("run() can only be called once"); tokio::join!( self.listen_mgr(run_channels.a2s_listen_r), @@ -174,17 +165,66 @@ impl Scheduler { a2s_listen_r .for_each_concurrent(None, |(address, s2a_listen_result_s)| { let address = address; + let cids = Arc::clone(&self.channel_ids); + + #[cfg(feature = "metrics")] + let mcache = self.metrics.connect_requests_cache(&address); + + debug!(?address, "Got request to open a channel_creator"); + self.metrics.listen_request(&address); + let (s2s_stop_listening_s, s2s_stop_listening_r) = oneshot::channel::<()>(); + let (c2s_protocol_s, mut c2s_protocol_r) = mpsc::unbounded_channel(); + let metrics = Arc::clone(&self.protocol_metrics); async move { - debug!(?address, "Got request to open a channel_creator"); - self.metrics.listen_request(&address); - let (end_sender, end_receiver) = oneshot::channel::<()>(); self.channel_listener .lock() .await - .insert(address.clone(), end_sender); - self.channel_creator(address, end_receiver, s2a_listen_result_s) - .await; + .insert(address.clone().into(), s2s_stop_listening_s); + + #[cfg(feature = "metrics")] + mcache.inc(); + + let res = match address { + ListenAddr::Tcp(addr) => { + Protocols::with_tcp_listen( + addr, + cids, + metrics, + s2s_stop_listening_r, + c2s_protocol_s, + ) + .await + }, + #[cfg(feature = "quic")] + ListenAddr::Quic(addr, ref server_config) => { + Protocols::with_quic_listen( + addr, + server_config.clone(), + cids, + metrics, + s2s_stop_listening_r, + c2s_protocol_s, + ) + .await + }, + ListenAddr::Mpsc(addr) => { + Protocols::with_mpsc_listen( + addr, + cids, + metrics, + s2s_stop_listening_r, + c2s_protocol_s, + ) + .await + }, + _ => unimplemented!(), + }; + let _ = s2a_listen_result_s.send(res); + + while let Some((prot, cid)) = c2s_protocol_r.recv().await { + self.init_protocol(prot, cid, None, true).await; + } } }) .await; @@ -195,82 +235,26 @@ impl Scheduler { trace!("Start connect_mgr"); while let Some((addr, pid_sender)) = a2s_connect_r.recv().await { let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); - let metrics = Arc::clone(&self.protocol_metrics); + let metrics = + ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&self.protocol_metrics)); self.metrics.connect_request(&addr); - let (protocol, handshake) = match addr { - ProtocolAddr::Tcp(addr) => { - let stream = match net::TcpStream::connect(addr).await { - Ok(stream) => stream, - Err(e) => { - pid_sender.send(Err(NetworkConnectError::Io(e))).unwrap(); - continue; - }, - }; - info!("Connecting Tcp to: {}", stream.peer_addr().unwrap()); - (Protocols::new_tcp(stream, cid, metrics), false) + let protocol = match addr { + ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, metrics).await, + #[cfg(feature = "quic")] + ConnectAddr::Quic(addr, ref config, name) => { + Protocols::with_quic_connect(addr, config.clone(), name, metrics).await }, - ProtocolAddr::Mpsc(addr) => { - let mpsc_s = match MPSC_POOL.lock().await.get(&addr) { - Some(s) => s.clone(), - None => { - pid_sender - .send(Err(NetworkConnectError::Io(std::io::Error::new( - std::io::ErrorKind::NotConnected, - "no mpsc listen on this addr", - )))) - .unwrap(); - continue; - }, - }; - let (remote_to_local_s, remote_to_local_r) = - mpsc::channel(Self::MPSC_CHANNEL_BOUND); - let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel(); - mpsc_s - .send((remote_to_local_s, local_to_remote_oneshot_s)) - .unwrap(); - let local_to_remote_s = local_to_remote_oneshot_r.await.unwrap(); - info!(?addr, "Connecting Mpsc"); - ( - Protocols::new_mpsc(local_to_remote_s, remote_to_local_r, cid, metrics), - false, - ) - }, - /* */ - //ProtocolAddr::Udp(addr) => { - //#[cfg(feature = "metrics")] - //self.metrics - //.connect_requests_total - //.with_label_values(&["udp"]) - //.inc(); - //let socket = match net::UdpSocket::bind("0.0.0.0:0").await { - //Ok(socket) => Arc::new(socket), - //Err(e) => { - //pid_sender.send(Err(e)).unwrap(); - //continue; - //}, - //}; - //if let Err(e) = socket.connect(addr).await { - //pid_sender.send(Err(e)).unwrap(); - //continue; - //}; - //info!("Connecting Udp to: {}", addr); - //let (udp_data_sender, udp_data_receiver) = mpsc::unbounded_channel::>(); - //let protocol = UdpProtocol::new( - //Arc::clone(&socket), - //addr, - //#[cfg(feature = "metrics")] - //Arc::clone(&self.metrics), - //udp_data_receiver, - //); - //self.runtime.spawn( - //Self::udp_single_channel_connect(Arc::clone(&socket), udp_data_sender) - //.instrument(tracing::info_span!("udp", ?addr)), - //); - //(Protocols::Udp(protocol), true) - //}, + ConnectAddr::Mpsc(addr) => Protocols::with_mpsc_connect(addr, metrics).await, _ => unimplemented!(), }; - self.init_protocol(protocol, cid, Some(pid_sender), handshake) + let protocol = match protocol { + Ok(p) => p, + Err(e) => { + pid_sender.send(Err(e)).unwrap(); + continue; + }, + }; + self.init_protocol(protocol, cid, Some(pid_sender), false) .await; } trace!("Stop connect_mgr"); @@ -384,156 +368,6 @@ impl Scheduler { trace!("Stop scheduler_shutdown_mgr"); } - async fn channel_creator( - &self, - addr: ProtocolAddr, - s2s_stop_listening_r: oneshot::Receiver<()>, - s2a_listen_result_s: oneshot::Sender>, - ) { - trace!(?addr, "Start up channel creator"); - #[cfg(feature = "metrics")] - let mcache = self.metrics.connect_requests_cache(&addr); - match addr { - ProtocolAddr::Tcp(addr) => { - let listener = match net::TcpListener::bind(addr).await { - Ok(listener) => { - s2a_listen_result_s.send(Ok(())).unwrap(); - listener - }, - Err(e) => { - info!( - ?addr, - ?e, - "Tcp bind error during listener startup" - ); - s2a_listen_result_s.send(Err(e)).unwrap(); - return; - }, - }; - trace!(?addr, "Listener bound"); - let mut end_receiver = s2s_stop_listening_r.fuse(); - while let Some(data) = select! { - next = listener.accept().fuse() => Some(next), - _ = &mut end_receiver => None, - } { - let (stream, remote_addr) = match data { - Ok((s, p)) => (s, p), - Err(e) => { - warn!(?e, "TcpStream Error, ignoring connection attempt"); - continue; - }, - }; - #[cfg(feature = "metrics")] - mcache.inc(); - let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); - info!(?remote_addr, ?cid, "Accepting Tcp from"); - self.init_protocol(Protocols::new_tcp(stream, cid, Arc::clone(&self.protocol_metrics)), cid, None, true) - .await; - } - }, - ProtocolAddr::Mpsc(addr) => { - let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel(); - MPSC_POOL.lock().await.insert(addr, mpsc_s); - s2a_listen_result_s.send(Ok(())).unwrap(); - trace!(?addr, "Listener bound"); - - let mut end_receiver = s2s_stop_listening_r.fuse(); - while let Some((local_to_remote_s, local_remote_to_local_s)) = select! { - next = mpsc_r.recv().fuse() => next, - _ = &mut end_receiver => None, - } { - let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND); - local_remote_to_local_s.send(remote_to_local_s).unwrap(); - #[cfg(feature = "metrics")] - mcache.inc(); - let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed); - info!(?addr, ?cid, "Accepting Mpsc from"); - self.init_protocol(Protocols::new_mpsc(local_to_remote_s, remote_to_local_r, cid, Arc::clone(&self.protocol_metrics)), cid, None, true) - .await; - } - warn!("MpscStream Failed, stopping"); - },/* - ProtocolAddr::Udp(addr) => { - let socket = match net::UdpSocket::bind(addr).await { - Ok(socket) => { - s2a_listen_result_s.send(Ok(())).unwrap(); - Arc::new(socket) - }, - Err(e) => { - info!( - ?addr, - ?e, - "Listener couldn't be started due to error on udp bind" - ); - s2a_listen_result_s.send(Err(e)).unwrap(); - return; - }, - }; - trace!(?addr, "Listener bound"); - // receiving is done from here and will be piped to protocol as UDP does not - // have any state - let mut listeners = HashMap::new(); - let mut end_receiver = s2s_stop_listening_r.fuse(); - const UDP_MAXIMUM_SINGLE_PACKET_SIZE_EVER: usize = 9216; - let mut data = [0u8; UDP_MAXIMUM_SINGLE_PACKET_SIZE_EVER]; - while let Ok((size, remote_addr)) = select! { - next = socket.recv_from(&mut data).fuse() => next, - _ = &mut end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")), - } { - let mut datavec = Vec::with_capacity(size); - datavec.extend_from_slice(&data[0..size]); - //Due to the async nature i cannot make of .entry() as it would lead to a still - // borrowed in another branch situation - #[allow(clippy::map_entry)] - if !listeners.contains_key(&remote_addr) { - info!("Accepting Udp from: {}", &remote_addr); - let (udp_data_sender, udp_data_receiver) = - mpsc::unbounded_channel::>(); - listeners.insert(remote_addr, udp_data_sender); - let protocol = UdpProtocol::new( - Arc::clone(&socket), - remote_addr, - #[cfg(feature = "metrics")] - Arc::clone(&self.metrics), - udp_data_receiver, - ); - self.init_protocol(Protocols::Udp(protocol), None, false) - .await; - } - let udp_data_sender = listeners.get_mut(&remote_addr).unwrap(); - udp_data_sender.send(datavec).unwrap(); - } - },*/ - _ => unimplemented!(), - } - trace!(?addr, "Ending channel creator"); - } - - #[allow(dead_code)] - async fn udp_single_channel_connect( - socket: Arc, - w2p_udp_package_s: mpsc::UnboundedSender>, - ) { - let addr = socket.local_addr(); - trace!(?addr, "Start udp_single_channel_connect"); - //TODO: implement real closing - let (_end_sender, end_receiver) = oneshot::channel::<()>(); - - // receiving is done from here and will be piped to protocol as UDP does not - // have any state - let mut end_receiver = end_receiver.fuse(); - let mut data = [0u8; 9216]; - while let Ok(size) = select! { - next = socket.recv(&mut data).fuse() => next, - _ = &mut end_receiver => Err(std::io::Error::new(std::io::ErrorKind::Other, "")), - } { - let mut datavec = Vec::with_capacity(size); - datavec.extend_from_slice(&data[0..size]); - w2p_udp_package_s.send(datavec).unwrap(); - } - trace!(?addr, "Stop udp_single_channel_connect"); - } - async fn init_protocol( &self, mut protocol: Protocols, diff --git a/network/src/util.rs b/network/src/util.rs index b9a8801263..640d65ee55 100644 --- a/network/src/util.rs +++ b/network/src/util.rs @@ -44,74 +44,3 @@ impl DeferredTracer { } } } - -/// Used for storing Protocols in a Participant or Stream <-> Protocol -pub(crate) struct SortedVec { - pub data: Vec<(K, V)>, -} - -impl Default for SortedVec { - fn default() -> Self { Self { data: vec![] } } -} - -impl SortedVec -where - K: Ord + Copy, -{ - pub fn insert(&mut self, k: K, v: V) { - self.data.push((k, v)); - self.data.sort_by_key(|&(k, _)| k); - } - - pub fn delete(&mut self, k: &K) -> Option { - if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { - Some(self.data.remove(i).1) - } else { - None - } - } - - pub fn get(&self, k: &K) -> Option<&V> { - if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { - Some(&self.data[i].1) - } else { - None - } - } - - pub fn get_mut(&mut self, k: &K) -> Option<&mut V> { - if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { - Some(&mut self.data[i].1) - } else { - None - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn sorted_vec() { - let mut vec = SortedVec::default(); - vec.insert(10, "Hello"); - println!("{:?}", vec.data); - vec.insert(30, "World"); - println!("{:?}", vec.data); - vec.insert(20, " "); - println!("{:?}", vec.data); - assert_eq!(vec.data[0].1, "Hello"); - assert_eq!(vec.data[1].1, " "); - assert_eq!(vec.data[2].1, "World"); - assert_eq!(vec.get(&30), Some(&"World")); - assert_eq!(vec.get_mut(&20), Some(&mut " ")); - assert_eq!(vec.get(&10), Some(&"Hello")); - assert_eq!(vec.delete(&40), None); - assert_eq!(vec.delete(&10), Some("Hello")); - assert_eq!(vec.delete(&10), None); - assert_eq!(vec.get(&30), Some(&"World")); - assert_eq!(vec.get_mut(&20), Some(&mut " ")); - assert_eq!(vec.get(&10), None); - } -} diff --git a/network/tests/closing.rs b/network/tests/closing.rs index 7d6a2cb0ee..100e84e544 100644 --- a/network/tests/closing.rs +++ b/network/tests/closing.rs @@ -347,8 +347,8 @@ fn open_participant_before_remote_part_is_closed() { let n_a = Network::new(Pid::fake(0), &r); let n_b = Network::new(Pid::fake(1), &r); let addr = tcp(); - r.block_on(n_a.listen(addr.clone())).unwrap(); - let p_b = r.block_on(n_b.connect(addr)).unwrap(); + r.block_on(n_a.listen(addr.0)).unwrap(); + let p_b = r.block_on(n_b.connect(addr.1)).unwrap(); let mut s1_b = r.block_on(p_b.open(4, Promises::empty(), 0)).unwrap(); s1_b.send("HelloWorld").unwrap(); let p_a = r.block_on(n_a.connected()).unwrap(); @@ -367,8 +367,8 @@ fn open_participant_after_remote_part_is_closed() { let n_a = Network::new(Pid::fake(0), &r); let n_b = Network::new(Pid::fake(1), &r); let addr = tcp(); - r.block_on(n_a.listen(addr.clone())).unwrap(); - let p_b = r.block_on(n_b.connect(addr)).unwrap(); + r.block_on(n_a.listen(addr.0)).unwrap(); + let p_b = r.block_on(n_b.connect(addr.1)).unwrap(); let mut s1_b = r.block_on(p_b.open(4, Promises::empty(), 0)).unwrap(); s1_b.send("HelloWorld").unwrap(); drop(s1_b); @@ -387,8 +387,8 @@ fn close_network_scheduler_completely() { let n_a = Network::new(Pid::fake(0), &r); let n_b = Network::new(Pid::fake(1), &r); let addr = tcp(); - r.block_on(n_a.listen(addr.clone())).unwrap(); - let p_b = r.block_on(n_b.connect(addr)).unwrap(); + r.block_on(n_a.listen(addr.0)).unwrap(); + let p_b = r.block_on(n_b.connect(addr.1)).unwrap(); let mut s1_b = r.block_on(p_b.open(4, Promises::empty(), 0)).unwrap(); s1_b.send("HelloWorld").unwrap(); diff --git a/network/tests/helper.rs b/network/tests/helper.rs index 68d5cebd87..9e78928f55 100644 --- a/network/tests/helper.rs +++ b/network/tests/helper.rs @@ -11,7 +11,7 @@ use std::{ use tokio::runtime::Runtime; use tracing::*; use tracing_subscriber::EnvFilter; -use veloren_network::{Network, Participant, Pid, Promises, ProtocolAddr, Stream}; +use veloren_network::{ConnectAddr, ListenAddr, Network, Participant, Pid, Promises, Stream}; #[allow(dead_code)] pub fn setup(tracing: bool, sleep: u64) -> (u64, u64) { @@ -47,7 +47,7 @@ pub fn setup(tracing: bool, sleep: u64) -> (u64, u64) { #[allow(dead_code)] pub fn network_participant_stream( - addr: ProtocolAddr, + addr: (ListenAddr, ConnectAddr), ) -> ( Arc, Network, @@ -62,11 +62,11 @@ pub fn network_participant_stream( let n_a = Network::new(Pid::fake(0), &runtime); let n_b = Network::new(Pid::fake(1), &runtime); - n_a.listen(addr.clone()).await.unwrap(); - let p1_b = n_b.connect(addr).await.unwrap(); + n_a.listen(addr.0).await.unwrap(); + let p1_b = n_b.connect(addr.1).await.unwrap(); let p1_a = n_a.connected().await.unwrap(); - let s1_a = p1_a.open(4, Promises::empty(), 0).await.unwrap(); + let s1_a = p1_a.open(4, Promises::ORDERED, 0).await.unwrap(); let s1_b = p1_b.opened().await.unwrap(); (n_a, p1_a, s1_a, n_b, p1_b, s1_b) @@ -75,28 +75,76 @@ pub fn network_participant_stream( } #[allow(dead_code)] -pub fn tcp() -> ProtocolAddr { +pub fn tcp() -> (ListenAddr, ConnectAddr) { lazy_static! { static ref PORTS: AtomicU16 = AtomicU16::new(5000); } let port = PORTS.fetch_add(1, Ordering::Relaxed); - ProtocolAddr::Tcp(SocketAddr::from(([127, 0, 0, 1], port))) + ( + ListenAddr::Tcp(SocketAddr::from(([127, 0, 0, 1], port))), + ConnectAddr::Tcp(SocketAddr::from(([127, 0, 0, 1], port))), + ) +} + +lazy_static! { + static ref UDP_PORTS: AtomicU16 = AtomicU16::new(5000); } #[allow(dead_code)] -pub fn udp() -> ProtocolAddr { - lazy_static! { - static ref PORTS: AtomicU16 = AtomicU16::new(5000); - } - let port = PORTS.fetch_add(1, Ordering::Relaxed); - ProtocolAddr::Udp(SocketAddr::from(([127, 0, 0, 1], port))) +pub fn quic() -> (ListenAddr, ConnectAddr) { + const LOCALHOST: &str = "localhost"; + let port = UDP_PORTS.fetch_add(1, Ordering::Relaxed); + + let transport_config = quinn::TransportConfig::default(); + let mut server_config = quinn::ServerConfig::default(); + server_config.transport = Arc::new(transport_config); + let mut server_config = quinn::ServerConfigBuilder::new(server_config); + server_config.protocols(&[b"veloren"]); + + trace!("generating self-signed certificate"); + let cert = rcgen::generate_simple_self_signed(vec![LOCALHOST.into()]).unwrap(); + let key = cert.serialize_private_key_der(); + let cert = cert.serialize_der().unwrap(); + + let key = quinn::PrivateKey::from_der(&key).expect("private key failed"); + let cert = quinn::Certificate::from_der(&cert).expect("cert failed"); + server_config + .certificate(quinn::CertificateChain::from_certs(vec![cert.clone()]), key) + .expect("set cert failed"); + + let server_config = server_config.build(); + + let mut client_config = quinn::ClientConfigBuilder::default(); + client_config.protocols(&[b"veloren"]); + client_config + .add_certificate_authority(cert) + .expect("adding certificate failed"); + + let client_config = client_config.build(); + ( + ListenAddr::Quic(SocketAddr::from(([127, 0, 0, 1], port)), server_config), + ConnectAddr::Quic( + SocketAddr::from(([127, 0, 0, 1], port)), + client_config, + LOCALHOST.to_owned(), + ), + ) } #[allow(dead_code)] -pub fn mpsc() -> ProtocolAddr { +pub fn udp() -> (ListenAddr, ConnectAddr) { + let port = UDP_PORTS.fetch_add(1, Ordering::Relaxed); + ( + ListenAddr::Udp(SocketAddr::from(([127, 0, 0, 1], port))), + ConnectAddr::Udp(SocketAddr::from(([127, 0, 0, 1], port))), + ) +} + +#[allow(dead_code)] +pub fn mpsc() -> (ListenAddr, ConnectAddr) { lazy_static! { static ref PORTS: AtomicU64 = AtomicU64::new(5000); } let port = PORTS.fetch_add(1, Ordering::Relaxed); - ProtocolAddr::Mpsc(port) + (ListenAddr::Mpsc(port), ConnectAddr::Mpsc(port)) } diff --git a/network/tests/integration.rs b/network/tests/integration.rs index 93534ac082..9d2e57bf77 100644 --- a/network/tests/integration.rs +++ b/network/tests/integration.rs @@ -2,9 +2,9 @@ use std::sync::Arc; use tokio::runtime::Runtime; use veloren_network::{NetworkError, StreamError}; mod helper; -use helper::{mpsc, network_participant_stream, tcp, udp}; +use helper::{mpsc, network_participant_stream, quic, tcp, udp}; use std::io::ErrorKind; -use veloren_network::{Network, Pid, Promises, ProtocolAddr}; +use veloren_network::{ConnectAddr, ListenAddr, Network, Pid, Promises}; #[test] #[ignore] @@ -73,6 +73,30 @@ fn stream_simple_mpsc_3msg() { drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown } +#[test] +fn stream_simple_quic() { + let (_, _) = helper::setup(false, 0); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(quic()); + + s1_a.send("Hello World").unwrap(); + assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown +} + +#[test] +fn stream_simple_quic_3msg() { + let (_, _) = helper::setup(false, 0); + let (r, _n_a, _p_a, mut s1_a, _n_b, _p_b, mut s1_b) = network_participant_stream(quic()); + + s1_a.send("Hello World").unwrap(); + s1_a.send(1337).unwrap(); + assert_eq!(r.block_on(s1_b.recv()), Ok("Hello World".to_string())); + assert_eq!(r.block_on(s1_b.recv()), Ok(1337)); + s1_a.send("3rdMessage").unwrap(); + assert_eq!(r.block_on(s1_b.recv()), Ok("3rdMessage".to_string())); + drop((_n_a, _n_b, _p_a, _p_b)); //clean teardown +} + #[test] #[ignore] fn stream_simple_udp() { @@ -110,16 +134,16 @@ fn tcp_and_udp_2_connections() -> std::result::Result<(), Box std::result::Result<(), Box (), _ => panic!(), @@ -170,10 +194,10 @@ fn api_stream_send_main() -> std::result::Result<(), Box> let network = network; let remote = remote; network - .listen(ProtocolAddr::Tcp("127.0.0.1:1200".parse().unwrap())) + .listen(ListenAddr::Tcp("127.0.0.1:1200".parse().unwrap())) .await?; let remote_p = remote - .connect(ProtocolAddr::Tcp("127.0.0.1:1200".parse().unwrap())) + .connect(ConnectAddr::Tcp("127.0.0.1:1200".parse().unwrap())) .await?; // keep it alive let _stream_p = remote_p @@ -199,10 +223,10 @@ fn api_stream_recv_main() -> std::result::Result<(), Box> let network = network; let remote = remote; network - .listen(ProtocolAddr::Tcp("127.0.0.1:1220".parse().unwrap())) + .listen(ListenAddr::Tcp("127.0.0.1:1220".parse().unwrap())) .await?; let remote_p = remote - .connect(ProtocolAddr::Tcp("127.0.0.1:1220".parse().unwrap())) + .connect(ConnectAddr::Tcp("127.0.0.1:1220".parse().unwrap())) .await?; let mut stream_p = remote_p .open(4, Promises::ORDERED | Promises::CONSISTENCY, 0) diff --git a/server/src/lib.rs b/server/src/lib.rs index 5093ad64a4..31f44535b1 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -86,7 +86,7 @@ use common_state::plugin::PluginMgr; use common_state::{BuildAreas, State}; use common_systems::add_local_systems; use metrics::{EcsSystemMetrics, PhysicsMetrics, TickMetrics}; -use network::{Network, Pid, ProtocolAddr}; +use network::{ListenAddr, Network, Pid}; use persistence::{ character_loader::{CharacterLoader, CharacterLoaderResponseKind}, character_updater::CharacterUpdater, @@ -391,8 +391,8 @@ impl Server { ) .await }); - runtime.block_on(network.listen(ProtocolAddr::Tcp(settings.gameserver_address)))?; - runtime.block_on(network.listen(ProtocolAddr::Mpsc(14004)))?; + runtime.block_on(network.listen(ListenAddr::Tcp(settings.gameserver_address)))?; + runtime.block_on(network.listen(ListenAddr::Mpsc(14004)))?; let connection_handler = ConnectionHandler::new(network, &runtime); // Initiate real-time world simulation