diff --git a/worker/src/main/scala/com/lucidchart/piezo/ConnectionProvider.scala b/worker/src/main/scala/com/lucidchart/piezo/ConnectionProvider.scala index 337794a..88dea3e 100644 --- a/worker/src/main/scala/com/lucidchart/piezo/ConnectionProvider.scala +++ b/worker/src/main/scala/com/lucidchart/piezo/ConnectionProvider.scala @@ -1,25 +1,131 @@ package com.lucidchart.piezo -import org.quartz.utils.HikariCpPoolingConnectionProvider +import java.sql.{Connection, SQLTransientConnectionException} import java.util.Properties +import java.util.concurrent.TimeUnit +import org.quartz.utils.HikariCpPoolingConnectionProvider import org.slf4j.LoggerFactory -class ConnectionProvider(props: Properties) { +class ConnectionProvider(props: Properties, causeFailoverEveryConnection: Boolean = false) { + + private class Pool(ipToSet: String) { + val ip: String = ipToSet + val finalJdbcURL: String = if (detectIpAddressFailover && originalJdbcURL != null) originalJdbcURL.replace(dataSourceHostname, ip) else originalJdbcURL + val connectionProvider: Option[HikariCpPoolingConnectionProvider] = createNewConnectionProvider(finalJdbcURL) + logger.info(s"Initialized Db connection pool for ${finalJdbcURL}") + // Hikari takes about a second to add connections to the connection pool + // We are now going to warm-up connectionPool(with timelimit of 2500ms) + connectionProvider.map(warmUpCP) + } + val logger = LoggerFactory.getLogger(this.getClass) private val dataSource = props.getProperty("org.quartz.jobStore.dataSource") - private val provider = if(dataSource != null) { - Some(new HikariCpPoolingConnectionProvider( - props.getProperty("org.quartz.dataSource." + dataSource + ".driver"), - props.getProperty("org.quartz.dataSource." + dataSource + ".URL"), - props.getProperty("org.quartz.dataSource." + dataSource + ".user"), - props.getProperty("org.quartz.dataSource." + dataSource + ".password"), - props.getProperty("org.quartz.dataSource." + dataSource + ".maxConnections").toInt, - props.getProperty("org.quartz.dataSource." + dataSource + ".validationQuery") - )) - } else { - logger.info("No job store found in config") - None + private val originalJdbcURL = props.getProperty("org.quartz.dataSource." + dataSource + ".URL") + private val detectIpAddressFailover = props.getProperty("org.quartz.dataSource." + dataSource + ".ipFailover") == "true" + // Removes "jdbc:mysql://" prefix and ":{port}..." suffix + private val dataSourceHostname = if (originalJdbcURL != null) originalJdbcURL.replace("jdbc:mysql://", "").split(":")(0) else null + + // Needs to be mutable so that the pool can be exchanged during a failover AND kept in memory so that connections are long-lived + private var pool: Pool = new Pool(getIP) + private val poolLock = new Object() + + def createNewConnectionProvider(finalJdbcURL: String): Option[HikariCpPoolingConnectionProvider] = { + if(dataSource != null) { + Some(new HikariCpPoolingConnectionProvider( + props.getProperty("org.quartz.dataSource." + dataSource + ".driver"), + finalJdbcURL, + props.getProperty("org.quartz.dataSource." + dataSource + ".user"), + props.getProperty("org.quartz.dataSource." + dataSource + ".password"), + props.getProperty("org.quartz.dataSource." + dataSource + ".maxConnections").toInt, + props.getProperty("org.quartz.dataSource." + dataSource + ".validationQuery") + )) + } else { + logger.info("No job store found in config to get connections") + None + } + } + + /** + * HikariCP connection pools don't automatically close when IP addresses for a hostname change. This function returns True, iff at + * least one of the following conditions is met: + * - IP addresses have changed for the CNAME record used for DNS lookup + * - causeFailoverEveryConnection is set to "true", which is used for testing failover functionality + * + * @param pool + * the connection pool currently being used + * @param dnsIP + * the IP returned when performing a DNS lookup + * @return + */ + private def hasIpAddressChanged(pool: Pool, dnsIP: String): Boolean = { + causeFailoverEveryConnection == true || pool.ip != dnsIP } - def getConnection = provider.get.getConnection + def retryGettingIp[T](n: Int)(fn: => T): T = { + try { + fn + } catch { + // Failed to resolve it from JVM + case e if n > 1 => retryGettingIp(n - 1)(fn) + } + } + + val numRetries = 10 + def getIP: String = { + retryGettingIp(numRetries) { + // Get the ip address of the hostname. The result is cached in the JVM + java.net.InetAddress.getByName(dataSourceHostname).getHostAddress + } + } + + def getConnection = { + if (detectIpAddressFailover && originalJdbcURL != null) { + // If the IP has changed, then we know a failover has occurred, and we need to create a new hikari config + val newIP: String = getIP + if (hasIpAddressChanged(pool, newIP)) { + // A failover has occurred, so we lock the pool and swap it out with a new hikari config + val (poolToUse, optionalOldPool) = poolLock synchronized { + val oldPool = pool + val newIP: String = getIP + if (hasIpAddressChanged(oldPool, newIP)) { + val newPool = new Pool(newIP) + pool = newPool + (newPool, Some(oldPool)) + } else { + // Already up to date by another thread + (oldPool, None) + } + } + + // Close the previous config + optionalOldPool.foreach { old => + // TODO: Get "server.databaseName" from somewhere + logger.info(s"Closing DB connection pool for ${originalJdbcURL}: IP changed (${old.ip} -> ${poolToUse.ip}).") + old.connectionProvider.get.shutdown() + } + + poolToUse.connectionProvider.get.getConnection + } else { + pool.connectionProvider.get.getConnection + } + } else { + pool.connectionProvider.get.getConnection + } + } + + private def warmUpCP(connectionPool: HikariCpPoolingConnectionProvider): Unit = { + var testConn: Connection = null + val start = System.currentTimeMillis + while (testConn == null && (System.currentTimeMillis - start) < 2500) { + try { + testConn = connectionPool.getConnection() + } catch { + case _: SQLTransientConnectionException => { TimeUnit.MILLISECONDS.sleep(100) } // do nothing + case e: Exception => throw e + } + } + if (testConn != null) { + testConn.close() + } + } } diff --git a/worker/src/main/scala/com/lucidchart/piezo/JobHistoryModel.scala b/worker/src/main/scala/com/lucidchart/piezo/JobHistoryModel.scala index c7bb86c..5be5175 100644 --- a/worker/src/main/scala/com/lucidchart/piezo/JobHistoryModel.scala +++ b/worker/src/main/scala/com/lucidchart/piezo/JobHistoryModel.scala @@ -16,9 +16,9 @@ case class JobRecord( fire_instance_id: String ) -class JobHistoryModel(props: Properties) { +class JobHistoryModel(props: Properties, causeFailoverEveryConnection: Boolean = false) { val logger = LoggerFactory.getLogger(this.getClass) - val connectionProvider = new ConnectionProvider(props) + val connectionProvider = new ConnectionProvider(props, causeFailoverEveryConnection) def addJob( fireInstanceId: String, diff --git a/worker/src/test/resources/quartz_test_mysql.properties b/worker/src/test/resources/quartz_test_mysql.properties index 2e2a1d0..2c82d48 100644 --- a/worker/src/test/resources/quartz_test_mysql.properties +++ b/worker/src/test/resources/quartz_test_mysql.properties @@ -31,3 +31,4 @@ org.quartz.dataSource.test_jobs.user: root org.quartz.dataSource.test_jobs.password: root org.quartz.dataSource.test_jobs.maxConnections: 10 org.quartz.dataSource.test_jobs.validationQuery: select 0 +org.quartz.dataSource.test_jobs.ipFailover: true diff --git a/worker/src/test/scala/com/lucidchart/piezo/ModelTest.scala b/worker/src/test/scala/com/lucidchart/piezo/ModelTest.scala index b432828..68fe005 100644 --- a/worker/src/test/scala/com/lucidchart/piezo/ModelTest.scala +++ b/worker/src/test/scala/com/lucidchart/piezo/ModelTest.scala @@ -74,6 +74,16 @@ class ModelTest extends Specification with BeforeAll with AfterAll { jobHistoryModel.getLastJobSuccessByTrigger(triggerKey) must beSome jobHistoryModel.getJobs().nonEmpty must beTrue } + + "work correctly with a failover for every connection to the database" in { + val jobHistoryModel = new JobHistoryModel(properties, causeFailoverEveryConnection = true) + val jobKey = new JobKey("blahc", "blahc") + val triggerKey = new TriggerKey("blahtnc", "blahtgc") + jobHistoryModel.getJob(jobKey).headOption must beNone + jobHistoryModel.addJob("abc", jobKey, triggerKey, new Date(), 1000, true) + jobHistoryModel.getJob(jobKey).headOption must beSome + jobHistoryModel.getLastJobSuccessByTrigger(triggerKey) must beSome + } } "TriggerMonitoringModel" should {