Skip to content

Commit

Permalink
Merge branch 'dev' into ff
Browse files Browse the repository at this point in the history
  • Loading branch information
vicancy authored Oct 25, 2023
2 parents 7928fbb + b26a920 commit df759d5
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 161 deletions.

This file was deleted.

25 changes: 5 additions & 20 deletions src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,20 @@ internal class RestClient
{
private readonly IHttpClientFactory _httpClientFactory;
private readonly IPayloadContentBuilder _payloadContentBuilder;
private readonly bool _enableMessageTracing;

public RestClient(IHttpClientFactory httpClientFactory, IPayloadContentBuilder contentBuilder, bool enableMessageTracing)
public RestClient(IHttpClientFactory httpClientFactory, IPayloadContentBuilder contentBuilder)
{
_httpClientFactory = httpClientFactory;
_payloadContentBuilder = contentBuilder;
_enableMessageTracing = enableMessageTracing;
}

public RestClient(IHttpClientFactory httpClientFactory, ObjectSerializer objectSerializer, bool enableMessageTracing) : this(httpClientFactory, new JsonPayloadContentBuilder(objectSerializer), enableMessageTracing)
// TODO: Test only, will remove later
internal RestClient(IHttpClientFactory httpClientFactory) : this(httpClientFactory, new JsonPayloadContentBuilder(new JsonObjectSerializer()))
{
}


public RestClient() : this(HttpClientFactory.Instance, new JsonObjectSerializer(), false)
// TODO: remove later
public RestClient() : this(HttpClientFactory.Instance)
{
}

Expand Down Expand Up @@ -180,10 +179,6 @@ private static Uri GetUri(string url, IDictionary<string, StringValues>? query)
private HttpRequestMessage BuildRequest(RestApiEndpoint api, HttpMethod httpMethod, string? methodName = null, object[]? args = null)
{
var payload = httpMethod == HttpMethod.Post ? new PayloadMessage { Target = methodName, Arguments = args } : null;
if (_enableMessageTracing)
{
AddTracingId(api);
}
return GenerateHttpRequest(api.Audience, api.Query, httpMethod, payload, api.Token);
}

Expand All @@ -194,15 +189,5 @@ private HttpRequestMessage GenerateHttpRequest(string url, IDictionary<string, S
request.Content = _payloadContentBuilder.Build(payload);
return request;
}

private void AddTracingId(RestApiEndpoint api)
{
var id = MessageWithTracingIdHelper.Generate();
if (api.Query == null)
{
api.Query = new Dictionary<string, StringValues>();
}
api.Query.Add(Constants.Headers.AsrsMessageTracingId, id.ToString());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ private static IServiceCollection AddRestClientFactory(this IServiceCollection s
client.Timeout = Timeout.InfiniteTimeSpan;
}
ConfigureProduceInfo(sp, client);
ConfigureMessageTracingId(sp, client);
})
.ConfigurePrimaryHttpMessageHandler(ConfigureProxy)
.AddHttpMessageHandler(sp => ActivatorUtilities.CreateInstance<RetryHttpMessageHandler>(sp, (HttpStatusCode code) => IsTransientErrorForNonMessageApi(code)))
Expand All @@ -207,6 +208,7 @@ private static IServiceCollection AddRestClientFactory(this IServiceCollection s
{
client.Timeout = sp.GetRequiredService<IOptions<ServiceManagerOptions>>().Value.HttpClientTimeout;
ConfigureProduceInfo(sp, client);
ConfigureMessageTracingId(sp, client);
})
.ConfigurePrimaryHttpMessageHandler(ConfigureProxy)
.AddHttpMessageHandler(sp => ActivatorUtilities.CreateInstance<RetryHttpMessageHandler>(sp, (HttpStatusCode code) => IsTransientErrorAndIdempotentForMessageApi(code)));
Expand Down Expand Up @@ -236,6 +238,13 @@ static void ConfigureProduceInfo(IServiceProvider sp, HttpClient client) =>
// The following value should not be used.
"Microsoft.Azure.SignalR.Management/");

