export unterlying errors via network crate, to generate more detailed logs

This commit is contained in:
Marcel Märtens 2022-05-20 12:20:57 +02:00
parent 8e0295ff47
commit e194a2e334
12 changed files with 298 additions and 155 deletions

2
Cargo.lock generated
View File

@ -6690,7 +6690,7 @@ dependencies = [
[[package]]
name = "veloren-network-protocol"
version = "0.6.0"
version = "0.6.1"
dependencies = [
"async-channel",
"async-trait",

View File

@ -1,7 +1,7 @@
[package]
name = "veloren-network-protocol"
description = "pure Protocol without any I/O itself"
version = "0.6.0"
version = "0.6.1"
authors = ["Marcel Märtens <marcel.cochem@googlemail.com>"]
edition = "2021"
@ -24,7 +24,7 @@ rand = { version = "0.8" }
# async traits
async-trait = "0.1.42"
bytes = "^1"
hashbrown = { version = ">=0.9, <0.13" }
hashbrown = { version = ">=0.12, <0.13" }
[dev-dependencies]
async-channel = "1.5.1"

View File

@ -271,73 +271,88 @@ mod utils {
#[async_trait]
impl UnreliableDrain for ACDrain {
type CustomErr = ();
type DataFormat = MpscMsg;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> {
async fn send(
&mut self,
data: Self::DataFormat,
) -> Result<(), ProtocolError<Self::CustomErr>> {
self.sender
.send(data)
.await
.map_err(|_| ProtocolError::Closed)
.map_err(|_| ProtocolError::Custom(()))
}
}
#[async_trait]
impl UnreliableSink for ACSink {
type CustomErr = ();
type DataFormat = MpscMsg;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> {
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
self.receiver
.recv()
.await
.map_err(|_| ProtocolError::Closed)
.map_err(|_| ProtocolError::Custom(()))
}
}
#[async_trait]
impl UnreliableDrain for TcpDrain {
type CustomErr = ();
type DataFormat = BytesMut;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> {
async fn send(
&mut self,
data: Self::DataFormat,
) -> Result<(), ProtocolError<Self::CustomErr>> {
self.sender
.send(data)
.await
.map_err(|_| ProtocolError::Closed)
.map_err(|_| ProtocolError::Custom(()))
}
}
#[async_trait]
impl UnreliableSink for TcpSink {
type CustomErr = ();
type DataFormat = BytesMut;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> {
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
self.receiver
.recv()
.await
.map_err(|_| ProtocolError::Closed)
.map_err(|_| ProtocolError::Custom(()))
}
}
#[async_trait]
impl UnreliableDrain for QuicDrain {
type CustomErr = ();
type DataFormat = QuicDataFormat;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> {
async fn send(
&mut self,
data: Self::DataFormat,
) -> Result<(), ProtocolError<Self::CustomErr>> {
self.sender
.send(data)
.await
.map_err(|_| ProtocolError::Closed)
.map_err(|_| ProtocolError::Custom(()))
}
}
#[async_trait]
impl UnreliableSink for QuicSink {
type CustomErr = ();
type DataFormat = QuicDataFormat;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> {
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
self.receiver
.recv()
.await
.map_err(|_| ProtocolError::Closed)
.map_err(|_| ProtocolError::Custom(()))
}
}
}

View File

@ -2,38 +2,50 @@
///
/// [`InitProtocol`]: crate::InitProtocol
#[derive(Debug, PartialEq)]
pub enum InitProtocolError {
Closed,
pub enum InitProtocolError<E: std::fmt::Debug + Send> {
Custom(E),
/// expected Handshake, didn't get handshake
NotHandshake,
/// expected Id, didn't get id
NotId,
WrongMagicNumber([u8; 7]),
WrongVersion([u32; 3]),
}
/// When you return closed you must stay closed!
#[derive(Debug, PartialEq)]
pub enum ProtocolError {
/// Closed indicates the underlying I/O got closed
pub enum ProtocolError<E: std::fmt::Debug + Send> {
/// Custom Error on the underlying I/O,
/// e.g. the TCP, UDP or MPSC connection is dropped by the OS
Closed,
Custom(E),
/// Violated indicates the veloren_network_protocol was violated
/// the underlying I/O connection is still valid, but the remote side
/// send WRONG (e.g. Invalid, or wrong order) data on the protocol layer.
Violated,
}
impl From<ProtocolError> for InitProtocolError {
fn from(err: ProtocolError) -> Self {
impl<E: std::fmt::Debug + Send> From<ProtocolError<E>> for InitProtocolError<E> {
fn from(err: ProtocolError<E>) -> Self {
match err {
ProtocolError::Closed => InitProtocolError::Closed,
// not possible as the Init has raw access to the I/O
ProtocolError::Violated => InitProtocolError::Closed,
ProtocolError::Custom(e) => InitProtocolError::Custom(e),
ProtocolError::Violated => {
unreachable!("not possible as the Init has raw access to the I/O")
},
}
}
}
impl core::fmt::Display for InitProtocolError {
impl<E: std::fmt::Debug + Send> core::fmt::Display for InitProtocolError<E> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
InitProtocolError::Closed => write!(f, "Channel closed"),
InitProtocolError::Custom(e) => write!(f, "custom: {:?}", e),
InitProtocolError::NotHandshake => write!(
f,
"Remote send something which couldn't be parsed as a handshake"
),
InitProtocolError::NotId => {
write!(f, "Remote send something which couldn't be parsed as an id")
},
InitProtocolError::WrongMagicNumber(r) => write!(
f,
"Magic Number doesn't match, remote side send '{:?}' instead of '{:?}'",
@ -50,14 +62,14 @@ impl core::fmt::Display for InitProtocolError {
}
}
impl core::fmt::Display for ProtocolError {
impl<E: std::fmt::Debug + Send> core::fmt::Display for ProtocolError<E> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
ProtocolError::Closed => write!(f, "Channel closed"),
ProtocolError::Custom(e) => write!(f, "Channel custom close: {:?}", e),
ProtocolError::Violated => write!(f, "Channel protocol violated"),
}
}
}
impl std::error::Error for InitProtocolError {}
impl std::error::Error for ProtocolError {}
impl<E: std::fmt::Debug + Send> std::error::Error for InitProtocolError<E> {}
impl<E: std::fmt::Debug + Send> std::error::Error for ProtocolError<E> {}

View File

@ -21,7 +21,8 @@ use tracing::{debug, error, info, trace};
/// [`RecvProtocol`]: crate::RecvProtocol
#[async_trait]
pub trait ReliableDrain {
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError>;
type CustomErr: std::fmt::Debug + Send;
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError<Self::CustomErr>>;
}
/// Implement this for auto Handshake with [`ReliableDrain`]. See
@ -30,21 +31,25 @@ pub trait ReliableDrain {
/// [`ReliableDrain`]: crate::ReliableDrain
#[async_trait]
pub trait ReliableSink {
async fn recv(&mut self) -> Result<InitFrame, ProtocolError>;
type CustomErr: std::fmt::Debug + Send;
async fn recv(&mut self) -> Result<InitFrame, ProtocolError<Self::CustomErr>>;
}
#[async_trait]
impl<D, S> InitProtocol for (D, S)
impl<D, S, E> InitProtocol for (D, S)
where
D: ReliableDrain + Send,
S: ReliableSink + Send,
D: ReliableDrain<CustomErr = E> + Send,
S: ReliableSink<CustomErr = E> + Send,
E: std::fmt::Debug + Send,
{
type CustomErr = E;
async fn initialize(
&mut self,
initializer: bool,
local_pid: Pid,
local_secret: u128,
) -> Result<(Pid, Sid, u128), InitProtocolError> {
) -> Result<(Pid, Sid, u128), InitProtocolError<E>> {
#[cfg(debug_assertions)]
const WRONG_NUMBER: &str = "Handshake does not contain the magic number required by \
veloren server.\nWe are not sure if you are a valid veloren \
@ -122,11 +127,11 @@ where
Ok(string) => error!(?string, ERR_S),
_ => error!(?bytes, ERR_S),
}
Err(InitProtocolError::Closed)
Err(InitProtocolError::NotHandshake)
},
_ => {
info!("Handshake failed");
Err(InitProtocolError::Closed)
Err(InitProtocolError::NotHandshake)
},
}?;
@ -152,11 +157,11 @@ where
Ok(string) => error!(?string, ERR_S),
_ => error!(?bytes, ERR_S),
}
Err(InitProtocolError::Closed)
Err(InitProtocolError::NotId)
},
_ => {
info!("Handshake failed");
Err(InitProtocolError::Closed)
Err(InitProtocolError::NotId)
},
}
}
@ -176,7 +181,7 @@ mod tests {
let _ = p2;
});
let (r1, _) = tokio::join!(r1, r2);
assert_eq!(r1.unwrap(), Err(InitProtocolError::Closed));
assert_eq!(r1.unwrap(), Err(InitProtocolError::Custom(())));
}
#[tokio::test]
@ -191,7 +196,7 @@ mod tests {
})
.await?;
let _ = p2.1.recv().await?;
Result::<(), InitProtocolError>::Ok(())
Result::<(), InitProtocolError<()>>::Ok(())
});
let (r1, r2) = tokio::join!(r1, r2);
assert_eq!(
@ -218,7 +223,7 @@ mod tests {
});
let (r1, r2) = tokio::join!(r1, r2);
assert_eq!(r1.unwrap(), Err(InitProtocolError::WrongVersion([0, 1, 2])));
assert_eq!(r2.unwrap(), Err(InitProtocolError::Closed));
assert_eq!(r2.unwrap(), Err(InitProtocolError::Custom(())));
}
#[tokio::test]
@ -234,10 +239,10 @@ mod tests {
.await?;
let _ = p2.1.recv().await?;
p2.0.send(InitFrame::Raw(b"Hello World".to_vec())).await?;
Result::<(), InitProtocolError>::Ok(())
Result::<(), InitProtocolError<()>>::Ok(())
});
let (r1, r2) = tokio::join!(r1, r2);
assert_eq!(r1.unwrap(), Err(InitProtocolError::Closed));
assert_eq!(r1.unwrap(), Err(InitProtocolError::NotId));
assert_eq!(r2.unwrap(), Ok(()));
}
}

