better protocol versioning

This commit is contained in:
crabman 2024-04-30 15:48:19 +00:00
parent e07b9b4c3a
commit 78fd92fae5
No known key found for this signature in database
4 changed files with 67 additions and 36 deletions

View File

@ -9,8 +9,8 @@ edition = "2021"
[features] [features]
server = ["dep:rand"] server = ["dep:rand"]
client = ["tokio/time"] client = ["tokio/time"]
example = ["tokio/macros", "tokio/rt-multi-thread", "dep:tracing-subscriber"] example = ["tokio/macros", "tokio/rt-multi-thread", "dep:tracing-subscriber", "client", "server"]
default = ["server", "client"] default = []
[dependencies] [dependencies]
tokio = { workspace = true, features = ["net", "sync"] } tokio = { workspace = true, features = ["net", "sync"] }

View File

@ -21,21 +21,26 @@ pub enum QueryClientError {
Io(tokio::io::Error), Io(tokio::io::Error),
Protocol(protocol::Error), Protocol(protocol::Error),
InvalidResponse, InvalidResponse,
InvalidVersion,
Timeout, Timeout,
ChallengeFailed, ChallengeFailed,
RequestTooLarge, RequestTooLarge,
} }
struct ClientInitData {
p: u64,
#[allow(dead_code)]
server_max_version: u16,
}
/// The `p` field has to be requested from the server each time this client is /// The `p` field has to be requested from the server each time this client is
/// constructed, if possible reuse this! /// constructed, if possible reuse this!
pub struct QueryClient { pub struct QueryClient {
pub addr: SocketAddr, pub addr: SocketAddr,
p: u64, init: Option<ClientInitData>,
} }
impl QueryClient { impl QueryClient {
pub fn new(addr: SocketAddr) -> Self { Self { addr, p: 0 } } pub fn new(addr: SocketAddr) -> Self { Self { addr, init: None } }
pub async fn server_info(&mut self) -> Result<(ServerInfo, Duration), QueryClientError> { pub async fn server_info(&mut self) -> Result<(ServerInfo, Duration), QueryClientError> {
self.send_query(QueryServerRequest::ServerInfo) self.send_query(QueryServerRequest::ServerInfo)
@ -73,10 +78,20 @@ impl QueryClient {
// 2 extra bytes for version information, currently unused // 2 extra bytes for version information, currently unused
buf.extend(VERSION.to_le_bytes()); buf.extend(VERSION.to_le_bytes());
buf.extend({ buf.extend({
let request_data = <RawQueryServerRequest as Parcel>::raw_bytes( let request_data = if let Some(init) = &self.init {
&RawQueryServerRequest { p: self.p, request }, // TODO: Use the maximum version supported by both the client and server once
// new protocol versions are added
<RawQueryServerRequest as Parcel>::raw_bytes(
&RawQueryServerRequest { p: init.p, request },
&Default::default(), &Default::default(),
)?; )?
} else {
// TODO: Use the legacy version here once new protocol versions are added
<RawQueryServerRequest as Parcel>::raw_bytes(
&RawQueryServerRequest { p: 0, request },
&Default::default(),
)?
};
if request_data.len() > MAX_REQUEST_SIZE { if request_data.len() > MAX_REQUEST_SIZE {
warn!( warn!(
?request, ?request,
@ -103,14 +118,9 @@ impl QueryClient {
Err(QueryClientError::InvalidResponse)? Err(QueryClientError::InvalidResponse)?
} }
// FIXME: Allow lower versions once proper versioning is added.
if u16::from_le_bytes(buf[..2].try_into().unwrap()) != VERSION {
Err(QueryClientError::InvalidVersion)?
}
let packet = <RawQueryServerResponse as Parcel>::read( let packet = <RawQueryServerResponse as Parcel>::read(
// TODO: Remove this padding once version information is added to packets // TODO: Remove this padding once version information is added to packets
&mut io::Cursor::new(&buf[2..buf_len]), &mut io::Cursor::new(&buf[..buf_len]),
&Default::default(), &Default::default(),
)?; )?;
@ -118,9 +128,12 @@ impl QueryClient {
RawQueryServerResponse::Response(response) => { RawQueryServerResponse::Response(response) => {
return Ok((response, query_sent.elapsed())); return Ok((response, query_sent.elapsed()));
}, },
RawQueryServerResponse::P(p) => { RawQueryServerResponse::Init(init) => {
trace!(?p, "Resetting p"); trace!(?init, "Resetting p");
self.p = p self.init = Some(ClientInitData {
p: init.p,
server_max_version: init.max_supported_version,
});
}, },
} }
} }

View File

@ -11,7 +11,7 @@ pub(crate) const MAX_RESPONSE_SIZE: usize = 256;
#[derive(Protocol, Debug, Clone, Copy)] #[derive(Protocol, Debug, Clone, Copy)]
pub(crate) struct RawQueryServerRequest { pub(crate) struct RawQueryServerRequest {
/// See comment on [`RawQueryServerResponse::P`] /// See comment on [`Init::p`]
pub p: u64, pub p: u64,
pub request: QueryServerRequest, pub request: QueryServerRequest,
} }
@ -27,17 +27,28 @@ pub enum QueryServerRequest {
} }
#[derive(Protocol, Debug, Clone, Copy)] #[derive(Protocol, Debug, Clone, Copy)]
#[protocol(discriminant = "integer")] pub(crate) struct Init {
#[protocol(discriminator(u8))]
pub(crate) enum RawQueryServerResponse {
Response(QueryServerResponse),
/// This is used as a challenge to prevent IP address spoofing by verifying /// This is used as a challenge to prevent IP address spoofing by verifying
/// that the client can receive from the source address. /// that the client can receive from the source address.
/// ///
/// Any request to the server must include this value to be processed, /// Any request to the server must include this value to be processed,
/// otherwise this response will be returned (giving clients a value to pass /// otherwise this response will be returned (giving clients a value to pass
/// for later requests). /// for later requests).
P(u64), pub p: u64,
/// The maximum supported protocol version by the server. The first request
/// to a server must always be done in the V0 protocol to query this value.
/// Following requests (when the version is known), can be done in the
/// maximum version or below, responses will be sent in the same version as
/// the requests.
pub max_supported_version: u16,
}
#[derive(Protocol, Debug, Clone, Copy)]
#[protocol(discriminant = "integer")]
#[protocol(discriminator(u8))]
pub(crate) enum RawQueryServerResponse {
Response(QueryServerResponse),
Init(Init),
} }
#[derive(Protocol, Debug, Clone, Copy)] #[derive(Protocol, Debug, Clone, Copy)]

View File

@ -14,8 +14,9 @@ use tracing::{debug, error, trace};
use crate::{ use crate::{
proto::{ proto::{
QueryServerRequest, QueryServerResponse, RawQueryServerRequest, RawQueryServerResponse, Init, QueryServerRequest, QueryServerResponse, RawQueryServerRequest,
ServerInfo, MAX_REQUEST_SIZE, MAX_RESPONSE_SIZE, VELOREN_HEADER, VERSION, RawQueryServerResponse, ServerInfo, MAX_REQUEST_SIZE, MAX_RESPONSE_SIZE, VELOREN_HEADER,
VERSION,
}, },
ratelimit::{RateLimiter, ReducedIpAddr}, ratelimit::{RateLimiter, ReducedIpAddr},
}; };
@ -183,7 +184,16 @@ impl QueryServer {
}; };
if real_p != client_p { if real_p != client_p {
Self::send_response(RawQueryServerResponse::P(real_p), remote, socket, metrics).await; Self::send_response(
RawQueryServerResponse::Init(Init {
p: real_p,
max_supported_version: VERSION,
}),
remote,
socket,
metrics,
)
.await;
return; return;
} }
@ -225,30 +235,27 @@ impl QueryServer {
socket: &UdpSocket, socket: &UdpSocket,
metrics: &mut Metrics, metrics: &mut Metrics,
) { ) {
// TODO: Remove this extra padding once we add version information to requests // TODO: Once more versions are added, send the packet in the same version as
let mut buf = Vec::from(VERSION.to_ne_bytes()); // the request here.
match <RawQueryServerResponse as Parcel>::raw_bytes(&response, &Default::default()) { match <RawQueryServerResponse as Parcel>::raw_bytes(&response, &Default::default()) {
Ok(data) => { Ok(data) => {
buf.extend(data); if data.len() > MAX_RESPONSE_SIZE {
if buf.len() > MAX_RESPONSE_SIZE {
error!( error!(
?MAX_RESPONSE_SIZE, ?MAX_RESPONSE_SIZE,
"Attempted to send a response larger than the maximum allowed size (size: \ "Attempted to send a response larger than the maximum allowed size (size: \
{}, response: {response:?})", {}, response: {response:?})",
buf.len() data.len()
); );
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
panic!( panic!(
"Attempted to send a response larger than the maximum allowed size (size: \ "Attempted to send a response larger than the maximum allowed size (size: \
{}, max: {}, response: {response:?})", {}, max: {}, response: {response:?})",
buf.len(), data.len(),
MAX_RESPONSE_SIZE MAX_RESPONSE_SIZE
); );
} }
match socket.send_to(&buf, addr).await { match socket.send_to(&data, addr).await {
Ok(_) => { Ok(_) => {
metrics.sent_responses += 1; metrics.sent_responses += 1;
}, },