Skip to content

Commit

Permalink
Merge branch 'dev' into ff
Browse files Browse the repository at this point in the history
  • Loading branch information
vicancy authored Oct 24, 2023
2 parents ce98f4c + ea96052 commit 7928fbb
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 39 deletions.
6 changes: 4 additions & 2 deletions src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
using System;
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Http;
Expand Down Expand Up @@ -169,7 +172,6 @@ private async Task AuthorizeWithTokenAsync(string accessToken, CancellationToken
await new RestClient().SendAsync(
api,
HttpMethod.Get,
"",
handleExpectedResponseAsync: HandleHttpResponseAsync,
cancellationToken: ctoken);
}
Expand Down
24 changes: 9 additions & 15 deletions src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,54 +43,50 @@ public RestClient(IHttpClientFactory httpClientFactory, IPayloadContentBuilder c
public Task SendAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
string productInfo,
string? methodName = null,
object[]? args = null,
Func<HttpResponseMessage, bool>? handleExpectedResponse = null,
CancellationToken cancellationToken = default)
{
if (handleExpectedResponse == null)
{
return SendAsync(api, httpMethod, productInfo, methodName, args, handleExpectedResponseAsync: null, cancellationToken);
return SendAsync(api, httpMethod, methodName, args, handleExpectedResponseAsync: null, cancellationToken);
}

return SendAsync(api, httpMethod, productInfo, methodName, args, response => Task.FromResult(handleExpectedResponse(response)), cancellationToken);
return SendAsync(api, httpMethod, methodName, args, response => Task.FromResult(handleExpectedResponse(response)), cancellationToken);
}

public Task SendAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
string productInfo,
string? methodName = null,
object[]? args = null,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponseAsync = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Options.DefaultName, api, httpMethod, productInfo, methodName, args, handleExpectedResponseAsync, cancellationToken);
return SendAsyncCore(Options.DefaultName, api, httpMethod, methodName, args, handleExpectedResponseAsync, cancellationToken);
}

public Task SendWithRetryAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
string productInfo,
string? methodName = null,
object[]? args = null,
Func<HttpResponseMessage, bool>? handleExpectedResponse = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, productInfo, methodName, args, handleExpectedResponse == null ? null : response => Task.FromResult(handleExpectedResponse(response)), cancellationToken);
return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, methodName, args, handleExpectedResponse == null ? null : response => Task.FromResult(handleExpectedResponse(response)), cancellationToken);
}

public Task SendMessageWithRetryAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
string productInfo,
string? methodName = null,
object[]? args = null,
Func<HttpResponseMessage, bool>? handleExpectedResponse = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, productInfo, methodName, args, handleExpectedResponse == null ? null : response => Task.FromResult(handleExpectedResponse(response)), cancellationToken);
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, methodName, args, handleExpectedResponse == null ? null : response => Task.FromResult(handleExpectedResponse(response)), cancellationToken);
}

