Skip to content

Commit

Permalink
Merged PR 41305: Added new set of credential env variables to be used…
Browse files Browse the repository at this point in the history
… separately for pull and...

Added new set of credential env variables to be used separately for pull and push operations. Old set of variables is used for fallback.

----
#### AI description  (iteration 1)
#### PR Classification
New feature: Added support for separate credential environment variables for different registry modes (push, pull, pull from output).

#### PR Summary
This pull request introduces new environment variables for Docker credentials based on registry modes and updates the relevant classes and tests to support this feature.
- `AuthHandshakeMessageHandler.cs`: Added `GetDockerCredentialsFromEnvironment` method to fetch credentials based on registry mode.
- `Registry.cs`: Introduced `RegistryMode` enum and updated constructors to handle different registry modes.
- `DefaultRegistryAPI.cs`: Updated to use registry mode when creating HTTP clients.
- `ContainerHelpers.cs`: Added new constants for push and pull registry credentials.
- Added unit tests in `AuthHandshakeMessageHandlerTests.cs` to verify the new credential fetching logic.
  • Loading branch information
baronfel authored and marcpopMSFT committed Jul 24, 2024
1 parent 9952265 commit 375e494
Show file tree
Hide file tree
Showing 14 changed files with 179 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Concurrent;
using System.ComponentModel;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Net;
Expand Down Expand Up @@ -38,12 +39,14 @@ private sealed record AuthInfo(string Realm, string? Service, string? Scope);

private readonly string _registryName;
private readonly ILogger _logger;
private readonly RegistryMode _registryMode;
private static ConcurrentDictionary<string, AuthenticationHeaderValue?> _authenticationHeaders = new();

public AuthHandshakeMessageHandler(string registryName, HttpMessageHandler innerHandler, ILogger logger) : base(innerHandler)
public AuthHandshakeMessageHandler(string registryName, HttpMessageHandler innerHandler, ILogger logger, RegistryMode mode) : base(innerHandler)
{
_registryName = registryName;
_logger = logger;
_registryMode = mode;
}

/// <summary>
Expand Down Expand Up @@ -156,14 +159,10 @@ public DateTimeOffset ResolvedExpiration
/// </summary>
private async Task<(AuthenticationHeaderValue, DateTimeOffset)?> GetAuthenticationAsync(string registry, string scheme, AuthInfo? bearerAuthInfo, CancellationToken cancellationToken)
{
// Allow overrides for auth via environment variables
string? credU = Environment.GetEnvironmentVariable(ContainerHelpers.HostObjectUser) ?? Environment.GetEnvironmentVariable(ContainerHelpers.HostObjectUserLegacy);
string? credP = Environment.GetEnvironmentVariable(ContainerHelpers.HostObjectPass) ?? Environment.GetEnvironmentVariable(ContainerHelpers.HostObjectPassLegacy);

// fetch creds for the host

DockerCredentials? privateRepoCreds;

if (!string.IsNullOrEmpty(credU) && !string.IsNullOrEmpty(credP))
// Allow overrides for auth via environment variables
if (GetDockerCredentialsFromEnvironment(_registryMode) is (string credU, string credP))
{
privateRepoCreds = new DockerCredentials(credU, credP);
}
Expand Down Expand Up @@ -196,6 +195,63 @@ public DateTimeOffset ResolvedExpiration
}
}

internal static (string credU, string credP)? TryGetCredentialsFromEnvVars(string unameVar, string passwordVar)
{
var credU = Environment.GetEnvironmentVariable(unameVar);
var credP = Environment.GetEnvironmentVariable(passwordVar);
if (!string.IsNullOrEmpty(credU) && !string.IsNullOrEmpty(credP))
{
return (credU, credP);
}
else
{
return null;
}
}

