Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release 1.21.5 #1806

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,10 @@ private set

private Task<object> InitializedTask => _initializedTcs.Task;

public AadAccessKey(Uri uri, TokenCredential credential) : base(uri)
public AadAccessKey(Uri endpoint, TokenCredential credential, Uri serverEndpoint = null) : base(endpoint)
{
var builder = new UriBuilder(Endpoint)
{
Path = "/api/v1/auth/accessKey",
Port = uri.Port
};
AuthorizeUrl = builder.Uri.AbsoluteUri;
var authorizeUri = (serverEndpoint ?? endpoint).Append("/api/v1/auth/accessKey");
AuthorizeUrl = authorizeUri.AbsoluteUri;
TokenCredential = credential;
}

Expand Down
5 changes: 3 additions & 2 deletions src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKey.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ namespace Microsoft.Azure.SignalR
internal class AccessKey
{
public string Id => Key?.Item1;
public string Value => Key?.Item2;

protected Tuple<string, string> Key { get; set; }
public string Value => Key?.Item2;

public Uri Endpoint { get; }

protected Tuple<string, string> Key { get; set; }

public AccessKey(string uri, string key) : this(new Uri(uri))
{
Key = new Tuple<string, string>(key.GetHashCode().ToString(), key);
Expand Down
41 changes: 33 additions & 8 deletions src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,17 @@ namespace Microsoft.Azure.SignalR
public class ServiceEndpoint
{
private readonly Uri _serviceEndpoint;

private readonly Uri _serverEndpoint;

private readonly Uri _clientEndpoint;

private readonly TokenCredential _tokenCredential;

private readonly object _lock = new object();

private volatile AccessKey _accessKey;

public string ConnectionString { get; }

public EndpointType EndpointType { get; } = EndpointType.Primary;
Expand Down Expand Up @@ -42,6 +50,7 @@ public Uri ClientEndpoint
_clientEndpoint = value;
}
}

/// <summary>
/// When current app server instance has server connections connected to the target endpoint for current hub, it can deliver messages to that endpoint.
/// The endpoint is then considered as *Online*; otherwise, *Offline*.
Expand Down Expand Up @@ -69,7 +78,21 @@ public Uri ClientEndpoint

internal string Version { get; }

internal AccessKey AccessKey { get; private set; }
internal AccessKey AccessKey
{
get
{
if (_accessKey is null)
{
lock (_lock)
{
_accessKey ??= new AadAccessKey(_serviceEndpoint, _tokenCredential, ServerEndpoint);
}
}
return _accessKey;
}
private init => _accessKey = value;
}

// Flag to indicate an updaing endpoint needs staging
internal virtual bool PendingReload { get; set; }
Expand Down Expand Up @@ -132,16 +155,18 @@ public ServiceEndpoint(string nameWithEndpointType, Uri endpoint, TokenCredentia
/// <param name="name">The endpoint name.</param>
/// <param name="serverEndpoint">The endpoint for servers to connect to Azure SignalR.</param>
/// <param name="clientEndpoint">The endpoint for clients to connect to Azure SignalR.</param>
public ServiceEndpoint(Uri endpoint, TokenCredential credential, EndpointType endpointType = EndpointType.Primary, string name = "",
Uri serverEndpoint = null, Uri clientEndpoint = null)
public ServiceEndpoint(Uri endpoint,
TokenCredential credential,
EndpointType endpointType = EndpointType.Primary,
string name = "",
Uri serverEndpoint = null,
Uri clientEndpoint = null)
{
_serviceEndpoint = endpoint ?? throw new ArgumentNullException(nameof(endpoint));
CheckScheme(endpoint);
if (credential is null)
{
throw new ArgumentNullException(nameof(credential));
}
AccessKey = new AadAccessKey(endpoint, credential);

_tokenCredential = credential ?? throw new ArgumentNullException(nameof(credential));

EndpointType = endpointType;
Name = name;

Expand Down
16 changes: 16 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Endpoints/UriExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Linq;

namespace Microsoft.Azure.SignalR
{
internal static class UriExtensions
{
public static Uri Append(this Uri uri, params string[] paths)
{
return new Uri(paths.Aggregate(uri.AbsoluteUri, (current, path) => string.Format("{0}/{1}", current.TrimEnd('/'), path.TrimStart('/'))));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ internal static class ConnectionStringParser

private const string EndpointProperty = "endpoint";

private const string ServerEndpointProperty = "ServerEndpoint";

private const string InvalidVersionValueFormat = "Version {0} is not supported.";

private const string PortProperty = "port";

private const string ServerEndpoint = "ServerEndpoint";

// For SDK 1.x, only support Azure SignalR Service 1.x
private const string SupportedVersion = "1";

Expand Down Expand Up @@ -114,6 +114,7 @@ internal static ParsedConnectionString Parse(string connectionString)
}

Uri clientEndpointUri = null;
Uri serverEndpointUri = null;

// parse and validate clientEndpoint.
if (dict.TryGetValue(ClientEndpointProperty, out var clientEndpoint))
Expand All @@ -124,25 +125,26 @@ internal static ParsedConnectionString Parse(string connectionString)
}
}

// parse and validate clientEndpoint.
if (dict.TryGetValue(ServerEndpointProperty, out var serverEndpoint))
{
if (!TryGetEndpointUri(serverEndpoint, out serverEndpointUri))
{
throw new ArgumentException($"{ServerEndpointProperty} property in connection string is not a valid URI: {serverEndpoint}.");
}
}

// try building accesskey.
dict.TryGetValue(AuthTypeProperty, out var type);
var accessKey = type?.ToLower() switch
{
TypeAzureAD => BuildAadAccessKey(builder.Uri, dict),
TypeAzure => BuildAzureAccessKey(builder.Uri, dict),
TypeAzureApp => BuildAzureAppAccessKey(builder.Uri, dict),
TypeAzureMsi => BuildAzureMsiAccessKey(builder.Uri, dict),
TypeAzureAD => BuildAzureADAccessKey(builder.Uri, serverEndpointUri, dict),
TypeAzure => BuildAzureAccessKey(builder.Uri, serverEndpointUri, dict),
TypeAzureApp => BuildAzureAppAccessKey(builder.Uri, serverEndpointUri, dict),
TypeAzureMsi => BuildAzureMsiAccessKey(builder.Uri, serverEndpointUri, dict),
_ => BuildAccessKey(builder.Uri, dict),
};

Uri serverEndpointUri = null;
if (dict.TryGetValue(ServerEndpoint, out var serverEndpoint))
{
if (!TryGetEndpointUri(serverEndpoint, out serverEndpointUri))
{
throw new ArgumentException($"{ServerEndpoint} property in connection string is not a valid URI: {serverEndpoint}.");
}
}
return new ParsedConnectionString()
{
Endpoint = builder.Uri,
Expand All @@ -159,19 +161,19 @@ internal static bool TryGetEndpointUri(string endpoint, out Uri uriResult)
(uriResult.Scheme == Uri.UriSchemeHttp || uriResult.Scheme == Uri.UriSchemeHttps);
}

private static AccessKey BuildAadAccessKey(Uri uri, Dictionary<string, string> dict)
private static AccessKey BuildAzureADAccessKey(Uri uri, Uri serverEndpointUri, Dictionary<string, string> dict)
{
if (dict.TryGetValue(ClientIdProperty, out var clientId))
{
if (dict.TryGetValue(TenantIdProperty, out var tenantId))
{
if (dict.TryGetValue(ClientSecretProperty, out var clientSecret))
{
return new AadAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret));
return new AadAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri);
}
else if (dict.TryGetValue(ClientCertProperty, out var clientCertPath))
{
return new AadAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath));
return new AadAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri);
}
else
{
Expand All @@ -180,30 +182,28 @@ private static AccessKey BuildAadAccessKey(Uri uri, Dictionary<string, string> d
}
else
{
return new AadAccessKey(uri, new ManagedIdentityCredential(clientId));
return new AadAccessKey(uri, new ManagedIdentityCredential(clientId), serverEndpointUri);
}
}
else
{
return new AadAccessKey(uri, new ManagedIdentityCredential());
return new AadAccessKey(uri, new ManagedIdentityCredential(), serverEndpointUri);
}
}

