diff --git a/worker/src/main/scala/com/lucidchart/piezo/ConnectionProvider.scala b/worker/src/main/scala/com/lucidchart/piezo/ConnectionProvider.scala index 337794a..d381c56 100644 --- a/worker/src/main/scala/com/lucidchart/piezo/ConnectionProvider.scala +++ b/worker/src/main/scala/com/lucidchart/piezo/ConnectionProvider.scala @@ -1,25 +1,134 @@ package com.lucidchart.piezo -import org.quartz.utils.HikariCpPoolingConnectionProvider +import java.net.UnknownHostException +import java.sql.{Connection, SQLTransientConnectionException} import java.util.Properties +import java.util.concurrent.TimeUnit +import org.quartz.utils.HikariCpPoolingConnectionProvider import org.slf4j.LoggerFactory +import scala.annotation.tailrec class ConnectionProvider(props: Properties) { + + private class Pool(val ip: String) { + val connectionProvider: Option[HikariCpPoolingConnectionProvider] = createNewConnectionProvider() + logger.info(s"Initialized Db connection pool for ${jdbcURL}") + // Hikari takes about a second to add connections to the connection pool + // We are now going to warm-up connectionPool(with timelimit of 2500ms) + connectionProvider.map(warmUpCP) + } + val logger = LoggerFactory.getLogger(this.getClass) private val dataSource = props.getProperty("org.quartz.jobStore.dataSource") - private val provider = if(dataSource != null) { - Some(new HikariCpPoolingConnectionProvider( - props.getProperty("org.quartz.dataSource." + dataSource + ".driver"), - props.getProperty("org.quartz.dataSource." + dataSource + ".URL"), - props.getProperty("org.quartz.dataSource." + dataSource + ".user"), - props.getProperty("org.quartz.dataSource." + dataSource + ".password"), - props.getProperty("org.quartz.dataSource." + dataSource + ".maxConnections").toInt, - props.getProperty("org.quartz.dataSource." + dataSource + ".validationQuery") - )) - } else { - logger.info("No job store found in config") - None - } - - def getConnection = provider.get.getConnection + private val jdbcURL = if (dataSource != null) props.getProperty("org.quartz.dataSource." + dataSource + ".URL") else null + private val detectIpAddressFailover = if (dataSource != null) props.getProperty("org.quartz.dataSource." + dataSource + ".ipFailover") == "true" else false + // Removes "jdbc:mysql://" prefix and ":{port}..." suffix + private val dataSourceHostname = if (jdbcURL != null) jdbcURL.replace("jdbc:mysql://", "").split(":")(0) else null + + // Time (in milliseconds) that the in-memory cache will retain the ip address + private val cachedIpTTL: Long = 1000 + private val getIpNumRetries = 10 + // Class for storing the ip address of the host, along with an expiration date + private case class CachedIpWithExpiration(ip: String, expiration: Long) + // Cache for ip address and its expiration for a host + @volatile + private var cachedIpWithExpiration: Option[CachedIpWithExpiration] = None + + private val pool: Pool = new Pool(getIP) + + // Intended to be used only for tests. This mocks an IP failover every time a connection is retreived + private val causeFailoverEveryConnection = if (dataSource != null) props.getProperty("org.quartz.dataSource." + dataSource + ".causeFailoverEveryConnection") == "true" else false + + def createNewConnectionProvider(): Option[HikariCpPoolingConnectionProvider] = { + if(dataSource != null) { + Some(new HikariCpPoolingConnectionProvider( + props.getProperty("org.quartz.dataSource." + dataSource + ".driver"), + jdbcURL, + props.getProperty("org.quartz.dataSource." + dataSource + ".user"), + props.getProperty("org.quartz.dataSource." + dataSource + ".password"), + props.getProperty("org.quartz.dataSource." + dataSource + ".maxConnections").toInt, + props.getProperty("org.quartz.dataSource." + dataSource + ".validationQuery") + )) + } else { + logger.info("No job store found in config to get connections") + None + } + } + + /** + * HikariCP connection pools don't automatically close when IP addresses for a hostname change. This function returns True, iff at + * least one of the following conditions is met: + * - IP addresses have changed for the CNAME record used for DNS lookup + * - causeFailoverEveryConnection is set to "true", which is used for testing failover functionality + * + * @param pool + * the connection pool currently being used + * @param dnsIP + * the IP returned when performing a DNS lookup + * @return + */ + private def hasIpAddressChanged(pool: Pool, dnsIP: String): Boolean = { + causeFailoverEveryConnection || pool.ip != dnsIP + } + + @tailrec + private def retryGettingIp(n: Int)(fn: => String): String = { + try { + return fn + } catch { + // Failed to resolve it from JVM + case e: UnknownHostException if n > 1 => + } + // Wait 10 milliseconds between retries + Thread.sleep(10) + retryGettingIp(n - 1)(fn) + } + + def _getIp: String = { + retryGettingIp(getIpNumRetries) { + // Get the ip address of the hostname. The result is cached in the JVM + val ip = java.net.InetAddress.getByName(dataSourceHostname).getHostAddress + cachedIpWithExpiration = Some(CachedIpWithExpiration(ip, System.currentTimeMillis() + cachedIpTTL)) + ip + } + } + + def getIP: String = { + synchronized { + cachedIpWithExpiration.map { cachedValue => + if (System.currentTimeMillis() > cachedValue.expiration) { + _getIp + } else { + cachedValue.ip + } + }.getOrElse(_getIp) + } + } + + def getConnection = { + if (detectIpAddressFailover && dataSourceHostname != null) { + // If the IP has changed, then we know a failover has occurred, and we need to create a new hikari config + val newIP: String = getIP + if (hasIpAddressChanged(pool, newIP)) { + // A failover has occurred, so we evict connections softly. New connectons look up the new IP address + pool.connectionProvider.map(_.getDataSource().getHikariPoolMXBean().softEvictConnections()) + } + } + pool.connectionProvider.get.getConnection + } + + private def warmUpCP(connectionPool: HikariCpPoolingConnectionProvider): Unit = { + var testConn: Connection = null + val start = System.currentTimeMillis + while (testConn == null && (System.currentTimeMillis - start) < 2500) { + try { + testConn = connectionPool.getConnection() + } catch { + case _: SQLTransientConnectionException => { TimeUnit.MILLISECONDS.sleep(100) } // do nothing + } + } + if (testConn != null) { + testConn.close() + } + } } diff --git a/worker/src/test/resources/quartz_test_mysql.properties b/worker/src/test/resources/quartz_test_mysql.properties index 2e2a1d0..2c82d48 100644 --- a/worker/src/test/resources/quartz_test_mysql.properties +++ b/worker/src/test/resources/quartz_test_mysql.properties @@ -31,3 +31,4 @@ org.quartz.dataSource.test_jobs.user: root org.quartz.dataSource.test_jobs.password: root org.quartz.dataSource.test_jobs.maxConnections: 10 org.quartz.dataSource.test_jobs.validationQuery: select 0 +org.quartz.dataSource.test_jobs.ipFailover: true diff --git a/worker/src/test/resources/quartz_test_mysql_failover_every_connection.properties b/worker/src/test/resources/quartz_test_mysql_failover_every_connection.properties new file mode 100644 index 0000000..02cba1d --- /dev/null +++ b/worker/src/test/resources/quartz_test_mysql_failover_every_connection.properties @@ -0,0 +1,37 @@ +# Same configuration as quartz_test_mysql.properties +#============================================================================ +# Configure Main Scheduler Properties +#============================================================================ + +org.quartz.scheduler.instanceName: TestScheduler +org.quartz.scheduler.instanceId: AUTO + +org.quartz.scheduler.skipUpdateCheck: true + +#============================================================================ +# Configure ThreadPool +#============================================================================ + +org.quartz.threadPool.class: org.quartz.simpl.SimpleThreadPool +org.quartz.threadPool.threadCount: 2 +org.quartz.threadPool.threadPriority: 5 + +#============================================================================ +# Configure JobStore +#============================================================================ + +org.quartz.jobStore.misfireThreshold: 60000 +org.quartz.jobStore.class=org.quartz.impl.jdbcjobstore.JobStoreTX +org.quartz.jobStore.dataSource=test_jobs + + +org.quartz.dataSource.test_jobs.driver: com.mysql.cj.jdbc.Driver +org.quartz.dataSource.test_jobs.URL: jdbc:mysql://localhost:3306/test_jobs +org.quartz.dataSource.test_jobs.user: root +org.quartz.dataSource.test_jobs.password: root +org.quartz.dataSource.test_jobs.maxConnections: 10 +org.quartz.dataSource.test_jobs.validationQuery: select 0 +org.quartz.dataSource.test_jobs.ipFailover: true + +# Unique configuration to this file +org.quartz.dataSource.test_jobs.causeFailoverEveryConnection: true diff --git a/worker/src/test/scala/com/lucidchart/piezo/ModelTest.scala b/worker/src/test/scala/com/lucidchart/piezo/ModelTest.scala index b432828..90cb34c 100644 --- a/worker/src/test/scala/com/lucidchart/piezo/ModelTest.scala +++ b/worker/src/test/scala/com/lucidchart/piezo/ModelTest.scala @@ -17,6 +17,10 @@ class ModelTest extends Specification with BeforeAll with AfterAll { val properties = new Properties properties.load(propertiesStream) + val propertiesStreamFailoverEveryConnection = getClass().getResourceAsStream("/quartz_test_mysql_failover_every_connection.properties") + val propertiesWithFailoverEveryConnection = new Properties + propertiesWithFailoverEveryConnection.load(propertiesStreamFailoverEveryConnection) + val username = properties.getProperty("org.quartz.dataSource.test_jobs.user") val password = properties.getProperty("org.quartz.dataSource.test_jobs.password") val dbUrl = properties.getProperty("org.quartz.dataSource.test_jobs.URL") @@ -65,6 +69,7 @@ class ModelTest extends Specification with BeforeAll with AfterAll { "JobHistoryModel" should { "work correctly" in { + properties.getProperty("org.quartz.dataSource.test_jobs.causeFailoverEveryConnection") must beNull val jobHistoryModel = new JobHistoryModel(properties) val jobKey = new JobKey("blah", "blah") val triggerKey = new TriggerKey("blahtn", "blahtg") @@ -74,6 +79,17 @@ class ModelTest extends Specification with BeforeAll with AfterAll { jobHistoryModel.getLastJobSuccessByTrigger(triggerKey) must beSome jobHistoryModel.getJobs().nonEmpty must beTrue } + + "work correctly with a failover for every connection to the database" in { + propertiesWithFailoverEveryConnection.getProperty("org.quartz.dataSource.test_jobs.causeFailoverEveryConnection") mustEqual("true") + val jobHistoryModel = new JobHistoryModel(propertiesWithFailoverEveryConnection) + val jobKey = new JobKey("blahc", "blahc") + val triggerKey = new TriggerKey("blahtnc", "blahtgc") + jobHistoryModel.getJob(jobKey).headOption must beNone + jobHistoryModel.addJob("abc", jobKey, triggerKey, new Date(), 1000, true) + jobHistoryModel.getJob(jobKey).headOption must beSome + jobHistoryModel.getLastJobSuccessByTrigger(triggerKey) must beSome + } } "TriggerMonitoringModel" should {