Skip to content

Commit

Permalink
passing around the trace group identifier (#382 #383); harden the def…
Browse files Browse the repository at this point in the history
…inition of a valid trace group ID
  • Loading branch information
vreuter committed Dec 10, 2024
1 parent 617408a commit 0a390ef
Show file tree
Hide file tree
Showing 18 changed files with 502 additions and 239 deletions.
15 changes: 9 additions & 6 deletions looptrace/SpotPicker.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,15 +708,18 @@ def build_locus_spot_data_extraction_table(
)

# roi.name is the index value.
all_rois.append([fov, roi["index"], timepoint, ref_timepoint, ch, roi["traceId"], roi["tracePartners"],
z_min, z_max, y_min, y_max, x_min, x_max,
pad_z_min, pad_z_max, pad_y_min, pad_y_max, pad_x_min, pad_x_max,
z_drift_coarse, y_drift_coarse, x_drift_coarse,
dc_row["zDriftFinePixels"], dc_row["yDriftFinePixels"], dc_row["xDriftFinePixels"]])
all_rois.append([
fov, roi["index"], timepoint, ref_timepoint, ch,
roi["traceGroup"], roi["traceId"], roi["tracePartners"],
z_min, z_max, y_min, y_max, x_min, x_max,
pad_z_min, pad_z_max, pad_y_min, pad_y_max, pad_x_min, pad_x_max,
z_drift_coarse, y_drift_coarse, x_drift_coarse,
dc_row["zDriftFinePixels"], dc_row["yDriftFinePixels"], dc_row["xDriftFinePixels"]
])

return pd.DataFrame(all_rois, columns=[
FIELD_OF_VIEW_COLUMN, "roiId", "timepoint", "ref_timepoint", SPOT_CHANNEL_COLUMN_NAME,
"traceId", "tracePartners",
"traceGroup", "traceId", "tracePartners",
"zMin", "zMax", "yMin", "yMax", "xMin", "xMax",
"pad_z_min", "pad_z_max", "pad_y_min", "pad_y_max", "pad_x_min", "pad_x_max",
"zDriftCoarsePixels", "yDriftCoarsePixels", "xDriftCoarsePixels",
Expand Down
119 changes: 82 additions & 37 deletions src/main/scala/AssignTraceIds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import at.ac.oeaw.imba.gerlich.gerlib.io.csv.{
readCsvToCaseClasses,
writeCaseClassesToCsv,
}
import at.ac.oeaw.imba.gerlich.gerlib.json.syntax.asJson
import at.ac.oeaw.imba.gerlich.gerlib.numeric.*
import at.ac.oeaw.imba.gerlich.gerlib.numeric.instances.all.given
import at.ac.oeaw.imba.gerlich.gerlib.syntax.all.*
Expand All @@ -52,6 +53,7 @@ import at.ac.oeaw.imba.gerlich.looptrace.csv.ColumnNames.{
}
import at.ac.oeaw.imba.gerlich.looptrace.csv.getCsvRowDecoderForImagingChannel
import at.ac.oeaw.imba.gerlich.looptrace.csv.instances.all.given
import at.ac.oeaw.imba.gerlich.looptrace.csv.instances.tracing.getCsvRowEncoderForTraceIdAssignmentWithoutRoiIndex
import at.ac.oeaw.imba.gerlich.looptrace.instances.all.given
import at.ac.oeaw.imba.gerlich.looptrace.internal.BuildInfo
import at.ac.oeaw.imba.gerlich.looptrace.roi.MergeAndSplitRoiTools.IndexedDetectedSpot
Expand Down Expand Up @@ -96,7 +98,7 @@ object AssignTraceIds extends ScoptCliReaders, StrictLogging:
.required()
.action((f, c) => c.copy(outputFile = f))
.text("Path to file to which to write main output"),
opt[os.Path]("skipsFile")
opt[os.Path]("skipsFileBase")
.required()
.action((f, c) => c.copy(skipsFile = f))
.text("Path to location to which to write skips"),
Expand Down Expand Up @@ -241,7 +243,7 @@ object AssignTraceIds extends ScoptCliReaders, StrictLogging:

val lookupTraceGroupId: NonEmptySet[ImagingTimepoint] => Either[AtLeast2[List, (Set[ImagingTimepoint], TraceGroupId)], TraceGroupOptional] =
val traceGroupIdByRegTimeSet: NonEmptyMap[Set[ImagingTimepoint], TraceGroupId] =
val membersNamePairs = rules.map(r => r.mergeGroup.members.toSet -> TraceGroupId(r.name))
val membersNamePairs = rules.map(r => r.mergeGroup.members.toSet -> r.name)
val (k1, v1) = membersNamePairs.head
given Order[Set[ImagingTimepoint]] = Order.by(_.toList)
membersNamePairs.tail
Expand Down Expand Up @@ -311,10 +313,18 @@ object AssignTraceIds extends ScoptCliReaders, StrictLogging:
.fold(!discardIfNotInGroupOfInterest)(_.requirement === RoiPartnersRequirementType.Lackadaisical)
.option{
val query = NonEmptySet.one(singleInputRecord.timepoint)
lookupTraceGroupId(query).bimap(
multiHit => (singleInputRecord.index, query -> multiHit.map(_._2)),
singleHit => OutputRecord(singleInputRecord, singleHit, currId, None, None)
)
val roiId = singleInputRecord.index
lookupTraceGroupId(query) match {
case Left(multiHit) =>
// problem case --> siphon off separately
(roiId, query -> multiHit.map(_._2)).asLeft
case Right(TraceGroupOptional(None)) =>
val assignment = TraceIdAssignment.UngroupedRecord(roiId, currId)
OutputRecord(singleInputRecord, assignment).asRight
case Right(TraceGroupOptional(Some(groupId))) =>
val assignment = TraceIdAssignment.GroupedAndUnmerged(roiId, currId, groupId)
OutputRecord(singleInputRecord, assignment).asRight
}
}
List(maybeOutputRecord).flatten
},
Expand Down Expand Up @@ -362,16 +372,52 @@ object AssignTraceIds extends ScoptCliReaders, StrictLogging:
val emitElem: InputRecord => InputRecordFate = lookupTraceGroupId(observedTimes) match {
case Left(multiHit) =>
(r: InputRecord) => (r.index, observedTimes -> multiHit.map(_._2)).asLeft
case Right(traceGroupIdOpt) =>
case Right(groupIdOpt) =>
val groupId = groupIdOpt.toOption.getOrElse{
// TODO: should this be made a valid case?
throw new Exception(s"No trace group ID found for multi-timepoint group (${})")
}
(r: InputRecord) =>
val partners = multiIds.remove(r.index).some
OutputRecord(r, traceGroupIdOpt, currId, partners, groupHasAllTimepoints.some).asRight
val partners = multiIds.remove(r.index)
val assignment = TraceIdAssignment.GroupedAndMerged(r.index, currId, groupId, partners, groupHasAllTimepoints)
OutputRecord(r, assignment).asRight
}
recGroup.map(emitElem).toList
else List()
else List() // The group lacked all the required regional imaging timepoints, and that's required if we're at this conditional.
)

private type RecordFailure = (RoiIndex, (NonEmptySet[ImagingTimepoint], AtLeast2[List, TraceGroupId]))
private type InputRecordFate = Either[RecordFailure, OutputRecord]

/** Helpers for working with the case of a record being unable to be processed */
object InputRecordFate:
given upickle.default.Writer[RecordFailure] =
import at.ac.oeaw.imba.gerlich.gerlib.collections.AtLeast2.syntax.toList
import at.ac.oeaw.imba.gerlich.gerlib.imaging.instances.all.given
import at.ac.oeaw.imba.gerlich.gerlib.json.JsonValueWriter

upickle.default.readwriter[ujson.Obj].bimap(
{ case (roiId, (groupTimes, groupIds)) =>
ujson.Obj(
"roiId" -> roiId.get,
"groupTimes" -> groupTimes
.toNonEmptyList
.toList
.sorted(using Order[ImagingTimepoint].toOrdering)
.map(_.asJson),
"groupIds" -> groupIds.toList.map(_.get),
)
},
obj =>
// Don't bother implementing this since we don't need it here, and
// we're typing the value of this expression as just a Writer, not a ReadWriter,
// so we don't need the Reader side of the equation.
???
)

private type InputRecordFate = Either[(RoiIndex, (NonEmptySet[ImagingTimepoint], AtLeast2[List, TraceGroupId])), OutputRecord]
/** Print the given failures as a JSON string, using the specified indentation. */
def printJsonFails(fails: List[RecordFailure], indent: Int = 2): String =
upickle.default.write(fails, indent = 2)

def workflow(roundsConfig: ImagingRoundsConfiguration, roisFile: os.Path, pixels: Pixels3D, outputFile: os.Path, skipsFile: os.Path): Unit = {
import InputRecord.given
Expand All @@ -396,7 +442,7 @@ object AssignTraceIds extends ScoptCliReaders, StrictLogging:
records.zipWithIndex.map{ (r, i) =>
val newTid = TraceId.unsafe(NonnegativeInt.unsafe(i) + initTraceId.get)
checkTraceId(traceIdsOffLimits)(newTid)
OutputRecord(r, TraceGroupOptional.empty, newTid, None, None).asRight
OutputRecord(r, TraceIdAssignment.UngroupedRecord(r.index, newTid)).asRight
}.toList
case Some(rules) =>
labelRecordsWithTraceId(rules, roundsConfig.discardRoisNotInGroupsOfInterest, pixels)(records)
Expand All @@ -406,18 +452,30 @@ object AssignTraceIds extends ScoptCliReaders, StrictLogging:
case Nil => IO{ logger.error("No output to write!") }
case inputFates =>
import OutputRecord.given
import InputRecordFate.given

given CsvRowEncoder[ImagingChannel, String] =
// for derivation of CsvRowEncoder[ImagingContext, String]
SpotChannelColumnName.toNamedEncoder


val (skips, records) = Alternative[List].separate(inputFates)
logger.info(s"Writing output file: $outputFile")
fs2.Stream
.emits(records.sortBy(_.inputRecord.index)(using Order[RoiIndex].toOrdering).toList)
.through(writeCaseClassesToCsv[OutputRecord](outputFile))
.compile
.drain
// TODO: write skips file

IO{ logger.info(s"Writing main output file: $outputFile") }.flatMap(
Function.const{
fs2.Stream
.emits(records.sortBy(_.inputRecord.index)(using Order[RoiIndex].toOrdering).toList)
.through(writeCaseClassesToCsv[OutputRecord](outputFile))
.compile
.drain
}
)

IO { logger.info(s"Writing skips file: $skipsFile") }.flatMap(
Function.const{
IO { os.write(skipsFile, upickle.default.write(skips), createFolders = true) }
}
)
})
.unsafeRunSync()

