Merge branch 'xMAC94x/protocol_errors' into 'master'

Have a clear error for when the I/O closes and when some protocol is violated....

See merge request veloren/veloren!2082
This commit is contained in:
Marcel 2021-04-12 22:44:49 +00:00
commit ef171478f6
4 changed files with 87 additions and 69 deletions

View File

@ -11,13 +11,21 @@ pub enum InitProtocolError {
/// When you return closed you must stay closed! /// When you return closed you must stay closed!
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum ProtocolError { pub enum ProtocolError {
/// Closed indicates the underlying I/O got closed
/// e.g. the TCP, UDP or MPSC connection is dropped by the OS
Closed, Closed,
/// Violated indicates the veloren_network_protocol was violated
/// the underlying I/O connection is still valid, but the remote side
/// send WRONG (e.g. Invalid, or wrong order) data on the protocol layer.
Violated,
} }
impl From<ProtocolError> for InitProtocolError { impl From<ProtocolError> for InitProtocolError {
fn from(err: ProtocolError) -> Self { fn from(err: ProtocolError) -> Self {
match err { match err {
ProtocolError::Closed => InitProtocolError::Closed, ProtocolError::Closed => InitProtocolError::Closed,
// not possible as the Init has raw access to the I/O
ProtocolError::Violated => InitProtocolError::Closed,
} }
} }
} }
@ -46,6 +54,7 @@ impl core::fmt::Display for ProtocolError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self { match self {
ProtocolError::Closed => write!(f, "Channel closed"), ProtocolError::Closed => write!(f, "Channel closed"),
ProtocolError::Violated => write!(f, "Channel protocol violated"),
} }
} }
} }

View File

@ -206,10 +206,12 @@ impl OTFrame {
} }
impl ITFrame { impl ITFrame {
pub(crate) fn read_frame(bytes: &mut BytesMut) -> Option<Self> { /// Err => cannot recover
/// Ok(None) => waiting for more data
pub(crate) fn read_frame(bytes: &mut BytesMut) -> Result<Option<Self>, ()> {
let frame_no = match bytes.first() { let frame_no = match bytes.first() {
Some(&f) => f, Some(&f) => f,
None => return None, None => return Ok(None),
}; };
let size = match frame_no { let size = match frame_no {
FRAME_SHUTDOWN => TCP_SHUTDOWN_CNS, FRAME_SHUTDOWN => TCP_SHUTDOWN_CNS,
@ -218,15 +220,15 @@ impl ITFrame {
FRAME_DATA_HEADER => TCP_DATA_HEADER_CNS, FRAME_DATA_HEADER => TCP_DATA_HEADER_CNS,
FRAME_DATA => { FRAME_DATA => {
if bytes.len() < 9 + 1 + 1 { if bytes.len() < 9 + 1 + 1 {
return None; return Ok(None);
} }
u16::from_le_bytes([bytes[8 + 1], bytes[9 + 1]]) as usize + TCP_DATA_CNS u16::from_le_bytes([bytes[8 + 1], bytes[9 + 1]]) as usize + TCP_DATA_CNS
}, },
_ => return None, _ => return Err(()),
}; };
if bytes.len() < size + 1 { if bytes.len() < size + 1 {
return None; return Ok(None);
} }
let frame = match frame_no { let frame = match frame_no {
@ -270,7 +272,7 @@ impl ITFrame {
}, },
_ => unreachable!("Frame::to_frame should be handled before!"), _ => unreachable!("Frame::to_frame should be handled before!"),
}; };
Some(frame) Ok(Some(frame))
} }
} }
@ -393,7 +395,7 @@ mod tests {
for frame in get_otframes() { for frame in get_otframes() {
println!("frame: {:?}", &frame); println!("frame: {:?}", &frame);
assert_eq!(frame.clone(), dupl(frame).expect("NONE")); assert_eq!(frame.clone(), dupl(frame).expect("ERR").expect("NONE"));
} }
} }
@ -416,7 +418,7 @@ mod tests {
// compare // compare
for (f, fd) in frames.drain(..).zip(framesd.drain(..)) { for (f, fd) in frames.drain(..).zip(framesd.drain(..)) {
println!("frame: {:?}", &f); println!("frame: {:?}", &f);
assert_eq!(f, fd.expect("NONE")); assert_eq!(f, fd.expect("ERR").expect("NONE"));
} }
} }
@ -430,7 +432,7 @@ mod tests {
assert_eq!(buffer.len(), SIZE); assert_eq!(buffer.len(), SIZE);
let mut deque = buffer.iter().copied().collect(); let mut deque = buffer.iter().copied().collect();
let frame2 = ITFrame::read_frame(&mut deque); let frame2 = ITFrame::read_frame(&mut deque);
assert_eq!(frame1, frame2.expect("NONE")); assert_eq!(frame1, frame2.expect("ERR").expect("NONE"));
} }
#[test] #[test]
@ -516,13 +518,13 @@ mod tests {
OTFrame::write_bytes(frame1, &mut buffer); OTFrame::write_bytes(frame1, &mut buffer);
buffer.truncate(6); // simulate partial retrieve buffer.truncate(6); // simulate partial retrieve
let frame1d = ITFrame::read_frame(&mut buffer); let frame1d = ITFrame::read_frame(&mut buffer);
assert_eq!(frame1d, None); assert_eq!(frame1d, Ok(None));
} }
#[test] #[test]
fn frame_rubish() { fn frame_rubish() {
let mut buffer = BytesMut::from(&b"dtrgwcser"[..]); let mut buffer = BytesMut::from(&b"dtrgwcser"[..]);
assert_eq!(ITFrame::read_frame(&mut buffer), None); assert_eq!(ITFrame::read_frame(&mut buffer), Err(()));
} }
#[test] #[test]
@ -537,7 +539,7 @@ mod tests {
OTFrame::write_bytes(frame1, &mut buffer); OTFrame::write_bytes(frame1, &mut buffer);
buffer[9] = 255; buffer[9] = 255;
let framed = ITFrame::read_frame(&mut buffer); let framed = ITFrame::read_frame(&mut buffer);
assert_eq!(framed, None); assert_eq!(framed, Ok(None));
} }
#[test] #[test]
@ -554,13 +556,13 @@ mod tests {
let framed = ITFrame::read_frame(&mut buffer); let framed = ITFrame::read_frame(&mut buffer);
assert_eq!( assert_eq!(
framed, framed,
Some(ITFrame::Data { Ok(Some(ITFrame::Data {
mid: 7u64, mid: 7u64,
data: BytesMut::from(&b"foo"[..]), data: BytesMut::from(&b"foo"[..]),
}) }))
); );
//next = Invalid => Empty //next = Invalid => Empty
let framed = ITFrame::read_frame(&mut buffer); let framed = ITFrame::read_frame(&mut buffer);
assert_eq!(framed, None); assert_eq!(framed, Err(()));
} }
} }

