Skip to content

Commit

Permalink
Deal with various auth token types (#247)
Browse files Browse the repository at this point in the history
As of [1], the Janus aggregator API is explicit about which kind of
authentication token it uses (`Bearer` or `DapAuth`) and also generates
bearer tokens by default. This PR adopts changes to the aggregator API
message definitions.

[1]: divviup/janus#1548
  • Loading branch information
tgeoghegan authored Aug 1, 2023
1 parent 9381b83 commit fa6950e
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 42 deletions.
7 changes: 7 additions & 0 deletions client/src/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,10 @@ pub struct NewAggregator {
pub api_url: Url,
pub bearer_token: String,
}

#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
#[serde(tag = "type")]
// Currently, Janus collector authentication tokens are always bearer tokens.
pub enum CollectorAuthenticationToken {
Bearer { token: String },
}
6 changes: 5 additions & 1 deletion client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub const CONTENT_TYPE: &str = "application/vnd.divviup+json;version=0.1";
pub const DEFAULT_URL: &str = "https://api.staging.divviup.org/";
pub const USER_AGENT: &str = concat!("divviup-client/", env!("CARGO_PKG_VERSION"));

use aggregator::CollectorAuthenticationToken;
use serde::{de::DeserializeOwned, Serialize};
use serde_json::json;
use std::{future::Future, pin::Pin};
Expand Down Expand Up @@ -248,7 +249,10 @@ impl DivviupClient {
.await
}

pub async fn task_collector_auth_tokens(&self, task_id: &str) -> ClientResult<Vec<String>> {
pub async fn task_collector_auth_tokens(
&self,
task_id: &str,
) -> ClientResult<Vec<CollectorAuthenticationToken>> {
self.get(&format!("api/tasks/{task_id}/collector_auth_tokens"))
.await
}
Expand Down
14 changes: 7 additions & 7 deletions src/api_mocks/aggregator_api.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::random_chars;
use crate::clients::aggregator_client::api_types::{
AggregatorApiConfig, HpkeAeadId, HpkeConfig, HpkeKdfId, HpkeKemId, HpkePublicKey,
JanusDuration, QueryType, Role, TaskCreate, TaskId, TaskIds, TaskMetrics, TaskResponse,
VdafInstance,
AggregatorApiConfig, AuthenticationToken, HpkeAeadId, HpkeConfig, HpkeKdfId, HpkeKemId,
HpkePublicKey, JanusDuration, QueryType, Role, TaskCreate, TaskId, TaskIds, TaskMetrics,
TaskResponse, VdafInstance,
};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use querystrong::QueryStrong;
Expand Down Expand Up @@ -80,8 +80,8 @@ async fn get_task(conn: &mut Conn, (): ()) -> Json<TaskResponse> {
time_precision: JanusDuration::from_seconds(60),
tolerable_clock_skew: JanusDuration::from_seconds(60),
collector_hpke_config: random_hpke_config(),
aggregator_auth_token: Some(random_chars(32)),
collector_auth_token: Some(random_chars(32)),
aggregator_auth_token: Some(AuthenticationToken::new(random_chars(32))),
collector_auth_token: Some(AuthenticationToken::new(random_chars(32))),
aggregator_hpke_configs: repeat_with(random_hpke_config).take(5).collect(),
})
}
Expand Down Expand Up @@ -109,8 +109,8 @@ pub fn task_response(task_create: TaskCreate) -> TaskResponse {
time_precision: JanusDuration::from_seconds(task_create.time_precision),
tolerable_clock_skew: JanusDuration::from_seconds(60),
collector_hpke_config: random_hpke_config(),
aggregator_auth_token: Some(random_chars(32)),
collector_auth_token: Some(random_chars(32)),
aggregator_auth_token: Some(AuthenticationToken::new(random_chars(32))),
collector_auth_token: Some(AuthenticationToken::new(random_chars(32))),
aggregator_hpke_configs: repeat_with(random_hpke_config).take(5).collect(),
}
}
Expand Down
46 changes: 39 additions & 7 deletions src/clients/aggregator_client/api_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,35 @@ impl From<Option<i64>> for QueryType {
}
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "type")]
pub enum AuthenticationToken {
/// Type of the authentication token. Authentication token type is always "Bearer" in
/// divviup-api.
Bearer {
/// Encoded value of the token. The encoding is opaque to divviup-api.
token: String,
},
}

impl AuthenticationToken {
pub fn new(token: String) -> Self {
Self::Bearer { token }
}

pub fn token(self) -> String {
match self {
Self::Bearer { token } => token,
}
}
}

#[derive(Serialize, Deserialize, Debug)]
pub struct TaskCreate {
#[serde(skip_serializing_if = "Option::is_none")]
pub aggregator_auth_token: Option<String>,
pub aggregator_auth_token: Option<AuthenticationToken>,
#[serde(skip_serializing_if = "Option::is_none")]
pub collector_auth_token: Option<String>,
pub collector_auth_token: Option<AuthenticationToken>,
pub peer_aggregator_endpoint: Url,
pub query_type: QueryType,
pub vdaf: VdafInstance,
Expand Down Expand Up @@ -156,7 +179,10 @@ impl TaskCreate {
time_precision: new_task.time_precision_seconds,
collector_hpke_config: new_task.hpke_config.clone(),
vdaf_verify_key: new_task.vdaf_verify_key.clone(),
aggregator_auth_token: new_task.aggregator_auth_token.clone(),
aggregator_auth_token: new_task
.aggregator_auth_token
.clone()
.map(AuthenticationToken::new),
collector_auth_token: None,
})
}
Expand All @@ -177,8 +203,8 @@ pub struct TaskResponse {
pub time_precision: JanusDuration,
pub tolerable_clock_skew: JanusDuration,
pub collector_hpke_config: HpkeConfig,
pub aggregator_auth_token: Option<String>,
pub collector_auth_token: Option<String>,
pub aggregator_auth_token: Option<AuthenticationToken>,
pub collector_auth_token: Option<AuthenticationToken>,
pub aggregator_hpke_configs: Vec<HpkeConfig>,
}

