diff --git a/network/src/channel.rs b/network/src/channel.rs index 92fb798ec3..4fd31f1c21 100644 --- a/network/src/channel.rs +++ b/network/src/channel.rs @@ -1,4 +1,4 @@ -use crate::api::NetworkConnectError; +use crate::api::{ConnectAddr, NetworkConnectError}; use async_trait::async_trait; use bytes::BytesMut; use futures_util::FutureExt; @@ -65,6 +65,7 @@ pub(crate) type C2cMpscConnect = ( mpsc::Sender, oneshot::Sender>, ); +pub(crate) type C2sProtocol = (Protocols, ConnectAddr, Cid); impl Protocols { const MPSC_CHANNEL_BOUND: usize = 1000; @@ -92,7 +93,7 @@ impl Protocols { cids: Arc, metrics: Arc, s2s_stop_listening_r: oneshot::Receiver<()>, - c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>, + c2s_protocol_s: mpsc::UnboundedSender, ) -> std::io::Result<()> { use socket2::{Domain, Socket, Type}; let domain = Domain::for_address(addr); @@ -132,7 +133,11 @@ impl Protocols { let cid = cids.fetch_add(1, Ordering::Relaxed); info!(?remote_addr, ?cid, "Accepting Tcp from"); let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics)); - let _ = c2s_protocol_s.send((Self::new_tcp(stream, metrics.clone()), cid)); + let _ = c2s_protocol_s.send(( + Self::new_tcp(stream, metrics.clone()), + ConnectAddr::Tcp(remote_addr), + cid, + )); } }); Ok(()) @@ -192,7 +197,7 @@ impl Protocols { cids: Arc, metrics: Arc, s2s_stop_listening_r: oneshot::Receiver<()>, - c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>, + c2s_protocol_s: mpsc::UnboundedSender, ) -> io::Result<()> { let (mpsc_s, mut mpsc_r) = mpsc::unbounded_channel(); MPSC_POOL.lock().await.insert(addr, mpsc_s); @@ -214,6 +219,7 @@ impl Protocols { let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics)); let _ = c2s_protocol_s.send(( Self::new_mpsc(local_to_remote_s, remote_to_local_r, metrics.clone()), + ConnectAddr::Mpsc(addr), cid, )); } @@ -276,7 +282,7 @@ impl Protocols { cids: Arc, metrics: Arc, s2s_stop_listening_r: oneshot::Receiver<()>, - c2s_protocol_s: mpsc::UnboundedSender<(Self, Cid)>, + c2s_protocol_s: mpsc::UnboundedSender, ) -> io::Result<()> { let (_endpoint, mut listener) = match quinn::Endpoint::server(server_config, addr) { Ok(v) => v, @@ -303,7 +309,16 @@ impl Protocols { let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&metrics)); match Protocols::new_quic(connection, true, metrics).await { Ok(quic) => { - let _ = c2s_protocol_s.send((quic, cid)); + // TODO: we cannot guess the client hostname in quic server here. + // though we need it for the certificate to be validated, in the future + // this will either go away with new auth, or we have to do something like + // a reverse DNS lookup + let connect_addr = ConnectAddr::Quic( + addr, + quinn::ClientConfig::with_native_roots(), + "TODO_remote_hostname".to_string(), + ); + let _ = c2s_protocol_s.send((quic, connect_addr, cid)); }, Err(e) => { trace!(?e, "failed to start quic"); diff --git a/network/src/participant.rs b/network/src/participant.rs index 2f8f38bca6..9564f55a81 100644 --- a/network/src/participant.rs +++ b/network/src/participant.rs @@ -1,5 +1,5 @@ use crate::{ - api::{ParticipantError, Stream}, + api::{ConnectAddr, ParticipantError, Stream}, channel::{Protocols, ProtocolsError, RecvProtocols, SendProtocols}, metrics::NetworkMetrics, util::DeferredTracer, @@ -27,7 +27,7 @@ use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::*; pub(crate) type A2bStreamOpen = (Prio, Promises, Bandwidth, oneshot::Sender); -pub(crate) type S2bCreateChannel = (Cid, Sid, Protocols, oneshot::Sender<()>); +pub(crate) type S2bCreateChannel = (Cid, Sid, Protocols, ConnectAddr, oneshot::Sender<()>); pub(crate) type S2bShutdownBparticipant = (Duration, oneshot::Sender>); pub(crate) type B2sPrioStatistic = (Pid, u64, u64); @@ -36,6 +36,7 @@ pub(crate) type B2sPrioStatistic = (Pid, u64, u64); struct ChannelInfo { cid: Cid, cid_string: String, //optimisationmetrics + remote_con_addr: ConnectAddr, } #[derive(Debug)] @@ -560,35 +561,39 @@ impl BParticipant { ) { let s2b_create_channel_r = UnboundedReceiverStream::new(s2b_create_channel_r); s2b_create_channel_r - .for_each_concurrent(None, |(cid, _, protocol, b2s_create_channel_done_s)| { - // This channel is now configured, and we are running it in scope of the - // participant. - let channels = Arc::clone(&self.channels); - let b2b_add_send_protocol_s = b2b_add_send_protocol_s.clone(); - let b2b_add_recv_protocol_s = b2b_add_recv_protocol_s.clone(); - async move { - let mut lock = channels.write().await; - let mut channel_no = lock.len(); - lock.insert( - cid, - Mutex::new(ChannelInfo { + .for_each_concurrent( + None, + |(cid, _, protocol, remote_con_addr, b2s_create_channel_done_s)| { + // This channel is now configured, and we are running it in scope of the + // participant. + let channels = Arc::clone(&self.channels); + let b2b_add_send_protocol_s = b2b_add_send_protocol_s.clone(); + let b2b_add_recv_protocol_s = b2b_add_recv_protocol_s.clone(); + async move { + let mut lock = channels.write().await; + let mut channel_no = lock.len(); + lock.insert( cid, - cid_string: cid.to_string(), - }), - ); - drop(lock); - let (send, recv) = protocol.split(); - b2b_add_send_protocol_s.send((cid, send)).unwrap(); - b2b_add_recv_protocol_s.send((cid, recv)).unwrap(); - b2s_create_channel_done_s.send(()).unwrap(); - if channel_no > 5 { - debug!(?channel_no, "metrics will overwrite channel #5"); - channel_no = 5; + Mutex::new(ChannelInfo { + cid, + cid_string: cid.to_string(), + remote_con_addr, + }), + ); + drop(lock); + let (send, recv) = protocol.split(); + b2b_add_send_protocol_s.send((cid, send)).unwrap(); + b2b_add_recv_protocol_s.send((cid, recv)).unwrap(); + b2s_create_channel_done_s.send(()).unwrap(); + if channel_no > 5 { + debug!(?channel_no, "metrics will overwrite channel #5"); + channel_no = 5; + } + self.metrics + .channels_connected(&self.remote_pid_string, channel_no, cid); } - self.metrics - .channels_connected(&self.remote_pid_string, channel_no, cid); - } - }) + }, + ) .await; trace!("Stop create_channel_mgr"); self.shutdown_barrier @@ -836,7 +841,7 @@ mod tests { let p1 = Protocols::new_mpsc(s1, r2, metrics); let (complete_s, complete_r) = oneshot::channel(); create_channel - .send((cid, Sid::new(0), p1, complete_s)) + .send((cid, Sid::new(0), p1, ConnectAddr::Mpsc(42), complete_s)) .unwrap(); complete_r.await.unwrap(); let metrics = ProtocolMetricCache::new(&cid.to_string(), met); diff --git a/network/src/scheduler.rs b/network/src/scheduler.rs index 48d51be32e..63a37356c5 100644 --- a/network/src/scheduler.rs +++ b/network/src/scheduler.rs @@ -223,8 +223,8 @@ impl Scheduler { }; let _ = s2a_listen_result_s.send(res); - while let Some((prot, cid)) = c2s_protocol_r.recv().await { - self.init_protocol(prot, cid, None, true).await; + while let Some((prot, con_addr, cid)) = c2s_protocol_r.recv().await { + self.init_protocol(prot, con_addr, cid, None, true).await; } } }) @@ -239,7 +239,7 @@ impl Scheduler { let metrics = ProtocolMetricCache::new(&cid.to_string(), Arc::clone(&self.protocol_metrics)); self.metrics.connect_request(&addr); - let protocol = match addr { + let protocol = match addr.clone() { ConnectAddr::Tcp(addr) => Protocols::with_tcp_connect(addr, metrics).await, #[cfg(feature = "quic")] ConnectAddr::Quic(addr, ref config, name) => { @@ -255,7 +255,7 @@ impl Scheduler { continue; }, }; - self.init_protocol(protocol, cid, Some(pid_sender), false) + self.init_protocol(protocol, addr, cid, Some(pid_sender), false) .await; } trace!("Stop connect_mgr"); @@ -375,6 +375,7 @@ impl Scheduler { async fn init_protocol( &self, mut protocol: Protocols, + con_addr: ConnectAddr, //address necessary to connect to the remote cid: Cid, s2a_return_pid_s: Option>>, send_handshake: bool, @@ -451,7 +452,7 @@ impl Scheduler { oneshot::channel(); //From now on wire connects directly with bparticipant! s2b_create_channel_s - .send((cid, sid, protocol, b2s_create_channel_done_s)) + .send((cid, sid, protocol, con_addr, b2s_create_channel_done_s)) .unwrap(); b2s_create_channel_done_r.await.unwrap(); if let Some(pid_oneshot) = s2a_return_pid_s {