diff --git a/build.gradle b/build.gradle index 550f21d3..636c745e 100644 --- a/build.gradle +++ b/build.gradle @@ -72,7 +72,7 @@ ext { argparse4jVersion = '0.7.0' junitVersion = '4.12' evaluatorVersion = '3.5.4' - neo4jJavaDriverVersion = '2.0.0-alpha01' + neo4jJavaDriverVersion = '2.0.0-alpha02' findbugsVersion = '3.0.0' jansiVersion = '1.13' jlineVersion = '2.14.6' diff --git a/cypher-shell/src/integration-test/java/org/neo4j/shell/commands/CypherShellVerboseIntegrationTest.java b/cypher-shell/src/integration-test/java/org/neo4j/shell/commands/CypherShellVerboseIntegrationTest.java index cfcbdad7..a50cd019 100644 --- a/cypher-shell/src/integration-test/java/org/neo4j/shell/commands/CypherShellVerboseIntegrationTest.java +++ b/cypher-shell/src/integration-test/java/org/neo4j/shell/commands/CypherShellVerboseIntegrationTest.java @@ -200,7 +200,7 @@ public void cypherWithOrder() throws CommandException { //then String actual = linePrinter.output(); - assertThat( actual, containsString( "Ordered by" ) ); + assertThat( actual, containsString( "Order" ) ); assertThat( actual, containsString( "n.age ASC" ) ); } diff --git a/cypher-shell/src/main/java/org/neo4j/shell/CypherShell.java b/cypher-shell/src/main/java/org/neo4j/shell/CypherShell.java index 6b26f7d1..a19cdbb4 100644 --- a/cypher-shell/src/main/java/org/neo4j/shell/CypherShell.java +++ b/cypher-shell/src/main/java/org/neo4j/shell/CypherShell.java @@ -3,6 +3,7 @@ import org.neo4j.cypher.internal.evaluator.EvaluationException; import org.neo4j.cypher.internal.evaluator.Evaluator; import org.neo4j.cypher.internal.evaluator.ExpressionEvaluator; +import org.neo4j.driver.exceptions.Neo4jException; import org.neo4j.shell.commands.Command; import org.neo4j.shell.commands.CommandExecutable; import org.neo4j.shell.commands.CommandHelper; @@ -37,6 +38,7 @@ public class CypherShell implements StatementExecuter, Connector, TransactionHan private final PrettyPrinter prettyPrinter; private CommandHelper commandHelper; private ExpressionEvaluator evaluator = Evaluator.expressionEvaluator(); + private String lastNeo4jErrorCode; public CypherShell(@Nonnull LinePrinter linePrinter, @Nonnull PrettyConfig prettyConfig) { this(linePrinter, new BoltStateHandler(), new PrettyPrinter(prettyConfig)); @@ -80,6 +82,11 @@ public void execute(@Nonnull final String cmdString) throws ExitException, Comma executeCypher(cmdString); } + @Override + public String lastNeo4jErrorCode() { + return lastNeo4jErrorCode; + } + /** * Executes a piece of text as if it were Cypher. By default, all of the cypher is executed in single statement * (with an implicit transaction). @@ -87,8 +94,14 @@ public void execute(@Nonnull final String cmdString) throws ExitException, Comma * @param cypher non-empty cypher text to executeLine */ private void executeCypher(@Nonnull final String cypher) throws CommandException { - final Optional result = boltStateHandler.runCypher(cypher, allParameterValues()); - result.ifPresent(boltResult -> prettyPrinter.format(boltResult, linePrinter)); + try { + final Optional result = boltStateHandler.runCypher( cypher, allParameterValues() ); + result.ifPresent(boltResult -> prettyPrinter.format(boltResult, linePrinter)); + lastNeo4jErrorCode = null; + } catch (Neo4jException e) { + lastNeo4jErrorCode = e.code(); + throw e; + } } @Override @@ -142,9 +155,15 @@ public void beginTransaction() throws CommandException { @Override public Optional> commitTransaction() throws CommandException { - Optional> results = boltStateHandler.commitTransaction(); - results.ifPresent(boltResult -> boltResult.forEach(result -> prettyPrinter.format(result, linePrinter))); - return results; + try { + Optional> results = boltStateHandler.commitTransaction(); + results.ifPresent(boltResult -> boltResult.forEach(result -> prettyPrinter.format(result, linePrinter))); + lastNeo4jErrorCode = null; + return results; + } catch (Neo4jException e) { + lastNeo4jErrorCode = e.code(); + throw e; + } } @Override @@ -202,12 +221,24 @@ protected void addRuntimeHookToResetShell() { @Override public void setActiveDatabase(String databaseName) throws CommandException { - boltStateHandler.setActiveDatabase(databaseName); + try { + boltStateHandler.setActiveDatabase(databaseName); + lastNeo4jErrorCode = null; + } catch (Neo4jException e) { + lastNeo4jErrorCode = e.code(); + throw e; + } + } + + @Override + public String getActiveDatabaseAsSetByUser() + { + return boltStateHandler.getActiveDatabaseAsSetByUser(); } @Override - public String getActiveDatabase() + public String getActualDatabaseAsReportedByServer() { - return boltStateHandler.getActiveDatabase(); + return boltStateHandler.getActualDatabaseAsReportedByServer(); } } diff --git a/cypher-shell/src/main/java/org/neo4j/shell/DatabaseManager.java b/cypher-shell/src/main/java/org/neo4j/shell/DatabaseManager.java index bcd36d65..56c5e823 100644 --- a/cypher-shell/src/main/java/org/neo4j/shell/DatabaseManager.java +++ b/cypher-shell/src/main/java/org/neo4j/shell/DatabaseManager.java @@ -8,10 +8,15 @@ public interface DatabaseManager { String ABSENT_DB_NAME = ""; - String DEFAULT_DEFAULT_DB_NAME = "neo4j"; String SYSTEM_DB_NAME = "system"; + String DEFAULT_DEFAULT_DB_NAME = "neo4j"; + + String DATABASE_NOT_FOUND_ERROR_CODE = "Neo.ClientError.Database.DatabaseNotFound"; + String DATABASE_UNAVAILABLE_ERROR_CODE = "Neo.TransientError.General.DatabaseUnavailable"; void setActiveDatabase(String databaseName) throws CommandException; - String getActiveDatabase(); + String getActiveDatabaseAsSetByUser(); + + String getActualDatabaseAsReportedByServer(); } diff --git a/cypher-shell/src/main/java/org/neo4j/shell/StatementExecuter.java b/cypher-shell/src/main/java/org/neo4j/shell/StatementExecuter.java index ab5227cd..19825f08 100644 --- a/cypher-shell/src/main/java/org/neo4j/shell/StatementExecuter.java +++ b/cypher-shell/src/main/java/org/neo4j/shell/StatementExecuter.java @@ -22,4 +22,9 @@ public interface StatementExecuter { * Stops any running statements */ void reset(); + + /** + * Get the error code from the last executed Cypher statement, or null if the last execution was successful. + */ + String lastNeo4jErrorCode(); } diff --git a/cypher-shell/src/main/java/org/neo4j/shell/cli/InteractiveShellRunner.java b/cypher-shell/src/main/java/org/neo4j/shell/cli/InteractiveShellRunner.java index 03eb947e..4318584e 100644 --- a/cypher-shell/src/main/java/org/neo4j/shell/cli/InteractiveShellRunner.java +++ b/cypher-shell/src/main/java/org/neo4j/shell/cli/InteractiveShellRunner.java @@ -26,8 +26,9 @@ import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; -import static org.neo4j.driver.internal.messaging.request.MultiDatabaseUtil.ABSENT_DB_NAME; -import static org.neo4j.shell.DatabaseManager.DEFAULT_DEFAULT_DB_NAME; +import static org.neo4j.shell.DatabaseManager.ABSENT_DB_NAME; +import static org.neo4j.shell.DatabaseManager.DATABASE_NOT_FOUND_ERROR_CODE; +import static org.neo4j.shell.DatabaseManager.DATABASE_UNAVAILABLE_ERROR_CODE; /** * A shell runner intended for interactive sessions where lines are input one by one and execution should happen @@ -39,6 +40,10 @@ public class InteractiveShellRunner implements ShellRunner, SignalHandler { private final static String TRANSACTION_PROMPT = "# "; private final static String USERNAME_DB_DELIMITER = "@"; private final static int ONELINE_PROMPT_MAX_LENGTH = 50; + private static final String UNRESOLVED_DEFAULT_DB_PROPMPT_TEXT = ""; + private static final String DATABASE_NOT_FOUND_ERROR_PROMPT_TEXT = "[NOT_FOUND]"; + private static final String DATABASE_UNAVAILABLE_ERROR_PROMPT_TEXT = "[UNAVAILABLE]"; + // Need to know if we are currently executing when catch Ctrl-C, needs to be atomic due to // being called from different thread private final AtomicBoolean currentlyExecuting; @@ -162,19 +167,23 @@ AnsiFormattedText updateAndGetPrompt() { return continuationPrompt; } - String databaseName = databaseManager.getActiveDatabase(); + String databaseName = databaseManager.getActualDatabaseAsReportedByServer(); + if (databaseName == null) { + // We have failed to get a successful response from the connection ping query + // Build the prompt from the db name as set by the user + a suffix indicating that we are in a disconnected state + String dbNameSetByUser = databaseManager.getActiveDatabaseAsSetByUser(); + databaseName = ABSENT_DB_NAME.equals(dbNameSetByUser)? UNRESOLVED_DEFAULT_DB_PROPMPT_TEXT : dbNameSetByUser; + } else if (ABSENT_DB_NAME.equals(databaseName)) { + // The driver did not give us a database name in the response from the connection ping query + databaseName = UNRESOLVED_DEFAULT_DB_PROPMPT_TEXT; + } - // Substitute empty name for the default default-database-name - // For now we just use a hard-coded default name - // Ideally we would like to receive the actual name in the ResultSummary when we connect (in BoltStateHandler.reconnect()) - // (If the user is an admin we could also query for the default database config value with: - // "CALL dbms.listConfig() YIELD name, value WHERE name = "dbms.default_database" RETURN value" - // but that does not work in general) - databaseName = ABSENT_DB_NAME.equals(databaseName) ? DEFAULT_DEFAULT_DB_NAME : databaseName; + String errorSuffix = getErrorPrompt(executer.lastNeo4jErrorCode()); int promptIndent = connectionConfig.username().length() + USERNAME_DB_DELIMITER.length() + databaseName.length() + + errorSuffix.length() + FRESH_PROMPT.length(); AnsiFormattedText prePrompt = AnsiFormattedText.s().bold() @@ -182,6 +191,11 @@ AnsiFormattedText updateAndGetPrompt() { .append("@") .append(databaseName); + // If we encountered an error with the connection ping query we display it in the prompt in RED + if (!errorSuffix.isEmpty()) { + prePrompt.colorRed().append(errorSuffix).colorDefault(); + } + if (promptIndent <= ONELINE_PROMPT_MAX_LENGTH) { continuationPrompt = AnsiFormattedText.s().bold().append(OutputFormatter.repeat(' ', promptIndent)); return prePrompt @@ -195,6 +209,19 @@ AnsiFormattedText updateAndGetPrompt() { } } + private String getErrorPrompt(String errorCode) { + // NOTE: errorCode can be null + String errorPromptSuffix; + if (DATABASE_NOT_FOUND_ERROR_CODE.equals(errorCode)) { + errorPromptSuffix = DATABASE_NOT_FOUND_ERROR_PROMPT_TEXT; + } else if (DATABASE_UNAVAILABLE_ERROR_CODE.equals(errorCode)) { + errorPromptSuffix = DATABASE_UNAVAILABLE_ERROR_PROMPT_TEXT; + } else { + errorPromptSuffix = ""; + } + return errorPromptSuffix; + } + /** * Catch Ctrl-C from user and handle it nicely * diff --git a/cypher-shell/src/main/java/org/neo4j/shell/state/BoltStateHandler.java b/cypher-shell/src/main/java/org/neo4j/shell/state/BoltStateHandler.java index d764e94f..5ed81344 100644 --- a/cypher-shell/src/main/java/org/neo4j/shell/state/BoltStateHandler.java +++ b/cypher-shell/src/main/java/org/neo4j/shell/state/BoltStateHandler.java @@ -12,6 +12,7 @@ import org.neo4j.driver.StatementResult; import org.neo4j.driver.Transaction; import org.neo4j.driver.exceptions.SessionExpiredException; +import org.neo4j.driver.summary.DatabaseInfo; import org.neo4j.shell.ConnectionConfig; import org.neo4j.shell.Connector; import org.neo4j.shell.DatabaseManager; @@ -36,10 +37,11 @@ public class BoltStateHandler implements TransactionHandler, Connector, DatabaseManager { private final TriFunction driverProvider; protected Driver driver; - protected Session session; + Session session; private String version; private List transactionStatements; - private String activeDatabaseName; + private String activeDatabaseNameAsSetByUser; + private String actualDatabaseNameAsReportedByServer; public BoltStateHandler() { this(GraphDatabase::driver); @@ -47,7 +49,7 @@ public BoltStateHandler() { BoltStateHandler(TriFunction driverProvider) { this.driverProvider = driverProvider; - activeDatabaseName = ""; + activeDatabaseNameAsSetByUser = ABSENT_DB_NAME; } @Override @@ -56,16 +58,22 @@ public void setActiveDatabase(String databaseName) throws CommandException if (isTransactionOpen()) { throw new CommandException("There is an open transaction. You need to close it before you can switch database."); } - activeDatabaseName = databaseName; + activeDatabaseNameAsSetByUser = databaseName; if (isConnected()) { reconnect(false); } } @Override - public String getActiveDatabase() + public String getActiveDatabaseAsSetByUser() { - return activeDatabaseName; + return activeDatabaseNameAsSetByUser; + } + + @Override + public String getActualDatabaseAsReportedByServer() + { + return actualDatabaseNameAsReportedByServer; } @Override @@ -146,14 +154,21 @@ private void reconnect(boolean keepBookmark) { session.close(); sessionOptionalArgs = t -> t.withBookmarks(bookmark); } - Consumer sessionArgs = t -> t.withDefaultAccessMode(AccessMode.WRITE).withDatabase(activeDatabaseName); + Consumer sessionArgs = t -> { + t.withDefaultAccessMode(AccessMode.WRITE); + if (!ABSENT_DB_NAME.equals(activeDatabaseNameAsSetByUser)) { + t.withDatabase(activeDatabaseNameAsSetByUser); + } + }; session = driver.session(sessionArgs.andThen(sessionOptionalArgs)); - String query = activeDatabaseName.equals(SYSTEM_DB_NAME) ? "SHOW DATABASES" : "RETURN 1"; + String query = activeDatabaseNameAsSetByUser.equals(SYSTEM_DB_NAME) ? "SHOW DATABASES" : "RETURN 1"; + + resetActualDbName(); // Set this to null first in case run throws an exception StatementResult run = session.run(query); + this.version = run.summary().server().version(); - // It would be nice if we could also get the actual database name here, in the case where we used ABSENT_DB_NAME - run.consume(); + updateActualDbName(run); } @Nonnull @@ -207,9 +222,24 @@ private Optional getBoltResult(@Nonnull String cypher, @Nonnull Map< return Optional.empty(); } + updateActualDbName(statementResult); + return Optional.of(new StatementBoltResult(statementResult)); } + private String getActualDbName(@Nonnull StatementResult statementResult) { + DatabaseInfo dbInfo = statementResult.summary().database(); + return dbInfo.name() == null ? ABSENT_DB_NAME : dbInfo.name(); + } + + private void updateActualDbName(@Nonnull StatementResult statementResult) { + actualDatabaseNameAsReportedByServer = getActualDbName(statementResult); + } + + private void resetActualDbName() { + actualDatabaseNameAsReportedByServer = null; + } + /** * Disconnect from Neo4j, clearing up any session resources, but don't give any output. * Intended only to be used if connect fails. @@ -225,6 +255,7 @@ void silentDisconnect() { } finally { session = null; driver = null; + resetActualDbName(); } } @@ -265,7 +296,9 @@ private Optional> captureResults(@Nonnull List trans List results = executeWithRetry(transactionStatements, (statement, transaction) -> { // calling list() is what actually executes cypher on the server StatementResult sr = transaction.run(statement); - return new ListBoltResult(sr.list(), sr.consume(), sr.keys()); + BoltResult singleResult = new ListBoltResult(sr.list(), sr.consume(), sr.keys()); + updateActualDbName(sr); + return singleResult; }); clearTransactionStatements(); diff --git a/cypher-shell/src/test/java/org/neo4j/shell/cli/InteractiveShellRunnerTest.java b/cypher-shell/src/test/java/org/neo4j/shell/cli/InteractiveShellRunnerTest.java index 1445c2c4..ab8fd3f7 100644 --- a/cypher-shell/src/test/java/org/neo4j/shell/cli/InteractiveShellRunnerTest.java +++ b/cypher-shell/src/test/java/org/neo4j/shell/cli/InteractiveShellRunnerTest.java @@ -77,7 +77,7 @@ public void setup() throws Exception { historyFile = temp.newFile(); badLineError = new ClientException("Found a bad line"); userMessagesHandler = mock(UserMessagesHandler.class); - when(databaseManager.getActiveDatabase()).thenReturn("mydb"); + when(databaseManager.getActualDatabaseAsReportedByServer()).thenReturn("mydb"); when(userMessagesHandler.getWelcomeMessage()).thenReturn("Welcome to cypher-shell!"); when(userMessagesHandler.getExitMessage()).thenReturn("Exit message"); when(connectionConfig.username()).thenReturn("myusername"); @@ -96,6 +96,7 @@ public void testSimple() throws Exception { verify(cmdExecuter).execute("good1;"); verify(cmdExecuter).execute("\ngood2;"); + verify(cmdExecuter, times(3)).lastNeo4jErrorCode(); verifyNoMoreInteractions(cmdExecuter); } @@ -118,6 +119,7 @@ public void runUntilEndShouldKeepGoingOnErrors() throws IOException, CommandExce verify(cmdExecuter).execute("\ngood2;"); verify(cmdExecuter).execute("\nbad2;"); verify(cmdExecuter).execute("\ngood3;"); + verify(cmdExecuter, times(6)).lastNeo4jErrorCode(); verifyNoMoreInteractions(cmdExecuter); verify(logger, times(2)).printError(badLineError); @@ -144,6 +146,7 @@ public void runUntilEndShouldStopOnExitExceptionAndReturnCode() throws IOExcepti verify(cmdExecuter).execute("\nbad1;"); verify(cmdExecuter).execute("\ngood2;"); verify(cmdExecuter).execute("\nexit;"); + verify(cmdExecuter, times(4)).lastNeo4jErrorCode(); verifyNoMoreInteractions(cmdExecuter); verify(logger).printError(badLineError); @@ -293,8 +296,8 @@ public void testPrompt() throws Exception { public void testLongPrompt() throws Exception { // given InputStream inputStream = new ByteArrayInputStream("".getBytes()); - String dbName = "TheLongestDbNameEverCreatedInAllOfHistoryAndTheUniversePlusSome"; - when(databaseManager.getActiveDatabase()).thenReturn(dbName); + String actualDbName = "TheLongestDbNameEverCreatedInAllOfHistoryAndTheUniversePlusSome"; + when(databaseManager.getActualDatabaseAsReportedByServer()).thenReturn(actualDbName); InteractiveShellRunner runner = new InteractiveShellRunner(cmdExecuter, txHandler, databaseManager, logger, statementParser, inputStream, historyFile, userMessagesHandler, connectionConfig); @@ -303,7 +306,7 @@ public void testLongPrompt() throws Exception { AnsiFormattedText prompt = runner.updateAndGetPrompt(); // then - String wantedPrompt = format("myusername@%s%n> ", dbName); + String wantedPrompt = format("myusername@%s%n> ", actualDbName); assertEquals(wantedPrompt, prompt.plainString()); // when @@ -363,6 +366,7 @@ public void multilineRequiresNewLineOrSemicolonToEnd() throws Exception { runner.runUntilEnd(); // then + verify(cmdExecuter).lastNeo4jErrorCode(); verifyNoMoreInteractions(cmdExecuter); } @@ -423,6 +427,7 @@ public void testSignalHandleOutsideExecution() throws Exception { runner.handle(new Signal(InteractiveShellRunner.INTERRUPT_SIGNAL)); // then + verify(cmdExecuter).lastNeo4jErrorCode(); verifyNoMoreInteractions(cmdExecuter); verify(logger).printError("@|RED \nInterrupted (Note that Cypher queries must end with a |@" + "@|RED,BOLD semicolon. |@" + @@ -484,8 +489,13 @@ public void reset() { } @Override - public String getActiveDatabase() { + public String getActiveDatabaseAsSetByUser() { return ABSENT_DB_NAME; } + + @Override + public String getActualDatabaseAsReportedByServer() { + return DEFAULT_DEFAULT_DB_NAME; + } } } diff --git a/cypher-shell/src/test/java/org/neo4j/shell/prettyprint/OutputFormatterTest.java b/cypher-shell/src/test/java/org/neo4j/shell/prettyprint/OutputFormatterTest.java index 6e3c3029..38b596cb 100644 --- a/cypher-shell/src/test/java/org/neo4j/shell/prettyprint/OutputFormatterTest.java +++ b/cypher-shell/src/test/java/org/neo4j/shell/prettyprint/OutputFormatterTest.java @@ -7,6 +7,7 @@ import java.util.Map; import org.neo4j.driver.internal.BoltServerAddress; +import org.neo4j.driver.internal.summary.InternalDatabaseInfo; import org.neo4j.driver.internal.summary.InternalResultSummary; import org.neo4j.driver.internal.summary.InternalServerInfo; import org.neo4j.driver.internal.util.ServerVersion; @@ -35,6 +36,7 @@ public void shouldReportTotalDBHits() { ResultSummary summary = new InternalResultSummary( new Statement( "PROFILE MATCH (n:LABEL) WHERE 20 < n.age < 35 return n" ), new InternalServerInfo( new BoltServerAddress( "localhost:7687" ), ServerVersion.vInDev ), + new InternalDatabaseInfo("neo4j"), StatementType.READ_ONLY, null, plan, diff --git a/cypher-shell/src/test/java/org/neo4j/shell/state/BoltStateHandlerTest.java b/cypher-shell/src/test/java/org/neo4j/shell/state/BoltStateHandlerTest.java index e1777da4..1093807b 100644 --- a/cypher-shell/src/test/java/org/neo4j/shell/state/BoltStateHandlerTest.java +++ b/cypher-shell/src/test/java/org/neo4j/shell/state/BoltStateHandlerTest.java @@ -12,7 +12,9 @@ import org.neo4j.driver.Statement; import org.neo4j.driver.StatementResult; import org.neo4j.driver.Value; +import org.neo4j.driver.exceptions.ClientException; import org.neo4j.driver.exceptions.SessionExpiredException; +import org.neo4j.driver.summary.DatabaseInfo; import org.neo4j.driver.summary.ResultSummary; import org.neo4j.driver.summary.ServerInfo; import org.neo4j.shell.ConnectionConfig; @@ -44,7 +46,8 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.neo4j.driver.internal.messaging.request.MultiDatabaseUtil.ABSENT_DB_NAME; +import static org.neo4j.shell.DatabaseManager.ABSENT_DB_NAME; +import static org.neo4j.shell.DatabaseManager.DEFAULT_DEFAULT_DB_NAME; public class BoltStateHandlerTest { @Rule @@ -84,7 +87,7 @@ public Driver apply(String uri, AuthToken authToken, Config config) { @Test public void versionIsNotEmptyAfterConnect() throws CommandException { - Driver driverMock = stubVersionInAnOpenSession(mock(StatementResult.class), mock(Session.class), "Neo4j/9.4.1-ALPHA"); + Driver driverMock = stubResultSummaryInAnOpenSession(mock(StatementResult.class), mock(Session.class), "Neo4j/9.4.1-ALPHA"); BoltStateHandler handler = new BoltStateHandler((s, authToken, config) -> driverMock); ConnectionConfig config = new ConnectionConfig("bolt://", "", -1, "", "", false, ABSENT_DB_NAME); @@ -93,6 +96,44 @@ public void versionIsNotEmptyAfterConnect() throws CommandException { assertEquals("9.4.1-ALPHA", handler.getServerVersion()); } + @Test + public void actualDatabaseNameIsNotEmptyAfterConnect() throws CommandException { + Driver driverMock = + stubResultSummaryInAnOpenSession(mock(StatementResult.class), mock(Session.class), "Neo4j/9.4.1-ALPHA", "my_default_db"); + + BoltStateHandler handler = new BoltStateHandler((s, authToken, config) -> driverMock); + ConnectionConfig config = new ConnectionConfig("bolt://", "", -1, "", "", false, ABSENT_DB_NAME); + handler.connect(config); + + assertEquals("my_default_db", handler.getActualDatabaseAsReportedByServer()); + } + + @Test + public void exceptionFromRunQueryDoesNotResetActualDatabaseNameToUnresolved() throws CommandException { + Session sessionMock = mock(Session.class); + StatementResult resultMock = mock(StatementResult.class); + Driver driverMock = + stubResultSummaryInAnOpenSession(resultMock, sessionMock, "Neo4j/9.4.1-ALPHA", "my_default_db"); + + ClientException databaseNotFound = new ClientException("Neo.ClientError.Database.DatabaseNotFound", "blah"); + + when(sessionMock.run(any(Statement.class))) + .thenThrow(databaseNotFound) + .thenReturn(resultMock); + + BoltStateHandler handler = new BoltStateHandler((s, authToken, config) -> driverMock); + ConnectionConfig config = new ConnectionConfig("bolt://", "", -1, "", "", false, ABSENT_DB_NAME); + handler.connect(config); + + try { + handler.runCypher("RETURN \"hello\"", Collections.emptyMap()); + fail("should fail on runCypher"); + } catch (Exception e) { + assertThat(e, is(databaseNotFound)); + assertEquals("my_default_db", handler.getActualDatabaseAsReportedByServer()); + } + } + @Test public void closeTransactionAfterRollback() throws CommandException { boltStateHandler.connect(); @@ -113,11 +154,10 @@ public void exceptionsFromSilentDisconnectAreSuppressedToReportOriginalErrors() RuntimeException originalException = new RuntimeException("original exception"); RuntimeException thrownFromSilentDisconnect = new RuntimeException("exception from silent disconnect"); - - Driver mockedDriver = stubVersionInAnOpenSession(resultMock, session, "neo4j-version"); + Driver mockedDriver = stubResultSummaryInAnOpenSession(resultMock, session, "neo4j-version"); OfflineBoltStateHandler boltStateHandler = new OfflineBoltStateHandler(mockedDriver); - when(resultMock.consume()).thenThrow(originalException); + when(resultMock.summary()).thenThrow(originalException); doThrow(thrownFromSilentDisconnect).when(session).close(); try { @@ -171,7 +211,7 @@ public void beginNeedsToInitialiseTransactionStatements() throws CommandExceptio @Test public void commitPurgesTheTransactionStatementsAndCollectsResults() throws CommandException { Session sessionMock = mock(Session.class); - Driver driverMock = stubVersionInAnOpenSession(mock(StatementResult.class), sessionMock, "neo4j-version"); + Driver driverMock = stubResultSummaryInAnOpenSession(mock(StatementResult.class), sessionMock, "neo4j-version"); Record record1 = mock(Record.class); Record record2 = mock(Record.class); @@ -247,12 +287,11 @@ public void shouldExecuteInTransactionIfOpen() throws CommandException { @Test public void shouldRunCypherQuery() throws CommandException { Session sessionMock = mock(Session.class); - StatementResult versionMock = mock(StatementResult.class); StatementResult resultMock = mock(StatementResult.class); Record recordMock = mock(Record.class); Value valueMock = mock(Value.class); - Driver driverMock = stubVersionInAnOpenSession(versionMock, sessionMock, "neo4j-version"); + Driver driverMock = stubResultSummaryInAnOpenSession(resultMock, sessionMock, "neo4j-version"); when(resultMock.list()).thenReturn(asList(recordMock)); @@ -274,12 +313,11 @@ public void shouldRunCypherQuery() throws CommandException { @Test public void triesAgainOnSessionExpired() throws Exception { Session sessionMock = mock(Session.class); - StatementResult versionMock = mock(StatementResult.class); StatementResult resultMock = mock(StatementResult.class); Record recordMock = mock(Record.class); Value valueMock = mock(Value.class); - Driver driverMock = stubVersionInAnOpenSession(versionMock, sessionMock, "neo4j-version"); + Driver driverMock = stubResultSummaryInAnOpenSession(resultMock, sessionMock, "neo4j-version"); when(resultMock.list()).thenReturn(asList(recordMock)); @@ -326,7 +364,7 @@ public void canOnlyConnectOnce() throws CommandException { public void resetSessionOnReset() throws Exception { // given Session sessionMock = mock(Session.class); - Driver driverMock = stubVersionInAnOpenSession(mock(StatementResult.class), sessionMock, "neo4j-version"); + Driver driverMock = stubResultSummaryInAnOpenSession(mock(StatementResult.class), sessionMock, "neo4j-version"); OfflineBoltStateHandler boltStateHandler = new OfflineBoltStateHandler(driverMock); @@ -377,17 +415,24 @@ public void turnOnEncryptionIfRequested() throws CommandException { assertEquals(Config.EncryptionLevel.REQUIRED, provider.config.encryptionLevel()); } - private Driver stubVersionInAnOpenSession(StatementResult versionMock, Session sessionMock, String value) { + private Driver stubResultSummaryInAnOpenSession(StatementResult resultMock, Session sessionMock, String version) { + return stubResultSummaryInAnOpenSession(resultMock, sessionMock, version, DEFAULT_DEFAULT_DB_NAME); + } + + private Driver stubResultSummaryInAnOpenSession(StatementResult resultMock, Session sessionMock, String version, String databaseName) { Driver driverMock = mock(Driver.class); ResultSummary resultSummary = mock(ResultSummary.class); ServerInfo serverInfo = mock(ServerInfo.class); + DatabaseInfo databaseInfo = mock(DatabaseInfo.class); when(resultSummary.server()).thenReturn(serverInfo); - when(serverInfo.version()).thenReturn(value); - when(versionMock.summary()).thenReturn(resultSummary); + when(serverInfo.version()).thenReturn(version); + when(resultMock.summary()).thenReturn(resultSummary); + when(resultSummary.database()).thenReturn(databaseInfo); + when(databaseInfo.name()).thenReturn(databaseName); when(sessionMock.isOpen()).thenReturn(true); - when(sessionMock.run("RETURN 1")).thenReturn(versionMock); + when(sessionMock.run("RETURN 1")).thenReturn(resultMock); when(driverMock.session(any())).thenReturn(sessionMock); return driverMock; diff --git a/cypher-shell/src/test/java/org/neo4j/shell/test/bolt/FakeDriver.java b/cypher-shell/src/test/java/org/neo4j/shell/test/bolt/FakeDriver.java index 54f0cd8a..0dd40752 100644 --- a/cypher-shell/src/test/java/org/neo4j/shell/test/bolt/FakeDriver.java +++ b/cypher-shell/src/test/java/org/neo4j/shell/test/bolt/FakeDriver.java @@ -7,6 +7,7 @@ import org.neo4j.driver.async.AsyncSession; import org.neo4j.driver.exceptions.Neo4jException; import org.neo4j.driver.reactive.RxSession; +import org.neo4j.driver.types.TypeSystem; import java.util.concurrent.CompletionStage; import java.util.function.Consumer; @@ -65,4 +66,10 @@ public AsyncSession asyncSession(Consumer templateCon { return null; } + + @Override + public TypeSystem defaultTypeSystem() + { + return null; + } } diff --git a/cypher-shell/src/test/java/org/neo4j/shell/test/bolt/FakeResultSummary.java b/cypher-shell/src/test/java/org/neo4j/shell/test/bolt/FakeResultSummary.java index b6fd8606..e15cebfd 100644 --- a/cypher-shell/src/test/java/org/neo4j/shell/test/bolt/FakeResultSummary.java +++ b/cypher-shell/src/test/java/org/neo4j/shell/test/bolt/FakeResultSummary.java @@ -80,4 +80,17 @@ public String version() } }; } + + @Override + public DatabaseInfo database() + { + return new DatabaseInfo() + { + @Override + public String name() + { + return null; + } + }; + } } diff --git a/cypher-shell/src/test/java/org/neo4j/shell/test/bolt/FakeSession.java b/cypher-shell/src/test/java/org/neo4j/shell/test/bolt/FakeSession.java index 9e286dfe..92149728 100644 --- a/cypher-shell/src/test/java/org/neo4j/shell/test/bolt/FakeSession.java +++ b/cypher-shell/src/test/java/org/neo4j/shell/test/bolt/FakeSession.java @@ -17,23 +17,18 @@ public Transaction beginTransaction() { } @Override - public Transaction beginTransaction( TransactionConfig config ) + public Transaction beginTransaction(TransactionConfig config) { return null; } - @Override - public Transaction beginTransaction(String bookmark) { - return null; - } - @Override public T readTransaction(TransactionWork work) { return null; } @Override - public T readTransaction( TransactionWork work, TransactionConfig config ) + public T readTransaction(TransactionWork work, TransactionConfig config ) { return null; } @@ -44,25 +39,25 @@ public T writeTransaction(TransactionWork work) { } @Override - public T writeTransaction( TransactionWork work, TransactionConfig config ) + public T writeTransaction(TransactionWork work, TransactionConfig config) { return null; } @Override - public StatementResult run( String statement, TransactionConfig config ) + public StatementResult run(String statement, TransactionConfig config) { return FakeStatementResult.parseStatement(statement); } @Override - public StatementResult run( String statement, Map parameters, TransactionConfig config ) + public StatementResult run(String statement, Map parameters, TransactionConfig config) { return FakeStatementResult.parseStatement(statement); } @Override - public StatementResult run( Statement statement, TransactionConfig config ) + public StatementResult run(Statement statement, TransactionConfig config) { return new FakeStatementResult(); } @@ -110,9 +105,4 @@ public StatementResult run(String statementTemplate) { public StatementResult run(Statement statement) { return new FakeStatementResult(); } - - @Override - public TypeSystem typeSystem() { - return null; - } }