Expand Down Expand Up @@ -280,8 +306,14 @@ mod test {
"aead_id": "Aes128Gcm",
"public_key": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
},
"aggregator_auth_token": "YWdncmVnYXRvci0xMjM0NTY3OA",
"collector_auth_token": "Y29sbGVjdG9yLWFiY2RlZjAw",
"aggregator_auth_token": {
"type": "Bearer",
"token": "YWdncmVnYXRvci0xMjM0NTY3OA"
},
"collector_auth_token": {
"type": "Bearer",
"token": "Y29sbGVjdG9yLWFiY2RlZjAw"
},
"aggregator_hpke_configs": [
{
"id": 13,
Expand Down
16 changes: 0 additions & 16 deletions src/entity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,3 @@ pub use aggregator::{
pub use api_token::{
Column as ApiTokenColumn, Entity as ApiTokens, Model as ApiToken, UpdateApiToken,
};

mod validators {
const BASE64_CHARS: &[u8] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=";

pub(super) fn base64(data: &str) -> Result<(), validator::ValidationError> {
if data
.chars()
.all(|c| u8::try_from(c).map_or(false, |c| BASE64_CHARS.contains(&c)))
{
Ok(())
} else {
Err(validator::ValidationError::new("base64"))
}
}
}
3 changes: 1 addition & 2 deletions src/entity/aggregator/new_aggregator.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::ActiveModel;
use crate::{
clients::{AggregatorClient, ClientError},
entity::{validators::base64, Account, Aggregator},
entity::{Account, Aggregator},
handler::Error,
};
use sea_orm::IntoActiveModel;
Expand All @@ -19,7 +19,6 @@ pub struct NewAggregator {
pub name: Option<String>,
#[cfg_attr(not(feature = "integration-testing"), validate(custom = "https"))]
pub api_url: Option<String>,
#[validate(required, custom = "base64", length(min = 8))]
pub bearer_token: Option<String>,
pub is_first_party: Option<bool>,
}
Expand Down
4 changes: 1 addition & 3 deletions src/entity/aggregator/update_aggregator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
clients::{AggregatorClient, ClientError},
entity::{validators::base64, Aggregator},
entity::Aggregator,
Error,
};
use sea_orm::{ActiveModelTrait, ActiveValue, IntoActiveModel};
Expand All @@ -14,8 +14,6 @@ use validator::{Validate, ValidationError, ValidationErrors};
pub struct UpdateAggregator {
#[validate(length(min = 1))]
pub name: Option<String>,

#[validate(custom = "base64", length(min = 8))]
pub bearer_token: Option<String>,
}

Expand Down
3 changes: 2 additions & 1 deletion src/entity/task/provisionable_task.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::{ActiveModel, *};
use crate::{
clients::aggregator_client::api_types::AuthenticationToken,
entity::{Account, Aggregator},
handler::Error,
};
Expand Down Expand Up @@ -84,7 +85,7 @@ impl ProvisionableTask {
.provision_aggregator(client.clone(), self.helper_aggregator.clone())
.await?;

self.aggregator_auth_token = helper.aggregator_auth_token;
self.aggregator_auth_token = helper.aggregator_auth_token.map(AuthenticationToken::token);

let _leader = self
.provision_aggregator(client, self.leader_aggregator.clone())
Expand Down
1 change: 0 additions & 1 deletion tests/aggregators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,6 @@ mod create {
let error: Value = conn.response_json().await;
assert!(error.get("name").is_some());
assert!(error.get("api_url").is_some());
assert!(error.get("bearer_token").is_some());
Ok(())
}

Expand Down
10 changes: 6 additions & 4 deletions tests/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,8 @@ mod update {
}

mod collector_auth_tokens {
use divviup_api::clients::aggregator_client::api_types::AuthenticationToken;

use super::{assert_eq, test, *};

#[test(harness = with_client_logs)]
Expand All @@ -668,7 +670,7 @@ mod collector_auth_tokens {
.unwrap();

assert_ok!(conn);
let body: Vec<String> = conn.response_json().await;
let body: Vec<AuthenticationToken> = conn.response_json().await;
assert_eq!(vec![auth_token], body);
Ok(())
}
Expand Down Expand Up @@ -705,7 +707,7 @@ mod collector_auth_tokens {
.unwrap();

assert_ok!(conn);
let body: Vec<String> = conn.response_json().await;
let body: Vec<AuthenticationToken> = conn.response_json().await;
assert_eq!(vec![auth_token], body);
Ok(())
}
Expand All @@ -727,7 +729,7 @@ mod collector_auth_tokens {
.unwrap();

assert_ok!(conn);
let body: Vec<String> = conn.response_json().await;
let body: Vec<AuthenticationToken> = conn.response_json().await;
assert_eq!(vec![auth_token], body);
Ok(())
}
Expand All @@ -749,7 +751,7 @@ mod collector_auth_tokens {
.unwrap();

assert_ok!(conn);
let body: Vec<String> = conn.response_json().await;
let body: Vec<AuthenticationToken> = conn.response_json().await;
assert_eq!(vec![auth_token], body);
Ok(())
}
Expand Down

0 comments on commit fa6950e

Please sign in to comment.