diff --git a/network/protocol/src/mpsc.rs b/network/protocol/src/mpsc.rs index 5a00f27209..76969aa0be 100644 --- a/network/protocol/src/mpsc.rs +++ b/network/protocol/src/mpsc.rs @@ -6,7 +6,7 @@ use crate::{ frame::InitFrame, handshake::{ReliableDrain, ReliableSink}, metrics::ProtocolMetricCache, - types::Bandwidth, + types::{Bandwidth, Promises}, RecvProtocol, SendProtocol, UnreliableDrain, UnreliableSink, }; use async_trait::async_trait; @@ -57,6 +57,16 @@ where metrics, } } + + /// returns all promises that this Protocol can take care of + /// If you open a Stream anyway, unsupported promises are ignored. + pub fn supported_promises() -> Promises { + Promises::ORDERED + | Promises::CONSISTENCY + | Promises::GUARANTEED_DELIVERY + | Promises::COMPRESSED + | Promises::ENCRYPTED /*assume a direct mpsc connection is secure*/ + } } impl MpscRecvProtocol diff --git a/network/protocol/src/tcp.rs b/network/protocol/src/tcp.rs index c944674ad6..315fd0d076 100644 --- a/network/protocol/src/tcp.rs +++ b/network/protocol/src/tcp.rs @@ -6,7 +6,7 @@ use crate::{ message::{ITMessage, ALLOC_BLOCK}, metrics::{ProtocolMetricCache, RemoveReason}, prio::PrioManager, - types::{Bandwidth, Mid, Sid}, + types::{Bandwidth, Mid, Promises, Sid}, RecvProtocol, SendProtocol, UnreliableDrain, UnreliableSink, }; use async_trait::async_trait; @@ -70,6 +70,15 @@ where metrics, } } + + /// returns all promises that this Protocol can take care of + /// If you open a Stream anyway, unsupported promises are ignored. + pub fn supported_promises() -> Promises { + Promises::ORDERED + | Promises::CONSISTENCY + | Promises::GUARANTEED_DELIVERY + | Promises::COMPRESSED + } } impl TcpRecvProtocol diff --git a/network/protocol/src/types.rs b/network/protocol/src/types.rs index 0037795314..dfc9142f38 100644 --- a/network/protocol/src/types.rs +++ b/network/protocol/src/types.rs @@ -65,7 +65,7 @@ pub struct Pid { /// Unique ID per Stream, in one Channel. /// one side will always start with 0, while the other start with u64::MAX / 2. /// number increases for each created Stream. -#[derive(PartialEq, Eq, Hash, Clone, Copy)] +#[derive(PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord)] pub struct Sid { internal: u64, } diff --git a/network/src/lib.rs b/network/src/lib.rs index 607eed9d91..448b50f41c 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -104,7 +104,7 @@ mod message; mod metrics; mod participant; mod scheduler; -mod trace; +mod util; pub use api::{ Network, NetworkConnectError, NetworkError, Participant, ParticipantError, ProtocolAddr, diff --git a/network/src/participant.rs b/network/src/participant.rs index 9e8c496172..c1b350c817 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -2,7 +2,7 @@ use crate::{ api::{ParticipantError, Stream}, channel::{Protocols, RecvProtocols, SendProtocols}, metrics::NetworkMetrics, - trace::DeferredTracer, + util::{DeferredTracer, SortedVec}, }; use bytes::Bytes; use futures_util::{FutureExt, StreamExt}; @@ -137,7 +137,7 @@ impl BParticipant { let (b2b_force_close_recv_protocol_s, b2b_force_close_recv_protocol_r) = async_channel::unbounded::(); let (b2b_notify_send_of_recv_s, b2b_notify_send_of_recv_r) = - crossbeam_channel::unbounded::(); + crossbeam_channel::unbounded::<(Cid, ProtocolEvent)>(); let (a2b_close_stream_s, a2b_close_stream_r) = mpsc::unbounded_channel::(); let (a2b_msg_s, a2b_msg_r) = crossbeam_channel::unbounded::<(Sid, Bytes)>(); @@ -180,6 +180,28 @@ impl BParticipant { ); } + fn best_protocol(all: &SortedVec, promises: Promises) -> Option { + // check for mpsc + for (cid, p) in all.data.iter() { + if matches!(p, SendProtocols::Mpsc(_)) { + return Some(*cid); + } + } + // check for tcp + if network_protocol::TcpSendProtocol::::supported_promises() + == promises + { + for (cid, p) in all.data.iter() { + if matches!(p, SendProtocols::Tcp(_)) { + return Some(*cid); + } + } + } + + warn!("couldn't satisfy promises"); + all.data.first().map(|(c, _)| *c) + } + //TODO: local stream_cid: HashMap to know the respective protocol #[allow(clippy::too_many_arguments)] async fn send_mgr( @@ -189,18 +211,18 @@ impl BParticipant { a2b_msg_r: crossbeam_channel::Receiver<(Sid, Bytes)>, mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, SendProtocols)>, b2b_close_send_protocol_r: async_channel::Receiver, - b2b_notify_send_of_recv_r: crossbeam_channel::Receiver, + b2b_notify_send_of_recv_r: crossbeam_channel::Receiver<(Cid, ProtocolEvent)>, _b2s_prio_statistic_s: mpsc::UnboundedSender, ) { - let mut send_protocols: HashMap = HashMap::new(); + let mut sorted_send_protocols = SortedVec::::default(); + let mut sorted_stream_protocols = SortedVec::::default(); let mut interval = tokio::time::interval(Self::TICK_TIME); let mut last_instant = Instant::now(); let mut stream_ids = self.offset_sid; trace!("workaround, actively wait for first protocol"); - b2b_add_protocol_r - .recv() - .await - .map(|(c, p)| send_protocols.insert(c, p)); + if let Some((c, p)) = b2b_add_protocol_r.recv().await { + sorted_send_protocols.insert(c, p) + } loop { let (open, close, _, addp, remp) = select!( Some(n) = a2b_open_stream_r.recv().fuse() => (Some(n), None, None, None, None), @@ -210,25 +232,29 @@ impl BParticipant { Ok(n) = b2b_close_send_protocol_r.recv().fuse() => (None, None, None, None, Some(n)), ); - addp.map(|(cid, p)| { + if let Some((cid, p)) = addp { debug!(?cid, "add protocol"); - send_protocols.insert(cid, p) - }); + sorted_send_protocols.insert(cid, p); + } - let (cid, active) = match send_protocols.iter_mut().next() { - Some((cid, a)) => (*cid, a), - None => { - warn!("no channel"); - tokio::time::sleep(Self::TICK_TIME * 1000).await; //TODO: failover - continue; - }, - }; + //verify that we have at LEAST 1 channel before continuing + if sorted_send_protocols.data.is_empty() { + warn!("no channel"); + tokio::time::sleep(Self::TICK_TIME * 1000).await; //TODO: failover + continue; + } + + //let (cid, active) = sorted_send_protocols.data.iter_mut().next().unwrap(); + //used for error handling + let mut cid = u64::MAX; let active_err = async { if let Some((prio, promises, guaranteed_bandwidth, return_s)) = open { let sid = stream_ids; - trace!(?sid, "open stream"); stream_ids += Sid::from(1); + cid = Self::best_protocol(&sorted_send_protocols, promises).unwrap(); + trace!(?sid, ?cid, "open stream"); + let stream = self .create_stream(sid, prio, promises, guaranteed_bandwidth) .await; @@ -240,34 +266,62 @@ impl BParticipant { guaranteed_bandwidth, }; + sorted_stream_protocols.insert(sid, cid); return_s.send(stream).unwrap(); - active.send(event).await?; + sorted_send_protocols + .get_mut(&cid) + .unwrap() + .send(event) + .await?; } // process recv content first let mut closeevents = b2b_notify_send_of_recv_r .try_iter() - .map(|e| { - if matches!(e, ProtocolEvent::OpenStream { .. }) { - active.notify_from_recv(e); + .map(|(cid, e)| match e { + ProtocolEvent::OpenStream { sid, .. } => { + match sorted_send_protocols.get_mut(&cid) { + Some(p) => { + sorted_stream_protocols.insert(sid, cid); + p.notify_from_recv(e); + }, + None => { + warn!(?cid, "couldn't notify create protocol, doesn't exist") + }, + }; None - } else { - Some(e) - } + }, + e => Some((cid, e)), }) .collect::>(); // get all messages and assign it to a channel for (sid, buffer) in a2b_msg_r.try_iter() { - active - .send(ProtocolEvent::Message { data: buffer, sid }) - .await? + cid = *sorted_stream_protocols.get(&sid).unwrap(); + let event = ProtocolEvent::Message { data: buffer, sid }; + sorted_send_protocols + .get_mut(&cid) + .unwrap() + .send(event) + .await?; } // process recv content afterwards + //TODO: this might get skipped when a send msg fails on another channel in the + // previous line let _ = closeevents.drain(..).map(|e| { - if let Some(e) = e { - active.notify_from_recv(e); + if let Some((cid, e)) = e { + match sorted_send_protocols.get_mut(&cid) { + Some(p) => { + if let ProtocolEvent::OpenStream { sid, .. } = e { + let _ = sorted_stream_protocols.delete(&sid); + p.notify_from_recv(e); + } else { + unreachable!("we dont send other over this channel"); + } + }, + None => warn!(?cid, "couldn't notify close protocol, doesn't exist"), + }; } }); @@ -276,13 +330,21 @@ impl BParticipant { self.delete_stream(sid).await; // Fire&Forget the protocol will take care to verify that this Frame is delayed // till the last msg was received! - active.send(ProtocolEvent::CloseStream { sid }).await?; + cid = sorted_stream_protocols.delete(&sid).unwrap(); + let event = ProtocolEvent::CloseStream { sid }; + sorted_send_protocols + .get_mut(&cid) + .unwrap() + .send(event) + .await?; } let send_time = Instant::now(); let diff = send_time.duration_since(last_instant); last_instant = send_time; - active.flush(1_000_000_000, diff).await?; //this actually blocks, so we cant set streams while it. + for (_, p) in sorted_send_protocols.data.iter_mut() { + p.flush(1_000_000_000, diff).await?; //this actually blocks, so we cant set streams while it. + } let r: Result<(), network_protocol::ProtocolError> = Ok(()); r } @@ -292,16 +354,16 @@ impl BParticipant { // remote recv will now fail, which will trigger remote send which will trigger // recv trace!("TODO: for now decide to FAIL this participant and not wait for a failover"); - send_protocols.remove(&cid).unwrap(); + sorted_send_protocols.delete(&cid).unwrap(); self.metrics.channels_disconnected(&self.remote_pid_string); - if send_protocols.is_empty() { + if sorted_send_protocols.data.is_empty() { break; } } if let Some(cid) = remp { debug!(?cid, "remove protocol"); - match send_protocols.remove(&cid) { + match sorted_send_protocols.delete(&cid) { Some(mut prot) => { self.metrics.channels_disconnected(&self.remote_pid_string); trace!("blocking flush"); @@ -311,7 +373,7 @@ impl BParticipant { }, None => trace!("tried to remove protocol twice"), }; - if send_protocols.is_empty() { + if sorted_send_protocols.data.is_empty() { break; } } @@ -330,7 +392,7 @@ impl BParticipant { mut b2b_add_protocol_r: mpsc::UnboundedReceiver<(Cid, RecvProtocols)>, b2b_force_close_recv_protocol_r: async_channel::Receiver, b2b_close_send_protocol_s: async_channel::Sender, - b2b_notify_send_of_recv_s: crossbeam_channel::Sender, + b2b_notify_send_of_recv_s: crossbeam_channel::Sender<(Cid, ProtocolEvent)>, ) { let mut recv_protocols: HashMap> = HashMap::new(); // we should be able to directly await futures imo @@ -390,7 +452,7 @@ impl BParticipant { guaranteed_bandwidth, }) => { trace!(?sid, "open stream"); - let _ = b2b_notify_send_of_recv_s.send(r.unwrap()); + let _ = b2b_notify_send_of_recv_s.send((cid, r.unwrap())); // waiting for receiving is not necessary, because the send_mgr will first // process this before process messages! let stream = self @@ -401,7 +463,7 @@ impl BParticipant { }, Ok(ProtocolEvent::CloseStream { sid }) => { trace!(?sid, "close stream"); - let _ = b2b_notify_send_of_recv_s.send(r.unwrap()); + let _ = b2b_notify_send_of_recv_s.send((cid, r.unwrap())); self.delete_stream(sid).await; retrigger(cid, p, &mut recv_protocols); }, diff --git a/network/src/trace.rs b/network/src/trace.rs deleted file mode 100644 index 640d65ee55..0000000000 --- a/network/src/trace.rs +++ /dev/null @@ -1,46 +0,0 @@ -use core::hash::Hash; -use std::{collections::HashMap, time::Instant}; -use tracing::Level; - -/// used to collect multiple traces and not spam the console -pub(crate) struct DeferredTracer { - level: Level, - items: HashMap, - last: Instant, - last_cnt: u32, -} - -impl DeferredTracer { - pub(crate) fn new(level: Level) -> Self { - Self { - level, - items: HashMap::new(), - last: Instant::now(), - last_cnt: 0, - } - } - - pub(crate) fn log(&mut self, t: T) { - if tracing::level_enabled!(self.level) { - *self.items.entry(t).or_default() += 1; - self.last = Instant::now(); - self.last_cnt += 1; - } else { - } - } - - pub(crate) fn print(&mut self) -> Option> { - const MAX_LOGS: u32 = 10_000; - const MAX_SECS: u64 = 1; - if tracing::level_enabled!(self.level) - && (self.last_cnt > MAX_LOGS || self.last.elapsed().as_secs() >= MAX_SECS) - { - if self.last_cnt > MAX_LOGS { - tracing::debug!("this seems to be logged continuesly"); - } - Some(std::mem::take(&mut self.items)) - } else { - None - } - } -} diff --git a/network/src/util.rs b/network/src/util.rs new file mode 100644 index 0000000000..b9a8801263 --- /dev/null +++ b/network/src/util.rs @@ -0,0 +1,117 @@ +use core::hash::Hash; +use std::{collections::HashMap, time::Instant}; +use tracing::Level; + +/// used to collect multiple traces and not spam the console +pub(crate) struct DeferredTracer { + level: Level, + items: HashMap, + last: Instant, + last_cnt: u32, +} + +impl DeferredTracer { + pub(crate) fn new(level: Level) -> Self { + Self { + level, + items: HashMap::new(), + last: Instant::now(), + last_cnt: 0, + } + } + + pub(crate) fn log(&mut self, t: T) { + if tracing::level_enabled!(self.level) { + *self.items.entry(t).or_default() += 1; + self.last = Instant::now(); + self.last_cnt += 1; + } else { + } + } + + pub(crate) fn print(&mut self) -> Option> { + const MAX_LOGS: u32 = 10_000; + const MAX_SECS: u64 = 1; + if tracing::level_enabled!(self.level) + && (self.last_cnt > MAX_LOGS || self.last.elapsed().as_secs() >= MAX_SECS) + { + if self.last_cnt > MAX_LOGS { + tracing::debug!("this seems to be logged continuesly"); + } + Some(std::mem::take(&mut self.items)) + } else { + None + } + } +} + +/// Used for storing Protocols in a Participant or Stream <-> Protocol +pub(crate) struct SortedVec { + pub data: Vec<(K, V)>, +} + +impl Default for SortedVec { + fn default() -> Self { Self { data: vec![] } } +} + +impl SortedVec +where + K: Ord + Copy, +{ + pub fn insert(&mut self, k: K, v: V) { + self.data.push((k, v)); + self.data.sort_by_key(|&(k, _)| k); + } + + pub fn delete(&mut self, k: &K) -> Option { + if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { + Some(self.data.remove(i).1) + } else { + None + } + } + + pub fn get(&self, k: &K) -> Option<&V> { + if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { + Some(&self.data[i].1) + } else { + None + } + } + + pub fn get_mut(&mut self, k: &K) -> Option<&mut V> { + if let Ok(i) = self.data.binary_search_by_key(k, |&(k, _)| k) { + Some(&mut self.data[i].1) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sorted_vec() { + let mut vec = SortedVec::default(); + vec.insert(10, "Hello"); + println!("{:?}", vec.data); + vec.insert(30, "World"); + println!("{:?}", vec.data); + vec.insert(20, " "); + println!("{:?}", vec.data); + assert_eq!(vec.data[0].1, "Hello"); + assert_eq!(vec.data[1].1, " "); + assert_eq!(vec.data[2].1, "World"); + assert_eq!(vec.get(&30), Some(&"World")); + assert_eq!(vec.get_mut(&20), Some(&mut " ")); + assert_eq!(vec.get(&10), Some(&"Hello")); + assert_eq!(vec.delete(&40), None); + assert_eq!(vec.delete(&10), Some("Hello")); + assert_eq!(vec.delete(&10), None); + assert_eq!(vec.get(&30), Some(&"World")); + assert_eq!(vec.get_mut(&20), Some(&mut " ")); + assert_eq!(vec.get(&10), None); + } +}