private async Task ThrowExceptionOnResponseFailureAsync(HttpResponseMessage response)
Expand Down Expand Up @@ -122,14 +118,13 @@ private async Task SendAsyncCore(
string httpClientName,
RestApiEndpoint api,
HttpMethod httpMethod,
string productInfo,
string? methodName = null,
object[]? args = null,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponseAsync = null,
CancellationToken cancellationToken = default)
{
using var httpClient = _httpClientFactory.CreateClient(httpClientName);
using var request = BuildRequest(api, httpMethod, productInfo, methodName, args);
using var request = BuildRequest(api, httpMethod, methodName, args);

try
{
Expand Down Expand Up @@ -182,21 +177,20 @@ private static Uri GetUri(string url, IDictionary<string, StringValues>? query)
return builder.Uri;
}

private HttpRequestMessage BuildRequest(RestApiEndpoint api, HttpMethod httpMethod, string productInfo, string? methodName = null, object[]? args = null)
private HttpRequestMessage BuildRequest(RestApiEndpoint api, HttpMethod httpMethod, string? methodName = null, object[]? args = null)
{
var payload = httpMethod == HttpMethod.Post ? new PayloadMessage { Target = methodName, Arguments = args } : null;
if (_enableMessageTracing)
{
AddTracingId(api);
}
return GenerateHttpRequest(api.Audience, api.Query, httpMethod, payload, api.Token, productInfo);
return GenerateHttpRequest(api.Audience, api.Query, httpMethod, payload, api.Token);
}

private HttpRequestMessage GenerateHttpRequest(string url, IDictionary<string, StringValues> query, HttpMethod httpMethod, PayloadMessage? payload, string tokenString, string productInfo)
private HttpRequestMessage GenerateHttpRequest(string url, IDictionary<string, StringValues> query, HttpMethod httpMethod, PayloadMessage? payload, string tokenString)
{
var request = new HttpRequestMessage(httpMethod, GetUri(url, query));
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", tokenString);
request.Headers.Add(Constants.AsrsUserAgent, productInfo);
request.Content = _payloadContentBuilder.Build(payload);
return request;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ private static IServiceCollection AddRestClientFactory(this IServiceCollection s
{
// For AAD, health check.
services
.AddHttpClient(Options.DefaultName, (sp, client) => client.Timeout = sp.GetRequiredService<IOptions<ServiceManagerOptions>>().Value.HttpClientTimeout)
.AddHttpClient(Options.DefaultName, (sp, client) =>
{
client.Timeout = sp.GetRequiredService<IOptions<ServiceManagerOptions>>().Value.HttpClientTimeout;
ConfigureProduceInfo(sp, client);
})
.ConfigurePrimaryHttpMessageHandler(ConfigureProxy);

// For other data plane APIs.
Expand Down Expand Up @@ -192,13 +196,18 @@ private static IServiceCollection AddRestClientFactory(this IServiceCollection s
// The timeout is enforced by TimeoutHttpMessageHandler.
client.Timeout = Timeout.InfiniteTimeSpan;
}
ConfigureProduceInfo(sp, client);
})
.ConfigurePrimaryHttpMessageHandler(ConfigureProxy)
.AddHttpMessageHandler(sp => ActivatorUtilities.CreateInstance<RetryHttpMessageHandler>(sp, (HttpStatusCode code) => IsTransientErrorForNonMessageApi(code)))
.AddHttpMessageHandler(sp => ActivatorUtilities.CreateInstance<TimeoutHttpMessageHandler>(sp));

services
.AddHttpClient(Constants.HttpClientNames.MessageResilient, (sp, client) => client.Timeout = sp.GetRequiredService<IOptions<ServiceManagerOptions>>().Value.HttpClientTimeout)
.AddHttpClient(Constants.HttpClientNames.MessageResilient, (sp, client) =>
{
client.Timeout = sp.GetRequiredService<IOptions<ServiceManagerOptions>>().Value.HttpClientTimeout;
ConfigureProduceInfo(sp, client);
})
.ConfigurePrimaryHttpMessageHandler(ConfigureProxy)
.AddHttpMessageHandler(sp => ActivatorUtilities.CreateInstance<RetryHttpMessageHandler>(sp, (HttpStatusCode code) => IsTransientErrorAndIdempotentForMessageApi(code)));

Expand All @@ -221,6 +230,12 @@ static bool IsTransientErrorAndIdempotentForMessageApi(HttpStatusCode code) =>
static bool IsTransientErrorForNonMessageApi(HttpStatusCode code) =>
code >= HttpStatusCode.InternalServerError ||
code == HttpStatusCode.RequestTimeout;

static void ConfigureProduceInfo(IServiceProvider sp, HttpClient client) =>
client.DefaultRequestHeaders.Add(Constants.AsrsUserAgent, sp.GetRequiredService<IOptions<ServiceManagerOptions>>().Value.ProductInfo ??
// The following value should not be used.
"Microsoft.Azure.SignalR.Management/");

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public IServiceHubLifetimeManager<THub> Create<THub>(string hubName) where THub
var httpClientFactory = _serviceProvider.GetRequiredService<IHttpClientFactory>();
var serviceEndpoint = _serviceProvider.GetRequiredService<IServiceEndpointManager>().Endpoints.First().Key;
var restClient = new RestClient(httpClientFactory, payloadBuilderResolver.GetPayloadContentBuilder(), _options.EnableMessageTracing);
return new RestHubLifetimeManager<THub>(hubName, serviceEndpoint, _options.ProductInfo, _options.ApplicationName, restClient);
return new RestHubLifetimeManager<THub>(hubName, serviceEndpoint, _options.ApplicationName, restClient);
}
default: throw new InvalidEnumArgumentException(nameof(ServiceManagerOptions.ServiceTransportType), (int)_options.ServiceTransportType, typeof(ServiceTransportType));
}
Expand Down
Loading

0 comments on commit 7928fbb

Please sign in to comment.