static void ConfigureMessageTracingId(IServiceProvider sp, HttpClient client)
{
if (sp.GetRequiredService<IOptions<ServiceManagerOptions>>().Value.EnableMessageTracing)
{
client.DefaultRequestHeaders.Add(Constants.Headers.AsrsMessageTracingId, MessageWithTracingIdHelper.Generate().ToString());
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public IServiceHubLifetimeManager<THub> Create<THub>(string hubName) where THub
var payloadBuilderResolver = _serviceProvider.GetRequiredService<PayloadBuilderResolver>();
var httpClientFactory = _serviceProvider.GetRequiredService<IHttpClientFactory>();
var serviceEndpoint = _serviceProvider.GetRequiredService<IServiceEndpointManager>().Endpoints.First().Key;
var restClient = new RestClient(httpClientFactory, payloadBuilderResolver.GetPayloadContentBuilder(), _options.EnableMessageTracing);
var restClient = new RestClient(httpClientFactory, payloadBuilderResolver.GetPayloadContentBuilder());
return new RestHubLifetimeManager<THub>(hubName, serviceEndpoint, _options.ApplicationName, restClient);
}
default: throw new InvalidEnumArgumentException(nameof(ServiceManagerOptions.ServiceTransportType), (int)_options.ServiceTransportType, typeof(ServiceTransportType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,21 @@
using System.Linq;
using Microsoft.AspNetCore.Http;
using Microsoft.Azure.SignalR.Common;
using Microsoft.Extensions.Options;

namespace Microsoft.Azure.SignalR
{
internal class DefaultEndpointRouter : DefaultMessageRouter, IEndpointRouter
{
private readonly EndpointRoutingMode _mode;

public DefaultEndpointRouter(IOptions<ServiceOptions> options)
{
_mode = options?.Value.EndpointRoutingMode ?? EndpointRoutingMode.Weighted;
}

/// <summary>
/// Select an endpoint for negotiate request according to the mode
/// Select an endpoint for negotiate request
/// </summary>
/// <param name="context">The http context of the incoming request</param>
/// <param name="endpoints">All the available endpoints</param>
public ServiceEndpoint GetNegotiateEndpoint(HttpContext context, IEnumerable<ServiceEndpoint> endpoints)
{
// get primary endpoints snapshot
var availableEndpoints = GetNegotiateEndpoints(endpoints);
return _mode switch
{
EndpointRoutingMode.Random => GetEndpointRandomly(availableEndpoints),
EndpointRoutingMode.LeastConnection => GetEndpointWithLeastConnection(availableEndpoints),
_ => GetEndpointAccordingToWeight(availableEndpoints),
};
return GetEndpointAccordingToWeight(availableEndpoints);
}

/// <summary>
Expand Down Expand Up @@ -69,7 +56,7 @@ private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] available
if (availableEndpoints.Any(endpoint => endpoint.EndpointMetrics.ConnectionCapacity == 0) ||
availableEndpoints.Length == 1)
{
return GetEndpointRandomly(availableEndpoints);
return availableEndpoints[StaticRandom.Next(availableEndpoints.Length)];
}

var we = new int[availableEndpoints.Length];
Expand All @@ -89,38 +76,5 @@ private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] available

return availableEndpoints[Array.FindLastIndex(we, x => x <= index) + 1];
}

/// <summary>
/// Choose endpoint with least connection count
/// </summary>
private ServiceEndpoint GetEndpointWithLeastConnection(ServiceEndpoint[] availableEndpoints)
{
//first check if weight is available or necessary
if (availableEndpoints.Any(endpoint => endpoint.EndpointMetrics.ConnectionCapacity == 0) ||
availableEndpoints.Length == 1)
{
return GetEndpointRandomly(availableEndpoints);
}

var leastConnectionCount = int.MaxValue;
var index = 0;
for (var i = 0; i < availableEndpoints.Length; i++)
{
var endpointMetrics = availableEndpoints[i].EndpointMetrics;
var connectionCount = endpointMetrics.ClientConnectionCount + endpointMetrics.ServerConnectionCount;
if (connectionCount < leastConnectionCount)
{
leastConnectionCount = connectionCount;
index = i;
}
}

return availableEndpoints[index];
}

private static ServiceEndpoint GetEndpointRandomly(ServiceEndpoint[] availableEndpoints)
{
return availableEndpoints[StaticRandom.Next(availableEndpoints.Length)];
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public class EndpointRouterDecorator : IEndpointRouter

public EndpointRouterDecorator(IEndpointRouter router = null)
{
_inner = router ?? new DefaultEndpointRouter(null);
_inner = router ?? new DefaultEndpointRouter();
}

public virtual ServiceEndpoint GetNegotiateEndpoint(HttpContext context, IEnumerable<ServiceEndpoint> endpoints)
Expand Down
6 changes: 0 additions & 6 deletions src/Microsoft.Azure.SignalR/ServiceOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,5 @@ public int ConnectionCount
/// Gets or sets a function which accepts <see cref="HttpContext"/> and returns a bitmask combining one or more <see cref="HttpTransportType"/> values that specify what transports the service should use to receive HTTP requests.
/// </summary>
public Func<HttpContext, HttpTransportType> TransportTypeDetector { get; set; } = null;

/// <summary>
/// Gets or sets the default endpoint routing mode when using multiple endpoints.
/// <see cref="EndpointRoutingMode.Weighted"/> by default.
/// </summary>
public EndpointRoutingMode EndpointRoutingMode { get; set; } = EndpointRoutingMode.Weighted;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System.Net;
using System.Net.Http;
using System.Threading.Tasks;
using Azure.Core.Serialization;
using Microsoft.Azure.SignalR.Tests.Common;
using Microsoft.Extensions.DependencyInjection;
using Xunit;
Expand All @@ -21,7 +20,7 @@ public async Task TestHttpRequestExceptionWithStatusCodeSetAsync()
var httpClientFactory = new ServiceCollection()
.AddHttpClient("").ConfigurePrimaryHttpMessageHandler(() => new TestRootHandler(HttpStatusCode.InsufficientStorage)).Services
.BuildServiceProvider().GetRequiredService<IHttpClientFactory>();
var client = new RestClient(httpClientFactory, new JsonObjectSerializer(), true);
var client = new RestClient(httpClientFactory);
var apiEndpoint = new RestApiEndpoint("https://localhost.test.com", "token");
var exception = await Assert.ThrowsAsync<AzureSignalRRuntimeException>(() =>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,68 @@ public async Task HttpClientProductInfoTestAsync(string httpClientName)
await httpClient.SendAsync(new HttpRequestMessage(HttpMethod.Get, "http://abc"));
}

[Theory]
[InlineData(Constants.HttpClientNames.Resilient)]
[InlineData(Constants.HttpClientNames.MessageResilient)]
public async Task HttpClientMessageTracingIdEnabledTestAsync(string httpClientName)
{
using var hubContext = await new ServiceManagerBuilder()
.WithOptions(o =>
{
o.ConnectionString = FakeEndpointUtils.GetFakeConnectionString(1).Single();
o.EnableMessageTracing = true;
})
.ConfigureServices(services => services.AddHttpClient(httpClientName)
.ConfigurePrimaryHttpMessageHandler(() =>
new TestRootHandler((message, token) =>
{
if (message.Headers.TryGetValues(Constants.Headers.AsrsMessageTracingId, out var values))
{
Assert.Single(values);
Convert.ToUInt64(values.Single());
}
else
{
throw new Exception("Message tracing Id header is missing");
}
})))
.BuildServiceManager()
.CreateHubContextAsync("hubName", default);
var serviceProvider = (hubContext as ServiceHubContextImpl).ServiceProvider;
var httpClientFactory = serviceProvider.GetRequiredService<IHttpClientFactory>();
using var httpClient = httpClientFactory.CreateClient(httpClientName);
await httpClient.SendAsync(new HttpRequestMessage(HttpMethod.Get, "http://abc"));
}


[Theory]
[InlineData(Constants.HttpClientNames.Resilient)]
[InlineData(Constants.HttpClientNames.MessageResilient)]
public async Task HttpClientMessageTracingIdDisabledTestAsync(string httpClientName)
{
using var hubContext = await new ServiceManagerBuilder()
.WithOptions(o =>
{
o.ConnectionString = FakeEndpointUtils.GetFakeConnectionString(1).Single();
o.EnableMessageTracing = false;
})
.ConfigureServices(services => services.AddHttpClient(httpClientName)
.ConfigurePrimaryHttpMessageHandler(() =>
new TestRootHandler((message, token) =>
{
if (message.Headers.TryGetValues(Constants.Headers.AsrsMessageTracingId, out var values))
{
throw new Exception("Message tracing Id header is not expected");
}
})))
.BuildServiceManager()
.CreateHubContextAsync("hubName", default);
var serviceProvider = (hubContext as ServiceHubContextImpl).ServiceProvider;
var httpClientFactory = serviceProvider.GetRequiredService<IHttpClientFactory>();
using var httpClient = httpClientFactory.CreateClient(httpClientName);
await httpClient.SendAsync(new HttpRequestMessage(HttpMethod.Get, "http://abc"));
}

private class WaitInfinitelyHandler : DelegatingHandler
{
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Threading.Tasks;
using Azure.Core.Serialization;
using Microsoft.Azure.SignalR.Tests;
using Xunit;

Expand All @@ -33,28 +31,6 @@ internal async Task RestApiTest(Task<RestApiEndpoint> task, string expectedAudie
Assert.Equal(expectedTokenString, api.Token);
}

[Theory]
[InlineData(true)]
[InlineData(false)]
internal async Task EnableMessageTracingIdInRestApiTest(bool enable)
{
var api = await _restApiProvider.GetBroadcastEndpointAsync("app", "hub");
var client = new RestClient(HttpClientFactory.Instance, new NewtonsoftJsonObjectSerializer(), enable);
try
{
await client.SendAsync(api, HttpMethod.Post, "", handleExpectedResponse: default).OrTimeout(200);
}
catch
{
}
Assert.Equal(enable, api.Query?.ContainsKey(Constants.Headers.AsrsMessageTracingId) ?? false);
if (enable)
{
var id = Convert.ToUInt64(api.Query[Constants.Headers.AsrsMessageTracingId]);
Assert.Equal(MessageWithTracingIdHelper.Prefix, id);
}
}

public static IEnumerable<object[]> GetTestData() =>
from context in GetContext()
from pair in GetTestDataByContext(context)
Expand Down
Loading

0 comments on commit df759d5

Please sign in to comment.