Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add InheritableMDC #190

Merged
merged 1 commit into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ val useRequireIOPlugin =
lazy val rootProject = (project in file("."))
.settings(commonSettings)
.settings(publishArtifact := false, name := "ox")
.aggregate(core, plugin, pluginTest, examples, kafka)
.aggregate(core, plugin, pluginTest, examples, kafka, mdcLogback)

lazy val core: Project = (project in file("core"))
.settings(commonSettings)
Expand Down Expand Up @@ -123,6 +123,17 @@ lazy val kafka: Project = (project in file("kafka"))
)
.dependsOn(core)

lazy val mdcLogback: Project = (project in file("mdc-logback"))
.settings(commonSettings)
.settings(
name := "mdc-logback",
libraryDependencies ++= Seq(
logback,
scalaTest
)
)
.dependsOn(core)

lazy val documentation: Project = (project in file("generated-doc")) // important: it must not be doc/
.enablePlugins(MdocPlugin)
.settings(commonSettings)
Expand All @@ -140,5 +151,6 @@ lazy val documentation: Project = (project in file("generated-doc")) // importan
)
.dependsOn(
core,
kafka
kafka,
mdcLogback
)
3 changes: 2 additions & 1 deletion doc/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ In addition to this documentation, ScalaDocs can be browsed at [https://javadoc.

.. toctree::
:maxdepth: 2
:caption: Kafka integration
:caption: Integrations

kafka
mdc-logback

.. toctree::
:maxdepth: 2
Expand Down
46 changes: 46 additions & 0 deletions doc/mdc-logback.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Inheritable MDC using Logback

Ox provides support for setting inheritable MDC (mapped diagnostic context) values, when using the [Logback](https://logback.qos.ch)
logging library. Normally, value set using `MDC.put` aren't inherited across (virtual) threads, which includes forks
created in concurrency contexts.

Inheritable values are especially useful e.g. when setting a correlation id in an HTTP request interceptor, or at any
entrypoint to the application. Such correlation id values can be then added automatically to each log message, provided
the appropriate log encoder pattern is used.

To enable using inheritable MDC values, the application's code should call `InheritableMDC.init` as soon as possible.
The best place would be the application's entrypoint (the `main` method).

Once this is done, inheritable MDC values can be set in a scoped & structured manner using `InheritableMDC.supervisedWhere`
and variants.

As inheritable MDC values use a [`ForkLocal`](structured-concurrency/fork-local.md) under the hood, their usage
restrictions apply: outer concurrency scopes should not be used to create forks within the scopes. Only newly created
scopes, or the provided scope can be used to create forks. That's why `supervisedError`, `unsupervisedError` and
`supervisedErrorWhere` methods are provided.

"Normal" MDC usage is not affected. That is, values set using `MDC.put` are not inherited, and are only available in
the thread where they are set.

For example:

```scala mdoc:compile-only
import org.slf4j.MDC

import ox.fork
import ox.logback.InheritableMDC

InheritableMDC.supervisedWhere("a" -> "1", "b" -> "2") {
MDC.put("c", "3") // not inherited

fork {
MDC.get("a") // "1"
MDC.get("b") // "2"
MDC.get("c") // null
}.join()

MDC.get("a") // "1"
MDC.get("b") // "2"
MDC.get("c") // "3"
}
```
122 changes: 122 additions & 0 deletions mdc-logback/src/main/scala/ox/logback/InheritableMDC.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package ox.logback

import ch.qos.logback.classic.LoggerContext
import ch.qos.logback.classic.spi.LogbackServiceProvider
import org.slf4j.spi.{MDCAdapter, SLF4JServiceProvider}
import org.slf4j.{ILoggerFactory, IMarkerFactory, Logger, LoggerFactory, MDC}
import ox.{ErrorMode, ForkLocal, Ox, OxError, OxUnsupervised, pipe, tap}

/** Provides support for MDC which is inheritable across (virtual) threads. Only MDC values set using the [[where]] method will be
* inherited; this method also defines the scope, within which the provided MDC values are available.
*
* The semantics of [[MDC.put]] are unchanged: values set using this method will only be visible in the original thread. That is because
* the "usual" [[MDC]] usage is unstructured, and we don't want to set values for the entire scope (which might exceed the calling thread).
*
* Internally, a [[ForkLocal]] (backed by a ScopedValue) is used, to store the scoped context.
*
* Prior to using inheritable MDCs, the [[init]] method has to be called. This performs some operations using the reflection API, to
* substitute Logback's MDC support with one that is scope-aware.
*/
object InheritableMDC:
private val logger: Logger = LoggerFactory.getLogger(getClass.getName)
private[logback] val currentContext: ForkLocal[Option[MDCAdapter]] = ForkLocal(None)

/** Initialise inheritable MDCs. Must be called as early in the app's code as possible. */
lazy val init: Unit =
// Obtaining the current provider, to replace it later with our implementation returning the correct MDCAdapter.
val getProviderMethod = classOf[LoggerFactory].getDeclaredMethod("getProvider")
getProviderMethod.setAccessible(true)
getProviderMethod.invoke(null) match {
case currentProvider: LogbackServiceProvider =>
// Creating and setting the correct MDCAdapter on the LoggerContext; this is used internally by Logback to
// obtain the MDC values.
val ctx = currentProvider.getLoggerFactory.asInstanceOf[LoggerContext]
val scopedValuedMDCAdapter = new DelegateToCurrentMDCAdapter(ctx.getMDCAdapter)
ctx.setMDCAdapter(scopedValuedMDCAdapter)

// Second, we need to override the provider so that its .getMDCAdapter method returns our instance. This is used
// when setting/clearing the MDC values. Whether in a scope or not, this will delegate to the "root" MDCAdapter,
// because of ScopedMDCAdapter's implementation.
val providerField = classOf[LoggerFactory].getDeclaredField("PROVIDER")
providerField.setAccessible(true)
providerField.set(null, new OverrideMDCAdapterDelegateProvider(currentProvider, scopedValuedMDCAdapter))

logger.info(s"Scoped-value based MDC initialized")
case currentProvider =>
logger.warn(s"A non-Logback SLF4J provider ($currentProvider) is being used, unable to initialize scoped-value based MDC")
}
end init

/** Set the given MDC key-value mappings as passed in `kvs`, for the duration of evaluating `f`. The values will be available for any
* forks created within `f`.
*
* @see
* Usage notes on [[ForkLocal.unsupervisedWhere()]].
*/
def unsupervisedWhere[T](kvs: (String, String)*)(f: OxUnsupervised ?=> T): T = currentContext.unsupervisedWhere(adapterWith(kvs: _*))(f)

/** Set the given MDC key-value mappings as passed in `kvs`, for the duration of evaluating `f`. The values will be available for any
* forks created within `f`.
*
* @see
* Usage notes on [[ForkLocal.supervisedWhere()]].
*/
def supervisedWhere[T](kvs: (String, String)*)(f: Ox ?=> T): T = currentContext.supervisedWhere(adapterWith(kvs: _*))(f)

/** Set the given MDC key-value mappings as passed in `kvs`, for the duration of evaluating `f`. The values will be available for any
* forks created within `f`.
*
* @see
* Usage notes on [[ForkLocal.supervisedErrorWhere()]].
*/
def supervisedErrorWhere[E, F[_], U](errorMode: ErrorMode[E, F])(kvs: (String, String)*)(f: OxError[E, F] ?=> F[U]): F[U] =
currentContext.supervisedErrorWhere(errorMode)(adapterWith(kvs: _*))(f)

private def adapterWith(kvs: (String, String)*): Option[ScopedMDCAdapter] =
// unwrapping the MDC adapter, so that we get the "target" one; using DelegateToCurrentMDCAdapter would lead to
// infinite loops when delegating
val currentAdapter = MDC.getMDCAdapter.asInstanceOf[DelegateToCurrentMDCAdapter].currentAdapter()
Some(new ScopedMDCAdapter(Map(kvs: _*), currentAdapter))
end InheritableMDC

private class OverrideMDCAdapterDelegateProvider(delegate: SLF4JServiceProvider, mdcAdapter: MDCAdapter) extends SLF4JServiceProvider:
override def getMDCAdapter: MDCAdapter = mdcAdapter

override def getLoggerFactory: ILoggerFactory = delegate.getLoggerFactory
override def getMarkerFactory: IMarkerFactory = delegate.getMarkerFactory
override def getRequestedApiVersion: String = delegate.getRequestedApiVersion
override def initialize(): Unit = delegate.initialize()

/** An [[MDCAdapter]] which delegates to a [[ScopedMDCAdapter]] if one is available, or falls back to the root one otherwise. */
private class DelegateToCurrentMDCAdapter(rootAdapter: MDCAdapter) extends MDCAdapter:
def currentAdapter(): MDCAdapter = InheritableMDC.currentContext.get().getOrElse(rootAdapter)

override def put(key: String, `val`: String): Unit = currentAdapter().put(key, `val`)
override def get(key: String): String = currentAdapter().get(key)
override def remove(key: String): Unit = currentAdapter().remove(key)
override def clear(): Unit = currentAdapter().clear()
override def getCopyOfContextMap: java.util.Map[String, String] = currentAdapter().getCopyOfContextMap
override def setContextMap(contextMap: java.util.Map[String, String]): Unit = currentAdapter().setContextMap(contextMap)
override def pushByKey(key: String, value: String): Unit = currentAdapter().pushByKey(key, value)
override def popByKey(key: String): String = currentAdapter().popByKey(key)
override def getCopyOfDequeByKey(key: String): java.util.Deque[String] = currentAdapter().getCopyOfDequeByKey(key)
override def clearDequeByKey(key: String): Unit = currentAdapter().clearDequeByKey(key)

/** An [[MDCAdapter]] that is used within a structured scope. Stores an (immutable) map of values that are set within this scope. All other
* operations are delegated to the parent adapter (might be either another scoped, or the root Logback, adapter).
*/
private class ScopedMDCAdapter(mdcValues: Map[String, String], delegate: MDCAdapter) extends MDCAdapter:
override def get(key: String): String = mdcValues.getOrElse(key, delegate.get(key))
override def getCopyOfContextMap: java.util.Map[String, String] =
delegate.getCopyOfContextMap
.pipe(v => if v == null then new java.util.HashMap() else new java.util.HashMap[String, String](v))
.tap(copy => mdcValues.foreach((k, v) => copy.put(k, v)))

override def put(key: String, `val`: String): Unit = delegate.put(key, `val`)
override def remove(key: String): Unit = delegate.remove(key)
override def clear(): Unit = delegate.clear()
override def setContextMap(contextMap: java.util.Map[String, String]): Unit = delegate.setContextMap(contextMap)
override def pushByKey(key: String, value: String): Unit = delegate.pushByKey(key, value)
override def popByKey(key: String): String = delegate.popByKey(key)
override def getCopyOfDequeByKey(key: String): java.util.Deque[String] = delegate.getCopyOfDequeByKey(key)
override def clearDequeByKey(key: String): Unit = delegate.clearDequeByKey(key)
26 changes: 26 additions & 0 deletions mdc-logback/src/test/scala/ox/logback/InheritableMDCTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package ox.logback

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.slf4j.MDC
import ox.fork

class InheritableMDCTest extends AnyFlatSpec with Matchers {
InheritableMDC.init

it should "make MDC values available in forks" in {
InheritableMDC.supervisedWhere("a" -> "1", "b" -> "2") {
MDC.put("c", "3") // should not be inherited

fork {
MDC.get("a") shouldBe "1"
MDC.get("b") shouldBe "2"
MDC.get("c") shouldBe null
}.join()

MDC.get("a") shouldBe "1"
MDC.get("b") shouldBe "2"
MDC.get("c") shouldBe "3"
}
}
}
Loading