diff --git a/build.sbt b/build.sbt index d3dc8eb..0662b75 100644 --- a/build.sbt +++ b/build.sbt @@ -36,6 +36,7 @@ lazy val core = project stdSettings("zio-profiling"), libraryDependencies ++= Seq( "dev.zio" %% "zio" % zioVersion, + "dev.zio" %% "zio-streams" % zioVersion, "org.scala-lang.modules" %% "scala-collection-compat" % collectionCompatVersion, "dev.zio" %% "zio-test" % zioVersion % Test, "dev.zio" %% "zio-test-sbt" % zioVersion % Test diff --git a/project/BuildHelper.scala b/project/BuildHelper.scala index 896ac98..4c68033 100644 --- a/project/BuildHelper.scala +++ b/project/BuildHelper.scala @@ -31,13 +31,13 @@ object BuildHelper { val Scala213 = versions("2.13") val Scala3 = versions("3") - val defaulScalaVersion = Scala213 + val defaultScalaVersion = Scala213 def stdSettings(prjName: String) = Seq( name := s"$prjName", crossScalaVersions := List(Scala212, Scala213, Scala3), - ThisBuild / scalaVersion := defaulScalaVersion, + ThisBuild / scalaVersion := defaultScalaVersion, scalacOptions := stdOptions ++ extraOptions(scalaVersion.value, optimize = !isSnapshot.value), libraryDependencies ++= { if (scalaVersion.value == Scala3) @@ -50,7 +50,7 @@ object BuildHelper { compilerPlugin("com.github.ghik" % "silencer-plugin" % silencerVersion cross CrossVersion.full) ) }, - semanticdbEnabled := scalaVersion.value == defaulScalaVersion, + semanticdbEnabled := scalaVersion.value == defaultScalaVersion, semanticdbOptions ++= (if (scalaVersion.value != Scala3) List("-P:semanticdb:synthetics:on") else Nil), semanticdbVersion := scalafixSemanticdb.revision, ThisBuild / scalafixScalaBinaryVersion := CrossVersion.binaryScalaVersion(scalaVersion.value), diff --git a/zio-profiling-jmh/src/main/scala/zio/profiling/jmh/BenchmarkUtils.scala b/zio-profiling-jmh/src/main/scala/zio/profiling/jmh/BenchmarkUtils.scala index f8200ec..f558b58 100644 --- a/zio-profiling-jmh/src/main/scala/zio/profiling/jmh/BenchmarkUtils.scala +++ b/zio-profiling-jmh/src/main/scala/zio/profiling/jmh/BenchmarkUtils.scala @@ -14,6 +14,11 @@ object BenchmarkUtils { if (customRt ne null) customRt else Runtime.default } + def getSupervisor(): Supervisor[Any] = { + val customRt = runtimeRef.get() + if (customRt ne null) customRt.environment.get[SamplingProfilerSupervisor] else Supervisor.none + } + def unsafeRun[E, A](zio: ZIO[Any, E, A]): A = Unsafe.unsafe { implicit unsafe => getRuntime().unsafe.run(zio).getOrThrowFiberFailure() diff --git a/zio-profiling-tagging-plugin/src/main/scala-2/zio/profiling/plugins/TaggingPlugin.scala b/zio-profiling-tagging-plugin/src/main/scala-2/zio/profiling/plugins/TaggingPlugin.scala index 29d0737..420c65c 100644 --- a/zio-profiling-tagging-plugin/src/main/scala-2/zio/profiling/plugins/TaggingPlugin.scala +++ b/zio-profiling-tagging-plugin/src/main/scala-2/zio/profiling/plugins/TaggingPlugin.scala @@ -27,24 +27,23 @@ class TaggingPlugin(val global: Global) extends Plugin { class TaggingTransformer(unit: CompilationUnit) extends TypingTransformer(unit) { override def transform(tree: Tree): Tree = tree match { - case valDef @ ValDef(_, _, ZioTypeTree(t1, t2, t3), rhs) if isNonAbstract(valDef) => - val transformedRhs = tagEffectTree(descriptiveName(tree), rhs, t1, t2, t3) + case valDef @ ValDef(_, _, TaggableTypeTree(taggingTarget), rhs) if rhs.nonEmpty => + val transformedRhs = tagEffectTree(descriptiveName(tree), rhs, taggingTarget) val typedRhs = localTyper.typed(transformedRhs) val updated = treeCopy.ValDef(tree, valDef.mods, valDef.name, valDef.tpt, rhs = typedRhs) + super.transform(updated) - case defDef @ DefDef(_, _, _, _, ZioTypeTree(t1, t2, t3), rhs) if isNonAbstract(defDef) => - val transformedRhs = tagEffectTree(descriptiveName(tree), rhs, t1, t2, t3) + case defDef @ DefDef(_, _, _, _, TaggableTypeTree(taggingTarget), rhs) if rhs.nonEmpty => + val transformedRhs = tagEffectTree(descriptiveName(tree), rhs, taggingTarget) val typedRhs = localTyper.typed(transformedRhs) val updated = treeCopy.DefDef(tree, defDef.mods, defDef.name, defDef.tparams, defDef.vparamss, defDef.tpt, rhs = typedRhs) + super.transform(updated) case _ => super.transform(tree) } - private def isNonAbstract(tree: ValOrDefDef): Boolean = - !tree.mods.isDeferred - private def descriptiveName(tree: Tree): String = { val fullName = tree.symbol.fullNameString val sourceFile = tree.pos.source.file.name @@ -53,21 +52,37 @@ class TaggingPlugin(val global: Global) extends Plugin { s"$fullName($sourceFile:$sourceLine)" } - private def tagEffectTree(name: String, tree: Tree, t1: Type, t2: Type, t3: Type): Tree = { + private def tagEffectTree(name: String, tree: Tree, taggingTarget: TaggingTarget): Tree = { val costCenterModule = rootMirror.getRequiredModule("_root_.zio.profiling.CostCenter") val traceModule = rootMirror.getRequiredModule("_root_.zio.Trace") - q"$costCenterModule.withChildCostCenter[$t1,$t2,$t3]($name)($tree)($traceModule.empty)" + taggingTarget match { + case ZioTaggingTarget(t1, t2, t3) => + q"$costCenterModule.withChildCostCenter[$t1,$t2,$t3]($name)($tree)($traceModule.empty)" + case ZStreamTaggingTarget(t1, t2, t3) => + println(name) + q"$costCenterModule.withChildCostCenterStream[$t1,$t2,$t3]($name)($tree)($traceModule.empty)" + } + } - private object ZioTypeTree { - private def zioTypeRef: Type = - rootMirror.getRequiredClass("zio.ZIO").tpe + private sealed trait TaggingTarget + + private case class ZioTaggingTarget(rType: Type, eType: Type, aType: Type) extends TaggingTarget + private case class ZStreamTaggingTarget(rType: Type, eType: Type, aType: Type) extends TaggingTarget + + private object TaggableTypeTree { + private def zioTypeRef: Type = rootMirror.getRequiredClass("_root_.zio.ZIO").tpe + + private def zStreamTypeRef: Type = rootMirror.getRequiredClass("_root_.zio.stream.ZStream").tpe - def unapply(tpt: Tree): Option[(Type, Type, Type)] = + def unapply(tpt: Tree): Option[TaggingTarget] = tpt.tpe.dealias match { - case TypeRef(_, sym, t1 :: t2 :: t3 :: Nil) if sym == zioTypeRef.typeSymbol => Some((t1, t2, t3)) - case _ => None + case TypeRef(_, sym, t1 :: t2 :: t3 :: Nil) if sym == zioTypeRef.typeSymbol => + Some(ZioTaggingTarget(t1, t2, t3)) + case TypeRef(_, sym, t1 :: t2 :: t3 :: Nil) if sym == zStreamTypeRef.typeSymbol => + Some(ZStreamTaggingTarget(t1, t2, t3)) + case _ => None } } diff --git a/zio-profiling-tagging-plugin/src/main/scala-3/zio/profiling/plugins/TaggingPhase.scala b/zio-profiling-tagging-plugin/src/main/scala-3/zio/profiling/plugins/TaggingPhase.scala index 6515dfd..8fe5bb4 100644 --- a/zio-profiling-tagging-plugin/src/main/scala-3/zio/profiling/plugins/TaggingPhase.scala +++ b/zio-profiling-tagging-plugin/src/main/scala-3/zio/profiling/plugins/TaggingPhase.scala @@ -24,16 +24,16 @@ object TaggingPhase extends PluginPhase { override val runsBefore = Set(Staging.name) override def transformValDef(tree: tpd.ValDef)(using Context): tpd.Tree = tree match { - case ValDef(_, ZioTypeTree(t1, t2, t3), _) if !tree.mods.flags.is(Flags.DeferredTerm) => - val transformedRhs = tagEffectTree(descriptiveName(tree), tree.rhs, t1, t2, t3) + case ValDef(_, TaggableTypeTree(taggingTarget), rhs) if !tree.rhs.isEmpty => + val transformedRhs = tagEffectTree(descriptiveName(tree), tree.rhs, taggingTarget) cpy.ValDef(tree)(rhs = transformedRhs) case _ => tree } override def transformDefDef(tree: tpd.DefDef)(using Context): tpd.Tree = tree match { - case DefDef(_, _, tpt @ ZioTypeTree(t1, t2, t3), _) if !tree.mods.flags.is(Flags.DeferredTerm) => - val transformedRhs = tagEffectTree(descriptiveName(tree), tree.rhs, t1, t2, t3) + case DefDef(_, _, TaggableTypeTree(taggingTarget), rhs) if !tree.rhs.isEmpty => + val transformedRhs = tagEffectTree(descriptiveName(tree), tree.rhs, taggingTarget) cpy.DefDef(tree)(rhs = transformedRhs) case _ => tree @@ -47,27 +47,46 @@ object TaggingPhase extends PluginPhase { s"$fullName($sourceFile:$sourceLine)" } - private def tagEffectTree(name: String, tree: tpd.Tree, t1: Type, t2: Type, t3: Type)(using Context): tpd.Tree = { - val costcenterSym = requiredModule("zio.profiling.CostCenter") - val withChildCostCenterSym = costcenterSym.requiredMethod("withChildCostCenter") - - val traceSym = requiredModule("zio.Trace") + private def tagEffectTree(name: String, tree: tpd.Tree, taggingTarget: TaggingTarget)(using Context): tpd.Tree = { + val costcenterSym = requiredModule("_root_.zio.profiling.CostCenter") + val traceSym = requiredModule("_root_.zio.Trace") val emptyTraceSym = traceSym.requiredMethodRef("empty") - tpd.ref(withChildCostCenterSym) - .appliedToTypes(List(t1, t2, t3)) - .appliedTo(tpd.Literal(Constant(name))) - .appliedTo(tree) - .appliedTo(tpd.ref(emptyTraceSym)) + taggingTarget match { + case ZioTaggingTarget(t1, t2, t3) => + val withChildCostCenterSym = costcenterSym.requiredMethod("withChildCostCenter") + + tpd.ref(withChildCostCenterSym) + .appliedToTypes(List(t1, t2, t3)) + .appliedTo(tpd.Literal(Constant(name))) + .appliedTo(tree) + .appliedTo(tpd.ref(emptyTraceSym)) + + case ZStreamTaggingTarget(t1, t2, t3) => + val withChildCostCenterSym = costcenterSym.requiredMethod("withChildCostCenterStream") + + tpd.ref(withChildCostCenterSym) + .appliedToTypes(List(t1, t2, t3)) + .appliedTo(tpd.Literal(Constant(name))) + .appliedTo(tree) + .appliedTo(tpd.ref(emptyTraceSym)) + } } - private object ZioTypeTree { - private def zioTypeRef(using Context): TypeRef = - requiredClassRef("zio.ZIO") + private sealed trait TaggingTarget + + private case class ZioTaggingTarget(rType: Type, eType: Type, aType: Type) extends TaggingTarget + private case class ZStreamTaggingTarget(rType: Type, eType: Type, aType: Type) extends TaggingTarget + + private object TaggableTypeTree { + private def zioTypeRef(using Context): TypeRef = requiredClassRef("_root_.zio.ZIO") + + private def zStreamTypeRef(using Context): TypeRef = requiredClassRef("_root_.stream.ZStream") - def unapply(tp: Tree[Type])(using Context): Option[(Type, Type, Type)] = + def unapply(tp: Tree[Type])(using Context): Option[TaggingTarget] = tp.tpe.dealias match { - case AppliedType(at, t1 :: t2 :: t3 :: Nil) if at.isRef(zioTypeRef.symbol) => Some((t1, t2, t3)) + case AppliedType(at, t1 :: t2 :: t3 :: Nil) if at.isRef(zioTypeRef.symbol) => Some(ZioTaggingTarget(t1, t2, t3)) + case AppliedType(at, t1 :: t2 :: t3 :: Nil) if at.isRef(zioTypeRef.symbol) => Some(ZStreamTaggingTarget(t1, t2, t3)) case _ => None } diff --git a/zio-profiling/src/main/scala/zio/profiling/CostCenter.scala b/zio-profiling/src/main/scala/zio/profiling/CostCenter.scala index c6e5063..35d8fe8 100644 --- a/zio-profiling/src/main/scala/zio/profiling/CostCenter.scala +++ b/zio-profiling/src/main/scala/zio/profiling/CostCenter.scala @@ -1,6 +1,7 @@ package zio.profiling import zio._ +import zio.stream.ZStream /** * A CostCenter allows grouping multiple source code locations into one unit for reporting and targeting purposes. @@ -75,6 +76,14 @@ object CostCenter { def withChildCostCenter[R, E, A](name: String)(zio: ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, A] = globalRef.locallyWith(_ / name)(zio) + /** + * Run an effect with a child cost center nested under the current one. + */ + def withChildCostCenterStream[R, E, A](name: String)(stream: ZStream[R, E, A])(implicit + trace: Trace + ): ZStream[R, E, A] = + ZStream.scoped[R](globalRef.locallyScopedWith(_ / name)) *> stream + private final val globalRef: FiberRef[CostCenter] = Unsafe.unsafe(implicit u => FiberRef.unsafe.make(CostCenter.Root, identity, (old, _) => old)) } diff --git a/zio-profiling/src/main/scala/zio/profiling/sampling/SamplingProfiler.scala b/zio-profiling/src/main/scala/zio/profiling/sampling/SamplingProfiler.scala index 2cbf476..0662027 100644 --- a/zio-profiling/src/main/scala/zio/profiling/sampling/SamplingProfiler.scala +++ b/zio-profiling/src/main/scala/zio/profiling/sampling/SamplingProfiler.scala @@ -30,7 +30,7 @@ final case class SamplingProfiler( /** * Create a runtime that will profile all effects executed with it. Use `runtime.environment.get` in order to get a - * reference to the supervisor. Make sure to shut down the runtime when down. + * reference to the supervisor. Make sure to shut down the runtime when done. */ def supervisedRuntime(implicit unsafe: Unsafe): Runtime.Scoped[SamplingProfilerSupervisor] = { val layer = ZLayer.scoped[Any](makeSupervisor).flatMap(env => Runtime.addSupervisor(env.get).map(_ => env))