mirror of
https://gitlab.com/veloren/veloren.git
synced 2024-08-30 18:12:32 +00:00
move connect code to channel and get rid of unwraps
This commit is contained in:
parent
95b186e29a
commit
4afadf57dc
@ -222,7 +222,9 @@ where
|
||||
if is_reliable(&promises) {
|
||||
self.reliable_buffers.insert(sid, BytesMut::new());
|
||||
//Send a empty message to notify local drain of stream
|
||||
self.drain.send(QuicDataFormat::with_reliable(&mut BytesMut::new(), sid)).await?;
|
||||
self.drain
|
||||
.send(QuicDataFormat::with_reliable(&mut BytesMut::new(), sid))
|
||||
.await?;
|
||||
}
|
||||
event.to_frame().write_bytes(&mut self.main_buffer);
|
||||
self.drain
|
||||
|
@ -1,3 +1,4 @@
|
||||
use crate::api::NetworkConnectError;
|
||||
use async_trait::async_trait;
|
||||
use bytes::BytesMut;
|
||||
use network_protocol::{
|
||||
@ -9,10 +10,12 @@ use network_protocol::{
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net,
|
||||
net::tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
sync::mpsc,
|
||||
sync::{mpsc, oneshot},
|
||||
};
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{info, trace};
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[derive(Debug)]
|
||||
@ -40,6 +43,23 @@ pub(crate) enum RecvProtocols {
|
||||
}
|
||||
|
||||
impl Protocols {
|
||||
const MPSC_CHANNEL_BOUND: usize = 1000;
|
||||
|
||||
pub(crate) async fn with_tcp_connect(
|
||||
addr: std::net::SocketAddr,
|
||||
cid: Cid,
|
||||
metrics: Arc<ProtocolMetrics>,
|
||||
) -> Result<Self, NetworkConnectError> {
|
||||
let stream = match net::TcpStream::connect(addr).await {
|
||||
Ok(stream) => stream,
|
||||
Err(e) => {
|
||||
return Err(crate::api::NetworkConnectError::Io(e));
|
||||
},
|
||||
};
|
||||
info!("Connecting Tcp to: {}", stream.peer_addr().unwrap());
|
||||
Ok(Protocols::new_tcp(stream, cid, metrics))
|
||||
}
|
||||
|
||||
pub(crate) fn new_tcp(
|
||||
stream: tokio::net::TcpStream,
|
||||
cid: Cid,
|
||||
@ -59,6 +79,49 @@ impl Protocols {
|
||||
Protocols::Tcp((sp, rp))
|
||||
}
|
||||
|
||||
pub(crate) async fn with_mpsc_connect(
|
||||
addr: u64,
|
||||
cid: Cid,
|
||||
metrics: Arc<ProtocolMetrics>,
|
||||
) -> Result<Self, NetworkConnectError> {
|
||||
let mpsc_s = match crate::scheduler::MPSC_POOL.lock().await.get(&addr) {
|
||||
Some(s) => s.clone(),
|
||||
None => {
|
||||
return Err(NetworkConnectError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::NotConnected,
|
||||
"no mpsc listen on this addr",
|
||||
)));
|
||||
},
|
||||
};
|
||||
let (remote_to_local_s, remote_to_local_r) = mpsc::channel(Self::MPSC_CHANNEL_BOUND);
|
||||
let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel();
|
||||
if mpsc_s
|
||||
.send((remote_to_local_s, local_to_remote_oneshot_s))
|
||||
.is_err()
|
||||
{
|
||||
return Err(NetworkConnectError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::BrokenPipe,
|
||||
"mpsc pipe broke during connect",
|
||||
)));
|
||||
}
|
||||
let local_to_remote_s = match local_to_remote_oneshot_r.await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(NetworkConnectError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::BrokenPipe,
|
||||
e,
|
||||
)));
|
||||
},
|
||||
};
|
||||
info!(?addr, "Connecting Mpsc");
|
||||
Ok(Self::new_mpsc(
|
||||
local_to_remote_s,
|
||||
remote_to_local_r,
|
||||
cid,
|
||||
metrics,
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) fn new_mpsc(
|
||||
sender: mpsc::Sender<MpscMsg>,
|
||||
receiver: mpsc::Receiver<MpscMsg>,
|
||||
@ -72,6 +135,46 @@ impl Protocols {
|
||||
Protocols::Mpsc((sp, rp))
|
||||
}
|
||||
|
||||
pub(crate) async fn with_quic_connect(
|
||||
addr: std::net::SocketAddr,
|
||||
config: quinn::ClientConfig,
|
||||
name: String,
|
||||
cid: Cid,
|
||||
metrics: Arc<ProtocolMetrics>,
|
||||
) -> Result<Self, NetworkConnectError> {
|
||||
let config = config.clone();
|
||||
let endpoint = quinn::Endpoint::builder();
|
||||
let (endpoint, _) = match endpoint.bind(&"[::]:0".parse().unwrap()) {
|
||||
Ok(e) => e,
|
||||
Err(quinn::EndpointError::Socket(e)) => return Err(NetworkConnectError::Io(e)),
|
||||
};
|
||||
|
||||
info!("Connecting Quic to: {}", &addr);
|
||||
let connecting = endpoint.connect_with(config, &addr, &name).map_err(|e| {
|
||||
trace!(?e, "error setting up quic");
|
||||
NetworkConnectError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::ConnectionAborted,
|
||||
e,
|
||||
))
|
||||
})?;
|
||||
let connection = connecting.await.map_err(|e| {
|
||||
trace!(?e, "error with quic connection");
|
||||
NetworkConnectError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::ConnectionAborted,
|
||||
e,
|
||||
))
|
||||
})?;
|
||||
Protocols::new_quic(connection, false, cid, metrics)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
trace!(?e, "error with quic");
|
||||
NetworkConnectError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::ConnectionAborted,
|
||||
e,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(feature = "quic")]
|
||||
pub(crate) async fn new_quic(
|
||||
mut connection: quinn::NewConnection,
|
||||
@ -84,11 +187,15 @@ impl Protocols {
|
||||
let (sendstream, recvstream) = if listen {
|
||||
connection.connection.open_bi().await?
|
||||
} else {
|
||||
connection.bi_streams.next().await.expect("none").expect("dasdasd")
|
||||
connection
|
||||
.bi_streams
|
||||
.next()
|
||||
.await
|
||||
.ok_or_else(|| quinn::ConnectionError::LocallyClosed)??
|
||||
};
|
||||
let (recvstreams_s,recvstreams_r) = mpsc::unbounded_channel();
|
||||
let (recvstreams_s, recvstreams_r) = mpsc::unbounded_channel();
|
||||
let streams_s_clone = recvstreams_s.clone();
|
||||
let (sendstreams_s,sendstreams_r) = mpsc::unbounded_channel();
|
||||
let (sendstreams_s, sendstreams_r) = mpsc::unbounded_channel();
|
||||
let sp = QuicSendProtocol::new(
|
||||
QuicDrain {
|
||||
con: connection.connection.clone(),
|
||||
@ -261,7 +368,12 @@ impl UnreliableSink for MpscSink {
|
||||
///////////////////////////////////////
|
||||
//// QUIC
|
||||
#[cfg(feature = "quic")]
|
||||
type QuicStream = (BytesMut, Result<Option<usize>, quinn::ReadError>, quinn::RecvStream, Option<Sid>);
|
||||
type QuicStream = (
|
||||
BytesMut,
|
||||
Result<Option<usize>, quinn::ReadError>,
|
||||
quinn::RecvStream,
|
||||
Option<Sid>,
|
||||
);
|
||||
|
||||
#[cfg(feature = "quic")]
|
||||
#[derive(Debug)]
|
||||
@ -284,7 +396,11 @@ pub struct QuicSink {
|
||||
}
|
||||
|
||||
#[cfg(feature = "quic")]
|
||||
fn spawn_new(mut recvstream: quinn::RecvStream, sid: Option<Sid>, streams_s: &mpsc::UnboundedSender<QuicStream>) {
|
||||
fn spawn_new(
|
||||
mut recvstream: quinn::RecvStream,
|
||||
sid: Option<Sid>,
|
||||
streams_s: &mpsc::UnboundedSender<QuicStream>,
|
||||
) {
|
||||
let streams_s_clone = streams_s.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut buffer = BytesMut::new();
|
||||
@ -301,19 +417,16 @@ impl UnreliableDrain for QuicDrain {
|
||||
|
||||
async fn send(&mut self, data: Self::DataFormat) -> Result<(), ProtocolError> {
|
||||
match match data.stream {
|
||||
QuicDataFormatStream::Main => {
|
||||
self.main.write_all(&data.data).await
|
||||
},
|
||||
QuicDataFormatStream::Main => self.main.write_all(&data.data).await,
|
||||
QuicDataFormatStream::Unreliable => unimplemented!(),
|
||||
QuicDataFormatStream::Reliable(sid) => {
|
||||
use std::collections::hash_map::Entry;
|
||||
tracing::trace!(?sid, "Reliable");
|
||||
match self.reliables.entry(sid) {
|
||||
Entry::Occupied(mut occupied) => {
|
||||
occupied.get_mut().write_all(&data.data).await
|
||||
},
|
||||
Entry::Occupied(mut occupied) => occupied.get_mut().write_all(&data.data).await,
|
||||
Entry::Vacant(vacant) => {
|
||||
// IF the buffer is empty this was created localy and WE are allowed to open_bi(), if not, we NEED to block on sendstreams_r
|
||||
// IF the buffer is empty this was created localy and WE are allowed to
|
||||
// open_bi(), if not, we NEED to block on sendstreams_r
|
||||
if data.data.is_empty() {
|
||||
match self.con.open_bi().await {
|
||||
Ok((mut sendstream, recvstream)) => {
|
||||
@ -327,14 +440,17 @@ impl UnreliableDrain for QuicDrain {
|
||||
Err(_) => return Err(ProtocolError::Closed),
|
||||
}
|
||||
} else {
|
||||
let sendstream = self.sendstreams_r.recv().await.ok_or(ProtocolError::Closed)?;
|
||||
let sendstream = self
|
||||
.sendstreams_r
|
||||
.recv()
|
||||
.await
|
||||
.ok_or(ProtocolError::Closed)?;
|
||||
vacant.insert(sendstream).write_all(&data.data).await
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
{
|
||||
} {
|
||||
Ok(()) => Ok(()),
|
||||
Err(_) => Err(ProtocolError::Closed),
|
||||
}
|
||||
@ -391,7 +507,6 @@ impl UnreliableSink for QuicSink {
|
||||
Err(_) => Err(ProtocolError::Closed),
|
||||
}?;
|
||||
|
||||
|
||||
let streams_s_clone = self.recvstreams_s.clone();
|
||||
tokio::spawn(async move {
|
||||
buffer.resize(1500, 0u8);
|
||||
|
@ -34,7 +34,7 @@ use tracing::*;
|
||||
// - c: channel/handshake
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<S2sMpscConnect>>> = {
|
||||
pub(crate) static ref MPSC_POOL: Mutex<HashMap<u64, mpsc::UnboundedSender<S2sMpscConnect>>> = {
|
||||
Mutex::new(HashMap::new())
|
||||
};
|
||||
}
|
||||
@ -197,94 +197,23 @@ impl Scheduler {
|
||||
let cid = self.channel_ids.fetch_add(1, Ordering::Relaxed);
|
||||
let metrics = Arc::clone(&self.protocol_metrics);
|
||||
self.metrics.connect_request(&addr);
|
||||
let (protocol, handshake) = match addr {
|
||||
ConnectAddr::Tcp(addr) => {
|
||||
let stream = match net::TcpStream::connect(addr).await {
|
||||
Ok(stream) => stream,
|
||||
Err(e) => {
|
||||
pid_sender.send(Err(NetworkConnectError::Io(e))).unwrap();
|
||||
continue;
|
||||
},
|
||||
};
|
||||
info!("Connecting Tcp to: {}", stream.peer_addr().unwrap());
|
||||
(Protocols::new_tcp(stream, cid, metrics), false)
|
||||
},
|
||||
let protocol = match addr {
|
||||
ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, cid, metrics).await,
|
||||
#[cfg(feature = "quic")]
|
||||
ConnectAddr::Quic(addr, ref config, name) => {
|
||||
let config = config.clone();
|
||||
let endpoint = quinn::Endpoint::builder();
|
||||
let (endpoint, _) = endpoint.bind(&"[::]:0".parse().unwrap()).expect("FIXME");
|
||||
|
||||
let connecting = endpoint.connect_with(config, &addr, &name).expect("FIXME");
|
||||
let connection = connecting.await.expect("FIXME");
|
||||
(
|
||||
Protocols::new_quic(connection, false, cid, metrics).await.unwrap(),
|
||||
false,
|
||||
)
|
||||
//pid_sender.send(Ok(())).unwrap();
|
||||
Protocols::with_quic_connect(addr, config.clone(), name, cid, metrics).await
|
||||
},
|
||||
ConnectAddr::Mpsc(addr) => {
|
||||
let mpsc_s = match MPSC_POOL.lock().await.get(&addr) {
|
||||
Some(s) => s.clone(),
|
||||
None => {
|
||||
pid_sender
|
||||
.send(Err(NetworkConnectError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::NotConnected,
|
||||
"no mpsc listen on this addr",
|
||||
))))
|
||||
.unwrap();
|
||||
ConnectAddr::Mpsc(addr) => Protocols::with_mpsc_connect(addr, cid, metrics).await,
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
let protocol = match protocol {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
pid_sender.send(Err(e)).unwrap();
|
||||
continue;
|
||||
},
|
||||
};
|
||||
let (remote_to_local_s, remote_to_local_r) =
|
||||
mpsc::channel(Self::MPSC_CHANNEL_BOUND);
|
||||
let (local_to_remote_oneshot_s, local_to_remote_oneshot_r) = oneshot::channel();
|
||||
mpsc_s
|
||||
.send((remote_to_local_s, local_to_remote_oneshot_s))
|
||||
.unwrap();
|
||||
let local_to_remote_s = local_to_remote_oneshot_r.await.unwrap();
|
||||
info!(?addr, "Connecting Mpsc");
|
||||
(
|
||||
Protocols::new_mpsc(local_to_remote_s, remote_to_local_r, cid, metrics),
|
||||
false,
|
||||
)
|
||||
},
|
||||
/* */
|
||||
//ProtocolConnectAddr::Udp(addr) => {
|
||||
//#[cfg(feature = "metrics")]
|
||||
//self.metrics
|
||||
//.connect_requests_total
|
||||
//.with_label_values(&["udp"])
|
||||
//.inc();
|
||||
//let socket = match net::UdpSocket::bind("0.0.0.0:0").await {
|
||||
//Ok(socket) => Arc::new(socket),
|
||||
//Err(e) => {
|
||||
//pid_sender.send(Err(e)).unwrap();
|
||||
//continue;
|
||||
//},
|
||||
//};
|
||||
//if let Err(e) = socket.connect(addr).await {
|
||||
//pid_sender.send(Err(e)).unwrap();
|
||||
//continue;
|
||||
//};
|
||||
//info!("Connecting Udp to: {}", addr);
|
||||
//let (udp_data_sender, udp_data_receiver) = mpsc::unbounded_channel::<Vec<u8>>();
|
||||
//let protocol = UdpProtocol::new(
|
||||
//Arc::clone(&socket),
|
||||
//addr,
|
||||
//#[cfg(feature = "metrics")]
|
||||
//Arc::clone(&self.metrics),
|
||||
//udp_data_receiver,
|
||||
//);
|
||||
//self.runtime.spawn(
|
||||
//Self::udp_single_channel_connect(Arc::clone(&socket), udp_data_sender)
|
||||
//.instrument(tracing::info_span!("udp", ?addr)),
|
||||
//);
|
||||
//(Protocols::Udp(protocol), true)
|
||||
//},
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
self.init_protocol(protocol, cid, Some(pid_sender), handshake)
|
||||
self.init_protocol(protocol, cid, Some(pid_sender), false)
|
||||
.await;
|
||||
}
|
||||
trace!("Stop connect_mgr");
|
||||
|
Loading…
Reference in New Issue
Block a user