diff --git a/network/src/api.rs b/network/src/api.rs index 54de28cf35..07fbcd1375 100644 --- a/network/src/api.rs +++ b/network/src/api.rs @@ -8,11 +8,7 @@ use crate::{ scheduler::Scheduler, types::{Mid, Pid, Prio, Promises, Sid}, }; -use async_std::{ - io, - sync::{Mutex, RwLock}, - task, -}; +use async_std::{io, sync::Mutex, task}; use futures::{ channel::{mpsc, oneshot}, sink::SinkExt, @@ -52,8 +48,8 @@ pub enum ProtocolAddr { pub struct Participant { local_pid: Pid, remote_pid: Pid, - a2b_stream_open_s: RwLock>, - b2a_stream_opened_r: RwLock>, + a2b_stream_open_s: Mutex>, + b2a_stream_opened_r: Mutex>, a2s_disconnect_s: A2sDisconnect, } @@ -147,12 +143,12 @@ pub enum StreamError { /// [`connected`]: Network::connected pub struct Network { local_pid: Pid, - participant_disconnect_sender: RwLock>, + participant_disconnect_sender: Mutex>, listen_sender: - RwLock>)>>, + Mutex>)>>, connect_sender: - RwLock>)>>, - connected_receiver: RwLock>, + Mutex>)>>, + connected_receiver: Mutex>, shutdown_sender: Option>, } @@ -249,10 +245,10 @@ impl Network { ( Self { local_pid: participant_id, - participant_disconnect_sender: RwLock::new(HashMap::new()), - listen_sender: RwLock::new(listen_sender), - connect_sender: RwLock::new(connect_sender), - connected_receiver: RwLock::new(connected_receiver), + participant_disconnect_sender: Mutex::new(HashMap::new()), + listen_sender: Mutex::new(listen_sender), + connect_sender: Mutex::new(connect_sender), + connected_receiver: Mutex::new(connected_receiver), shutdown_sender: Some(shutdown_sender), }, move || { @@ -300,7 +296,7 @@ impl Network { let (s2a_result_s, s2a_result_r) = oneshot::channel::>(); debug!(?address, "listening on address"); self.listen_sender - .write() + .lock() .await .send((address, s2a_result_s)) .await?; @@ -356,7 +352,7 @@ impl Network { let (pid_sender, pid_receiver) = oneshot::channel::>(); debug!(?address, "Connect to address"); self.connect_sender - .write() + .lock() .await .send((address, pid_sender)) .await?; @@ -370,7 +366,7 @@ impl Network { "Received Participant id from remote and return to user" ); self.participant_disconnect_sender - .write() + .lock() .await .insert(pid, participant.a2s_disconnect_s.clone()); Ok(participant) @@ -410,9 +406,9 @@ impl Network { /// [`Streams`]: crate::api::Stream /// [`listen`]: crate::api::Network::listen pub async fn connected(&self) -> Result { - let participant = self.connected_receiver.write().await.next().await?; + let participant = self.connected_receiver.lock().await.next().await?; self.participant_disconnect_sender - .write() + .lock() .await .insert(participant.remote_pid, participant.a2s_disconnect_s.clone()); Ok(participant) @@ -430,8 +426,8 @@ impl Participant { Self { local_pid, remote_pid, - a2b_stream_open_s: RwLock::new(a2b_stream_open_s), - b2a_stream_opened_r: RwLock::new(b2a_stream_opened_r), + a2b_stream_open_s: Mutex::new(a2b_stream_open_s), + b2a_stream_opened_r: Mutex::new(b2a_stream_opened_r), a2s_disconnect_s: Arc::new(Mutex::new(Some(a2s_disconnect_s))), } } @@ -477,11 +473,11 @@ impl Participant { /// /// [`Streams`]: crate::api::Stream pub async fn open(&self, prio: u8, promises: Promises) -> Result { - //use this lock for now to make sure that only one open at a time is made, - // TODO: not sure if we can paralise that, check in future - let mut a2b_stream_open_s = self.a2b_stream_open_s.write().await; let (p2a_return_stream_s, p2a_return_stream_r) = oneshot::channel(); - if let Err(e) = a2b_stream_open_s + if let Err(e) = self + .a2b_stream_open_s + .lock() + .await .send((prio, promises, p2a_return_stream_s)) .await { @@ -535,10 +531,7 @@ impl Participant { /// [`connected`]: Network::connected /// [`open`]: Participant::open pub async fn opened(&self) -> Result { - //use this lock for now to make sure that only one open at a time is made, - // TODO: not sure if we can paralise that, check in future - let mut stream_opened_receiver = self.b2a_stream_opened_r.write().await; - match stream_opened_receiver.next().await { + match self.b2a_stream_opened_r.lock().await.next().await { Some(stream) => { let sid = stream.sid; debug!(?sid, ?self.remote_pid, "Receive opened stream"); @@ -861,7 +854,7 @@ impl Drop for Network { // we MUST avoid nested block_on, good that Network::Drop no longer triggers // Participant::Drop directly but just the BParticipant for (remote_pid, a2s_disconnect_s) in - self.participant_disconnect_sender.write().await.drain() + self.participant_disconnect_sender.lock().await.drain() { match a2s_disconnect_s.lock().await.take() { Some(mut a2s_disconnect_s) => { diff --git a/network/src/participant.rs b/network/src/participant.rs index 741d85fa2b..ecfc548879 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -8,7 +8,7 @@ use crate::{ protocols::Protocols, types::{Cid, Frame, Pid, Prio, Promises, Sid}, }; -use async_std::sync::RwLock; +use async_std::sync::{Mutex, RwLock}; use futures::{ channel::{mpsc, oneshot}, future::FutureExt, @@ -46,7 +46,7 @@ struct StreamInfo { prio: Prio, promises: Promises, send_closed: Arc, - b2a_msg_recv_s: mpsc::UnboundedSender, + b2a_msg_recv_s: Mutex>, } #[derive(Debug)] @@ -71,7 +71,7 @@ pub struct BParticipant { remote_pid: Pid, remote_pid_string: String, //optimisation offset_sid: Sid, - channels: Arc>>, + channels: Arc>>>, streams: RwLock>, running_mgr: AtomicUsize, run_channels: Option, @@ -252,10 +252,10 @@ impl BParticipant { frame: Frame, #[cfg(feature = "metrics")] frames_out_total_cache: &mut PidCidFrameCache, ) -> bool { - // find out ideal channel here - //TODO: just take first - let mut lock = self.channels.write().await; - if let Some(ci) = lock.values_mut().next() { + let mut drop_cid = None; + // TODO: find out ideal channel here + let res = if let Some(ci) = self.channels.read().await.values().next() { + let mut ci = ci.lock().await; //we are increasing metrics without checking the result to please // borrow_checker. otherwise we would need to close `frame` what we // dont want! @@ -266,20 +266,7 @@ impl BParticipant { if let Err(e) = ci.b2w_frame_s.send(frame).await { let cid = ci.cid; info!(?e, ?cid, "channel no longer available"); - if let Some(ci) = self.channels.write().await.remove(&cid) { - trace!(?cid, "stopping read protocol"); - if let Err(e) = ci.b2r_read_shutdown.send(()) { - trace!(?cid, ?e, "seems like was already shut down"); - } - } - //TODO FIXME tags: takeover channel multiple - info!( - "FIXME: the frame is actually drop. which is fine for now as the participant \ - will be closed, but not if we do channel-takeover" - ); - //TEMP FIX: as we dont have channel takeover yet drop the whole bParticipant - self.close_write_api(Some(ParticipantError::ProtocolFailedUnrecoverable)) - .await; + drop_cid = Some(cid); false } else { true @@ -301,7 +288,25 @@ impl BParticipant { guard.1 += 1; } false - } + }; + if let Some(cid) = drop_cid { + if let Some(ci) = self.channels.write().await.remove(&cid) { + let ci = ci.into_inner(); + trace!(?cid, "stopping read protocol"); + if let Err(e) = ci.b2r_read_shutdown.send(()) { + trace!(?cid, ?e, "seems like was already shut down"); + } + } + //TODO FIXME tags: takeover channel multiple + info!( + "FIXME: the frame is actually drop. which is fine for now as the participant will \ + be closed, but not if we do channel-takeover" + ); + //TEMP FIX: as we dont have channel takeover yet drop the whole bParticipant + self.close_write_api(Some(ParticipantError::ProtocolFailedUnrecoverable)) + .await; + }; + res } async fn handle_frames_mgr( @@ -325,7 +330,8 @@ impl BParticipant { Err(()) => { // The read protocol stopped, i need to make sure that write gets stopped debug!("read protocol was closed. Stopping write protocol"); - if let Some(ci) = self.channels.write().await.get_mut(&cid) { + if let Some(ci) = self.channels.read().await.get(&cid) { + let mut ci = ci.lock().await; ci.b2w_frame_s .close() .await @@ -381,7 +387,7 @@ impl BParticipant { .with_label_values(&[&self.remote_pid_string]) .inc(); si.send_closed.store(true, Ordering::Relaxed); - si.b2a_msg_recv_s.close_channel(); + si.b2a_msg_recv_s.into_inner().close_channel(); trace!(?sid, "Closed stream from remote"); } else { warn!( @@ -414,8 +420,8 @@ impl BParticipant { if finished { //trace!(?mid, "finished receiving message"); let imsg = messages.remove(&mid).unwrap(); - if let Some(si) = self.streams.write().await.get_mut(&imsg.sid) { - if let Err(e) = si.b2a_msg_recv_s.send(imsg).await { + if let Some(si) = self.streams.read().await.get(&imsg.sid) { + if let Err(e) = si.b2a_msg_recv_s.lock().await.send(imsg).await { warn!( ?e, ?mid, @@ -449,8 +455,10 @@ impl BParticipant { .await; }, f => { - //unreachable!("Frame should never reache participant!: {:?}", f); - error!(?f, ?cid, "Frame should never reache participant!"); + unreachable!( + "Frame should never reach participant!: {:?}, cid: {}", + f, cid + ); }, } } @@ -482,12 +490,15 @@ impl BParticipant { let channels = self.channels.clone(); async move { let (channel, b2w_frame_s, b2r_read_shutdown) = Channel::new(cid); - channels.write().await.insert(cid, ChannelInfo { + channels.write().await.insert( cid, - cid_string: cid.to_string(), - b2w_frame_s, - b2r_read_shutdown, - }); + Mutex::new(ChannelInfo { + cid, + cid_string: cid.to_string(), + b2w_frame_s, + b2r_read_shutdown, + }), + ); b2s_create_channel_done_s.send(()).unwrap(); #[cfg(feature = "metrics")] self.metrics @@ -619,6 +630,7 @@ impl BParticipant { debug!("Closing all channels, after flushed prios"); for (cid, ci) in self.channels.write().await.drain() { + let ci = ci.into_inner(); if let Err(e) = ci.b2r_read_shutdown.send(()) { debug!( ?e, @@ -655,7 +667,10 @@ impl BParticipant { sender .send(match lock.error.take() { None => Ok(()), - Some(e) => Err(e), + Some(ParticipantError::ProtocolFailedUnrecoverable) => { + Err(ParticipantError::ProtocolFailedUnrecoverable) + }, + Some(ParticipantError::ParticipantDisconnected) => Ok(()), }) .unwrap(); @@ -695,7 +710,7 @@ impl BParticipant { match self.streams.read().await.get(&sid) { Some(si) => { si.send_closed.store(true, Ordering::Relaxed); - si.b2a_msg_recv_s.close_channel(); + si.b2a_msg_recv_s.lock().await.close_channel(); }, None => warn!("Couldn't find the stream, might be simultaneous close from remote"), } @@ -742,7 +757,7 @@ impl BParticipant { prio, promises, send_closed: send_closed.clone(), - b2a_msg_recv_s, + b2a_msg_recv_s: Mutex::new(b2a_msg_recv_s), }); #[cfg(feature = "metrics")] self.metrics @@ -770,7 +785,7 @@ impl BParticipant { lock.b2a_stream_opened_s.close_channel(); debug!("Closing all streams for write"); - for (sid, si) in self.streams.write().await.iter() { + for (sid, si) in self.streams.read().await.iter() { trace!(?sid, "Shutting down Stream for write"); si.send_closed.store(true, Ordering::Relaxed); } @@ -783,7 +798,7 @@ impl BParticipant { debug!("Closing all streams"); for (sid, si) in self.streams.write().await.drain() { trace!(?sid, "Shutting down Stream"); - si.b2a_msg_recv_s.close_channel(); + si.b2a_msg_recv_s.lock().await.close_channel(); } } } diff --git a/network/src/protocols.rs b/network/src/protocols.rs index 65541d4dc1..f26e2572f7 100644 --- a/network/src/protocols.rs +++ b/network/src/protocols.rs @@ -2,12 +2,13 @@ use crate::metrics::{CidFrameCache, NetworkMetrics}; use crate::{ participant::C2pFrame, - types::{Cid, Frame, Mid, Pid, Sid}, + types::{Cid, Frame}, }; use async_std::{ + io::prelude::*, net::{TcpStream, UdpSocket}, - prelude::*, }; + use futures::{ channel::{mpsc, oneshot}, future::{Fuse, FutureExt}, @@ -69,33 +70,85 @@ impl TcpProtocol { } } - /// read_except and if it fails, close the protocol - async fn read_or_close( - cid: Cid, - mut stream: &TcpStream, - mut bytes: &mut [u8], + async fn read_frame( + r: &mut R, mut end_receiver: &mut Fuse>, - w2c_cid_frame_s: &mut mpsc::UnboundedSender, - ) -> bool { + ) -> Result> { + let handle = |read_result| match read_result { + Ok(_) => Ok(()), + Err(e) => Err(Some(e)), + }; + + let mut frame_no = [0u8; 1]; match select! { - r = stream.read_exact(&mut bytes).fuse() => Some(r), + r = r.read_exact(&mut frame_no).fuse() => Some(r), _ = end_receiver => None, } { - Some(Ok(_)) => false, - Some(Err(e)) => { - info!(?e, "Closing tcp protocol due to read error"); - //w2c_cid_frame_s is shared, dropping it wouldn't notify the receiver as every - // channel is holding a sender! thats why Ne need a explicit - // STOP here - w2c_cid_frame_s - .send((cid, Err(()))) - .await - .expect("Channel or Participant seems no longer to exist"); - true - }, + Some(read_result) => handle(read_result)?, None => { trace!("shutdown requested"); - true + return Err(None); + }, + }; + + match frame_no[0] { + FRAME_HANDSHAKE => { + let mut bytes = [0u8; 19]; + handle(r.read_exact(&mut bytes).await)?; + Ok(Frame::gen_handshake(bytes)) + }, + FRAME_INIT => { + let mut bytes = [0u8; 32]; + handle(r.read_exact(&mut bytes).await)?; + Ok(Frame::gen_init(bytes)) + }, + FRAME_SHUTDOWN => Ok(Frame::Shutdown), + FRAME_OPEN_STREAM => { + let mut bytes = [0u8; 10]; + handle(r.read_exact(&mut bytes).await)?; + Ok(Frame::gen_open_stream(bytes)) + }, + FRAME_CLOSE_STREAM => { + let mut bytes = [0u8; 8]; + handle(r.read_exact(&mut bytes).await)?; + Ok(Frame::gen_close_stream(bytes)) + }, + FRAME_DATA_HEADER => { + let mut bytes = [0u8; 24]; + handle(r.read_exact(&mut bytes).await)?; + Ok(Frame::gen_data_header(bytes)) + }, + FRAME_DATA => { + let mut bytes = [0u8; 18]; + handle(r.read_exact(&mut bytes).await)?; + let (mid, start, length) = Frame::gen_data(bytes); + let mut data = vec![0; length as usize]; + handle(r.read_exact(&mut data).await)?; + Ok(Frame::Data { mid, start, data }) + }, + FRAME_RAW => { + let mut bytes = [0u8; 2]; + handle(r.read_exact(&mut bytes).await)?; + let length = Frame::gen_raw(bytes); + let mut data = vec![0; length as usize]; + handle(r.read_exact(&mut data).await)?; + Ok(Frame::Raw(data)) + }, + other => { + // report a RAW frame, but cannot rely on the next 2 bytes to be a size. + // guessing 32 bytes, which might help to sort down issues + let mut data = vec![0; 32]; + //keep the first byte! + match r.read(&mut data[1..]).await { + Ok(n) => { + data.truncate(n + 1); + Ok(()) + }, + Err(e) => Err(Some(e)), + }?; + data[0] = other; + warn!(?data, "got a unexpected RAW msg"); + Ok(Frame::Raw(data)) }, } } @@ -114,131 +167,105 @@ impl TcpProtocol { .metrics .wire_in_throughput .with_label_values(&[&cid.to_string()]); - let stream = self.stream.clone(); + let mut stream = self.stream.clone(); let mut end_r = end_r.fuse(); - macro_rules! read_or_close { - ($x:expr) => { - if TcpProtocol::read_or_close(cid, &stream, $x, &mut end_r, w2c_cid_frame_s).await { - trace!("read_or_close requested a shutdown"); - break; - } - }; - } - loop { - let frame_no = { - let mut bytes = [0u8; 1]; - read_or_close!(&mut bytes); - bytes[0] - }; - let frame = match frame_no { - FRAME_HANDSHAKE => { - let mut bytes = [0u8; 19]; - read_or_close!(&mut bytes); - let magic_number = *<&[u8; 7]>::try_from(&bytes[0..7]).unwrap(); - Frame::Handshake { - magic_number, - version: [ - u32::from_le_bytes(*<&[u8; 4]>::try_from(&bytes[7..11]).unwrap()), - u32::from_le_bytes(*<&[u8; 4]>::try_from(&bytes[11..15]).unwrap()), - u32::from_le_bytes(*<&[u8; 4]>::try_from(&bytes[15..19]).unwrap()), - ], - } - }, - FRAME_INIT => { - let mut bytes = [0u8; 16]; - read_or_close!(&mut bytes); - let pid = Pid::from_le_bytes(bytes); - read_or_close!(&mut bytes); - let secret = u128::from_le_bytes(bytes); - Frame::Init { pid, secret } - }, - FRAME_SHUTDOWN => Frame::Shutdown, - FRAME_OPEN_STREAM => { - let mut bytes = [0u8; 10]; - read_or_close!(&mut bytes); - let sid = Sid::from_le_bytes(*<&[u8; 8]>::try_from(&bytes[0..8]).unwrap()); - let prio = bytes[8]; - let promises = bytes[9]; - Frame::OpenStream { - sid, - prio, - promises, - } - }, - FRAME_CLOSE_STREAM => { - let mut bytes = [0u8; 8]; - read_or_close!(&mut bytes); - let sid = Sid::from_le_bytes(*<&[u8; 8]>::try_from(&bytes[0..8]).unwrap()); - Frame::CloseStream { sid } - }, - FRAME_DATA_HEADER => { - let mut bytes = [0u8; 24]; - read_or_close!(&mut bytes); - let mid = Mid::from_le_bytes(*<&[u8; 8]>::try_from(&bytes[0..8]).unwrap()); - let sid = Sid::from_le_bytes(*<&[u8; 8]>::try_from(&bytes[8..16]).unwrap()); - let length = u64::from_le_bytes(*<&[u8; 8]>::try_from(&bytes[16..24]).unwrap()); - Frame::DataHeader { mid, sid, length } - }, - FRAME_DATA => { - let mut bytes = [0u8; 18]; - read_or_close!(&mut bytes); - let mid = Mid::from_le_bytes(*<&[u8; 8]>::try_from(&bytes[0..8]).unwrap()); - let start = u64::from_le_bytes(*<&[u8; 8]>::try_from(&bytes[8..16]).unwrap()); - let length = u16::from_le_bytes(*<&[u8; 2]>::try_from(&bytes[16..18]).unwrap()); - let mut data = vec![0; length as usize]; + match Self::read_frame(&mut stream, &mut end_r).await { + Ok(frame) => { #[cfg(feature = "metrics")] - throughput_cache.inc_by(length as i64); - read_or_close!(&mut data); - Frame::Data { mid, start, data } + { + metrics_cache.with_label_values(&frame).inc(); + if let Frame::Data { + mid: _, + start: _, + ref data, + } = frame + { + throughput_cache.inc_by(data.len() as i64); + } + } + w2c_cid_frame_s + .send((cid, Ok(frame))) + .await + .expect("Channel or Participant seems no longer to exist"); }, - FRAME_RAW => { - let mut bytes = [0u8; 2]; - read_or_close!(&mut bytes); - let length = u16::from_le_bytes([bytes[0], bytes[1]]); - let mut data = vec![0; length as usize]; - read_or_close!(&mut data); - Frame::Raw(data) + Err(e_option) => { + if let Some(e) = e_option { + info!(?e, "Closing tcp protocol due to read error"); + //w2c_cid_frame_s is shared, dropping it wouldn't notify the receiver as + // every channel is holding a sender! thats why Ne + // need a explicit STOP here + w2c_cid_frame_s + .send((cid, Err(()))) + .await + .expect("Channel or Participant seems no longer to exist"); + } + //None is clean shutdown + break; }, - other => { - // report a RAW frame, but cannot rely on the next 2 bytes to be a size. - // guessing 32 bytes, which might help to sort down issues - let mut data = vec![0; 32]; - //keep the first byte! - read_or_close!(&mut data[1..]); - data[0] = other; - warn!(?data, "got a unexpected RAW msg"); - Frame::Raw(data) - }, - }; - #[cfg(feature = "metrics")] - metrics_cache.with_label_values(&frame).inc(); - w2c_cid_frame_s - .send((cid, Ok(frame))) - .await - .expect("Channel or Participant seems no longer to exist"); + } } trace!("Shutting down tcp read()"); } - /// read_except and if it fails, close the protocol - async fn write_or_close( - stream: &mut TcpStream, - bytes: &[u8], - c2w_frame_r: &mut mpsc::UnboundedReceiver, - ) -> bool { - match stream.write_all(&bytes).await { - Err(e) => { - info!( - ?e, - "Got an error writing to tcp, going to close this channel" - ); - c2w_frame_r.close(); - true + pub async fn write_frame( + w: &mut W, + frame: Frame, + ) -> Result<(), std::io::Error> { + match frame { + Frame::Handshake { + magic_number, + version, + } => { + w.write_all(&FRAME_HANDSHAKE.to_be_bytes()).await?; + w.write_all(&magic_number).await?; + w.write_all(&version[0].to_le_bytes()).await?; + w.write_all(&version[1].to_le_bytes()).await?; + w.write_all(&version[2].to_le_bytes()).await?; }, - _ => false, - } + Frame::Init { pid, secret } => { + w.write_all(&FRAME_INIT.to_be_bytes()).await?; + w.write_all(&pid.to_le_bytes()).await?; + w.write_all(&secret.to_le_bytes()).await?; + }, + Frame::Shutdown => { + w.write_all(&FRAME_SHUTDOWN.to_be_bytes()).await?; + }, + Frame::OpenStream { + sid, + prio, + promises, + } => { + w.write_all(&FRAME_OPEN_STREAM.to_be_bytes()).await?; + w.write_all(&sid.to_le_bytes()).await?; + w.write_all(&prio.to_le_bytes()).await?; + w.write_all(&promises.to_le_bytes()).await?; + }, + Frame::CloseStream { sid } => { + w.write_all(&FRAME_CLOSE_STREAM.to_be_bytes()).await?; + w.write_all(&sid.to_le_bytes()).await?; + }, + Frame::DataHeader { mid, sid, length } => { + w.write_all(&FRAME_DATA_HEADER.to_be_bytes()).await?; + w.write_all(&mid.to_le_bytes()).await?; + w.write_all(&sid.to_le_bytes()).await?; + w.write_all(&length.to_le_bytes()).await?; + }, + Frame::Data { mid, start, data } => { + w.write_all(&FRAME_DATA.to_be_bytes()).await?; + w.write_all(&mid.to_le_bytes()).await?; + w.write_all(&start.to_le_bytes()).await?; + w.write_all(&(data.len() as u16).to_le_bytes()).await?; + w.write_all(&data).await?; + }, + Frame::Raw(data) => { + w.write_all(&FRAME_RAW.to_be_bytes()).await?; + w.write_all(&(data.len() as u16).to_le_bytes()).await?; + w.write_all(&data).await?; + }, + }; + Ok(()) } pub async fn write_to_wire(&self, cid: Cid, mut c2w_frame_r: mpsc::UnboundedReceiver) { @@ -254,72 +281,27 @@ impl TcpProtocol { #[cfg(not(feature = "metrics"))] let _cid = cid; - macro_rules! write_or_close { - ($x:expr) => { - if TcpProtocol::write_or_close(&mut stream, $x, &mut c2w_frame_r).await { - trace!("write_or_close requested a shutdown"); - break; - } - }; - } - while let Some(frame) = c2w_frame_r.next().await { #[cfg(feature = "metrics")] - metrics_cache.with_label_values(&frame).inc(); - match frame { - Frame::Handshake { - magic_number, - version, - } => { - write_or_close!(&FRAME_HANDSHAKE.to_be_bytes()); - write_or_close!(&magic_number); - write_or_close!(&version[0].to_le_bytes()); - write_or_close!(&version[1].to_le_bytes()); - write_or_close!(&version[2].to_le_bytes()); - }, - Frame::Init { pid, secret } => { - write_or_close!(&FRAME_INIT.to_be_bytes()); - write_or_close!(&pid.to_le_bytes()); - write_or_close!(&secret.to_le_bytes()); - }, - Frame::Shutdown => { - write_or_close!(&FRAME_SHUTDOWN.to_be_bytes()); - }, - Frame::OpenStream { - sid, - prio, - promises, - } => { - write_or_close!(&FRAME_OPEN_STREAM.to_be_bytes()); - write_or_close!(&sid.to_le_bytes()); - write_or_close!(&prio.to_le_bytes()); - write_or_close!(&promises.to_le_bytes()); - }, - Frame::CloseStream { sid } => { - write_or_close!(&FRAME_CLOSE_STREAM.to_be_bytes()); - write_or_close!(&sid.to_le_bytes()); - }, - Frame::DataHeader { mid, sid, length } => { - write_or_close!(&FRAME_DATA_HEADER.to_be_bytes()); - write_or_close!(&mid.to_le_bytes()); - write_or_close!(&sid.to_le_bytes()); - write_or_close!(&length.to_le_bytes()); - }, - Frame::Data { mid, start, data } => { - #[cfg(feature = "metrics")] + { + metrics_cache.with_label_values(&frame).inc(); + if let Frame::Data { + mid: _, + start: _, + ref data, + } = frame + { throughput_cache.inc_by(data.len() as i64); - write_or_close!(&FRAME_DATA.to_be_bytes()); - write_or_close!(&mid.to_le_bytes()); - write_or_close!(&start.to_le_bytes()); - write_or_close!(&(data.len() as u16).to_le_bytes()); - write_or_close!(&data); - }, - Frame::Raw(data) => { - write_or_close!(&FRAME_RAW.to_be_bytes()); - write_or_close!(&(data.len() as u16).to_le_bytes()); - write_or_close!(&data); - }, + } } + if let Err(e) = Self::write_frame(&mut stream, frame).await { + info!( + ?e, + "Got an error writing to tcp, going to close this channel" + ); + c2w_frame_r.close(); + break; + }; } trace!("shutting down tcp write()"); } @@ -372,81 +354,22 @@ impl UdpProtocol { let frame_no = bytes[0]; let frame = match frame_no { FRAME_HANDSHAKE => { - let bytes = &bytes[1..20]; - let magic_number = [ - bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], - ]; - Frame::Handshake { - magic_number, - version: [ - u32::from_le_bytes([bytes[7], bytes[8], bytes[9], bytes[10]]), - u32::from_le_bytes([bytes[11], bytes[12], bytes[13], bytes[14]]), - u32::from_le_bytes([bytes[15], bytes[16], bytes[17], bytes[18]]), - ], - } - }, - FRAME_INIT => { - let pid = Pid::from_le_bytes([ - bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], - bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], - bytes[15], bytes[16], - ]); - let secret = u128::from_le_bytes([ - bytes[17], bytes[18], bytes[19], bytes[20], bytes[21], bytes[22], - bytes[23], bytes[24], bytes[25], bytes[26], bytes[27], bytes[28], - bytes[29], bytes[30], bytes[31], bytes[32], - ]); - Frame::Init { pid, secret } + Frame::gen_handshake(*<&[u8; 19]>::try_from(&bytes[1..20]).unwrap()) }, + FRAME_INIT => Frame::gen_init(*<&[u8; 32]>::try_from(&bytes[1..33]).unwrap()), FRAME_SHUTDOWN => Frame::Shutdown, FRAME_OPEN_STREAM => { - let bytes = &bytes[1..11]; - let sid = Sid::from_le_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], - bytes[7], - ]); - let prio = bytes[8]; - let promises = bytes[9]; - Frame::OpenStream { - sid, - prio, - promises, - } + Frame::gen_open_stream(*<&[u8; 10]>::try_from(&bytes[1..11]).unwrap()) }, FRAME_CLOSE_STREAM => { - let bytes = &bytes[1..9]; - let sid = Sid::from_le_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], - bytes[7], - ]); - Frame::CloseStream { sid } + Frame::gen_close_stream(*<&[u8; 8]>::try_from(&bytes[1..9]).unwrap()) }, FRAME_DATA_HEADER => { - let bytes = &bytes[1..25]; - let mid = Mid::from_le_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], - bytes[7], - ]); - let sid = Sid::from_le_bytes([ - bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], - bytes[15], - ]); - let length = u64::from_le_bytes([ - bytes[16], bytes[17], bytes[18], bytes[19], bytes[20], bytes[21], - bytes[22], bytes[23], - ]); - Frame::DataHeader { mid, sid, length } + Frame::gen_data_header(*<&[u8; 24]>::try_from(&bytes[1..25]).unwrap()) }, FRAME_DATA => { - let mid = Mid::from_le_bytes([ - bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], - bytes[8], - ]); - let start = u64::from_le_bytes([ - bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15], - bytes[16], - ]); - let length = u16::from_le_bytes([bytes[17], bytes[18]]); + let (mid, start, length) = + Frame::gen_data(*<&[u8; 18]>::try_from(&bytes[1..19]).unwrap()); let mut data = vec![0; length as usize]; #[cfg(feature = "metrics")] throughput_cache.inc_by(length as i64); @@ -454,7 +377,7 @@ impl UdpProtocol { Frame::Data { mid, start, data } }, FRAME_RAW => { - let length = u16::from_le_bytes([bytes[1], bytes[2]]); + let length = Frame::gen_raw(*<&[u8; 2]>::try_from(&bytes[1..3]).unwrap()); let mut data = vec![0; length as usize]; data.copy_from_slice(&bytes[3..]); Frame::Raw(data) @@ -648,7 +571,6 @@ mod tests { .await .unwrap(); client.flush(); - //handle data let (mut w2c_cid_frame_s, mut w2c_cid_frame_r) = mpsc::unbounded::(); let (read_stop_sender, read_stop_receiver) = oneshot::channel(); diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index ae53f5d4e7..b352463166 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -7,10 +7,7 @@ use crate::{ protocols::{Protocols, TcpProtocol, UdpProtocol}, types::Pid, }; -use async_std::{ - io, net, - sync::{Mutex, RwLock}, -}; +use async_std::{io, net, sync::Mutex}; use futures::{ channel::{mpsc, oneshot}, executor::ThreadPool, @@ -76,9 +73,9 @@ pub struct Scheduler { pool: Arc, run_channels: Option, participant_channels: Arc>>, - participants: Arc>>, + participants: Arc>>, channel_ids: Arc, - channel_listener: RwLock>>, + channel_listener: Mutex>>, #[cfg(feature = "metrics")] metrics: Arc, } @@ -136,9 +133,9 @@ impl Scheduler { pool: Arc::new(ThreadPool::new().unwrap()), run_channels, participant_channels: Arc::new(Mutex::new(Some(participant_channels))), - participants: Arc::new(RwLock::new(HashMap::new())), + participants: Arc::new(Mutex::new(HashMap::new())), channel_ids: Arc::new(AtomicU64::new(0)), - channel_listener: RwLock::new(HashMap::new()), + channel_listener: Mutex::new(HashMap::new()), #[cfg(feature = "metrics")] metrics, }, @@ -180,7 +177,7 @@ impl Scheduler { .inc(); let (end_sender, end_receiver) = oneshot::channel::<()>(); self.channel_listener - .write() + .lock() .await .insert(address.clone(), end_sender); self.channel_creator(address, end_receiver, s2a_listen_result_s) @@ -273,7 +270,7 @@ impl Scheduler { // 3. Participant will try to access the BParticipant senders and receivers with // their next api action, it will fail and be closed then. trace!(?pid, "Got request to close participant"); - if let Some(mut pi) = self.participants.write().await.remove(&pid) { + if let Some(mut pi) = self.participants.lock().await.remove(&pid) { let (finished_sender, finished_receiver) = oneshot::channel(); pi.s2b_shutdown_bparticipant_s .take() @@ -310,7 +307,7 @@ impl Scheduler { a2s_scheduler_shutdown_r.await.unwrap(); self.closed.store(true, Ordering::Relaxed); debug!("Shutting down all BParticipants gracefully"); - let mut participants = self.participants.write().await; + let mut participants = self.participants.lock().await; let waitings = participants .drain() .map(|(pid, mut pi)| { @@ -336,7 +333,7 @@ impl Scheduler { }; } debug!("shutting down protocol listeners"); - for (addr, end_channel_sender) in self.channel_listener.write().await.drain() { + for (addr, end_channel_sender) in self.channel_listener.lock().await.drain() { trace!(?addr, "stopping listen on protocol"); if let Err(e) = end_channel_sender.send(()) { warn!(?addr, ?e, "listener crashed/disconnected already"); @@ -531,7 +528,7 @@ impl Scheduler { ?pid, "Detected that my channel is ready!, activating it :)" ); - let mut participants = participants.write().await; + let mut participants = participants.lock().await; if !participants.contains_key(&pid) { debug!(?cid, "New participant connected via a channel"); let ( diff --git a/network/src/types.rs b/network/src/types.rs index 51fd1843e3..527fc368ae 100644 --- a/network/src/types.rs +++ b/network/src/types.rs @@ -1,4 +1,5 @@ use rand::Rng; +use std::convert::TryFrom; pub type Mid = u64; pub type Cid = u64; @@ -124,6 +125,58 @@ impl Frame { #[cfg(feature = "metrics")] pub fn get_string(&self) -> &str { Self::int_to_string(self.get_int()) } + + pub fn gen_handshake(buf: [u8; 19]) -> Self { + let magic_number = *<&[u8; 7]>::try_from(&buf[0..7]).unwrap(); + Frame::Handshake { + magic_number, + version: [ + u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[7..11]).unwrap()), + u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[11..15]).unwrap()), + u32::from_le_bytes(*<&[u8; 4]>::try_from(&buf[15..19]).unwrap()), + ], + } + } + + pub fn gen_init(buf: [u8; 32]) -> Self { + Frame::Init { + pid: Pid::from_le_bytes(*<&[u8; 16]>::try_from(&buf[0..16]).unwrap()), + secret: u128::from_le_bytes(*<&[u8; 16]>::try_from(&buf[16..32]).unwrap()), + } + } + + pub fn gen_open_stream(buf: [u8; 10]) -> Self { + Frame::OpenStream { + sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), + prio: buf[8], + promises: buf[9], + } + } + + pub fn gen_close_stream(buf: [u8; 8]) -> Self { + Frame::CloseStream { + sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), + } + } + + pub fn gen_data_header(buf: [u8; 24]) -> Self { + Frame::DataHeader { + mid: Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()), + sid: Sid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()), + length: u64::from_le_bytes(*<&[u8; 8]>::try_from(&buf[16..24]).unwrap()), + } + } + + pub fn gen_data(buf: [u8; 18]) -> (Mid, u64, u16) { + let mid = Mid::from_le_bytes(*<&[u8; 8]>::try_from(&buf[0..8]).unwrap()); + let start = u64::from_le_bytes(*<&[u8; 8]>::try_from(&buf[8..16]).unwrap()); + let length = u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[16..18]).unwrap()); + (mid, start, length) + } + + pub fn gen_raw(buf: [u8; 2]) -> u16 { + u16::from_le_bytes(*<&[u8; 2]>::try_from(&buf[0..2]).unwrap()) + } } impl Pid { diff --git a/network/tests/closing.rs b/network/tests/closing.rs index 0a9a51ed93..c40762e9ec 100644 --- a/network/tests/closing.rs +++ b/network/tests/closing.rs @@ -42,11 +42,7 @@ fn close_participant() { let (_n_a, p1_a, mut s1_a, _n_b, p1_b, mut s1_b) = block_on(network_participant_stream(tcp())); block_on(p1_a.disconnect()).unwrap(); - //As no more read/write is run disconnect is successful or already disconnected - match block_on(p1_b.disconnect()) { - Ok(_) | Err(ParticipantError::ParticipantDisconnected) => (), - e => panic!("wrong disconnect type {:?}", e), - }; + block_on(p1_b.disconnect()).unwrap(); assert_eq!(s1_a.send("Hello World"), Err(StreamError::StreamClosed)); assert_eq!( @@ -285,6 +281,7 @@ fn failed_stream_open_after_remote_part_is_closed() { #[test] fn open_participant_before_remote_part_is_closed() { + let (_, _) = helper::setup(false, 0); let (n_a, f) = Network::new(Pid::fake(0)); std::thread::spawn(f); let (n_b, f) = Network::new(Pid::fake(1)); @@ -305,6 +302,7 @@ fn open_participant_before_remote_part_is_closed() { #[test] fn open_participant_after_remote_part_is_closed() { + let (_, _) = helper::setup(false, 0); let (n_a, f) = Network::new(Pid::fake(0)); std::thread::spawn(f); let (n_b, f) = Network::new(Pid::fake(1)); @@ -325,6 +323,7 @@ fn open_participant_after_remote_part_is_closed() { #[test] fn close_network_scheduler_completely() { + let (_, _) = helper::setup(false, 0); let (n_a, f) = Network::new(Pid::fake(0)); let ha = std::thread::spawn(f); let (n_b, f) = Network::new(Pid::fake(1));