Skip to content

Commit

Permalink
Fix bug in triggers
Browse files Browse the repository at this point in the history
  • Loading branch information
loveleif authored Dec 4, 2023
1 parent 7bd73ae commit e8ac28f
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 75 deletions.
1 change: 0 additions & 1 deletion common/src/main/java/apoc/ApocConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import static org.neo4j.internal.helpers.ProcessUtils.executeCommandWithOutput;

import apoc.export.util.ExportConfig;
import inet.ipaddr.IPAddressString;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Field;
Expand Down
25 changes: 16 additions & 9 deletions core/src/main/java/apoc/spatial/Geocode.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import java.util.Map;
import java.util.stream.Stream;
import org.apache.commons.configuration2.Configuration;

import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.procedure.*;

Expand Down Expand Up @@ -147,7 +146,8 @@ public Stream<GeoCodeResult> geocode(String address, long maxResults, URLAccessC
}

@Override
public Stream<GeoCodeResult> reverseGeocode(Double latitude, Double longitude, URLAccessChecker urlAccessChecker) {
public Stream<GeoCodeResult> reverseGeocode(
Double latitude, Double longitude, URLAccessChecker urlAccessChecker) {
if (latitude == null || longitude == null) {
return Stream.empty();
}
Expand Down Expand Up @@ -230,14 +230,16 @@ public Stream<GeoCodeResult> geocode(String address, long maxResults, URLAccessC
}

@Override
public Stream<GeoCodeResult> reverseGeocode(Double latitude, Double longitude, URLAccessChecker urlAccessChecker) {
public Stream<GeoCodeResult> reverseGeocode(
Double latitude, Double longitude, URLAccessChecker urlAccessChecker) {
if (latitude == null || longitude == null) {
return Stream.empty();
}
throttler.waitForThrottle();

Object value = JsonUtil.loadJson(
OSM_URL_REVERSE_GEOCODE + String.format("lat=%s&lon=%s", latitude, longitude), urlAccessChecker)
OSM_URL_REVERSE_GEOCODE + String.format("lat=%s&lon=%s", latitude, longitude),
urlAccessChecker)
.findFirst()
.orElse(null);
if (value instanceof Map) {
Expand Down Expand Up @@ -286,7 +288,8 @@ public Stream<GeoCodeResult> geocode(String address, long maxResults, URLAccessC
}
throttler.waitForThrottle();
Object value = JsonUtil.loadJson(
String.format(GEOCODE_URL, credentials(this.config)) + Util.encodeUrlComponent(address), urlAccessChecker)
String.format(GEOCODE_URL, credentials(this.config)) + Util.encodeUrlComponent(address),
urlAccessChecker)
.findFirst()
.orElse(null);
if (value instanceof Map) {
Expand All @@ -311,13 +314,16 @@ public Stream<GeoCodeResult> geocode(String address, long maxResults, URLAccessC
}

@Override
public Stream<GeoCodeResult> reverseGeocode(Double latitude, Double longitude, URLAccessChecker urlAccessChecker) {
public Stream<GeoCodeResult> reverseGeocode(
Double latitude, Double longitude, URLAccessChecker urlAccessChecker) {
if (latitude == null || longitude == null) {
return Stream.empty();
}
throttler.waitForThrottle();
Object value = JsonUtil.loadJson(String.format(REVERSE_GEOCODE_URL, credentials(this.config))
+ Util.encodeUrlComponent(latitude + "," + longitude), urlAccessChecker)
Object value = JsonUtil.loadJson(
String.format(REVERSE_GEOCODE_URL, credentials(this.config))
+ Util.encodeUrlComponent(latitude + "," + longitude),
urlAccessChecker)
.findFirst()
.orElse(null);
if (value instanceof Map) {
Expand Down Expand Up @@ -402,7 +408,8 @@ public Stream<GeoCodeResult> geocode(
return getSupplier(config)
.geocode(
address,
maxResults == 0 ? MAX_RESULTS : Math.min(Math.max(maxResults, 1), MAX_RESULTS), urlAccessChecker);
maxResults == 0 ? MAX_RESULTS : Math.min(Math.max(maxResults, 1), MAX_RESULTS),
urlAccessChecker);
} catch (IllegalStateException re) {
if (!quotaException && re.getMessage().startsWith("QUOTA_EXCEEDED")) return Stream.empty();
throw re;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ public static ResourceIterator<Node> getTriggerNodes(String databaseName, Transa
return tx.findNodes(label, dbNameKey, databaseName, SystemPropertyKeys.name.name(), name);
}

private static void setLastUpdate(String databaseName, Transaction tx) {
public static void setLastUpdate(String databaseName, Transaction tx) {
Node node = tx.findNode(SystemLabels.ApocTriggerMeta, SystemPropertyKeys.database.name(), databaseName);
if (node == null) {
node = tx.createNode(SystemLabels.ApocTriggerMeta);
Expand Down
37 changes: 21 additions & 16 deletions core/src/main/java/apoc/trigger/TriggerNewProcedures.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,10 @@ public Stream<TriggerInfo> install(
checkTargetDatabase(databaseName);
Map<String, Object> params = (Map) config.getOrDefault("params", Collections.emptyMap());

return withTransaction(tx -> {
TriggerInfo triggerInfo =
TriggerHandlerNewProcedures.install(databaseName, name, statement, selector, params, tx);
return Stream.of(triggerInfo);
});
return withUpdatingTransaction(
databaseName,
tx -> Stream.of(
TriggerHandlerNewProcedures.install(databaseName, name, statement, selector, params, tx)));
}

// TODO - change with @SystemOnlyProcedure
Expand All @@ -111,10 +110,8 @@ public Stream<TriggerInfo> install(
public Stream<TriggerInfo> drop(@Name("databaseName") String databaseName, @Name("name") String name) {
checkInSystemWriter();

return withTransaction(tx -> {
final TriggerInfo removed = TriggerHandlerNewProcedures.drop(databaseName, name, tx);
return Stream.ofNullable(removed);
});
return withUpdatingTransaction(
databaseName, tx -> Stream.ofNullable(TriggerHandlerNewProcedures.drop(databaseName, name, tx)));
}

// TODO - change with @SystemOnlyProcedure
Expand All @@ -125,8 +122,9 @@ public Stream<TriggerInfo> drop(@Name("databaseName") String databaseName, @Name
public Stream<TriggerInfo> dropAll(@Name("databaseName") String databaseName) {
checkInSystemWriter();

return withTransaction(tx -> TriggerHandlerNewProcedures.dropAll(databaseName, tx).stream()
.sorted(Comparator.comparing(i -> i.name)));
return withUpdatingTransaction(
databaseName, tx -> TriggerHandlerNewProcedures.dropAll(databaseName, tx).stream()
.sorted(Comparator.comparing(i -> i.name)));
}

// TODO - change with @SystemOnlyProcedure
Expand All @@ -137,7 +135,7 @@ public Stream<TriggerInfo> dropAll(@Name("databaseName") String databaseName) {
public Stream<TriggerInfo> stop(@Name("databaseName") String databaseName, @Name("name") String name) {
checkInSystemWriter();

return withTransaction(tx -> {
return withUpdatingTransaction(databaseName, tx -> {
final TriggerInfo triggerInfo = TriggerHandlerNewProcedures.updatePaused(databaseName, name, true, tx);
return Stream.ofNullable(triggerInfo);
});
Expand All @@ -151,7 +149,7 @@ public Stream<TriggerInfo> stop(@Name("databaseName") String databaseName, @Name
public Stream<TriggerInfo> start(@Name("databaseName") String databaseName, @Name("name") String name) {
checkInSystemWriter();

return withTransaction(tx -> {
return withUpdatingTransaction(databaseName, tx -> {
final TriggerInfo triggerInfo = TriggerHandlerNewProcedures.updatePaused(databaseName, name, false, tx);
return Stream.ofNullable(triggerInfo);
});
Expand All @@ -168,11 +166,18 @@ public Stream<TriggerInfo> show(@Name("databaseName") String databaseName) {
return TriggerHandlerNewProcedures.getTriggerNodesList(databaseName, tx);
}

public <T> T withTransaction(Function<Transaction, T> action) {
public <T> T withUpdatingTransaction(String databaseName, Function<Transaction, T> action) {
T result = null;
try (Transaction tx = db.beginTx()) {
T result = action.apply(tx);
result = action.apply(tx);
tx.commit();
}

// Last update time needs to be after the installation commit happened to not risk missing updates
try (final var tx = db.beginTx()) {
TriggerHandlerNewProcedures.setLastUpdate(databaseName, tx);
tx.commit();
return result;
}
return result;
}
}
4 changes: 2 additions & 2 deletions core/src/test/java/apoc/util/LogsUtilTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ public void shouldReturnInputIfInvalidQuery() {

@Test
public void whitespaceDeprecationSucceedsSanitization() {
String sanitized =
LogsUtil.sanitizeQuery("CREATE USER dum\u0085my IF NOT EXISTS SET PASSWORD 'pass12345' CHANGE NOT REQUIRED");
String sanitized = LogsUtil.sanitizeQuery(
"CREATE USER dum\u0085my IF NOT EXISTS SET PASSWORD 'pass12345' CHANGE NOT REQUIRED");
assertEquals(sanitized, "CREATE USER dum\u0085my IF NOT EXISTS SET PASSWORD '******' CHANGE NOT REQUIRED");
}
}
3 changes: 1 addition & 2 deletions core/src/test/java/apoc/util/QueryUtilTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ public void shouldReturnFalseForInvalidQueries() {

@Test
public void shouldReturnTrueForQueryWithParserDeprecation() {
assertTrue(
QueryUtil.isValidQuery("CREATE (n:My\u0085Label)"));
assertTrue(QueryUtil.isValidQuery("CREATE (n:My\u0085Label)"));
}
}
48 changes: 26 additions & 22 deletions it/src/test/java/apoc/it/common/UtilIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,18 @@
import org.junit.Assert;
import org.junit.Test;
import org.junit.jupiter.api.AfterEach;
import org.mockito.stubbing.Answer;
import org.neo4j.configuration.Config;
import org.neo4j.configuration.GraphDatabaseInternalSettings;
import org.neo4j.graphdb.security.URLAccessValidationError;
import org.neo4j.graphdb.security.URLAccessChecker;

import org.mockito.stubbing.Answer;
import org.neo4j.graphdb.security.URLAccessValidationError;
import org.testcontainers.containers.GenericContainer;

public class UtilIT {
private GenericContainer httpServer;

public UtilIT() throws Exception {
googleUrl = new URL( "https://www.google.com" );
googleUrl = new URL("https://www.google.com");
}

private GenericContainer setUpServer(String redirectURL) {
Expand Down Expand Up @@ -81,26 +80,28 @@ public void redirectShouldWorkWhenProtocolNotChangesWithUrlLocation() throws Exc

// given
URL url = getServerUrl(httpServer);
when( mockChecker.checkURL( url ) ).thenReturn( url );
when( mockChecker.checkURL( googleUrl ) ).thenReturn( googleUrl );
when(mockChecker.checkURL(url)).thenReturn(url);
when(mockChecker.checkURL(googleUrl)).thenReturn(googleUrl);

// when
String page = IOUtils.toString( Util.openInputStream(url.toString(), null, null, null, mockChecker ), StandardCharsets.UTF_8);
String page = IOUtils.toString(
Util.openInputStream(url.toString(), null, null, null, mockChecker), StandardCharsets.UTF_8);

// then
assertTrue(page.contains("<title>Google</title>"));
}

@Test
public void redirectWithBlockedIPsWithUrlLocation() throws Exception{
public void redirectWithBlockedIPsWithUrlLocation() throws Exception {
URLAccessChecker mockChecker = mock(URLAccessChecker.class);

httpServer = setUpServer("http://127.168.0.1");
URL url = getServerUrl(httpServer);
when( mockChecker.checkURL( url ) ).thenReturn( url );
when( mockChecker.checkURL( new URL("http://127.168.0.1") ) ).thenThrow( new URLAccessValidationError( "no" ) );
when(mockChecker.checkURL(url)).thenReturn(url);
when(mockChecker.checkURL(new URL("http://127.168.0.1"))).thenThrow(new URLAccessValidationError("no"));

IOException e = Assert.assertThrows(IOException.class, () -> Util.openInputStream(url.toString(), null, null, null, mockChecker));
IOException e = Assert.assertThrows(
IOException.class, () -> Util.openInputStream(url.toString(), null, null, null, mockChecker));
TestCase.assertTrue(e.getMessage().contains("no"));
}

Expand All @@ -109,11 +110,12 @@ public void redirectWithProtocolUpgradeIsAllowed() throws Exception {
URLAccessChecker mockChecker = mock(URLAccessChecker.class);
httpServer = setUpServer("https://www.google.com");
URL url = getServerUrl(httpServer);
when( mockChecker.checkURL( url ) ).thenReturn( url );
when( mockChecker.checkURL( googleUrl ) ).thenReturn( googleUrl );
when(mockChecker.checkURL(url)).thenReturn(url);
when(mockChecker.checkURL(googleUrl)).thenReturn(googleUrl);

// when
String page = IOUtils.toString( Util.openInputStream(url.toString(), null, null, null, mockChecker), StandardCharsets.UTF_8 );
String page = IOUtils.toString(
Util.openInputStream(url.toString(), null, null, null, mockChecker), StandardCharsets.UTF_8);

// then
assertTrue(page.contains("<title>Google</title>"));
Expand All @@ -136,7 +138,8 @@ public void shouldFailForExceedingRedirectLimit() throws Exception {
URLAccessChecker mockChecker = mock(URLAccessChecker.class);
httpServer = setUpServer("https://127.0.0.0");
URL url = getServerUrl(httpServer);
when( mockChecker.checkURL( any() ) ).thenAnswer( (Answer<URL>) invocation -> (URL) invocation.getArguments()[0] );
when(mockChecker.checkURL(any()))
.thenAnswer((Answer<URL>) invocation -> (URL) invocation.getArguments()[0]);

ArrayList<GenericContainer> servers = new ArrayList<>();
for (int i = 1; i <= 10; i++) {
Expand All @@ -146,7 +149,8 @@ public void shouldFailForExceedingRedirectLimit() throws Exception {
}

URL finalUrl = url;
IOException e = Assert.assertThrows(IOException.class, () -> Util.openInputStream(finalUrl.toString(), null, null, null, mockChecker));
IOException e = Assert.assertThrows(
IOException.class, () -> Util.openInputStream(finalUrl.toString(), null, null, null, mockChecker));

TestCase.assertTrue(e.getMessage().contains("Redirect limit exceeded"));

Expand All @@ -161,19 +165,19 @@ public void redirectShouldThrowExceptionWhenProtocolChangesWithFileLocation() th
httpServer = setUpServer("file:/etc/passwd");
// given
URL url = getServerUrl(httpServer);
when( mockChecker.checkURL( url ) ).thenReturn( url );
when(mockChecker.checkURL(url)).thenReturn(url);
Config neo4jConfig = mock(Config.class);
when(neo4jConfig.get(GraphDatabaseInternalSettings.cypher_ip_blocklist)).thenReturn(Collections.emptyList());

// when
RuntimeException e =
Assert.assertThrows(RuntimeException.class, () -> Util.openInputStream(url.toString(), null, null, null, mockChecker));
RuntimeException e = Assert.assertThrows(
RuntimeException.class, () -> Util.openInputStream(url.toString(), null, null, null, mockChecker));

assertEquals("The redirect URI has a different protocol: file:/etc/passwd", e.getMessage());
}

private URL getServerUrl(GenericContainer httpServer) throws MalformedURLException
{
return new URL(String.format("http://%s:%s", httpServer.getContainerIpAddress(), httpServer.getMappedPort(8000)));
private URL getServerUrl(GenericContainer httpServer) throws MalformedURLException {
return new URL(
String.format("http://%s:%s", httpServer.getContainerIpAddress(), httpServer.getMappedPort(8000)));
}
}
28 changes: 7 additions & 21 deletions it/src/test/java/apoc/it/core/TriggerEnterpriseFeaturesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import static apoc.ApocConfig.APOC_CONFIG_INITIALIZER;
import static apoc.ApocConfig.APOC_TRIGGER_ENABLED;
import static apoc.SystemPropertyKeys.database;
import static apoc.it.core.CreateTriggers.CreateTrigger;
import static apoc.it.core.CreateAndDropTriggers.CreateTrigger;
import static apoc.trigger.TriggerHandler.TRIGGER_REFRESH;
import static apoc.trigger.TriggerTestUtil.TIMEOUT;
import static apoc.trigger.TriggerTestUtil.TRIGGER_DEFAULT_REFRESH;
Expand Down Expand Up @@ -350,10 +350,8 @@ public void stressTest() throws InterruptedException, ExecutionException, Timeou
try {
final var nodesFuture1 = executor.submit(createNodes);
final var nodesFuture2 = executor.submit(createNodes);
final var createTriggersFuture = executor.submit(new CreateTriggers(driver, db, iterations));
final var removeTriggersFuture = executor.submit(new DropRandomTriggers(driver, db, iterations));
final var createTriggersFuture = executor.submit(new CreateAndDropTriggers(driver, db, iterations));
createTriggersFuture.get(5, TimeUnit.MINUTES);
removeTriggersFuture.get(30, TimeUnit.SECONDS);
createNodes.stop();
nodesFuture1.get(30, TimeUnit.SECONDS);
nodesFuture2.get(30, TimeUnit.SECONDS);
Expand Down Expand Up @@ -435,32 +433,20 @@ public void stop() {
}
}

record CreateTriggers(Driver driver, String db, int iterations) implements Runnable {
record CreateAndDropTriggers(Driver driver, String db, int iterations) implements Runnable {
public static final String CreateTrigger = "call apoc.trigger.install($db, $name, $trigger,{})";

@Override
public void run() {
try (final var session = driver.session(forDatabase(SYSTEM_DATABASE_NAME))) {
for (int i = 0; i < iterations; ++i) {
final var name = "temp-trigger-" + i;
session.run(CreateTrigger, Map.of("db", db, "name", name, "trigger", "RETURN 1"))
.consume();
}
}
}
}

record DropRandomTriggers(Driver driver, String db, int iterations) implements Runnable {
static final String DropTrigger = "call apoc.trigger.drop($db, $name)";

@Override
public void run() {
final var rand = new Random();
try (final var session = driver.session(forDatabase(SYSTEM_DATABASE_NAME))) {
for (int i = 0; i < iterations; ++i) {
final var name = "temp-trigger-" + rand.nextInt(iterations);
session.run(DropTrigger, Map.<String, Object>of("db", db, "name", name))
final var name = "temp-trigger-" + i;
session.run(CreateTrigger, Map.of("db", db, "name", name, "trigger", "RETURN 1"))
.consume();
final var deleteName = "temp-trigger-" + rand.nextInt(iterations);
session.run(DropTrigger, Map.of("db", db, "name", deleteName)).consume();
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion test-utils/src/main/java/apoc/trigger/TriggerTestUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.neo4j.graphdb.GraphDatabaseService;

public class TriggerTestUtil {
public static final long TIMEOUT = 10L;
public static final long TIMEOUT = 20L;
public static final long TRIGGER_DEFAULT_REFRESH = 3000;

public static void awaitTriggerDiscovered(GraphDatabaseService db, String name, String query) {
Expand Down

0 comments on commit e8ac28f

Please sign in to comment.