Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pre-validate checkpoint location write permission #414

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.flint.spark

import java.io.IOException
import java.util.UUID

import org.apache.hadoop.fs.Path
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
Expand All @@ -17,6 +17,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
import org.apache.spark.sql.execution.streaming.CheckpointFileManager.RenameHelperMethods
Comment on lines 19 to +20
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we merge the import?

Copy link
Collaborator Author

@dai-chen dai-chen Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't find a way to do this in Scala. Please show me if you know how to do this. Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the class structure of CheckpointFileManager, it appears that the only option to avoid another import is to call CheckpointFileManager.RenameHelperMethods

import org.apache.spark.sql.flint.{loadTable, parseTableName, qualifyTableName}

/**
Expand Down Expand Up @@ -76,14 +77,37 @@ trait FlintSparkValidationHelper extends Logging {
new Path(checkpointLocation),
spark.sessionState.newHadoopConf())

// The primary intent here is to catch any exceptions during the accessibility check.
// The actual result is ignored, as Spark can create any necessary sub-folders
// when the streaming job starts.
/*
* Read permission check: The primary intent here is to catch any exceptions
* during the accessibility check. The actual result is ignored, as the write
* permission check below will create any necessary sub-folders.
*/
checkpointManager.exists(new Path(checkpointLocation))

/*
* Write permission check: Attempt to create a temporary file to verify write access.
* The temporary file is left in place in case additional delete permissions required.
*/
checkpointManager match {
case manager: RenameHelperMethods =>
val tempFilePath =
new Path(
checkpointManager
.createCheckpointDirectory(), // create all parent folders if needed
s"${UUID.randomUUID().toString}.tmp")

manager.createTempFile(tempFilePath).close()
case _ =>
logInfo(
s"Bypass checkpoint location write permission check: ${checkpointManager.getClass}")
}
dai-chen marked this conversation as resolved.
Show resolved Hide resolved

true
} catch {
case e: IOException =>
logWarning(s"Failed to check if checkpoint location $checkpointLocation exists", e)
case e: Exception =>
logWarning(
s"Exception occurred while verifying access to checkpoint location $checkpointLocation",
e)
false
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class AutoIndexRefresh(indexName: String, index: FlintSparkIndex)
if (checkpointLocation.isDefined) {
require(
isCheckpointLocationAccessible(spark, checkpointLocation.get),
s"No permission to access the checkpoint location ${checkpointLocation.get}")
s"No sufficient permission to access the checkpoint location ${checkpointLocation.get}")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class IncrementalIndexRefresh(indexName: String, index: FlintSparkIndex)
"Checkpoint location is required by incremental refresh")
require(
isCheckpointLocationAccessible(spark, checkpointLocation.get),
s"No permission to access the checkpoint location ${checkpointLocation.get}")
s"No sufficient permission to access the checkpoint location ${checkpointLocation.get}")
}

override def start(spark: SparkSession, flintSparkConf: FlintSparkConf): Option[String] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@ package org.opensearch.flint.spark

import java.util.{Locale, UUID}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, FSDataInputStream, Path, PathFilter}
import org.opensearch.flint.spark.covering.FlintSparkCoveringIndex
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.refresh.FlintSparkIndexRefresh.RefreshMode.{AUTO, INCREMENTAL, RefreshMode}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.scalatest.matchers.must.Matchers.have
import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, the}
import org.scalatestplus.mockito.MockitoSugar.mock

import org.apache.spark.sql.SparkHiveSupportSuite
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
import org.apache.spark.sql.flint.config.FlintSparkConf.CHECKPOINT_MANDATORY
import org.apache.spark.sql.internal.SQLConf

class FlintSparkIndexValidationITSuite extends FlintSparkSuite with SparkHiveSupportSuite {

Expand Down Expand Up @@ -103,6 +108,38 @@ class FlintSparkIndexValidationITSuite extends FlintSparkSuite with SparkHiveSup
}
}

