mirror of
https://gitlab.com/veloren/veloren.git
synced 2024-08-30 18:12:32 +00:00
228 lines
8.3 KiB
Rust
228 lines
8.3 KiB
Rust
|
use crate::{
|
||
|
frame::InitFrame,
|
||
|
types::{
|
||
|
Pid, Sid, STREAM_ID_OFFSET1, STREAM_ID_OFFSET2, VELOREN_MAGIC_NUMBER,
|
||
|
VELOREN_NETWORK_VERSION,
|
||
|
},
|
||
|
InitProtocol, InitProtocolError, ProtocolError,
|
||
|
};
|
||
|
use async_trait::async_trait;
|
||
|
use tracing::{debug, error, info, trace};
|
||
|
|
||
|
// Protocols might define a Reliable Variant for auto Handshake discovery
|
||
|
// this doesn't need to be effective
|
||
|
#[async_trait]
|
||
|
pub trait ReliableDrain {
|
||
|
async fn send(&mut self, frame: InitFrame) -> Result<(), ProtocolError>;
|
||
|
}
|
||
|
|
||
|
#[async_trait]
|
||
|
pub trait ReliableSink {
|
||
|
async fn recv(&mut self) -> Result<InitFrame, ProtocolError>;
|
||
|
}
|
||
|
|
||
|
#[async_trait]
|
||
|
impl<D, S> InitProtocol for (D, S)
|
||
|
where
|
||
|
D: ReliableDrain + Send,
|
||
|
S: ReliableSink + Send,
|
||
|
{
|
||
|
async fn initialize(
|
||
|
&mut self,
|
||
|
initializer: bool,
|
||
|
local_pid: Pid,
|
||
|
local_secret: u128,
|
||
|
) -> Result<(Pid, Sid, u128), InitProtocolError> {
|
||
|
#[cfg(debug_assertions)]
|
||
|
const WRONG_NUMBER: &'static [u8] = "Handshake does not contain the magic number required \
|
||
|
by veloren server.\nWe are not sure if you are a \
|
||
|
valid veloren client.\nClosing the connection"
|
||
|
.as_bytes();
|
||
|
#[cfg(debug_assertions)]
|
||
|
const WRONG_VERSION: &'static str = "Handshake does contain a correct magic number, but \
|
||
|
invalid version.\nWe don't know how to communicate \
|
||
|
with you.\nClosing the connection";
|
||
|
const ERR_S: &str = "Got A Raw Message, these are usually Debug Messages indicating that \
|
||
|
something went wrong on network layer and connection will be closed";
|
||
|
|
||
|
let drain = &mut self.0;
|
||
|
let sink = &mut self.1;
|
||
|
|
||
|
if initializer {
|
||
|
drain
|
||
|
.send(InitFrame::Handshake {
|
||
|
magic_number: VELOREN_MAGIC_NUMBER,
|
||
|
version: VELOREN_NETWORK_VERSION,
|
||
|
})
|
||
|
.await?;
|
||
|
}
|
||
|
|
||
|
match sink.recv().await? {
|
||
|
InitFrame::Handshake {
|
||
|
magic_number,
|
||
|
version,
|
||
|
} => {
|
||
|
trace!(?magic_number, ?version, "Recv handshake");
|
||
|
if magic_number != VELOREN_MAGIC_NUMBER {
|
||
|
error!(?magic_number, "Connection with invalid magic_number");
|
||
|
#[cfg(debug_assertions)]
|
||
|
drain.send(InitFrame::Raw(WRONG_NUMBER.to_vec())).await?;
|
||
|
Err(InitProtocolError::WrongMagicNumber(magic_number))
|
||
|
} else if version != VELOREN_NETWORK_VERSION {
|
||
|
error!(?version, "Connection with wrong network version");
|
||
|
#[cfg(debug_assertions)]
|
||
|
drain
|
||
|
.send(InitFrame::Raw(
|
||
|
format!(
|
||
|
"{} Our Version: {:?}\nYour Version: {:?}\nClosing the connection",
|
||
|
WRONG_VERSION, VELOREN_NETWORK_VERSION, version,
|
||
|
)
|
||
|
.as_bytes()
|
||
|
.to_vec(),
|
||
|
))
|
||
|
.await?;
|
||
|
Err(InitProtocolError::WrongVersion(version))
|
||
|
} else {
|
||
|
trace!("Handshake Frame completed");
|
||
|
if initializer {
|
||
|
drain
|
||
|
.send(InitFrame::Init {
|
||
|
pid: local_pid,
|
||
|
secret: local_secret,
|
||
|
})
|
||
|
.await?;
|
||
|
} else {
|
||
|
drain
|
||
|
.send(InitFrame::Handshake {
|
||
|
magic_number: VELOREN_MAGIC_NUMBER,
|
||
|
version: VELOREN_NETWORK_VERSION,
|
||
|
})
|
||
|
.await?;
|
||
|
}
|
||
|
Ok(())
|
||
|
}
|
||
|
},
|
||
|
InitFrame::Raw(bytes) => {
|
||
|
match std::str::from_utf8(bytes.as_slice()) {
|
||
|
Ok(string) => error!(?string, ERR_S),
|
||
|
_ => error!(?bytes, ERR_S),
|
||
|
}
|
||
|
Err(InitProtocolError::Closed)
|
||
|
},
|
||
|
_ => {
|
||
|
info!("Handshake failed");
|
||
|
Err(InitProtocolError::Closed)
|
||
|
},
|
||
|
}?;
|
||
|
|
||
|
match sink.recv().await? {
|
||
|
InitFrame::Init { pid, secret } => {
|
||
|
debug!(?pid, "Participant send their ID");
|
||
|
let stream_id_offset = if initializer {
|
||
|
STREAM_ID_OFFSET1
|
||
|
} else {
|
||
|
drain
|
||
|
.send(InitFrame::Init {
|
||
|
pid: local_pid,
|
||
|
secret: local_secret,
|
||
|
})
|
||
|
.await?;
|
||
|
STREAM_ID_OFFSET2
|
||
|
};
|
||
|
info!(?pid, "This Handshake is now configured!");
|
||
|
Ok((pid, stream_id_offset, secret))
|
||
|
},
|
||
|
InitFrame::Raw(bytes) => {
|
||
|
match std::str::from_utf8(bytes.as_slice()) {
|
||
|
Ok(string) => error!(?string, ERR_S),
|
||
|
_ => error!(?bytes, ERR_S),
|
||
|
}
|
||
|
Err(InitProtocolError::Closed)
|
||
|
},
|
||
|
_ => {
|
||
|
info!("Handshake failed");
|
||
|
Err(InitProtocolError::Closed)
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[cfg(test)]
|
||
|
mod tests {
|
||
|
use super::*;
|
||
|
use crate::{mpsc::test_utils::*, InitProtocolError};
|
||
|
|
||
|
#[tokio::test]
|
||
|
async fn handshake_drop_start() {
|
||
|
let [mut p1, p2] = ac_bound(10, None);
|
||
|
let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await });
|
||
|
let r2 = tokio::spawn(async move {
|
||
|
let _ = p2;
|
||
|
});
|
||
|
let (r1, _) = tokio::join!(r1, r2);
|
||
|
assert_eq!(r1.unwrap(), Err(InitProtocolError::Closed));
|
||
|
}
|
||
|
|
||
|
#[tokio::test]
|
||
|
async fn handshake_wrong_magic_number() {
|
||
|
let [mut p1, mut p2] = ac_bound(10, None);
|
||
|
let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await });
|
||
|
let r2 = tokio::spawn(async move {
|
||
|
let _ = p2.1.recv().await?;
|
||
|
p2.0.send(InitFrame::Handshake {
|
||
|
magic_number: *b"woopsie",
|
||
|
version: VELOREN_NETWORK_VERSION,
|
||
|
})
|
||
|
.await?;
|
||
|
let _ = p2.1.recv().await?;
|
||
|
Result::<(), InitProtocolError>::Ok(())
|
||
|
});
|
||
|
let (r1, r2) = tokio::join!(r1, r2);
|
||
|
assert_eq!(
|
||
|
r1.unwrap(),
|
||
|
Err(InitProtocolError::WrongMagicNumber(*b"woopsie"))
|
||
|
);
|
||
|
assert_eq!(r2.unwrap(), Ok(()));
|
||
|
}
|
||
|
|
||
|
#[tokio::test]
|
||
|
async fn handshake_wrong_version() {
|
||
|
let [mut p1, mut p2] = ac_bound(10, None);
|
||
|
let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await });
|
||
|
let r2 = tokio::spawn(async move {
|
||
|
let _ = p2.1.recv().await?;
|
||
|
p2.0.send(InitFrame::Handshake {
|
||
|
magic_number: VELOREN_MAGIC_NUMBER,
|
||
|
version: [0, 1, 2],
|
||
|
})
|
||
|
.await?;
|
||
|
let _ = p2.1.recv().await?;
|
||
|
let _ = p2.1.recv().await?; //this should be closed now
|
||
|
Ok(())
|
||
|
});
|
||
|
let (r1, r2) = tokio::join!(r1, r2);
|
||
|
assert_eq!(r1.unwrap(), Err(InitProtocolError::WrongVersion([0, 1, 2])));
|
||
|
assert_eq!(r2.unwrap(), Err(InitProtocolError::Closed));
|
||
|
}
|
||
|
|
||
|
#[tokio::test]
|
||
|
async fn handshake_unexpected_raw() {
|
||
|
let [mut p1, mut p2] = ac_bound(10, None);
|
||
|
let r1 = tokio::spawn(async move { p1.initialize(true, Pid::fake(2), 1337).await });
|
||
|
let r2 = tokio::spawn(async move {
|
||
|
let _ = p2.1.recv().await?;
|
||
|
p2.0.send(InitFrame::Handshake {
|
||
|
magic_number: VELOREN_MAGIC_NUMBER,
|
||
|
version: VELOREN_NETWORK_VERSION,
|
||
|
})
|
||
|
.await?;
|
||
|
let _ = p2.1.recv().await?;
|
||
|
p2.0.send(InitFrame::Raw(b"Hello World".to_vec())).await?;
|
||
|
Result::<(), InitProtocolError>::Ok(())
|
||
|
});
|
||
|
let (r1, r2) = tokio::join!(r1, r2);
|
||
|
assert_eq!(r1.unwrap(), Err(InitProtocolError::Closed));
|
||
|
assert_eq!(r2.unwrap(), Ok(()));
|
||
|
}
|
||
|
}
|