diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 11b9f7bb2..a4189e5b3 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -23,7 +23,7 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.opensearch.action.get.GetResponse import org.opensearch.flint.core.storage.{FlintReader, OpenSearchReader, OpenSearchUpdater} -import org.opensearch.flint.data.FlintStatement +import org.opensearch.flint.data.{FlintStatement, InteractiveSession, SessionStates} import org.opensearch.search.sort.SortOrder import org.scalatestplus.mockito.MockitoSugar @@ -50,7 +50,7 @@ class FlintREPLTest val args = Array("resultIndexName") val (queryOption, resultIndex) = FlintREPL.parseArgs(args) queryOption shouldBe None - resultIndex shouldBe "resultIndexName" + resultIndex shouldBe Some("resultIndexName") } test( @@ -58,16 +58,15 @@ class FlintREPLTest val args = Array("SELECT * FROM table", "resultIndexName") val (queryOption, resultIndex) = FlintREPL.parseArgs(args) queryOption shouldBe Some("SELECT * FROM table") - resultIndex shouldBe "resultIndexName" + resultIndex shouldBe Some("resultIndexName") } test( "parseArgs with no arguments should throw IllegalArgumentException with specific message") { val args = Array.empty[String] - val exception = intercept[IllegalArgumentException] { - FlintREPL.parseArgs(args) - } - exception.getMessage shouldBe "Unsupported number of arguments. Expected 1 or 2 arguments." + val (queryOption, resultIndex) = FlintREPL.parseArgs(args) + queryOption shouldBe None + resultIndex shouldBe None } test( @@ -76,7 +75,7 @@ class FlintREPLTest val exception = intercept[IllegalArgumentException] { FlintREPL.parseArgs(args) } - exception.getMessage shouldBe "Unsupported number of arguments. Expected 1 or 2 arguments." + exception.getMessage shouldBe "Unsupported number of arguments. Expected no more than two arguments." } test("getQuery should return query from queryOption if present") { @@ -131,19 +130,19 @@ class FlintREPLTest test("createHeartBeatUpdater should update heartbeat correctly") { // Mocks - val flintSessionUpdater = mock[OpenSearchUpdater] - val osClient = mock[OSClient] val threadPool = mock[ScheduledExecutorService] - val getResponse = mock[GetResponse] val scheduledFutureRaw = mock[ScheduledFuture[_]] - + val sessionManager = mock[SessionManager] + val sessionId = "session1" + val currentInterval = 1000L + val initialDelayMillis = 0L // when scheduled task is scheduled, execute the runnable immediately only once and become no-op afterwards. when( threadPool.scheduleAtFixedRate( any[Runnable], - eqTo(0), - *, - eqTo(java.util.concurrent.TimeUnit.MILLISECONDS))) + eqTo(initialDelayMillis), + eqTo(currentInterval), + eqTo(TimeUnit.MILLISECONDS))) .thenAnswer((invocation: InvocationOnMock) => { val runnable = invocation.getArgument[Runnable](0) runnable.run() @@ -151,43 +150,43 @@ class FlintREPLTest }) // Invoke the method - FlintREPL.createHeartBeatUpdater() - + FlintREPL.createHeartBeatUpdater( + sessionId, + sessionManager, + currentInterval, + initialDelayMillis, + threadPool) // Verifications - verify(flintSessionUpdater, atLeastOnce()).upsert(eqTo("session1"), *) + verify(sessionManager).recordHeartbeat(sessionId) } test("PreShutdownListener updates FlintInstance if conditions are met") { // Mock dependencies - val osClient = mock[OSClient] - val getResponse = mock[GetResponse] - val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val sessionIndex = "testIndex" val sessionId = "testSessionId" val timerContext = mock[Timer.Context] + val sessionManager = mock[SessionManager] - // Setup the getDoc to return a document indicating the session is running - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) - when(getResponse.getSourceAsMap).thenReturn( - Map[String, Object]( - "applicationId" -> "app1", - "jobId" -> "job1", - "sessionId" -> "session1", - "state" -> "running", - "lastUpdateTime" -> java.lang.Long.valueOf(12345L), - "error" -> "someError", - "state" -> "running", - "jobStartTime" -> java.lang.Long.valueOf(0L)).asJava) + val interactiveSession = new InteractiveSession( + "app123", + "job123", + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000 + ) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) // Instantiate the listener - val listener = new PreShutdownListener(osClient, flintSessionIndexUpdater) + val listener = new PreShutdownListener(sessionId, sessionManager, timerContext) // Simulate application end listener.onApplicationEnd(SparkListenerApplicationEnd(System.currentTimeMillis())) - // Verify the update is called with the correct arguments - verify(flintSessionIndexUpdater).updateIf(*, *, *, *) + verify(sessionManager).updateSessionDetails( + interactiveSession, + SessionUpdateMode.UPDATE_IF + ) + interactiveSession.state shouldBe SessionStates.DEAD } test("Test getFailedData method") { @@ -256,21 +255,22 @@ class FlintREPLTest } } - test("test canPickNextStatement: Doc Exists and Valid JobId") { + test("test canPickNextStatement: Valid jobId") { val sessionId = "session123" val jobId = "jobABC" - val osClient = mock[OSClient] - val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] - val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) - - val sourceMap = new java.util.HashMap[String, Object]() - sourceMap.put("jobId", jobId.asInstanceOf[Object]) - when(getResponse.getSourceAsMap).thenReturn(sourceMap) + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000 + ) + when(sessionManager.getSessionDetails(sessionId)).thenReturn(Some(interactiveSession)) - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) assert(result) } @@ -279,19 +279,19 @@ class FlintREPLTest val sessionId = "session123" val jobId = "jobABC" val differentJobId = "jobXYZ" - val osClient = mock[OSClient] - val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] - val getResponse = mock[GetResponse] - when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(true) - - val sourceMap = new java.util.HashMap[String, Object]() - sourceMap.put("jobId", differentJobId.asInstanceOf[Object]) - when(getResponse.getSourceAsMap).thenReturn(sourceMap) + val interactiveSession = new InteractiveSession( + "app123", + jobId, + sessionId, + SessionStates.RUNNING, + System.currentTimeMillis(), + System.currentTimeMillis() - 10000 + ) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, differentJobId) // Assertions assert(!result) // The function should return false @@ -302,6 +302,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -316,7 +317,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) // Assertions assert(!result) // The function should return false because jobId is excluded @@ -327,6 +328,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] // Mock the getDoc response val getResponse = mock[GetResponse] @@ -335,7 +337,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(null) // Simulate the source being null // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) // Assertions assert(result) // The function should return true despite the null source @@ -346,7 +348,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" - + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) @@ -360,7 +362,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) assert(result) // The function should return true } @@ -370,6 +372,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] // Set up the mock GetResponse val getResponse = mock[GetResponse] @@ -377,7 +380,7 @@ class FlintREPLTest when(getResponse.isExists()).thenReturn(false) // Simulate the document does not exist // Execute the function under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) // Assert the function returns true assert(result) @@ -388,13 +391,14 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] // Set up the mock OSClient to throw an exception when(osClient.getDoc(sessionIndex, sessionId)) .thenThrow(new RuntimeException("OpenSearch cluster unresponsive")) // Execute the method under test and expect true, since the method is designed to return true even in case of an exception - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) // Verify the result is true despite the exception assert(result) @@ -407,6 +411,7 @@ class FlintREPLTest val nonMatchingExcludeJobId = "jobXYZ" // This ID does not match the jobId val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -420,7 +425,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) // The function should return true since jobId is not excluded assert(result) @@ -436,11 +441,9 @@ class FlintREPLTest exception.setServiceName("AWSGlue") val mockFlintStatement = mock[FlintStatement] - val expectedError = ( - """{"Message":"Fail to read data from Glue. Cause: Access denied in AWS Glue service. Please check permissions. (Service: AWSGlue; """ + - """Status Code: 400; Error Code: AccessDeniedException; Request ID: null; Proxy: null)",""" + - """"ErrorSource":"AWSGlue","StatusCode":"400"}""" - ) + val expectedError = """{"Message":"Fail to read data from Glue. Cause: Access denied in AWS Glue service. Please check permissions. (Service: AWSGlue; """ + + """Status Code: 400; Error Code: AccessDeniedException; Request ID: null; Proxy: null)",""" + + """"ErrorSource":"AWSGlue","StatusCode":"400"}""" val result = FlintREPL.processQueryException(exception, mockFlintStatement) @@ -457,6 +460,7 @@ class FlintREPLTest val osClient = mock[OSClient] val sessionIndex = "sessionIndex" val handleSessionError = mock[Function1[String, Unit]] + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -474,7 +478,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) // The function should return false since jobId is excluded assert(!result) @@ -485,6 +489,7 @@ class FlintREPLTest val jobId = "jobABC" val osClient = mock[OSClient] val sessionIndex = "sessionIndex" + val sessionManager = mock[SessionManager] val getResponse = mock[GetResponse] when(osClient.getDoc(sessionIndex, sessionId)).thenReturn(getResponse) @@ -502,7 +507,7 @@ class FlintREPLTest when(getResponse.getSourceAsMap).thenReturn(sourceMap) // Execute the method under test - val result = FlintREPL.canPickNextStatement(sessionId, jobId, osClient, sessionIndex) + val result = FlintREPL.canPickNextStatement(sessionManager, sessionId, jobId) // The function should return true since the jobId is not in the excludeJobIds list assert(result) @@ -531,19 +536,20 @@ class FlintREPLTest val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() try { val flintSessionIndexUpdater = mock[OpenSearchUpdater] - - val commandContext = QueryExecutionContext( + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] + val commandContext = StatementExecutionContext( spark, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, jobId, + sessionId, + sessionManager, + statementLifecycleManager, + queryResultWriter, + dataSource, Duration(10, MINUTES), - 60, - 60) + 60L, + 60L) intercept[RuntimeException] { FlintREPL.exponentialBackoffRetry(maxRetries, 2.seconds) { @@ -574,6 +580,9 @@ class FlintREPLTest val sessionId = "someSessionId" val startTime = System.currentTimeMillis() val expectedDataFrame = mock[DataFrame] + val flintStatement = mock[FlintStatement] + val state = mock[InMemoryQueryExecutionState] + val context = mock[StatementExecutionContext] when(mockFlintStatement.query).thenReturn("SELECT 1") when(mockFlintStatement.submitTime).thenReturn(Instant.now().toEpochMilli()) @@ -595,7 +604,7 @@ class FlintREPLTest val sparkContext = mock[SparkContext] when(mockSparkSession.sparkContext).thenReturn(sparkContext) - val result = FlintREPL.executeAndHandle() + val result = FlintREPL.executeAndHandle(flintStatement, state, context) verify(mockSparkSession, times(1)).sql(any[String]) verify(sparkContext, times(1)).cancelJobGroup(any[String]) @@ -624,6 +633,8 @@ class FlintREPLTest val sessionId = "someSessionId" val startTime = System.currentTimeMillis() val expectedDataFrame = mock[DataFrame] + val state = mock[InMemoryQueryExecutionState] + val context = mock[StatementExecutionContext] // sql method can only throw RuntimeException when(mockSparkSession.sql(any[String])).thenThrow( @@ -637,7 +648,7 @@ class FlintREPLTest .thenReturn(expectedDataFrame) when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) - val result = FlintREPL.executeAndHandle() + val result = FlintREPL.executeAndHandle(flintStatement, state, context) // Verify that ParseException was caught and handled result should not be None // or result.isDefined shouldBe true @@ -650,6 +661,13 @@ class FlintREPLTest test("setupFlintJobWithExclusionCheck should proceed normally when no jobs are excluded") { val osClient = mock[OSClient] val getResponse = mock[GetResponse] + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] + val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") + when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) when(getResponse.getSourceAsMap).thenReturn( @@ -668,13 +686,26 @@ class FlintREPLTest val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") // other mock objects like osClient, flintSessionIndexUpdater with necessary mocking - val result = FlintREPL.setupFlintJobWithExclusionCheck() + val result = FlintREPL.setupFlintJobWithExclusionCheck( + conf, + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime) assert(!result) // Expecting false as the job should proceed normally } test("setupFlintJobWithExclusionCheck should exit early if current job is excluded") { val osClient = mock[OSClient] val getResponse = mock[GetResponse] + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] + val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") + when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Mock the rest of the GetResponse as needed @@ -682,13 +713,26 @@ class FlintREPLTest val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "jobId") - val result = FlintREPL.setupFlintJobWithExclusionCheck() + val result = FlintREPL.setupFlintJobWithExclusionCheck( + conf, + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime) assert(result) // Expecting true as the job should exit early } test("setupFlintJobWithExclusionCheck should exit early if a duplicate job is running") { val osClient = mock[OSClient] val getResponse = mock[GetResponse] + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] + val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") + when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) // Mock the GetResponse to simulate a scenario of a duplicate job @@ -708,44 +752,65 @@ class FlintREPLTest val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-1,job-2") - val result = FlintREPL.setupFlintJobWithExclusionCheck() + val result = FlintREPL.setupFlintJobWithExclusionCheck( + conf, + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime) assert(result) // Expecting true for early exit due to duplicate job } test("setupFlintJobWithExclusionCheck should setup job normally when conditions are met") { val osClient = mock[OSClient] val getResponse = mock[GetResponse] + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] + val conf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") + when(osClient.getDoc(*, *)).thenReturn(getResponse) when(getResponse.isExists()).thenReturn(true) val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "job-3,job-4") - val result = FlintREPL.setupFlintJobWithExclusionCheck() + val result = FlintREPL.setupFlintJobWithExclusionCheck( + conf, + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime) assert(!result) // Expecting false as the job proceeds normally } test( "setupFlintJobWithExclusionCheck should throw NoSuchElementException if sessionIndex or sessionId is missing") { - val osClient = mock[OSClient] - val flintSessionIndexUpdater = mock[OpenSearchUpdater] val mockConf = new SparkConf().set("spark.flint.deployment.excludeJobs", "") + val applicationId = "app1" + val jobId = "job1" + val sessionId = "session1" + val jobStartTime = System.currentTimeMillis() + val sessionManager = mock[SessionManager] assertThrows[NoSuchElementException] { - FlintREPL.setupFlintJobWithExclusionCheck() + FlintREPL.setupFlintJobWithExclusionCheck( + mockConf, + applicationId, + jobId, + sessionId, + sessionManager, + jobStartTime) } } test("queryLoop continue until inactivity limit is reached") { - val mockReader = mock[FlintReader] - val osClient = mock[OSClient] - when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC))) - .thenReturn(mockReader) - when(mockReader.hasNext).thenReturn(false) - val resultIndex = "testResultIndex" val dataSource = "testDataSource" - val sessionIndex = "testSessionIndex" val sessionId = "testSessionId" val jobId = "testJobId" @@ -754,25 +819,24 @@ class FlintREPLTest // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() - val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] - val commandContext = QueryExecutionContext( + val commandContext = StatementExecutionContext( spark, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, jobId, + sessionId, + sessionManager, + statementLifecycleManager, + queryResultWriter, + dataSource, Duration(10, MINUTES), - shortInactivityLimit, - 60) + 60L, + 60L) // Mock processCommands to always allow loop continuation - val getResponse = mock[GetResponse] - when(osClient.getDoc(*, *)).thenReturn(getResponse) - when(getResponse.isExists()).thenReturn(false) + when(sessionManager.getNextStatement(sessionId)).thenReturn(None) val startTime = System.currentTimeMillis() @@ -800,24 +864,26 @@ class FlintREPLTest val sessionId = "testSessionId" val jobId = "testJobId" val longInactivityLimit = 10000 // 10 seconds + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val commandContext = QueryExecutionContext( + val commandContext = StatementExecutionContext( spark, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, jobId, + sessionId, + sessionManager, + statementLifecycleManager, + queryResultWriter, + dataSource, Duration(10, MINUTES), - longInactivityLimit, - 60) + 60L, + 60L) // Mocking canPickNextStatement to return false when(osClient.getDoc(sessionIndex, sessionId)).thenAnswer(_ => { @@ -859,21 +925,23 @@ class FlintREPLTest // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val commandContext = QueryExecutionContext( + val commandContext = StatementExecutionContext( spark, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, jobId, + sessionId, + sessionManager, + statementLifecycleManager, + queryResultWriter, + dataSource, Duration(10, MINUTES), - inactivityLimit, - 60) + 60L, + 60L) try { // Mocking ThreadUtils to track the shutdown call @@ -909,21 +977,23 @@ class FlintREPLTest // Create a SparkSession for testing val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] val flintSessionIndexUpdater = mock[OpenSearchUpdater] - val commandContext = QueryExecutionContext( + val commandContext = StatementExecutionContext( spark, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, jobId, + sessionId, + sessionManager, + statementLifecycleManager, + queryResultWriter, + dataSource, Duration(10, MINUTES), - inactivityLimit, - 60) + 60L, + 60L) try { // Mocking ThreadUtils to track the shutdown call @@ -992,19 +1062,22 @@ class FlintREPLTest when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] - val commandContext = QueryExecutionContext( - mockSparkSession, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, + val commandContext = StatementExecutionContext( + spark, jobId, + sessionId, + sessionManager, + statementLifecycleManager, + queryResultWriter, + dataSource, Duration(10, MINUTES), - inactivityLimit, - 60) + 60L, + 60L) val startTime = Instant.now().toEpochMilli() @@ -1042,19 +1115,21 @@ class FlintREPLTest val spark = SparkSession.builder().master("local").appName("FlintREPLTest").getOrCreate() val flintSessionIndexUpdater = mock[OpenSearchUpdater] + val sessionManager = mock[SessionManager] + val statementLifecycleManager = mock[StatementLifecycleManager] + val queryResultWriter = mock[QueryResultWriter] - val commandContext = QueryExecutionContext( + val commandContext = StatementExecutionContext( spark, - dataSource, - resultIndex, - sessionId, - flintSessionIndexUpdater, - osClient, - sessionIndex, jobId, + sessionId, + sessionManager, + statementLifecycleManager, + queryResultWriter, + dataSource, Duration(10, MINUTES), - inactivityLimit, - 60) + 60L, + 60L) val startTime = Instant.now().toEpochMilli()