export unterlying errors via network crate, to generate more detailed logs

This commit is contained in:
Marcel Märtens 2022-05-20 12:20:57 +02:00
parent 8e0295ff47
commit e194a2e334
12 changed files with 298 additions and 155 deletions

2
Cargo.lock generated
View File

@ -6690,7 +6690,7 @@ dependencies = [
[[package]] [[package]]
name = "veloren-network-protocol" name = "veloren-network-protocol"
version = "0.6.0" version = "0.6.1"
dependencies = [ dependencies = [
"async-channel", "async-channel",
"async-trait", "async-trait",

View File

@ -1,7 +1,7 @@
[package] [package]
name = "veloren-network-protocol" name = "veloren-network-protocol"
description = "pure Protocol without any I/O itself" description = "pure Protocol without any I/O itself"
version = "0.6.0" version = "0.6.1"
authors = ["Marcel Märtens <marcel.cochem@googlemail.com>"] authors = ["Marcel Märtens <marcel.cochem@googlemail.com>"]
edition = "2021" edition = "2021"
@ -24,7 +24,7 @@ rand = { version = "0.8" }
# async traits # async traits
async-trait = "0.1.42" async-trait = "0.1.42"
bytes = "^1" bytes = "^1"
hashbrown = { version = ">=0.9, <0.13" } hashbrown = { version = ">=0.12, <0.13" }
[dev-dependencies] [dev-dependencies]
async-channel = "1.5.1" async-channel = "1.5.1"

View File

@ -271,73 +271,88 @@ mod utils {
#[async_trait] #[async_trait]
impl UnreliableDrain for ACDrain { impl UnreliableDrain for ACDrain {
type CustomErr = ();
type DataFormat = MpscMsg; type DataFormat = MpscMsg;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { async fn send(
&mut self,
data: Self::DataFormat,
) -> Result<(), ProtocolError<Self::CustomErr>> {
self.sender self.sender
.send(data) .send(data)
.await .await
.map_err(|_| ProtocolError::Closed) .map_err(|_| ProtocolError::Custom(()))
} }
} }
#[async_trait] #[async_trait]
impl UnreliableSink for ACSink { impl UnreliableSink for ACSink {
type CustomErr = ();
type DataFormat = MpscMsg; type DataFormat = MpscMsg;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> { async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
self.receiver self.receiver
.recv() .recv()
.await .await
.map_err(|_| ProtocolError::Closed) .map_err(|_| ProtocolError::Custom(()))
} }
} }
#[async_trait] #[async_trait]
impl UnreliableDrain for TcpDrain { impl UnreliableDrain for TcpDrain {
type CustomErr = ();
type DataFormat = BytesMut; type DataFormat = BytesMut;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { async fn send(
&mut self,
data: Self::DataFormat,
) -> Result<(), ProtocolError<Self::CustomErr>> {
self.sender self.sender
.send(data) .send(data)
.await .await
.map_err(|_| ProtocolError::Closed) .map_err(|_| ProtocolError::Custom(()))
} }
} }
#[async_trait] #[async_trait]
impl UnreliableSink for TcpSink { impl UnreliableSink for TcpSink {
type CustomErr = ();
type DataFormat = BytesMut; type DataFormat = BytesMut;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> { async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
self.receiver self.receiver
.recv() .recv()
.await .await
.map_err(|_| ProtocolError::Closed) .map_err(|_| ProtocolError::Custom(()))
} }
} }
#[async_trait] #[async_trait]
impl UnreliableDrain for QuicDrain { impl UnreliableDrain for QuicDrain {
type CustomErr = ();
type DataFormat = QuicDataFormat; type DataFormat = QuicDataFormat;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { async fn send(
&mut self,
data: Self::DataFormat,
) -> Result<(), ProtocolError<Self::CustomErr>> {
self.sender self.sender
.send(data) .send(data)
.await .await
.map_err(|_| ProtocolError::Closed) .map_err(|_| ProtocolError::Custom(()))
} }
} }
#[async_trait] #[async_trait]
impl UnreliableSink for QuicSink { impl UnreliableSink for QuicSink {
type CustomErr = ();
type DataFormat = QuicDataFormat; type DataFormat = QuicDataFormat;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> { async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
self.receiver self.receiver
.recv() .recv()
.await .await
.map_err(|_| ProtocolError::Closed) .map_err(|_| ProtocolError::Custom(()))
} }
} }
} }

View File

@ -2,38 +2,50 @@
/// ///
/// [`InitProtocol`]: crate::InitProtocol /// [`InitProtocol`]: crate::InitProtocol
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum InitProtocolError { pub enum InitProtocolError<E: std::fmt::Debug + Send> {
Closed, Custom(E),
/// expected Handshake, didn't get handshake
NotHandshake,
/// expected Id, didn't get id
NotId,
WrongMagicNumber([u8; 7]), WrongMagicNumber([u8; 7]),
WrongVersion([u32; 3]), WrongVersion([u32; 3]),
} }
/// When you return closed you must stay closed! /// When you return closed you must stay closed!
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum ProtocolError { pub enum ProtocolError<E: std::fmt::Debug + Send> {
/// Closed indicates the underlying I/O got closed /// Custom Error on the underlying I/O,
/// e.g. the TCP, UDP or MPSC connection is dropped by the OS /// e.g. the TCP, UDP or MPSC connection is dropped by the OS
Closed, Custom(E),
/// Violated indicates the veloren_network_protocol was violated /// Violated indicates the veloren_network_protocol was violated
/// the underlying I/O connection is still valid, but the remote side /// the underlying I/O connection is still valid, but the remote side
/// send WRONG (e.g. Invalid, or wrong order) data on the protocol layer. /// send WRONG (e.g. Invalid, or wrong order) data on the protocol layer.
Violated, Violated,
} }
impl From<ProtocolError> for InitProtocolError { impl<E: std::fmt::Debug + Send> From<ProtocolError<E>> for InitProtocolError<E> {
fn from(err: ProtocolError) -> Self { fn from(err: ProtocolError<E>) -> Self {
match err { match err {
ProtocolError::Closed => InitProtocolError::Closed, ProtocolError::Custom(e) => InitProtocolError::Custom(e),
// not possible as the Init has raw access to the I/O ProtocolError::Violated => {
ProtocolError::Violated => InitProtocolError::Closed, unreachable!("not possible as the Init has raw access to the I/O")
},
} }
} }
} }
impl core::fmt::Display for InitProtocolError { impl<E: std::fmt::Debug + Send> core::fmt::Display for InitProtocolError<E> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self { match self {
InitProtocolError::Closed => write!(f, "Channel closed"), InitProtocolError::Custom(e) => write!(f, "custom: {:?}", e),
InitProtocolError::NotHandshake => write!(
f,
"Remote send something which couldn't be parsed as a handshake"
),
InitProtocolError::NotId => {
write!(f, "Remote send something which couldn't be parsed as an id")
},
InitProtocolError::WrongMagicNumber(r) => write!( InitProtocolError::WrongMagicNumber(r) => write!(
f, f,
"Magic Number doesn't match, remote side send '{:?}' instead of '{:?}'", "Magic Number doesn't match, remote side send '{:?}' instead of '{:?}'",
@ -50,14 +62,14 @@ impl core::fmt::Display for InitProtocolError {
} }
} }
impl core::fmt::Display for ProtocolError { impl<E: std::fmt::Debug + Send> core::fmt::Display for ProtocolError<E> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self { match self {
ProtocolError::Closed => write!(f, "Channel closed"), ProtocolError::Custom(e) => write!(f, "Channel custom close: {:?}", e),
ProtocolError::Violated => write!(f, "Channel protocol violated"), ProtocolError::Violated => write!(f, "Channel protocol violated"),
} }
} }
} }
impl std::error::Error for InitProtocolError {} impl<E: std::fmt::Debug + Send> std::error::Error for InitProtocolError<E> {}
impl std::error::Error for ProtocolError {} impl<E: std::fmt::Debug + Send> std::error::Error for ProtocolError<E> {}

