Skip to content

Commit

Permalink
Add UT and more doc
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Sep 4, 2024
1 parent 9c03d7d commit 20a6935
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,33 @@ import org.apache.spark.sql.execution.streaming.CheckpointFileManager.RenameHelp
* @param checkpointLocation
* The path to the checkpoint directory.
*/
class FlintSparkCheckpoint(spark: SparkSession, checkpointLocation: String) extends Logging {
private val checkpointDir = new Path(checkpointLocation)
class FlintSparkCheckpoint(spark: SparkSession, val checkpointLocation: String) extends Logging {

/** Checkpoint root directory path */
private val checkpointRootDir = new Path(checkpointLocation)

/** Spark checkpoint manager */
private val checkpointManager =
CheckpointFileManager.create(checkpointDir, spark.sessionState.newHadoopConf())
CheckpointFileManager.create(checkpointRootDir, spark.sessionState.newHadoopConf())

/**
* Checks if the checkpoint directory exists.
*
* @return
* true if the checkpoint directory exists, false otherwise.
*/
def exists(): Boolean = checkpointManager.exists(checkpointDir)
def exists(): Boolean = checkpointManager.exists(checkpointRootDir)

/**
* Creates the checkpoint directory and all necessary parent directories if they do not already
* exist.
*
* @return
* The path to the created checkpoint directory.
*/
def createDirectory(): Path = {
checkpointManager.createCheckpointDirectory
}

/**
* Creates a temporary file in the checkpoint directory.
Expand All @@ -45,9 +60,7 @@ class FlintSparkCheckpoint(spark: SparkSession, checkpointLocation: String) exte
checkpointManager match {
case manager: RenameHelperMethods =>
val tempFilePath =
new Path(
checkpointManager.createCheckpointDirectory(), // create all parent folders if needed
s"${UUID.randomUUID().toString}.tmp")
new Path(createDirectory(), s"${UUID.randomUUID().toString}.tmp")
Some(manager.createTempFile(tempFilePath))
case _ =>
logInfo(s"Cannot create temp file at checkpoint location: ${checkpointManager.getClass}")
Expand All @@ -62,11 +75,11 @@ class FlintSparkCheckpoint(spark: SparkSession, checkpointLocation: String) exte
*/
def delete(): Unit = {
try {
checkpointManager.delete(checkpointDir)
logInfo(s"Checkpoint directory $checkpointDir deleted.")
checkpointManager.delete(checkpointRootDir)
logInfo(s"Checkpoint directory $checkpointRootDir deleted")
} catch {
case e: Exception =>
logError(s"Error deleting checkpoint directory $checkpointDir", e)
logError(s"Error deleting checkpoint directory $checkpointRootDir", e)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark

import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.scalatest.matchers.should.Matchers

import org.apache.spark.FlintSuite

class FlintSparkCheckpointSuite extends FlintSuite with Matchers {

test("exists") {
withCheckpoint { checkpoint =>
checkpoint.exists() shouldBe false
checkpoint.createDirectory()
checkpoint.exists() shouldBe true
}
}

test("create directory") {
withTempPath { tempDir =>
val checkpointDir = new Path(tempDir.getAbsolutePath, "sub/subsub")
val checkpoint = new FlintSparkCheckpoint(spark, checkpointDir.toString)
checkpoint.createDirectory()

tempDir.exists() shouldBe true
}
}

test("create temp file") {
withCheckpoint { checkpoint =>
val tempFile = checkpoint.createTempFile()
tempFile shouldBe defined

// Close the stream to ensure the file is flushed
tempFile.get.close()

// Assert that there is a .tmp file
listFiles(checkpoint.checkpointLocation)
.exists(isTempFile) shouldBe true
}
}

test("delete") {
withCheckpoint { checkpoint =>
checkpoint.createDirectory()
checkpoint.delete()
checkpoint.exists() shouldBe false
}
}

private def withCheckpoint(block: FlintSparkCheckpoint => Unit): Unit = {
withTempPath { checkpointDir =>
val checkpoint = new FlintSparkCheckpoint(spark, checkpointDir.getAbsolutePath)
block(checkpoint)
}
}

private def listFiles(dir: String): Array[FileStatus] = {
val fs = FileSystem.get(spark.sessionState.newHadoopConf())
fs.listStatus(new Path(dir))
}

private def isTempFile(file: FileStatus): Boolean = {
file.isFile && file.getPath.getName.endsWith(".tmp")
}
}

0 comments on commit 20a6935

Please sign in to comment.