Skip to content

Commit

Permalink
Set message tracing ID in DI instead of in REST client (#1857)
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-Sindo authored Oct 25, 2023
1 parent 6fa8775 commit b26a920
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 47 deletions.
25 changes: 5 additions & 20 deletions src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,20 @@ internal class RestClient
{
private readonly IHttpClientFactory _httpClientFactory;
private readonly IPayloadContentBuilder _payloadContentBuilder;
private readonly bool _enableMessageTracing;

public RestClient(IHttpClientFactory httpClientFactory, IPayloadContentBuilder contentBuilder, bool enableMessageTracing)
public RestClient(IHttpClientFactory httpClientFactory, IPayloadContentBuilder contentBuilder)
{
_httpClientFactory = httpClientFactory;
_payloadContentBuilder = contentBuilder;
_enableMessageTracing = enableMessageTracing;
}

public RestClient(IHttpClientFactory httpClientFactory, ObjectSerializer objectSerializer, bool enableMessageTracing) : this(httpClientFactory, new JsonPayloadContentBuilder(objectSerializer), enableMessageTracing)
// TODO: Test only, will remove later
internal RestClient(IHttpClientFactory httpClientFactory) : this(httpClientFactory, new JsonPayloadContentBuilder(new JsonObjectSerializer()))
{
}


public RestClient() : this(HttpClientFactory.Instance, new JsonObjectSerializer(), false)
// TODO: remove later
public RestClient() : this(HttpClientFactory.Instance)
{
}

Expand Down Expand Up @@ -180,10 +179,6 @@ private static Uri GetUri(string url, IDictionary<string, StringValues>? query)
private HttpRequestMessage BuildRequest(RestApiEndpoint api, HttpMethod httpMethod, string? methodName = null, object[]? args = null)
{
var payload = httpMethod == HttpMethod.Post ? new PayloadMessage { Target = methodName, Arguments = args } : null;
if (_enableMessageTracing)
{
AddTracingId(api);
}
return GenerateHttpRequest(api.Audience, api.Query, httpMethod, payload, api.Token);
}

Expand All @@ -194,15 +189,5 @@ private HttpRequestMessage GenerateHttpRequest(string url, IDictionary<string, S
request.Content = _payloadContentBuilder.Build(payload);
return request;
}

private void AddTracingId(RestApiEndpoint api)
{
var id = MessageWithTracingIdHelper.Generate();
if (api.Query == null)
{
api.Query = new Dictionary<string, StringValues>();
}
api.Query.Add(Constants.Headers.AsrsMessageTracingId, id.ToString());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ private static IServiceCollection AddRestClientFactory(this IServiceCollection s
client.Timeout = Timeout.InfiniteTimeSpan;
}
ConfigureProduceInfo(sp, client);
ConfigureMessageTracingId(sp, client);
})
.ConfigurePrimaryHttpMessageHandler(ConfigureProxy)
.AddHttpMessageHandler(sp => ActivatorUtilities.CreateInstance<RetryHttpMessageHandler>(sp, (HttpStatusCode code) => IsTransientErrorForNonMessageApi(code)))
Expand All @@ -207,6 +208,7 @@ private static IServiceCollection AddRestClientFactory(this IServiceCollection s
{
client.Timeout = sp.GetRequiredService<IOptions<ServiceManagerOptions>>().Value.HttpClientTimeout;
ConfigureProduceInfo(sp, client);
ConfigureMessageTracingId(sp, client);
})
.ConfigurePrimaryHttpMessageHandler(ConfigureProxy)
.AddHttpMessageHandler(sp => ActivatorUtilities.CreateInstance<RetryHttpMessageHandler>(sp, (HttpStatusCode code) => IsTransientErrorAndIdempotentForMessageApi(code)));
Expand Down Expand Up @@ -236,6 +238,13 @@ static void ConfigureProduceInfo(IServiceProvider sp, HttpClient client) =>
// The following value should not be used.
"Microsoft.Azure.SignalR.Management/");