Expand All @@ -426,16 +484,8 @@ object AssignTraceIds extends ScoptCliReaders, StrictLogging:

final case class OutputRecord(
inputRecord: InputRecord, // NB: this part of the record contains the ACTUAL merge partners (if any).
traceGroupId: TraceGroupOptional,
traceId: TraceId,
// NB: these are pointer to actual ROIs, not just their timepoints; these are for TRACE merge, NOT ACTUAL merge.
maybePartners: Option[NonEmptySet[RoiIndex]],
hasAllPartners: Option[Boolean] // empty if and only if the record's a singlton
assignment: TraceIdAssignment,
):
require(
(maybePartners.isEmpty && hasAllPartners.isEmpty) || (maybePartners.nonEmpty && hasAllPartners.nonEmpty),
s"The hasAllPartners optional and maybePartners optional must either both be empty or both be nonempty; got ${hasAllPartners} and ${maybePartners}"
)
def index: RoiIndex = inputRecord.index
def context: ImagingContext = inputRecord.context
def centroid: Centroid[Double] = inputRecord.centroid
Expand Down Expand Up @@ -468,15 +518,10 @@ object AssignTraceIds extends ScoptCliReaders, StrictLogging:
case Some(nuclearDesignation) =>
NucleusDesignationColumnName.write(nuclearDesignation)
}
val traceGroupRow = TraceGroupColumnName.write(elem.traceGroupId)
val tidRow = TraceIdColumnName.write(elem.traceId)
val tracePartnersRow =
TracePartnersColumName.write(elem.maybePartners.fold(Set())(_.toSortedSet.toSet))
val allPartnersFlagRow =
given CellEncoder[Option[Boolean]] with
override def apply(cell: Option[Boolean]): String = cell.fold("")(encPartnersFlag.apply)
TracePartnersAreAllPresentColumnName.write(elem.hasAllPartners)
idRow |+| ctxRow |+| centerRow |+| boxRow |+| mergeInputsRow |+| nucRow |+| traceGroupRow |+| tidRow |+| allPartnersFlagRow |+| tracePartnersRow
val encAssignment: CsvRowEncoder[TraceIdAssignment, String] =
getCsvRowEncoderForTraceIdAssignmentWithoutRoiIndex
val traceIdAssignmentRow = encAssignment(elem.assignment)
idRow |+| ctxRow |+| centerRow |+| boxRow |+| mergeInputsRow |+| nucRow |+| traceIdAssignmentRow
end OutputRecord

final case class InputRecord(
Expand Down
Loading

0 comments on commit 0a390ef

Please sign in to comment.