From 1aa3a31c24a2c59a788d2ba1c3ef3a5f17cf9f46 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Mon, 12 Aug 2024 10:03:27 -0700 Subject: [PATCH] Refactor UT Signed-off-by: Chen Dai --- .../FlintSparkTransactionSupportSuite.scala | 100 +++++++++++++----- 1 file changed, 71 insertions(+), 29 deletions(-) diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionSupportSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionSupportSuite.scala index 57114f230..be9201112 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionSupportSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/FlintSparkTransactionSupportSuite.scala @@ -5,7 +5,7 @@ package org.opensearch.flint.spark -import org.mockito.Mockito.{never, reset, times, verify, when, RETURNS_DEEP_STUBS} +import org.mockito.Mockito._ import org.opensearch.flint.common.metadata.log.FlintMetadataLogService import org.opensearch.flint.core.FlintClient import org.scalatest.matchers.should.Matchers @@ -28,47 +28,42 @@ class FlintSparkTransactionSupportSuite extends FlintSuite with Matchers { mockFlintMetadataLogService } - override protected def beforeEach(): Unit = { + override protected def afterEach(): Unit = { reset(mockFlintClient, mockFlintMetadataLogService) } test("execute transaction without force initialization when index exists") { - when(mockFlintClient.exists(testIndex)).thenReturn(true) - val result = - transactionSupport - .withTransaction[Boolean](testIndex, testOpName) { _ => true } - - result shouldBe Some(true) - verify(mockFlintMetadataLogService, times(1)).startTransaction(testIndex, false) + assertIndexOperation() + .withForceInit(false) + .withResult("test") + .whenIndexDataExists() + .expectResult("test") + .verifyTransaction(forceInit = false) } test("execute transaction with force initialization when index exists") { - when(mockFlintClient.exists(testIndex)).thenReturn(true) - val result = - transactionSupport - .withTransaction[Boolean](testIndex, testOpName, forceInit = true) { _ => true } - - result shouldBe Some(true) - verify(mockFlintMetadataLogService, times(1)).startTransaction(testIndex, true) + assertIndexOperation() + .withForceInit(true) + .withResult("test") + .whenIndexDataExists() + .expectResult("test") + .verifyTransaction(forceInit = true) } test("bypass transaction without force initialization when index does not exist") { - when(mockFlintClient.exists(testIndex)).thenReturn(false) - val result = - transactionSupport - .withTransaction[Boolean](testIndex, testOpName) { _ => true } - - result shouldBe None + assertIndexOperation() + .withForceInit(false) + .withResult("test") + .whenIndexDataNotExist() + .expectNoResult() } test("execute transaction with force initialization even if index does not exist") { - when(mockFlintClient.exists(testIndex)).thenReturn(false) - val result = - transactionSupport - .withTransaction[Boolean](testIndex, testOpName, forceInit = true) { _ => true } - - result shouldBe Some(true) - verify(mockFlintMetadataLogService, times(1)).startTransaction(testIndex, true) + assertIndexOperation() + .withForceInit(true) + .withResult("test") + .whenIndexDataNotExist() + .expectResult("test") } test("propagate original exception thrown within transaction") { @@ -82,4 +77,51 @@ class FlintSparkTransactionSupportSuite extends FlintSuite with Matchers { } } should have message "Fake cause" } + + private def assertIndexOperation(): FlintIndexAssertion = new FlintIndexAssertion + + class FlintIndexAssertion { + private var forceInit: Boolean = false + private var expectedResult: Option[String] = None + + def withForceInit(forceInit: Boolean): FlintIndexAssertion = { + this.forceInit = forceInit + this + } + + def withResult(expectedResult: String): FlintIndexAssertion = { + this.expectedResult = Some(expectedResult) + this + } + + def whenIndexDataExists(): FlintIndexAssertion = { + when(mockFlintClient.exists(testIndex)).thenReturn(true) + this + } + + def whenIndexDataNotExist(): FlintIndexAssertion = { + when(mockFlintClient.exists(testIndex)).thenReturn(false) + this + } + + def verifyTransaction(forceInit: Boolean): FlintIndexAssertion = { + verify(mockFlintMetadataLogService, times(1)).startTransaction(testIndex, forceInit) + this + } + + def expectResult(expectedResult: String): FlintIndexAssertion = { + val result = transactionSupport.withTransaction[String](testIndex, testOpName, forceInit) { + _ => expectedResult + } + result shouldBe Some(expectedResult) + this + } + + def expectNoResult(): Unit = { + val result = transactionSupport.withTransaction[String](testIndex, testOpName, forceInit) { + _ => expectedResult.getOrElse("") + } + result shouldBe None + } + } }