diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml index 065e2e44f..433448421 100644 --- a/.github/workflows/master.yml +++ b/.github/workflows/master.yml @@ -108,7 +108,7 @@ jobs: run: docker run -d -it -p 39227:9200 -p 39337:9300 -e "discovery.type=single-node" -v /home/runner/work/elastic4s/elastic4s/elastic4s-tests/src/test/resources/elasticsearch.yml:/usr/share/elasticsearch/config/elasticsearch.yml docker.elastic.co/elasticsearch/elasticsearch:8.5.3 - name: run tests - run: sbt ++3.2.0 elastic4s-scala3/test + run: sbt ++3.3.0 elastic4s-scala3/test - name: Import GPG key id: import_gpg @@ -125,7 +125,7 @@ jobs: echo "email: ${{ steps.import_gpg.outputs.email }}" - name: publish snapshot - run: sbt ++3.2.0 elastic4s-scala3/publish + run: sbt ++3.3.0 elastic4s-scala3/publish env: OSSRH_USERNAME: ${{ secrets.OSSRH_USERNAME }} OSSRH_PASSWORD: ${{ secrets.OSSRH_PASSWORD }} diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 4e46d13ac..0310991ab 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -65,4 +65,4 @@ jobs: run: docker run -d -it -p 39227:9200 -p 39337:9300 -e "discovery.type=single-node" -v /home/runner/work/elastic4s/elastic4s/elastic4s-tests/src/test/resources/elasticsearch.yml:/usr/share/elasticsearch/config/elasticsearch.yml docker.elastic.co/elasticsearch/elasticsearch:8.5.3 - name: run tests - run: sbt ++3.2.0 elastic4s-scala3/test + run: sbt ++3.3.0 elastic4s-scala3/test diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 66c066bea..80a6e5886 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -57,7 +57,7 @@ jobs: OSSRH_PASSWORD: ${{ secrets.OSSRH_PASSWORD }} - name: publish 3.0 release - run: sbt ++3.2.0 elastic4s-scala3/publishSigned + run: sbt ++3.3.0 elastic4s-scala3/publishSigned env: RELEASE_VERSION: ${{ github.event.inputs.version }} OSSRH_USERNAME: ${{ secrets.OSSRH_USERNAME }} diff --git a/build.sbt b/build.sbt index 448ed679f..f6b6a26f3 100644 --- a/build.sbt +++ b/build.sbt @@ -25,7 +25,7 @@ def ossrhUsername = sys.env.getOrElse("OSSRH_USERNAME", "") def ossrhPassword = sys.env.getOrElse("OSSRH_PASSWORD", "") val scala2Versions = Seq("2.12.17", "2.13.11") -val scalaAllVersions = scala2Versions :+ "3.2.2" +val scalaAllVersions = scala2Versions :+ "3.3.0" lazy val commonScalaVersionSettings = Seq( scalaVersion := "2.12.17", crossScalaVersions := Nil @@ -139,7 +139,8 @@ lazy val scala3Projects: Seq[ProjectReference] = Seq( ziojson, clientsttp, httpstreams, - akkastreams + akkastreams, + pekkostreams ) lazy val scala3_root = Project("elastic4s-scala3", file("scala3")) .settings(name := "elastic4s") @@ -157,7 +158,7 @@ lazy val root = Project("elastic4s", file(".")) noPublishSettings ) .aggregate( - Seq[ProjectReference](scalaz, sprayjson, ziojson_1, clientakka) ++ scala3Projects: _* + Seq[ProjectReference](scalaz, sprayjson, ziojson_1, clientakka, clientpekko) ++ scala3Projects: _* ) lazy val domain = (project in file("elastic4s-domain")) @@ -269,6 +270,12 @@ lazy val akkastreams = (project in file("elastic4s-streams-akka")) .settings(scala3Settings) .settings(libraryDependencies += Dependencies.akkaStream) +lazy val pekkostreams = (project in file("elastic4s-streams-pekko")) + .dependsOn(core, testkit % "test", jackson % "test") + .settings(name := "elastic4s-streams-pkko") + .settings(scala3Settings) + .settings(libraryDependencies += Dependencies.pekkoStream) + lazy val jackson = (project in file("elastic4s-json-jackson")) .dependsOn(core) .settings(name := "elastic4s-json-jackson") @@ -324,8 +331,15 @@ lazy val clientsttp = (project in file("elastic4s-client-sttp")) lazy val clientakka = (project in file("elastic4s-client-akka")) .dependsOn(core, testkit % "test") .settings(name := "elastic4s-client-akka") - .settings(scala2Settings) // tests need re-writing to not use scalaMock. We also need akka-http to be cross-published, which depends on an akka bump with restrictive licensing changes - .settings(libraryDependencies ++= Seq(akkaHTTP, akkaStream, scalaMock)) + .settings(scala2Settings) // We need akka-http to be cross-published, which depends on an akka bump with restrictive licensing changes + .settings(libraryDependencies ++= Seq(akkaHTTP, akkaStream)) + +lazy val clientpekko = (project in file("elastic4s-client-pekko")) + .dependsOn(core, testkit % "test") + .settings(name := "elastic4s-client-pekko") + .settings(scala3Settings) + .settings(libraryDependencies ++= Seq(pekkoHTTP, pekkoStream)) + lazy val tests = (project in file("elastic4s-tests")) .settings(name := "elastic4s-tests") diff --git a/elastic4s-client-akka/src/test/scala/com/sksamuel/elastic4s/akka/AkkaHttpClientMockTest.scala b/elastic4s-client-akka/src/test/scala/com/sksamuel/elastic4s/akka/AkkaHttpClientMockTest.scala index 86a9a9cbd..d21319320 100644 --- a/elastic4s-client-akka/src/test/scala/com/sksamuel/elastic4s/akka/AkkaHttpClientMockTest.scala +++ b/elastic4s-client-akka/src/test/scala/com/sksamuel/elastic4s/akka/AkkaHttpClientMockTest.scala @@ -3,12 +3,13 @@ package com.sksamuel.elastic4s.akka import akka.actor.ActorSystem import akka.http.scaladsl.model.{HttpRequest, HttpResponse, StatusCodes, Uri} import com.sksamuel.elastic4s.{ElasticRequest, HttpEntity => ElasticEntity, HttpResponse => ElasticResponse} -import org.scalamock.function.MockFunction1 -import org.scalamock.scalatest.MockFactory +import org.mockito.ArgumentMatchers._ import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent._ import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec +import org.scalatestplus.mockito.MockitoSugar +import org.mockito.Mockito._ import scala.concurrent.duration._ import scala.util.{Failure, Success, Try} @@ -16,19 +17,19 @@ import scala.util.{Failure, Success, Try} class AkkaHttpClientMockTest extends AnyWordSpec with Matchers - with MockFactory + with MockitoSugar with ScalaFutures with IntegrationPatience with BeforeAndAfterAll { private implicit lazy val system: ActorSystem = ActorSystem() - override def afterAll: Unit = { + override def afterAll(): Unit = { system.terminate() } - def mockHttpPool(): (MockFunction1[HttpRequest, Try[HttpResponse]], TestHttpPoolFactory) = { - val sendRequest = mockFunction[HttpRequest, Try[HttpResponse]] + def mockHttpPool(): (Function[HttpRequest, Try[HttpResponse]], TestHttpPoolFactory) = { + val sendRequest = mock[Function[HttpRequest, Try[HttpResponse]]] val poolFactory = new TestHttpPoolFactory(sendRequest) (sendRequest, poolFactory) } @@ -49,22 +50,22 @@ class AkkaHttpClientMockTest val client = new AkkaHttpClient(AkkaHttpClientSettings(hosts), blacklist, httpPool) - (blacklist.contains _).expects("host1").returns(false) - (blacklist.contains _).expects("host2").returns(false) - (blacklist.add _).expects("host1").returns(true) - (blacklist.remove _).expects("host2").returns(false) + when(blacklist.contains("host1")).thenReturn(false) + when(blacklist.contains("host2")).thenReturn(false) + when(blacklist.add("host1")).thenReturn(true) + when(blacklist.remove("host2")).thenReturn(false) - sendRequest - .expects(argThat { (r: HttpRequest) => - r.uri == Uri("http://host1/test") - }) - .returns(Success(HttpResponse(StatusCodes.BadGateway))) + when(sendRequest + .apply(argThat { (r: HttpRequest) => + r != null && r.uri == Uri("http://host1/test") + })) + .thenReturn(Success(HttpResponse(StatusCodes.BadGateway))) - sendRequest - .expects(argThat { (r: HttpRequest) => - r.uri == Uri("http://host2/test") - }) - .returns(Success(HttpResponse().withEntity("ok"))) + when(sendRequest + .apply(argThat { (r: HttpRequest) => + r != null && r.uri == Uri("http://host2/test") + })) + .thenReturn(Success(HttpResponse().withEntity("ok"))) client .sendAsync(ElasticRequest("GET", "/test")) @@ -91,14 +92,14 @@ class AkkaHttpClientMockTest blacklist, httpPool) - (blacklist.contains _).expects("host1").returns(false) - (blacklist.add _).expects("host1").returns(true) + when(blacklist.contains("host1")).thenReturn(false) + when(blacklist.add("host1")).thenReturn(true) - sendRequest - .expects(argThat { (r: HttpRequest) => + when(sendRequest + .apply(argThat { (r: HttpRequest) => r.uri == Uri("http://host1/test") - }) - .returns(Success(HttpResponse(StatusCodes.BadGateway))) + })) + .thenReturn(Success(HttpResponse(StatusCodes.BadGateway))) client .sendAsync(ElasticRequest("GET", "/test")) @@ -122,22 +123,22 @@ class AkkaHttpClientMockTest val client = new AkkaHttpClient(AkkaHttpClientSettings(hosts), blacklist, httpPool) - (blacklist.contains _).expects("host1").returns(false) - (blacklist.contains _).expects("host2").returns(false) - (blacklist.add _).expects("host1").returns(true) - (blacklist.remove _).expects("host2").returns(false) + when(blacklist.contains("host1")).thenReturn(false) + when(blacklist.contains("host2")).thenReturn(false) + when(blacklist.add("host1")).thenReturn(true) + when(blacklist.remove("host2")).thenReturn(false) - sendRequest - .expects(argThat { (r: HttpRequest) => - r.uri == Uri("http://host1/test") - }) - .returns(Success(HttpResponse(StatusCodes.BadGateway))) + when(sendRequest + .apply(argThat { (r: HttpRequest) => + r != null && r.uri == Uri("http://host1/test") + })) + .thenReturn(Success(HttpResponse(StatusCodes.BadGateway))) - sendRequest - .expects(argThat { (r: HttpRequest) => + when(sendRequest + .apply(argThat { (r: HttpRequest) => r.uri == Uri("http://host2/test") - }) - .returns(Success(HttpResponse().withEntity("host2"))) + })) + .thenReturn(Success(HttpResponse().withEntity("host2"))) client .sendAsync(ElasticRequest("GET", "/test")) @@ -158,22 +159,22 @@ class AkkaHttpClientMockTest val client = new AkkaHttpClient(AkkaHttpClientSettings(hosts), blacklist, httpPool) - (blacklist.contains _).expects("host1").returns(false) - (blacklist.contains _).expects("host2").returns(false) - (blacklist.add _).expects("host1").returns(true) - (blacklist.remove _).expects("host2").returns(false) + when(blacklist.contains("host1")).thenReturn(false) + when(blacklist.contains("host2")).thenReturn(false) + when(blacklist.add("host1")).thenReturn(true) + when(blacklist.remove("host2")).thenReturn(false) - sendRequest - .expects(argThat { (r: HttpRequest) => - r.uri == Uri("http://host1/test") - }) - .returns(Failure(new Exception("Some exception"))) + when(sendRequest + .apply(argThat { (r: HttpRequest) => + r != null && r.uri == Uri("http://host1/test") + })) + .thenReturn(Failure(new Exception("Some exception"))) - sendRequest - .expects(argThat { (r: HttpRequest) => - r.uri == Uri("http://host2/test") - }) - .returns(Success(HttpResponse().withEntity("host2"))) + when(sendRequest + .apply(argThat { (r: HttpRequest) => + r != null && r.uri == Uri("http://host2/test") + })) + .thenReturn(Success(HttpResponse().withEntity("host2"))) client .sendAsync(ElasticRequest("GET", "/test")) @@ -194,16 +195,16 @@ class AkkaHttpClientMockTest val client = new AkkaHttpClient(AkkaHttpClientSettings(hosts), blacklist, httpPool) - (blacklist.contains _).expects("host1").returns(true) - (blacklist.size _).expects().returns(1) - (blacklist.contains _).expects("host2").returns(false) - (blacklist.remove _).expects("host2").returns(false) + when(blacklist.contains("host1")).thenReturn(true) + when(blacklist.size).thenReturn(1) + when(blacklist.contains("host2")).thenReturn(false) + when(blacklist.remove("host2")).thenReturn(false) - sendRequest - .expects(argThat { (r: HttpRequest) => + when(sendRequest + .apply(argThat { (r: HttpRequest) => r.uri == Uri("http://host2/test") - }) - .returns(Success(HttpResponse().withEntity("host2"))) + })) + .thenReturn(Success(HttpResponse().withEntity("host2"))) client .sendAsync(ElasticRequest("GET", "/test")) diff --git a/elastic4s-client-akka/src/test/scala/com/sksamuel/elastic4s/akka/AkkaHttpClientTest.scala b/elastic4s-client-akka/src/test/scala/com/sksamuel/elastic4s/akka/AkkaHttpClientTest.scala index 85926e8a8..ee86f4806 100644 --- a/elastic4s-client-akka/src/test/scala/com/sksamuel/elastic4s/akka/AkkaHttpClientTest.scala +++ b/elastic4s-client-akka/src/test/scala/com/sksamuel/elastic4s/akka/AkkaHttpClientTest.scala @@ -17,7 +17,7 @@ class AkkaHttpClientTest extends AnyFlatSpec with Matchers with DockerTests with private implicit lazy val system: ActorSystem = ActorSystem() - override def beforeAll: Unit = { + override def beforeAll(): Unit = { Try { client.execute { deleteIndex("testindex") @@ -25,7 +25,7 @@ class AkkaHttpClientTest extends AnyFlatSpec with Matchers with DockerTests with } } - override def afterAll: Unit = { + override def afterAll(): Unit = { Try { client.execute { deleteIndex("testindex") diff --git a/elastic4s-client-pekko/src/main/resources/reference.conf b/elastic4s-client-pekko/src/main/resources/reference.conf new file mode 100644 index 000000000..f9902ea74 --- /dev/null +++ b/elastic4s-client-pekko/src/main/resources/reference.conf @@ -0,0 +1,18 @@ +com.sksamuel.elastic4s.pekko { + hosts = [] + https = false + verify-ssl-certificate = true + // optionally provide credentials + // username = ... + // password = ... + queue-size = 1000 + blacklist { + min-duration = 1m + max-duration = 30m + } + max-retry-timeout = 30s + pekko.http { + // pekko-http settings specific for elastic4s + // can be overwritten in this section + } +} diff --git a/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/Blacklist.scala b/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/Blacklist.scala new file mode 100644 index 000000000..29353087c --- /dev/null +++ b/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/Blacklist.scala @@ -0,0 +1,45 @@ +package com.sksamuel.elastic4s.pekko + +/** + * List of 'bad' hosts. + * Implementation must have expiration logic backed-in. + */ +private[pekko] trait Blacklist { + + /** + * Adds a host to the blacklist. + * + * @param host host + * @return true if record is blacklisted for the first time + */ + def add(host: String): Boolean + + /** + * Removes a host from the blacklist. + * + * @param host host + * @return true if host was blacklisted + */ + def remove(host: String): Boolean + + /** + * Checks if a host can be used. + * + * @param host host + * @return true if host is not in a blacklist or temporary removed from it + */ + def contains(host: String): Boolean + + /** + * Number of hosts in blacklist + */ + def size: Int + + /** + * List all hosts in the blacklist + * + * @return + */ + def list: List[String] +} + diff --git a/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/DefaultBlacklist.scala b/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/DefaultBlacklist.scala new file mode 100644 index 000000000..9cd7bb9bd --- /dev/null +++ b/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/DefaultBlacklist.scala @@ -0,0 +1,78 @@ +package com.sksamuel.elastic4s.pekko + +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ +import scala.concurrent.duration.FiniteDuration + +/** + * Thread-safe host blacklist. + * Blacklist duration starts with `min` and exponentially increased up to `max` on subsequent calls to `add`. + * When `remove` is called - blacklist record is permanently removed and next `add` will start with `min` again. + * + * @param min minimum time to keep blacklist record + * @param max maximum time to keep blacklist record + * @param nanoTime clock in nanoseconds + */ +private[pekko] class DefaultBlacklist(min: FiniteDuration, + max: FiniteDuration, + nanoTime: => Long = System.nanoTime) + extends Blacklist { + + import DefaultBlacklist._ + + private val hosts = new ConcurrentHashMap[String, BlacklistRecord]() + + override def add(host: String): Boolean = { + val now = nanoTime + val record = hosts.getOrDefault( + host, + BlacklistRecord(enabled = true, startTime = now, untilTime = -1, -1)) + + if(now >= record.untilTime) { + val retries = record.retries + 1 + + val untilTime = now + Math + .min(min.toNanos * Math.pow(2, retries * 0.5), max.toNanos) + .toLong + + val updatedRecord = + record.copy( + enabled = true, + untilTime = untilTime, + retries = retries) + + hosts.put(host, updatedRecord) == null + } else false + } + + override def remove(host: String): Boolean = { + hosts.remove(host) != null + } + + override def contains(host: String): Boolean = { + hosts.get(host) match { + case null => false + case r => + if (r.enabled) { + if (nanoTime - r.untilTime >= 0) { + hosts.put(host, r.copy(enabled = false)) + false + } else true + } else false + } + } + + override def size: Int = hosts.size() + + override def list: List[String] = hosts.keys().asScala.toList +} + +object DefaultBlacklist { + + private case class BlacklistRecord(enabled: Boolean, + startTime: Long, + untilTime: Long, + retries: Int) + +} diff --git a/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/DefaultHttpPoolFactory.scala b/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/DefaultHttpPoolFactory.scala new file mode 100644 index 000000000..2057dcffa --- /dev/null +++ b/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/DefaultHttpPoolFactory.scala @@ -0,0 +1,62 @@ +package com.sksamuel.elastic4s.pekko + +import org.apache.pekko.NotUsed +import org.apache.pekko.actor.ActorSystem +import org.apache.pekko.http.scaladsl.{ConnectionContext, Http} +import org.apache.pekko.http.scaladsl.model.{HttpRequest, HttpResponse} +import org.apache.pekko.http.scaladsl.settings.ConnectionPoolSettings +import org.apache.pekko.stream.scaladsl.Flow + +import java.security.cert.X509Certificate +import javax.net.ssl.{KeyManager, SSLContext, X509TrustManager} +import scala.concurrent.Future +import scala.concurrent.duration.Duration +import scala.util.Try + +private[pekko] class DefaultHttpPoolFactory(settings: ConnectionPoolSettings, verifySslCertificate : Boolean)( + implicit system: ActorSystem) + extends HttpPoolFactory { + + private val http = Http() + + private val poolSettings = settings.withResponseEntitySubscriptionTimeout( + Duration.Inf) // we guarantee to consume consume data from all responses + + // take from https://gist.github.com/iRevive/7d17144284a7a2227487635ec815860d + private val trustfulSslContext: SSLContext = { + object NoCheckX509TrustManager extends X509TrustManager { + override def checkClientTrusted(chain: Array[X509Certificate], authType: String) = () + + override def checkServerTrusted(chain: Array[X509Certificate], authType: String) = () + + override def getAcceptedIssuers = Array[X509Certificate]() + } + + val context = SSLContext.getInstance("TLS") + context.init(Array[KeyManager](), Array(NoCheckX509TrustManager), null) + context + } + + // https://doc.akka.io/docs/akka-http/current/client-side/client-https-support.html#disabling-hostname-verification + private val insecureConnectionContext = ConnectionContext.httpsClient {(host,port)=> + val engine = trustfulSslContext.createSSLEngine(host,port) + engine.setUseClientMode(true) + engine + } + + override def create[T]() + : Flow[(HttpRequest, T), (HttpRequest, Try[HttpResponse], T), NotUsed] = { + Flow[(HttpRequest, T)].map { + case (request, state) => (request, (request, state)) + }.via{ + http.superPool[(HttpRequest, T)]( + settings = poolSettings, + connectionContext = if(verifySslCertificate) http.defaultClientHttpsContext else insecureConnectionContext + ).map { + case (response, (request, state)) => (request, response, state) + } + } + } + + override def shutdown(): Future[Unit] = http.shutdownAllConnectionPools() +} diff --git a/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/HttpPoolFactory.scala b/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/HttpPoolFactory.scala new file mode 100644 index 000000000..7ae0cf9ae --- /dev/null +++ b/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/HttpPoolFactory.scala @@ -0,0 +1,18 @@ +package com.sksamuel.elastic4s.pekko + +import org.apache.pekko.NotUsed +import org.apache.pekko.http.scaladsl.model.{HttpRequest, HttpResponse} +import org.apache.pekko.stream.scaladsl.Flow + +import scala.concurrent.Future +import scala.util.Try + +/** + * Factory for Pekko's http pool flow. + */ +private[pekko] trait HttpPoolFactory { + + def create[T](): Flow[(HttpRequest, T), (HttpRequest, Try[HttpResponse], T), NotUsed] + + def shutdown(): Future[Unit] +} diff --git a/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/PekkoHttpClient.scala b/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/PekkoHttpClient.scala new file mode 100644 index 000000000..b95aeb18c --- /dev/null +++ b/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/PekkoHttpClient.scala @@ -0,0 +1,310 @@ +package com.sksamuel.elastic4s.pekko + +import org.apache.pekko.actor.ActorSystem +import org.apache.pekko.http.scaladsl.model.Uri.Query +import org.apache.pekko.http.scaladsl.model._ +import org.apache.pekko.http.scaladsl.model.headers.{BasicHttpCredentials, RawHeader} +import org.apache.pekko.stream.scaladsl.{FileIO, Keep, Sink, Source, StreamConverters} +import org.apache.pekko.stream.{Materializer, OverflowStrategy, QueueOfferResult} +import org.apache.pekko.util.ByteString +import com.sksamuel.elastic4s.HttpEntity.StringEntity +import com.sksamuel.elastic4s.{ + ElasticRequest, + HttpClient => ElasticHttpClient, + HttpEntity => ElasticHttpEntity, + HttpResponse => ElasticHttpResponse +} + +import scala.concurrent.{Future, Promise} +import scala.util.{Failure, Success, Try} + +class PekkoHttpClient private[pekko]( + settings: PekkoHttpClientSettings, + blacklist: Blacklist, + httpPoolFactory: HttpPoolFactory +)(implicit system: ActorSystem) + extends ElasticHttpClient { + + import PekkoHttpClient._ + import system.dispatcher + + private implicit val materializer: Materializer = Materializer(system) + + private val scheme = if (settings.https) "https" else "http" + + private val queue = + Source + .queue[(ElasticRequest, RequestState)]( + settings.queueSize, + OverflowStrategy.backpressure + ) + .statefulMapConcat { () => + val hosts = iterateHosts + + in => + { + (in, hosts.next()) match { + case ((r, s), Some(host)) => + // if host is resolved - send request forward + toRequest(r, host) match { + case Success(req) => + s.host.success(host) + (req, s) :: Nil + case Failure(e) => + s.host.failure(e) + s.response.failure(e) + Nil + } + case ((_, s), None) => + // if not - all hosts are blacklisted, return an error + val exception = AllHostsBlacklistedException + s.host.failure(exception) + s.response.failure(exception) + Nil + } + } + } + .via(httpPoolFactory.create[RequestState]()) + .flatMapMerge( + settings.poolSettings.maxConnections, { + case (request, Success(response), state) + if request.method == HttpMethods.HEAD => + response.discardEntityBytes() + Source.single( + (Success(toResponse(response, ByteString.empty)), state) + ) + case (_, Success(r), s) => + r.entity.dataBytes + .fold(ByteString())(_ ++ _) + .map(data => (Success(toResponse(r, data)), s)) + .recoverWithRetries(1, { // in case of TCP timeout or response subscription timeout, etc. + case t: Throwable => + Source.single(Failure(t), s) + }) + case (_, Failure(e), s) => Source.single(Failure(e), s) + } + ) + .toMat(Sink.foreach({ + case (Success(resp), s) => s.response.success(resp) + case (Failure(e), s) => s.response.failure(e) + }))(Keep.left) + .run() + + /** + * Iterator of Some(host) or None if all hosts are blacklisted. + */ + private def iterateHosts: Iterator[Option[String]] = + Iterator + .continually(settings.hosts) + .flatten + .flatMap { host => + if (blacklist.contains(host)) { + logger.trace(s"[$host] is in blacklist") + if (blacklist.size < settings.hosts.size) Nil + else None :: Nil + } else Some(host) :: Nil + } + + private def queueRequest(request: ElasticRequest, + state: RequestState): Future[ElasticHttpResponse] = { + queue.offer(request -> state).flatMap { + case QueueOfferResult.Enqueued => state.response.future + case QueueOfferResult.Dropped => + Future.failed(new Exception("Queue overflowed. Try again later.")) + case QueueOfferResult.Failure(ex) => Future.failed(ex) + case QueueOfferResult.QueueClosed => + Future.failed( + new Exception( + "Queue was closed (pool shut down) while running the request. Try again later." + ) + ) + } + } + + private def queueRequestWithRetry( + request: ElasticRequest, + startTimeNanos: Long = System.nanoTime + ): Future[ElasticHttpResponse] = { + + val state = RequestState() + + def retryIfPossible( + notPossible: => Either[Throwable, ElasticHttpResponse] + ): Future[ElasticHttpResponse] = { + val timePassed = System.nanoTime - startTimeNanos + if (timePassed < settings.maxRetryTimeout.toNanos) { + logger.trace(s"Retrying a request: ${request.endpoint}") + queueRequestWithRetry(request, startTimeNanos) + } else { + notPossible match { + case Left(exc) => + Future.failed( + new Exception( + s"Request retries exceeded max retry timeout [${settings.maxRetryTimeout}]", + exc + ) + ) + case Right(resp) => + Future.successful(resp) + } + } + } + + def markDead(): Future[Unit] = { + state.host.future + .map { host => + if (blacklist.add(host)) { + logger.debug(s"added [$host] to blacklist") + } else { + logger.trace(s"updated [$host] in a blacklist") + } + } + } + + def markAlive(): Future[Unit] = { + state.host.future + .map { host => + if (blacklist.remove(host)) { + logger.debug(s"removed [$host] from blacklist") + } + } + } + + queueRequest(request, state) + .flatMap { response => + val status = StatusCode.int2StatusCode(response.statusCode) + if (status.isSuccess()) { + markAlive().map(_ => response) + } else { + if (isRetryStatus(status)) { + markDead().flatMap(_ => retryIfPossible(Right(response))) + } else { + // mark host alive and don't retry, as the error should be a request problem + markAlive().map(_ => response) + } + } + } + .recoverWith { + case err @ AllHostsBlacklistedException => retryIfPossible(Left(err)) + case err: Throwable => + markDead().flatMap(_ => retryIfPossible(Left(err))) + } + } + + private def isRetryStatus(statusCode: StatusCode) = { + statusCode match { + case StatusCodes.BadGateway => true + case StatusCodes.ServiceUnavailable => true + case StatusCodes.GatewayTimeout => true + case _ => false + } + } + + private[pekko] def sendAsync( + request: ElasticRequest + ): Future[ElasticHttpResponse] = { + queueRequestWithRetry(request) + } + + override def send( + request: ElasticRequest, + callback: Either[Throwable, ElasticHttpResponse] => Unit + ): Unit = { + sendAsync(request).onComplete { + case Success(r) => callback(Right(r)) + case Failure(e) => callback(Left(e)) + } + } + + def shutdown(): Future[Unit] = { + httpPoolFactory.shutdown() + } + + override def close(): Unit = { + shutdown() + } + + private def toRequest(request: ElasticRequest, + host: String): Try[HttpRequest] = Try { + val httpRequest = HttpRequest( + method = HttpMethods + .getForKeyCaseInsensitive(request.method) + .getOrElse(HttpMethod.custom(request.method)), + uri = Uri(request.endpoint) + .withQuery(Query(request.params)) + .withAuthority(Uri.Authority.parse(host)) + .withScheme(scheme), + headers = request.headers.map((RawHeader.apply _).tupled).toList, + entity = request.entity.map(toEntity).getOrElse(HttpEntity.Empty) + ) + + settings.requestCallback( + if (settings.hasCredentialsDefined) { + httpRequest.addCredentials(BasicHttpCredentials(settings.username.get, settings.password.get)) + } else { + httpRequest + } + ) + } + + private def toResponse(response: HttpResponse, + data: ByteString): ElasticHttpResponse = { + ElasticHttpResponse( + response.status.intValue(), + Some(StringEntity(data.utf8String, None)), + response.headers.map(h => h.name -> h.value).toMap + ) + } + + private def toEntity(entity: ElasticHttpEntity): RequestEntity = { + entity match { + case ElasticHttpEntity.StringEntity(content, contentType) => + val ct = + contentType + .flatMap(value => ContentType.parse(value).right.toOption) + .getOrElse(ContentTypes.`text/plain(UTF-8)`) + HttpEntity(ct, ByteString(content)) + case ElasticHttpEntity.ByteArrayEntity(content, contentType) => + val ct = + contentType + .flatMap(value => ContentType.parse(value).right.toOption) + .getOrElse(ContentTypes.`text/plain(UTF-8)`) + HttpEntity(ct, ByteString(content)) + case ElasticHttpEntity.FileEntity(file, contentType) => + val ct = contentType + .flatMap(value => ContentType.parse(value).right.toOption) + .getOrElse(ContentTypes.`application/octet-stream`) + HttpEntity(ct, file.length, FileIO.fromPath(file.toPath)) + case ElasticHttpEntity.InputStreamEntity(stream, contentType) => + val ct = contentType + .flatMap(value => ContentType.parse(value).right.toOption) + .getOrElse(ContentTypes.`application/octet-stream`) + HttpEntity(ct, StreamConverters.fromInputStream(() => stream)) + } + } +} + +object PekkoHttpClient { + + def apply( + settings: PekkoHttpClientSettings + )(implicit system: ActorSystem): PekkoHttpClient = { + + val blacklist = new DefaultBlacklist( + settings.blacklistMinDuration, + settings.blacklistMaxDuration + ) + + val httpPoolFactory = new DefaultHttpPoolFactory(settings.poolSettings, settings.verifySSLCertificate) + + new PekkoHttpClient(settings, blacklist, httpPoolFactory) + } + + private[pekko] case class RequestState(response: Promise[ElasticHttpResponse] = + Promise(), + host: Promise[String] = Promise()) + + private[pekko] case object AllHostsBlacklistedException + extends Exception("All hosts are blacklisted!") + +} diff --git a/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/PekkoHttpClientSettings.scala b/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/PekkoHttpClientSettings.scala new file mode 100644 index 000000000..97321ac49 --- /dev/null +++ b/elastic4s-client-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/PekkoHttpClientSettings.scala @@ -0,0 +1,83 @@ +package com.sksamuel.elastic4s.pekko + +import java.util.concurrent.TimeUnit + +import org.apache.pekko.http.scaladsl.model.HttpRequest +import org.apache.pekko.http.scaladsl.settings.ConnectionPoolSettings +import com.typesafe.config.{Config, ConfigFactory} + +import scala.collection.JavaConverters._ +import scala.concurrent.duration._ +import scala.util.Try + +object PekkoHttpClientSettings { + + private def defaultConfig: Config = + ConfigFactory.load().getConfig("com.sksamuel.elastic4s.pekko") + + lazy val default: PekkoHttpClientSettings = apply(defaultConfig) + + def apply(config: Config): PekkoHttpClientSettings = { + val cfg = config.withFallback(defaultConfig) + val hosts = cfg.getStringList("hosts").asScala.toVector + val username = Try(cfg.getString("username")).map(Some(_)).getOrElse(None) + val password = Try(cfg.getString("password")).map(Some(_)).getOrElse(None) + val queueSize = cfg.getInt("queue-size") + val https = cfg.getBoolean("https") + val verifySslCertificate = Try(cfg.getBoolean("verify-ssl-certificate")).toOption.getOrElse(true) + val blacklistMinDuration = Duration( + cfg.getDuration("blacklist.min-duration", TimeUnit.MILLISECONDS), + TimeUnit.MILLISECONDS + ) + val blacklistMaxDuration = Duration( + cfg.getDuration("blacklist.max-duration", TimeUnit.MILLISECONDS), + TimeUnit.MILLISECONDS + ) + val maxRetryTimeout = Duration( + cfg.getDuration("max-retry-timeout", TimeUnit.MILLISECONDS), + TimeUnit.MILLISECONDS + ) + val poolSettings = ConnectionPoolSettings( + cfg.withFallback(ConfigFactory.load()) + ) + PekkoHttpClientSettings( + https, + hosts, + username, + password, + queueSize, + poolSettings, + verifySslCertificate, + blacklistMinDuration, + blacklistMaxDuration, + maxRetryTimeout + ) + } + + def apply(): PekkoHttpClientSettings = { + default + } + + def apply(hosts: Seq[String]): PekkoHttpClientSettings = { + apply().copy(hosts = hosts.toVector) + } +} + +case class PekkoHttpClientSettings( + https: Boolean, + hosts: Vector[String], + username: Option[String], + password: Option[String], + queueSize: Int, + poolSettings: ConnectionPoolSettings, + verifySSLCertificate : Boolean, + blacklistMinDuration: FiniteDuration = + PekkoHttpClientSettings.default.blacklistMinDuration, + blacklistMaxDuration: FiniteDuration = + PekkoHttpClientSettings.default.blacklistMaxDuration, + maxRetryTimeout: FiniteDuration = + PekkoHttpClientSettings.default.maxRetryTimeout, + requestCallback: HttpRequest => HttpRequest = identity +) { + def hasCredentialsDefined: Boolean = username.isDefined && password.isDefined +} diff --git a/elastic4s-client-pekko/src/test/scala/com/sksamuel/elastic4s/pekko/DefaultBlacklistTest.scala b/elastic4s-client-pekko/src/test/scala/com/sksamuel/elastic4s/pekko/DefaultBlacklistTest.scala new file mode 100644 index 000000000..c803f3e21 --- /dev/null +++ b/elastic4s-client-pekko/src/test/scala/com/sksamuel/elastic4s/pekko/DefaultBlacklistTest.scala @@ -0,0 +1,84 @@ +package com.sksamuel.elastic4s.pekko + +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.concurrent.duration._ +import scala.language.postfixOps + +class DefaultBlacklistTest extends AnyWordSpec with Matchers { + + val minDuration = 1 second + val maxDuration = 10 seconds + val host = "elastic.test" + + "DefaultBlacklist" should { + + "add host to blacklist" in { + val blacklist = new DefaultBlacklist(minDuration, maxDuration) + blacklist.add(host) shouldBe true + blacklist.contains(host) shouldBe true + } + + "remove host from blacklist" in { + val blacklist = new DefaultBlacklist(minDuration, maxDuration) + blacklist.add(host) + blacklist.remove(host) shouldBe true + blacklist.contains(host) shouldBe false + } + + "ensure host is in blacklist" in { + val blacklist = new DefaultBlacklist(minDuration, maxDuration) + blacklist.add(host) shouldBe true + blacklist.add(host) shouldBe false + } + + "ensure host is not in blacklist" in { + val blacklist = new DefaultBlacklist(minDuration, maxDuration) + blacklist.remove(host) shouldBe false + } + + "remove host from blacklist on timeout" in { + var now: Long = 0 + val blacklist = new DefaultBlacklist(minDuration, maxDuration, now) + + blacklist.add(host) + blacklist.contains(host) shouldBe true + + now += minDuration.toNanos + blacklist.contains(host) shouldBe false + } + + "increase blacklist timeout up to max" in { + var now: Long = 0 + val blacklist = new DefaultBlacklist(minDuration, maxDuration, now) + + blacklist.add(host) + + now += minDuration.toNanos + blacklist.contains(host) shouldBe false + + // after first blacklist timed out add it again + blacklist.add(host) + + // check that the same time increase now doesn't result in invalidated blacklist record + now += minDuration.toNanos + blacklist.contains(host) shouldBe true + + // now when more time elapses it should invalidate it again + now += maxDuration.toNanos + blacklist.contains(host) shouldBe false + } + + "not increase blacklist timeout on early `add`" in { + var now: Long = 0 + val blacklist = new DefaultBlacklist(minDuration, maxDuration, now) + + blacklist.add(host) + now = minDuration.toNanos / 2 + blacklist.add(host) + now = minDuration.toNanos + blacklist.contains(host) shouldBe false + } + } +} diff --git a/elastic4s-client-pekko/src/test/scala/com/sksamuel/elastic4s/pekko/PekkoHttpClientMockTest.scala b/elastic4s-client-pekko/src/test/scala/com/sksamuel/elastic4s/pekko/PekkoHttpClientMockTest.scala new file mode 100644 index 000000000..76a29fe97 --- /dev/null +++ b/elastic4s-client-pekko/src/test/scala/com/sksamuel/elastic4s/pekko/PekkoHttpClientMockTest.scala @@ -0,0 +1,218 @@ +package com.sksamuel.elastic4s.pekko + +import org.apache.pekko.actor.ActorSystem +import org.apache.pekko.http.scaladsl.model.{HttpRequest, HttpResponse, StatusCodes, Uri} +import com.sksamuel.elastic4s.{ElasticRequest, HttpEntity => ElasticEntity, HttpResponse => ElasticResponse} +import org.mockito.ArgumentMatchers._ +import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent._ +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec +import org.scalatestplus.mockito.MockitoSugar +import org.mockito.Mockito._ + +import scala.concurrent.duration._ +import scala.util.{Failure, Success, Try} + +class PekkoHttpClientMockTest + extends AnyWordSpec + with Matchers + with MockitoSugar + with ScalaFutures + with IntegrationPatience + with BeforeAndAfterAll { + + private implicit lazy val system: ActorSystem = ActorSystem() + + override def afterAll(): Unit = { + system.terminate() + } + + def mockHttpPool(): (Function[HttpRequest, Try[HttpResponse]], TestHttpPoolFactory) = { + val sendRequest = mock[Function[HttpRequest, Try[HttpResponse]]] + val poolFactory = new TestHttpPoolFactory(sendRequest) + (sendRequest, poolFactory) + } + + "PekkoHttpClient" should { + + "retry on 502" in { + + val hosts = List( + "host1", + "host2" + ) + + val blacklist = mock[Blacklist] + + val (sendRequest, httpPool) = mockHttpPool() + + val client = + new PekkoHttpClient(PekkoHttpClientSettings(hosts), blacklist, httpPool) + + when(blacklist.contains("host1")).thenReturn(false) + when(blacklist.contains("host2")).thenReturn(false) + when(blacklist.add("host1")).thenReturn(true) + when(blacklist.remove("host2")).thenReturn(false) + + when(sendRequest + .apply(argThat { (r: HttpRequest) => + r != null && r.uri == Uri("http://host1/test") + })) + .thenReturn(Success(HttpResponse(StatusCodes.BadGateway))) + + when(sendRequest + .apply(argThat { (r: HttpRequest) => + r != null && r.uri == Uri("http://host2/test") + })) + .thenReturn(Success(HttpResponse().withEntity("ok"))) + + client + .sendAsync(ElasticRequest("GET", "/test")) + .futureValue shouldBe ElasticResponse( + 200, + Some(ElasticEntity.StringEntity("ok", None)), + Map.empty) + } + + "don't retry if no time left" in { + + val hosts = List( + "host1", + "host2" + ) + + val blacklist = mock[Blacklist] + + val (sendRequest, httpPool) = mockHttpPool() + + val client = + new PekkoHttpClient( + PekkoHttpClientSettings(hosts).copy(maxRetryTimeout = 0.seconds), + blacklist, + httpPool) + + when(blacklist.contains("host1")).thenReturn(false) + when(blacklist.add("host1")).thenReturn(true) + + when(sendRequest + .apply(argThat { (r: HttpRequest) => + r.uri == Uri("http://host1/test") + })) + .thenReturn(Success(HttpResponse(StatusCodes.BadGateway))) + + client + .sendAsync(ElasticRequest("GET", "/test")) + .futureValue shouldBe ElasticResponse( + 502, + Some(ElasticEntity.StringEntity("", None)), + Map.empty) + } + + "blacklist on 502" in { + + val hosts = List( + "host1", + "host2" + ) + + val blacklist = mock[Blacklist] + + val (sendRequest, httpPool) = mockHttpPool() + + val client = + new PekkoHttpClient(PekkoHttpClientSettings(hosts), blacklist, httpPool) + + when(blacklist.contains("host1")).thenReturn(false) + when(blacklist.contains("host2")).thenReturn(false) + when(blacklist.add("host1")).thenReturn(true) + when(blacklist.remove("host2")).thenReturn(false) + + when(sendRequest + .apply(argThat { (r: HttpRequest) => + r != null && r.uri == Uri("http://host1/test") + })) + .thenReturn(Success(HttpResponse(StatusCodes.BadGateway))) + + when(sendRequest + .apply(argThat { (r: HttpRequest) => + r.uri == Uri("http://host2/test") + })) + .thenReturn(Success(HttpResponse().withEntity("host2"))) + + client + .sendAsync(ElasticRequest("GET", "/test")) + .futureValue + } + + "blacklist on exception" in { + + val hosts = List( + "host1", + "host2" + ) + + val blacklist = mock[Blacklist] + + val (sendRequest, httpPool) = mockHttpPool() + + val client = + new PekkoHttpClient(PekkoHttpClientSettings(hosts), blacklist, httpPool) + + when(blacklist.contains("host1")).thenReturn(false) + when(blacklist.contains("host2")).thenReturn(false) + when(blacklist.add("host1")).thenReturn(true) + when(blacklist.remove("host2")).thenReturn(false) + + when(sendRequest + .apply(argThat { (r: HttpRequest) => + r != null && r.uri == Uri("http://host1/test") + })) + .thenReturn(Failure(new Exception("Some exception"))) + + when(sendRequest + .apply(argThat { (r: HttpRequest) => + r != null && r.uri == Uri("http://host2/test") + })) + .thenReturn(Success(HttpResponse().withEntity("host2"))) + + client + .sendAsync(ElasticRequest("GET", "/test")) + .futureValue + } + + "skip blacklisted hosts" in { + + val hosts = List( + "host1", + "host2" + ) + + val blacklist = mock[Blacklist] + + val (sendRequest, httpPool) = mockHttpPool() + + val client = + new PekkoHttpClient(PekkoHttpClientSettings(hosts), blacklist, httpPool) + + when(blacklist.contains("host1")).thenReturn(true) + when(blacklist.size).thenReturn(1) + when(blacklist.contains("host2")).thenReturn(false) + when(blacklist.remove("host2")).thenReturn(false) + + when(sendRequest + .apply(argThat { (r: HttpRequest) => + r.uri == Uri("http://host2/test") + })) + .thenReturn(Success(HttpResponse().withEntity("host2"))) + + client + .sendAsync(ElasticRequest("GET", "/test")) + .futureValue shouldBe ElasticResponse( + 200, + Some(ElasticEntity.StringEntity("host2", None)), + Map.empty) + } + + } +} diff --git a/elastic4s-client-pekko/src/test/scala/com/sksamuel/elastic4s/pekko/PekkoHttpClientTest.scala b/elastic4s-client-pekko/src/test/scala/com/sksamuel/elastic4s/pekko/PekkoHttpClientTest.scala new file mode 100644 index 000000000..4393b8a9c --- /dev/null +++ b/elastic4s-client-pekko/src/test/scala/com/sksamuel/elastic4s/pekko/PekkoHttpClientTest.scala @@ -0,0 +1,121 @@ +package com.sksamuel.elastic4s.pekko + +import org.apache.pekko.actor.ActorSystem +import com.sksamuel.elastic4s.{ElasticClient, ElasticRequest, Executor, HttpClient, HttpResponse} +import com.sksamuel.elastic4s.requests.common.HealthStatus +import com.sksamuel.elastic4s.testkit.DockerTests +import org.scalatest.BeforeAndAfterAll +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import java.nio.charset.StandardCharsets +import java.util.Base64 +import scala.concurrent.Future +import scala.util.Try + +class PekkoHttpClientTest extends AnyFlatSpec with Matchers with DockerTests with BeforeAndAfterAll { + + private implicit lazy val system: ActorSystem = ActorSystem() + + override def beforeAll(): Unit = { + Try { + client.execute { + deleteIndex("testindex") + }.await + } + } + + override def afterAll(): Unit = { + Try { + client.execute { + deleteIndex("testindex") + }.await + + pekkoClient.shutdown().await + system.terminate().await + } + } + + private lazy val pekkoClient = PekkoHttpClient(PekkoHttpClientSettings(List(s"$elasticHost:$elasticPort"))) + + override val client = ElasticClient(pekkoClient) + + "PekkoHttpClient" should "support utf-8" in { + + client.execute { + indexInto("testindex").doc("""{ "text":"¡Hola! ¿Qué tal?" }""") + }.await.result.result shouldBe "created" + } + + it should "work fine whith _cat endpoints " in { + + client.execute { + catSegments() + }.await.result + + client.execute { + catShards() + }.await.result + + client.execute { + catNodes() + }.await.result + + client.execute { + catPlugins() + }.await.result + + client.execute { + catThreadPool() + }.await.result + + client.execute { + catHealth() + }.await.result + + client.execute { + catCount() + }.await.result + + client.execute { + catMaster() + }.await.result + + client.execute { + catAliases() + }.await.result + + client.execute { + catIndices() + }.await.result + + client.execute { + catIndices(HealthStatus.Green) + }.await.result + + client.execute { + catAllocation() + }.await.result + + } + + it should "work with head methods" in { + client.execute( + indexExists("unknown_index") + ).await.result + } + + it should "propagate headers if included" in { + implicit val executor: Executor[Future] = new Executor[Future] { + override def exec(client: HttpClient, request: ElasticRequest): Future[HttpResponse] = { + val cred = Base64.getEncoder.encodeToString("user123:pass123".getBytes(StandardCharsets.UTF_8)) + Executor.FutureExecutor.exec(client, request.copy(headers = Map("Authorization" -> s"Basic $cred"))) + } + } + + client.execute { + catHealth() + }.await.result.status shouldBe "401" + } +} + diff --git a/elastic4s-client-pekko/src/test/scala/com/sksamuel/elastic4s/pekko/TestHttpPoolFactory.scala b/elastic4s-client-pekko/src/test/scala/com/sksamuel/elastic4s/pekko/TestHttpPoolFactory.scala new file mode 100644 index 000000000..56c70c387 --- /dev/null +++ b/elastic4s-client-pekko/src/test/scala/com/sksamuel/elastic4s/pekko/TestHttpPoolFactory.scala @@ -0,0 +1,22 @@ +package com.sksamuel.elastic4s.pekko + +import org.apache.pekko.NotUsed +import org.apache.pekko.http.scaladsl.model.{HttpRequest, HttpResponse} +import org.apache.pekko.stream.scaladsl.Flow + +import scala.concurrent.Future +import scala.concurrent.duration.{FiniteDuration, _} +import scala.util.Try + +class TestHttpPoolFactory(sendRequest: HttpRequest => Try[HttpResponse], + timeout: FiniteDuration = 2.seconds) extends HttpPoolFactory { + + override def create[T](): Flow[(HttpRequest, T), (HttpRequest, Try[HttpResponse], T), NotUsed] = { + Flow[(HttpRequest, T)] + .map { + case (r, s) => (r, Try(sendRequest(r)).flatten, s) + } + } + + override def shutdown(): Future[Unit] = Future.successful(()) +} diff --git a/elastic4s-streams-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/streams/BatchElasticSink.scala b/elastic4s-streams-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/streams/BatchElasticSink.scala new file mode 100644 index 000000000..ba0a73f59 --- /dev/null +++ b/elastic4s-streams-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/streams/BatchElasticSink.scala @@ -0,0 +1,81 @@ +package com.sksamuel.elastic4s.pekko.streams + +import com.sksamuel.elastic4s._ +import com.sksamuel.elastic4s.handlers.bulk.BulkHandlers +import com.sksamuel.elastic4s.requests.bulk.{BulkCompatibleRequest, BulkRequest, BulkResponse} +import com.sksamuel.elastic4s.requests.common.RefreshPolicy +import org.apache.pekko.stream.stage.{GraphStage, GraphStageLogic, InHandler} +import org.apache.pekko.stream.{Attributes, Inlet, SinkShape} + +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success, Try} + +case class SinkSettings(refreshAfterOp: Boolean = false) + +class BatchElasticSink[T](client: ElasticClient, settings: SinkSettings)(implicit + ec: ExecutionContext, + builder: RequestBuilder[T]) + extends GraphStage[SinkShape[Seq[T]]] { + + private val in: Inlet[Seq[T]] = Inlet.create("ElasticSink.out") + override val shape: SinkShape[Seq[T]] = SinkShape.of(in) + + private implicit val bulkHandler: BulkHandlers.BulkHandler.type = BulkHandlers.BulkHandler + private implicit val executor: Executor[Future] = Executor.FutureExecutor + private implicit val functor: Functor[Future] = Functor.FutureFunctor + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) { + + private val handler: InHandler = new InHandler { + override def onPush(): Unit = { + val seq = grab(in) + index(seq.map(builder.request)) + } + } + + override def preStart(): Unit = pull(in) + + setHandler(in, handler) + + private def callBack(requests: Seq[BulkCompatibleRequest]) = + getAsyncCallback[Try[Response[BulkResponse]]] { + case Failure(t) => failStage(t) + case Success(resp) => + resp match { + case RequestFailure(_, _, _, error) => failStage(error.asException) + case RequestSuccess(_, _, _, result) => + val failedRequests = result.failures.map { item => + requests(item.itemId) + } + if (failedRequests.nonEmpty) + index(failedRequests) + else + pull(in) + } + } + + private def index(requests: Seq[BulkCompatibleRequest]): Unit = { + + val policy = if (settings.refreshAfterOp) RefreshPolicy.Immediate else RefreshPolicy.NONE + val f = client.execute { + BulkRequest(requests).refresh(policy) + } + f.onComplete(callBack(requests).invoke) + + } + + } +} + +/** + * An implementation of this typeclass must provide a bulk compatible request for the given instance of T. + * The bulk compatible request will then be sent to elastic. + * + * A bulk compatible request can be either an index, update, or delete. + * + * @tparam T the type of elements this builder supports + */ +trait RequestBuilder[T] { + def request(t: T): BulkCompatibleRequest +} diff --git a/elastic4s-streams-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/streams/ElasticSource.scala b/elastic4s-streams-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/streams/ElasticSource.scala new file mode 100644 index 000000000..49303a44d --- /dev/null +++ b/elastic4s-streams-pekko/src/main/scala/com/sksamuel/elastic4s/pekko/streams/ElasticSource.scala @@ -0,0 +1,94 @@ +package com.sksamuel.elastic4s.pekko.streams + +import com.sksamuel.elastic4s.ElasticDsl.searchScroll +import com.sksamuel.elastic4s._ +import com.sksamuel.elastic4s.requests.searches._ +import org.apache.pekko.stream.stage.{GraphStage, GraphStageLogic, OutHandler} +import org.apache.pekko.stream.{Attributes, Outlet, SourceShape} + +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success, Try} + +/** + * A Pekko [[org.apache.pekko.stream.scaladsl.Source]], that publishes documents using an elasticsearch + * scroll cursor. The initial query must be provided to the source, and there are helpers to create + * a query for all documents in an index. + * + * @param client a client for the cluster + * @param settings settings for how documents are queried + */ +class ElasticSource(client: ElasticClient, settings: SourceSettings) + (implicit ec: ExecutionContext) extends GraphStage[SourceShape[SearchHit]] { + require(settings.search.keepAlive.isDefined, "The SearchRequest must have a scroll defined (a keep alive time)") + + private val out: Outlet[SearchHit] = Outlet.create("ElasticSource.out") + override val shape: SourceShape[SearchHit] = SourceShape.of(out) + + private implicit val searchHandler: Handler[SearchRequest, SearchResponse] = SearchHandlers.SearchHandler + private implicit val scrollHandler: Handler[SearchScrollRequest, SearchResponse] = SearchScrollHandlers.SearchScrollHandler + private implicit val executor: Executor[Future] = Executor.FutureExecutor + private implicit val functor: Functor[Future] = Functor.FutureFunctor + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with OutHandler { + + private val buffer = scala.collection.mutable.Queue.empty[SearchHit] + private var scrollId: String = _ + private var fetching = false + + // Parse the keep alive setting out of the original query. + private val keepAlive = settings.search.keepAlive.map(_.toString).getOrElse("1m") + + if (settings.warm) + fetch() + + private val populateHandler = getAsyncCallback[Try[Response[SearchResponse]]] { + case Failure(e) => fail(out, e) + case Success(response) => response match { + case RequestFailure(_, _, _, error) => fail(out, error.asException) + case RequestSuccess(_, _, _, searchr) => + searchr.scrollId match { + case None => fail(out, new RuntimeException("Search response did not include a scroll id")) + case Some(id) => + scrollId = id + fetching = false + buffer ++= searchr.hits.hits + if (buffer.nonEmpty && isAvailable(out)) { + push(out, buffer.dequeue) + maybeFetch() + } + // complete when no more elements to emit + if (searchr.hits.hits.length == 0) { + complete(out) + } + } + } + } + + // check if the buffer has dropped below threshold (or is empty) and if so, trigger a fetch + private def maybeFetch(): Unit = { + if (buffer.isEmpty || buffer.size <= settings.fetchThreshold) + fetch() + } + + // if no fetch is in progress then fire one + private def fetch(): Unit = { + if (!fetching) { + Option(scrollId) match { + case None => client.execute(settings.search).onComplete(populateHandler.invoke) + case Some(id) => client.execute(searchScroll(id).keepAlive(keepAlive)).onComplete(populateHandler.invoke) + } + fetching = true + } + } + + override def onPull(): Unit = { + if (buffer.nonEmpty) + push(out, buffer.dequeue) + maybeFetch() + } + + setHandler(out, this) + } +} + +case class SourceSettings(search: SearchRequest, maxItems: Long, fetchThreshold: Int = 0, warm: Boolean) diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 0c9d731de..ab62964ef 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -16,9 +16,10 @@ object Dependencies { val Log4jVersion = "2.15.0" val MockitoVersion = "5.4.0" val MonixVersion = "3.4.1" + val PekkoHttpVersion = "1.0.0" + val PekkoVersion = "1.0.1" val PlayJsonVersion = "2.10.0-RC6" val ReactiveStreamsVersion = "1.0.3" - val ScalamockVersion = "5.2.0" val ScalatestPlusMockitoArtifactId = "mockito-3-4" val ScalatestPlusVersion = "3.2.9.0" val ScalazVersion = "7.2.35" @@ -65,12 +66,15 @@ object Dependencies { lazy val cats2 = "org.typelevel" %% "cats-effect" % CatsEffect2Version lazy val elasticsearchRestClient = "org.elasticsearch.client" % "elasticsearch-rest-client" % ElasticsearchVersion lazy val json4s = Seq("org.json4s" %% "json4s-core" % Json4sVersion, "org.json4s" %% "json4s-jackson" % Json4sVersion) - lazy val monix = "io.monix" %% "monix" % MonixVersion - lazy val playJson = Seq("com.typesafe.play" %% "play-json" % PlayJsonVersion) - lazy val sprayJson = Seq("io.spray" %% "spray-json" % SprayJsonVersion) - lazy val sttp = "com.softwaremill.sttp.client3" %% "core" % SttpVersion - lazy val zioJson1 = "dev.zio" %% "zio-json" % ZIOJson1Version - lazy val zioJson = "dev.zio" %% "zio-json" % ZIOJsonVersion + lazy val monix = "io.monix" %% "monix" % MonixVersion + lazy val pekkoActor = "org.apache.pekko" %% "pekko-actor" % PekkoVersion + lazy val pekkoHTTP = "org.apache.pekko" %% "pekko-http" % PekkoHttpVersion + lazy val pekkoStream = "org.apache.pekko" %% "pekko-stream" % PekkoVersion + lazy val playJson = Seq("com.typesafe.play" %% "play-json" % PlayJsonVersion) + lazy val sprayJson = Seq("io.spray" %% "spray-json" % SprayJsonVersion) + lazy val sttp = "com.softwaremill.sttp.client3" %% "core" % SttpVersion + lazy val zioJson1 = "dev.zio" %% "zio-json" % ZIOJson1Version + lazy val zioJson = "dev.zio" %% "zio-json" % ZIOJsonVersion lazy val elasticsearchRestClientSniffer = "org.elasticsearch.client" % "elasticsearch-rest-client-sniffer" % ElasticsearchVersion @@ -78,7 +82,6 @@ object Dependencies { lazy val log4jApi = "org.apache.logging.log4j" % "log4j-api" % Log4jVersion % "test" lazy val mockitoCore = "org.mockito" % "mockito-core" % MockitoVersion % "test" lazy val reactiveStreamsTck = "org.reactivestreams" % "reactive-streams-tck" % ReactiveStreamsVersion % "test" - lazy val scalaMock = "org.scalamock" %% "scalamock" % ScalamockVersion % "test" lazy val scalaTestMain = "org.scalatest" %% "scalatest" % ScalatestVersion lazy val scalaTest = scalaTestMain % "test" lazy val scalaTestPlusMokito = "org.scalatestplus" %% ScalatestPlusMockitoArtifactId % ScalatestPlusVersion