Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: restrict access list includes #364

Merged
merged 1 commit into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public async Task<Page<AccessListInfo, string>> GetAccessListsByOwner(
string owner,
Page<string>.Request request,
AccessListIncludes includes = default,
string? resourceIdentifier = null,
CancellationToken cancellationToken = default)
{
Guard.IsNotNull(owner);
Expand All @@ -44,6 +45,7 @@ public async Task<Page<AccessListInfo, string>> GetAccessListsByOwner(
continueFrom: request.ContinuationToken,
count: SMALL_PAGE_SIZE + 1,
includes,
resourceIdentifier,
cancellationToken);

return Page.Create(accessLists, SMALL_PAGE_SIZE, static list => list.Identifier);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,18 @@ public interface IAccessListService
/// <param name="owner">The resource owner (org.nr.).</param>
/// <param name="request">The page request.</param>
/// <param name="includes">What additional to include in the response.</param>
/// <param name="resourceIdentifier">
/// Optional resource identifier. Used if <paramref name="includes"/> contains flag <see cref="AccessListIncludes.ResourceConnections"/>
/// to filter the resource connections included.
/// </param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/>.</param>
/// <returns>A <see cref="Page{TItem, TToken}"/> of <see cref="AccessListInfo"/>.</returns>
Task<Page<AccessListInfo, string>> GetAccessListsByOwner(string owner, Page<string>.Request request, AccessListIncludes includes = default, CancellationToken cancellationToken = default);
Task<Page<AccessListInfo, string>> GetAccessListsByOwner(
string owner,
Page<string>.Request request,
AccessListIncludes includes = default,
string? resourceIdentifier = null,
CancellationToken cancellationToken = default);

/// <summary>
/// Gets an access list by owner and identifier.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@ public interface IAccessListsRepository
/// <param name="continueFrom">An optional value to continue iterating from. This value is an <see cref="AccessListInfo.Identifier"/> to start from, using greater than or equals comparison.</param>
/// <param name="count">The total number of entries to return.</param>
/// <param name="includes">What additional to include in the response.</param>
/// <param name="resourceIdentifier">
/// Optional resource identifier. Used if <paramref name="includes"/> contains flag <see cref="AccessListIncludes.ResourceConnections"/>
/// to filter the resource connections included.
/// </param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/>.</param>
/// <returns>A list of <see cref="AccessListInfo"/>, sorted by <see cref="AccessListInfo.Identifier"/> and limited by <paramref name="count"/>.</returns>
Task<IReadOnlyList<AccessListInfo>> GetAccessListsByOwner(
string resourceOwner,
string? continueFrom,
int count,
AccessListIncludes includes = default,
string? resourceIdentifier = null,
CancellationToken cancellationToken = default);

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ static NpgsqlCommand ByOwnerAndIdentQuery(NpgsqlConnection conn, string owner, s
}
}

public async Task<IReadOnlyList<AccessListInfo>> GetAccessListsByOwner(string resourceOwner, string? continueFrom, int count, AccessListIncludes includes, CancellationToken cancellationToken)
public async Task<IReadOnlyList<AccessListInfo>> GetAccessListsByOwner(string resourceOwner, string? continueFrom, int count, AccessListIncludes includes, string? resourceIdentifier, CancellationToken cancellationToken)
{
Guard.IsNotNullOrEmpty(resourceOwner);
Guard.IsGreaterThan(count, 0);
Expand Down Expand Up @@ -203,6 +203,7 @@ ORDER BY identifier ASC
await LoadResourceConnections(
accessLists,
idSet,
resourceIdentifier,
includeActions: includes.HasFlag(AccessListIncludes.ResourceConnectionsActions),
cancellationToken);
}
Expand Down Expand Up @@ -472,16 +473,18 @@ FROM resourceregistry.access_list_state
return reader.GetGuid(0);
}

