diff --git a/src/backend/TrafficCourts/Staff.Service/Middleware/ETagMiddleware.cs b/src/backend/TrafficCourts/Staff.Service/Middleware/ETagMiddleware.cs new file mode 100644 index 000000000..449cbab36 --- /dev/null +++ b/src/backend/TrafficCourts/Staff.Service/Middleware/ETagMiddleware.cs @@ -0,0 +1,69 @@ +using System.Security.Cryptography; +using System.Text; + +namespace TrafficCourts.Staff.Service.Middleware; + +public class ETagMiddleware +{ + private readonly RequestDelegate _next; + + public ETagMiddleware(RequestDelegate next) + { + _next = next; + } + + public async Task InvokeAsync(HttpContext context) + { + // Only handle GET requests + if (context.Request.Method != HttpMethods.Get) + { + await _next(context); + return; + } + + // capture the response body + Stream originalBodyStream = context.Response.Body; + + // create new response body stream + using var memoryStream = new MemoryStream(); + context.Response.Body = memoryStream; + await _next(context); + + // Only generate ETag for successful responses + if (context.Response.StatusCode == StatusCodes.Status200OK) + { + // Generate ETag based on the response content + memoryStream.Seek(0, SeekOrigin.Begin); + var content = await new StreamReader(memoryStream).ReadToEndAsync(); + var eTag = GenerateETag(content); + + // Check the If-None-Match header + if (context.Request.Headers.TryGetValue("If-None-Match", out Microsoft.Extensions.Primitives.StringValues value) && value == eTag) + { + context.Response.StatusCode = StatusCodes.Status304NotModified; + context.Response.Headers.ETag = eTag; + context.Response.Body = originalBodyStream; + context.Response.ContentLength = 0; // Ensure the response body is empty + return; + } + + // Include the ETag in the response headers + context.Response.Headers.ETag = eTag; + + // Write the response body back to the original stream + memoryStream.Seek(0, SeekOrigin.Begin); + await memoryStream.CopyToAsync(originalBodyStream); + } + + context.Response.Body = originalBodyStream; + } + + private string GenerateETag(string content) + { + using (var md5 = MD5.Create()) + { + var hash = md5.ComputeHash(Encoding.UTF8.GetBytes(content)); + return $"\"{Convert.ToBase64String(hash)}\""; + } + } +} diff --git a/src/backend/TrafficCourts/Staff.Service/Program.cs b/src/backend/TrafficCourts/Staff.Service/Program.cs index c6b104420..41a70c0dd 100644 --- a/src/backend/TrafficCourts/Staff.Service/Program.cs +++ b/src/backend/TrafficCourts/Staff.Service/Program.cs @@ -3,6 +3,7 @@ using TrafficCourts.Configuration.Validation; using TrafficCourts.Diagnostics; using TrafficCourts.Staff.Service; +using TrafficCourts.Staff.Service.Middleware; var builder = WebApplication.CreateBuilder(args); var logger = builder.GetProgramLogger(); @@ -18,6 +19,8 @@ app.UseAuthentication(); app.UseAuthorization(); +app.UseMiddleware(); + app.UseFastEndpoints(c => { c.Endpoints.RoutePrefix = "api"; c.Endpoints.PrefixNameWithFirstTag = true;