private static AccessKey BuildAccessKey(Uri uri, Dictionary<string, string> dict)
{
if (dict.TryGetValue(AccessKeyProperty, out var key))
{
return new AccessKey(uri, key);
}
throw new ArgumentException(MissingAccessKeyProperty, AccessKeyProperty);
return dict.TryGetValue(AccessKeyProperty, out var key)
? new AccessKey(uri, key)
: throw new ArgumentException(MissingAccessKeyProperty, AccessKeyProperty);
}

private static AccessKey BuildAzureAccessKey(Uri uri, Dictionary<string, string> dict)
private static AccessKey BuildAzureAccessKey(Uri uri, Uri serverEndpointUri, Dictionary<string, string> dict)
{
return new AadAccessKey(uri, new DefaultAzureCredential());
return new AadAccessKey(uri, new DefaultAzureCredential(), serverEndpointUri);
}

private static AccessKey BuildAzureAppAccessKey(Uri uri, Dictionary<string, string> dict)
private static AccessKey BuildAzureAppAccessKey(Uri uri, Uri serverEndpointUri, Dictionary<string, string> dict)
{
if (!dict.TryGetValue(ClientIdProperty, out var clientId))
{
Expand All @@ -217,22 +217,20 @@ private static AccessKey BuildAzureAppAccessKey(Uri uri, Dictionary<string, stri

if (dict.TryGetValue(ClientSecretProperty, out var clientSecret))
{
return new AadAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret));
return new AadAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri);
}
else if (dict.TryGetValue(ClientCertProperty, out var clientCertPath))
{
return new AadAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath));
return new AadAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri);
}
throw new ArgumentException(MissingClientSecretProperty, ClientSecretProperty);
}