private async Task LoadResourceConnections(List<AccessListInfo> accessLists, List<Guid> idSet, bool includeActions, CancellationToken cancellationToken)
private async Task LoadResourceConnections(List<AccessListInfo> accessLists, List<Guid> idSet, string? resourceIdentifier, bool includeActions, CancellationToken cancellationToken)
{
const string QUERY = /*strpsql*/@"
SELECT aggregate_id, resource_identifier, actions, created, modified
FROM resourceregistry.access_list_resource_connections_state
WHERE aggregate_id = ANY(@aggregate_ids)
AND (@resource_identifier IS NULL OR resource_identifier = @resource_identifier)
ORDER BY aggregate_id, resource_identifier;";

await using var cmd = _conn.CreateCommand(QUERY);
cmd.Parameters.Add<IList<Guid>>("aggregate_ids", NpgsqlDbType.Array | NpgsqlDbType.Uuid).TypedValue = idSet;
cmd.Parameters.AddWithNullableValue("resource_identifier", NpgsqlDbType.Text, resourceIdentifier);

AccessListInfo? current = null;
List<AccessListResourceConnection>? connections = null;
Expand Down
10 changes: 8 additions & 2 deletions src/Altinn.ResourceRegistry.Persistence/AccessListsRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,14 @@ public AccessListsRepository(
}

/// <inheritdoc/>
public Task<IReadOnlyList<AccessListInfo>> GetAccessListsByOwner(string resourceOwner, string? continueFrom, int count, AccessListIncludes includes = default, CancellationToken cancellationToken = default)
=> InTransaction(repo => repo.GetAccessListsByOwner(resourceOwner, continueFrom, count, includes, cancellationToken), cancellationToken);
public Task<IReadOnlyList<AccessListInfo>> GetAccessListsByOwner(
string resourceOwner,
string? continueFrom,
int count,
AccessListIncludes includes = default,
string? resourceIdentifier = null,
CancellationToken cancellationToken = default)
=> InTransaction(repo => repo.GetAccessListsByOwner(resourceOwner, continueFrom, count, includes, resourceIdentifier, cancellationToken), cancellationToken);

/// <inheritdoc/>
public Task<AccessListInfo?> LookupInfo(Guid id, AccessListIncludes includes = default, CancellationToken cancellationToken = default)
Expand Down
13 changes: 12 additions & 1 deletion src/ResourceRegistry/Controllers/AccessListsController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ public AccessListsController(IAccessListService service)
/// <param name="owner">The resource owner</param>
/// <param name="token">Optional continuation token</param>
/// <param name="include">What additional information to include in the response</param>
/// <param name="resourceIdentifier">
/// Optional resource identifier. Required if <paramref name="include"/> has flag <see cref="AccessListIncludes.ResourceConnections"/>
/// set. This is used to filter the resource connections included in the access lists to only the provided resource.
/// </param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/></param>
/// <returns>A paginated set of <see cref="AccessListInfoDto"/></returns>
[HttpGet("", Name = ROUTE_GET_BY_OWNER)]
Expand All @@ -65,9 +69,16 @@ public async Task<ActionResult<Paginated<AccessListInfoDto>>> GetAccessListsByOw
string owner,
[FromQuery(Name = "token")] Opaque<string>? token = null,
[FromQuery(Name = "include")] AccessListIncludes include = AccessListIncludes.None,
[FromQuery(Name = "resource")] string? resourceIdentifier = null,
CancellationToken cancellationToken = default)
{
var page = await _service.GetAccessListsByOwner(owner, Page.ContinueFrom(token?.Value), include, cancellationToken);
if (include.HasFlag(AccessListIncludes.ResourceConnections) && string.IsNullOrWhiteSpace(resourceIdentifier))
{
ModelState.AddModelError("resource", "Resource identifier is required when including resource connections");
return BadRequest(ModelState);
}

var page = await _service.GetAccessListsByOwner(owner, Page.ContinueFrom(token?.Value), include, resourceIdentifier, cancellationToken);
if (page == null)
{
return NotFound();
Expand Down
46 changes: 41 additions & 5 deletions test/ResourceRegistryTest/AccessListControllerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Altinn.ResourceRegistry.Tests.Utils;
using Altinn.ResourceRegistry.TestUtils;
using FluentAssertions;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.DependencyInjection;
using Npgsql;
using System;
Expand Down Expand Up @@ -121,33 +122,68 @@ public async Task Can_Include_Additional_Data()
}

{
using var response = await client.GetAsync($"/resourceregistry/api/v1/access-lists/{ORG_NR}?include=resources");
using var response = await client.GetAsync($"/resourceregistry/api/v1/access-lists/{ORG_NR}?include=resources&resource=test1");

var content = await response.Content.ReadFromJsonAsync<Paginated<AccessListInfoDto>>();
Assert.NotNull(content);

content.Items.Should().HaveCount(1);
content.Items.Should().Contain(al => al.Identifier == "test1")
.Which.ResourceConnections.Should().HaveCount(2)
.Which.ResourceConnections.Should().HaveCount(1)
.And.AllSatisfy(rc => rc.Actions.Should().BeNull())
.And.Contain(rc => rc.ResourceIdentifier == RESOURCE1_NAME);
}

{
using var response = await client.GetAsync($"/resourceregistry/api/v1/access-lists/{ORG_NR}?include=resources&resource=test2");

var content = await response.Content.ReadFromJsonAsync<Paginated<AccessListInfoDto>>();
Assert.NotNull(content);

content.Items.Should().HaveCount(1);
content.Items.Should().Contain(al => al.Identifier == "test1")
.Which.ResourceConnections.Should().HaveCount(1)
.And.AllSatisfy(rc => rc.Actions.Should().BeNull())
.And.Contain(rc => rc.ResourceIdentifier == RESOURCE1_NAME)
.And.Contain(rc => rc.ResourceIdentifier == RESOURCE2_NAME);
}

{
using var response = await client.GetAsync($"/resourceregistry/api/v1/access-lists/{ORG_NR}?include=resource-actions");
using var response = await client.GetAsync($"/resourceregistry/api/v1/access-lists/{ORG_NR}?include=resource-actions&resource=test1");

var content = await response.Content.ReadFromJsonAsync<Paginated<AccessListInfoDto>>();
Assert.NotNull(content);

content.Items.Should().HaveCount(1);
content.Items.Should().Contain(al => al.Identifier == "test1")
.Which.ResourceConnections.Should().HaveCount(2)
.Which.ResourceConnections.Should().HaveCount(1)
.And.AllSatisfy(rc => rc.Actions.Should().NotBeNull())
.And.Contain(rc => rc.ResourceIdentifier == RESOURCE1_NAME)
.Which.Actions.Should().BeEmpty();
}

{
using var response = await client.GetAsync($"/resourceregistry/api/v1/access-lists/{ORG_NR}?include=resource-actions&resource=test2");

var content = await response.Content.ReadFromJsonAsync<Paginated<AccessListInfoDto>>();
Assert.NotNull(content);

content.Items.Should().HaveCount(1);
content.Items.Should().Contain(al => al.Identifier == "test1")
.Which.ResourceConnections.Should().HaveCount(1)
.And.AllSatisfy(rc => rc.Actions.Should().NotBeNull())
.And.Contain(rc => rc.ResourceIdentifier == RESOURCE2_NAME)
.Which.Actions.Should().BeEquivalentTo([ACTION_READ]);
}

{
using var response = await client.GetAsync($"/resourceregistry/api/v1/access-lists/{ORG_NR}?include=resources");
response.StatusCode.Should().Be(HttpStatusCode.BadRequest);
}

{
using var response = await client.GetAsync($"/resourceregistry/api/v1/access-lists/{ORG_NR}?include=resource-actions");
response.StatusCode.Should().Be(HttpStatusCode.BadRequest);
}
}

[Fact]
Expand Down
Loading