Skip to content

Commit

Permalink
Merge pull request #113 from aegbert5/update-connection-provider-to-s…
Browse files Browse the repository at this point in the history
…upport-aws-writer-failovers

Update ConnectionProvider to support AWS writer failovers
  • Loading branch information
tmccombs authored Sep 19, 2024
2 parents 2e4281e + c571102 commit d0801c5
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 16 deletions.
141 changes: 125 additions & 16 deletions worker/src/main/scala/com/lucidchart/piezo/ConnectionProvider.scala
Original file line number Diff line number Diff line change
@@ -1,25 +1,134 @@
package com.lucidchart.piezo

import org.quartz.utils.HikariCpPoolingConnectionProvider
import java.net.UnknownHostException
import java.sql.{Connection, SQLTransientConnectionException}
import java.util.Properties
import java.util.concurrent.TimeUnit
import org.quartz.utils.HikariCpPoolingConnectionProvider
import org.slf4j.LoggerFactory
import scala.annotation.tailrec

class ConnectionProvider(props: Properties) {

private class Pool(val ip: String) {
val connectionProvider: Option[HikariCpPoolingConnectionProvider] = createNewConnectionProvider()
logger.info(s"Initialized Db connection pool for ${jdbcURL}")
// Hikari takes about a second to add connections to the connection pool
// We are now going to warm-up connectionPool(with timelimit of 2500ms)
connectionProvider.map(warmUpCP)
}

val logger = LoggerFactory.getLogger(this.getClass)
private val dataSource = props.getProperty("org.quartz.jobStore.dataSource")
private val provider = if(dataSource != null) {
Some(new HikariCpPoolingConnectionProvider(
props.getProperty("org.quartz.dataSource." + dataSource + ".driver"),
props.getProperty("org.quartz.dataSource." + dataSource + ".URL"),
props.getProperty("org.quartz.dataSource." + dataSource + ".user"),
props.getProperty("org.quartz.dataSource." + dataSource + ".password"),
props.getProperty("org.quartz.dataSource." + dataSource + ".maxConnections").toInt,
props.getProperty("org.quartz.dataSource." + dataSource + ".validationQuery")
))
} else {
logger.info("No job store found in config")
None
}

def getConnection = provider.get.getConnection
private val jdbcURL = if (dataSource != null) props.getProperty("org.quartz.dataSource." + dataSource + ".URL") else null
private val detectIpAddressFailover = if (dataSource != null) props.getProperty("org.quartz.dataSource." + dataSource + ".ipFailover") == "true" else false
// Removes "jdbc:mysql://" prefix and ":{port}..." suffix
private val dataSourceHostname = if (jdbcURL != null) jdbcURL.replace("jdbc:mysql://", "").split(":")(0) else null

// Time (in milliseconds) that the in-memory cache will retain the ip address
private val cachedIpTTL: Long = 1000
private val getIpNumRetries = 10
// Class for storing the ip address of the host, along with an expiration date
private case class CachedIpWithExpiration(ip: String, expiration: Long)
// Cache for ip address and its expiration for a host
@volatile
private var cachedIpWithExpiration: Option[CachedIpWithExpiration] = None

private val pool: Pool = new Pool(getIP)

// Intended to be used only for tests. This mocks an IP failover every time a connection is retreived
private val causeFailoverEveryConnection = if (dataSource != null) props.getProperty("org.quartz.dataSource." + dataSource + ".causeFailoverEveryConnection") == "true" else false

def createNewConnectionProvider(): Option[HikariCpPoolingConnectionProvider] = {
if(dataSource != null) {
Some(new HikariCpPoolingConnectionProvider(
props.getProperty("org.quartz.dataSource." + dataSource + ".driver"),
jdbcURL,
props.getProperty("org.quartz.dataSource." + dataSource + ".user"),
props.getProperty("org.quartz.dataSource." + dataSource + ".password"),
props.getProperty("org.quartz.dataSource." + dataSource + ".maxConnections").toInt,
props.getProperty("org.quartz.dataSource." + dataSource + ".validationQuery")
))
} else {
logger.info("No job store found in config to get connections")
None
}
}

/**
* HikariCP connection pools don't automatically close when IP addresses for a hostname change. This function returns True, iff at
* least one of the following conditions is met:
* - IP addresses have changed for the CNAME record used for DNS lookup
* - causeFailoverEveryConnection is set to "true", which is used for testing failover functionality
*
* @param pool
* the connection pool currently being used
* @param dnsIP
* the IP returned when performing a DNS lookup
* @return
*/
private def hasIpAddressChanged(pool: Pool, dnsIP: String): Boolean = {
causeFailoverEveryConnection || pool.ip != dnsIP
}

@tailrec
private def retryGettingIp(n: Int)(fn: => String): String = {
try {
return fn
} catch {
// Failed to resolve it from JVM
case e: UnknownHostException if n > 1 =>
}
// Wait 10 milliseconds between retries
Thread.sleep(10)
retryGettingIp(n - 1)(fn)
}

def _getIp: String = {
retryGettingIp(getIpNumRetries) {
// Get the ip address of the hostname. The result is cached in the JVM
val ip = java.net.InetAddress.getByName(dataSourceHostname).getHostAddress
cachedIpWithExpiration = Some(CachedIpWithExpiration(ip, System.currentTimeMillis() + cachedIpTTL))
ip
}
}

def getIP: String = {
synchronized {
cachedIpWithExpiration.map { cachedValue =>
if (System.currentTimeMillis() > cachedValue.expiration) {
_getIp
} else {
cachedValue.ip
}
}.getOrElse(_getIp)
}
}

def getConnection = {
if (detectIpAddressFailover && dataSourceHostname != null) {
// If the IP has changed, then we know a failover has occurred, and we need to create a new hikari config
val newIP: String = getIP
if (hasIpAddressChanged(pool, newIP)) {
// A failover has occurred, so we evict connections softly. New connectons look up the new IP address
pool.connectionProvider.map(_.getDataSource().getHikariPoolMXBean().softEvictConnections())
}
}
pool.connectionProvider.get.getConnection
}

private def warmUpCP(connectionPool: HikariCpPoolingConnectionProvider): Unit = {
var testConn: Connection = null
val start = System.currentTimeMillis
while (testConn == null && (System.currentTimeMillis - start) < 2500) {
try {
testConn = connectionPool.getConnection()
} catch {
case _: SQLTransientConnectionException => { TimeUnit.MILLISECONDS.sleep(100) } // do nothing
}
}
if (testConn != null) {
testConn.close()
}
}
}
1 change: 1 addition & 0 deletions worker/src/test/resources/quartz_test_mysql.properties
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ org.quartz.dataSource.test_jobs.user: root
org.quartz.dataSource.test_jobs.password: root
org.quartz.dataSource.test_jobs.maxConnections: 10
org.quartz.dataSource.test_jobs.validationQuery: select 0
org.quartz.dataSource.test_jobs.ipFailover: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Same configuration as quartz_test_mysql.properties
#============================================================================
# Configure Main Scheduler Properties
#============================================================================

