Skip to content

Commit

Permalink
Clean up logs and add UTs
Browse files Browse the repository at this point in the history
Signed-off-by: Louis Chu <[email protected]>
  • Loading branch information
noCharger committed Aug 17, 2024
1 parent e85a010 commit dac235d
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ package org.apache.spark.sql

import org.opensearch.flint.common.model.FlintStatement

/**
* Trait for writing the result of a query execution to an external data storage.
*/
trait QueryResultWriter {

/**
* Writes the given DataFrame, which represents the result of a query execution, to an external
* data storage based on the provided FlintStatement metadata.
*/
def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,11 @@
import org.opensearch.flint.core.RestHighLevelClientWrapper;
import org.opensearch.flint.core.auth.ResourceBasedAWSRequestSigningApacheInterceptor;
import org.opensearch.flint.core.http.RetryableHttpAsyncClient;
import java.util.logging.Logger;


/**
* Utility functions to create {@link IRestHighLevelClient}.
*/
public class OpenSearchClientUtils {
private static final Logger LOG = Logger.getLogger(OpenSearchClientUtils.class.getName());


/**
* Metadata log index name prefix
Expand Down Expand Up @@ -62,7 +58,6 @@ public static RestHighLevelClient createRestHighLevelClient(FlintOptions options
}

public static IRestHighLevelClient createClient(FlintOptions options) {
LOG.info("createClient called");
return new RestHighLevelClientWrapper(createRestHighLevelClient(options));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.logging.Logger;

/**
* Abstract OpenSearch Reader.
*/
public abstract class OpenSearchReader implements FlintReader {
private static final Logger LOG = Logger.getLogger(OpenSearchReader.class.getName());

@VisibleForTesting
/** Search request source builder. */
public final SearchRequest searchRequest;
Expand All @@ -50,7 +47,6 @@ public OpenSearchReader(IRestHighLevelClient client, SearchRequest searchRequest
return false;
}
List<SearchHit> searchHits = Arrays.asList(response.get().getHits().getHits());
LOG.info("Result sets: " + searchHits.size());
iterator = searchHits.iterator();
}
return iterator.hasNext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.flint.config.FlintSparkConf.{HOST_ENDPOINT, HOST_POR
/**
* Test required OpenSearch domain should extend OpenSearchSuite.
*/
trait OpenSearchSuite extends BeforeAndAfterAll with Logging {
trait OpenSearchSuite extends BeforeAndAfterAll {
self: Suite =>

protected lazy val container = new OpenSearchContainer()
Expand Down Expand Up @@ -146,7 +146,6 @@ trait OpenSearchSuite extends BeforeAndAfterAll with Logging {

val response =
openSearchClient.bulk(request, RequestOptions.DEFAULT)
logInfo(response.toString)
assume(
!response.hasFailures,
s"bulk index docs to $index failed: ${response.buildFailureMessage()}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ object FlintREPL extends Logging with FlintJobExecutor {

// init SparkContext
val conf: SparkConf = createSparkConf()
val dataSource = conf.get(FlintSparkConf.DATA_SOURCE_NAME.key, "unknown")
val dataSource = conf.get(FlintSparkConf.DATA_SOURCE_NAME.key, "")

if (dataSource == "unknown") {
logInfo(FlintSparkConf.DATA_SOURCE_NAME.key + " is not set")
if (dataSource.trim.isEmpty) {
logAndThrow(FlintSparkConf.DATA_SOURCE_NAME.key + " is not set or is empty")
}
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
Expand Down Expand Up @@ -919,12 +919,13 @@ object FlintREPL extends Logging with FlintJobExecutor {
result.getOrElse(throw new RuntimeException("Failed after retries"))
}

private def getSessionId(conf: SparkConf): String = {
val sessionIdOption: Option[String] = Option(conf.get(FlintSparkConf.SESSION_ID.key, null))
if (sessionIdOption.isEmpty) {
logAndThrow(FlintSparkConf.SESSION_ID.key + " is not set")
def getSessionId(conf: SparkConf): String = {
conf.getOption(FlintSparkConf.SESSION_ID.key) match {
case Some(sessionId) if sessionId.nonEmpty =>
sessionId
case _ =>
logAndThrow(s"${FlintSparkConf.SESSION_ID.key} is not set or is empty")
}
sessionIdOption.get
}

private def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ class OSClient(val flintOptions: FlintOptions) extends Logging {
case Success(response) =>
IRestHighLevelClient.recordOperationSuccess(
MetricConstants.REQUEST_METADATA_READ_METRIC_PREFIX)
logInfo(response.toString)
response
case Failure(e: Exception) =>
IRestHighLevelClient.recordOperationFailure(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ class SessionManagerImpl(spark: SparkSession, resultIndexOption: Option[String])
with FlintJobExecutor
with Logging {

// we don't allow default value for sessionIndex, sessionId and datasource. Throw exception if key not found.
if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}

// we don't allow default value for sessionIndex. Throw exception if key not found.
val sessionIndex: String = spark.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")

if (sessionIndex.isEmpty) {
logAndThrow(FlintSparkConf.REQUEST_INDEX.key + " is not set")
}

if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}

val osClient = new OSClient(FlintSparkConf().flintOptions())
val flintSessionIndexUpdater = osClient.createUpdater(sessionIndex)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ import org.apache.spark.sql.types.{LongType, NullType, StringType, StructField,
import org.apache.spark.sql.util.{DefaultThreadPoolFactory, MockThreadPoolFactory, MockTimeProvider, RealTimeProvider, ShutdownHookManagerTrait}
import org.apache.spark.util.ThreadUtils

@Ignore
class FlintREPLTest
extends SparkFunSuite
with MockitoSugar
Expand All @@ -50,38 +49,56 @@ class FlintREPLTest
// By using a type alias and casting, I can bypass the type checking error.
type AnyScheduledFuture = ScheduledFuture[_]

test(
"parseArgs with one argument should return None for query and the argument as resultIndex") {
test("parseArgs with no arguments should return (None, None)") {
val args = Array.empty[String]
val (queryOption, resultIndexOption) = FlintREPL.parseArgs(args)
queryOption shouldBe None
resultIndexOption shouldBe None
}

test("parseArgs with one argument should return None for query and Some for resultIndex") {
val args = Array("resultIndexName")
val (queryOption, resultIndex) = FlintREPL.parseArgs(args)
val (queryOption, resultIndexOption) = FlintREPL.parseArgs(args)
queryOption shouldBe None
resultIndex shouldBe "resultIndexName"
resultIndexOption shouldBe Some("resultIndexName")
}

test(
"parseArgs with two arguments should return the first argument as query and the second as resultIndex") {
test("parseArgs with two arguments should return Some for both query and resultIndex") {
val args = Array("SELECT * FROM table", "resultIndexName")
val (queryOption, resultIndex) = FlintREPL.parseArgs(args)
val (queryOption, resultIndexOption) = FlintREPL.parseArgs(args)
queryOption shouldBe Some("SELECT * FROM table")
resultIndex shouldBe "resultIndexName"
resultIndexOption shouldBe Some("resultIndexName")
}

test(
"parseArgs with no arguments should throw IllegalArgumentException with specific message") {
val args = Array.empty[String]
"parseArgs with more than two arguments should throw IllegalArgumentException with specific message") {
val args = Array("arg1", "arg2", "arg3")
val exception = intercept[IllegalArgumentException] {
FlintREPL.parseArgs(args)
}
exception.getMessage shouldBe "Unsupported number of arguments. Expected 1 or 2 arguments."
exception.getMessage shouldBe "Unsupported number of arguments. Expected no more than two arguments."
}

test(
"parseArgs with more than two arguments should throw IllegalArgumentException with specific message") {
val args = Array("arg1", "arg2", "arg3")
test("getSessionId should throw exception when SESSION_ID is not set") {
val conf = new SparkConf()
val exception = intercept[IllegalArgumentException] {
FlintREPL.parseArgs(args)
FlintREPL.getSessionId(conf)
}
exception.getMessage shouldBe "Unsupported number of arguments. Expected 1 or 2 arguments."
assert(exception.getMessage === FlintSparkConf.SESSION_ID.key + " is not set or is empty")
}

test("getSessionId should return the session ID when it's set") {
val sessionId = "test-session-id"
val conf = new SparkConf().set(FlintSparkConf.SESSION_ID.key, sessionId)
assert(FlintREPL.getSessionId(conf) === sessionId)
}

test("getSessionId should throw exception when SESSION_ID is set to empty string") {
val conf = new SparkConf().set(FlintSparkConf.SESSION_ID.key, "")
val exception = intercept[IllegalArgumentException] {
FlintREPL.getSessionId(conf)
}
assert(exception.getMessage === FlintSparkConf.SESSION_ID.key + " is not set or is empty")
}

test("getQuery should return query from queryOption if present") {
Expand Down Expand Up @@ -159,7 +176,7 @@ class FlintREPLTest
}
}

test("createHeartBeatUpdater should update heartbeat correctly") {
ignore("createHeartBeatUpdater should update heartbeat correctly") {
// Mocks
val threadPool = mock[ScheduledExecutorService]
val scheduledFutureRaw = mock[ScheduledFuture[_]]
Expand Down Expand Up @@ -321,7 +338,7 @@ class FlintREPLTest
assert(!result) // The function should return false
}

test("test canPickNextStatement: Doc Exists, JobId Matches, but JobId is Excluded") {
ignore("test canPickNextStatement: Doc Exists, JobId Matches, but JobId is Excluded") {
val sessionId = "session123"
val jobId = "jobABC"
val osClient = mock[OSClient]
Expand Down Expand Up @@ -545,7 +562,7 @@ class FlintREPLTest
assert(!result)
}

test("Doc Exists and excludeJobIds is an ArrayList Not Containing JobId") {
ignore("Doc Exists and excludeJobIds is an ArrayList Not Containing JobId") {
val sessionId = "session123"
val jobId = "jobABC"
val osClient = mock[OSClient]
Expand Down Expand Up @@ -639,7 +656,7 @@ class FlintREPLTest
}
}

test("executeAndHandle should handle TimeoutException properly") {
ignore("executeAndHandle should handle TimeoutException properly") {
val mockSparkSession = mock[SparkSession]
val mockConf = mock[RuntimeConfig]
when(mockSparkSession.conf).thenReturn(mockConf)
Expand Down Expand Up @@ -695,7 +712,7 @@ class FlintREPLTest
} finally threadPool.shutdown()
}

test("executeAndHandle should handle ParseException properly") {
ignore("executeAndHandle should handle ParseException properly") {
val mockSparkSession = mock[SparkSession]
val mockConf = mock[RuntimeConfig]
when(mockSparkSession.conf).thenReturn(mockConf)
Expand Down Expand Up @@ -795,7 +812,7 @@ class FlintREPLTest
assert(!result) // Expecting false as the job should proceed normally
}

test("setupFlintJobWithExclusionCheck should exit early if current job is excluded") {
ignore("setupFlintJobWithExclusionCheck should exit early if current job is excluded") {
val osClient = mock[OSClient]
val getResponse = mock[GetResponse]
val applicationId = "app1"
Expand Down Expand Up @@ -911,7 +928,7 @@ class FlintREPLTest
assert(!result) // Expecting false as the job proceeds normally
}

test(
ignore(
"setupFlintJobWithExclusionCheck should throw NoSuchElementException if sessionIndex or sessionId is missing") {
val osClient = mock[OSClient]
val flintSessionIndexUpdater = mock[OpenSearchUpdater]
Expand All @@ -933,7 +950,7 @@ class FlintREPLTest
}
}

test("queryLoop continue until inactivity limit is reached") {
ignore("queryLoop continue until inactivity limit is reached") {
val mockReader = mock[FlintReader]
val osClient = mock[OSClient]
when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC)))
Expand Down Expand Up @@ -984,7 +1001,7 @@ class FlintREPLTest
spark.stop()
}

test("queryLoop should stop when canPickUpNextStatement is false") {
ignore("queryLoop should stop when canPickUpNextStatement is false") {
val mockReader = mock[FlintReader]
val osClient = mock[OSClient]
when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC)))
Expand Down Expand Up @@ -1046,7 +1063,7 @@ class FlintREPLTest
spark.stop()
}

test("queryLoop should properly shut down the thread pool after execution") {
ignore("queryLoop should properly shut down the thread pool after execution") {
val mockReader = mock[FlintReader]
val osClient = mock[OSClient]
when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC)))
Expand Down Expand Up @@ -1155,7 +1172,7 @@ class FlintREPLTest
}
}

test("queryLoop should correctly update loop control variables") {
ignore("queryLoop should correctly update loop control variables") {
val mockReader = mock[FlintReader]
val osClient = mock[OSClient]
when(osClient.createQueryReader(any[String], any[String], any[String], eqTo(SortOrder.ASC)))
Expand Down Expand Up @@ -1240,7 +1257,7 @@ class FlintREPLTest
(100, 300L) // 100 ms, 300 ms
)

test(
ignore(
"queryLoop should execute loop without processing any commands for different inactivity limits and frequencies") {
forAll(testCases) { (inactivityLimit, queryLoopExecutionFrequency) =>
val mockReader = mock[FlintReader]
Expand Down

0 comments on commit dac235d

Please sign in to comment.