From 9638581d09a0c2e1cde5bd3bb26d3246c1671906 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= Date: Wed, 7 Apr 2021 19:11:16 +0200 Subject: [PATCH] ITFrame::read_frame now throws an error when the frame_no is invalid. This will be catched by the respective protocols, e.g. tcp and cause a violation --- network/protocol/src/frame.rs | 32 +++++----- network/protocol/src/tcp.rs | 109 ++++++++++++++++++---------------- 2 files changed, 75 insertions(+), 66 deletions(-) diff --git a/network/protocol/src/frame.rs b/network/protocol/src/frame.rs index 20a0439953..aa44af14ed 100644 --- a/network/protocol/src/frame.rs +++ b/network/protocol/src/frame.rs @@ -206,10 +206,12 @@ impl OTFrame { } impl ITFrame { - pub(crate) fn read_frame(bytes: &mut BytesMut) -> Option { + /// Err => cannot recover + /// Ok(None) => waiting for more data + pub(crate) fn read_frame(bytes: &mut BytesMut) -> Result, ()> { let frame_no = match bytes.first() { Some(&f) => f, - None => return None, + None => return Ok(None), }; let size = match frame_no { FRAME_SHUTDOWN => TCP_SHUTDOWN_CNS, @@ -218,15 +220,15 @@ impl ITFrame { FRAME_DATA_HEADER => TCP_DATA_HEADER_CNS, FRAME_DATA => { if bytes.len() < 9 + 1 + 1 { - return None; + return Ok(None); } u16::from_le_bytes([bytes[8 + 1], bytes[9 + 1]]) as usize + TCP_DATA_CNS }, - _ => return None, + _ => return Err(()), }; if bytes.len() < size + 1 { - return None; + return Ok(None); } let frame = match frame_no { @@ -270,7 +272,7 @@ impl ITFrame { }, _ => unreachable!("Frame::to_frame should be handled before!"), }; - Some(frame) + Ok(Some(frame)) } } @@ -393,7 +395,7 @@ mod tests { for frame in get_otframes() { println!("frame: {:?}", &frame); - assert_eq!(frame.clone(), dupl(frame).expect("NONE")); + assert_eq!(frame.clone(), dupl(frame).expect("ERR").expect("NONE")); } } @@ -416,7 +418,7 @@ mod tests { // compare for (f, fd) in frames.drain(..).zip(framesd.drain(..)) { println!("frame: {:?}", &f); - assert_eq!(f, fd.expect("NONE")); + assert_eq!(f, fd.expect("ERR").expect("NONE")); } } @@ -430,7 +432,7 @@ mod tests { assert_eq!(buffer.len(), SIZE); let mut deque = buffer.iter().copied().collect(); let frame2 = ITFrame::read_frame(&mut deque); - assert_eq!(frame1, frame2.expect("NONE")); + assert_eq!(frame1, frame2.expect("ERR").expect("NONE")); } #[test] @@ -516,13 +518,13 @@ mod tests { OTFrame::write_bytes(frame1, &mut buffer); buffer.truncate(6); // simulate partial retrieve let frame1d = ITFrame::read_frame(&mut buffer); - assert_eq!(frame1d, None); + assert_eq!(frame1d, Ok(None)); } #[test] fn frame_rubish() { let mut buffer = BytesMut::from(&b"dtrgwcser"[..]); - assert_eq!(ITFrame::read_frame(&mut buffer), None); + assert_eq!(ITFrame::read_frame(&mut buffer), Err(())); } #[test] @@ -537,7 +539,7 @@ mod tests { OTFrame::write_bytes(frame1, &mut buffer); buffer[9] = 255; let framed = ITFrame::read_frame(&mut buffer); - assert_eq!(framed, None); + assert_eq!(framed, Ok(None)); } #[test] @@ -554,13 +556,13 @@ mod tests { let framed = ITFrame::read_frame(&mut buffer); assert_eq!( framed, - Some(ITFrame::Data { + Ok(Some(ITFrame::Data { mid: 7u64, data: BytesMut::from(&b"foo"[..]), - }) + })) ); //next = Invalid => Empty let framed = ITFrame::read_frame(&mut buffer); - assert_eq!(framed, None); + assert_eq!(framed, Err(())); } } diff --git a/network/protocol/src/tcp.rs b/network/protocol/src/tcp.rs index 4e80736342..43d14e2a1e 100644 --- a/network/protocol/src/tcp.rs +++ b/network/protocol/src/tcp.rs @@ -231,60 +231,67 @@ where { async fn recv(&mut self) -> Result { 'outer: loop { - while let Some(frame) = ITFrame::read_frame(&mut self.buffer) { - #[cfg(feature = "trace_pedantic")] - trace!(?frame, "recv"); - match frame { - ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), - ITFrame::OpenStream { - sid, - prio, - promises, - guaranteed_bandwidth, - } => { - break 'outer Ok(ProtocolEvent::OpenStream { - sid, - prio: prio.min(crate::types::HIGHEST_PRIO), - promises, - guaranteed_bandwidth, - }); - }, - ITFrame::CloseStream { sid } => { - break 'outer Ok(ProtocolEvent::CloseStream { sid }); - }, - ITFrame::DataHeader { sid, mid, length } => { - let m = ITMessage::new(sid, length, &mut self.itmsg_allocator); - self.metrics.rmsg_ib(sid, length); - self.incoming.insert(mid, m); - }, - ITFrame::Data { mid, data } => { - self.metrics.rdata_frames_b(data.len() as u64); - let m = match self.incoming.get_mut(&mid) { - Some(m) => m, - None => { - info!( - ?mid, - "protocol violation by remote side: send Data before Header" - ); - break 'outer Err(ProtocolError::Violated); + loop { + match ITFrame::read_frame(&mut self.buffer) { + Ok(Some(frame)) => { + #[cfg(feature = "trace_pedantic")] + trace!(?frame, "recv"); + match frame { + ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), + ITFrame::OpenStream { + sid, + prio, + promises, + guaranteed_bandwidth, + } => { + break 'outer Ok(ProtocolEvent::OpenStream { + sid, + prio: prio.min(crate::types::HIGHEST_PRIO), + promises, + guaranteed_bandwidth, + }); + }, + ITFrame::CloseStream { sid } => { + break 'outer Ok(ProtocolEvent::CloseStream { sid }); + }, + ITFrame::DataHeader { sid, mid, length } => { + let m = ITMessage::new(sid, length, &mut self.itmsg_allocator); + self.metrics.rmsg_ib(sid, length); + self.incoming.insert(mid, m); + }, + ITFrame::Data { mid, data } => { + self.metrics.rdata_frames_b(data.len() as u64); + let m = match self.incoming.get_mut(&mid) { + Some(m) => m, + None => { + info!( + ?mid, + "protocol violation by remote side: send Data before \ + Header" + ); + break 'outer Err(ProtocolError::Violated); + }, + }; + m.data.extend_from_slice(&data); + if m.data.len() == m.length as usize { + // finished, yay + let m = self.incoming.remove(&mid).unwrap(); + self.metrics.rmsg_ob( + m.sid, + RemoveReason::Finished, + m.data.len() as u64, + ); + break 'outer Ok(ProtocolEvent::Message { + sid: m.sid, + data: m.data.freeze(), + }); + } }, }; - m.data.extend_from_slice(&data); - if m.data.len() == m.length as usize { - // finished, yay - let m = self.incoming.remove(&mid).unwrap(); - self.metrics.rmsg_ob( - m.sid, - RemoveReason::Finished, - m.data.len() as u64, - ); - break 'outer Ok(ProtocolEvent::Message { - sid: m.sid, - data: m.data.freeze(), - }); - } }, - }; + Ok(None) => break, //inner => read more data + Err(()) => return Err(ProtocolError::Violated), + } } let chunk = self.sink.recv().await?; if self.buffer.is_empty() {