diff --git a/network/protocol/src/error.rs b/network/protocol/src/error.rs index 089208b24f..27199197d2 100644 --- a/network/protocol/src/error.rs +++ b/network/protocol/src/error.rs @@ -11,13 +11,21 @@ pub enum InitProtocolError { /// When you return closed you must stay closed! #[derive(Debug, PartialEq)] pub enum ProtocolError { + /// Closed indicates the underlying I/O got closed + /// e.g. the TCP, UDP or MPSC connection is dropped by the OS Closed, + /// Violated indicates the veloren_network_protocol was violated + /// the underlying I/O connection is still valid, but the remote side + /// send WRONG (e.g. Invalid, or wrong order) data on the protocol layer. + Violated, } impl From for InitProtocolError { fn from(err: ProtocolError) -> Self { match err { ProtocolError::Closed => InitProtocolError::Closed, + // not possible as the Init has raw access to the I/O + ProtocolError::Violated => InitProtocolError::Closed, } } } @@ -46,6 +54,7 @@ impl core::fmt::Display for ProtocolError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { ProtocolError::Closed => write!(f, "Channel closed"), + ProtocolError::Violated => write!(f, "Channel protocol violated"), } } } diff --git a/network/protocol/src/frame.rs b/network/protocol/src/frame.rs index 20a0439953..aa44af14ed 100644 --- a/network/protocol/src/frame.rs +++ b/network/protocol/src/frame.rs @@ -206,10 +206,12 @@ impl OTFrame { } impl ITFrame { - pub(crate) fn read_frame(bytes: &mut BytesMut) -> Option { + /// Err => cannot recover + /// Ok(None) => waiting for more data + pub(crate) fn read_frame(bytes: &mut BytesMut) -> Result, ()> { let frame_no = match bytes.first() { Some(&f) => f, - None => return None, + None => return Ok(None), }; let size = match frame_no { FRAME_SHUTDOWN => TCP_SHUTDOWN_CNS, @@ -218,15 +220,15 @@ impl ITFrame { FRAME_DATA_HEADER => TCP_DATA_HEADER_CNS, FRAME_DATA => { if bytes.len() < 9 + 1 + 1 { - return None; + return Ok(None); } u16::from_le_bytes([bytes[8 + 1], bytes[9 + 1]]) as usize + TCP_DATA_CNS }, - _ => return None, + _ => return Err(()), }; if bytes.len() < size + 1 { - return None; + return Ok(None); } let frame = match frame_no { @@ -270,7 +272,7 @@ impl ITFrame { }, _ => unreachable!("Frame::to_frame should be handled before!"), }; - Some(frame) + Ok(Some(frame)) } } @@ -393,7 +395,7 @@ mod tests { for frame in get_otframes() { println!("frame: {:?}", &frame); - assert_eq!(frame.clone(), dupl(frame).expect("NONE")); + assert_eq!(frame.clone(), dupl(frame).expect("ERR").expect("NONE")); } } @@ -416,7 +418,7 @@ mod tests { // compare for (f, fd) in frames.drain(..).zip(framesd.drain(..)) { println!("frame: {:?}", &f); - assert_eq!(f, fd.expect("NONE")); + assert_eq!(f, fd.expect("ERR").expect("NONE")); } } @@ -430,7 +432,7 @@ mod tests { assert_eq!(buffer.len(), SIZE); let mut deque = buffer.iter().copied().collect(); let frame2 = ITFrame::read_frame(&mut deque); - assert_eq!(frame1, frame2.expect("NONE")); + assert_eq!(frame1, frame2.expect("ERR").expect("NONE")); } #[test] @@ -516,13 +518,13 @@ mod tests { OTFrame::write_bytes(frame1, &mut buffer); buffer.truncate(6); // simulate partial retrieve let frame1d = ITFrame::read_frame(&mut buffer); - assert_eq!(frame1d, None); + assert_eq!(frame1d, Ok(None)); } #[test] fn frame_rubish() { let mut buffer = BytesMut::from(&b"dtrgwcser"[..]); - assert_eq!(ITFrame::read_frame(&mut buffer), None); + assert_eq!(ITFrame::read_frame(&mut buffer), Err(())); } #[test] @@ -537,7 +539,7 @@ mod tests { OTFrame::write_bytes(frame1, &mut buffer); buffer[9] = 255; let framed = ITFrame::read_frame(&mut buffer); - assert_eq!(framed, None); + assert_eq!(framed, Ok(None)); } #[test] @@ -554,13 +556,13 @@ mod tests { let framed = ITFrame::read_frame(&mut buffer); assert_eq!( framed, - Some(ITFrame::Data { + Ok(Some(ITFrame::Data { mid: 7u64, data: BytesMut::from(&b"foo"[..]), - }) + })) ); //next = Invalid => Empty let framed = ITFrame::read_frame(&mut buffer); - assert_eq!(framed, None); + assert_eq!(framed, Err(())); } } diff --git a/network/protocol/src/mpsc.rs b/network/protocol/src/mpsc.rs index 1f3219bab2..f4a5eee1cd 100644 --- a/network/protocol/src/mpsc.rs +++ b/network/protocol/src/mpsc.rs @@ -142,7 +142,7 @@ where } Ok(e) }, - MpscMsg::InitFrame(_) => Err(ProtocolError::Closed), + MpscMsg::InitFrame(_) => Err(ProtocolError::Violated), } } } @@ -164,7 +164,7 @@ where { async fn recv(&mut self) -> Result { match self.sink.recv().await? { - MpscMsg::Event(_) => Err(ProtocolError::Closed), + MpscMsg::Event(_) => Err(ProtocolError::Violated), MpscMsg::InitFrame(f) => Ok(f), } } diff --git a/network/protocol/src/tcp.rs b/network/protocol/src/tcp.rs index e78741fc01..43d14e2a1e 100644 --- a/network/protocol/src/tcp.rs +++ b/network/protocol/src/tcp.rs @@ -231,60 +231,67 @@ where { async fn recv(&mut self) -> Result { 'outer: loop { - while let Some(frame) = ITFrame::read_frame(&mut self.buffer) { - #[cfg(feature = "trace_pedantic")] - trace!(?frame, "recv"); - match frame { - ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), - ITFrame::OpenStream { - sid, - prio, - promises, - guaranteed_bandwidth, - } => { - break 'outer Ok(ProtocolEvent::OpenStream { - sid, - prio: prio.min(crate::types::HIGHEST_PRIO), - promises, - guaranteed_bandwidth, - }); - }, - ITFrame::CloseStream { sid } => { - break 'outer Ok(ProtocolEvent::CloseStream { sid }); - }, - ITFrame::DataHeader { sid, mid, length } => { - let m = ITMessage::new(sid, length, &mut self.itmsg_allocator); - self.metrics.rmsg_ib(sid, length); - self.incoming.insert(mid, m); - }, - ITFrame::Data { mid, data } => { - self.metrics.rdata_frames_b(data.len() as u64); - let m = match self.incoming.get_mut(&mid) { - Some(m) => m, - None => { - info!( - ?mid, - "protocol violation by remote side: send Data before Header" - ); - break 'outer Err(ProtocolError::Closed); + loop { + match ITFrame::read_frame(&mut self.buffer) { + Ok(Some(frame)) => { + #[cfg(feature = "trace_pedantic")] + trace!(?frame, "recv"); + match frame { + ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), + ITFrame::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + break 'outer Ok(ProtocolEvent::OpenStream { + sid, + prio: prio.min(crate::types::HIGHEST_PRIO), + promises, + guaranteed_bandwidth, + }); + }, + ITFrame::CloseStream { sid } => { + break 'outer Ok(ProtocolEvent::CloseStream { sid }); + }, + ITFrame::DataHeader { sid, mid, length } => { + let m = ITMessage::new(sid, length, &mut self.itmsg_allocator); + self.metrics.rmsg_ib(sid, length); + self.incoming.insert(mid, m); + }, + ITFrame::Data { mid, data } => { + self.metrics.rdata_frames_b(data.len() as u64); + let m = match self.incoming.get_mut(&mid) { + Some(m) => m, + None => { + info!( + ?mid, + "protocol violation by remote side: send Data before \ + Header" + ); + break 'outer Err(ProtocolError::Violated); + }, + }; + m.data.extend_from_slice(&data); + if m.data.len() == m.length as usize { + // finished, yay + let m = self.incoming.remove(&mid).unwrap(); + self.metrics.rmsg_ob( + m.sid, + RemoveReason::Finished, + m.data.len() as u64, + ); + break 'outer Ok(ProtocolEvent::Message { + sid: m.sid, + data: m.data.freeze(), + }); + } }, }; - m.data.extend_from_slice(&data); - if m.data.len() == m.length as usize { - // finished, yay - let m = self.incoming.remove(&mid).unwrap(); - self.metrics.rmsg_ob( - m.sid, - RemoveReason::Finished, - m.data.len() as u64, - ); - break 'outer Ok(ProtocolEvent::Message { - sid: m.sid, - data: m.data.freeze(), - }); - } }, - }; + Ok(None) => break, //inner => read more data + Err(()) => return Err(ProtocolError::Violated), + } } let chunk = self.sink.recv().await?; if self.buffer.is_empty() { @@ -321,7 +328,7 @@ where return Ok(frame); } } - Err(ProtocolError::Closed) + Err(ProtocolError::Violated) } }