From b49795bb4d6005f90c846eaaf8b7eac3d6c4e2e6 Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Tue, 20 Feb 2024 14:24:01 +0100 Subject: [PATCH] Cache TensorFlow Metadata proto files (#921) --- build.sbt | 38 ++++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/build.sbt b/build.sbt index a9ef2f4f..fd89641a 100644 --- a/build.sbt +++ b/build.sbt @@ -14,6 +14,7 @@ * limitations under the License. */ import sbt._ +import sbt.util.CacheImplicits._ import sbtprotoc.ProtocPlugin.ProtobufConfig import com.github.sbt.git.SbtGit.GitKeys.gitRemoteRepo import com.typesafe.tools.mima.core._ @@ -550,6 +551,11 @@ lazy val protobuf = project description := "Magnolia add-on for Google Protocol Buffer" ) +val tensorflowMetadataSourcesDir = + settingKey[File]("Directory containing TensorFlow metadata proto files") +val tensorflowMetadata = + taskKey[Seq[File]]("Retrieve TensorFlow metadata proto files") + lazy val tensorflow = project .in(file("tensorflow")) .dependsOn( @@ -570,14 +576,34 @@ lazy val tensorflow = project // remove compilation warnings for generated java files javacOptions ~= { _.filterNot(_ == "-Xlint:all") }, // tensorflow metadata protos are not packaged into a jar. Manually extract them as external + Compile / tensorflowMetadataSourcesDir := target.value / s"metadata-$tensorflowMetadataVersion", Compile / PB.protoSources += target.value / s"metadata-$tensorflowMetadataVersion", + Compile / tensorflowMetadata := { + def work(tensorFlowMetadataVersion: String) = { + val tfMetadata = url( + s"https://github.com/tensorflow/metadata/archive/refs/tags/v$tensorFlowMetadataVersion.zip" + ) + IO.unzipURL(tfMetadata, target.value, "*.proto").toSeq + } + + val cacheStoreFactory = streams.value.cacheStoreFactory + val root = (Compile / tensorflowMetadataSourcesDir).value + val tracker = + Tracked.inputChanged(cacheStoreFactory.make("input")) { (versionChanged, version: String) => + val cached = Tracked.outputChanged(cacheStoreFactory.make("output")) { + (outputChanged: Boolean, files: Seq[HashFileInfo]) => + if (versionChanged || outputChanged) work(version) + else files.map(_.file) + } + cached(() => (root ** "*.proto").get().map(FileInfo.hash(_))) + } + + tracker(tensorflowMetadataVersion) + }, Compile / PB.unpackDependencies := { - val tfMetadata = new URL( - s"https://github.com/tensorflow/metadata/archive/refs/tags/v$tensorflowMetadataVersion.zip" - ) - val protoFiles = IO.unzipURL(tfMetadata, target.value, _.endsWith(".proto")) - val root = target.value / s"metadata-$tensorflowMetadataVersion" - val metadataDep = ProtocPlugin.UnpackedDependency(protoFiles.toSeq, Seq.empty) + val protoFiles = (Compile / tensorflowMetadata).value + val root = (Compile / tensorflowMetadataSourcesDir).value + val metadataDep = ProtocPlugin.UnpackedDependency(protoFiles, Seq.empty) val deps = (Compile / PB.unpackDependencies).value new ProtocPlugin.UnpackedDependencies(deps.mappedFiles ++ Map(root -> metadataDep)) },