Seq(
(AUTO, createSkippingIndexStatement),
(AUTO, createCoveringIndexStatement),
(AUTO, createMaterializedViewStatement),
(INCREMENTAL, createSkippingIndexStatement),
(INCREMENTAL, createCoveringIndexStatement),
(INCREMENTAL, createMaterializedViewStatement))
.foreach { case (refreshMode, statement) =>
test(
s"should fail to create $refreshMode refresh Flint index if checkpoint location is not writable: $statement") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (name STRING) USING JSON")

withTempDir { checkpointDir =>
// Set checkpoint dir readonly to simulate the exception
checkpointDir.setWritable(false)

the[IllegalArgumentException] thrownBy {
sql(s"""
| $statement
| WITH (
| ${optionName(refreshMode)} = true,
| checkpoint_location = "$checkpointDir"
| )
|""".stripMargin)
} should have message
s"requirement failed: No sufficient permission to access the checkpoint location $checkpointDir"
}
}
}
}

Seq(
(AUTO, createSkippingIndexStatement),
(AUTO, createCoveringIndexStatement),
Expand All @@ -127,7 +164,7 @@ class FlintSparkIndexValidationITSuite extends FlintSparkSuite with SparkHiveSup
| )
|""".stripMargin)
} should have message
s"requirement failed: No permission to access the checkpoint location $checkpointDir"
s"requirement failed: No sufficient permission to access the checkpoint location $checkpointDir"
}
}
}
Expand Down Expand Up @@ -173,14 +210,97 @@ class FlintSparkIndexValidationITSuite extends FlintSparkSuite with SparkHiveSup
sql(statement)
flint.refreshIndex(flintIndexName)
flint.queryIndex(flintIndexName).count() shouldBe 1

deleteTestIndex(flintIndexName)
}
}
}

Seq(
(skippingIndexName, AUTO, createSkippingIndexStatement),
(coveringIndexName, AUTO, createCoveringIndexStatement),
(materializedViewName, AUTO, createMaterializedViewStatement),
(skippingIndexName, INCREMENTAL, createSkippingIndexStatement),
(coveringIndexName, INCREMENTAL, createCoveringIndexStatement),
(materializedViewName, INCREMENTAL, createMaterializedViewStatement))
.foreach { case (flintIndexName, refreshMode, statement) =>
test(
s"should succeed to create $refreshMode refresh Flint index even if checkpoint sub-folder doesn't exist: $statement") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (name STRING) USING JSON")
sql(s"INSERT INTO $testTable VALUES ('test')")

withTempDir { checkpointDir =>
// Specify nonexistent sub-folder and expect pre-validation to pass
val nonExistCheckpointDir = s"$checkpointDir/${UUID.randomUUID().toString}"
sql(s"""
| $statement
| WITH (
| ${optionName(refreshMode)} = true,
| checkpoint_location = '$nonExistCheckpointDir'
| )
|""".stripMargin)

deleteTestIndex(flintIndexName)
}
}
}
}

test(
"should bypass write permission check for checkpoint location if checkpoint manager class doesn't support create temp file") {
withTable(testTable) {
sql(s"CREATE TABLE $testTable (name STRING) USING JSON")
sql(s"INSERT INTO $testTable VALUES ('test')")

withTempDir { checkpointDir =>
// Set readonly to verify write permission check bypass
checkpointDir.setWritable(false)

// Configure fake checkpoint file manager
val confKey = SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key
withSQLConf(confKey -> classOf[FakeCheckpointFileManager].getName) {
sql(s"""
| $createSkippingIndexStatement
| WITH (
| incremental_refresh = true,
| checkpoint_location = '${checkpointDir.getAbsolutePath}'
| )
|""".stripMargin)
}
}
}
}

private def lowercase(mode: RefreshMode): String = mode.toString.toLowerCase(Locale.ROOT)

private def optionName(mode: RefreshMode): String = mode match {
case AUTO => "auto_refresh"
case INCREMENTAL => "incremental_refresh"
}
}

/**
* Fake checkpoint file manager.
*/
class FakeCheckpointFileManager(path: Path, conf: Configuration) extends CheckpointFileManager {

override def createAtomic(
path: Path,
overwriteIfPossible: Boolean): CheckpointFileManager.CancellableFSDataOutputStream =
throw new UnsupportedOperationException

override def open(path: Path): FSDataInputStream = mock[FSDataInputStream]

override def list(path: Path, filter: PathFilter): Array[FileStatus] = Array()

override def mkdirs(path: Path): Unit = throw new UnsupportedOperationException

override def exists(path: Path): Boolean = true

override def delete(path: Path): Unit = throw new UnsupportedOperationException

override def isLocal: Boolean = throw new UnsupportedOperationException

override def createCheckpointDirectory(): Path = path
}
Loading