org.quartz.scheduler.instanceName: TestScheduler
org.quartz.scheduler.instanceId: AUTO

org.quartz.scheduler.skipUpdateCheck: true

#============================================================================
# Configure ThreadPool
#============================================================================

org.quartz.threadPool.class: org.quartz.simpl.SimpleThreadPool
org.quartz.threadPool.threadCount: 2
org.quartz.threadPool.threadPriority: 5

#============================================================================
# Configure JobStore
#============================================================================

org.quartz.jobStore.misfireThreshold: 60000
org.quartz.jobStore.class=org.quartz.impl.jdbcjobstore.JobStoreTX
org.quartz.jobStore.dataSource=test_jobs


org.quartz.dataSource.test_jobs.driver: com.mysql.cj.jdbc.Driver
org.quartz.dataSource.test_jobs.URL: jdbc:mysql://localhost:3306/test_jobs
org.quartz.dataSource.test_jobs.user: root
org.quartz.dataSource.test_jobs.password: root
org.quartz.dataSource.test_jobs.maxConnections: 10
org.quartz.dataSource.test_jobs.validationQuery: select 0
org.quartz.dataSource.test_jobs.ipFailover: true

# Unique configuration to this file
org.quartz.dataSource.test_jobs.causeFailoverEveryConnection: true
16 changes: 16 additions & 0 deletions worker/src/test/scala/com/lucidchart/piezo/ModelTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class ModelTest extends Specification with BeforeAll with AfterAll {
val properties = new Properties
properties.load(propertiesStream)

val propertiesStreamFailoverEveryConnection = getClass().getResourceAsStream("/quartz_test_mysql_failover_every_connection.properties")
val propertiesWithFailoverEveryConnection = new Properties
propertiesWithFailoverEveryConnection.load(propertiesStreamFailoverEveryConnection)

val username = properties.getProperty("org.quartz.dataSource.test_jobs.user")
val password = properties.getProperty("org.quartz.dataSource.test_jobs.password")
val dbUrl = properties.getProperty("org.quartz.dataSource.test_jobs.URL")
Expand Down Expand Up @@ -65,6 +69,7 @@ class ModelTest extends Specification with BeforeAll with AfterAll {

"JobHistoryModel" should {
"work correctly" in {
properties.getProperty("org.quartz.dataSource.test_jobs.causeFailoverEveryConnection") must beNull
val jobHistoryModel = new JobHistoryModel(properties)
val jobKey = new JobKey("blah", "blah")
val triggerKey = new TriggerKey("blahtn", "blahtg")
Expand All @@ -74,6 +79,17 @@ class ModelTest extends Specification with BeforeAll with AfterAll {
jobHistoryModel.getLastJobSuccessByTrigger(triggerKey) must beSome
jobHistoryModel.getJobs().nonEmpty must beTrue
}

"work correctly with a failover for every connection to the database" in {
propertiesWithFailoverEveryConnection.getProperty("org.quartz.dataSource.test_jobs.causeFailoverEveryConnection") mustEqual("true")
val jobHistoryModel = new JobHistoryModel(propertiesWithFailoverEveryConnection)
val jobKey = new JobKey("blahc", "blahc")
val triggerKey = new TriggerKey("blahtnc", "blahtgc")
jobHistoryModel.getJob(jobKey).headOption must beNone
jobHistoryModel.addJob("abc", jobKey, triggerKey, new Date(), 1000, true)
jobHistoryModel.getJob(jobKey).headOption must beSome
jobHistoryModel.getLastJobSuccessByTrigger(triggerKey) must beSome
}
}

"TriggerMonitoringModel" should {
Expand Down

0 comments on commit d0801c5

Please sign in to comment.