/// <summary>
/// Gets docker credentials from the environment variables based on registry mode.
/// </summary>
internal static (string credU, string credP)? GetDockerCredentialsFromEnvironment(RegistryMode mode)
{
if (mode == RegistryMode.Push)
{
if (TryGetCredentialsFromEnvVars(ContainerHelpers.PushHostObjectUser, ContainerHelpers.PushHostObjectPass) is (string, string) pushCreds)
{
return pushCreds;
}

if (TryGetCredentialsFromEnvVars(ContainerHelpers.HostObjectUser, ContainerHelpers.HostObjectPass) is (string, string) genericCreds)
{
return genericCreds;
}

return TryGetCredentialsFromEnvVars(ContainerHelpers.HostObjectUserLegacy, ContainerHelpers.HostObjectPassLegacy);
}
else if (mode == RegistryMode.Pull)
{
return TryGetCredentialsFromEnvVars(ContainerHelpers.PullHostObjectUser, ContainerHelpers.PullHostObjectPass);
}
else if (mode == RegistryMode.PullFromOutput)
{
if (TryGetCredentialsFromEnvVars(ContainerHelpers.PullHostObjectUser, ContainerHelpers.PullHostObjectPass) is (string, string) pullCreds)
{
return pullCreds;
}

if (TryGetCredentialsFromEnvVars(ContainerHelpers.HostObjectUser, ContainerHelpers.HostObjectPass) is (string, string) genericCreds)
{
return genericCreds;
}

return TryGetCredentialsFromEnvVars(ContainerHelpers.HostObjectUserLegacy, ContainerHelpers.HostObjectPassLegacy);
}
else
{
throw new InvalidEnumArgumentException(nameof(mode), (int)mode, typeof(RegistryMode));
}
}

/// <summary>
/// Implements the Docker OAuth2 Authentication flow as documented at <see href="https://docs.docker.com/registry/spec/auth/oauth/"/>.
/// </summary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ internal static async Task<int> ContainerizeAsync(
logger.LogTrace("Trace logging: enabled.");

bool isLocalPull = string.IsNullOrEmpty(baseRegistry);
Registry? sourceRegistry = isLocalPull ? null : new Registry(baseRegistry, logger);
RegistryMode sourceRegistryMode = baseRegistry.Equals(outputRegistry, StringComparison.InvariantCultureIgnoreCase) ? RegistryMode.PullFromOutput : RegistryMode.Pull;
Registry? sourceRegistry = isLocalPull ? null : new Registry(baseRegistry, logger, sourceRegistryMode);
SourceImageReference sourceImageReference = new(sourceRegistry, baseImageName, baseImageTag);

DestinationImageReference destinationImageReference = DestinationImageReference.CreateFromSettings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ public static class ContainerHelpers
internal const string HostObjectPass = "DOTNET_CONTAINER_REGISTRY_PWORD";
internal const string HostObjectPassLegacy = "SDK_CONTAINER_REGISTRY_PWORD";

internal const string PushHostObjectUser = "DOTNET_CONTAINER_PUSH_REGISTRY_UNAME";
internal const string PushHostObjectPass = "DOTNET_CONTAINER_PUSH_REGISTRY_PWORD";

internal const string PullHostObjectUser = "DOTNET_CONTAINER_PULL_REGISTRY_UNAME";
internal const string PullHostObjectPass = "DOTNET_CONTAINER_PULL_REGISTRY_PWORD";

