Skip to content

Commit

Permalink
Stabilized transactions for access tokens feature
Browse files Browse the repository at this point in the history
  • Loading branch information
zaychenko-sergei committed Jun 14, 2024
1 parent 5b4c061 commit 8cf31a7
Show file tree
Hide file tree
Showing 12 changed files with 107 additions and 247 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use opendatafabric::AccountID;
use thiserror::Error;
use uuid::Uuid;

use crate::{AccessToken, Account};
use crate::AccessToken;

///////////////////////////////////////////////////////////////////////////////

Expand All @@ -38,11 +38,11 @@ pub trait AccessTokenRepository: Send + Sync {
revoke_time: DateTime<Utc>,
) -> Result<(), RevokeTokenError>;

async fn find_account_by_active_token_id(
async fn find_account_id_by_active_token_id(
&self,
token_id: &Uuid,
token_hash: [u8; 32],
) -> Result<Account, FindAccountByTokenError>;
) -> Result<AccountID, FindAccountByTokenError>;
}

///////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ use uuid::Uuid;
use crate::{
AccessToken,
AccessTokenPaginationOpts,
Account,
CreateAccessTokenError,
FindAccountByTokenError,
GetAccessTokenError,
Expand All @@ -31,11 +30,11 @@ pub trait AccessTokenService: Sync + Send {
account_id: &AccountID,
) -> Result<KamuAccessToken, CreateAccessTokenError>;

async fn find_account_by_active_token_id(
async fn find_account_id_by_active_token_id(
&self,
token_id: &Uuid,
token_hash: [u8; 32],
) -> Result<Account, FindAccountByTokenError>;
) -> Result<AccountID, FindAccountByTokenError>;

async fn get_token_by_id(&self, token_id: &Uuid) -> Result<AccessToken, GetAccessTokenError>;

Expand Down
7 changes: 3 additions & 4 deletions src/domain/accounts/services/src/access_token_service_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ use kamu_accounts::{
AccessTokenPaginationOpts,
AccessTokenRepository,
AccessTokenService,
Account,
CreateAccessTokenError,
FindAccountByTokenError,
GetAccessTokenError,
Expand Down Expand Up @@ -81,13 +80,13 @@ impl AccessTokenService for AccessTokenServiceImpl {
Ok(kamu_access_token)
}

async fn find_account_by_active_token_id(
async fn find_account_id_by_active_token_id(
&self,
token_id: &Uuid,
token_hash: [u8; 32],
) -> Result<Account, FindAccountByTokenError> {
) -> Result<AccountID, FindAccountByTokenError> {
self.access_token_repository
.find_account_by_active_token_id(token_id, token_hash)
.find_account_id_by_active_token_id(token_id, token_hash)
.await
}

Expand Down
24 changes: 16 additions & 8 deletions src/domain/accounts/services/src/authentication_service_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,22 @@ impl AuthenticationServiceImpl {
Err(e) => Err(GetAccountInfoError::Internal(e)),
}
}
AccessTokenType::KamuAccessToken(kamu_access_token) => self
.access_token_svc
.find_account_by_active_token_id(
&kamu_access_token.id,
kamu_access_token.random_bytes_hash,
)
.await
.map_err(|err| GetAccountInfoError::Internal(err.int_err())),
AccessTokenType::KamuAccessToken(kamu_access_token) => {
let account_id = self
.access_token_svc
.find_account_id_by_active_token_id(
&kamu_access_token.id,
kamu_access_token.random_bytes_hash,
)
.await
.map_err(|err| GetAccountInfoError::Internal(err.int_err()))?;

match self.account_by_id(&account_id).await {
Ok(Some(account)) => Ok(account),
Ok(None) => Err(GetAccountInfoError::AccountUnresolved),
Err(e) => Err(GetAccountInfoError::Internal(e)),
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use std::sync::{Arc, Mutex};

use chrono::{DateTime, Utc};
use dill::*;
use internal_error::ErrorIntoInternal;
use kamu_accounts::AccessToken;
use opendatafabric::AccountID;
use uuid::Uuid;
Expand All @@ -24,7 +23,6 @@ use crate::domain::*;

pub struct AccessTokenRepositoryInMemory {
state: Arc<Mutex<State>>,
account_repository: Arc<dyn AccountRepository>,
}

/////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -52,9 +50,8 @@ impl State {
#[interface(dyn AccessTokenRepository)]
#[scope(Singleton)]
impl AccessTokenRepositoryInMemory {
pub fn new(account_repository: Arc<dyn AccountRepository>) -> Self {
pub fn new() -> Self {
Self {
account_repository,
state: Arc::new(Mutex::new(State::new())),
}
}
Expand Down Expand Up @@ -157,11 +154,11 @@ impl AccessTokenRepository for AccessTokenRepositoryInMemory {
}))
}

async fn find_account_by_active_token_id(
async fn find_account_id_by_active_token_id(
&self,
token_id: &Uuid,
token_hash: [u8; 32],
) -> Result<Account, FindAccountByTokenError> {
) -> Result<AccountID, FindAccountByTokenError> {
let access_token = self
.get_token_by_id(token_id)
.await
Expand All @@ -181,12 +178,6 @@ impl AccessTokenRepository for AccessTokenRepositoryInMemory {
return Err(FindAccountByTokenError::InvalidTokenHash);
}

let account = self
.account_repository
.get_account_by_id(&access_token.account_id)
.await
.map_err(|err| FindAccountByTokenError::Internal(err.int_err()))?;

Ok(account)
Ok(access_token.account_id)
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -206,36 +206,25 @@ impl AccessTokenRepository for PostgresAccessTokenRepository {
Ok(())
}

async fn find_account_by_active_token_id(
async fn find_account_id_by_active_token_id(
&self,
token_id: &Uuid,
token_hash: [u8; 32],
) -> Result<Account, FindAccountByTokenError> {
) -> Result<AccountID, FindAccountByTokenError> {
let mut tr = self.transaction.lock().await;

let connection_mut = tr
.connection_mut()
.await
.map_err(FindAccountByTokenError::Internal)?;

let maybe_account_row = sqlx::query_as!(
AccountWithTokenRowModel,
let maybe_account_row = sqlx::query!(
r#"
SELECT
at.token_hash,
a.id as "id: AccountID",
a.account_name,
a.email as "email?",
a.display_name,
a.account_type as "account_type: AccountType",
a.avatar_url,
a.registered_at,
a.is_admin,
a.provider,
a.provider_identity_key
at.account_id
FROM access_tokens at
LEFT JOIN accounts a ON a.id = account_id
WHERE at.id = $1 AND revoked_at IS null
WHERE at.id = $1 AND at.revoked_at IS null
"#,
token_id
)
Expand All @@ -248,7 +237,7 @@ impl AccessTokenRepository for PostgresAccessTokenRepository {
if token_hash != account_row.token_hash.as_slice() {
return Err(FindAccountByTokenError::InvalidTokenHash);
}
Ok(account_row.into())
Ok(AccountID::from_did_str(&account_row.account_id).unwrap())
} else {
Err(FindAccountByTokenError::NotFound(
AccessTokenNotFoundError {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,25 +178,28 @@ pub async fn test_find_account_by_active_token_id(catalog: &Catalog) {
.await
.unwrap();

let db_account = access_token_repo
.find_account_by_active_token_id(&access_token.id, access_token.token_hash)
let db_account_id = access_token_repo
.find_account_id_by_active_token_id(&access_token.id, access_token.token_hash)
.await
.unwrap();
assert_eq!(db_account, account);
assert_eq!(db_account_id, account.id);

let db_account = access_token_repo
.find_account_by_active_token_id(&access_token.id, fake_access_token.token_hash)
let db_account_res = access_token_repo
.find_account_id_by_active_token_id(&access_token.id, fake_access_token.token_hash)
.await;
assert_matches!(db_account, Err(FindAccountByTokenError::InvalidTokenHash));
assert_matches!(
db_account_res,
Err(FindAccountByTokenError::InvalidTokenHash)
);

let revoke_time = Utc::now().round_subsecs(6);
let revoke_result = access_token_repo
.mark_revoked(&access_token.id, revoke_time)
.await;
assert!(revoke_result.is_ok());

let db_account = access_token_repo
.find_account_by_active_token_id(&access_token.id, access_token.token_hash)
let db_account_res = access_token_repo
.find_account_id_by_active_token_id(&access_token.id, access_token.token_hash)
.await;
assert_matches!(db_account, Err(FindAccountByTokenError::NotFound(_)));
assert_matches!(db_account_res, Err(FindAccountByTokenError::NotFound(_)));
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 8cf31a7

Please sign in to comment.