better protocol versioning

This commit is contained in:
crabman 2024-04-30 15:48:19 +00:00
parent e07b9b4c3a
commit 78fd92fae5
No known key found for this signature in database
4 changed files with 67 additions and 36 deletions

View File

@ -9,8 +9,8 @@ edition = "2021"
[features]
server = ["dep:rand"]
client = ["tokio/time"]
example = ["tokio/macros", "tokio/rt-multi-thread", "dep:tracing-subscriber"]
default = ["server", "client"]
example = ["tokio/macros", "tokio/rt-multi-thread", "dep:tracing-subscriber", "client", "server"]
default = []
[dependencies]
tokio = { workspace = true, features = ["net", "sync"] }

View File

@ -21,21 +21,26 @@ pub enum QueryClientError {
Io(tokio::io::Error),
Protocol(protocol::Error),
InvalidResponse,
InvalidVersion,
Timeout,
ChallengeFailed,
RequestTooLarge,
}
struct ClientInitData {
p: u64,
#[allow(dead_code)]
server_max_version: u16,
}
/// The `p` field has to be requested from the server each time this client is
/// constructed, if possible reuse this!
pub struct QueryClient {
pub addr: SocketAddr,
p: u64,
init: Option<ClientInitData>,
}
impl QueryClient {
pub fn new(addr: SocketAddr) -> Self { Self { addr, p: 0 } }
pub fn new(addr: SocketAddr) -> Self { Self { addr, init: None } }
pub async fn server_info(&mut self) -> Result<(ServerInfo, Duration), QueryClientError> {
self.send_query(QueryServerRequest::ServerInfo)
@ -73,10 +78,20 @@ impl QueryClient {
// 2 extra bytes for version information, currently unused
buf.extend(VERSION.to_le_bytes());
buf.extend({
let request_data = <RawQueryServerRequest as Parcel>::raw_bytes(
&RawQueryServerRequest { p: self.p, request },
&Default::default(),
)?;
let request_data = if let Some(init) = &self.init {
// TODO: Use the maximum version supported by both the client and server once
// new protocol versions are added
<RawQueryServerRequest as Parcel>::raw_bytes(
&RawQueryServerRequest { p: init.p, request },
&Default::default(),
)?
} else {
// TODO: Use the legacy version here once new protocol versions are added
<RawQueryServerRequest as Parcel>::raw_bytes(
&RawQueryServerRequest { p: 0, request },
&Default::default(),
)?
};
if request_data.len() > MAX_REQUEST_SIZE {
warn!(
?request,
@ -103,14 +118,9 @@ impl QueryClient {
Err(QueryClientError::InvalidResponse)?
}
// FIXME: Allow lower versions once proper versioning is added.
if u16::from_le_bytes(buf[..2].try_into().unwrap()) != VERSION {
Err(QueryClientError::InvalidVersion)?
}
let packet = <RawQueryServerResponse as Parcel>::read(
// TODO: Remove this padding once version information is added to packets
&mut io::Cursor::new(&buf[2..buf_len]),
&mut io::Cursor::new(&buf[..buf_len]),
&Default::default(),
)?;
@ -118,9 +128,12 @@ impl QueryClient {
RawQueryServerResponse::Response(response) => {
return Ok((response, query_sent.elapsed()));
},
RawQueryServerResponse::P(p) => {
trace!(?p, "Resetting p");
self.p = p
RawQueryServerResponse::Init(init) => {
trace!(?init, "Resetting p");
self.init = Some(ClientInitData {
p: init.p,
server_max_version: init.max_supported_version,
});
},
}
}

View File

@ -11,7 +11,7 @@ pub(crate) const MAX_RESPONSE_SIZE: usize = 256;
#[derive(Protocol, Debug, Clone, Copy)]
pub(crate) struct RawQueryServerRequest {
/// See comment on [`RawQueryServerResponse::P`]
/// See comment on [`Init::p`]
pub p: u64,
pub request: QueryServerRequest,
}
@ -27,17 +27,28 @@ pub enum QueryServerRequest {
}
#[derive(Protocol, Debug, Clone, Copy)]
#[protocol(discriminant = "integer")]
#[protocol(discriminator(u8))]
pub(crate) enum RawQueryServerResponse {
Response(QueryServerResponse),
pub(crate) struct Init {
/// This is used as a challenge to prevent IP address spoofing by verifying
/// that the client can receive from the source address.
///
/// Any request to the server must include this value to be processed,
/// otherwise this response will be returned (giving clients a value to pass
/// for later requests).
P(u64),
pub p: u64,
/// The maximum supported protocol version by the server. The first request
/// to a server must always be done in the V0 protocol to query this value.
/// Following requests (when the version is known), can be done in the
/// maximum version or below, responses will be sent in the same version as
/// the requests.
pub max_supported_version: u16,
}
#[derive(Protocol, Debug, Clone, Copy)]
#[protocol(discriminant = "integer")]
#[protocol(discriminator(u8))]
pub(crate) enum RawQueryServerResponse {
Response(QueryServerResponse),
Init(Init),
}
#[derive(Protocol, Debug, Clone, Copy)]

View File

@ -14,8 +14,9 @@ use tracing::{debug, error, trace};
use crate::{
proto::{
QueryServerRequest, QueryServerResponse, RawQueryServerRequest, RawQueryServerResponse,
ServerInfo, MAX_REQUEST_SIZE, MAX_RESPONSE_SIZE, VELOREN_HEADER, VERSION,
Init, QueryServerRequest, QueryServerResponse, RawQueryServerRequest,
RawQueryServerResponse, ServerInfo, MAX_REQUEST_SIZE, MAX_RESPONSE_SIZE, VELOREN_HEADER,
VERSION,
},
ratelimit::{RateLimiter, ReducedIpAddr},
};
@ -183,7 +184,16 @@ impl QueryServer {
};
if real_p != client_p {
Self::send_response(RawQueryServerResponse::P(real_p), remote, socket, metrics).await;
Self::send_response(
RawQueryServerResponse::Init(Init {
p: real_p,
max_supported_version: VERSION,
}),
remote,
socket,
metrics,
)
.await;
return;
}
@ -225,30 +235,27 @@ impl QueryServer {
socket: &UdpSocket,
metrics: &mut Metrics,
) {
// TODO: Remove this extra padding once we add version information to requests
let mut buf = Vec::from(VERSION.to_ne_bytes());
// TODO: Once more versions are added, send the packet in the same version as
// the request here.
match <RawQueryServerResponse as Parcel>::raw_bytes(&response, &Default::default()) {
Ok(data) => {
buf.extend(data);
if buf.len() > MAX_RESPONSE_SIZE {
if data.len() > MAX_RESPONSE_SIZE {
error!(
?MAX_RESPONSE_SIZE,
"Attempted to send a response larger than the maximum allowed size (size: \
{}, response: {response:?})",
buf.len()
data.len()
);
#[cfg(debug_assertions)]
panic!(
"Attempted to send a response larger than the maximum allowed size (size: \
{}, max: {}, response: {response:?})",
buf.len(),
data.len(),
MAX_RESPONSE_SIZE
);
}
match socket.send_to(&buf, addr).await {
match socket.send_to(&data, addr).await {
Ok(_) => {
metrics.sent_responses += 1;
},