diff --git a/src/main/java/io/supertokens/authRecipe/AuthRecipe.java b/src/main/java/io/supertokens/authRecipe/AuthRecipe.java index 757236c80..0c4f7b918 100644 --- a/src/main/java/io/supertokens/authRecipe/AuthRecipe.java +++ b/src/main/java/io/supertokens/authRecipe/AuthRecipe.java @@ -935,7 +935,7 @@ private static void deleteNonAuthRecipeUser(TransactionConnection con, AppIdenti appIdentifierWithStorage.getActiveUsersStorage() .deleteUserActive_Transaction(con, appIdentifierWithStorage, userId); appIdentifierWithStorage.getMfaStorage() - .deleteMfaInfoForUser(appIdentifierWithStorage, userId); + .deleteMfaInfoForUser_Transaction(con, appIdentifierWithStorage, userId); } private static void deleteAuthRecipeUser(TransactionConnection con, diff --git a/src/main/java/io/supertokens/inmemorydb/Start.java b/src/main/java/io/supertokens/inmemorydb/Start.java index 4107b79a0..cf6e858ae 100644 --- a/src/main/java/io/supertokens/inmemorydb/Start.java +++ b/src/main/java/io/supertokens/inmemorydb/Start.java @@ -50,6 +50,7 @@ import io.supertokens.pluginInterface.jwt.exceptions.DuplicateKeyIdException; import io.supertokens.pluginInterface.jwt.sqlstorage.JWTRecipeSQLStorage; import io.supertokens.pluginInterface.mfa.MfaStorage; +import io.supertokens.pluginInterface.mfa.sqlStorage.MfaSQLStorage; import io.supertokens.pluginInterface.multitenancy.*; import io.supertokens.pluginInterface.multitenancy.exceptions.DuplicateClientTypeException; import io.supertokens.pluginInterface.multitenancy.exceptions.DuplicateTenantException; @@ -103,7 +104,7 @@ public class Start implements SessionSQLStorage, EmailPasswordSQLStorage, EmailVerificationSQLStorage, ThirdPartySQLStorage, JWTRecipeSQLStorage, PasswordlessSQLStorage, UserMetadataSQLStorage, UserRolesSQLStorage, UserIdMappingStorage, UserIdMappingSQLStorage, MultitenancyStorage, MultitenancySQLStorage, TOTPSQLStorage, ActiveUsersStorage, - DashboardSQLStorage, AuthRecipeSQLStorage, MfaStorage { + ActiveUsersSQLStorage, DashboardSQLStorage, AuthRecipeSQLStorage, MfaStorage, MfaSQLStorage { private static final Object appenderLock = new Object(); private static final String APP_ID_KEY_NAME = "app_id"; @@ -2853,10 +2854,10 @@ public boolean disableFactor(TenantIdentifier tenantIdentifier, String userId, S } @Override - public boolean deleteMfaInfoForUser(AppIdentifier appIdentifier, String userId) + public boolean deleteMfaInfoForUser_Transaction(TransactionConnection con, AppIdentifier appIdentifier, String userId) throws StorageQueryException { try { - int deletedCount = MfaQueries.deleteUser(this, appIdentifier, userId); + int deletedCount = MfaQueries.deleteUser_Transaction(this, (Connection) con.getConnection(), appIdentifier, userId); if (deletedCount == 0) { return false; } diff --git a/src/main/java/io/supertokens/inmemorydb/queries/MfaQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/MfaQueries.java index 7758dd261..9cca4bf53 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/MfaQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/MfaQueries.java @@ -22,6 +22,7 @@ import io.supertokens.pluginInterface.multitenancy.AppIdentifier; import io.supertokens.pluginInterface.multitenancy.TenantIdentifier; +import java.sql.Connection; import java.sql.SQLException; import java.util.ArrayList; import java.util.List; @@ -101,10 +102,10 @@ public static int disableFactor(Start start, TenantIdentifier tenantIdentifier, }); } - public static int deleteUser(Start start, AppIdentifier appIdentifier, String userId) throws StorageQueryException, SQLException { + public static int deleteUser_Transaction(Start start, Connection sqlCon, AppIdentifier appIdentifier, String userId) throws StorageQueryException, SQLException { String QUERY = "DELETE FROM " + Config.getConfig(start).getMfaUserFactorsTable() + " WHERE app_id = ? AND user_id = ?"; - return update(start, QUERY, pst -> { + return update(sqlCon, QUERY, pst -> { pst.setString(1, appIdentifier.getAppId()); pst.setString(2, userId); }); diff --git a/src/test/java/io/supertokens/test/FeatureFlagTest.java b/src/test/java/io/supertokens/test/FeatureFlagTest.java index 480b60f2f..8c249464b 100644 --- a/src/test/java/io/supertokens/test/FeatureFlagTest.java +++ b/src/test/java/io/supertokens/test/FeatureFlagTest.java @@ -297,8 +297,10 @@ public void testThatCallingGetFeatureFlagAPIReturnsMfaStats() throws Exception { JsonObject usageStats = response.get("usageStats").getAsJsonObject(); JsonArray maus = usageStats.get("maus").getAsJsonArray(); - assert features.size() == 1; - assert features.get(0).getAsString().equals("mfa"); + if (!StorageLayer.isInMemDb(process.getProcess())) { + assert features.size() == 1; + assert features.get(0).getAsString().equals("mfa"); + } assert maus.size() == 30; assert maus.get(0).getAsInt() == 0; assert maus.get(29).getAsInt() == 0; @@ -349,8 +351,10 @@ public void testThatCallingGetFeatureFlagAPIReturnsMfaStats() throws Exception { JsonObject usageStats = response.get("usageStats").getAsJsonObject(); JsonArray maus = usageStats.get("maus").getAsJsonArray(); - assert features.size() == 1; - assert features.get(0).getAsString().equals("mfa"); + if (!StorageLayer.isInMemDb(process.getProcess())) { + assert features.size() == 1; + assert features.get(0).getAsString().equals("mfa"); + } assert maus.size() == 30; assert maus.get(0).getAsInt() == 2; // 2 users have signed up assert maus.get(29).getAsInt() == 2; diff --git a/src/test/java/io/supertokens/test/mfa/MfaLicenseTest.java b/src/test/java/io/supertokens/test/mfa/MfaLicenseTest.java index 1a902a6d1..96d32a3f6 100644 --- a/src/test/java/io/supertokens/test/mfa/MfaLicenseTest.java +++ b/src/test/java/io/supertokens/test/mfa/MfaLicenseTest.java @@ -22,12 +22,12 @@ import io.supertokens.mfa.Mfa; import io.supertokens.pluginInterface.mfa.MfaStorage; import io.supertokens.pluginInterface.multitenancy.TenantIdentifierWithStorage; +import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.httpRequest.HttpResponseException; import org.junit.Test; import java.util.HashMap; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; public class MfaLicenseTest extends MfaTestBase { @@ -37,6 +37,11 @@ public void testTotpWithoutLicense() throws Exception { if (result == null) { return; } + + if (StorageLayer.isInMemDb(result.process.getProcess())) { + return; + } + Main main = result.process.getProcess(); MfaStorage storage = result.storage; TenantIdentifierWithStorage tid = new TenantIdentifierWithStorage(null, null, null, storage); diff --git a/src/test/java/io/supertokens/test/mfa/MfaStorageTest.java b/src/test/java/io/supertokens/test/mfa/MfaStorageTest.java index 1393ce57b..6af701819 100644 --- a/src/test/java/io/supertokens/test/mfa/MfaStorageTest.java +++ b/src/test/java/io/supertokens/test/mfa/MfaStorageTest.java @@ -22,7 +22,9 @@ import io.supertokens.featureflag.FeatureFlagTestContent; import io.supertokens.multitenancy.Multitenancy; import io.supertokens.pluginInterface.mfa.MfaStorage; +import io.supertokens.pluginInterface.mfa.sqlStorage.MfaSQLStorage; import io.supertokens.pluginInterface.multitenancy.*; +import io.supertokens.pluginInterface.sqlStorage.SQLStorage; import org.junit.Test; import static org.junit.Assert.assertNotNull; @@ -110,7 +112,7 @@ public void deleteUserTest() throws Exception { if (result == null) { return; } - MfaStorage storage = result.storage; + MfaSQLStorage storage = result.storage; TenantIdentifier tid = new TenantIdentifier(null, null, null); assert storage.enableFactor(tid, "user1", "f1") == true; @@ -119,8 +121,11 @@ public void deleteUserTest() throws Exception { assert storage.enableFactor(tid, "user2", "f1") == true; assert storage.enableFactor(tid, "user2", "f3") == true; - assert storage.deleteMfaInfoForUser(tid.toAppIdentifier(), "non-existent-user") == false; - assert storage.deleteMfaInfoForUser(tid.toAppIdentifier(), "user2") == true; + ((SQLStorage) storage).startTransaction(con -> { + assert storage.deleteMfaInfoForUser_Transaction(con, tid.toAppIdentifier(), "non-existent-user") == false; + assert storage.deleteMfaInfoForUser_Transaction(con, tid.toAppIdentifier(), "user2") == true; + return null; + }); String[] factors = storage.listFactors(tid, "user2"); assert factors.length == 0; diff --git a/src/test/java/io/supertokens/test/mfa/MfaTestBase.java b/src/test/java/io/supertokens/test/mfa/MfaTestBase.java index 54627036a..1d9b236d5 100644 --- a/src/test/java/io/supertokens/test/mfa/MfaTestBase.java +++ b/src/test/java/io/supertokens/test/mfa/MfaTestBase.java @@ -21,7 +21,7 @@ import io.supertokens.featureflag.EE_FEATURES; import io.supertokens.featureflag.FeatureFlagTestContent; import io.supertokens.pluginInterface.STORAGE_TYPE; -import io.supertokens.pluginInterface.mfa.MfaStorage; +import io.supertokens.pluginInterface.mfa.sqlStorage.MfaSQLStorage; import io.supertokens.storageLayer.StorageLayer; import io.supertokens.test.TestingProcessManager; import io.supertokens.test.Utils; @@ -53,10 +53,10 @@ public void beforeEach() { public class TestSetupResult { - public MfaStorage storage; + public MfaSQLStorage storage; public TestingProcessManager.TestingProcess process; - public TestSetupResult(MfaStorage storage, TestingProcessManager.TestingProcess process) { + public TestSetupResult(MfaSQLStorage storage, TestingProcessManager.TestingProcess process) { this.storage = storage; this.process = process; } @@ -72,7 +72,8 @@ public TestSetupResult initSteps(boolean enableMfaFeature) if (StorageLayer.getStorage(process.getProcess()).getType() != STORAGE_TYPE.SQL) { return null; } - MfaStorage storage = (MfaStorage) StorageLayer.getStorage(process.getProcess()); + + MfaSQLStorage storage = (MfaSQLStorage) StorageLayer.getStorage(process.getProcess()); if (enableMfaFeature) { FeatureFlagTestContent.getInstance(process.main) diff --git a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java index fbd8c863f..0b589bec9 100644 --- a/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java +++ b/src/test/java/io/supertokens/test/totp/TOTPRecipeTest.java @@ -190,6 +190,8 @@ public void createDeviceAndVerifyCodeTest() throws Exception { InvalidTotpException.class, () -> Totp.verifyCode(main, "user", generateTotpCode(main, unverifiedDevice))); + Thread.sleep(1000 - System.currentTimeMillis() % 1000 + 10); + // Valid code & verified device (Success) String validCode = generateTotpCode(main, device); Totp.verifyCode(main, "user", validCode); @@ -200,7 +202,7 @@ public void createDeviceAndVerifyCodeTest() throws Exception { () -> Totp.verifyCode(main, "user", validCode)); // Sleep for 1s so that code changes. - Thread.sleep(1000); + Thread.sleep(1000 - System.currentTimeMillis() % 1000 + 10); // Use a new valid code: String newValidCode = generateTotpCode(main, device); @@ -278,6 +280,10 @@ public void createDeviceAndVerifyCodeTest() throws Exception { public int triggerAndCheckRateLimit(Main main, TOTPDevice device) throws Exception { int N = Config.getConfig(main).getTotpMaxAttempts(); + // Sleep until we finish the current second so that TOTP verification won't change in the time limit + Thread.sleep(1000 - System.currentTimeMillis() % 1000 + 10); + Thread.sleep(1000); // sleep another second so that the rate limit state is kind of reset + // First N attempts should fail with invalid code: // This is to trigger rate limiting for (int i = 0; i < N; i++) { @@ -447,6 +453,9 @@ public void removeDeviceTest() throws Exception { // Delete one of the devices { assertThrows(InvalidTotpException.class, () -> Totp.verifyCode(main, "user", "ic0")); + + Thread.sleep(1000 - System.currentTimeMillis() % 1000 + 10); + Totp.verifyCode(main, "user", generateTotpCode(main, device1)); Totp.verifyCode(main, "user", generateTotpCode(main, device2));