View File

@ -142,7 +142,7 @@ where
} }
Ok(e) Ok(e)
}, },
MpscMsg::InitFrame(_) => Err(ProtocolError::Closed), MpscMsg::InitFrame(_) => Err(ProtocolError::Violated),
} }
} }
} }
@ -164,7 +164,7 @@ where
{ {
async fn recv(&mut self) -> Result<InitFrame, ProtocolError> { async fn recv(&mut self) -> Result<InitFrame, ProtocolError> {
match self.sink.recv().await? { match self.sink.recv().await? {
MpscMsg::Event(_) => Err(ProtocolError::Closed), MpscMsg::Event(_) => Err(ProtocolError::Violated),
MpscMsg::InitFrame(f) => Ok(f), MpscMsg::InitFrame(f) => Ok(f),
} }
} }

View File

@ -231,60 +231,67 @@ where
{ {
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError> { async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError> {
'outer: loop { 'outer: loop {
while let Some(frame) = ITFrame::read_frame(&mut self.buffer) { loop {
#[cfg(feature = "trace_pedantic")] match ITFrame::read_frame(&mut self.buffer) {
trace!(?frame, "recv"); Ok(Some(frame)) => {
match frame { #[cfg(feature = "trace_pedantic")]
ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown), trace!(?frame, "recv");
ITFrame::OpenStream { match frame {
sid, ITFrame::Shutdown => break 'outer Ok(ProtocolEvent::Shutdown),
prio, ITFrame::OpenStream {
promises, sid,
guaranteed_bandwidth, prio,
} => { promises,
break 'outer Ok(ProtocolEvent::OpenStream { guaranteed_bandwidth,
sid, } => {
prio: prio.min(crate::types::HIGHEST_PRIO), break 'outer Ok(ProtocolEvent::OpenStream {
promises, sid,
guaranteed_bandwidth, prio: prio.min(crate::types::HIGHEST_PRIO),
}); promises,
}, guaranteed_bandwidth,
ITFrame::CloseStream { sid } => { });
break 'outer Ok(ProtocolEvent::CloseStream { sid }); },
}, ITFrame::CloseStream { sid } => {
ITFrame::DataHeader { sid, mid, length } => { break 'outer Ok(ProtocolEvent::CloseStream { sid });
let m = ITMessage::new(sid, length, &mut self.itmsg_allocator); },
self.metrics.rmsg_ib(sid, length); ITFrame::DataHeader { sid, mid, length } => {
self.incoming.insert(mid, m); let m = ITMessage::new(sid, length, &mut self.itmsg_allocator);
}, self.metrics.rmsg_ib(sid, length);
ITFrame::Data { mid, data } => { self.incoming.insert(mid, m);
self.metrics.rdata_frames_b(data.len() as u64); },
let m = match self.incoming.get_mut(&mid) { ITFrame::Data { mid, data } => {
Some(m) => m, self.metrics.rdata_frames_b(data.len() as u64);
None => { let m = match self.incoming.get_mut(&mid) {
info!( Some(m) => m,
?mid, None => {
"protocol violation by remote side: send Data before Header" info!(
); ?mid,
break 'outer Err(ProtocolError::Closed); "protocol violation by remote side: send Data before \
Header"
);
break 'outer Err(ProtocolError::Violated);
},
};
m.data.extend_from_slice(&data);
if m.data.len() == m.length as usize {
// finished, yay
let m = self.incoming.remove(&mid).unwrap();
self.metrics.rmsg_ob(
m.sid,
RemoveReason::Finished,
m.data.len() as u64,
);
break 'outer Ok(ProtocolEvent::Message {
sid: m.sid,
data: m.data.freeze(),
});
}
}, },
}; };
m.data.extend_from_slice(&data);
if m.data.len() == m.length as usize {
// finished, yay
let m = self.incoming.remove(&mid).unwrap();
self.metrics.rmsg_ob(
m.sid,
RemoveReason::Finished,
m.data.len() as u64,
);
break 'outer Ok(ProtocolEvent::Message {
sid: m.sid,
data: m.data.freeze(),
});
}
}, },
}; Ok(None) => break, //inner => read more data
Err(()) => return Err(ProtocolError::Violated),
}
} }
let chunk = self.sink.recv().await?; let chunk = self.sink.recv().await?;
if self.buffer.is_empty() { if self.buffer.is_empty() {
@ -321,7 +328,7 @@ where
return Ok(frame); return Ok(frame);
} }
} }
Err(ProtocolError::Closed) Err(ProtocolError::Violated)
} }
} }