View File

@ -21,7 +21,8 @@ use tracing::{debug, error, info, trace};
/// [`RecvProtocol`]: crate::RecvProtocol /// [`RecvProtocol`]: crate::RecvProtocol
#[async_trait] #[async_trait]
pub trait ReliableDrain { pub trait ReliableDrain {
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError>; type CustomErr: std::fmt::Debug + Send;
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError<Self::CustomErr>>;
} }
/// Implement this for auto Handshake with [`ReliableDrain`]. See /// Implement this for auto Handshake with [`ReliableDrain`]. See
@ -30,21 +31,25 @@ pub trait ReliableDrain {
/// [`ReliableDrain`]: crate::ReliableDrain /// [`ReliableDrain`]: crate::ReliableDrain
#[async_trait] #[async_trait]
pub trait ReliableSink { pub trait ReliableSink {
async fn recv(&mut self) -> Result<InitFrame, ProtocolError>; type CustomErr: std::fmt::Debug + Send;
async fn recv(&mut self) -> Result<InitFrame, ProtocolError<Self::CustomErr>>;
} }
#[async_trait] #[async_trait]
impl<D, S> InitProtocol for (D, S) impl<D, S, E> InitProtocol for (D, S)
where where
D: ReliableDrain + Send, D: ReliableDrain<CustomErr = E> + Send,
S: ReliableSink + Send, S: ReliableSink<CustomErr = E> + Send,
E: std::fmt::Debug + Send,
{ {
type CustomErr = E;
async fn initialize( async fn initialize(
&mut self, &mut self,
initializer: bool, initializer: bool,
local_pid: Pid, local_pid: Pid,
local_secret: u128, local_secret: u128,
) -> Result<(Pid, Sid, u128), InitProtocolError> { ) -> Result<(Pid, Sid, u128), InitProtocolError<E>> {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
const WRONG_NUMBER: &str = "Handshake does not contain the magic number required by \ const WRONG_NUMBER: &str = "Handshake does not contain the magic number required by \
veloren server.\nWe are not sure if you are a valid veloren \ veloren server.\nWe are not sure if you are a valid veloren \
@ -122,11 +127,11 @@ where
Ok(string) => error!(?string, ERR_S), Ok(string) => error!(?string, ERR_S),
_ => error!(?bytes, ERR_S), _ => error!(?bytes, ERR_S),
} }
Err(InitProtocolError::Closed) Err(InitProtocolError::NotHandshake)
}, },
_ => { _ => {
info!("Handshake failed"); info!("Handshake failed");
Err(InitProtocolError::Closed) Err(InitProtocolError::NotHandshake)
}, },
}?; }?;
@ -152,11 +157,11 @@ where
Ok(string) => error!(?string, ERR_S), Ok(string) => error!(?string, ERR_S),
_ => error!(?bytes, ERR_S), _ => error!(?bytes, ERR_S),
} }
Err(InitProtocolError::Closed) Err(InitProtocolError::NotId)
}, },
_ => { _ => {
info!("Handshake failed"); info!("Handshake failed");
Err(InitProtocolError::Closed) Err(InitProtocolError::NotId)
}, },
} }
} }
@ -176,7 +181,7 @@ mod tests {
let _ = p2; let _ = p2;
}); });
let (r1, _) = tokio::join!(r1, r2); let (r1, _) = tokio::join!(r1, r2);
assert_eq!(r1.unwrap(), Err(InitProtocolError::Closed)); assert_eq!(r1.unwrap(), Err(InitProtocolError::Custom(())));
} }
#[tokio::test] #[tokio::test]
@ -191,7 +196,7 @@ mod tests {
}) })
.await?; .await?;
let _ = p2.1.recv().await?; let _ = p2.1.recv().await?;
Result::<(), InitProtocolError>::Ok(()) Result::<(), InitProtocolError<()>>::Ok(())
}); });
let (r1, r2) = tokio::join!(r1, r2); let (r1, r2) = tokio::join!(r1, r2);
assert_eq!( assert_eq!(
@ -218,7 +223,7 @@ mod tests {
}); });
let (r1, r2) = tokio::join!(r1, r2); let (r1, r2) = tokio::join!(r1, r2);
assert_eq!(r1.unwrap(), Err(InitProtocolError::WrongVersion([0, 1, 2]))); assert_eq!(r1.unwrap(), Err(InitProtocolError::WrongVersion([0, 1, 2])));
assert_eq!(r2.unwrap(), Err(InitProtocolError::Closed)); assert_eq!(r2.unwrap(), Err(InitProtocolError::Custom(())));
} }
#[tokio::test] #[tokio::test]
@ -234,10 +239,10 @@ mod tests {
.await?; .await?;
let _ = p2.1.recv().await?; let _ = p2.1.recv().await?;
p2.0.send(InitFrame::Raw(b"Hello World".to_vec())).await?; p2.0.send(InitFrame::Raw(b"Hello World".to_vec())).await?;
Result::<(), InitProtocolError>::Ok(()) Result::<(), InitProtocolError<()>>::Ok(())
}); });
let (r1, r2) = tokio::join!(r1, r2); let (r1, r2) = tokio::join!(r1, r2);
assert_eq!(r1.unwrap(), Err(InitProtocolError::Closed)); assert_eq!(r1.unwrap(), Err(InitProtocolError::NotId));
assert_eq!(r2.unwrap(), Ok(())); assert_eq!(r2.unwrap(), Ok(()));
} }
} }

