Skip to content

Commit

Permalink
Update ConnectionProvider to support AWS writer failovers
Browse files Browse the repository at this point in the history
  • Loading branch information
aegbert5 committed Sep 5, 2024
1 parent d8d7e50 commit 81e2fb2
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 17 deletions.
136 changes: 121 additions & 15 deletions worker/src/main/scala/com/lucidchart/piezo/ConnectionProvider.scala
Original file line number Diff line number Diff line change
@@ -1,25 +1,131 @@
package com.lucidchart.piezo

import org.quartz.utils.HikariCpPoolingConnectionProvider
import java.sql.{Connection, SQLTransientConnectionException}
import java.util.Properties
import java.util.concurrent.TimeUnit
import org.quartz.utils.HikariCpPoolingConnectionProvider
import org.slf4j.LoggerFactory

class ConnectionProvider(props: Properties) {
class ConnectionProvider(props: Properties, causeFailoverEveryConnection: Boolean = false) {

private class Pool(ipToSet: String) {
val ip: String = ipToSet
val finalJdbcURL: String = if (detectIpAddressFailover && originalJdbcURL != null) originalJdbcURL.replace(dataSourceHostname, ip) else originalJdbcURL
val connectionProvider: Option[HikariCpPoolingConnectionProvider] = createNewConnectionProvider(finalJdbcURL)
logger.info(s"Initialized Db connection pool for ${finalJdbcURL}")
// Hikari takes about a second to add connections to the connection pool
// We are now going to warm-up connectionPool(with timelimit of 2500ms)
connectionProvider.map(warmUpCP)
}

val logger = LoggerFactory.getLogger(this.getClass)
private val dataSource = props.getProperty("org.quartz.jobStore.dataSource")
private val provider = if(dataSource != null) {
Some(new HikariCpPoolingConnectionProvider(
props.getProperty("org.quartz.dataSource." + dataSource + ".driver"),
props.getProperty("org.quartz.dataSource." + dataSource + ".URL"),
props.getProperty("org.quartz.dataSource." + dataSource + ".user"),
props.getProperty("org.quartz.dataSource." + dataSource + ".password"),
props.getProperty("org.quartz.dataSource." + dataSource + ".maxConnections").toInt,
props.getProperty("org.quartz.dataSource." + dataSource + ".validationQuery")
))
} else {
logger.info("No job store found in config")
None
private val originalJdbcURL = props.getProperty("org.quartz.dataSource." + dataSource + ".URL")
private val detectIpAddressFailover = props.getProperty("org.quartz.dataSource." + dataSource + ".ipFailover") == "true"
// Removes "jdbc:mysql://" prefix and ":{port}..." suffix
private val dataSourceHostname = if (originalJdbcURL != null) originalJdbcURL.replace("jdbc:mysql://", "").split(":")(0) else null

// Needs to be mutable so that the pool can be exchanged during a failover AND kept in memory so that connections are long-lived
private var pool: Pool = new Pool(getIP)
private val poolLock = new Object()

def createNewConnectionProvider(finalJdbcURL: String): Option[HikariCpPoolingConnectionProvider] = {
if(dataSource != null) {
Some(new HikariCpPoolingConnectionProvider(
props.getProperty("org.quartz.dataSource." + dataSource + ".driver"),
finalJdbcURL,
props.getProperty("org.quartz.dataSource." + dataSource + ".user"),
props.getProperty("org.quartz.dataSource." + dataSource + ".password"),
props.getProperty("org.quartz.dataSource." + dataSource + ".maxConnections").toInt,
props.getProperty("org.quartz.dataSource." + dataSource + ".validationQuery")
))
} else {
logger.info("No job store found in config to get connections")
None
}
}

/**
* HikariCP connection pools don't automatically close when IP addresses for a hostname change. This function returns True, iff at
* least one of the following conditions is met:
* - IP addresses have changed for the CNAME record used for DNS lookup
* - causeFailoverEveryConnection is set to "true", which is used for testing failover functionality
*
* @param pool
* the connection pool currently being used
* @param dnsIP
* the IP returned when performing a DNS lookup
* @return
*/
private def hasIpAddressChanged(pool: Pool, dnsIP: String): Boolean = {
causeFailoverEveryConnection == true || pool.ip != dnsIP
}

def getConnection = provider.get.getConnection
def retryGettingIp[T](n: Int)(fn: => T): T = {
try {
fn
} catch {
// Failed to resolve it from JVM
case e if n > 1 => retryGettingIp(n - 1)(fn)
}
}

val numRetries = 10
def getIP: String = {
retryGettingIp(numRetries) {
// Get the ip address of the hostname. The result is cached in the JVM
java.net.InetAddress.getByName(dataSourceHostname).getHostAddress
}
}

def getConnection = {
if (detectIpAddressFailover && originalJdbcURL != null) {
// If the IP has changed, then we know a failover has occurred, and we need to create a new hikari config
val newIP: String = getIP
if (hasIpAddressChanged(pool, newIP)) {
// A failover has occurred, so we lock the pool and swap it out with a new hikari config
val (poolToUse, optionalOldPool) = poolLock synchronized {
val oldPool = pool
val newIP: String = getIP
if (hasIpAddressChanged(oldPool, newIP)) {
val newPool = new Pool(newIP)
pool = newPool
(newPool, Some(oldPool))
} else {
// Already up to date by another thread
(oldPool, None)
}
}

// Close the previous config
optionalOldPool.foreach { old =>
// TODO: Get "server.databaseName" from somewhere
logger.info(s"Closing DB connection pool for ${originalJdbcURL}: IP changed (${old.ip} -> ${poolToUse.ip}).")
old.connectionProvider.get.shutdown()
}

poolToUse.connectionProvider.get.getConnection
} else {
pool.connectionProvider.get.getConnection
}
} else {
pool.connectionProvider.get.getConnection
}
}

private def warmUpCP(connectionPool: HikariCpPoolingConnectionProvider): Unit = {
var testConn: Connection = null
val start = System.currentTimeMillis
while (testConn == null && (System.currentTimeMillis - start) < 2500) {
try {
testConn = connectionPool.getConnection()
} catch {
case _: SQLTransientConnectionException => { TimeUnit.MILLISECONDS.sleep(100) } // do nothing
case e: Exception => throw e
}
}
if (testConn != null) {
testConn.close()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ case class JobRecord(
fire_instance_id: String
)

class JobHistoryModel(props: Properties) {
class JobHistoryModel(props: Properties, causeFailoverEveryConnection: Boolean = false) {
val logger = LoggerFactory.getLogger(this.getClass)
val connectionProvider = new ConnectionProvider(props)
val connectionProvider = new ConnectionProvider(props, causeFailoverEveryConnection)

def addJob(
fireInstanceId: String,
Expand Down
1 change: 1 addition & 0 deletions worker/src/test/resources/quartz_test_mysql.properties
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ org.quartz.dataSource.test_jobs.user: root
org.quartz.dataSource.test_jobs.password: root
org.quartz.dataSource.test_jobs.maxConnections: 10
org.quartz.dataSource.test_jobs.validationQuery: select 0
org.quartz.dataSource.test_jobs.ipFailover: true
10 changes: 10 additions & 0 deletions worker/src/test/scala/com/lucidchart/piezo/ModelTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ class ModelTest extends Specification with BeforeAll with AfterAll {
jobHistoryModel.getLastJobSuccessByTrigger(triggerKey) must beSome
jobHistoryModel.getJobs().nonEmpty must beTrue
}

"work correctly with a failover for every connection to the database" in {
val jobHistoryModel = new JobHistoryModel(properties, causeFailoverEveryConnection = true)
val jobKey = new JobKey("blahc", "blahc")
val triggerKey = new TriggerKey("blahtnc", "blahtgc")
jobHistoryModel.getJob(jobKey).headOption must beNone
jobHistoryModel.addJob("abc", jobKey, triggerKey, new Date(), 1000, true)
jobHistoryModel.getJob(jobKey).headOption must beSome
jobHistoryModel.getLastJobSuccessByTrigger(triggerKey) must beSome
}
}

"TriggerMonitoringModel" should {
Expand Down

0 comments on commit 81e2fb2

Please sign in to comment.