Skip to content

Commit

Permalink
Make gateway latency check generic (#4759)
Browse files Browse the repository at this point in the history
* Replace concrete gateway type with trait in latency check

* Rename to ConnectableGateway
  • Loading branch information
octol authored Aug 15, 2024
1 parent ec61728 commit dff82f9
Showing 1 changed file with 40 additions and 18 deletions.
58 changes: 40 additions & 18 deletions common/client-core/src/init/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,34 @@ const MEASUREMENTS: usize = 3;
const CONN_TIMEOUT: Duration = Duration::from_millis(1500);
const PING_TIMEOUT: Duration = Duration::from_millis(1000);

struct GatewayWithLatency<'a> {
gateway: &'a gateway::Node,
// The abstraction that some of these helpers use
pub trait ConnectableGateway {
fn identity(&self) -> &identity::PublicKey;
fn clients_address(&self) -> String;
fn is_wss(&self) -> bool;
}

impl ConnectableGateway for gateway::Node {
fn identity(&self) -> &identity::PublicKey {
self.identity()
}

fn clients_address(&self) -> String {
self.clients_address()
}

fn is_wss(&self) -> bool {
self.clients_wss_port.is_some()
}
}

struct GatewayWithLatency<'a, G: ConnectableGateway> {
gateway: &'a G,
latency: Duration,
}

impl<'a> GatewayWithLatency<'a> {
fn new(gateway: &'a gateway::Node, latency: Duration) -> Self {
impl<'a, G: ConnectableGateway> GatewayWithLatency<'a, G> {
fn new(gateway: &'a G, latency: Duration) -> Self {
GatewayWithLatency { gateway, latency }
}
}
Expand Down Expand Up @@ -130,11 +151,14 @@ async fn connect(endpoint: &str) -> Result<WsConn, ClientCoreError> {
JSWebsocket::new(endpoint).map_err(|_| ClientCoreError::GatewayJsConnectionFailure)
}

async fn measure_latency(gateway: &gateway::Node) -> Result<GatewayWithLatency, ClientCoreError> {
async fn measure_latency<G>(gateway: &G) -> Result<GatewayWithLatency<G>, ClientCoreError>
where
G: ConnectableGateway,
{
let addr = gateway.clients_address();
trace!(
"establishing connection to {} ({addr})...",
gateway.identity_key,
gateway.identity(),
);
let mut stream = connect(&addr).await?;

Expand Down Expand Up @@ -177,7 +201,7 @@ async fn measure_latency(gateway: &gateway::Node) -> Result<GatewayWithLatency,
let count = results.len() as u64;
if count == 0 {
return Err(ClientCoreError::NoGatewayMeasurements {
identity: gateway.identity_key.to_base58_string(),
identity: gateway.identity().to_base58_string(),
});
}

Expand All @@ -187,11 +211,11 @@ async fn measure_latency(gateway: &gateway::Node) -> Result<GatewayWithLatency,
Ok(GatewayWithLatency::new(gateway, avg))
}

pub async fn choose_gateway_by_latency<R: Rng>(
pub async fn choose_gateway_by_latency<'a, R: Rng, G: ConnectableGateway + Clone>(
rng: &mut R,
gateways: &[gateway::Node],
gateways: &[G],
must_use_tls: bool,
) -> Result<gateway::Node, ClientCoreError> {
) -> Result<G, ClientCoreError> {
let gateways = filter_by_tls(gateways, must_use_tls)?;

info!(
Expand Down Expand Up @@ -223,21 +247,19 @@ pub async fn choose_gateway_by_latency<R: Rng>(

info!(
"chose gateway {} with average latency of {:?}",
chosen.gateway.identity_key, chosen.latency
chosen.gateway.identity(),
chosen.latency
);

Ok(chosen.gateway.clone())
}

fn filter_by_tls(
gateways: &[gateway::Node],
fn filter_by_tls<G: ConnectableGateway>(
gateways: &[G],
must_use_tls: bool,
) -> Result<Vec<&gateway::Node>, ClientCoreError> {
) -> Result<Vec<&G>, ClientCoreError> {
if must_use_tls {
let filtered = gateways
.iter()
.filter(|g| g.clients_wss_port.is_some())
.collect::<Vec<_>>();
let filtered = gateways.iter().filter(|g| g.is_wss()).collect::<Vec<_>>();

if filtered.is_empty() {
return Err(ClientCoreError::NoWssGateways);
Expand Down

0 comments on commit dff82f9

Please sign in to comment.