internal const string DockerRegistryAlias = "docker.io";

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ public static DestinationImageReference CreateFromSettings(
}
else if (!string.IsNullOrEmpty(outputRegistry))
{
destinationImageReference = new DestinationImageReference(new Registry(outputRegistry, loggerFactory.CreateLogger<Registry>()), repository, imageTags);
destinationImageReference = new DestinationImageReference(
new Registry(outputRegistry, loggerFactory.CreateLogger<Registry>(), RegistryMode.Push),
repository,
imageTags);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ internal class DefaultRegistryAPI : IRegistryAPI
// Making this a round 30 for convenience.
private static TimeSpan LongRequestTimeout = TimeSpan.FromMinutes(30);

internal DefaultRegistryAPI(string registryName, Uri baseUri, bool isInsecureRegistry, ILogger logger)
internal DefaultRegistryAPI(string registryName, Uri baseUri, bool isInsecureRegistry, ILogger logger, RegistryMode mode)
{
_baseUri = baseUri;
_logger = logger;
_client = CreateClient(registryName, baseUri, isInsecureRegistry, logger);
_client = CreateClient(registryName, baseUri, logger, isInsecureRegistry, mode);
Manifest = new DefaultManifestOperations(_baseUri, registryName, _client, _logger);
Blob = new DefaultBlobOperations(_baseUri, registryName, _client, _logger);
}
Expand All @@ -35,11 +35,11 @@ internal DefaultRegistryAPI(string registryName, Uri baseUri, bool isInsecureReg

public IManifestOperations Manifest { get; }

private static HttpClient CreateClient(string registryName, Uri baseUri, bool isInsecureRegistry, ILogger logger)
private static HttpClient CreateClient(string registryName, Uri baseUri, ILogger logger, bool isInsecureRegistry, RegistryMode mode)
{
HttpMessageHandler innerHandler = CreateHttpHandler(baseUri, isInsecureRegistry, logger);

HttpMessageHandler clientHandler = new AuthHandshakeMessageHandler(registryName, innerHandler, logger);
HttpMessageHandler clientHandler = new AuthHandshakeMessageHandler(registryName, innerHandler, logger, mode);

if (baseUri.IsAmazonECRRegistry())
{
Expand Down
49 changes: 45 additions & 4 deletions src/Containers/Microsoft.NET.Build.Containers/Registry/Registry.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ public RidGraphManifestPicker(string runtimeIdentifierGraphPath)

}

internal enum RegistryMode
{
Push,
Pull,
PullFromOutput
}

internal sealed class Registry
{
private const string DockerHubRegistry1 = "registry-1.docker.io";
Expand All @@ -70,11 +77,24 @@ internal sealed class Registry
/// </summary>
public string RegistryName { get; }

internal Registry(string registryName, ILogger logger, IRegistryAPI? registryAPI = null, RegistrySettings? settings = null) :
internal Registry(string registryName, ILogger logger, IRegistryAPI registryAPI, RegistrySettings? settings = null) :
this(new Uri($"https://{registryName}"), logger, registryAPI, settings)
{ }

internal Registry(Uri baseUri, ILogger logger, IRegistryAPI? registryAPI = null, RegistrySettings? settings = null)
internal Registry(string registryName, ILogger logger, RegistryMode mode, RegistrySettings? settings = null) :
this(new Uri($"https://{registryName}"), logger, new RegistryApiFactory(mode), settings)
{ }


internal Registry(Uri baseUri, ILogger logger, IRegistryAPI registryAPI, RegistrySettings? settings = null) :
this(baseUri, logger, new RegistryApiFactory(registryAPI), settings)
{ }

internal Registry(Uri baseUri, ILogger logger, RegistryMode mode, RegistrySettings? settings = null) :
this(baseUri, logger, new RegistryApiFactory(mode), settings)
{ }

private Registry(Uri baseUri, ILogger logger, RegistryApiFactory factory, RegistrySettings? settings = null)
{
RegistryName = DeriveRegistryName(baseUri);

Expand All @@ -87,15 +107,15 @@ internal Registry(Uri baseUri, ILogger logger, IRegistryAPI? registryAPI = null,

_logger = logger;
_settings = settings ?? new RegistrySettings(RegistryName);
_registryAPI = registryAPI ?? new DefaultRegistryAPI(RegistryName, BaseUri, _settings.IsInsecure, logger);
_registryAPI = factory.Create(RegistryName, BaseUri, logger, _settings.IsInsecure);
}

private static string DeriveRegistryName(Uri baseUri)
{
var port = baseUri.Port == -1 ? string.Empty : $":{baseUri.Port}";
if (baseUri.OriginalString.EndsWith(port, ignoreCase: true, culture: null))
{
// the port was part of the original assignment, so it's ok to consider it part of the 'name
// the port was part of the original assignment, so it's ok to consider it part of the 'name'
return baseUri.GetComponents(UriComponents.HostAndPort, UriFormat.Unescaped);
}
else
Expand Down Expand Up @@ -507,4 +527,25 @@ private async Task PushAsync(BuiltImage builtImage, SourceImageReference source,
_logger.LogInformation(Strings.Registry_ManifestUploaded, RegistryName);
}
}

private readonly ref struct RegistryApiFactory
{
private readonly IRegistryAPI? _registryApi;
private readonly RegistryMode? _mode;

public RegistryApiFactory(IRegistryAPI registryApi)
{
_registryApi = registryApi;
}

public RegistryApiFactory(RegistryMode mode)
{
_mode = mode;
}

public IRegistryAPI Create(string registryName, Uri baseUri, ILogger logger, bool isInsecureRegistry)
{
return _registryApi ?? new DefaultRegistryAPI(registryName, baseUri, isInsecureRegistry, logger, _mode!.Value);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ internal async Task<bool> ExecuteAsync(CancellationToken cancellationToken)
return !Log.HasLoggedErrors;
}

Registry? sourceRegistry = IsLocalPull ? null : new Registry(BaseRegistry, logger);
RegistryMode sourceRegistryMode = BaseRegistry.Equals(OutputRegistry, StringComparison.InvariantCultureIgnoreCase) ? RegistryMode.PullFromOutput : RegistryMode.Pull;
Registry? sourceRegistry = IsLocalPull ? null : new Registry(BaseRegistry, logger, sourceRegistryMode);
SourceImageReference sourceImageReference = new(sourceRegistry, BaseImageName, BaseImageTag);

DestinationImageReference destinationImageReference = DestinationImageReference.CreateFromSettings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ public async System.Threading.Tasks.Task CreateNewImage_RootlessBaseImage()
var logger = loggerFactory.CreateLogger(nameof(CreateNewImage_RootlessBaseImage));

// Build a rootless base runtime image.
Registry registry = new Registry(DockerRegistryManager.LocalRegistry, logger);
Registry registry = new(DockerRegistryManager.LocalRegistry, logger, RegistryMode.Push);

ImageBuilder imageBuilder = await registry.GetImageManifestAsync(
DockerRegistryManager.RuntimeBaseImage,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ public static async Task StartAndPopulateDockerRegistry(ITestOutputHelper testOu
int spawnRegistryDelay = 1000; //ms
StringBuilder failureReasons = new();

var pullRegistry = new Registry(BaseImageSource, logger);
var pushRegistry = new Registry(LocalRegistry, logger);
var pullRegistry = new Registry(BaseImageSource, logger, RegistryMode.Pull);
var pushRegistry = new Registry(LocalRegistry, logger, RegistryMode.Push);

for (int spawnRegistryAttempt = 1; spawnRegistryAttempt <= spawnRegistryMaxRetry; spawnRegistryAttempt++)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public async Task GetFromRegistry()
{
var loggerFactory = new TestLoggerFactory(_testOutput);
var logger = loggerFactory.CreateLogger(nameof(GetFromRegistry));
Registry registry = new Registry(DockerRegistryManager.LocalRegistry, logger);
Registry registry = new(DockerRegistryManager.LocalRegistry, logger, RegistryMode.Push);
var ridgraphfile = ToolsetUtils.GetRuntimeGraphFilePath();

// Don't need rid graph for local registry image pulls - since we're only pushing single image manifests (not manifest lists)
Expand Down Expand Up @@ -74,9 +74,9 @@ public async Task WriteToPrivateBasicRegistry()
// login to that registry
ContainerCli.LoginCommand(_testOutput, "--username", "testuser", "--password", "testpassword", registryName).Execute().Should().Pass();
// push an image to that registry using username/password
Registry localAuthed = new(new Uri($"https://{registryName}"), logger, settings: new(registryName) { ParallelUploadEnabled = false, ForceChunkedUpload = true });
Registry localAuthed = new(new Uri($"https://{registryName}"), logger, RegistryMode.Push, settings: new() { ParallelUploadEnabled = false, ForceChunkedUpload = true });
var ridgraphfile = ToolsetUtils.GetRuntimeGraphFilePath();
Registry mcr = new Registry(DockerRegistryManager.BaseImageSource, logger);
Registry mcr = new(DockerRegistryManager.BaseImageSource, logger, RegistryMode.Pull);

var sourceImage = new SourceImageReference(mcr, DockerRegistryManager.RuntimeBaseImage, DockerRegistryManager.Net6ImageTag);
var destinationImage = new DestinationImageReference(localAuthed, DockerRegistryManager.RuntimeBaseImage, new[] { DockerRegistryManager.Net6ImageTag });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public async Task ApiEndToEndWithRegistryPushAndPull()

// Build the image

Registry registry = new Registry(DockerRegistryManager.LocalRegistry, logger);
Registry registry = new(DockerRegistryManager.LocalRegistry, logger, RegistryMode.Push);

ImageBuilder imageBuilder = await registry.GetImageManifestAsync(
DockerRegistryManager.RuntimeBaseImage,
Expand Down Expand Up @@ -93,7 +93,7 @@ public async Task ApiEndToEndWithLocalLoad()

// Build the image

Registry registry = new Registry(DockerRegistryManager.LocalRegistry, logger);
Registry registry = new(DockerRegistryManager.LocalRegistry, logger, RegistryMode.Push);

ImageBuilder imageBuilder = await registry.GetImageManifestAsync(
DockerRegistryManager.RuntimeBaseImage,
Expand Down Expand Up @@ -134,7 +134,7 @@ public async Task ApiEndToEndWithArchiveWritingAndLoad()

// Build the image

Registry registry = new Registry(DockerRegistryManager.LocalRegistry, logger);
Registry registry = new(DockerRegistryManager.LocalRegistry, logger, RegistryMode.Push);

ImageBuilder imageBuilder = await registry.GetImageManifestAsync(
DockerRegistryManager.RuntimeBaseImage,
Expand Down Expand Up @@ -555,7 +555,7 @@ public async Task CanPackageForAllSupportedContainerRIDs(string dockerPlatform,
string publishDirectory = BuildLocalApp(tfm: ToolsetInfo.CurrentTargetFramework, rid: rid);

// Build the image
Registry registry = new(DockerRegistryManager.BaseImageSource, logger);
Registry registry = new(DockerRegistryManager.BaseImageSource, logger, RegistryMode.Push);
var isWin = rid.StartsWith("win");
ImageBuilder? imageBuilder = await registry.GetImageManifestAsync(
DockerRegistryManager.RuntimeBaseImage,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public async Task CanReadManifestFromRegistry(string fullyQualifiedContainerName
containerTag ??= "latest";

ILogger logger = _loggerFactory.CreateLogger(nameof(CanReadManifestFromRegistry));
Registry registry = new Registry(containerRegistry, logger);
Registry registry = new(containerRegistry, logger, RegistryMode.Pull);

var ridgraphfile = ToolsetUtils.GetRuntimeGraphFilePath();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.NET.Build.Containers.UnitTests
{
public class AuthHandshakeMessageHandlerTests
{
[Theory]
[InlineData("SDK_CONTAINER_REGISTRY_UNAME", "SDK_CONTAINER_REGISTRY_PWORD", (int)RegistryMode.Push)]
[InlineData("DOTNET_CONTAINER_PUSH_REGISTRY_UNAME", "DOTNET_CONTAINER_PUSH_REGISTRY_PWORD", (int)RegistryMode.Push)]
[InlineData("DOTNET_CONTAINER_PULL_REGISTRY_UNAME", "DOTNET_CONTAINER_PULL_REGISTRY_PWORD", (int)RegistryMode.Pull)]
[InlineData("DOTNET_CONTAINER_PULL_REGISTRY_UNAME", "DOTNET_CONTAINER_PULL_REGISTRY_PWORD", (int)RegistryMode.PullFromOutput)]
[InlineData("SDK_CONTAINER_REGISTRY_UNAME", "SDK_CONTAINER_REGISTRY_PWORD", (int)RegistryMode.PullFromOutput)]
public void GetDockerCredentialsFromEnvironment_ReturnsCorrectValues(string unameVarName, string pwordVarName, int mode)
{
string? originalUnameValue = Environment.GetEnvironmentVariable(unameVarName);
string? originalPwordValue = Environment.GetEnvironmentVariable(pwordVarName);

Environment.SetEnvironmentVariable(unameVarName, "uname");
Environment.SetEnvironmentVariable(pwordVarName, "pword");

if (AuthHandshakeMessageHandler.GetDockerCredentialsFromEnvironment((RegistryMode)mode) is (string credU, string credP))
{
Assert.Equal("uname", credU);
Assert.Equal("pword", credP);
}
else
{
Assert.Fail("Should have parsed credentials from environment");
}


// restore env variable values
Environment.SetEnvironmentVariable(unameVarName, originalUnameValue);
Environment.SetEnvironmentVariable(pwordVarName, originalPwordValue);
}
}
}
Loading

0 comments on commit 375e494

Please sign in to comment.