From 01992c05c66f50695c6182c4a061e76c7679e066 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Marcel=20M=C3=A4rtens?= <marcel.cochem@googlemail.com>
Date: Mon, 19 Apr 2021 16:49:23 +0200
Subject: [PATCH] QuicSink and QuicDrain do work now. When local SendProtocol
 is opening a Stream, it will send a empty message to QuicDrain which will
 then know that its time to open a quic stream. It will open a QuicStream and
 send its SID over to remote. The RecvStream will be send to local QuicSink
 RemoteRecv will notice a new BiStream was opened and read its Sid. It will
 now start listening on it. while remote main will get the information that a
 stream was opened and will notice the frontend. in participant remote Recv is
 synced with remote send (without triggering a empty message!). RemoteRecv
 Sink will send the sendstream to RemoteSend Drain and it will be used when a
 first message is send on this stream.

---
 network/protocol/src/quic.rs  | 43 ++++++++++---------
 network/protocol/src/types.rs |  2 +
 network/src/channel.rs        | 78 ++++++++++++++++++++++-------------
 3 files changed, 72 insertions(+), 51 deletions(-)

diff --git a/network/protocol/src/quic.rs b/network/protocol/src/quic.rs
index e656fdf5a1..0e76e1fe32 100644
--- a/network/protocol/src/quic.rs
+++ b/network/protocol/src/quic.rs
@@ -23,7 +23,7 @@ use tracing::trace;
 #[derive(PartialEq)]
 pub enum QuicDataFormatStream {
     Main,
-    Reliable(u64),
+    Reliable(Sid),
     Unreliable,
 }
 
@@ -40,9 +40,9 @@ impl QuicDataFormat {
         }
     }
 
-    fn with_reliable(buffer: &mut BytesMut, id: u64) -> Self {
+    fn with_reliable(buffer: &mut BytesMut, sid: Sid) -> Self {
         Self {
-            stream: QuicDataFormatStream::Reliable(id),
+            stream: QuicDataFormatStream::Reliable(sid),
             data: buffer.split(),
         }
     }
@@ -88,13 +88,19 @@ where
     main_buffer: BytesMut,
     unreliable_buffer: BytesMut,
     reliable_buffers: SortedVec<Sid, BytesMut>,
-    pending_reliable_buffers: Vec<(u64, BytesMut)>,
+    pending_reliable_buffers: Vec<(Sid, BytesMut)>,
     itmsg_allocator: BytesMut,
     incoming: HashMap<Mid, ITMessage>,
     sink: S,
     metrics: ProtocolMetricCache,
 }
 
+fn is_reliable(p: &Promises) -> bool {
+    p.contains(Promises::ORDERED)
+        || p.contains(Promises::CONSISTENCY)
+        || p.contains(Promises::GUARANTEED_DELIVERY)
+}
+
 impl<D> QuicSendProtocol<D>
 where
     D: UnreliableDrain<DataFormat = QuicDataFormat>,