private static AccessKey BuildAzureMsiAccessKey(Uri uri, Dictionary<string, string> dict)
private static AccessKey BuildAzureMsiAccessKey(Uri uri, Uri serverEndpointUri, Dictionary<string, string> dict)
{
if (dict.TryGetValue(ClientIdProperty, out var clientId))
{
return new AadAccessKey(uri, new ManagedIdentityCredential(clientId));
}
return new AadAccessKey(uri, new ManagedIdentityCredential());
return dict.TryGetValue(ClientIdProperty, out var clientId)
? new AadAccessKey(uri, new ManagedIdentityCredential(clientId), serverEndpointUri)
: new AadAccessKey(uri, new ManagedIdentityCredential(), serverEndpointUri);
}

private static Dictionary<string, string> ToDictionary(string connectionString)
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public async Task SendAsync(
}
catch (HttpRequestException ex)
{
throw new AzureSignalRInaccessibleEndpointException(request.RequestUri.ToString(), ex);
throw new AzureSignalRException($"An error happened when making request to {request.RequestUri}", ex);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ public class ServiceManagerOptions
/// </summary>
public ServiceTransportType ServiceTransportType { get; set; } = ServiceTransportType.Transient;

/// <summary>
/// Gets or sets the timespan to wait before the HTTP request times out. The default value is 100 seconds.
/// </summary>
public TimeSpan HttpClientTimeout { get; set; } = TimeSpan.FromSeconds(100);

/// <summary>
/// Gets the json serializer settings that will be used to serialize content sent to Azure SignalR Service.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ private static IServiceCollection TrySetProductInfo(this IServiceCollection serv
}

private static IServiceCollection AddRestClientFactory(this IServiceCollection services) => services
.AddHttpClient(Options.DefaultName)
.AddHttpClient(Options.DefaultName, (sp, client) => client.Timeout = sp.GetRequiredService<IOptions<ServiceManagerOptions>>().Value.HttpClientTimeout)
.ConfigurePrimaryHttpMessageHandler(sp => new HttpClientHandler() { Proxy = sp.GetRequiredService<IOptions<ServiceManagerOptions>>().Value.Proxy }).Services
.AddSingleton(sp =>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;

using Azure.Identity;

using Xunit;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ public void TestAzureApplication(string connectionString)
{
var r = ConnectionStringParser.Parse(connectionString);

var aadAccessKey = Assert.IsType<AadAccessKey>(r.AccessKey);
Assert.IsType<ClientSecretCredential>(aadAccessKey.TokenCredential);
var key = Assert.IsType<AadAccessKey>(r.AccessKey);
Assert.IsType<ClientSecretCredential>(key.TokenCredential);
Assert.Same(r.Endpoint, r.AccessKey.Endpoint);
Assert.Null(r.Version);
Assert.Null(r.ClientEndpoint);
Expand Down Expand Up @@ -148,8 +148,8 @@ internal void TestDefaultAzureCredential(string expectedEndpoint, string connect
var r = ConnectionStringParser.Parse(connectionString);

Assert.Equal(expectedEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/'));
var aadAccessKey = Assert.IsType<AadAccessKey>(r.AccessKey);
Assert.IsType<DefaultAzureCredential>(aadAccessKey.TokenCredential);
var key = Assert.IsType<AadAccessKey>(r.AccessKey);
Assert.IsType<DefaultAzureCredential>(key.TokenCredential);
Assert.Same(r.Endpoint, r.AccessKey.Endpoint);
}

Expand All @@ -165,12 +165,25 @@ internal void TestManagedIdentity(string expectedEndpoint, string connectionStri
var r = ConnectionStringParser.Parse(connectionString);

Assert.Equal(expectedEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/'));
var aadAccessKey = Assert.IsType<AadAccessKey>(r.AccessKey);
Assert.IsType<ManagedIdentityCredential>(aadAccessKey.TokenCredential);
var key = Assert.IsType<AadAccessKey>(r.AccessKey);
Assert.IsType<ManagedIdentityCredential>(key.TokenCredential);
Assert.Same(r.Endpoint, r.AccessKey.Endpoint);
Assert.Null(r.ClientEndpoint);
}

[Theory]
[InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo", "https://foo/api/v1/auth/accesskey")]
[InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo:123", "https://foo:123/api/v1/auth/accesskey")]
[InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo/bar", "https://foo/bar/api/v1/auth/accesskey")]
[InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo/bar/", "https://foo/bar/api/v1/auth/accesskey")]
[InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo:123/bar/", "https://foo:123/bar/api/v1/auth/accesskey")]
internal void TestAzureADWithServerEndpoint(string connectionString, string expectedAuthorizeUrl)
{
var r = ConnectionStringParser.Parse(connectionString);
var key = Assert.IsType<AadAccessKey>(r.AccessKey);
Assert.Equal(expectedAuthorizeUrl, key.AuthorizeUrl, StringComparer.OrdinalIgnoreCase);
}

public class ClientEndpointTestData : IEnumerable<object[]>
{
public IEnumerator<object[]> GetEnumerator()
Expand Down
Loading
Loading