static void ConfigureMessageTracingId(IServiceProvider sp, HttpClient client)
{
if (sp.GetRequiredService<IOptions<ServiceManagerOptions>>().Value.EnableMessageTracing)
{
client.DefaultRequestHeaders.Add(Constants.Headers.AsrsMessageTracingId, MessageWithTracingIdHelper.Generate().ToString());
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public IServiceHubLifetimeManager<THub> Create<THub>(string hubName) where THub
var payloadBuilderResolver = _serviceProvider.GetRequiredService<PayloadBuilderResolver>();
var httpClientFactory = _serviceProvider.GetRequiredService<IHttpClientFactory>();
var serviceEndpoint = _serviceProvider.GetRequiredService<IServiceEndpointManager>().Endpoints.First().Key;
var restClient = new RestClient(httpClientFactory, payloadBuilderResolver.GetPayloadContentBuilder(), _options.EnableMessageTracing);
var restClient = new RestClient(httpClientFactory, payloadBuilderResolver.GetPayloadContentBuilder());
return new RestHubLifetimeManager<THub>(hubName, serviceEndpoint, _options.ApplicationName, restClient);
}
default: throw new InvalidEnumArgumentException(nameof(ServiceManagerOptions.ServiceTransportType), (int)_options.ServiceTransportType, typeof(ServiceTransportType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System.Net;
using System.Net.Http;
using System.Threading.Tasks;
using Azure.Core.Serialization;
using Microsoft.Azure.SignalR.Tests.Common;
using Microsoft.Extensions.DependencyInjection;
using Xunit;
Expand All @@ -21,7 +20,7 @@ public async Task TestHttpRequestExceptionWithStatusCodeSetAsync()
var httpClientFactory = new ServiceCollection()
.AddHttpClient("").ConfigurePrimaryHttpMessageHandler(() => new TestRootHandler(HttpStatusCode.InsufficientStorage)).Services
.BuildServiceProvider().GetRequiredService<IHttpClientFactory>();
var client = new RestClient(httpClientFactory, new JsonObjectSerializer(), true);
var client = new RestClient(httpClientFactory);
var apiEndpoint = new RestApiEndpoint("https://localhost.test.com", "token");
var exception = await Assert.ThrowsAsync<AzureSignalRRuntimeException>(() =>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,68 @@ public async Task HttpClientProductInfoTestAsync(string httpClientName)
await httpClient.SendAsync(new HttpRequestMessage(HttpMethod.Get, "http://abc"));
}

[Theory]
[InlineData(Constants.HttpClientNames.Resilient)]
[InlineData(Constants.HttpClientNames.MessageResilient)]
public async Task HttpClientMessageTracingIdEnabledTestAsync(string httpClientName)
{
using var hubContext = await new ServiceManagerBuilder()
.WithOptions(o =>
{
o.ConnectionString = FakeEndpointUtils.GetFakeConnectionString(1).Single();
o.EnableMessageTracing = true;
})
.ConfigureServices(services => services.AddHttpClient(httpClientName)
.ConfigurePrimaryHttpMessageHandler(() =>
new TestRootHandler((message, token) =>
{
if (message.Headers.TryGetValues(Constants.Headers.AsrsMessageTracingId, out var values))
{
Assert.Single(values);
Convert.ToUInt64(values.Single());
}
else
{
throw new Exception("Message tracing Id header is missing");
}
})))
.BuildServiceManager()
.CreateHubContextAsync("hubName", default);
var serviceProvider = (hubContext as ServiceHubContextImpl).ServiceProvider;
var httpClientFactory = serviceProvider.GetRequiredService<IHttpClientFactory>();
using var httpClient = httpClientFactory.CreateClient(httpClientName);
await httpClient.SendAsync(new HttpRequestMessage(HttpMethod.Get, "http://abc"));
}


[Theory]
[InlineData(Constants.HttpClientNames.Resilient)]
[InlineData(Constants.HttpClientNames.MessageResilient)]
public async Task HttpClientMessageTracingIdDisabledTestAsync(string httpClientName)
{
using var hubContext = await new ServiceManagerBuilder()
.WithOptions(o =>
{
o.ConnectionString = FakeEndpointUtils.GetFakeConnectionString(1).Single();
o.EnableMessageTracing = false;
})
.ConfigureServices(services => services.AddHttpClient(httpClientName)
.ConfigurePrimaryHttpMessageHandler(() =>
new TestRootHandler((message, token) =>
{
if (message.Headers.TryGetValues(Constants.Headers.AsrsMessageTracingId, out var values))
{
throw new Exception("Message tracing Id header is not expected");
}
})))
.BuildServiceManager()
.CreateHubContextAsync("hubName", default);
var serviceProvider = (hubContext as ServiceHubContextImpl).ServiceProvider;
var httpClientFactory = serviceProvider.GetRequiredService<IHttpClientFactory>();
using var httpClient = httpClientFactory.CreateClient(httpClientName);
await httpClient.SendAsync(new HttpRequestMessage(HttpMethod.Get, "http://abc"));
}

private class WaitInfinitelyHandler : DelegatingHandler
{
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Threading.Tasks;
using Azure.Core.Serialization;
using Microsoft.Azure.SignalR.Tests;
using Xunit;

Expand All @@ -33,28 +31,6 @@ internal async Task RestApiTest(Task<RestApiEndpoint> task, string expectedAudie
Assert.Equal(expectedTokenString, api.Token);
}

[Theory]
[InlineData(true)]
[InlineData(false)]
internal async Task EnableMessageTracingIdInRestApiTest(bool enable)
{
var api = await _restApiProvider.GetBroadcastEndpointAsync("app", "hub");
var client = new RestClient(HttpClientFactory.Instance, new NewtonsoftJsonObjectSerializer(), enable);
try
{
await client.SendAsync(api, HttpMethod.Post, "", handleExpectedResponse: default).OrTimeout(200);
}
catch
{
}
Assert.Equal(enable, api.Query?.ContainsKey(Constants.Headers.AsrsMessageTracingId) ?? false);
if (enable)
{
var id = Convert.ToUInt64(api.Query[Constants.Headers.AsrsMessageTracingId]);
Assert.Equal(MessageWithTracingIdHelper.Prefix, id);
}
}

public static IEnumerable<object[]> GetTestData() =>
from context in GetContext()
from pair in GetTestDataByContext(context)
Expand Down

0 comments on commit b26a920

Please sign in to comment.