diff --git a/libtbx/mpi4py.py b/libtbx/mpi4py.py index edb047ff73..da39bd9834 100644 --- a/libtbx/mpi4py.py +++ b/libtbx/mpi4py.py @@ -54,6 +54,8 @@ def Gatherv(self, sendbuf, recvbuf, root=0): for item, count in zip(sendbuf, counts): rbuff[counter:counter+count] = item counter += count + def allgather(self, sendobj): + return [sendobj] def Abort(self,errorcode=0): import sys sys.exit() diff --git a/xfel/merging/application/mpi_helper.py b/xfel/merging/application/mpi_helper.py index c8302dd982..215a8d0260 100644 --- a/xfel/merging/application/mpi_helper.py +++ b/xfel/merging/application/mpi_helper.py @@ -72,7 +72,7 @@ def count(self, data, root=0): Return total `Counter` of occurrences of each element in data across ranks. Example: (a1, a1, a2) + (a1, a2, a3) = {a1: 3, a2: 2, a1: 1} """ - counters = self.comm.gather(Counter(data), rank=root) + counters = self.comm.gather(Counter(data), root=root) return sum(counters, Counter()) if self.rank == root else None def sum(self, data, root=0): diff --git a/xfel/merging/application/postrefine/postrefinement_rs.py b/xfel/merging/application/postrefine/postrefinement_rs.py index 59aec9bca7..d7dc97f8ba 100644 --- a/xfel/merging/application/postrefine/postrefinement_rs.py +++ b/xfel/merging/application/postrefine/postrefinement_rs.py @@ -2,6 +2,7 @@ import six from six.moves import range from six.moves import cStringIO as StringIO +from collections import Counter import math from xfel.merging.application.worker import worker from libtbx import adopt_init_args, group_args @@ -46,7 +47,7 @@ def run(self, experiments, reflections): new_experiments = ExperimentList() new_reflections = flex.reflection_table() - experiments_rejected_by_reason = {} # reason:how_many_rejected + experiments_rejected_by_reason = Counter() # reason:how_many_rejected for expt_id, experiment in enumerate(experiments): @@ -191,10 +192,7 @@ def run(self, experiments, reflections): reason = repr(e) if not reason: reason = "Unknown error" - if not reason in experiments_rejected_by_reason: - experiments_rejected_by_reason[reason] = 1 - else: - experiments_rejected_by_reason[reason] += 1 + experiments_rejected_by_reason[reason] += 1 if not error_detected: new_experiments.append(experiment) @@ -228,27 +226,12 @@ def run(self, experiments, reflections): self.logger.log("Experiments rejected by post-refinement: %d"%experiments_rejected_by_postrefinement) self.logger.log("Reflections rejected by post-refinement: %d"%reflections_rejected_by_postrefinement) - all_reasons = [] for reason, count in six.iteritems(experiments_rejected_by_reason): self.logger.log("Experiments rejected due to %s: %d"%(reason,count)) - all_reasons.append(reason) - - comm = self.mpi_helper.comm - MPI = self.mpi_helper.MPI - - # Collect all rejection reasons from all ranks. Use allreduce to let each rank have all reasons. - all_reasons = comm.allreduce(all_reasons, MPI.SUM) - all_reasons = set(all_reasons) # Now that each rank has all reasons from all ranks, we can treat the reasons in a uniform way. - total_experiments_rejected_by_reason = {} - for reason in all_reasons: - rejected_experiment_count = 0 - if reason in experiments_rejected_by_reason: - rejected_experiment_count = experiments_rejected_by_reason[reason] - total_experiments_rejected_by_reason[reason] = comm.reduce(rejected_experiment_count, MPI.SUM, 0) - - total_accepted_experiment_count = comm.reduce(len(new_experiments), MPI.SUM, 0) + total_experiments_rejected_by_reason = self.mpi_helper.count(experiments_rejected_by_reason) + total_accepted_experiment_count = self.mpi_helper.sum(len(new_experiments)) # how many reflections have we rejected due to post-refinement? rejected_reflections = len(reflections) - len(new_reflections); diff --git a/xfel/merging/application/postrefine/postrefinement_rs2.py b/xfel/merging/application/postrefine/postrefinement_rs2.py index e431d413b8..d9531d7c43 100644 --- a/xfel/merging/application/postrefine/postrefinement_rs2.py +++ b/xfel/merging/application/postrefine/postrefinement_rs2.py @@ -2,6 +2,7 @@ import six from six.moves import range from six.moves import cStringIO as StringIO +from collections import Counter import math from libtbx import adopt_init_args from dials.array_family import flex @@ -56,7 +57,7 @@ def run(self, experiments, reflections): new_experiments = ExperimentList() new_reflections = flex.reflection_table() - experiments_rejected_by_reason = {} # reason:how_many_rejected + experiments_rejected_by_reason = Counter() # reason:how_many_rejected for expt_id, experiment in enumerate(experiments): @@ -192,10 +193,7 @@ def run(self, experiments, reflections): reason = repr(e) if not reason: reason = "Unknown error" - if not reason in experiments_rejected_by_reason: - experiments_rejected_by_reason[reason] = 1 - else: - experiments_rejected_by_reason[reason] += 1 + experiments_rejected_by_reason[reason] += 1 if not error_detected: new_experiments.append(experiment) @@ -229,27 +227,12 @@ def run(self, experiments, reflections): self.logger.log("Experiments rejected by post-refinement: %d"%experiments_rejected_by_postrefinement) self.logger.log("Reflections rejected by post-refinement: %d"%reflections_rejected_by_postrefinement) - all_reasons = [] for reason, count in six.iteritems(experiments_rejected_by_reason): self.logger.log("Experiments rejected due to %s: %d"%(reason,count)) - all_reasons.append(reason) - - comm = self.mpi_helper.comm - MPI = self.mpi_helper.MPI - - # Collect all rejection reasons from all ranks. Use allreduce to let each rank have all reasons. - all_reasons = comm.allreduce(all_reasons, MPI.SUM) - all_reasons = set(all_reasons) # Now that each rank has all reasons from all ranks, we can treat the reasons in a uniform way. - total_experiments_rejected_by_reason = {} - for reason in all_reasons: - rejected_experiment_count = 0 - if reason in experiments_rejected_by_reason: - rejected_experiment_count = experiments_rejected_by_reason[reason] - total_experiments_rejected_by_reason[reason] = comm.reduce(rejected_experiment_count, MPI.SUM, 0) - - total_accepted_experiment_count = comm.reduce(len(new_experiments), MPI.SUM, 0) + total_experiments_rejected_by_reason = self.mpi_helper.count(experiments_rejected_by_reason) + total_accepted_experiment_count = self.mpi_helper.sum(len(new_experiments)) # how many reflections have we rejected due to post-refinement? rejected_reflections = len(reflections) - len(new_reflections);