From 90c338141f861cdc0fd2e26f724b78216a079d2a Mon Sep 17 00:00:00 2001 From: Josh Anderson Date: Fri, 12 Nov 2021 09:08:02 -0600 Subject: [PATCH] Add back util script edits --- smqtk_iqr/utils/compute_functions.py | 290 ++------------------ smqtk_iqr/utils/compute_many_descriptors.py | 48 ++-- smqtk_iqr/utils/generate_image_transform.py | 43 +-- smqtk_iqr/utils/nn_index_tool.py | 17 +- 4 files changed, 86 insertions(+), 312 deletions(-) diff --git a/smqtk_iqr/utils/compute_functions.py b/smqtk_iqr/utils/compute_functions.py index 8e74dafb..8fb3b52e 100644 --- a/smqtk_iqr/utils/compute_functions.py +++ b/smqtk_iqr/utils/compute_functions.py @@ -7,22 +7,28 @@ """ import collections import logging -import itertools -from typing import Deque, Hashable, List, Set - -import numpy - -from smqtk.representation import DataElement -from smqtk.utils import ( - cli, - bits, - parallel, +from typing import ( + Deque, Hashable, Set, Tuple, Generator, Iterable, Any, Optional ) - -def compute_many_descriptors(data_elements, descr_generator, descr_factory, - descr_set, batch_size=None, overwrite=False, - procs=None, **kwds): +from smqtk_dataprovider import ( + DataElement +) +from smqtk_descriptors import ( + DescriptorElement, DescriptorGenerator, DescriptorSet +) +from smqtk_descriptors.descriptor_element_factory import DescriptorElementFactory + + +def compute_many_descriptors(data_elements: Iterable[DataElement], + descr_generator: DescriptorGenerator, + descr_factory: DescriptorElementFactory, + descr_set: DescriptorSet, + batch_size: Optional[int] = None, + overwrite: bool = False, + procs: Optional[int] = None, + **kwds: Any) -> Iterable[Tuple[DataElement, + DescriptorElement]]: """ Compute descriptors for each data element, yielding (DataElement, DescriptorElement) tuple pairs in the order that they were @@ -33,21 +39,21 @@ def compute_many_descriptors(data_elements, descr_generator, descr_factory, :param data_elements: Iterable of DataElement instances of files to work on. - :type data_elements: collections.abc.Iterable[smqtk.representation.DataElement] + :type data_elements: collections.abc.Iterable[DataElement] :param descr_generator: DescriptorGenerator implementation instance to use to generate descriptor vectors. - :type descr_generator: smqtk.algorithms.DescriptorGenerator + :type descr_generator: DescriptorGenerator :param descr_factory: DescriptorElement factory to use when producing descriptor vectors. - :type descr_factory: smqtk.representation.DescriptorElementFactory + :type descr_factory: DescriptorElementFactory :param descr_set: DescriptorSet instance to add generated descriptors to. When given a non-zero batch size, we add descriptors to the given set in batches of that size. When a batch size is not given, we add all generated descriptors to the set after they have been generated. - :type descr_set: smqtk.representation.DescriptorSet + :type descr_set: DescriptorSet :param batch_size: Optional number of elements to asynchronously compute at a time. This is useful when it is desired for this function to yield @@ -72,8 +78,8 @@ def compute_many_descriptors(data_elements, descr_generator, descr_factory, :return: Generator that yields (DataElement, DescriptorElement) for each data element given, in the order they were provided. - :rtype: collections.abc.Iterable[(smqtk.representation.DataElement, - smqtk.representation.DescriptorElement)] + :rtype: collections.abc.Iterable[(DataElement, + DescriptorElement)] """ log = logging.getLogger(__name__) @@ -85,7 +91,7 @@ def compute_many_descriptors(data_elements, descr_generator, descr_factory, total = [0] unique: Set[Hashable] = set() - def iter_capture_elements(): + def iter_capture_elements() -> Generator: for d in data_elements: de_deque.append(d) yield d @@ -101,7 +107,7 @@ def iter_capture_elements(): if batch_size: log.debug("Computing in batches of size %d", batch_size) - def iterate_batch_results(): + def iterate_batch_results() -> Generator: descr_list_ = list(descr_generator.generate_elements( de_deque, descr_factory, overwrite )) @@ -149,243 +155,3 @@ def iterate_batch_results(): log.debug("yielding generated elements") for data, descr in zip(de_deque, descr_list): yield data, descr - - -class _CountedGenerator(object): - """ - Used to count elements of an iterable as they are accessed - - :param collections.abc.Iterable iterable: An iterable containing elements to be - accessed and counted. - :param list count_list: A list to which the count of items in iterable will - be added once the iterable has been exhausted. - """ - def __init__(self, iterable, count_list): - self.iterable = iterable - self.count_list = count_list - self.count = 0 - - def __call__(self): - for item in self.iterable: - self.count += 1 - yield item - self.count_list.append(self.count) - - -def compute_transformed_descriptors(data_elements, descr_generator, - descr_factory, descr_set, - transform_function, batch_size=None, - overwrite=False, procs=None, **kwds): - """ - Compute descriptors for copies of each data element generated by - a transform function, yielding a list of tuples containing the original - DataElement as the first element and a tuple of descriptors corresponding - to the transformed DataElements. - - *Note:* Please see the closely-related :func:`compute_many_descriptors` - for details on parameters and usage. - - *Note:* **This function currently only operates over images due to the - specific data validity check/filter performed.* - - :param transform_function: Takes in a DataElement and returns an iterable - of transformed DataElements. - :type transform_function: collections.abc.Callable - - :rtype: collections.abc.Iterable[ - (smqtk.representation.DataElement, - collections.abc.Iterable[smqtk.representation.DescriptorElement])] - """ - transformed_counts: List[int] = [] - - def transformed_elements(): - for elem in data_elements: - yield _CountedGenerator(transform_function(elem), - transformed_counts)() - - chained_elements = itertools.chain.from_iterable( - transformed_elements()) - descriptors = compute_many_descriptors(chained_elements, - descr_generator, descr_factory, - descr_set, batch_size=batch_size, - overwrite=overwrite, procs=procs, - **kwds) - for count, de in zip(transformed_counts, data_elements): - yield de, itertools.islice((d[1] for d in descriptors), count) - - -def compute_hash_codes(uuids, descr_set, functor, report_interval=1.0, - use_mp=False, ordered=False): - """ - Given an iterable of DescriptorElement UUIDs, asynchronously access them - from the given ``set``, asynchronously compute hash codes via ``functor`` - and convert to an integer, yielding (UUID, hash-int) pairs. - - :param uuids: Sequence of UUIDs to process - :type uuids: collections.abc.Iterable[collections.abc.Hashable] - - :param descr_set: Descriptor set to pull from. - :type descr_set: smqtk.representation.descriptor_set.DescriptorSet - - :param functor: LSH hash code functor instance - :type functor: smqtk.algorithms.LshFunctor - - :param report_interval: Frequency in seconds at which we report speed and - completion progress via logging. Reporting is disabled when logging - is not in debug and this value is greater than 0. - :type report_interval: float - - :param use_mp: If multiprocessing should be used for parallel - computation vs. threading. Reminder: This will copy currently loaded - objects onto worker processes (e.g. the given set), which could lead - to dangerously high RAM consumption. - :type use_mp: bool - - :param ordered: If the element-hash value pairs yielded are in the same - order as element UUID values input. This function should be slightly - faster when ordering is not required. - :type ordered: bool - - :return: Generator instance yielding (DescriptorElement, int) value pairs. - - """ - # TODO: parallel map fetch elements from set? - # -> separately from compute - - def get_hash(u): - v = descr_set.get_descriptor(u).vector() - return u, bits.bit_vector_to_int_large(functor.get_hash(v)) - - # Setup log and reporting function - log = logging.getLogger(__name__) - - if log.getEffectiveLevel() > logging.DEBUG or report_interval <= 0: - def log_func(*_, **__): - return - log.debug("Not logging progress") - else: - log.debug("Logging progress at %f second intervals", report_interval) - log_func = log.debug # type: ignore - - log.debug("Starting computation") - reporter = cli.ProgressReporter(log_func, report_interval) - reporter.start() - for uuid, hash_int in parallel.parallel_map(get_hash, uuids, - ordered=ordered, - use_multiprocessing=use_mp): - yield (uuid, hash_int) - # Progress reporting - reporter.increment_report() - - # Final report - reporter.report() - - -def mb_kmeans_build_apply(descr_set, mbkm, initial_fit_size): - """ - Build the MiniBatchKMeans centroids based on the descriptors in the given - set, then predicting descriptor clusters with the final result model. - - If the given set is empty, no fitting or clustering occurs and an empty - dictionary is returned. - - :param descr_set: set of descriptors - :type descr_set: smqtk.representation.DescriptorSet - - :param mbkm: Scikit-Learn MiniBatchKMeans instead to train and then use for - prediction - :type mbkm: sklearn.cluster.MiniBatchKMeans - - :param initial_fit_size: Number of descriptors to run an initial fit with. - This brings the advantage of choosing a best initialization point from - multiple. - :type initial_fit_size: int - - :return: Dictionary of the cluster label (integer) to the set of descriptor - UUIDs belonging to that cluster. - :rtype: dict[int, set[collections.abc.Hashable]] - - """ - log = logging.getLogger(__name__) - - ifit_completed = False - k_deque: Deque[Hashable] = collections.deque() - d_fitted = 0 - - log.info("Getting set keys (shuffled)") - set_keys = sorted(descr_set.keys()) - numpy.random.seed(mbkm.random_state) - numpy.random.shuffle(set_keys) - - def parallel_iter_vectors(descriptors): - """ Get the vectors for the descriptors given. - Not caring about order returned. - """ - return parallel.parallel_map(lambda d: d.vector(), descriptors, - use_multiprocessing=False) - - def get_vectors(k_iter): - """ Get numpy array of descriptor vectors (2D array returned) """ - return numpy.array(list( - parallel_iter_vectors(descr_set.get_many_descriptors(k_iter)) - )) - - log.info("Collecting iteratively fitting model") - pr = cli.ProgressReporter(log.debug, 1.0).start() - for i, k in enumerate(set_keys): - k_deque.append(k) - pr.increment_report() - - if initial_fit_size and not ifit_completed: - if len(k_deque) == initial_fit_size: - log.info("Initial fit using %d descriptors", len(k_deque)) - log.info("- collecting vectors") - vectors = get_vectors(k_deque) - log.info("- fitting model") - mbkm.fit(vectors) - log.info("- cleaning") - d_fitted += len(vectors) - k_deque.clear() - ifit_completed = True - elif len(k_deque) == mbkm.batch_size: - log.info("Partial fit with batch size %d", len(k_deque)) - log.info("- collecting vectors") - vectors = get_vectors(k_deque) - log.info("- fitting model") - mbkm.partial_fit(vectors) - log.info("- cleaning") - d_fitted += len(k_deque) - k_deque.clear() - pr.report() - - # Final fit with any remaining descriptors - if k_deque: - log.info("Final partial fit of size %d", len(k_deque)) - log.info('- collecting vectors') - vectors = get_vectors(k_deque) - log.info('- fitting model') - mbkm.partial_fit(vectors) - log.info('- cleaning') - d_fitted += len(k_deque) - k_deque.clear() - - log.info("Computing descriptor classes with final KMeans model") - mbkm.verbose = False - d_classes = collections.defaultdict(set) - d_uv_iter = parallel.parallel_map(lambda d: (d.uuid(), d.vector()), - descr_set, - use_multiprocessing=False, - name="uv-collector") - # TODO: Batch predict call inputs to something larger than one at a time. - d_uc_iter = parallel.parallel_map( - lambda u_v: (u_v[0], mbkm.predict(u_v[1][numpy.newaxis, :])[0]), - d_uv_iter, - use_multiprocessing=False, - name="uc-collector") - pr = cli.ProgressReporter(log.debug, 1.0).start() - for uuid, c in d_uc_iter: - d_classes[c].add(uuid) - pr.increment_report() - pr.report() - - return d_classes diff --git a/smqtk_iqr/utils/compute_many_descriptors.py b/smqtk_iqr/utils/compute_many_descriptors.py index 191cd650..69c1bc2d 100644 --- a/smqtk_iqr/utils/compute_many_descriptors.py +++ b/smqtk_iqr/utils/compute_many_descriptors.py @@ -8,30 +8,31 @@ import csv import logging import os -from typing import cast, Deque, Optional - -from smqtk.algorithms import DescriptorGenerator -from smqtk.compute_functions import compute_many_descriptors -from smqtk.representation import ( - DescriptorElementFactory, - DataSet, - DescriptorSet, -) -from smqtk.representation.data_element.file_element import DataFileElement -from smqtk.utils import parallel -from smqtk.utils.cli import ( +import argparse +from typing import cast, Deque, Optional, Dict, Union, Generator + +from smqtk_descriptors import DescriptorGenerator, DescriptorSet +from smqtk_descriptors.descriptor_element_factory import DescriptorElementFactory +from smqtk_descriptors.utils import parallel +from smqtk_iqr.utils.compute_functions import compute_many_descriptors + +from smqtk_dataprovider import DataSet +from smqtk_dataprovider.impls.data_element.file import DataFileElement + +from smqtk_iqr.utils.cli import ( utility_main_helper, ProgressReporter, basic_cli_parser, ) -from smqtk.utils.configuration import ( +from smqtk_core.configuration import ( from_config_dict, make_default_config, ) -from smqtk.utils.image import is_valid_element + +from smqtk_image_io.utils.image import is_valid_element -def default_config(): +def default_config() -> Dict: return { "descriptor_generator": make_default_config(DescriptorGenerator.get_impls()), @@ -43,8 +44,8 @@ def default_config(): } -def run_file_list(c, filelist_filepath, checkpoint_filepath, batch_size=None, - check_image=False): +def run_file_list(c: dict, filelist_filepath: str, checkpoint_filepath: str, + batch_size: Optional[int] = None, check_image: bool = False) -> None: """ Top level function handling configuration and inputs/outputs. @@ -104,8 +105,8 @@ def run_file_list(c, filelist_filepath, checkpoint_filepath, batch_size=None, DescriptorGenerator.get_impls()) ) - def iter_valid_elements(): - def is_valid(file_path): + def iter_valid_elements() -> Generator: + def is_valid(file_path: str) -> Union[DataFileElement, bool]: e = DataFileElement(file_path) if is_valid_element( @@ -122,6 +123,7 @@ def is_valid(file_path): use_multiprocessing=True) for dfe in valid_files_filter: if dfe: + assert isinstance(dfe, DataFileElement) yield dfe if data_set is not None: data_elements.append(dfe) @@ -155,7 +157,7 @@ def is_valid(file_path): # compute_many_descriptors, so we can assume that's what comes out # of it as well. # noinspection PyProtectedMember - cf_writer.writerow([de._filepath, descr.uuid()]) + cf_writer.writerow([de._filepath, descr.uuid()]) # type: ignore pr.increment_report() pr.report() finally: @@ -165,7 +167,7 @@ def is_valid(file_path): log.info("Done") -def cli_parser(): +def cli_parser() -> argparse.ArgumentParser: parser = basic_cli_parser(__doc__) parser.add_argument('-b', '--batch-size', @@ -205,9 +207,9 @@ def cli_parser(): return parser -def main(): +def main() -> None: args = cli_parser().parse_args() - config = utility_main_helper(default_config, args) + config = utility_main_helper(default_config(), args) log = logging.getLogger(__name__) completed_files_fp = args.completed_files diff --git a/smqtk_iqr/utils/generate_image_transform.py b/smqtk_iqr/utils/generate_image_transform.py index 88d80d57..6a347c6a 100644 --- a/smqtk_iqr/utils/generate_image_transform.py +++ b/smqtk_iqr/utils/generate_image_transform.py @@ -52,25 +52,30 @@ import logging import os +import argparse +from typing import Dict, List, Optional, Tuple import PIL.Image -import smqtk.utils.cli -import smqtk.utils.file -import smqtk.utils.parallel -from smqtk.utils.image import ( +import smqtk_iqr.utils.cli +import smqtk_dataprovider.utils.file +import smqtk_descriptors.utils.parallel + +from smqtk_image_io.utils.image import ( image_crop_center_levels, image_crop_quadrant_pyramid, image_crop_tiles, image_brightness_intervals, image_contrast_intervals ) -def generate_image_transformations(image_path, - crop_center_n, crop_quadrant_levels, - crop_tile_shape, crop_tile_stride, - brightness_intervals, - contrast_intervals, - output_dir=None, - output_ext='.png'): +def generate_image_transformations(image_path: str, + crop_center_n: Optional[int], + crop_quadrant_levels: Optional[int], + crop_tile_shape: Optional[Tuple[int, int]], + crop_tile_stride: Optional[Tuple[int, int]], + brightness_intervals: Optional[int], + contrast_intervals: Optional[int], + output_dir: str = None, + output_ext: str = '.png') -> None: """ Transform an input image into different crops or other transforms, outputting results to the given output directory without overwriting or @@ -86,13 +91,13 @@ def generate_image_transformations(image_path, abs_path = os.path.abspath(image_path) output_dir = output_dir or os.path.dirname(abs_path) - smqtk.utils.file.safe_create_dir(output_dir) + smqtk_dataprovider.utils.file.safe_create_dir(output_dir) p_base = os.path.splitext(os.path.basename(abs_path))[0] p_ext = output_ext p_base = os.path.join(output_dir, p_base) image = PIL.Image.open(image_path).convert('RGB') - def save_image(img, suffixes): + def save_image(img: PIL.Image.Image, suffixes: List[str]) -> None: """ Save an image based on source image basename and an iterable of suffix parts that will be separated by periods. @@ -120,7 +125,7 @@ def save_image(img, suffixes): log.info("Cropping %dx%d pixel tiles from images with stride %s" % (t_width, t_height, crop_tile_stride)) # List needed to iterate generator. - list(smqtk.utils.parallel.parallel_map( + list(smqtk_descriptors.utils.parallel.parallel_map( lambda x, y, ii: save_image(ii, [tag, '%dx%d+%d+%d' % (t_width, t_height, x, y)]), @@ -138,7 +143,7 @@ def save_image(img, suffixes): save_image(i, ['contrast', str(c)]) -def default_config(): +def default_config() -> Dict: return { "crop": { # 0 means disabled @@ -156,8 +161,8 @@ def default_config(): } -def cli_parser(): - parser = smqtk.utils.cli.basic_cli_parser(__doc__) +def cli_parser() -> argparse.ArgumentParser: + parser = smqtk_iqr.utils.cli.basic_cli_parser(__doc__) g_io = parser.add_argument_group("Input/Output") g_io.add_argument("-i", "--image", @@ -174,9 +179,9 @@ def cli_parser(): return parser -def main(): +def main() -> None: args = cli_parser().parse_args() - config = smqtk.utils.cli.utility_main_helper(default_config, args) + config = smqtk_iqr.utils.cli.utility_main_helper(default_config(), args) input_image_path = args.image output_dir = args.output diff --git a/smqtk_iqr/utils/nn_index_tool.py b/smqtk_iqr/utils/nn_index_tool.py index e98469a8..4ff097ae 100644 --- a/smqtk_iqr/utils/nn_index_tool.py +++ b/smqtk_iqr/utils/nn_index_tool.py @@ -3,16 +3,17 @@ """ import click import logging +from typing import Dict -from smqtk.algorithms import NearestNeighborsIndex -from smqtk.representation import DescriptorSet -from smqtk.utils.cli import initialize_logging, load_config, output_config -from smqtk.utils.configuration import from_config_dict, make_default_config +from smqtk_indexing import NearestNeighborsIndex +from smqtk_descriptors import DescriptorSet +from smqtk_iqr.utils.cli import initialize_logging, load_config, output_config +from smqtk_core.configuration import from_config_dict, make_default_config LOG = logging.getLogger(__name__) -def build_default_config(): +def build_default_config() -> Dict: return { 'descriptor_set': make_default_config(DescriptorSet.get_impls()), 'neighbor_index': make_default_config(NearestNeighborsIndex.get_impls()), @@ -25,7 +26,7 @@ def build_default_config(): help="This option must be provided before any command. " "Provide once for additional informational logging. " "Provide a second time for additional debug logging.") -def cli_group(verbose): +def cli_group(verbose: int) -> None: """ Tool for building a nearest neighbors index from an input descriptor set. @@ -51,7 +52,7 @@ def cli_group(verbose): default=False, is_flag=True, help='If the given filepath should be overwritten if it ' 'already exists.') -def cli_config(output_filepath, input_config, overwrite): +def cli_config(output_filepath: str, input_config: str, overwrite: bool) -> None: """ Generate a default or template JSON configuration file for this tool. """ @@ -69,7 +70,7 @@ def cli_config(output_filepath, input_config, overwrite): @click.command('build') @click.argument('config_filepath', type=click.Path(exists=True, dir_okay=False)) -def cli_build(config_filepath): +def cli_build(config_filepath: str) -> None: """ Build a new nearest-neighbors index from the configured descriptor set's contents.