Skip to content

Commit

Permalink
Refactoring + unit tests for token events.
Browse files Browse the repository at this point in the history
  • Loading branch information
manuel-guilbault committed Dec 10, 2018
1 parent 49da521 commit 84548d3
Show file tree
Hide file tree
Showing 23 changed files with 253 additions and 132 deletions.
12 changes: 8 additions & 4 deletions src/AspNetCore.NonInteractiveOidcHandlers/CachingTokenHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,13 @@ protected async Task<TokenResponse> GetTokenAsync(string cacheKey, Func<Cancella
_logger.LogTrace("Token is not cached.");

var tokenResponse = await requestToken(cancellationToken).ConfigureAwait(false);
await _cache
.SetTokenAsync(prefixedCacheKey, tokenResponse, _options, cancellationToken)
.ConfigureAwait(false);
if (tokenResponse != null && !tokenResponse.IsError)
{
await _cache
.SetTokenAsync(prefixedCacheKey, tokenResponse, _options, cancellationToken)
.ConfigureAwait(false);
}

return tokenResponse;
}

Expand All @@ -55,7 +59,7 @@ await _cache
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
var token = await GetTokenAsync(cancellationToken);
if (token != null && token.AccessToken.IsPresent())
if (token != null && !token.IsError && token.AccessToken.IsPresent())
{
request.SetBearerToken(token.AccessToken);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,15 @@ public override async Task<TokenResponse> GetTokenAsync(CancellationToken cancel

private async Task<TokenResponse> AcquireTokenAsync(CancellationToken cancellationToken)
{
var tokenResponseTask = _options.TokenMutex.AcquireAsync(GetToken);
var tokenResponseTask = _options.TokenMutex.AcquireAsync(RequestTokenAsync);
try
{
var tokenResponse = await tokenResponseTask.ConfigureAwait(false);
if (tokenResponse.IsError)
{
_logger.LogError($"Error returned from token endpoint: {tokenResponse.Error}");
await _options.Events.OnTokenRequestFailed.Invoke(tokenResponse).ConfigureAwait(false);
throw new InvalidOperationException(
$"Token retrieval failed: {tokenResponse.Error} {tokenResponse.ErrorDescription}",
tokenResponse.Exception);
return tokenResponse;
}

await _options.Events.OnTokenAcquired(tokenResponse).ConfigureAwait(false);
Expand All @@ -59,7 +57,7 @@ private async Task<TokenResponse> AcquireTokenAsync(CancellationToken cancellati
}
}

private async Task<TokenResponse> GetToken()
private async Task<TokenResponse> RequestTokenAsync()
{
var httpClient = _httpClientFactory.CreateClient(_options.AuthorityHttpClientName);
var tokenEndpoint = await _options.GetTokenEndpointAsync(httpClient).ConfigureAwait(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public class ClientCredentialsTokenHandlerOptions: TokenHandlerOptions
/// </summary>
public IDictionary<string, string> ExtraTokenParameters { get; set; }

internal AsyncMutex<TokenResponse> TokenMutex { get; set; }
internal AsyncMutex<TokenResponse> TokenMutex { get; } = new AsyncMutex<TokenResponse>();

public override IEnumerable<string> GetValidationErrors()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ public override async Task<TokenResponse> GetTokenAsync(CancellationToken cancel
return null;
}

return await GetTokenAsync($"delegation:{inboundToken}", ct => AcquireToken(inboundToken, ct), cancellationToken)
return await GetTokenAsync($"delegation:{inboundToken}", _ => AcquireTokenAsync(inboundToken), cancellationToken)
.ConfigureAwait(false);
}

private async Task<TokenResponse> AcquireToken(string inboundToken, CancellationToken cancellationToken)
private async Task<TokenResponse> AcquireTokenAsync(string inboundToken)
{
var lazyToken = _options.LazyTokens.GetOrAdd(inboundToken, CreateLazyDelegatedToken);

Expand All @@ -64,9 +64,7 @@ private async Task<TokenResponse> AcquireToken(string inboundToken, Cancellation
{
_logger.LogError($"Error returned from token endpoint: {tokenResponse.Error}");
await _options.Events.OnTokenRequestFailed.Invoke(tokenResponse).ConfigureAwait(false);
throw new InvalidOperationException(
$"Token retrieval failed: {tokenResponse.Error} {tokenResponse.ErrorDescription}",
tokenResponse.Exception);
return tokenResponse;
}

await _options.Events.OnTokenAcquired(tokenResponse).ConfigureAwait(false);
Expand All @@ -82,9 +80,9 @@ private async Task<TokenResponse> AcquireToken(string inboundToken, Cancellation
}

private AsyncLazy<TokenResponse> CreateLazyDelegatedToken(string inboundToken)
=> new AsyncLazy<TokenResponse>(() => RequestToken(inboundToken));
=> new AsyncLazy<TokenResponse>(() => RequestTokenAsync(inboundToken));

private async Task<TokenResponse> RequestToken(string inboundToken)
private async Task<TokenResponse> RequestTokenAsync(string inboundToken)
{
var httpClient = _httpClientFactory.CreateClient(_options.AuthorityHttpClientName);
var tokenEndpoint = await _options.GetTokenEndpointAsync(httpClient).ConfigureAwait(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public class DelegationTokenHandlerOptions: TokenHandlerOptions
/// </summary>
public Func<HttpContext, Task<string>> TokenRetriever { get; set; } = TokenRetrieval.FromAuthenticationService();

internal ConcurrentDictionary<string, AsyncLazy<TokenResponse>> LazyTokens { get; set; }
internal ConcurrentDictionary<string, AsyncLazy<TokenResponse>> LazyTokens { get; } = new ConcurrentDictionary<string, AsyncLazy<TokenResponse>>();

public override IEnumerable<string> GetValidationErrors()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public static IHttpClientBuilder AddOidcTokenDelegation(this IHttpClientBuilder
.AddHttpContextAccessor()
.Configure(builder.Name, configureOptions)
.AddPostConfigure<DelegationTokenHandlerOptions, PostConfigureTokenHandlerOptions<DelegationTokenHandlerOptions>>()
.AddPostConfigure<DelegationTokenHandlerOptions, PostConfigureDelegationTokenHandlerOptions>();
;

var instanceName = builder.Name;
return builder.AddHttpMessageHandler(sp =>
Expand All @@ -71,7 +71,7 @@ public static IHttpClientBuilder AddOidcClientCredentials(this IHttpClientBuilde
builder.Services
.Configure(builder.Name, configureOptions)
.AddPostConfigure<ClientCredentialsTokenHandlerOptions, PostConfigureTokenHandlerOptions<ClientCredentialsTokenHandlerOptions>>()
.AddPostConfigure<ClientCredentialsTokenHandlerOptions, PostConfigureClientCredentialsTokenHandlerOptions>();
;

var instanceName = builder.Name;
return builder.AddHttpMessageHandler(sp =>
Expand All @@ -94,7 +94,7 @@ public static IHttpClientBuilder AddOidcPassword(this IHttpClientBuilder builder
builder.Services
.Configure(builder.Name, configureOptions)
.AddPostConfigure<PasswordTokenHandlerOptions, PostConfigureTokenHandlerOptions<PasswordTokenHandlerOptions>>()
.AddPostConfigure<PasswordTokenHandlerOptions, PostConfigurePasswordTokenHandlerOptions>();
;

var instanceName = builder.Name;
return builder.AddHttpMessageHandler(sp =>
Expand All @@ -118,7 +118,7 @@ public static IHttpClientBuilder AddOidcRefreshToken(this IHttpClientBuilder bui
builder.Services
.Configure(builder.Name, configureOptions)
.AddPostConfigure<RefreshTokenHandlerOptions, PostConfigureTokenHandlerOptions<RefreshTokenHandlerOptions>>()
.AddPostConfigure<RefreshTokenHandlerOptions, PostConfigureRefreshTokenHandlerOptions>();
;

var instanceName = builder.Name;
return builder.AddHttpMessageHandler(sp =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@ public Task<T> AcquireAsync(Func<Task<T>> factory)
{
lock (_taskGuard)
{
if (_task != null)
{
return _task;
}

return _task = factory();
return _task ?? (_task = factory());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ namespace AspNetCore.NonInteractiveOidcHandlers.Infrastructure
{
internal static class CachingExtensions
{
private static readonly Encoding CacheEncoding = Encoding.UTF8;

public static async Task<TokenResponse> GetTokenAsync(this IDistributedCache cache, string key, CancellationToken cancellationToken = default(CancellationToken))
{
var bytes = await cache
Expand All @@ -19,7 +21,7 @@ internal static class CachingExtensions
return null;
}

var json = Encoding.UTF8.GetString(bytes);
var json = CacheEncoding.GetString(bytes);
var tokenResponse = new TokenResponse(json);
return tokenResponse;
}
Expand All @@ -35,7 +37,7 @@ internal static class CachingExtensions
var absoluteExpiration = DateTimeOffset.UtcNow.Add(expiresIn < options.CacheDuration ? expiresIn : options.CacheDuration);

var json = tokenResponse.Raw;
var bytes = Encoding.UTF8.GetBytes(json);
var bytes = CacheEncoding.GetBytes(json);
await cache
.SetAsync(key, bytes, new DistributedCacheEntryOptions { AbsoluteExpiration = absoluteExpiration }, cancellationToken)
.ConfigureAwait(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,21 @@ public override async Task<TokenResponse> GetTokenAsync(CancellationToken cancel
}

var (userName, password) = userCredentials.Value;
return await GetTokenAsync($"password:{userName}", ct => AcquireToken(userName, password, ct), cancellationToken)
return await GetTokenAsync($"password:{userName}", _ => AcquireTokenAsync(userName, password), cancellationToken)
.ConfigureAwait(false);
}

private async Task<TokenResponse> AcquireToken(string userName, string password, CancellationToken cancellationToken)
private async Task<TokenResponse> AcquireTokenAsync(string userName, string password)
{
var lazyToken = _options.LazyTokens.GetOrAdd(userName, _ => new AsyncLazy<TokenResponse>(() => RequestToken(userName, password)));
var lazyToken = _options.LazyTokens.GetOrAdd(userName, _ => new AsyncLazy<TokenResponse>(() => RequestTokenAsync(userName, password)));
try
{
var tokenResponse = await lazyToken.Value.ConfigureAwait(false);
if (tokenResponse.IsError)
{
_logger.LogError($"Error returned from token endpoint: {tokenResponse.Error}");
await _options.Events.OnTokenRequestFailed.Invoke(tokenResponse).ConfigureAwait(false);
throw new InvalidOperationException(
$"Token retrieval failed: {tokenResponse.Error} {tokenResponse.ErrorDescription}",
tokenResponse.Exception);
return tokenResponse;
}

await _options.Events.OnTokenAcquired(tokenResponse).ConfigureAwait(false);
Expand All @@ -72,7 +70,7 @@ private async Task<TokenResponse> AcquireToken(string userName, string password,
}
}

private async Task<TokenResponse> RequestToken(string userName, string password)
private async Task<TokenResponse> RequestTokenAsync(string userName, string password)
{
var httpClient = _httpClientFactory.CreateClient(_options.AuthorityHttpClientName);
var tokenEndpoint = await _options.GetTokenEndpointAsync(httpClient).ConfigureAwait(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public class PasswordTokenHandlerOptions: TokenHandlerOptions
/// </summary>
public IDictionary<string, string> ExtraTokenParameters { get; set; }

internal ConcurrentDictionary<string, AsyncLazy<TokenResponse>> LazyTokens { get; set; }
internal ConcurrentDictionary<string, AsyncLazy<TokenResponse>> LazyTokens { get; } = new ConcurrentDictionary<string, AsyncLazy<TokenResponse>>();

public override IEnumerable<string> GetValidationErrors()
{
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

12 changes: 5 additions & 7 deletions src/AspNetCore.NonInteractiveOidcHandlers/RefreshTokenHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,21 @@ public override async Task<TokenResponse> GetTokenAsync(CancellationToken cancel
return null;
}

return await GetTokenAsync($"refresh_token:{refreshToken.ToSha512()}", ct => AcquireToken(refreshToken, ct), cancellationToken)
return await GetTokenAsync($"refresh_token:{refreshToken.ToSha512()}", _ => AcquireTokenAsync(refreshToken), cancellationToken)
.ConfigureAwait(false);
}

private async Task<TokenResponse> AcquireToken(string refreshToken, CancellationToken cancellationToken)
private async Task<TokenResponse> AcquireTokenAsync(string refreshToken)
{
var lazyToken = _options.LazyTokens.GetOrAdd(refreshToken, rt => new AsyncLazy<TokenResponse>(() => RequestToken(rt)));
var lazyToken = _options.LazyTokens.GetOrAdd(refreshToken, rt => new AsyncLazy<TokenResponse>(() => RequestTokenAsync(rt)));
try
{
var tokenResponse = await lazyToken.Value.ConfigureAwait(false);
if (tokenResponse.IsError)
{
_logger.LogError($"Error returned from token endpoint: {tokenResponse.Error}");
await _options.Events.OnTokenRequestFailed.Invoke(tokenResponse).ConfigureAwait(false);
throw new InvalidOperationException(
$"Token retrieval failed: {tokenResponse.Error} {tokenResponse.ErrorDescription}",
tokenResponse.Exception);
return tokenResponse;
}

await _options.Events.OnTokenAcquired(tokenResponse).ConfigureAwait(false);
Expand All @@ -72,7 +70,7 @@ private async Task<TokenResponse> AcquireToken(string refreshToken, Cancellation
}
}

private async Task<TokenResponse> RequestToken(string refreshToken)
private async Task<TokenResponse> RequestTokenAsync(string refreshToken)
{
var httpClient = _httpClientFactory.CreateClient(_options.AuthorityHttpClientName);
var tokenEndpoint = await _options.GetTokenEndpointAsync(httpClient).ConfigureAwait(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class RefreshTokenHandlerOptions: TokenHandlerOptions
/// </summary>
public Func<IServiceProvider, string> RefreshTokenRetriever { get; set; }

internal ConcurrentDictionary<string, AsyncLazy<TokenResponse>> LazyTokens { get; set; }
internal ConcurrentDictionary<string, AsyncLazy<TokenResponse>> LazyTokens { get; } = new ConcurrentDictionary<string, AsyncLazy<TokenResponse>>();

public override IEnumerable<string> GetValidationErrors()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ public static async Task<string> GetTokenEndpointAsync(this TokenHandlerOptions
return options.TokenEndpoint;
}

var endpoint = await options.GetTokenEndpointFromDiscoveryDocument(authorityHttpClient).ConfigureAwait(false);
var endpoint = await authorityHttpClient.GetTokenEndpointFromDiscoveryDocument(options).ConfigureAwait(false);
return endpoint;
}

public static async Task<string> GetTokenEndpointFromDiscoveryDocument(this TokenHandlerOptions options, HttpClient authorityHttpClient)
public static async Task<string> GetTokenEndpointFromDiscoveryDocument(this HttpClient authorityHttpClient, TokenHandlerOptions options)
{
var discoveryRequest = new DiscoveryDocumentRequest
{
Expand Down
Loading

0 comments on commit 84548d3

Please sign in to comment.