Skip to content

Commit

Permalink
Add an IUserProvider abstraction that is a combination of IUserAccess…
Browse files Browse the repository at this point in the history
…or, IImpersonator, IUserClaimCreator, IUserRetrieveService and IUserCacheInvalidator
  • Loading branch information
volkanceylan committed Oct 19, 2024
1 parent e03de4e commit 844b436
Show file tree
Hide file tree
Showing 15 changed files with 299 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public void OnResourceExecuting(ResourceExecutingContext context)
if (string.IsNullOrEmpty(path))
path = "/";

if (context.HttpContext.User.GetUserDefinition(userRetrieveService) is IHasPassword { HasPassword: false })
if (userRetrieveService.GetUserDefinition(context.HttpContext.User) is IHasPassword { HasPassword: false })
context.Result = new LocalRedirectResult("~/Account/SetPassword?reason=elevate");
else
context.Result = new LocalRedirectResult("~/Account/Elevate?returnUrl=" + Uri.EscapeDataString(path));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ public abstract class AccountPasswordActionsPageBase<TUserRow> : MembershipPageB

[HttpGet, PageAuthorize]
public virtual ActionResult ChangePassword(
[FromServices] IUserRetrieveService userRetrieveService)
[FromServices] IUserRetrieveService userRetriever)
{
var userDefinition = User.GetUserDefinition<IUserDefinition>(userRetrieveService);
if (userDefinition is IHasPassword hasPassword &&
if (userRetriever.GetUserDefinition(User) is IHasPassword hasPassword &&
!hasPassword.HasPassword)
{
return SetPassword();
Expand All @@ -36,13 +35,13 @@ public ActionResult SetPassword()

[HttpPost, ServiceAuthorize]
public virtual ActionResult SendResetPassword(
[FromServices] IUserRetrieveService userRetrieveService,
[FromServices] IUserRetrieveService userRetriever,
[FromServices] IEmailSender emailSender,
[FromServices] ISiteAbsoluteUrl siteAbsoluteUrl,
[FromServices] ITwoLevelCache cache,
[FromServices] ITextLocalizer localizer)
{
var userDefinition = User.GetUserDefinition<IUserDefinition>(userRetrieveService) ??
var userDefinition = userRetriever.GetUserDefinition(User) ??
throw new ValidationError("Couldn't find user definition.");

#if (IsPublicDemo)
Expand Down Expand Up @@ -70,7 +69,7 @@ public virtual Result<ServiceResponse> ChangePassword(ChangePasswordRequest requ
[FromServices] ITwoLevelCache cache,
[FromServices] IUserPasswordValidator passwordValidator,
[FromServices] IPasswordStrengthValidator passwordStrengthValidator,
[FromServices] IUserRetrieveService userRetrieveService,
[FromServices] IUserRetrieveService userRetriever,
[FromServices] IOptions<MembershipSettings> membershipOptions,
[FromServices] IOptions<EnvironmentSettings> environmentOptions,
[FromServices] ITextLocalizer localizer)
Expand All @@ -83,7 +82,7 @@ public virtual Result<ServiceResponse> ChangePassword(ChangePasswordRequest requ
var username = User.Identity?.Name;
var userDefinition = User.GetUserDefinition<IUserDefinition>(userRetrieveService);
var userDefinition = userRetriever.GetUserDefinition(User);
if (userDefinition is not IHasPassword hasPassword ||
hasPassword.HasPassword)
Expand Down
4 changes: 1 addition & 3 deletions serene/src/Serene.Web/Initialization/Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ public void ConfigureServices(IServiceCollection services)
services.AddSingleton<IRolePermissionService, AppServices.RolePermissionService>();
services.AddSingleton<IUploadAVScanner, ClamAVUploadScanner>();
services.AddSingleton<IUserPasswordValidator, AppServices.UserPasswordValidator>();
services.AddSingleton<IUserAccessor, AppServices.UserAccessor>();
services.AddSingleton<IUserClaimCreator, DefaultUserClaimCreator>();
services.AddSingleton<IUserRetrieveService, AppServices.UserRetrieveService>();
services.AddUserProvider<AppServices.UserAccessor, AppServices.UserRetrieveService>();
services.AddServiceHandlers();
services.AddDynamicScripts();
services.AddCssBundling();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,24 @@
using Serene.Administration.Repositories;

namespace Serene.Administration;

/// <summary>
/// This declares a dynamic script with key 'UserData' that will be available from client side.
/// </summary>
[DataScript("UserData", CacheDuration = -1, Permission = "*")]
public class UserDataScript : DataScript<ScriptUserDefinition>
public class UserDataScript(ITwoLevelCache cache, IPermissionService permissions,
IPermissionKeyLister permissionKeyLister, IUserProvider userProvider) : DataScript<ScriptUserDefinition>
{
private readonly ITwoLevelCache cache;
private readonly IPermissionService permissions;
private readonly IPermissionKeyLister permissionKeyLister;
private readonly IUserAccessor userAccessor;
private readonly IUserRetrieveService userRetriever;

public UserDataScript(ITwoLevelCache cache, IPermissionService permissions,
IPermissionKeyLister permissionKeyLister, IUserAccessor userAccessor, IUserRetrieveService userRetriever)
{
this.cache = cache ?? throw new ArgumentNullException(nameof(cache));
this.permissions = permissions ?? throw new ArgumentNullException(nameof(permissions));
this.permissionKeyLister = permissionKeyLister ?? throw new ArgumentNullException(nameof(permissionKeyLister));
this.userAccessor = userAccessor ?? throw new ArgumentNullException(nameof(userAccessor));
this.userRetriever = userRetriever ?? throw new ArgumentNullException(nameof(userRetriever));
}
private readonly ITwoLevelCache cache = cache ?? throw new ArgumentNullException(nameof(cache));
private readonly IPermissionService permissions = permissions ?? throw new ArgumentNullException(nameof(permissions));
private readonly IPermissionKeyLister permissionKeyLister = permissionKeyLister ?? throw new ArgumentNullException(nameof(permissionKeyLister));
private readonly IUserProvider userProvider = userProvider ?? throw new ArgumentNullException(nameof(userProvider));

protected override ScriptUserDefinition GetData()
{
{
var result = new ScriptUserDefinition();

if (userAccessor.User?.GetUserDefinition(userRetriever) is not UserDefinition user)
if (userProvider.GetUserDefinition() is not UserDefinition user)
{
result.Permissions = new Dictionary<string, bool>();
result.Permissions = [];
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ namespace Serene.AppServices;
public class UserPasswordValidator(ITwoLevelCache cache, ISqlConnections sqlConnections, IUserRetrieveService userRetriever,
ILogger<UserPasswordValidator> log = null, IDirectoryService directoryService = null) : IUserPasswordValidator
{
protected ITwoLevelCache Cache { get; } = cache ?? throw new ArgumentNullException(nameof(cache));
public ISqlConnections SqlConnections { get; } = sqlConnections ?? throw new ArgumentNullException(nameof(sqlConnections));
protected IUserRetrieveService UserRetriever { get; } = userRetriever ?? throw new ArgumentNullException(nameof(userRetriever));
protected IDirectoryService DirectoryService { get; } = directoryService;
protected ILogger<UserPasswordValidator> Log { get; } = log;
protected readonly ITwoLevelCache cache = cache ?? throw new ArgumentNullException(nameof(cache));
protected readonly ISqlConnections sqlConnections = sqlConnections ?? throw new ArgumentNullException(nameof(sqlConnections));
protected readonly IUserRetrieveService userRetriever = userRetriever ?? throw new ArgumentNullException(nameof(userRetriever));
protected readonly IDirectoryService directoryService = directoryService;
protected readonly ILogger<UserPasswordValidator> Log = log;

public PasswordValidationResult Validate(ref string username, string password)
{
Expand All @@ -22,7 +22,7 @@ public PasswordValidationResult Validate(ref string username, string password)

username = username.TrimToEmpty();

if (UserRetriever.ByUsername(username) is UserDefinition user)
if (userRetriever.ByUsername(username) is UserDefinition user)
return ValidateExistingUser(ref username, password, user);

return ValidateFirstTimeUser(ref username, password);
Expand All @@ -39,14 +39,14 @@ private PasswordValidationResult ValidateExistingUser(ref string username, strin
}

// prevent more than 50 invalid login attempts in 30 minutes
var throttler = new Throttler(Cache.Memory, "ValidateUser:" + username.ToLowerInvariant(), TimeSpan.FromMinutes(30), 50);
var throttler = new Throttler(cache.Memory, "ValidateUser:" + username.ToLowerInvariant(), TimeSpan.FromMinutes(30), 50);
if (!throttler.Check())
return PasswordValidationResult.Throttle;

bool validatePassword() => UserHelper.CalculateHash(password, user.PasswordSalt)
.Equals(user.PasswordHash, StringComparison.OrdinalIgnoreCase);

if (user.Source == "site" || user.Source == "sign" || DirectoryService == null)
if (user.Source == "site" || user.Source == "sign" || directoryService == null)
{
if (validatePassword())
{
Expand Down Expand Up @@ -76,7 +76,7 @@ bool validatePassword() => UserHelper.CalculateHash(password, user.PasswordSalt)
DirectoryEntry entry;
try
{
entry = DirectoryService.Validate(username, password);
entry = directoryService.Validate(username, password);
if (entry == null)
return PasswordValidationResult.Invalid;

Expand Down Expand Up @@ -108,7 +108,7 @@ bool validatePassword() => UserHelper.CalculateHash(password, user.PasswordSalt)
var displayName = entry.FirstName + " " + entry.LastName;
var email = entry.Email.TrimToNull() ?? user.Email ?? (username + "@yourdefaultdomain.com");

using var connection = SqlConnections.NewFor<UserRow>();
using var connection = sqlConnections.NewFor<UserRow>();
using var uow = new UnitOfWork(connection);
var fld = UserRow.Fields;
new SqlUpdate(fld.TableName)
Expand All @@ -122,8 +122,7 @@ bool validatePassword() => UserHelper.CalculateHash(password, user.PasswordSalt)

uow.Commit();

if (userRetriever is IUserCacheInvalidator cacheInvalidator)
cacheInvalidator.InvalidateItem(user);
userRetriever.InvalidateItem(user, cache);

return PasswordValidationResult.Valid;
}
Expand All @@ -136,17 +135,17 @@ bool validatePassword() => UserHelper.CalculateHash(password, user.PasswordSalt)

private PasswordValidationResult ValidateFirstTimeUser(ref string username, string password)
{
var throttler = new Throttler(Cache.Memory, "ValidateUser:" + username.ToLowerInvariant(), TimeSpan.FromMinutes(30), 50);
var throttler = new Throttler(cache.Memory, "ValidateUser:" + username.ToLowerInvariant(), TimeSpan.FromMinutes(30), 50);
if (!throttler.Check())
return PasswordValidationResult.Throttle;

if (DirectoryService == null)
if (directoryService == null)
return PasswordValidationResult.Invalid;

DirectoryEntry entry;
try
{
entry = DirectoryService.Validate(username, password);
entry = directoryService.Validate(username, password);
if (entry == null)
return PasswordValidationResult.Invalid;

Expand All @@ -166,7 +165,7 @@ private PasswordValidationResult ValidateFirstTimeUser(ref string username, stri
var email = entry.Email.TrimToNull() ?? (username + "@yourdefaultdomain.com");
username = entry.Username.TrimToNull() ?? username;

using var connection = SqlConnections.NewFor<UserRow>();
using var connection = sqlConnections.NewFor<UserRow>();
using var uow = new UnitOfWork(connection);
var userId = (int)connection.InsertAndGetID(new UserRow
{
Expand All @@ -184,11 +183,8 @@ private PasswordValidationResult ValidateFirstTimeUser(ref string username, stri

uow.Commit();

if (userRetriever is IUserCacheInvalidator cacheInvalidator)
{
cacheInvalidator.InvalidateById(userId.ToInvariant());
cacheInvalidator.InvalidateByUsername(username);
}
userRetriever.InvalidateById(userId.ToInvariant(), cache);
userRetriever.InvalidateByUsername(username, cache);

return PasswordValidationResult.Valid;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ public ActionResult AccessDenied(string returnURL)
[HttpPost, JsonRequest]
public Result<ServiceResponse> Login(LoginRequest request,
[FromServices] IUserPasswordValidator passwordValidator,
[FromServices] IUserRetrieveService userRetriever,
[FromServices] IUserClaimCreator userClaimCreator)
{

Expand All @@ -54,9 +53,6 @@ public Result<ServiceResponse> Login(LoginRequest request,
if (passwordValidator is null)
throw new ArgumentNullException(nameof(passwordValidator));
if (userRetriever is null)
throw new ArgumentNullException(nameof(userRetriever));
if (userClaimCreator is null)
throw new ArgumentNullException(nameof(userClaimCreator));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,8 @@ public Result<SignUpResponse> SignUp(SignUpRequest request,
uow.Commit();
if (userRetriever is IUserCacheInvalidator cacheInvalidator)
{
cacheInvalidator.InvalidateById(userId.ToInvariant());
cacheInvalidator.InvalidateByUsername(username);
}
userRetriever.InvalidateById(userId.ToInvariant(), Cache);
userRetriever.InvalidateByUsername(username, Cache);
if (environmentOptions?.Value.IsPublicDemo == true)
{
Expand Down
Loading

0 comments on commit 844b436

Please sign in to comment.