View File

@ -86,12 +86,14 @@ use async_trait::async_trait;
/// Handshake: Used to connect 2 Channels.
#[async_trait]
pub trait InitProtocol {
type CustomErr: std::fmt::Debug + Send;
async fn initialize(
&mut self,
initializer: bool,
local_pid: Pid,
secret: u128,
) -> Result<(Pid, Sid, u128), InitProtocolError>;
) -> Result<(Pid, Sid, u128), InitProtocolError<Self::CustomErr>>;
}
/// Generic Network Send Protocol.
@ -101,18 +103,20 @@ pub trait InitProtocol {
///
/// A `Stream` MUST be bound to a specific Channel. You MUST NOT switch the
/// channel to send a stream mid air. We will provide takeover options for
/// Channel closure in the future to allow keeping a `Stream` over a broker
/// Channel closure in the future to allow keeping a `Stream` over a broken
/// Channel.
///
/// [`ProtocolEvent`]: crate::ProtocolEvent
#[async_trait]
pub trait SendProtocol {
type CustomErr: std::fmt::Debug + Send;
/// YOU MUST inform the `SendProtocol` by any Stream Open BEFORE using it in
/// `send` and Stream Close AFTER using it in `send` via this fn.
fn notify_from_recv(&mut self, event: ProtocolEvent);
/// Send a Event via this Protocol. The `SendProtocol` MAY require `flush`
/// to be called before actual data is send to the respective `Sink`.
async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError>;
async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>>;
/// Flush all buffered messages according to their [`Prio`] and
/// [`Bandwidth`]. provide the current bandwidth budget (per second) as
/// well as the `dt` since last call. According to the budget the
@ -124,7 +128,7 @@ pub trait SendProtocol {
&mut self,
bandwidth: Bandwidth,
dt: std::time::Duration,
) -> Result<Bandwidth, ProtocolError>;
) -> Result<Bandwidth, ProtocolError<Self::CustomErr>>;
}
/// Generic Network Recv Protocol. See: [`SendProtocol`]
@ -132,9 +136,11 @@ pub trait SendProtocol {
/// [`SendProtocol`]: crate::SendProtocol
#[async_trait]
pub trait RecvProtocol {
type CustomErr: std::fmt::Debug + Send;
/// Either recv an event or fail the Protocol, once the Recv side is closed
/// it cannot recover from the error.
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError>;
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>>;
}
/// This crate makes use of UnreliableDrains, they are expected to provide the
@ -147,8 +153,9 @@ pub trait RecvProtocol {
/// [`async-channel`]: async-channel
#[async_trait]
pub trait UnreliableDrain: Send {
type CustomErr: std::fmt::Debug + Send;
type DataFormat;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError>;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>>;
}
/// Sink counterpart of [`UnreliableDrain`]
@ -156,6 +163,7 @@ pub trait UnreliableDrain: Send {
/// [`UnreliableDrain`]: crate::UnreliableDrain
#[async_trait]
pub trait UnreliableSink: Send {
type CustomErr: std::fmt::Debug + Send;
type DataFormat;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError>;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>>;
}

View File

@ -82,9 +82,11 @@ impl<D> SendProtocol for MpscSendProtocol<D>
where
D: UnreliableDrain<DataFormat = MpscMsg>,
{
type CustomErr = D::CustomErr;
fn notify_from_recv(&mut self, _event: ProtocolEvent) {}
async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> {
async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
#[cfg(feature = "trace_pedantic")]
trace!(?event, "send");
match &event {
@ -113,7 +115,11 @@ where
}
}
async fn flush(&mut self, _: Bandwidth, _: Duration) -> Result<Bandwidth, ProtocolError> {
async fn flush(
&mut self,
_: Bandwidth,
_: Duration,
) -> Result<Bandwidth, ProtocolError<Self::CustomErr>> {
Ok(0)
}
}
@ -123,7 +129,9 @@ impl<S> RecvProtocol for MpscRecvProtocol<S>
where
S: UnreliableSink<DataFormat = MpscMsg>,
{
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError> {
type CustomErr = S::CustomErr;
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
let event = self.sink.recv().await?;
#[cfg(feature = "trace_pedantic")]
trace!(?event, "recv");
@ -153,7 +161,9 @@ impl<D> ReliableDrain for MpscSendProtocol<D>
where
D: UnreliableDrain<DataFormat = MpscMsg>,
{
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> {
type CustomErr = D::CustomErr;
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError<Self::CustomErr>> {
self.drain.send(MpscMsg::InitFrame(frame)).await
}
}
@ -163,7 +173,9 @@ impl<S> ReliableSink for MpscRecvProtocol<S>
where
S: UnreliableSink<DataFormat = MpscMsg>,
{
async fn recv(&mut self) -> Result<InitFrame, ProtocolError> {
type CustomErr = S::CustomErr;
async fn recv(&mut self) -> Result<InitFrame, ProtocolError<Self::CustomErr>> {
match self.sink.recv().await? {
MpscMsg::Event(_) => Err(ProtocolError::Violated),
MpscMsg::InitFrame(f) => Ok(f),
@ -209,25 +221,30 @@ pub mod test_utils {
#[async_trait]
impl UnreliableDrain for ACDrain {
type CustomErr = ();
type DataFormat = MpscMsg;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> {
async fn send(
&mut self,
data: Self::DataFormat,
) -> Result<(), ProtocolError<Self::CustomErr>> {
self.sender
.send(data)
.await
.map_err(|_| ProtocolError::Closed)
.map_err(|_| ProtocolError::Custom(()))
}
}
#[async_trait]
impl UnreliableSink for ACSink {
type CustomErr = ();
type DataFormat = MpscMsg;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> {
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
self.receiver
.recv()
.await
.map_err(|_| ProtocolError::Closed)
.map_err(|_| ProtocolError::Custom(()))
}
}
}

View File

@ -147,7 +147,9 @@ where
}
}
async fn recv_into_stream(&mut self) -> Result<QuicDataFormatStream, ProtocolError> {
async fn recv_into_stream(
&mut self,
) -> Result<QuicDataFormatStream, ProtocolError<S::CustomErr>> {
let chunk = self.sink.recv().await?;
let buffer = match chunk.stream {
QuicDataFormatStream::Main => &mut self.main_buffer,
@ -181,6 +183,8 @@ impl<D> SendProtocol for QuicSendProtocol<D>
where
D: UnreliableDrain<DataFormat = QuicDataFormat>,
{
type CustomErr = D::CustomErr;
fn notify_from_recv(&mut self, event: ProtocolEvent) {
match event {
ProtocolEvent::OpenStream {
@ -206,7 +210,7 @@ where
}
}
async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> {
async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
#[cfg(feature = "trace_pedantic")]
trace!(?event, "send");
match event {
@ -268,7 +272,7 @@ where
&mut self,
bandwidth: Bandwidth,
dt: Duration,
) -> Result</* actual */ Bandwidth, ProtocolError> {
) -> Result</* actual */ Bandwidth, ProtocolError<Self::CustomErr>> {
let (frames, _) = self.store.grab(bandwidth, dt);
//Todo: optimize reserve
let mut data_frames = 0;
@ -343,7 +347,9 @@ impl<S> RecvProtocol for QuicRecvProtocol<S>
where
S: UnreliableSink<DataFormat = QuicDataFormat>,
{
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError> {
type CustomErr = S::CustomErr;
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
'outer: loop {
match ITFrame::read_frame(&mut self.main_buffer) {
Ok(Some(frame)) => {
@ -484,7 +490,9 @@ impl<D> ReliableDrain for QuicSendProtocol<D>
where
D: UnreliableDrain<DataFormat = QuicDataFormat>,
{
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> {
type CustomErr = D::CustomErr;
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError<Self::CustomErr>> {
self.main_buffer.reserve(500);
frame.write_bytes(&mut self.main_buffer);
self.drain
@ -498,7 +506,9 @@ impl<S> ReliableSink for QuicRecvProtocol<S>
where
S: UnreliableSink<DataFormat = QuicDataFormat>,
{
async fn recv(&mut self) -> Result<InitFrame, ProtocolError> {
type CustomErr = S::CustomErr;
async fn recv(&mut self) -> Result<InitFrame, ProtocolError<Self::CustomErr>> {
while self.main_buffer.len() < 100 {
if self.recv_into_stream().await? == QuicDataFormatStream::Main {
if let Some(frame) = InitFrame::read_frame(&mut self.main_buffer) {
@ -564,9 +574,13 @@ mod test_utils {
#[async_trait]
impl UnreliableDrain for QuicDrain {
type CustomErr = ();
type DataFormat = QuicDataFormat;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> {
async fn send(
&mut self,
data: Self::DataFormat,
) -> Result<(), ProtocolError<Self::CustomErr>> {
use rand::Rng;
if matches!(data.stream, QuicDataFormatStream::Unreliable)
&& rand::thread_rng().gen::<f32>() < self.drop_ratio
@ -576,19 +590,20 @@ mod test_utils {
self.sender
.send(data)
.await
.map_err(|_| ProtocolError::Closed)
.map_err(|_| ProtocolError::Custom(()))
}
}
#[async_trait]
impl UnreliableSink for QuicSink {
type CustomErr = ();
type DataFormat = QuicDataFormat;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> {
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
self.receiver
.recv()
.await
.map_err(|_| ProtocolError::Closed)
.map_err(|_| ProtocolError::Custom(()))
}
}
}
@ -865,7 +880,7 @@ mod tests {
drop(s);
let e = e.await.unwrap();
assert_eq!(e, Err(ProtocolError::Closed));
assert_eq!(e, Err(ProtocolError::Custom(())));
}
#[tokio::test]

View File

@ -100,6 +100,8 @@ impl<D> SendProtocol for TcpSendProtocol<D>
where
D: UnreliableDrain<DataFormat = BytesMut>,
{
type CustomErr = D::CustomErr;
fn notify_from_recv(&mut self, event: ProtocolEvent) {
match event {
ProtocolEvent::OpenStream {
@ -122,7 +124,7 @@ where
}
}
async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> {
async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
#[cfg(feature = "trace_pedantic")]
trace!(?event, "send");
match event {
@ -170,7 +172,7 @@ where
&mut self,
bandwidth: Bandwidth,
dt: Duration,
) -> Result</* actual */ Bandwidth, ProtocolError> {
) -> Result</* actual */ Bandwidth, ProtocolError<Self::CustomErr>> {
let (frames, total_bytes) = self.store.grab(bandwidth, dt);
self.buffer.reserve(total_bytes as usize);
let mut data_frames = 0;
@ -228,7 +230,9 @@ impl<S> RecvProtocol for TcpRecvProtocol<S>
where
S: UnreliableSink<DataFormat = BytesMut>,
{
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError> {
type CustomErr = S::CustomErr;
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
'outer: loop {
loop {
match ITFrame::read_frame(&mut self.buffer) {
@ -307,7 +311,9 @@ impl<D> ReliableDrain for TcpSendProtocol<D>
where
D: UnreliableDrain<DataFormat = BytesMut>,
{
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> {
type CustomErr = D::CustomErr;
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError<Self::CustomErr>> {
let mut buffer = BytesMut::with_capacity(500);
frame.write_bytes(&mut buffer);
self.drain.send(buffer).await
@ -319,7 +325,9 @@ impl<S> ReliableSink for TcpRecvProtocol<S>
where
S: UnreliableSink<DataFormat = BytesMut>,
{
async fn recv(&mut self) -> Result<InitFrame, ProtocolError> {
type CustomErr = S::CustomErr;
async fn recv(&mut self) -> Result<InitFrame, ProtocolError<Self::CustomErr>> {
while self.buffer.len() < 100 {
let chunk = self.sink.recv().await?;
self.buffer.extend_from_slice(&chunk);
@ -371,25 +379,30 @@ mod test_utils {
#[async_trait]
impl UnreliableDrain for TcpDrain {
type CustomErr = ();
type DataFormat = BytesMut;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> {
async fn send(
&mut self,
data: Self::DataFormat,
) -> Result<(), ProtocolError<Self::CustomErr>> {
self.sender
.send(data)
.await
.map_err(|_| ProtocolError::Closed)
.map_err(|_| ProtocolError::Custom(()))
}
}
#[async_trait]
impl UnreliableSink for TcpSink {
type CustomErr = ();
type DataFormat = BytesMut;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> {
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
self.receiver
.recv()
.await
.map_err(|_| ProtocolError::Closed)
.map_err(|_| ProtocolError::Custom(()))
}
}
}
@ -659,7 +672,7 @@ mod tests {
drop(s);
let e = e.await.unwrap();
assert_eq!(e, Err(ProtocolError::Closed));
assert_eq!(e, Err(ProtocolError::Custom(())));
}
#[tokio::test]

View File

@ -1,4 +1,5 @@
use crate::{
channel::ProtocolsError,
message::{partial_eq_bincode, Message},
participant::{A2bStreamOpen, S2bShutdownBparticipant},
scheduler::{A2sConnect, Scheduler},
@ -106,7 +107,7 @@ pub enum NetworkError {
pub enum NetworkConnectError {
/// Either a Pid UUID clash or you are trying to hijack a connection
InvalidSecret,
Handshake(InitProtocolError),
Handshake(InitProtocolError<ProtocolsError>),
Io(std::io::Error),
}

View File

@ -193,7 +193,7 @@ impl Protocols {
metrics: Arc<ProtocolMetrics>,
s2s_stop_listening_r: oneshot::Receiver<()>,
c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>,
) -> std::io::Result<()> {
) -> io::Result<()> {
let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel();
MPSC_POOL.lock().await.insert(addr, mpsc_s);
trace!(?addr, "Mpsc Listener bound");
@ -255,26 +255,17 @@ impl Protocols {
info!("Connecting Quic to: {}", &addr);
let connecting = endpoint.connect_with(config, addr, &name).map_err(|e| {
trace!(?e, "error setting up quic");
NetworkConnectError::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
e,
))
NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
})?;
let connection = connecting.await.map_err(|e| {
trace!(?e, "error with quic connection");
NetworkConnectError::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
e,
))
NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
})?;
Self::new_quic(connection, false, metrics)
.await
.map_err(|e| {
trace!(?e, "error with quic");
NetworkConnectError::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
e,
))
NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
})
}
@ -286,7 +277,7 @@ impl Protocols {
metrics: Arc<ProtocolMetrics>,
s2s_stop_listening_r: oneshot::Receiver<()>,
c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>,
) -> std::io::Result<()> {
) -> io::Result<()> {
let (_endpoint, mut listener) = match quinn::Endpoint::server(server_config, addr) {
Ok(v) => v,
Err(e) => return Err(e),
@ -378,12 +369,14 @@ impl Protocols {
#[async_trait]
impl network_protocol::InitProtocol for Protocols {
type CustomErr = ProtocolsError;
async fn initialize(
&mut self,
initializer: bool,
local_pid: Pid,
secret: u128,
) -> Result<(Pid, Sid, u128), InitProtocolError> {
) -> Result<(Pid, Sid, u128), InitProtocolError<Self::CustomErr>> {
match self {
Protocols::Tcp(p) => p.initialize(initializer, local_pid, secret).await,
Protocols::Mpsc(p) => p.initialize(initializer, local_pid, secret).await,
@ -395,6 +388,8 @@ impl network_protocol::InitProtocol for Protocols {
#[async_trait]
impl network_protocol::SendProtocol for SendProtocols {
type CustomErr = ProtocolsError;
fn notify_from_recv(&mut self, event: ProtocolEvent) {
match self {
SendProtocols::Tcp(s) => s.notify_from_recv(event),
@ -404,7 +399,7 @@ impl network_protocol::SendProtocol for SendProtocols {
}
}
async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> {
async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
match self {
SendProtocols::Tcp(s) => s.send(event).await,
SendProtocols::Mpsc(s) => s.send(event).await,
@ -417,7 +412,7 @@ impl network_protocol::SendProtocol for SendProtocols {
&mut self,
bandwidth: Bandwidth,
dt: Duration,
) -> Result<Bandwidth, ProtocolError> {
) -> Result<Bandwidth, ProtocolError<Self::CustomErr>> {
match self {
SendProtocols::Tcp(s) => s.flush(bandwidth, dt).await,
SendProtocols::Mpsc(s) => s.flush(bandwidth, dt).await,
@ -429,7 +424,9 @@ impl network_protocol::SendProtocol for SendProtocols {
#[async_trait]
impl network_protocol::RecvProtocol for RecvProtocols {
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError> {
type CustomErr = ProtocolsError;
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
match self {
RecvProtocols::Tcp(r) => r.recv().await,
RecvProtocols::Mpsc(r) => r.recv().await,
@ -439,6 +436,32 @@ impl network_protocol::RecvProtocol for RecvProtocols {
}
}
#[derive(Debug)]
pub enum MpscError {
Send(tokio::sync::mpsc::error::SendError<network_protocol::MpscMsg>),
Recv,
}
#[cfg(feature = "quic")]
#[derive(Debug)]
pub enum QuicError {
Send(std::io::Error),
Connection(quinn::ConnectionError),
Write(quinn::WriteError),
Read(quinn::ReadError),
InternalMpsc,
}
/// Error types for Protocols
#[derive(Debug)]
pub enum ProtocolsError {
Tcp(std::io::Error),
Udp(std::io::Error),
#[cfg(feature = "quic")]
Quic(QuicError),
Mpsc(MpscError),
}
///////////////////////////////////////
//// TCP
#[derive(Debug)]
@ -454,26 +477,31 @@ pub struct TcpSink {
#[async_trait]
impl UnreliableDrain for TcpDrain {
type CustomErr = ProtocolsError;
type DataFormat = BytesMut;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> {
match self.half.write_all(&data).await {
Ok(()) => Ok(()),
Err(_) => Err(ProtocolError::Closed),
}
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
self.half
.write_all(&data)
.await
.map_err(|e| ProtocolError::Custom(ProtocolsError::Tcp(e)))
}
}
#[async_trait]
impl UnreliableSink for TcpSink {
type CustomErr = ProtocolsError;
type DataFormat = BytesMut;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> {
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
self.buffer.resize(1500, 0u8);
match self.half.read(&mut self.buffer).await {
Ok(0) => Err(ProtocolError::Closed),
Ok(0) => Err(ProtocolError::Custom(ProtocolsError::Tcp(io::Error::new(
io::ErrorKind::BrokenPipe,
"read returned 0 bytes",
)))),
Ok(n) => Ok(self.buffer.split_to(n)),
Err(_) => Err(ProtocolError::Closed),
Err(e) => Err(ProtocolError::Custom(ProtocolsError::Tcp(e))),
}
}
}
@ -492,22 +520,27 @@ pub struct MpscSink {
#[async_trait]
impl UnreliableDrain for MpscDrain {
type CustomErr = ProtocolsError;
type DataFormat = MpscMsg;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> {
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
self.sender
.send(data)
.await
.map_err(|_| ProtocolError::Closed)
.map_err(|e| ProtocolError::Custom(ProtocolsError::Mpsc(MpscError::Send(e))))
}
}
#[async_trait]
impl UnreliableSink for MpscSink {
type CustomErr = ProtocolsError;
type DataFormat = MpscMsg;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> {
self.receiver.recv().await.ok_or(ProtocolError::Closed)
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
self.receiver
.recv()
.await
.ok_or(ProtocolError::Custom(ProtocolsError::Mpsc(MpscError::Recv)))
}
}
@ -560,10 +593,11 @@ fn spawn_new(
#[cfg(feature = "quic")]
#[async_trait]
impl UnreliableDrain for QuicDrain {
type CustomErr = ProtocolsError;
type DataFormat = QuicDataFormat;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> {
match match data.stream {
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
match data.stream {
QuicDataFormatStream::Main => self.main.write_all(&data.data).await,
QuicDataFormatStream::Unreliable => unimplemented!(),
QuicDataFormatStream::Reliable(sid) => {
@ -575,41 +609,43 @@ impl UnreliableDrain for QuicDrain {
// IF the buffer is empty this was created localy and WE are allowed to
// open_bi(), if not, we NEED to block on sendstreams_r
if data.data.is_empty() {
match self.con.open_bi().await {
Ok((mut sendstream, recvstream)) => {
// send SID as first msg
if sendstream.write_u64(sid.get_u64()).await.is_err() {
return Err(ProtocolError::Closed);
}
spawn_new(recvstream, Some(sid), &self.recvstreams_s);
vacant.insert(sendstream).write_all(&data.data).await
},
Err(_) => return Err(ProtocolError::Closed),
}
let (mut sendstream, recvstream) =
self.con.open_bi().await.map_err(|e| {
ProtocolError::Custom(ProtocolsError::Quic(
QuicError::Connection(e),
))
})?;
// send SID as first msg
sendstream.write_u64(sid.get_u64()).await.map_err(|e| {
ProtocolError::Custom(ProtocolsError::Quic(QuicError::Send(e)))
})?;
spawn_new(recvstream, Some(sid), &self.recvstreams_s);
vacant.insert(sendstream).write_all(&data.data).await
} else {
let sendstream = self
.sendstreams_r
.recv()
.await
.ok_or(ProtocolError::Closed)?;
let sendstream =
self.sendstreams_r
.recv()
.await
.ok_or(ProtocolError::Custom(ProtocolsError::Quic(
QuicError::InternalMpsc,
)))?;
vacant.insert(sendstream).write_all(&data.data).await
}
},
}
},
} {
Ok(()) => Ok(()),
Err(_) => Err(ProtocolError::Closed),
}
.map_err(|e| ProtocolError::Custom(ProtocolsError::Quic(QuicError::Write(e))))
}
}
#[cfg(feature = "quic")]
#[async_trait]
impl UnreliableSink for QuicSink {
type CustomErr = ProtocolsError;
type DataFormat = QuicDataFormat;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> {
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
let (mut buffer, result, mut recvstream, id) = loop {
use futures_util::FutureExt;
// first handle all bi streams!
@ -620,20 +656,20 @@ impl UnreliableSink for QuicSink {
};
if let Some(remote_stream) = a {
match remote_stream {
Ok((sendstream, mut recvstream)) => {
let sid = match recvstream.read_u64().await {
Ok(u64::MAX) => None, //unreliable
Ok(sid) => Some(Sid::new(sid)),
Err(_) => return Err(ProtocolError::Violated),
};
if self.sendstreams_s.send(sendstream).is_err() {
return Err(ProtocolError::Closed);
}
spawn_new(recvstream, sid, &self.recvstreams_s);
},
Err(_) => return Err(ProtocolError::Closed),
let (sendstream, mut recvstream) = remote_stream.map_err(|e| {
ProtocolError::Custom(ProtocolsError::Quic(QuicError::Connection(e)))
})?;
let sid = match recvstream.read_u64().await {
Ok(u64::MAX) => None, //unreliable
Ok(sid) => Some(Sid::new(sid)),
Err(_) => return Err(ProtocolError::Violated),
};
if self.sendstreams_s.send(sendstream).is_err() {
return Err(ProtocolError::Custom(ProtocolsError::Quic(
QuicError::InternalMpsc,
)));
}
spawn_new(recvstream, sid, &self.recvstreams_s);
}
if let Some(data) = b {
@ -642,7 +678,12 @@ impl UnreliableSink for QuicSink {
};
let r = match result {
Ok(Some(0)) => Err(ProtocolError::Closed),
Ok(Some(0)) => Err(ProtocolError::Custom(ProtocolsError::Quic(
QuicError::Send(io::Error::new(
io::ErrorKind::BrokenPipe,
"read returned 0 bytes",
)),
))),
Ok(Some(n)) => Ok(QuicDataFormat {
stream: match id {
Some(id) => QuicDataFormatStream::Reliable(id),
@ -650,8 +691,15 @@ impl UnreliableSink for QuicSink {
},
data: buffer.split_to(n),
}),
Ok(None) => Err(ProtocolError::Closed),
Err(_) => Err(ProtocolError::Closed),
Ok(None) => Err(ProtocolError::Custom(ProtocolsError::Quic(
QuicError::Send(io::Error::new(
io::ErrorKind::BrokenPipe,
"read returned None",
)),
))),
Err(e) => Err(ProtocolError::Custom(ProtocolsError::Quic(
QuicError::Read(e),
))),
}?;
let streams_s_clone = self.recvstreams_s.clone();
@ -739,6 +787,15 @@ mod tests {
drop(s);
let e = e.await.unwrap();
assert!(e.is_err());
assert_eq!(e.unwrap_err(), ProtocolError::Closed);
assert!(matches!(e, Err(..)));
let e = e.unwrap_err();
assert!(matches!(e, ProtocolError::Custom(..)));
assert!(matches!(e, ProtocolError::Custom(ProtocolsError::Tcp(_))));
match e {
ProtocolError::Custom(ProtocolsError::Tcp(e)) => {
assert_eq!(e.kind(), io::ErrorKind::BrokenPipe)
},
_ => panic!("invalid error"),
}
}
}

View File

@ -1,6 +1,6 @@
use crate::{
api::{ParticipantError, Stream},
channel::{Protocols, RecvProtocols, SendProtocols},
channel::{Protocols, ProtocolsError, RecvProtocols, SendProtocols},
metrics::NetworkMetrics,
util::DeferredTracer,
};
@ -371,7 +371,7 @@ impl BParticipant {
self.metrics
.participant_bandwidth(&self.remote_pid_string, part_bandwidth);
let _ = b2a_bandwidth_stats_s.send(part_bandwidth);
let r: Result<(), network_protocol::ProtocolError> = Ok(());
let r: Result<(), network_protocol::ProtocolError<ProtocolsError>> = Ok(());
r
}
.await;