Skip to content

Commit

Permalink
Add back util script edits
Browse files Browse the repository at this point in the history
  • Loading branch information
joshanderson-kw committed Nov 12, 2021
1 parent 6cbd35d commit 90c3381
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 312 deletions.
290 changes: 28 additions & 262 deletions smqtk_iqr/utils/compute_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,28 @@
"""
import collections
import logging
import itertools
from typing import Deque, Hashable, List, Set

import numpy

from smqtk.representation import DataElement
from smqtk.utils import (
cli,
bits,
parallel,
from typing import (
Deque, Hashable, Set, Tuple, Generator, Iterable, Any, Optional
)


def compute_many_descriptors(data_elements, descr_generator, descr_factory,
descr_set, batch_size=None, overwrite=False,
procs=None, **kwds):
from smqtk_dataprovider import (
DataElement
)
from smqtk_descriptors import (
DescriptorElement, DescriptorGenerator, DescriptorSet
)
from smqtk_descriptors.descriptor_element_factory import DescriptorElementFactory


def compute_many_descriptors(data_elements: Iterable[DataElement],
descr_generator: DescriptorGenerator,
descr_factory: DescriptorElementFactory,
descr_set: DescriptorSet,
batch_size: Optional[int] = None,
overwrite: bool = False,
procs: Optional[int] = None,
**kwds: Any) -> Iterable[Tuple[DataElement,
DescriptorElement]]:
"""
Compute descriptors for each data element, yielding
(DataElement, DescriptorElement) tuple pairs in the order that they were
Expand All @@ -33,21 +39,21 @@ def compute_many_descriptors(data_elements, descr_generator, descr_factory,
:param data_elements: Iterable of DataElement instances of files to
work on.
:type data_elements: collections.abc.Iterable[smqtk.representation.DataElement]
:type data_elements: collections.abc.Iterable[DataElement]
:param descr_generator: DescriptorGenerator implementation instance
to use to generate descriptor vectors.
:type descr_generator: smqtk.algorithms.DescriptorGenerator
:type descr_generator: DescriptorGenerator
:param descr_factory: DescriptorElement factory to use when producing
descriptor vectors.
:type descr_factory: smqtk.representation.DescriptorElementFactory
:type descr_factory: DescriptorElementFactory
:param descr_set: DescriptorSet instance to add generated descriptors
to. When given a non-zero batch size, we add descriptors to the given
set in batches of that size. When a batch size is not given, we add
all generated descriptors to the set after they have been generated.
:type descr_set: smqtk.representation.DescriptorSet
:type descr_set: DescriptorSet
:param batch_size: Optional number of elements to asynchronously compute
at a time. This is useful when it is desired for this function to yield
Expand All @@ -72,8 +78,8 @@ def compute_many_descriptors(data_elements, descr_generator, descr_factory,
:return: Generator that yields (DataElement, DescriptorElement) for each
data element given, in the order they were provided.
:rtype: collections.abc.Iterable[(smqtk.representation.DataElement,
smqtk.representation.DescriptorElement)]
:rtype: collections.abc.Iterable[(DataElement,
DescriptorElement)]
"""
log = logging.getLogger(__name__)
Expand All @@ -85,7 +91,7 @@ def compute_many_descriptors(data_elements, descr_generator, descr_factory,
total = [0]
unique: Set[Hashable] = set()

def iter_capture_elements():
def iter_capture_elements() -> Generator:
for d in data_elements:
de_deque.append(d)
yield d
Expand All @@ -101,7 +107,7 @@ def iter_capture_elements():
if batch_size:
log.debug("Computing in batches of size %d", batch_size)

def iterate_batch_results():
def iterate_batch_results() -> Generator:
descr_list_ = list(descr_generator.generate_elements(
de_deque, descr_factory, overwrite
))
Expand Down Expand Up @@ -149,243 +155,3 @@ def iterate_batch_results():
log.debug("yielding generated elements")
for data, descr in zip(de_deque, descr_list):
yield data, descr


class _CountedGenerator(object):
"""
Used to count elements of an iterable as they are accessed
:param collections.abc.Iterable iterable: An iterable containing elements to be
accessed and counted.
:param list count_list: A list to which the count of items in iterable will
be added once the iterable has been exhausted.
"""
def __init__(self, iterable, count_list):
self.iterable = iterable
self.count_list = count_list
self.count = 0

def __call__(self):
for item in self.iterable:
self.count += 1
yield item
self.count_list.append(self.count)


def compute_transformed_descriptors(data_elements, descr_generator,
descr_factory, descr_set,
transform_function, batch_size=None,
overwrite=False, procs=None, **kwds):
"""
Compute descriptors for copies of each data element generated by
a transform function, yielding a list of tuples containing the original
DataElement as the first element and a tuple of descriptors corresponding
to the transformed DataElements.
*Note:* Please see the closely-related :func:`compute_many_descriptors`
for details on parameters and usage.
*Note:* **This function currently only operates over images due to the
specific data validity check/filter performed.*
:param transform_function: Takes in a DataElement and returns an iterable
of transformed DataElements.
:type transform_function: collections.abc.Callable
:rtype: collections.abc.Iterable[
(smqtk.representation.DataElement,
collections.abc.Iterable[smqtk.representation.DescriptorElement])]
"""
transformed_counts: List[int] = []

def transformed_elements():
for elem in data_elements:
yield _CountedGenerator(transform_function(elem),
transformed_counts)()

chained_elements = itertools.chain.from_iterable(
transformed_elements())
descriptors = compute_many_descriptors(chained_elements,
descr_generator, descr_factory,
descr_set, batch_size=batch_size,
overwrite=overwrite, procs=procs,
**kwds)
for count, de in zip(transformed_counts, data_elements):
yield de, itertools.islice((d[1] for d in descriptors), count)


def compute_hash_codes(uuids, descr_set, functor, report_interval=1.0,
use_mp=False, ordered=False):
"""
Given an iterable of DescriptorElement UUIDs, asynchronously access them
from the given ``set``, asynchronously compute hash codes via ``functor``
and convert to an integer, yielding (UUID, hash-int) pairs.
:param uuids: Sequence of UUIDs to process
:type uuids: collections.abc.Iterable[collections.abc.Hashable]
:param descr_set: Descriptor set to pull from.
:type descr_set: smqtk.representation.descriptor_set.DescriptorSet
:param functor: LSH hash code functor instance
:type functor: smqtk.algorithms.LshFunctor
:param report_interval: Frequency in seconds at which we report speed and
completion progress via logging. Reporting is disabled when logging
is not in debug and this value is greater than 0.
:type report_interval: float
:param use_mp: If multiprocessing should be used for parallel
computation vs. threading. Reminder: This will copy currently loaded
objects onto worker processes (e.g. the given set), which could lead
to dangerously high RAM consumption.
:type use_mp: bool
:param ordered: If the element-hash value pairs yielded are in the same
order as element UUID values input. This function should be slightly
faster when ordering is not required.
:type ordered: bool
:return: Generator instance yielding (DescriptorElement, int) value pairs.
"""
# TODO: parallel map fetch elements from set?
# -> separately from compute

def get_hash(u):
v = descr_set.get_descriptor(u).vector()
return u, bits.bit_vector_to_int_large(functor.get_hash(v))

# Setup log and reporting function
log = logging.getLogger(__name__)

if log.getEffectiveLevel() > logging.DEBUG or report_interval <= 0:
def log_func(*_, **__):
return
log.debug("Not logging progress")
else:
log.debug("Logging progress at %f second intervals", report_interval)
log_func = log.debug # type: ignore

log.debug("Starting computation")
reporter = cli.ProgressReporter(log_func, report_interval)
reporter.start()
for uuid, hash_int in parallel.parallel_map(get_hash, uuids,
ordered=ordered,
use_multiprocessing=use_mp):
yield (uuid, hash_int)
# Progress reporting
reporter.increment_report()

# Final report
reporter.report()


def mb_kmeans_build_apply(descr_set, mbkm, initial_fit_size):
"""
Build the MiniBatchKMeans centroids based on the descriptors in the given
set, then predicting descriptor clusters with the final result model.
If the given set is empty, no fitting or clustering occurs and an empty
dictionary is returned.
:param descr_set: set of descriptors
:type descr_set: smqtk.representation.DescriptorSet
:param mbkm: Scikit-Learn MiniBatchKMeans instead to train and then use for
prediction
:type mbkm: sklearn.cluster.MiniBatchKMeans
:param initial_fit_size: Number of descriptors to run an initial fit with.
This brings the advantage of choosing a best initialization point from
multiple.
:type initial_fit_size: int
:return: Dictionary of the cluster label (integer) to the set of descriptor
UUIDs belonging to that cluster.
:rtype: dict[int, set[collections.abc.Hashable]]
"""
log = logging.getLogger(__name__)

ifit_completed = False
k_deque: Deque[Hashable] = collections.deque()
d_fitted = 0

log.info("Getting set keys (shuffled)")
set_keys = sorted(descr_set.keys())
numpy.random.seed(mbkm.random_state)
numpy.random.shuffle(set_keys)

def parallel_iter_vectors(descriptors):
""" Get the vectors for the descriptors given.
Not caring about order returned.
"""
return parallel.parallel_map(lambda d: d.vector(), descriptors,
use_multiprocessing=False)

def get_vectors(k_iter):
""" Get numpy array of descriptor vectors (2D array returned) """
return numpy.array(list(
parallel_iter_vectors(descr_set.get_many_descriptors(k_iter))
))

log.info("Collecting iteratively fitting model")
pr = cli.ProgressReporter(log.debug, 1.0).start()
for i, k in enumerate(set_keys):
k_deque.append(k)
pr.increment_report()

if initial_fit_size and not ifit_completed:
if len(k_deque) == initial_fit_size:
log.info("Initial fit using %d descriptors", len(k_deque))
log.info("- collecting vectors")
vectors = get_vectors(k_deque)
log.info("- fitting model")
mbkm.fit(vectors)
log.info("- cleaning")
d_fitted += len(vectors)
k_deque.clear()
ifit_completed = True
elif len(k_deque) == mbkm.batch_size:
log.info("Partial fit with batch size %d", len(k_deque))
log.info("- collecting vectors")
vectors = get_vectors(k_deque)
log.info("- fitting model")
mbkm.partial_fit(vectors)
log.info("- cleaning")
d_fitted += len(k_deque)
k_deque.clear()
pr.report()

# Final fit with any remaining descriptors
if k_deque:
log.info("Final partial fit of size %d", len(k_deque))
log.info('- collecting vectors')
vectors = get_vectors(k_deque)
log.info('- fitting model')
mbkm.partial_fit(vectors)
log.info('- cleaning')
d_fitted += len(k_deque)
k_deque.clear()

log.info("Computing descriptor classes with final KMeans model")
mbkm.verbose = False
d_classes = collections.defaultdict(set)
d_uv_iter = parallel.parallel_map(lambda d: (d.uuid(), d.vector()),
descr_set,
use_multiprocessing=False,
name="uv-collector")
# TODO: Batch predict call inputs to something larger than one at a time.
d_uc_iter = parallel.parallel_map(
lambda u_v: (u_v[0], mbkm.predict(u_v[1][numpy.newaxis, :])[0]),
d_uv_iter,
use_multiprocessing=False,
name="uc-collector")
pr = cli.ProgressReporter(log.debug, 1.0).start()
for uuid, c in d_uc_iter:
d_classes[c].add(uuid)
pr.increment_report()
pr.report()

return d_classes
Loading

0 comments on commit 90c3381

Please sign in to comment.