@@ -148,8 +154,8 @@ where
             QuicDataFormatStream::Main => &mut self.main_buffer,
             QuicDataFormatStream::Unreliable => &mut self.unreliable_buffer,
             QuicDataFormatStream::Reliable(id) => {
-                match self.reliable_buffers.data.get_mut(id as usize) {
-                    Some((_, buffer)) => buffer,
+                match self.reliable_buffers.get_mut(&id) {
+                    Some(buffer) => buffer,
                     None => {
                         self.pending_reliable_buffers.push((id, BytesMut::new()));
                         //Violated but will never happen
@@ -186,10 +192,7 @@ where
             } => {
                 self.store
                     .open_stream(sid, prio, promises, guaranteed_bandwidth);
-                if promises.contains(Promises::ORDERED)
-                    || promises.contains(Promises::CONSISTENCY)
-                    || promises.contains(Promises::GUARANTEED_DELIVERY)
-                {
+                if is_reliable(&promises) {
                     self.reliable_buffers.insert(sid, BytesMut::new());
                 }
             },
@@ -216,11 +219,10 @@ where
             } => {
                 self.store
                     .open_stream(sid, prio, promises, guaranteed_bandwidth);
-                if promises.contains(Promises::ORDERED)
-                    || promises.contains(Promises::CONSISTENCY)
-                    || promises.contains(Promises::GUARANTEED_DELIVERY)
-                {
+                if is_reliable(&promises) {
                     self.reliable_buffers.insert(sid, BytesMut::new());
+                    //Send a empty message to notify local drain of stream
+                    self.drain.send(QuicDataFormat::with_reliable(&mut BytesMut::new(), sid)).await?;
                 }
                 event.to_frame().write_bytes(&mut self.main_buffer);
                 self.drain
@@ -284,10 +286,10 @@ where
                 },
             }
         }
-        for (id, (_, buffer)) in self.reliable_buffers.data.iter_mut().enumerate() {
+        for (sid, buffer) in self.reliable_buffers.data.iter_mut() {
             if !buffer.is_empty() {
                 self.drain
-                    .send(QuicDataFormat::with_reliable(buffer, id as u64))
+                    .send(QuicDataFormat::with_reliable(buffer, *sid))
                     .await?;
             }
         }
@@ -354,10 +356,7 @@ where
                             promises,
                             guaranteed_bandwidth,
                         } => {
-                            if promises.contains(Promises::ORDERED)
-                                || promises.contains(Promises::CONSISTENCY)
-                                || promises.contains(Promises::GUARANTEED_DELIVERY)
-                            {
+                            if is_reliable(&promises) {
                                 self.reliable_buffers.insert(sid, BytesMut::new());
                             }
                             break 'outer Ok(ProtocolEvent::OpenStream {
@@ -808,7 +807,7 @@ mod tests {
             length: (DATA1.len() + DATA2.len()) as u64,
         }
         .write_bytes(&mut bytes);
-        s.send(QuicDataFormat::with_reliable(&mut bytes, 0))
+        s.send(QuicDataFormat::with_reliable(&mut bytes, sid))
             .await
             .unwrap();
 
@@ -822,7 +821,7 @@ mod tests {
             data: Bytes::from(&DATA2[..]),
         }
         .write_bytes(&mut bytes);
-        s.send(QuicDataFormat::with_reliable(&mut bytes, 0))
+        s.send(QuicDataFormat::with_reliable(&mut bytes, sid))
             .await
             .unwrap();
 
diff --git a/network/protocol/src/types.rs b/network/protocol/src/types.rs
index dfc9142f38..2e189b412d 100644
--- a/network/protocol/src/types.rs
+++ b/network/protocol/src/types.rs
@@ -118,6 +118,8 @@ impl Pid {
 impl Sid {
     pub const fn new(internal: u64) -> Self { Self { internal } }
 
+    pub fn get_u64(&self) -> u64 { self.internal }
+
     #[inline]
     pub(crate) fn from_bytes(bytes: &mut BytesMut) -> Self {
         Self {
diff --git a/network/src/channel.rs b/network/src/channel.rs
index 9866d88da9..872c0647cf 100644
--- a/network/src/channel.rs
+++ b/network/src/channel.rs
@@ -86,24 +86,27 @@ impl Protocols {
         } else {
             connection.bi_streams.next().await.expect("none").expect("dasdasd")
         };
-        let (streams_s,streams_r) = mpsc::unbounded_channel();
-        let streams_s_clone = streams_s.clone();
+        let (recvstreams_s,recvstreams_r) = mpsc::unbounded_channel();
+        let streams_s_clone = recvstreams_s.clone();
+        let (sendstreams_s,sendstreams_r) = mpsc::unbounded_channel();
         let sp = QuicSendProtocol::new(
             QuicDrain {
                 con: connection.connection.clone(),
                 main: sendstream,
                 reliables: std::collections::HashMap::new(),
-                streams_s: streams_s_clone,
+                recvstreams_s: streams_s_clone,
+                sendstreams_r,
             },
             metrics.clone(),
         );
-        spawn_new(recvstream, None, &streams_s);
+        spawn_new(recvstream, None, &recvstreams_s);
         let rp = QuicRecvProtocol::new(
             QuicSink {
                 con: connection.connection,
                 bi: connection.bi_streams,
-                streams_r,
-                streams_s,
+                recvstreams_r,
+                recvstreams_s,
+                sendstreams_s,
             },
             metrics,
         );
@@ -258,15 +261,16 @@ impl UnreliableSink for MpscSink {
 ///////////////////////////////////////
 //// QUIC
 #[cfg(feature = "quic")]
-type QuicStream = (BytesMut, Result<Option<usize>, quinn::ReadError>, quinn::RecvStream, Option<u64>);
+type QuicStream = (BytesMut, Result<Option<usize>, quinn::ReadError>, quinn::RecvStream, Option<Sid>);
 
 #[cfg(feature = "quic")]
 #[derive(Debug)]
 pub struct QuicDrain {
     con: quinn::Connection,
     main: quinn::SendStream,
-    reliables: std::collections::HashMap<u64, quinn::SendStream>,
-    streams_s: mpsc::UnboundedSender<QuicStream>,
+    reliables: std::collections::HashMap<Sid, quinn::SendStream>,
+    recvstreams_s: mpsc::UnboundedSender<QuicStream>,
+    sendstreams_r: mpsc::UnboundedReceiver<quinn::SendStream>,
 }
 
 #[cfg(feature = "quic")]
@@ -274,18 +278,19 @@ pub struct QuicDrain {
 pub struct QuicSink {
     con: quinn::Connection,
     bi: quinn::IncomingBiStreams,
-    streams_r: mpsc::UnboundedReceiver<QuicStream>,
-    streams_s: mpsc::UnboundedSender<QuicStream>,
+    recvstreams_r: mpsc::UnboundedReceiver<QuicStream>,
+    recvstreams_s: mpsc::UnboundedSender<QuicStream>,
+    sendstreams_s: mpsc::UnboundedSender<quinn::SendStream>,
 }
 
 #[cfg(feature = "quic")]
-fn spawn_new(mut recvstream: quinn::RecvStream, id: Option<u64>, streams_s: &mpsc::UnboundedSender<QuicStream>) {
+fn spawn_new(mut recvstream: quinn::RecvStream, sid: Option<Sid>, streams_s: &mpsc::UnboundedSender<QuicStream>) {
     let streams_s_clone = streams_s.clone();
     tokio::spawn(async move {
         let mut buffer = BytesMut::new();
         buffer.resize(1500, 0u8);
         let r = recvstream.read(&mut buffer).await;
-        let _ = streams_s_clone.send((buffer, r, recvstream, id));
+        let _ = streams_s_clone.send((buffer, r, recvstream, sid));
     });
 }
 
@@ -300,20 +305,30 @@ impl UnreliableDrain for QuicDrain {
                 self.main.write_all(&data.data).await
             },
             QuicDataFormatStream::Unreliable => unimplemented!(),
-            QuicDataFormatStream::Reliable(id) => {
+            QuicDataFormatStream::Reliable(sid) => {
                 use std::collections::hash_map::Entry;
-                match self.reliables.entry(id) {
+                tracing::trace!(?sid, "Reliable");
+                match self.reliables.entry(sid) {
                     Entry::Occupied(mut occupied) => {
                         occupied.get_mut().write_all(&data.data).await
                     },
                     Entry::Vacant(vacant) => {
-                        match self.con.open_bi().await {
-                            Ok((sendstream, recvstream)) => {
-                                let id = Some(0); //TODO FIXME
-                                spawn_new(recvstream, id, &self.streams_s);
-                                vacant.insert(sendstream).write_all(&data.data).await
-                            },
-                            Err(_) => return Err(ProtocolError::Closed),
+                        // IF the buffer is empty this was created localy and WE are allowed to open_bi(), if not, we NEED to block on sendstreams_r
+                        if data.data.is_empty() {
+                            match self.con.open_bi().await {
+                                Ok((mut sendstream, recvstream)) => {
+                                    // send SID as first msg
+                                    if sendstream.write_u64(sid.get_u64()).await.is_err() {
+                                        return Err(ProtocolError::Closed);
+                                    }
+                                    spawn_new(recvstream, Some(sid), &self.recvstreams_s);
+                                    vacant.insert(sendstream).write_all(&data.data).await
+                                },
+                                Err(_) => return Err(ProtocolError::Closed),
+                            }
+                        } else {
+                            let sendstream = self.sendstreams_r.recv().await.ok_or(ProtocolError::Closed)?;
+                            vacant.insert(sendstream).write_all(&data.data).await
                         }
                     },
                 }
@@ -338,16 +353,21 @@ impl UnreliableSink for QuicSink {
             let (a, b) = tokio::select! {
                 biased;
                 Some(n) = self.bi.next().fuse() => (Some(n), None),
-                Some(n) = self.streams_r.recv().fuse() => (None, Some(n)),
+                Some(n) = self.recvstreams_r.recv().fuse() => (None, Some(n)),
             };
 
             if let Some(remote_stream) = a {
                 match remote_stream {
-                    Ok((sendstream, recvstream)) => {
-                        //FIXME TODO
-                        let id = Some(0); // get real ID
-                        drop(sendstream); // not drop it!
-                        spawn_new(recvstream, id, &self.streams_s);
+                    Ok((sendstream, mut recvstream)) => {
+                        let sid = match recvstream.read_u64().await {
+                            Ok(u64::MAX) => None, //unreliable
+                            Ok(sid) => Some(Sid::new(sid)),
+                            Err(_) => return Err(ProtocolError::Violated),
+                        };
+                        if self.sendstreams_s.send(sendstream).is_err() {
+                            return Err(ProtocolError::Closed);
+                        }
+                        spawn_new(recvstream, sid, &self.recvstreams_s);
                     },
                     Err(_) => return Err(ProtocolError::Closed),
                 }
@@ -372,7 +392,7 @@ impl UnreliableSink for QuicSink {
         }?;
 
 
-        let streams_s_clone = self.streams_s.clone();
+        let streams_s_clone = self.recvstreams_s.clone();
         tokio::spawn(async move {
             buffer.resize(1500, 0u8);
             let r = recvstream.read(&mut buffer).await;