From 5caea7ff679ec6096fdb2d4828e6cdee30563854 Mon Sep 17 00:00:00 2001 From: Terence Fan Date: Wed, 16 Oct 2024 15:58:35 +0800 Subject: [PATCH] some code reformat (#2057) --- .../Auth/AccessKey.cs | 69 ++-- .../MicrosoftEntra/MicrosoftEntraAccessKey.cs | 2 +- .../Utilities/RestClient.cs | 263 +++++++------ .../Auth/AccessKeyForMicrosoftEntraTests.cs | 267 +++++++------ .../Auth/AuthUtilityTests.cs | 57 ++- .../Auth/ConnectionStringParserTests.cs | 359 +++++++++--------- .../Auth/MicrosoftEntraApplicationTests.cs | 111 +++--- 7 files changed, 560 insertions(+), 568 deletions(-) diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/AccessKey.cs b/src/Microsoft.Azure.SignalR.Common/Auth/AccessKey.cs index b16d4bff6..2a036cf28 100644 --- a/src/Microsoft.Azure.SignalR.Common/Auth/AccessKey.cs +++ b/src/Microsoft.Azure.SignalR.Common/Auth/AccessKey.cs @@ -7,42 +7,41 @@ using System.Threading; using System.Threading.Tasks; -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +internal class AccessKey { - internal class AccessKey + public string Id => Key?.Item1; + + public string Value => Key?.Item2; + + public Uri Endpoint { get; } + + protected Tuple Key { get; set; } + + public AccessKey(string uri, string key) : this(new Uri(uri)) + { + Key = new Tuple(key.GetHashCode().ToString(), key); + } + + public AccessKey(Uri uri, string key) : this(uri) + { + Key = new Tuple(key.GetHashCode().ToString(), key); + } + + protected AccessKey(Uri uri) + { + Endpoint = uri; + } + + public virtual Task GenerateAccessTokenAsync( + string audience, + IEnumerable claims, + TimeSpan lifetime, + AccessTokenAlgorithm algorithm, + CancellationToken ctoken = default) { - public string Id => Key?.Item1; - - public string Value => Key?.Item2; - - public Uri Endpoint { get; } - - protected Tuple Key { get; set; } - - public AccessKey(string uri, string key) : this(new Uri(uri)) - { - Key = new Tuple(key.GetHashCode().ToString(), key); - } - - public AccessKey(Uri uri, string key) : this(uri) - { - Key = new Tuple(key.GetHashCode().ToString(), key); - } - - protected AccessKey(Uri uri) - { - Endpoint = uri; - } - - public virtual Task GenerateAccessTokenAsync( - string audience, - IEnumerable claims, - TimeSpan lifetime, - AccessTokenAlgorithm algorithm, - CancellationToken ctoken = default) - { - var token = AuthUtility.GenerateAccessToken(this, audience, claims, lifetime, algorithm); - return Task.FromResult(token); - } + var token = AuthUtility.GenerateAccessToken(this, audience, claims, lifetime, algorithm); + return Task.FromResult(token); } } \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs index 7a331a22c..681636f54 100644 --- a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs +++ b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraAccessKey.cs @@ -106,7 +106,7 @@ public override async Task GenerateAccessTokenAsync( { await task; return IsAuthorized - ? await base.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm) + ? await base.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm, ctoken) : throw new AzureSignalRAccessTokenNotAuthorizedException(TokenCredential.GetType().Name, _lastException); } else diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs index 82d953553..3a30de04b 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs @@ -15,173 +15,170 @@ #nullable enable -namespace Microsoft.Azure.SignalR +namespace Microsoft.Azure.SignalR; + +internal class RestClient { - internal class RestClient + private readonly IHttpClientFactory _httpClientFactory; + + private readonly IPayloadContentBuilder _payloadContentBuilder; + + public RestClient(IHttpClientFactory httpClientFactory, IPayloadContentBuilder contentBuilder) { - private readonly IHttpClientFactory _httpClientFactory; - private readonly IPayloadContentBuilder _payloadContentBuilder; + _httpClientFactory = httpClientFactory; + _payloadContentBuilder = contentBuilder; + } - public RestClient(IHttpClientFactory httpClientFactory, IPayloadContentBuilder contentBuilder) - { - _httpClientFactory = httpClientFactory; - _payloadContentBuilder = contentBuilder; - } + // TODO: Test only, will remove later + internal RestClient(IHttpClientFactory httpClientFactory) : this(httpClientFactory, new JsonPayloadContentBuilder(new JsonObjectSerializer())) + { + } - // TODO: Test only, will remove later - internal RestClient(IHttpClientFactory httpClientFactory) : this(httpClientFactory, new JsonPayloadContentBuilder(new JsonObjectSerializer())) - { - } + public Task SendAsync( + RestApiEndpoint api, + HttpMethod httpMethod, + string? methodName = null, + object[]? args = null, + Func? handleExpectedResponse = null, + CancellationToken cancellationToken = default) + { + return handleExpectedResponse == null + ? SendAsync(api, httpMethod, methodName, args, handleExpectedResponseAsync: null, cancellationToken) + : SendAsync(api, httpMethod, methodName, args, response => Task.FromResult(handleExpectedResponse(response)), cancellationToken); + } - public Task SendAsync( - RestApiEndpoint api, - HttpMethod httpMethod, - string? methodName = null, - object[]? args = null, - Func? handleExpectedResponse = null, - CancellationToken cancellationToken = default) - { - if (handleExpectedResponse == null) - { - return SendAsync(api, httpMethod, methodName, args, handleExpectedResponseAsync: null, cancellationToken); - } + public Task SendAsync( + RestApiEndpoint api, + HttpMethod httpMethod, + string? methodName = null, + object[]? args = null, + Func>? handleExpectedResponseAsync = null, + CancellationToken cancellationToken = default) + { + return SendAsyncCore(Constants.HttpClientNames.UserDefault, api, httpMethod, methodName, args, handleExpectedResponseAsync, cancellationToken); + } - return SendAsync(api, httpMethod, methodName, args, response => Task.FromResult(handleExpectedResponse(response)), cancellationToken); - } + public Task SendWithRetryAsync( + RestApiEndpoint api, + HttpMethod httpMethod, + string? methodName = null, + object[]? args = null, + Func? handleExpectedResponse = null, + CancellationToken cancellationToken = default) + { + return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, methodName, args, handleExpectedResponse == null ? null : response => Task.FromResult(handleExpectedResponse(response)), cancellationToken); + } - public Task SendAsync( - RestApiEndpoint api, - HttpMethod httpMethod, - string? methodName = null, - object[]? args = null, - Func>? handleExpectedResponseAsync = null, - CancellationToken cancellationToken = default) + public Task SendMessageWithRetryAsync( + RestApiEndpoint api, + HttpMethod httpMethod, + string? methodName = null, + object[]? args = null, + Func? handleExpectedResponse = null, + CancellationToken cancellationToken = default) + { + return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, methodName, args, handleExpectedResponse == null ? null : response => Task.FromResult(handleExpectedResponse(response)), cancellationToken); + } + + private static Uri GetUri(string url, IDictionary? query) + { + if (query == null || query.Count == 0) { - return SendAsyncCore(Constants.HttpClientNames.UserDefault, api, httpMethod, methodName, args, handleExpectedResponseAsync, cancellationToken); + return new Uri(url); } - - public Task SendWithRetryAsync( - RestApiEndpoint api, - HttpMethod httpMethod, - string? methodName = null, - object[]? args = null, - Func? handleExpectedResponse = null, - CancellationToken cancellationToken = default) + var builder = new UriBuilder(url); + var sb = new StringBuilder(builder.Query); + if (sb.Length == 1 && sb[0] == '?') { - return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, methodName, args, handleExpectedResponse == null ? null : response => Task.FromResult(handleExpectedResponse(response)), cancellationToken); + sb.Clear(); } - - public Task SendMessageWithRetryAsync( - RestApiEndpoint api, - HttpMethod httpMethod, - string? methodName = null, - object[]? args = null, - Func? handleExpectedResponse = null, - CancellationToken cancellationToken = default) + else if (sb.Length > 0 && sb[0] != '?') { - return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, methodName, args, handleExpectedResponse == null ? null : response => Task.FromResult(handleExpectedResponse(response)), cancellationToken); + sb.Insert(0, '?'); } - - private async Task ThrowExceptionOnResponseFailureAsync(HttpResponseMessage response) + foreach (var item in query) { - if (response.IsSuccessStatusCode) + foreach (var value in item.Value) { - return; + sb.Append(sb.Length > 0 ? '&' : '?'); + sb.Append(Uri.EscapeDataString(item.Key)); + sb.Append('='); + sb.Append(Uri.EscapeDataString(value!)); } + } + builder.Query = sb.ToString(); + return builder.Uri; + } - var detail = await response.Content.ReadAsStringAsync(); + private static async Task ThrowExceptionOnResponseFailureAsync(HttpResponseMessage response) + { + if (response.IsSuccessStatusCode) + { + return; + } + + var detail = await response.Content.ReadAsStringAsync(); #if NET5_0_OR_GREATER - var innerException = new HttpRequestException( - $"Response status code does not indicate success: {(int)response.StatusCode} ({response.ReasonPhrase})", null, response.StatusCode); + var innerException = new HttpRequestException( +$"Response status code does not indicate success: {(int)response.StatusCode} ({response.ReasonPhrase})", null, response.StatusCode); #else - var innerException = new HttpRequestException( - $"Response status code does not indicate success: {(int)response.StatusCode} ({response.ReasonPhrase})"); + var innerException = new HttpRequestException( + $"Response status code does not indicate success: {(int)response.StatusCode} ({response.ReasonPhrase})"); #endif - throw response.StatusCode switch - { - HttpStatusCode.BadRequest => new AzureSignalRInvalidArgumentException(response.RequestMessage?.RequestUri?.ToString(), innerException, detail), - HttpStatusCode.Unauthorized => new AzureSignalRUnauthorizedException(response.RequestMessage?.RequestUri?.ToString(), innerException), - HttpStatusCode.NotFound => new AzureSignalRInaccessibleEndpointException(response.RequestMessage?.RequestUri?.ToString(), innerException), - _ => new AzureSignalRRuntimeException(response.RequestMessage?.RequestUri?.ToString(), innerException), - }; - } - - private async Task SendAsyncCore( - string httpClientName, - RestApiEndpoint api, - HttpMethod httpMethod, - string? methodName = null, - object[]? args = null, - Func>? handleExpectedResponseAsync = null, - CancellationToken cancellationToken = default) + throw response.StatusCode switch { - using var httpClient = _httpClientFactory.CreateClient(httpClientName); - using var request = BuildRequest(api, httpMethod, methodName, args); + HttpStatusCode.BadRequest => new AzureSignalRInvalidArgumentException(response.RequestMessage?.RequestUri?.ToString(), innerException, detail), + HttpStatusCode.Unauthorized => new AzureSignalRUnauthorizedException(response.RequestMessage?.RequestUri?.ToString(), innerException), + HttpStatusCode.NotFound => new AzureSignalRInaccessibleEndpointException(response.RequestMessage?.RequestUri?.ToString(), innerException), + _ => new AzureSignalRRuntimeException(response.RequestMessage?.RequestUri?.ToString(), innerException), + }; + } - try - { - using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); - if (handleExpectedResponseAsync == null) - { - await ThrowExceptionOnResponseFailureAsync(response); - } - else - { - if (!await handleExpectedResponseAsync(response)) - { - await ThrowExceptionOnResponseFailureAsync(response); - } - } - } - catch (HttpRequestException ex) - { - throw new AzureSignalRException($"An error happened when making request to {request.RequestUri}", ex); - } - } + private async Task SendAsyncCore( + string httpClientName, + RestApiEndpoint api, + HttpMethod httpMethod, + string? methodName = null, + object[]? args = null, + Func>? handleExpectedResponseAsync = null, + CancellationToken cancellationToken = default) + { + using var httpClient = _httpClientFactory.CreateClient(httpClientName); + using var request = BuildRequest(api, httpMethod, methodName, args); - private static Uri GetUri(string url, IDictionary? query) + try { - if (query == null || query.Count == 0) - { - return new Uri(url); - } - var builder = new UriBuilder(url); - var sb = new StringBuilder(builder.Query); - if (sb.Length == 1 && sb[0] == '?') + using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); + if (handleExpectedResponseAsync == null) { - sb.Clear(); + await ThrowExceptionOnResponseFailureAsync(response); } - else if (sb.Length > 0 && sb[0] != '?') + else { - sb.Insert(0, '?'); - } - foreach (var item in query) - { - foreach (var value in item.Value) + if (!await handleExpectedResponseAsync(response)) { - sb.Append(sb.Length > 0 ? '&' : '?'); - sb.Append(Uri.EscapeDataString(item.Key)); - sb.Append('='); - sb.Append(Uri.EscapeDataString(value!)); + await ThrowExceptionOnResponseFailureAsync(response); } } - builder.Query = sb.ToString(); - return builder.Uri; } - - private HttpRequestMessage BuildRequest(RestApiEndpoint api, HttpMethod httpMethod, string? methodName = null, object[]? args = null) + catch (HttpRequestException ex) { - var payload = httpMethod == HttpMethod.Post ? new PayloadMessage { Target = methodName, Arguments = args } : null; - return GenerateHttpRequest(api.Audience, api.Query, httpMethod, payload, api.Token); + throw new AzureSignalRException($"An error happened when making request to {request.RequestUri}", ex); } + } - private HttpRequestMessage GenerateHttpRequest(string url, IDictionary query, HttpMethod httpMethod, PayloadMessage? payload, string tokenString) - { - var request = new HttpRequestMessage(httpMethod, GetUri(url, query)); - request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", tokenString); - request.Content = _payloadContentBuilder.Build(payload); - return request; - } + private HttpRequestMessage BuildRequest(RestApiEndpoint api, HttpMethod httpMethod, string? methodName = null, object[]? args = null) + { + var payload = httpMethod == HttpMethod.Post ? new PayloadMessage { Target = methodName, Arguments = args } : null; + return GenerateHttpRequest(api.Audience, api.Query, httpMethod, payload, api.Token); + } + + private HttpRequestMessage GenerateHttpRequest(string url, IDictionary query, HttpMethod httpMethod, PayloadMessage? payload, string tokenString) + { + var request = new HttpRequestMessage(httpMethod, GetUri(url, query)); + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", tokenString); + request.Content = _payloadContentBuilder.Build(payload); + return request; } } \ No newline at end of file diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeyForMicrosoftEntraTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeyForMicrosoftEntraTests.cs index 1684c1ddd..d58d6fe84 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeyForMicrosoftEntraTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AccessKeyForMicrosoftEntraTests.cs @@ -8,150 +8,149 @@ using Moq; using Xunit; -namespace Microsoft.Azure.SignalR.Common.Tests.Auth +namespace Microsoft.Azure.SignalR.Common.Tests.Auth; + +[Collection("Auth")] +public class AccessKeyForMicrosoftEntraTests { - [Collection("Auth")] - public class AccessKeyForMicrosoftEntraTests + private const string DefaultSigningKey = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + + private static Uri DefaultEndpoint = new Uri("http://localhost"); + + [Theory] + [InlineData("https://a.bc", "https://a.bc/api/v1/auth/accessKey")] + [InlineData("https://a.bc:80", "https://a.bc:80/api/v1/auth/accessKey")] + [InlineData("https://a.bc:443", "https://a.bc/api/v1/auth/accessKey")] + public void TestExpectedGetAccessKeyUrl(string endpoint, string expectedGetAccessKeyUrl) { - private const string DefaultSigningKey = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + var key = new MicrosoftEntraAccessKey(new Uri(endpoint), new DefaultAzureCredential()); + Assert.Equal(expectedGetAccessKeyUrl, key.GetAccessKeyUrl); + } - private static Uri DefaultEndpoint = new Uri("http://localhost"); + [Fact] + public async Task TestUpdateAccessKey() + { + var mockCredential = new Mock(); + mockCredential.Setup(credential => credential.GetTokenAsync( + It.IsAny(), + It.IsAny())) + .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object); + + var audience = "http://localhost/chat"; + var claims = Array.Empty(); + var lifetime = TimeSpan.FromHours(1); + var algorithm = AccessTokenAlgorithm.HS256; + + var cts = new CancellationTokenSource(TimeSpan.FromSeconds(1)); + await Assert.ThrowsAsync( + async () => await key.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm, cts.Token) + ); + + var (kid, accessKey) = ("foo", DefaultSigningKey); + key.UpdateAccessKey(kid, accessKey); + + var token = await key.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm); + Assert.NotNull(token); + } - [Theory] - [InlineData("https://a.bc", "https://a.bc/api/v1/auth/accessKey")] - [InlineData("https://a.bc:80", "https://a.bc:80/api/v1/auth/accessKey")] - [InlineData("https://a.bc:443", "https://a.bc/api/v1/auth/accessKey")] - public void TestExpectedGetAccessKeyUrl(string endpoint, string expectedGetAccessKeyUrl) - { - var key = new MicrosoftEntraAccessKey(new Uri(endpoint), new DefaultAzureCredential()); - Assert.Equal(expectedGetAccessKeyUrl, key.GetAccessKeyUrl); - } + [Theory] + [InlineData(false, 1, true)] + [InlineData(false, 4, true)] + [InlineData(false, 6, false)] + [InlineData(true, 6, true)] + [InlineData(true, 54, true)] + [InlineData(true, 56, false)] + public async Task TestUpdateAccessKeyAsyncShouldSkip(bool isAuthorized, int timeElapsed, bool shouldSkip) + { + var mockCredential = new Mock(); + mockCredential.Setup(credential => credential.GetTokenAsync( + It.IsAny(), + It.IsAny())) + .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object); + var isAuthorizedField = typeof(MicrosoftEntraAccessKey).GetField("_isAuthorized", BindingFlags.NonPublic | BindingFlags.Instance); + isAuthorizedField.SetValue(key, isAuthorized); + Assert.Equal(isAuthorized, (bool)isAuthorizedField.GetValue(key)); - [Fact] - public async Task TestUpdateAccessKey() - { - var mockCredential = new Mock(); - mockCredential.Setup(credential => credential.GetTokenAsync( - It.IsAny(), - It.IsAny())) - .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); - var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object); - - var audience = "http://localhost/chat"; - var claims = Array.Empty(); - var lifetime = TimeSpan.FromHours(1); - var algorithm = AccessTokenAlgorithm.HS256; - - var cts = new CancellationTokenSource(TimeSpan.FromSeconds(1)); - await Assert.ThrowsAsync( - async () => await key.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm, cts.Token) - ); - - var (kid, accessKey) = ("foo", DefaultSigningKey); - key.UpdateAccessKey(kid, accessKey); - - var token = await key.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm); - Assert.NotNull(token); - } + var lastUpdatedTime = DateTime.UtcNow - TimeSpan.FromMinutes(timeElapsed); + var lastUpdatedTimeField = typeof(MicrosoftEntraAccessKey).GetField("_lastUpdatedTime", BindingFlags.NonPublic | BindingFlags.Instance); + lastUpdatedTimeField.SetValue(key, lastUpdatedTime); - [Theory] - [InlineData(false, 1, true)] - [InlineData(false, 4, true)] - [InlineData(false, 6, false)] - [InlineData(true, 6, true)] - [InlineData(true, 54, true)] - [InlineData(true, 56, false)] - public async Task TestUpdateAccessKeyAsyncShouldSkip(bool isAuthorized, int timeElapsed, bool shouldSkip) - { - var mockCredential = new Mock(); - mockCredential.Setup(credential => credential.GetTokenAsync( - It.IsAny(), - It.IsAny())) - .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); - var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object); - var isAuthorizedField = typeof(MicrosoftEntraAccessKey).GetField("_isAuthorized", BindingFlags.NonPublic | BindingFlags.Instance); - isAuthorizedField.SetValue(key, isAuthorized); - Assert.Equal(isAuthorized, (bool)isAuthorizedField.GetValue(key)); - - var lastUpdatedTime = DateTime.UtcNow - TimeSpan.FromMinutes(timeElapsed); - var lastUpdatedTimeField = typeof(MicrosoftEntraAccessKey).GetField("_lastUpdatedTime", BindingFlags.NonPublic | BindingFlags.Instance); - lastUpdatedTimeField.SetValue(key, lastUpdatedTime); - - var initializedTcsField = typeof(MicrosoftEntraAccessKey).GetField("_initializedTcs", BindingFlags.NonPublic | BindingFlags.Instance); - var initializedTcs = (TaskCompletionSource)initializedTcsField.GetValue(key); - - var lastExceptionFields = typeof(MicrosoftEntraAccessKey).GetField("_lastException", BindingFlags.NonPublic | BindingFlags.Instance); - - await key.UpdateAccessKeyAsync().OrTimeout(TimeSpan.FromSeconds(30)); - var actualLastUpdatedTime = Assert.IsType(lastUpdatedTimeField.GetValue(key)); - - if (shouldSkip) - { - Assert.Equal(isAuthorized, Assert.IsType(isAuthorizedField.GetValue(key))); - Assert.Equal(lastUpdatedTime, actualLastUpdatedTime); - Assert.Null(lastExceptionFields.GetValue(key)); - Assert.False(initializedTcs.Task.IsCompleted); - } - else - { - Assert.False(Assert.IsType(isAuthorizedField.GetValue(key))); - Assert.True(lastUpdatedTime < actualLastUpdatedTime); - Assert.NotNull(Assert.IsType(lastExceptionFields.GetValue(key))); - Assert.True(initializedTcs.Task.IsCompleted); - } - } + var initializedTcsField = typeof(MicrosoftEntraAccessKey).GetField("_initializedTcs", BindingFlags.NonPublic | BindingFlags.Instance); + var initializedTcs = (TaskCompletionSource)initializedTcsField.GetValue(key); - [Fact] - public async Task TestInitializeFailed() - { - var mockCredential = new Mock(); - mockCredential.Setup(credential => credential.GetTokenAsync( - It.IsAny(), - It.IsAny())) - .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); - var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object); - - var audience = "http://localhost/chat"; - var claims = Array.Empty(); - var lifetime = TimeSpan.FromHours(1); - var algorithm = AccessTokenAlgorithm.HS256; - - await key.UpdateAccessKeyAsync(); - - var exception = await Assert.ThrowsAsync( - async () => await key.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm) - ); - Assert.IsType(exception.InnerException); - } + var lastExceptionFields = typeof(MicrosoftEntraAccessKey).GetField("_lastException", BindingFlags.NonPublic | BindingFlags.Instance); - [Fact] - public async Task TestUpdateAccessKeyAfterInitializeFailed() + await key.UpdateAccessKeyAsync().OrTimeout(TimeSpan.FromSeconds(30)); + var actualLastUpdatedTime = Assert.IsType(lastUpdatedTimeField.GetValue(key)); + + if (shouldSkip) { - var mockCredential = new Mock(); - mockCredential.Setup(credential => credential.GetTokenAsync( - It.IsAny(), - It.IsAny())) - .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); - var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object); - - var audience = "http://localhost/chat"; - var claims = Array.Empty(); - var lifetime = TimeSpan.FromHours(1); - var algorithm = AccessTokenAlgorithm.HS256; - - await key.UpdateAccessKeyAsync(); - - var exception = await Assert.ThrowsAsync( - async () => await key.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm) - ); - Assert.IsType(exception.InnerException); - - var lastExceptionFields = typeof(MicrosoftEntraAccessKey).GetField("_lastException", BindingFlags.NonPublic | BindingFlags.Instance); - - Assert.NotNull(lastExceptionFields.GetValue(key)); - var (kid, accessKey) = ("foo", DefaultSigningKey); - key.UpdateAccessKey(kid, accessKey); + Assert.Equal(isAuthorized, Assert.IsType(isAuthorizedField.GetValue(key))); + Assert.Equal(lastUpdatedTime, actualLastUpdatedTime); Assert.Null(lastExceptionFields.GetValue(key)); + Assert.False(initializedTcs.Task.IsCompleted); + } + else + { + Assert.False(Assert.IsType(isAuthorizedField.GetValue(key))); + Assert.True(lastUpdatedTime < actualLastUpdatedTime); + Assert.NotNull(Assert.IsType(lastExceptionFields.GetValue(key))); + Assert.True(initializedTcs.Task.IsCompleted); } } + + [Fact] + public async Task TestInitializeFailed() + { + var mockCredential = new Mock(); + mockCredential.Setup(credential => credential.GetTokenAsync( + It.IsAny(), + It.IsAny())) + .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object); + + var audience = "http://localhost/chat"; + var claims = Array.Empty(); + var lifetime = TimeSpan.FromHours(1); + var algorithm = AccessTokenAlgorithm.HS256; + + await key.UpdateAccessKeyAsync(); + + var exception = await Assert.ThrowsAsync( + async () => await key.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm) + ); + Assert.IsType(exception.InnerException); + } + + [Fact] + public async Task TestUpdateAccessKeyAfterInitializeFailed() + { + var mockCredential = new Mock(); + mockCredential.Setup(credential => credential.GetTokenAsync( + It.IsAny(), + It.IsAny())) + .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); + var key = new MicrosoftEntraAccessKey(DefaultEndpoint, mockCredential.Object); + + var audience = "http://localhost/chat"; + var claims = Array.Empty(); + var lifetime = TimeSpan.FromHours(1); + var algorithm = AccessTokenAlgorithm.HS256; + + await key.UpdateAccessKeyAsync(); + + var exception = await Assert.ThrowsAsync( + async () => await key.GenerateAccessTokenAsync(audience, claims, lifetime, algorithm) + ); + Assert.IsType(exception.InnerException); + + var lastExceptionFields = typeof(MicrosoftEntraAccessKey).GetField("_lastException", BindingFlags.NonPublic | BindingFlags.Instance); + + Assert.NotNull(lastExceptionFields.GetValue(key)); + var (kid, accessKey) = ("foo", DefaultSigningKey); + key.UpdateAccessKey(kid, accessKey); + Assert.Null(lastExceptionFields.GetValue(key)); + } } diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AuthUtilityTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AuthUtilityTests.cs index 31a25bce0..931317fb3 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AuthUtilityTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AuthUtilityTests.cs @@ -10,43 +10,42 @@ using Xunit; -namespace Microsoft.Azure.SignalR.Common.Tests.Auth +namespace Microsoft.Azure.SignalR.Common.Tests.Auth; + +[Collection("Auth")] +public class AuthUtilityTests { - [Collection("Auth")] - public class AuthUtilityTests - { - private const string Audience = "https://localhost/aspnetclient?hub=testhub"; + private const string Audience = "https://localhost/aspnetclient?hub=testhub"; - private const string SigningKey = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + private const string SigningKey = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; - private static readonly TimeSpan DefaultLifetime = TimeSpan.FromHours(1); + private static readonly TimeSpan DefaultLifetime = TimeSpan.FromHours(1); - [Fact] - public void TestAccessTokenTooLongThrowsException() - { - var claims = GenerateClaims(100); - var accessKey = new AccessKey("http://localhost:443", SigningKey); - var exception = Assert.Throws(() => AuthUtility.GenerateAccessToken(accessKey, Audience, claims, DefaultLifetime, AccessTokenAlgorithm.HS256)); + [Fact] + public void TestAccessTokenTooLongThrowsException() + { + var claims = GenerateClaims(100); + var accessKey = new AccessKey("http://localhost:443", SigningKey); + var exception = Assert.Throws(() => AuthUtility.GenerateAccessToken(accessKey, Audience, claims, DefaultLifetime, AccessTokenAlgorithm.HS256)); - Assert.Equal("AccessToken must not be longer than 4K.", exception.Message); - } + Assert.Equal("AccessToken must not be longer than 4K.", exception.Message); + } - private static Claim[] GenerateClaims(int count) - { - return Enumerable.Range(0, count).Select(s => new Claim($"ClaimSubject{s}", $"ClaimValue{s}")).ToArray(); - } + private static Claim[] GenerateClaims(int count) + { + return Enumerable.Range(0, count).Select(s => new Claim($"ClaimSubject{s}", $"ClaimValue{s}")).ToArray(); + } - public class CachingTestData : IEnumerable + public class CachingTestData : IEnumerable + { + public IEnumerator GetEnumerator() { - public IEnumerator GetEnumerator() - { - yield return new object[] { new AccessKey("http://localhost:443", SigningKey), true }; - var key = new MicrosoftEntraAccessKey(new Uri("http://localhost"), new DefaultAzureCredential()); - key.UpdateAccessKey("foo", SigningKey); - yield return new object[] { key, false }; - } - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + yield return new object[] { new AccessKey("http://localhost:443", SigningKey), true }; + var key = new MicrosoftEntraAccessKey(new Uri("http://localhost"), new DefaultAzureCredential()); + key.UpdateAccessKey("foo", SigningKey); + yield return new object[] { key, false }; } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); } } diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs index 439f43a2b..a5f7d085d 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs @@ -8,218 +8,217 @@ using Azure.Identity; using Xunit; -namespace Microsoft.Azure.SignalR.Common.Tests.Auth +namespace Microsoft.Azure.SignalR.Common.Tests.Auth; + +[Collection("Auth")] +public class ConnectionStringParserTests { - [Collection("Auth")] - public class ConnectionStringParserTests - { - private const string ClientEndpoint = "http://bbb"; + private const string ClientEndpoint = "http://bbb"; - private const string DefaultKey = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + private const string DefaultKey = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; - private const string HttpEndpoint = "http://aaa"; + private const string HttpEndpoint = "http://aaa"; - private const string HttpsEndpoint = "https://aaa"; + private const string HttpsEndpoint = "https://aaa"; - private const string ServerEndpoint = "http://ccc"; + private const string ServerEndpoint = "http://ccc"; - public static IEnumerable ServerEndpointTestData + public static IEnumerable ServerEndpointTestData + { + get { - get - { - yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey}", null, null }; - yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey}", null, null }; - yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400", null, null }; - yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400;serverEndpoint={ServerEndpoint}", ServerEndpoint, 80 }; - yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400;serverEndpoint={ServerEndpoint}:500", $"{ServerEndpoint}:500", 500 }; - } + yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey}", null, null }; + yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey}", null, null }; + yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400", null, null }; + yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400;serverEndpoint={ServerEndpoint}", ServerEndpoint, 80 }; + yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400;serverEndpoint={ServerEndpoint}:500", $"{ServerEndpoint}:500", 500 }; } + } - [Theory] - [InlineData("endpoint=https://aaa;AuthType=aad;clientId=123;tenantId=aaaaaaaa-bbbb-bbbb-bbbb-cccccccccccc")] - [InlineData("endpoint=https://aaa;AuthType=azure.app;clientId=123;tenantId=aaaaaaaa-bbbb-bbbb-bbbb-cccccccccccc")] - public void InvalidAzureApplication(string connectionString) - { - var exception = Assert.Throws(() => ConnectionStringParser.Parse(connectionString)); - Assert.Contains("Connection string missing required properties clientSecret or clientCert", exception.Message); - } + [Theory] + [InlineData("endpoint=https://aaa;AuthType=aad;clientId=123;tenantId=aaaaaaaa-bbbb-bbbb-bbbb-cccccccccccc")] + [InlineData("endpoint=https://aaa;AuthType=azure.app;clientId=123;tenantId=aaaaaaaa-bbbb-bbbb-bbbb-cccccccccccc")] + public void InvalidAzureApplication(string connectionString) + { + var exception = Assert.Throws(() => ConnectionStringParser.Parse(connectionString)); + Assert.Contains("Connection string missing required properties clientSecret or clientCert", exception.Message); + } - [Theory] - [InlineData("endpoint=https://aaa;clientEndpoint=aaa;AccessKey=bbb;")] - [InlineData("endpoint=https://aaa;ClientEndpoint=endpoint=aaa;AccessKey=bbb;")] - public void InvalidClientEndpoint(string connectionString) - { - var exception = Assert.Throws(() => ConnectionStringParser.Parse(connectionString)); - Assert.Contains("Invalid value for clientEndpoint property, it must be a valid URI. (Parameter 'clientEndpoint')", exception.Message); - } + [Theory] + [InlineData("endpoint=https://aaa;clientEndpoint=aaa;AccessKey=bbb;")] + [InlineData("endpoint=https://aaa;ClientEndpoint=endpoint=aaa;AccessKey=bbb;")] + public void InvalidClientEndpoint(string connectionString) + { + var exception = Assert.Throws(() => ConnectionStringParser.Parse(connectionString)); + Assert.Contains("Invalid value for clientEndpoint property, it must be a valid URI. (Parameter 'clientEndpoint')", exception.Message); + } - [Theory] - [InlineData("Endpoint=xxx")] - [InlineData("AccessKey=xxx")] - [InlineData("XXX=yyy")] - [InlineData("XXX")] - public void InvalidConnectionStrings(string connectionString) - { - var exception = Assert.Throws(() => ConnectionStringParser.Parse(connectionString)); - Assert.Contains("Connection string missing required properties", exception.Message); - } + [Theory] + [InlineData("Endpoint=xxx")] + [InlineData("AccessKey=xxx")] + [InlineData("XXX=yyy")] + [InlineData("XXX")] + public void InvalidConnectionStrings(string connectionString) + { + var exception = Assert.Throws(() => ConnectionStringParser.Parse(connectionString)); + Assert.Contains("Connection string missing required properties", exception.Message); + } - [Theory] - [InlineData("Endpoint=aaa;AccessKey=bbb;")] - [InlineData("Endpoint=endpoint=aaa;AccessKey=bbb;")] - public void InvalidEndpoint(string connectionString) - { - var exception = Assert.Throws(() => ConnectionStringParser.Parse(connectionString)); - Assert.Contains("Invalid value for endpoint property, it must be a valid URI. (Parameter 'endpoint')", exception.Message); - } + [Theory] + [InlineData("Endpoint=aaa;AccessKey=bbb;")] + [InlineData("Endpoint=endpoint=aaa;AccessKey=bbb;")] + public void InvalidEndpoint(string connectionString) + { + var exception = Assert.Throws(() => ConnectionStringParser.Parse(connectionString)); + Assert.Contains("Invalid value for endpoint property, it must be a valid URI. (Parameter 'endpoint')", exception.Message); + } - [Theory] - [InlineData("Endpoint=https://aaa;AccessKey=bbb;version=1.0;port=2.3")] - [InlineData("Endpoint=https://aaa;AccessKey=bbb;version=1.1;port=1000000")] - [InlineData("Endpoint=https://aaa;AccessKey=bbb;version=1.0-preview;port=0")] - public void InvalidPort(string connectionString) - { - var exception = Assert.Throws(() => ConnectionStringParser.Parse(connectionString)); - Assert.Contains("Invalid value for port property, it must be an positive integer between (0, 65536) (Parameter 'port')", exception.Message); - } + [Theory] + [InlineData("Endpoint=https://aaa;AccessKey=bbb;version=1.0;port=2.3")] + [InlineData("Endpoint=https://aaa;AccessKey=bbb;version=1.1;port=1000000")] + [InlineData("Endpoint=https://aaa;AccessKey=bbb;version=1.0-preview;port=0")] + public void InvalidPort(string connectionString) + { + var exception = Assert.Throws(() => ConnectionStringParser.Parse(connectionString)); + Assert.Contains("Invalid value for port property, it must be an positive integer between (0, 65536) (Parameter 'port')", exception.Message); + } - [Theory] - [InlineData("Endpoint=https://aaa;AccessKey=bbb;version=abc", "abc")] - [InlineData("Endpoint=https://aaa;AccessKey=bbb;version=1.x", "1.x")] - [InlineData("Endpoint=https://aaa;AccessKey=bbb;version=2.0", "2.0")] - public void InvalidVersion(string connectionString, string version) - { - var exception = Assert.Throws(() => ConnectionStringParser.Parse(connectionString)); - Assert.Contains($"Version {version} is not supported.", exception.Message); - } + [Theory] + [InlineData("Endpoint=https://aaa;AccessKey=bbb;version=abc", "abc")] + [InlineData("Endpoint=https://aaa;AccessKey=bbb;version=1.x", "1.x")] + [InlineData("Endpoint=https://aaa;AccessKey=bbb;version=2.0", "2.0")] + public void InvalidVersion(string connectionString, string version) + { + var exception = Assert.Throws(() => ConnectionStringParser.Parse(connectionString)); + Assert.Contains($"Version {version} is not supported.", exception.Message); + } - [Theory] - [InlineData("endpoint=https://aaa;AuthType=aad;clientId=foo;clientSecret=bar;tenantId=aaaaaaaa-bbbb-bbbb-bbbb-cccccccccccc")] - [InlineData("endpoint=https://aaa;AuthType=azure.app;clientId=foo;clientSecret=bar;tenantId=aaaaaaaa-bbbb-bbbb-bbbb-cccccccccccc")] - public void TestAzureApplication(string connectionString) - { - var r = ConnectionStringParser.Parse(connectionString); + [Theory] + [InlineData("endpoint=https://aaa;AuthType=aad;clientId=foo;clientSecret=bar;tenantId=aaaaaaaa-bbbb-bbbb-bbbb-cccccccccccc")] + [InlineData("endpoint=https://aaa;AuthType=azure.app;clientId=foo;clientSecret=bar;tenantId=aaaaaaaa-bbbb-bbbb-bbbb-cccccccccccc")] + public void TestAzureApplication(string connectionString) + { + var r = ConnectionStringParser.Parse(connectionString); - var key = Assert.IsType(r.AccessKey); - Assert.IsType(key.TokenCredential); - Assert.Same(r.Endpoint, r.AccessKey.Endpoint); - Assert.Null(r.Version); - Assert.Null(r.ClientEndpoint); - } + var key = Assert.IsType(r.AccessKey); + Assert.IsType(key.TokenCredential); + Assert.Same(r.Endpoint, r.AccessKey.Endpoint); + Assert.Null(r.Version); + Assert.Null(r.ClientEndpoint); + } - [Theory] - [ClassData(typeof(ClientEndpointTestData))] - public void TestClientEndpoint(string connectionString, string expectedClientEndpoint, int? expectedPort) - { - var r = ConnectionStringParser.Parse(connectionString); - Assert.Same(r.Endpoint, r.AccessKey.Endpoint); - var expectedUri = expectedClientEndpoint == null ? null : new Uri(expectedClientEndpoint); - Assert.Equal(expectedUri, r.ClientEndpoint); - Assert.Equal(expectedPort, r.ClientEndpoint?.Port); - } + [Theory] + [ClassData(typeof(ClientEndpointTestData))] + public void TestClientEndpoint(string connectionString, string expectedClientEndpoint, int? expectedPort) + { + var r = ConnectionStringParser.Parse(connectionString); + Assert.Same(r.Endpoint, r.AccessKey.Endpoint); + var expectedUri = expectedClientEndpoint == null ? null : new Uri(expectedClientEndpoint); + Assert.Equal(expectedUri, r.ClientEndpoint); + Assert.Equal(expectedPort, r.ClientEndpoint?.Port); + } - [Theory] - [MemberData(nameof(ServerEndpointTestData))] - public void TestServerEndpoint(string connectionString, string expectedServerEndpoint, int? expectedPort) - { - var r = ConnectionStringParser.Parse(connectionString); - Assert.Same(r.Endpoint, r.AccessKey.Endpoint); - var expectedUri = expectedServerEndpoint == null ? null : new Uri(expectedServerEndpoint); - Assert.Equal(expectedUri, r.ServerEndpoint); - Assert.Equal(expectedPort, r.ServerEndpoint?.Port); - } + [Theory] + [MemberData(nameof(ServerEndpointTestData))] + public void TestServerEndpoint(string connectionString, string expectedServerEndpoint, int? expectedPort) + { + var r = ConnectionStringParser.Parse(connectionString); + Assert.Same(r.Endpoint, r.AccessKey.Endpoint); + var expectedUri = expectedServerEndpoint == null ? null : new Uri(expectedServerEndpoint); + Assert.Equal(expectedUri, r.ServerEndpoint); + Assert.Equal(expectedPort, r.ServerEndpoint?.Port); + } - [Theory] - [ClassData(typeof(VersionTestData))] - public void TestVersion(string connectionString, string expectedVersion) - { - var r = ConnectionStringParser.Parse(connectionString); - Assert.Same(r.Endpoint, r.AccessKey.Endpoint); - Assert.Equal(expectedVersion, r.Version); - } + [Theory] + [ClassData(typeof(VersionTestData))] + public void TestVersion(string connectionString, string expectedVersion) + { + var r = ConnectionStringParser.Parse(connectionString); + Assert.Same(r.Endpoint, r.AccessKey.Endpoint); + Assert.Equal(expectedVersion, r.Version); + } - [Theory] - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure;clientId=xxxx;")] // should ignore the clientId - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure;tenantId=xxxx;")] // should ignore the tenantId - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure;clientSecret=xxxx;")] // should ignore the clientSecret - internal void TestDefaultAzureCredential(string expectedEndpoint, string connectionString) - { - var r = ConnectionStringParser.Parse(connectionString); + [Theory] + [InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure;clientId=xxxx;")] // should ignore the clientId + [InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure;tenantId=xxxx;")] // should ignore the tenantId + [InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure;clientSecret=xxxx;")] // should ignore the clientSecret + internal void TestDefaultAzureCredential(string expectedEndpoint, string connectionString) + { + var r = ConnectionStringParser.Parse(connectionString); - Assert.Equal(expectedEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/')); - var key = Assert.IsType(r.AccessKey); - Assert.IsType(key.TokenCredential); - Assert.Same(r.Endpoint, r.AccessKey.Endpoint); - } + Assert.Equal(expectedEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/')); + var key = Assert.IsType(r.AccessKey); + Assert.IsType(key.TokenCredential); + Assert.Same(r.Endpoint, r.AccessKey.Endpoint); + } - [Theory] - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;")] - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;clientId=123;")] - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;tenantId=xxxx;")] // should ignore the tenantId - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;clientSecret=xxxx;")] // should ignore the clientSecret - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure.msi;")] - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure.msi;clientId=123;")] - internal void TestManagedIdentity(string expectedEndpoint, string connectionString) - { - var r = ConnectionStringParser.Parse(connectionString); + [Theory] + [InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;")] + [InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;clientId=123;")] + [InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;tenantId=xxxx;")] // should ignore the tenantId + [InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;clientSecret=xxxx;")] // should ignore the clientSecret + [InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure.msi;")] + [InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure.msi;clientId=123;")] + internal void TestManagedIdentity(string expectedEndpoint, string connectionString) + { + var r = ConnectionStringParser.Parse(connectionString); - Assert.Equal(expectedEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/')); - var key = Assert.IsType(r.AccessKey); - Assert.IsType(key.TokenCredential); - Assert.Same(r.Endpoint, r.AccessKey.Endpoint); - Assert.Null(r.ClientEndpoint); - } + Assert.Equal(expectedEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/')); + var key = Assert.IsType(r.AccessKey); + Assert.IsType(key.TokenCredential); + Assert.Same(r.Endpoint, r.AccessKey.Endpoint); + Assert.Null(r.ClientEndpoint); + } - [Theory] - [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo", "https://foo/api/v1/auth/accesskey")] - [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo:123", "https://foo:123/api/v1/auth/accesskey")] - [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo/bar", "https://foo/bar/api/v1/auth/accesskey")] - [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo/bar/", "https://foo/bar/api/v1/auth/accesskey")] - [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo:123/bar/", "https://foo:123/bar/api/v1/auth/accesskey")] - internal void TestAzureADWithServerEndpoint(string connectionString, string expectedAuthorizeUrl) - { - var r = ConnectionStringParser.Parse(connectionString); - var key = Assert.IsType(r.AccessKey); - Assert.Equal(expectedAuthorizeUrl, key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); - } + [Theory] + [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo", "https://foo/api/v1/auth/accesskey")] + [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo:123", "https://foo:123/api/v1/auth/accesskey")] + [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo/bar", "https://foo/bar/api/v1/auth/accesskey")] + [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo/bar/", "https://foo/bar/api/v1/auth/accesskey")] + [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo:123/bar/", "https://foo:123/bar/api/v1/auth/accesskey")] + internal void TestAzureADWithServerEndpoint(string connectionString, string expectedAuthorizeUrl) + { + var r = ConnectionStringParser.Parse(connectionString); + var key = Assert.IsType(r.AccessKey); + Assert.Equal(expectedAuthorizeUrl, key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); + } - public class ClientEndpointTestData : IEnumerable + public class ClientEndpointTestData : IEnumerable + { + public IEnumerator GetEnumerator() { - public IEnumerator GetEnumerator() - { - yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey}", null, null }; - yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey}", null, null }; - yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400", null, null }; - yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400;clientEndpoint={ClientEndpoint}", ClientEndpoint, 80 }; - yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400;clientEndpoint={ClientEndpoint}:500", $"{ClientEndpoint}:500", 500 }; - } - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey}", null, null }; + yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey}", null, null }; + yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400", null, null }; + yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400;clientEndpoint={ClientEndpoint}", ClientEndpoint, 80 }; + yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400;clientEndpoint={ClientEndpoint}:500", $"{ClientEndpoint}:500", 500 }; } - public class EndpointEndWithSlash : IEnumerable + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + + public class EndpointEndWithSlash : IEnumerable + { + public IEnumerator GetEnumerator() { - public IEnumerator GetEnumerator() - { - yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey}", HttpEndpoint }; - yield return new object[] { $"endpoint={HttpEndpoint}/;accesskey={DefaultKey}", HttpEndpoint }; - yield return new object[] { $"endpoint={HttpsEndpoint};accesskey={DefaultKey}", HttpsEndpoint }; - yield return new object[] { $"endpoint={HttpsEndpoint}/;accesskey={DefaultKey}", HttpsEndpoint }; - } - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey}", HttpEndpoint }; + yield return new object[] { $"endpoint={HttpEndpoint}/;accesskey={DefaultKey}", HttpEndpoint }; + yield return new object[] { $"endpoint={HttpsEndpoint};accesskey={DefaultKey}", HttpsEndpoint }; + yield return new object[] { $"endpoint={HttpsEndpoint}/;accesskey={DefaultKey}", HttpsEndpoint }; } - public class VersionTestData : IEnumerable + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + + public class VersionTestData : IEnumerable + { + public IEnumerator GetEnumerator() { - public IEnumerator GetEnumerator() - { - yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey}", null }; - yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey};version=1.0", "1.0" }; - yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey};version=1.1-preview", "1.1-preview" }; - } - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey}", null }; + yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey};version=1.0", "1.0" }; + yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey};version=1.1-preview", "1.1-preview" }; } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); } } \ No newline at end of file diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraApplicationTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraApplicationTests.cs index 5c46bc0b2..f6f1f6881 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraApplicationTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/MicrosoftEntraApplicationTests.cs @@ -9,76 +9,75 @@ using Microsoft.IdentityModel.Tokens; using Xunit; -namespace Microsoft.Azure.SignalR.Common.Tests.Auth +namespace Microsoft.Azure.SignalR.Common.Tests.Auth; + +[Collection("Auth")] +public class MicrosoftEntraApplicationTests { - [Collection("Auth")] - public class MicrosoftEntraApplicationTests - { - private const string IssuerEndpoint = "https://sts.windows.net/"; + private const string IssuerEndpoint = "https://sts.windows.net/"; - private const string TestClientId = ""; - private const string TestClientSecret = ""; - private const string TestTenantId = ""; + private const string TestClientId = ""; + private const string TestClientSecret = ""; + private const string TestTenantId = ""; - private static readonly string[] DefaultScopes = new string[] { "https://signalr.azure.com/.default" }; + private static readonly string[] DefaultScopes = new string[] { "https://signalr.azure.com/.default" }; - [Fact(Skip = "Provide valid Microsoft Entra application options")] - public async Task TestAcquireAccessToken() - { - var options = new ClientSecretCredential(TestTenantId, TestClientId, TestClientSecret); - var key = new MicrosoftEntraAccessKey(new Uri("https://localhost:8080"), options); - var token = await key.GetMicrosoftEntraTokenAsync(); - Assert.NotNull(token); - } + [Fact(Skip = "Provide valid Microsoft Entra application options")] + public async Task TestAcquireAccessToken() + { + var options = new ClientSecretCredential(TestTenantId, TestClientId, TestClientSecret); + var key = new MicrosoftEntraAccessKey(new Uri("https://localhost:8080"), options); + var token = await key.GetMicrosoftEntraTokenAsync(); + Assert.NotNull(token); + } - [Fact(Skip = "Provide valid Microsoft Entra application options")] - public async Task TestGetMicrosoftEntraTokenAndAuthenticate() - { - var credential = new ClientSecretCredential(TestTenantId, TestClientId, TestClientSecret); + [Fact(Skip = "Provide valid Microsoft Entra application options")] + public async Task TestGetMicrosoftEntraTokenAndAuthenticate() + { + var credential = new ClientSecretCredential(TestTenantId, TestClientId, TestClientSecret); - var configManager = new ConfigurationManager( - "https://login.microsoftonline.com/common/v2.0/.well-known/openid-configuration", - new OpenIdConnectConfigurationRetriever() - ); - var keys = (await configManager.GetConfigurationAsync()).SigningKeys; + var configManager = new ConfigurationManager( + "https://login.microsoftonline.com/common/v2.0/.well-known/openid-configuration", + new OpenIdConnectConfigurationRetriever() + ); + var keys = (await configManager.GetConfigurationAsync()).SigningKeys; - var p = new TokenValidationParameters() - { - ValidateLifetime = true, - ValidateAudience = false, + var p = new TokenValidationParameters() + { + ValidateLifetime = true, + ValidateAudience = false, - IssuerValidator = (string issuer, SecurityToken securityToken, TokenValidationParameters validationParameters) => + IssuerValidator = (string issuer, SecurityToken securityToken, TokenValidationParameters validationParameters) => + { + if (issuer.StartsWith(IssuerEndpoint)) { - if (issuer.StartsWith(IssuerEndpoint)) - { - return IssuerEndpoint; - } - throw new SecurityTokenInvalidIssuerException(); - }, + return IssuerEndpoint; + } + throw new SecurityTokenInvalidIssuerException(); + }, - ValidateIssuerSigningKey = true, - IssuerSigningKeys = keys, - }; + ValidateIssuerSigningKey = true, + IssuerSigningKeys = keys, + }; - var handler = new JwtSecurityTokenHandler(); - IdentityModelEventSource.ShowPII = true; + var handler = new JwtSecurityTokenHandler(); + IdentityModelEventSource.ShowPII = true; - var accessToken = await credential.GetTokenAsync(new TokenRequestContext(DefaultScopes)); - var claims = handler.ValidateToken(accessToken.Token, p, out var validToken); + var accessToken = await credential.GetTokenAsync(new TokenRequestContext(DefaultScopes)); + var claims = handler.ValidateToken(accessToken.Token, p, out var validToken); - Assert.NotNull(validToken); - } + Assert.NotNull(validToken); + } - [Fact(Skip = "Provide valid Microsoft Entra application options")] - internal async Task TestAuthenticateAsync() - { - var options = new ClientSecretCredential(TestTenantId, TestClientId, TestClientSecret); - var key = new MicrosoftEntraAccessKey(new Uri("https://localhost:8080"), options); - await key.UpdateAccessKeyAsync(); + [Fact(Skip = "Provide valid Microsoft Entra application options")] + internal async Task TestAuthenticateAsync() + { + var options = new ClientSecretCredential(TestTenantId, TestClientId, TestClientSecret); + var key = new MicrosoftEntraAccessKey(new Uri("https://localhost:8080"), options); + await key.UpdateAccessKeyAsync(); - Assert.True(key.IsAuthorized); - Assert.NotNull(key.Id); - Assert.NotNull(key.Value); - } + Assert.True(key.IsAuthorized); + Assert.NotNull(key.Id); + Assert.NotNull(key.Value); } }