View File

@ -86,12 +86,14 @@ use async_trait::async_trait;
/// Handshake: Used to connect 2 Channels. /// Handshake: Used to connect 2 Channels.
#[async_trait] #[async_trait]
pub trait InitProtocol { pub trait InitProtocol {
type CustomErr: std::fmt::Debug + Send;
async fn initialize( async fn initialize(
&mut self, &mut self,
initializer: bool, initializer: bool,
local_pid: Pid, local_pid: Pid,
secret: u128, secret: u128,
) -> Result<(Pid, Sid, u128), InitProtocolError>; ) -> Result<(Pid, Sid, u128), InitProtocolError<Self::CustomErr>>;
} }
/// Generic Network Send Protocol. /// Generic Network Send Protocol.
@ -101,18 +103,20 @@ pub trait InitProtocol {
/// ///
/// A `Stream` MUST be bound to a specific Channel. You MUST NOT switch the /// A `Stream` MUST be bound to a specific Channel. You MUST NOT switch the
/// channel to send a stream mid air. We will provide takeover options for /// channel to send a stream mid air. We will provide takeover options for
/// Channel closure in the future to allow keeping a `Stream` over a broker /// Channel closure in the future to allow keeping a `Stream` over a broken
/// Channel. /// Channel.
/// ///
/// [`ProtocolEvent`]: crate::ProtocolEvent /// [`ProtocolEvent`]: crate::ProtocolEvent
#[async_trait] #[async_trait]
pub trait SendProtocol { pub trait SendProtocol {
type CustomErr: std::fmt::Debug + Send;
/// YOU MUST inform the `SendProtocol` by any Stream Open BEFORE using it in /// YOU MUST inform the `SendProtocol` by any Stream Open BEFORE using it in
/// `send` and Stream Close AFTER using it in `send` via this fn. /// `send` and Stream Close AFTER using it in `send` via this fn.
fn notify_from_recv(&mut self, event: ProtocolEvent); fn notify_from_recv(&mut self, event: ProtocolEvent);
/// Send a Event via this Protocol. The `SendProtocol` MAY require `flush` /// Send a Event via this Protocol. The `SendProtocol` MAY require `flush`
/// to be called before actual data is send to the respective `Sink`. /// to be called before actual data is send to the respective `Sink`.
async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError>; async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>>;
/// Flush all buffered messages according to their [`Prio`] and /// Flush all buffered messages according to their [`Prio`] and
/// [`Bandwidth`]. provide the current bandwidth budget (per second) as /// [`Bandwidth`]. provide the current bandwidth budget (per second) as
/// well as the `dt` since last call. According to the budget the /// well as the `dt` since last call. According to the budget the
@ -124,7 +128,7 @@ pub trait SendProtocol {
&mut self, &mut self,
bandwidth: Bandwidth, bandwidth: Bandwidth,
dt: std::time::Duration, dt: std::time::Duration,
) -> Result<Bandwidth, ProtocolError>; ) -> Result<Bandwidth, ProtocolError<Self::CustomErr>>;
} }
/// Generic Network Recv Protocol. See: [`SendProtocol`] /// Generic Network Recv Protocol. See: [`SendProtocol`]
@ -132,9 +136,11 @@ pub trait SendProtocol {
/// [`SendProtocol`]: crate::SendProtocol /// [`SendProtocol`]: crate::SendProtocol
#[async_trait] #[async_trait]
pub trait RecvProtocol { pub trait RecvProtocol {
type CustomErr: std::fmt::Debug + Send;
/// Either recv an event or fail the Protocol, once the Recv side is closed /// Either recv an event or fail the Protocol, once the Recv side is closed
/// it cannot recover from the error. /// it cannot recover from the error.
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError>; async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>>;
} }
/// This crate makes use of UnreliableDrains, they are expected to provide the /// This crate makes use of UnreliableDrains, they are expected to provide the
@ -147,8 +153,9 @@ pub trait RecvProtocol {
/// [`async-channel`]: async-channel /// [`async-channel`]: async-channel
#[async_trait] #[async_trait]
pub trait UnreliableDrain: Send { pub trait UnreliableDrain: Send {
type CustomErr: std::fmt::Debug + Send;
type DataFormat; type DataFormat;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError>; async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>>;
} }
/// Sink counterpart of [`UnreliableDrain`] /// Sink counterpart of [`UnreliableDrain`]
@ -156,6 +163,7 @@ pub trait UnreliableDrain: Send {
/// [`UnreliableDrain`]: crate::UnreliableDrain /// [`UnreliableDrain`]: crate::UnreliableDrain
#[async_trait] #[async_trait]
pub trait UnreliableSink: Send { pub trait UnreliableSink: Send {
type CustomErr: std::fmt::Debug + Send;
type DataFormat; type DataFormat;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError>; async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>>;
} }

View File

@ -82,9 +82,11 @@ impl<D> SendProtocol for MpscSendProtocol<D>
where where
D: UnreliableDrain<DataFormat = MpscMsg>, D: UnreliableDrain<DataFormat = MpscMsg>,
{ {
type CustomErr = D::CustomErr;
fn notify_from_recv(&mut self, _event: ProtocolEvent) {} fn notify_from_recv(&mut self, _event: ProtocolEvent) {}
async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
#[cfg(feature = "trace_pedantic")] #[cfg(feature = "trace_pedantic")]
trace!(?event, "send"); trace!(?event, "send");
match &event { match &event {
@ -113,7 +115,11 @@ where
} }
} }
async fn flush(&mut self, _: Bandwidth, _: Duration) -> Result<Bandwidth, ProtocolError> { async fn flush(
&mut self,
_: Bandwidth,
_: Duration,
) -> Result<Bandwidth, ProtocolError<Self::CustomErr>> {
Ok(0) Ok(0)
} }
} }
@ -123,7 +129,9 @@ impl<S> RecvProtocol for MpscRecvProtocol<S>
where where
S: UnreliableSink<DataFormat = MpscMsg>, S: UnreliableSink<DataFormat = MpscMsg>,
{ {
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError> { type CustomErr = S::CustomErr;
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
let event = self.sink.recv().await?; let event = self.sink.recv().await?;
#[cfg(feature = "trace_pedantic")] #[cfg(feature = "trace_pedantic")]
trace!(?event, "recv"); trace!(?event, "recv");
@ -153,7 +161,9 @@ impl<D> ReliableDrain for MpscSendProtocol<D>
where where
D: UnreliableDrain<DataFormat = MpscMsg>, D: UnreliableDrain<DataFormat = MpscMsg>,
{ {
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> { type CustomErr = D::CustomErr;
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError<Self::CustomErr>> {
self.drain.send(MpscMsg::InitFrame(frame)).await self.drain.send(MpscMsg::InitFrame(frame)).await
} }
} }
@ -163,7 +173,9 @@ impl<S> ReliableSink for MpscRecvProtocol<S>
where where
S: UnreliableSink<DataFormat = MpscMsg>, S: UnreliableSink<DataFormat = MpscMsg>,
{ {
async fn recv(&mut self) -> Result<InitFrame, ProtocolError> { type CustomErr = S::CustomErr;
async fn recv(&mut self) -> Result<InitFrame, ProtocolError<Self::CustomErr>> {
match self.sink.recv().await? { match self.sink.recv().await? {
MpscMsg::Event(_) => Err(ProtocolError::Violated), MpscMsg::Event(_) => Err(ProtocolError::Violated),
MpscMsg::InitFrame(f) => Ok(f), MpscMsg::InitFrame(f) => Ok(f),
@ -209,25 +221,30 @@ pub mod test_utils {
#[async_trait] #[async_trait]
impl UnreliableDrain for ACDrain { impl UnreliableDrain for ACDrain {
type CustomErr = ();
type DataFormat = MpscMsg; type DataFormat = MpscMsg;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { async fn send(
&mut self,
data: Self::DataFormat,
) -> Result<(), ProtocolError<Self::CustomErr>> {
self.sender self.sender
.send(data) .send(data)
.await .await
.map_err(|_| ProtocolError::Closed) .map_err(|_| ProtocolError::Custom(()))
} }
} }
#[async_trait] #[async_trait]
impl UnreliableSink for ACSink { impl UnreliableSink for ACSink {
type CustomErr = ();
type DataFormat = MpscMsg; type DataFormat = MpscMsg;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> { async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
self.receiver self.receiver
.recv() .recv()
.await .await
.map_err(|_| ProtocolError::Closed) .map_err(|_| ProtocolError::Custom(()))
} }
} }
} }

View File

@ -147,7 +147,9 @@ where
} }
} }
async fn recv_into_stream(&mut self) -> Result<QuicDataFormatStream, ProtocolError> { async fn recv_into_stream(
&mut self,
) -> Result<QuicDataFormatStream, ProtocolError<S::CustomErr>> {
let chunk = self.sink.recv().await?; let chunk = self.sink.recv().await?;
let buffer = match chunk.stream { let buffer = match chunk.stream {
QuicDataFormatStream::Main => &mut self.main_buffer, QuicDataFormatStream::Main => &mut self.main_buffer,
@ -181,6 +183,8 @@ impl<D> SendProtocol for QuicSendProtocol<D>
where where
D: UnreliableDrain<DataFormat = QuicDataFormat>, D: UnreliableDrain<DataFormat = QuicDataFormat>,
{ {
type CustomErr = D::CustomErr;
fn notify_from_recv(&mut self, event: ProtocolEvent) { fn notify_from_recv(&mut self, event: ProtocolEvent) {
match event { match event {
ProtocolEvent::OpenStream { ProtocolEvent::OpenStream {
@ -206,7 +210,7 @@ where
} }
} }
async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
#[cfg(feature = "trace_pedantic")] #[cfg(feature = "trace_pedantic")]
trace!(?event, "send"); trace!(?event, "send");
match event { match event {
@ -268,7 +272,7 @@ where
&mut self, &mut self,
bandwidth: Bandwidth, bandwidth: Bandwidth,
dt: Duration, dt: Duration,
) -> Result</* actual */ Bandwidth, ProtocolError> { ) -> Result</* actual */ Bandwidth, ProtocolError<Self::CustomErr>> {
let (frames, _) = self.store.grab(bandwidth, dt); let (frames, _) = self.store.grab(bandwidth, dt);
//Todo: optimize reserve //Todo: optimize reserve
let mut data_frames = 0; let mut data_frames = 0;
@ -343,7 +347,9 @@ impl<S> RecvProtocol for QuicRecvProtocol<S>
where where
S: UnreliableSink<DataFormat = QuicDataFormat>, S: UnreliableSink<DataFormat = QuicDataFormat>,
{ {
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError> { type CustomErr = S::CustomErr;
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
'outer: loop { 'outer: loop {
match ITFrame::read_frame(&mut self.main_buffer) { match ITFrame::read_frame(&mut self.main_buffer) {
Ok(Some(frame)) => { Ok(Some(frame)) => {
@ -484,7 +490,9 @@ impl<D> ReliableDrain for QuicSendProtocol<D>
where where
D: UnreliableDrain<DataFormat = QuicDataFormat>, D: UnreliableDrain<DataFormat = QuicDataFormat>,
{ {
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> { type CustomErr = D::CustomErr;
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError<Self::CustomErr>> {
self.main_buffer.reserve(500); self.main_buffer.reserve(500);
frame.write_bytes(&mut self.main_buffer); frame.write_bytes(&mut self.main_buffer);
self.drain self.drain
@ -498,7 +506,9 @@ impl<S> ReliableSink for QuicRecvProtocol<S>
where where
S: UnreliableSink<DataFormat = QuicDataFormat>, S: UnreliableSink<DataFormat = QuicDataFormat>,
{ {
async fn recv(&mut self) -> Result<InitFrame, ProtocolError> { type CustomErr = S::CustomErr;
async fn recv(&mut self) -> Result<InitFrame, ProtocolError<Self::CustomErr>> {
while self.main_buffer.len() < 100 { while self.main_buffer.len() < 100 {
if self.recv_into_stream().await? == QuicDataFormatStream::Main { if self.recv_into_stream().await? == QuicDataFormatStream::Main {
if let Some(frame) = InitFrame::read_frame(&mut self.main_buffer) { if let Some(frame) = InitFrame::read_frame(&mut self.main_buffer) {
@ -564,9 +574,13 @@ mod test_utils {
#[async_trait] #[async_trait]
impl UnreliableDrain for QuicDrain { impl UnreliableDrain for QuicDrain {
type CustomErr = ();
type DataFormat = QuicDataFormat; type DataFormat = QuicDataFormat;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { async fn send(
&mut self,
data: Self::DataFormat,
) -> Result<(), ProtocolError<Self::CustomErr>> {
use rand::Rng; use rand::Rng;
if matches!(data.stream, QuicDataFormatStream::Unreliable) if matches!(data.stream, QuicDataFormatStream::Unreliable)
&& rand::thread_rng().gen::<f32>() < self.drop_ratio && rand::thread_rng().gen::<f32>() < self.drop_ratio
@ -576,19 +590,20 @@ mod test_utils {
self.sender self.sender
.send(data) .send(data)
.await .await
.map_err(|_| ProtocolError::Closed) .map_err(|_| ProtocolError::Custom(()))
} }
} }
#[async_trait] #[async_trait]
impl UnreliableSink for QuicSink { impl UnreliableSink for QuicSink {
type CustomErr = ();
type DataFormat = QuicDataFormat; type DataFormat = QuicDataFormat;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> { async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
self.receiver self.receiver
.recv() .recv()
.await .await
.map_err(|_| ProtocolError::Closed) .map_err(|_| ProtocolError::Custom(()))
} }
} }
} }
@ -865,7 +880,7 @@ mod tests {
drop(s); drop(s);
let e = e.await.unwrap(); let e = e.await.unwrap();
assert_eq!(e, Err(ProtocolError::Closed)); assert_eq!(e, Err(ProtocolError::Custom(())));
} }
#[tokio::test] #[tokio::test]

View File

@ -100,6 +100,8 @@ impl<D> SendProtocol for TcpSendProtocol<D>
where where
D: UnreliableDrain<DataFormat = BytesMut>, D: UnreliableDrain<DataFormat = BytesMut>,
{ {
type CustomErr = D::CustomErr;
fn notify_from_recv(&mut self, event: ProtocolEvent) { fn notify_from_recv(&mut self, event: ProtocolEvent) {
match event { match event {
ProtocolEvent::OpenStream { ProtocolEvent::OpenStream {
@ -122,7 +124,7 @@ where
} }
} }
async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
#[cfg(feature = "trace_pedantic")] #[cfg(feature = "trace_pedantic")]
trace!(?event, "send"); trace!(?event, "send");
match event { match event {
@ -170,7 +172,7 @@ where
&mut self, &mut self,
bandwidth: Bandwidth, bandwidth: Bandwidth,
dt: Duration, dt: Duration,
) -> Result</* actual */ Bandwidth, ProtocolError> { ) -> Result</* actual */ Bandwidth, ProtocolError<Self::CustomErr>> {
let (frames, total_bytes) = self.store.grab(bandwidth, dt); let (frames, total_bytes) = self.store.grab(bandwidth, dt);
self.buffer.reserve(total_bytes as usize); self.buffer.reserve(total_bytes as usize);
let mut data_frames = 0; let mut data_frames = 0;
@ -228,7 +230,9 @@ impl<S> RecvProtocol for TcpRecvProtocol<S>
where where
S: UnreliableSink<DataFormat = BytesMut>, S: UnreliableSink<DataFormat = BytesMut>,
{ {
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError> { type CustomErr = S::CustomErr;
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
'outer: loop { 'outer: loop {
loop { loop {
match ITFrame::read_frame(&mut self.buffer) { match ITFrame::read_frame(&mut self.buffer) {
@ -307,7 +311,9 @@ impl<D> ReliableDrain for TcpSendProtocol<D>
where where
D: UnreliableDrain<DataFormat = BytesMut>, D: UnreliableDrain<DataFormat = BytesMut>,
{ {
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError> { type CustomErr = D::CustomErr;
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError<Self::CustomErr>> {
let mut buffer = BytesMut::with_capacity(500); let mut buffer = BytesMut::with_capacity(500);
frame.write_bytes(&mut buffer); frame.write_bytes(&mut buffer);
self.drain.send(buffer).await self.drain.send(buffer).await
@ -319,7 +325,9 @@ impl<S> ReliableSink for TcpRecvProtocol<S>
where where
S: UnreliableSink<DataFormat = BytesMut>, S: UnreliableSink<DataFormat = BytesMut>,
{ {
async fn recv(&mut self) -> Result<InitFrame, ProtocolError> { type CustomErr = S::CustomErr;
async fn recv(&mut self) -> Result<InitFrame, ProtocolError<Self::CustomErr>> {
while self.buffer.len() < 100 { while self.buffer.len() < 100 {
let chunk = self.sink.recv().await?; let chunk = self.sink.recv().await?;
self.buffer.extend_from_slice(&chunk); self.buffer.extend_from_slice(&chunk);
@ -371,25 +379,30 @@ mod test_utils {
#[async_trait] #[async_trait]
impl UnreliableDrain for TcpDrain { impl UnreliableDrain for TcpDrain {
type CustomErr = ();
type DataFormat = BytesMut; type DataFormat = BytesMut;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { async fn send(
&mut self,
data: Self::DataFormat,
) -> Result<(), ProtocolError<Self::CustomErr>> {
self.sender self.sender
.send(data) .send(data)
.await .await
.map_err(|_| ProtocolError::Closed) .map_err(|_| ProtocolError::Custom(()))
} }
} }
#[async_trait] #[async_trait]
impl UnreliableSink for TcpSink { impl UnreliableSink for TcpSink {
type CustomErr = ();
type DataFormat = BytesMut; type DataFormat = BytesMut;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> { async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
self.receiver self.receiver
.recv() .recv()
.await .await
.map_err(|_| ProtocolError::Closed) .map_err(|_| ProtocolError::Custom(()))
} }
} }
} }
@ -659,7 +672,7 @@ mod tests {
drop(s); drop(s);
let e = e.await.unwrap(); let e = e.await.unwrap();
assert_eq!(e, Err(ProtocolError::Closed)); assert_eq!(e, Err(ProtocolError::Custom(())));
} }
#[tokio::test] #[tokio::test]

View File

@ -1,4 +1,5 @@
use crate::{ use crate::{
channel::ProtocolsError,
message::{partial_eq_bincode, Message}, message::{partial_eq_bincode, Message},
participant::{A2bStreamOpen, S2bShutdownBparticipant}, participant::{A2bStreamOpen, S2bShutdownBparticipant},
scheduler::{A2sConnect, Scheduler}, scheduler::{A2sConnect, Scheduler},
@ -106,7 +107,7 @@ pub enum NetworkError {
pub enum NetworkConnectError { pub enum NetworkConnectError {
/// Either a Pid UUID clash or you are trying to hijack a connection /// Either a Pid UUID clash or you are trying to hijack a connection
InvalidSecret, InvalidSecret,
Handshake(InitProtocolError), Handshake(InitProtocolError<ProtocolsError>),
Io(std::io::Error), Io(std::io::Error),
} }

View File

@ -193,7 +193,7 @@ impl Protocols {
metrics: Arc<ProtocolMetrics>, metrics: Arc<ProtocolMetrics>,
s2s_stop_listening_r: oneshot::Receiver<()>, s2s_stop_listening_r: oneshot::Receiver<()>,
c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>, c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>,
) -> std::io::Result<()> { ) -> io::Result<()> {
let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel(); let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel();
MPSC_POOL.lock().await.insert(addr, mpsc_s); MPSC_POOL.lock().await.insert(addr, mpsc_s);
trace!(?addr, "Mpsc Listener bound"); trace!(?addr, "Mpsc Listener bound");
@ -255,26 +255,17 @@ impl Protocols {
info!("Connecting Quic to: {}", &addr); info!("Connecting Quic to: {}", &addr);
let connecting = endpoint.connect_with(config, addr, &name).map_err(|e| { let connecting = endpoint.connect_with(config, addr, &name).map_err(|e| {
trace!(?e, "error setting up quic"); trace!(?e, "error setting up quic");
NetworkConnectError::Io(std::io::Error::new( NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
std::io::ErrorKind::ConnectionAborted,
e,
))
})?; })?;
let connection = connecting.await.map_err(|e| { let connection = connecting.await.map_err(|e| {
trace!(?e, "error with quic connection"); trace!(?e, "error with quic connection");
NetworkConnectError::Io(std::io::Error::new( NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
std::io::ErrorKind::ConnectionAborted,
e,
))
})?; })?;
Self::new_quic(connection, false, metrics) Self::new_quic(connection, false, metrics)
.await .await
.map_err(|e| { .map_err(|e| {
trace!(?e, "error with quic"); trace!(?e, "error with quic");
NetworkConnectError::Io(std::io::Error::new( NetworkConnectError::Io(io::Error::new(io::ErrorKind::ConnectionAborted, e))
std::io::ErrorKind::ConnectionAborted,
e,
))
}) })
} }
@ -286,7 +277,7 @@ impl Protocols {
metrics: Arc<ProtocolMetrics>, metrics: Arc<ProtocolMetrics>,
s2s_stop_listening_r: oneshot::Receiver<()>, s2s_stop_listening_r: oneshot::Receiver<()>,
c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>, c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>,
) -> std::io::Result<()> { ) -> io::Result<()> {
let (_endpoint, mut listener) = match quinn::Endpoint::server(server_config, addr) { let (_endpoint, mut listener) = match quinn::Endpoint::server(server_config, addr) {
Ok(v) => v, Ok(v) => v,
Err(e) => return Err(e), Err(e) => return Err(e),
@ -378,12 +369,14 @@ impl Protocols {
#[async_trait] #[async_trait]
impl network_protocol::InitProtocol for Protocols { impl network_protocol::InitProtocol for Protocols {
type CustomErr = ProtocolsError;
async fn initialize( async fn initialize(
&mut self, &mut self,
initializer: bool, initializer: bool,
local_pid: Pid, local_pid: Pid,
secret: u128, secret: u128,
) -> Result<(Pid, Sid, u128), InitProtocolError> { ) -> Result<(Pid, Sid, u128), InitProtocolError<Self::CustomErr>> {
match self { match self {
Protocols::Tcp(p) => p.initialize(initializer, local_pid, secret).await, Protocols::Tcp(p) => p.initialize(initializer, local_pid, secret).await,
Protocols::Mpsc(p) => p.initialize(initializer, local_pid, secret).await, Protocols::Mpsc(p) => p.initialize(initializer, local_pid, secret).await,
@ -395,6 +388,8 @@ impl network_protocol::InitProtocol for Protocols {
#[async_trait] #[async_trait]
impl network_protocol::SendProtocol for SendProtocols { impl network_protocol::SendProtocol for SendProtocols {
type CustomErr = ProtocolsError;
fn notify_from_recv(&mut self, event: ProtocolEvent) { fn notify_from_recv(&mut self, event: ProtocolEvent) {
match self { match self {
SendProtocols::Tcp(s) => s.notify_from_recv(event), SendProtocols::Tcp(s) => s.notify_from_recv(event),
@ -404,7 +399,7 @@ impl network_protocol::SendProtocol for SendProtocols {
} }
} }
async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError> { async fn send(&mut self, event: ProtocolEvent) -> Result<(), ProtocolError<Self::CustomErr>> {
match self { match self {
SendProtocols::Tcp(s) => s.send(event).await, SendProtocols::Tcp(s) => s.send(event).await,
SendProtocols::Mpsc(s) => s.send(event).await, SendProtocols::Mpsc(s) => s.send(event).await,
@ -417,7 +412,7 @@ impl network_protocol::SendProtocol for SendProtocols {
&mut self, &mut self,
bandwidth: Bandwidth, bandwidth: Bandwidth,
dt: Duration, dt: Duration,
) -> Result<Bandwidth, ProtocolError> { ) -> Result<Bandwidth, ProtocolError<Self::CustomErr>> {
match self { match self {
SendProtocols::Tcp(s) => s.flush(bandwidth, dt).await, SendProtocols::Tcp(s) => s.flush(bandwidth, dt).await,
SendProtocols::Mpsc(s) => s.flush(bandwidth, dt).await, SendProtocols::Mpsc(s) => s.flush(bandwidth, dt).await,
@ -429,7 +424,9 @@ impl network_protocol::SendProtocol for SendProtocols {
#[async_trait] #[async_trait]
impl network_protocol::RecvProtocol for RecvProtocols { impl network_protocol::RecvProtocol for RecvProtocols {
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError> { type CustomErr = ProtocolsError;
async fn recv(&mut self) -> Result<ProtocolEvent, ProtocolError<Self::CustomErr>> {
match self { match self {
RecvProtocols::Tcp(r) => r.recv().await, RecvProtocols::Tcp(r) => r.recv().await,
RecvProtocols::Mpsc(r) => r.recv().await, RecvProtocols::Mpsc(r) => r.recv().await,
@ -439,6 +436,32 @@ impl network_protocol::RecvProtocol for RecvProtocols {
} }
} }
#[derive(Debug)]
pub enum MpscError {
Send(tokio::sync::mpsc::error::SendError<network_protocol::MpscMsg>),
Recv,
}
#[cfg(feature = "quic")]
#[derive(Debug)]
pub enum QuicError {
Send(std::io::Error),
Connection(quinn::ConnectionError),
Write(quinn::WriteError),
Read(quinn::ReadError),
InternalMpsc,
}
/// Error types for Protocols
#[derive(Debug)]
pub enum ProtocolsError {
Tcp(std::io::Error),
Udp(std::io::Error),
#[cfg(feature = "quic")]
Quic(QuicError),
Mpsc(MpscError),
}
/////////////////////////////////////// ///////////////////////////////////////
//// TCP //// TCP
#[derive(Debug)] #[derive(Debug)]
@ -454,26 +477,31 @@ pub struct TcpSink {
#[async_trait] #[async_trait]
impl UnreliableDrain for TcpDrain { impl UnreliableDrain for TcpDrain {
type CustomErr = ProtocolsError;
type DataFormat = BytesMut; type DataFormat = BytesMut;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
match self.half.write_all(&data).await { self.half
Ok(()) => Ok(()), .write_all(&data)
Err(_) => Err(ProtocolError::Closed), .await
} .map_err(|e| ProtocolError::Custom(ProtocolsError::Tcp(e)))
} }
} }
#[async_trait] #[async_trait]
impl UnreliableSink for TcpSink { impl UnreliableSink for TcpSink {
type CustomErr = ProtocolsError;
type DataFormat = BytesMut; type DataFormat = BytesMut;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> { async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
self.buffer.resize(1500, 0u8); self.buffer.resize(1500, 0u8);
match self.half.read(&mut self.buffer).await { match self.half.read(&mut self.buffer).await {
Ok(0) => Err(ProtocolError::Closed), Ok(0) => Err(ProtocolError::Custom(ProtocolsError::Tcp(io::Error::new(
io::ErrorKind::BrokenPipe,
"read returned 0 bytes",
)))),
Ok(n) => Ok(self.buffer.split_to(n)), Ok(n) => Ok(self.buffer.split_to(n)),
Err(_) => Err(ProtocolError::Closed), Err(e) => Err(ProtocolError::Custom(ProtocolsError::Tcp(e))),
} }
} }
} }
@ -492,22 +520,27 @@ pub struct MpscSink {
#[async_trait] #[async_trait]
impl UnreliableDrain for MpscDrain { impl UnreliableDrain for MpscDrain {
type CustomErr = ProtocolsError;
type DataFormat = MpscMsg; type DataFormat = MpscMsg;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
self.sender self.sender
.send(data) .send(data)
.await .await
.map_err(|_| ProtocolError::Closed) .map_err(|e| ProtocolError::Custom(ProtocolsError::Mpsc(MpscError::Send(e))))
} }
} }
#[async_trait] #[async_trait]
impl UnreliableSink for MpscSink { impl UnreliableSink for MpscSink {
type CustomErr = ProtocolsError;
type DataFormat = MpscMsg; type DataFormat = MpscMsg;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> { async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
self.receiver.recv().await.ok_or(ProtocolError::Closed) self.receiver
.recv()
.await
.ok_or(ProtocolError::Custom(ProtocolsError::Mpsc(MpscError::Recv)))
} }
} }
@ -560,10 +593,11 @@ fn spawn_new(
#[cfg(feature = "quic")] #[cfg(feature = "quic")]
#[async_trait] #[async_trait]
impl UnreliableDrain for QuicDrain { impl UnreliableDrain for QuicDrain {
type CustomErr = ProtocolsError;
type DataFormat = QuicDataFormat; type DataFormat = QuicDataFormat;
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> { async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError<Self::CustomErr>> {
match match data.stream { match data.stream {
QuicDataFormatStream::Main => self.main.write_all(&data.data).await, QuicDataFormatStream::Main => self.main.write_all(&data.data).await,
QuicDataFormatStream::Unreliable => unimplemented!(), QuicDataFormatStream::Unreliable => unimplemented!(),
QuicDataFormatStream::Reliable(sid) => { QuicDataFormatStream::Reliable(sid) => {
@ -575,41 +609,43 @@ impl UnreliableDrain for QuicDrain {
// IF the buffer is empty this was created localy and WE are allowed to // IF the buffer is empty this was created localy and WE are allowed to
// open_bi(), if not, we NEED to block on sendstreams_r // open_bi(), if not, we NEED to block on sendstreams_r
if data.data.is_empty() { if data.data.is_empty() {
match self.con.open_bi().await { let (mut sendstream, recvstream) =
Ok((mut sendstream, recvstream)) => { self.con.open_bi().await.map_err(|e| {
ProtocolError::Custom(ProtocolsError::Quic(
QuicError::Connection(e),
))
})?;
// send SID as first msg // send SID as first msg
if sendstream.write_u64(sid.get_u64()).await.is_err() { sendstream.write_u64(sid.get_u64()).await.map_err(|e| {
return Err(ProtocolError::Closed); ProtocolError::Custom(ProtocolsError::Quic(QuicError::Send(e)))
} })?;
spawn_new(recvstream, Some(sid), &self.recvstreams_s); spawn_new(recvstream, Some(sid), &self.recvstreams_s);
vacant.insert(sendstream).write_all(&data.data).await vacant.insert(sendstream).write_all(&data.data).await
},
Err(_) => return Err(ProtocolError::Closed),
}
} else { } else {
let sendstream = self let sendstream =
.sendstreams_r self.sendstreams_r
.recv() .recv()
.await .await
.ok_or(ProtocolError::Closed)?; .ok_or(ProtocolError::Custom(ProtocolsError::Quic(
QuicError::InternalMpsc,
)))?;
vacant.insert(sendstream).write_all(&data.data).await vacant.insert(sendstream).write_all(&data.data).await
} }
}, },
} }
}, },
} {
Ok(()) => Ok(()),
Err(_) => Err(ProtocolError::Closed),
} }
.map_err(|e| ProtocolError::Custom(ProtocolsError::Quic(QuicError::Write(e))))
} }
} }
#[cfg(feature = "quic")] #[cfg(feature = "quic")]
#[async_trait] #[async_trait]
impl UnreliableSink for QuicSink { impl UnreliableSink for QuicSink {
type CustomErr = ProtocolsError;
type DataFormat = QuicDataFormat; type DataFormat = QuicDataFormat;
async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError> { async fn recv(&mut self) -> Result<Self::DataFormat, ProtocolError<Self::CustomErr>> {
let (mut buffer, result, mut recvstream, id) = loop { let (mut buffer, result, mut recvstream, id) = loop {
use futures_util::FutureExt; use futures_util::FutureExt;
// first handle all bi streams! // first handle all bi streams!
@ -620,20 +656,20 @@ impl UnreliableSink for QuicSink {
}; };
if let Some(remote_stream) = a { if let Some(remote_stream) = a {
match remote_stream { let (sendstream, mut recvstream) = remote_stream.map_err(|e| {
Ok((sendstream, mut recvstream)) => { ProtocolError::Custom(ProtocolsError::Quic(QuicError::Connection(e)))
})?;
let sid = match recvstream.read_u64().await { let sid = match recvstream.read_u64().await {
Ok(u64::MAX) => None, //unreliable Ok(u64::MAX) => None, //unreliable
Ok(sid) => Some(Sid::new(sid)), Ok(sid) => Some(Sid::new(sid)),
Err(_) => return Err(ProtocolError::Violated), Err(_) => return Err(ProtocolError::Violated),
}; };
if self.sendstreams_s.send(sendstream).is_err() { if self.sendstreams_s.send(sendstream).is_err() {
return Err(ProtocolError::Closed); return Err(ProtocolError::Custom(ProtocolsError::Quic(
QuicError::InternalMpsc,
)));
} }
spawn_new(recvstream, sid, &self.recvstreams_s); spawn_new(recvstream, sid, &self.recvstreams_s);
},
Err(_) => return Err(ProtocolError::Closed),
}
} }
if let Some(data) = b { if let Some(data) = b {
@ -642,7 +678,12 @@ impl UnreliableSink for QuicSink {
}; };
let r = match result { let r = match result {
Ok(Some(0)) => Err(ProtocolError::Closed), Ok(Some(0)) => Err(ProtocolError::Custom(ProtocolsError::Quic(
QuicError::Send(io::Error::new(
io::ErrorKind::BrokenPipe,
"read returned 0 bytes",
)),
))),
Ok(Some(n)) => Ok(QuicDataFormat { Ok(Some(n)) => Ok(QuicDataFormat {
stream: match id { stream: match id {
Some(id) => QuicDataFormatStream::Reliable(id), Some(id) => QuicDataFormatStream::Reliable(id),
@ -650,8 +691,15 @@ impl UnreliableSink for QuicSink {
}, },
data: buffer.split_to(n), data: buffer.split_to(n),
}), }),
Ok(None) => Err(ProtocolError::Closed), Ok(None) => Err(ProtocolError::Custom(ProtocolsError::Quic(
Err(_) => Err(ProtocolError::Closed), QuicError::Send(io::Error::new(
io::ErrorKind::BrokenPipe,
"read returned None",
)),
))),
Err(e) => Err(ProtocolError::Custom(ProtocolsError::Quic(
QuicError::Read(e),
))),
}?; }?;
let streams_s_clone = self.recvstreams_s.clone(); let streams_s_clone = self.recvstreams_s.clone();
@ -739,6 +787,15 @@ mod tests {
drop(s); drop(s);
let e = e.await.unwrap(); let e = e.await.unwrap();
assert!(e.is_err()); assert!(e.is_err());
assert_eq!(e.unwrap_err(), ProtocolError::Closed); assert!(matches!(e, Err(..)));
let e = e.unwrap_err();
assert!(matches!(e, ProtocolError::Custom(..)));
assert!(matches!(e, ProtocolError::Custom(ProtocolsError::Tcp(_))));
match e {
ProtocolError::Custom(ProtocolsError::Tcp(e)) => {
assert_eq!(e.kind(), io::ErrorKind::BrokenPipe)
},
_ => panic!("invalid error"),
}
} }
} }

View File

@ -1,6 +1,6 @@
use crate::{ use crate::{
api::{ParticipantError, Stream}, api::{ParticipantError, Stream},
channel::{Protocols, RecvProtocols, SendProtocols}, channel::{Protocols, ProtocolsError, RecvProtocols, SendProtocols},
metrics::NetworkMetrics, metrics::NetworkMetrics,
util::DeferredTracer, util::DeferredTracer,
}; };
@ -371,7 +371,7 @@ impl BParticipant {
self.metrics self.metrics
.participant_bandwidth(&self.remote_pid_string, part_bandwidth); .participant_bandwidth(&self.remote_pid_string, part_bandwidth);
let _ = b2a_bandwidth_stats_s.send(part_bandwidth); let _ = b2a_bandwidth_stats_s.send(part_bandwidth);
let r: Result<(), network_protocol::ProtocolError> = Ok(()); let r: Result<(), network_protocol::ProtocolError<ProtocolsError>> = Ok(());
r r
} }
.await; .await;