From a95811b10c4eb3acb45f2407aa894da4d0232faf Mon Sep 17 00:00:00 2001 From: Krzysztof Kowalczyk Date: Mon, 8 Jun 2020 17:39:09 +0200 Subject: [PATCH] Fixed COUNTIFS algorithm + cutom key feature in TeraSort --- spark_minimal_algorithms/algorithm.py | 8 +- spark_minimal_algorithms/examples/countifs.py | 260 ++++++++++-------- .../examples/tera_sort.py | 34 ++- tests/examples/test_countifs.py | 165 ++++++++++- tests/examples/test_tera_sort.py | 27 +- 5 files changed, 342 insertions(+), 152 deletions(-) diff --git a/spark_minimal_algorithms/algorithm.py b/spark_minimal_algorithms/algorithm.py index 0124384..d5fe25a 100644 --- a/spark_minimal_algorithms/algorithm.py +++ b/spark_minimal_algorithms/algorithm.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod, abstractstaticmethod -from typing import Any, Optional, Iterable, Type, List, Tuple, Dict +from typing import Any, Optional, Iterable, Type, List, Tuple, Dict, Union from pyspark import SparkContext, RDD from pyspark.broadcast import Broadcast @@ -124,7 +124,7 @@ def unwrap_emit(kv: Tuple[Any, Iterable[Any]]) -> Optional[Tuple[Any, Any]]: new_v = step_cls.emit_by_group(k, v, **kwargs) return new_v - emitted = list(rdd.map(unwrap_emit).collect()) + emitted = list(rdd.map(unwrap_emit, preservesPartitioning=True).collect()) to_broadcast = step_cls.broadcast(emitted, **kwargs) broadcast: Broadcast = self._sc.broadcast(to_broadcast) @@ -133,7 +133,7 @@ def unwrap_step(kv: Tuple[Any, Iterable[Any]]) -> Iterable[Any]: for new_v in step_cls.step(k, v, broadcast, **kwargs): yield new_v - rdd = rdd.flatMap(unwrap_step) + rdd = rdd.flatMap(unwrap_step, preservesPartitioning=True) return rdd @@ -190,7 +190,7 @@ def run(self, rdd: RDD, n_dim: int) -> RDD: ``` """ - __steps__: Dict[str, Type[Step]] = dict() + __steps__: Dict[str, Union[Type[Step], Type["Algorithm"]]] = dict() def __init__(self, sc: SparkContext, n_partitions: int): """ diff --git a/spark_minimal_algorithms/examples/countifs.py b/spark_minimal_algorithms/examples/countifs.py index 81b4811..0fb9d2c 100644 --- a/spark_minimal_algorithms/examples/countifs.py +++ b/spark_minimal_algorithms/examples/countifs.py @@ -3,6 +3,7 @@ from pyspark import RDD, Broadcast +from spark_minimal_algorithms.examples.tera_sort import SampleAndAssignBuckets from spark_minimal_algorithms.algorithm import Step, Algorithm @@ -17,140 +18,165 @@ def _get_format_str(n_elements: int) -> str: return binary_format_str -class SortAndAssignLabel(Step): - """ - IN: (point coords, point type info) where: - - - point coords: (x1, x2, ...) have same number of dimensions for all points - - point type info: (point type, index of point in the collection of points with this type) - - OUT: (label for the 1st dimension, point without 1st dimension) +def _label_first_coord_and_type(point: Any) -> Any: + label, coords, type_info = point + return label, coords[0], type_info[0] - for each data point and each query point such that data point > query point - at the first dimension +class SortAndAssignLabels(Step): """ + Replaces 2nd iteration of TeraSort to assign labels (on top of sorting the input). - @staticmethod - def select_key(coords_typeinfo: Any) -> Any: - coords, type_info = coords_typeinfo - t, order_for_t = type_info - return (coords[0], t), (coords[1:], order_for_t) - - @staticmethod - def unselect_key(selected_key_and_rest: Any) -> Any: - selected_key, rest = selected_key_and_rest - coord_0, t = selected_key - other_coords, order_for_t = rest - return (coord_0, other_coords, (t, order_for_t)) + IN: (bucket index, (label(s), point coords, point type info)) + OUT: (label(s) + new label, point coords without first coord, point type) + or just (labels, point type) if there are no more coords + """ @staticmethod - def extract_partition_idx( - idx: int, points: Iterable[Any] - ) -> Iterable[Tuple[int, Any]]: - for point in points: - yield idx, point + def _sort_within_partition(bucket_and_points: Tuple[int, Iterable[Any]]) -> Tuple[int, Iterable[Any]]: + bucket, points = bucket_and_points + points = sorted(points, key=_label_first_coord_and_type) + return bucket, points @staticmethod def group(rdd: RDD) -> RDD: # type: ignore - # sort by values - todo: consider using custom terasort implementation - cls = SortAndAssignLabel - rdd = rdd.map(cls.select_key).sortByKey().map(cls.unselect_key) - rdd = rdd.mapPartitionsWithIndex(cls.extract_partition_idx).groupByKey() + rdd = rdd.groupByKey().sortByKey() + rdd = rdd.map(SortAndAssignLabels._sort_within_partition, preservesPartitioning=True) + # for k, v in rdd.collect(): # todo: remove debug + # print(f"{k}: {list(v)}") return rdd @staticmethod def emit_by_group(group_key: int, group_items: Iterable[Any]) -> Optional[Any]: # type: ignore - return group_key, len(list(group_items)) + bucket_idx = group_key - @staticmethod - def broadcast( # type: ignore - emitted_items: List[Tuple[int, int]] - ) -> Dict[str, Union[str, List[int]]]: - parition_counts = [ - idx_count[1] - for idx_count in sorted(emitted_items, key=lambda idx_count: idx_count[0]) - ] - partition_prefix_counts = [ - sum(parition_counts[:i]) for i in range(len(parition_counts)) - ] - - total_count = partition_prefix_counts[-1] + parition_counts[-1] - label_format_str = _get_format_str(total_count) + first_label: Optional[str] = None + n_points_for_first_label: Optional[int] = None + last_label: Optional[str] = None + n_points_for_last_label: Optional[int] = None + for point in group_items: + label, coords, type_info = point - return { - "partition_prefix_count": partition_prefix_counts, - "label_format_str": label_format_str, - } + if first_label is None: + first_label = label + n_points_for_first_label = 1 + elif first_label == label: + n_points_for_first_label += 1 # noqa: T484 + + if last_label == label: + n_points_for_last_label += 1 # noqa: T484 + else: + last_label = label + n_points_for_last_label = 1 + + return bucket_idx, (first_label, n_points_for_first_label), (last_label, n_points_for_last_label) @staticmethod - def step( # type: ignore - group_key: int, group_items: Iterable[Any], broadcast: Broadcast - ) -> Iterable[Any]: - prefix_counts: List[int] = broadcast.value["partition_prefix_count"] - partition_prefix_count: int = prefix_counts[group_key] + def broadcast(emitted_items: List[List[Any]]) -> Dict[str, Any]: # type: ignore + bucket_label_counts = sorted(emitted_items, key=lambda bucket_count: bucket_count[0]) - label_format_str: str = broadcast.value["label_format_str"] + # print(f"bucket_label_counts: {bucket_label_counts}") # todo: remove debug - for idx, point in enumerate(group_items): - coord_0, coords, type_info = point - t, _ = type_info + previous_label = () # empty tuple is never assigned as a label + previous_count = 0 + bucket_prefix_counts = dict() # i => (last label in (i-1)-th bucket, count of points with this label in previous buckets) + total_label_counts = dict() # label => total count of points with this label (only for multi-bucket labels) + for bucket_count in bucket_label_counts: + bucket_partition_idx = bucket_count[0] + bucket_prefix_counts[bucket_partition_idx] = (previous_label, previous_count) - label = label_format_str.format(partition_prefix_count + idx) - if t == DATA: - for prefix_len in range(len(label)): - if label[prefix_len] == "1": - if len(coords) > 0: - yield label[:prefix_len], (coords, type_info) - else: - yield label[:prefix_len], type_info + first_label, first_label_count = bucket_count[1] + last_label, last_label_count = bucket_count[2] - elif t == QUERY: - for prefix_len in range(len(label)): - if label[prefix_len] == "0": - if len(coords) > 0: - yield label[:prefix_len], (coords, type_info) - else: - yield label[:prefix_len], type_info + if last_label == previous_label: + # entire bucket consists of point with that one label + previous_count += last_label_count + else: + # current bucket ends with different label than previous bucket + total_label_counts[previous_label] = previous_count + if first_label == previous_label: + # last label ends inside current bucket so we need to increase its count + total_label_counts[previous_label] += first_label_count + previous_label = last_label + previous_count = last_label_count -class AssignNestedLabel(Step): - """ - IN: (label, collection of points with label) + # after iteration ends, we still need to assign total count for last label + total_label_counts[previous_label] = previous_count - OUT: (old label + new label for the 1st dimension, point without 1st dimension) - for each data point and each query point such that data point > query point - at the first dimension + keys_to_delete = {k for k in total_label_counts if total_label_counts[k] == 0} + for k in keys_to_delete: + del total_label_counts[k] - """ + # print(f"bucket_prefix_counts: {bucket_prefix_counts}") # todo: remove debug + # print(f"total_label_counts: {total_label_counts}") # todo: remove debug - @staticmethod - def first_coord_and_point_type(point: Any) -> Any: - coords, type_info = point - return coords[0], type_info[0] + return { + "bucket_prefix_count": bucket_prefix_counts, + "total_label_count": total_label_counts, + } @staticmethod def step( # type: ignore - group_key: Union[str, Tuple[str, ...]], - group_items: Iterable[Any], - broadcast: Broadcast, + group_key: int, group_items: Iterable[Any], broadcast: Broadcast ) -> Iterable[Any]: - points = sorted(group_items, key=AssignNestedLabel.first_coord_and_point_type) - label_format_str = _get_format_str(len(points)) - old_label = group_key - - for idx, (coords, type_info) in enumerate(points): - new_label = label_format_str.format(idx) + bucket_idx = group_key + prefix_counts: List[Tuple[str, int]] = broadcast.value["bucket_prefix_count"] + bucket_prefix_count: Tuple[str, int] = prefix_counts[bucket_idx] + previous_label, prefix_count_for_previous_label = bucket_prefix_count + + # get number of points for labels which span beyond current partition + global_label_count: Dict[str, int] = broadcast.value["total_label_count"] + global_labels = set(global_label_count.keys()) + + # calculate number of points for each label (locally) + local_label_count: Dict[str, int] = dict() + for point in group_items: + label, _, _ = point + if label not in global_labels: + try: + local_label_count[label] += 1 + except KeyError: + if label == previous_label: + local_label_count[label] = 1 + prefix_count_for_previous_label + else: + local_label_count[label] = 1 + + # todo: label format strings for global labels can be pre-computed before broadcast + # todo: we can probably get rid of few intermediate dicts to save memory + label_count: Dict[str, int] = { + **global_label_count, + **local_label_count + } + # print(f"Caclulating label in partition: {group_key}") + # print(f"available label counts: {label_count}") + # print("") + label_format_str = { + label: _get_format_str(n_points_for_label) + for label, n_points_for_label in label_count.items() + } + # assign new labels to points, based on combined counts of points per old label + point_idx_within_label = prefix_count_for_previous_label + for idx, point in enumerate(group_items): + old_label, coords, type_info = point t, _ = type_info + + if old_label == previous_label: + new_label = label_format_str[old_label].format(point_idx_within_label) + point_idx_within_label += 1 + else: + new_label = label_format_str[old_label].format(0) + point_idx_within_label = 1 + previous_label = old_label + + # print(f"Point {point} (#{idx} in bucket #{bucket_idx}) got label {new_label}") # todo: remove debug + if t == DATA: for prefix_len in range(len(new_label)): if new_label[prefix_len] == "1": if len(coords) > 1: - yield (old_label, new_label[:prefix_len]), ( - coords[1:], - type_info, - ) + yield (old_label, new_label[:prefix_len]), coords[1:], type_info else: yield (old_label, new_label[:prefix_len]), type_info @@ -158,14 +184,32 @@ def step( # type: ignore for prefix_len in range(len(new_label)): if new_label[prefix_len] == "0": if len(coords) > 1: - yield (old_label, new_label[:prefix_len]), ( - coords[1:], - type_info, - ) + yield (old_label, new_label[:prefix_len]), coords[1:], type_info else: yield (old_label, new_label[:prefix_len]), type_info +class TeraSortWithLabels(Algorithm): + __steps__ = { + "assign_buckets": SampleAndAssignBuckets, + "sort_and_assign_labels": SortAndAssignLabels, + } + + def run(self, rdd: RDD) -> RDD: # type: ignore + rdd = rdd.cache() + + n_points = rdd.count() + m = n_points / self.n_partitions + optimal_p = math.log(n_points * self.n_partitions) / m + + rdd = self.assign_buckets( # type: ignore + rdd, p=optimal_p, key_func=_label_first_coord_and_type + ) + rdd = self.sort_and_assign_labels(rdd) # type: ignore + + return rdd + + class GetResultsByLabel(Step): """ IN: (label, points with this label) @@ -221,8 +265,7 @@ class Countifs(Algorithm): """ __steps__ = { - "sort_and_assign_label": SortAndAssignLabel, - "assign_nested_label": AssignNestedLabel, + "assign_next_label": TeraSortWithLabels, "get_results_by_label": GetResultsByLabel, "aggregate_results_by_query": AggregateResultsByQuery, } @@ -231,16 +274,15 @@ def run(self, data_rdd: RDD, query_rdd: RDD, n_dim: int) -> RDD: # type: ignore empty_result_rdd = query_rdd.map(lambda idx_coords: (idx_coords[0], 0)) data_rdd = data_rdd.map( - lambda idx_coords: (idx_coords[1], (DATA, idx_coords[0])) + lambda idx_coords: ((), idx_coords[1], (DATA, idx_coords[0])) ) query_rdd = query_rdd.map( - lambda idx_coords: (idx_coords[1], (QUERY, idx_coords[0])) + lambda idx_coords: ((), idx_coords[1], (QUERY, idx_coords[0])) ) rdd = data_rdd.union(query_rdd) - rdd = self.sort_and_assign_label(rdd) # type: ignore - for _ in range(n_dim - 1): - rdd = self.assign_nested_label(rdd) # type: ignore + for _ in range(n_dim): + rdd = self.assign_next_label(rdd=rdd) # type: ignore rdd = empty_result_rdd.union(self.get_results_by_label(rdd)) # type: ignore rdd = self.aggregate_results_by_query(rdd).sortByKey() # type: ignore diff --git a/spark_minimal_algorithms/examples/tera_sort.py b/spark_minimal_algorithms/examples/tera_sort.py index 7c0b260..60c15c9 100644 --- a/spark_minimal_algorithms/examples/tera_sort.py +++ b/spark_minimal_algorithms/examples/tera_sort.py @@ -1,4 +1,4 @@ -from typing import Iterable, Tuple, Any, List +from typing import Iterable, Tuple, Any, List, Callable from bisect import bisect_left import random import math @@ -28,23 +28,27 @@ def extract_idx( @staticmethod def group(rdd: RDD, **kwargs: Any) -> RDD: - rdd = rdd.mapPartitionsWithIndex(SampleAndAssignBuckets.extract_idx).groupByKey() + rdd = rdd.mapPartitionsWithIndex( + SampleAndAssignBuckets.extract_idx, preservesPartitioning=True + ) + rdd = rdd.groupByKey() return rdd @staticmethod def emit_by_group(group_key: int, group_items: Iterable[Any], **kwargs: Any) -> Any: samples = list() p: float = kwargs.get("p", SampleAndAssignBuckets.p) + key_func: Callable[[Any], Any] = kwargs.get("key_func", lambda x: x) for point in group_items: if random.random() < p: - samples.append(point) + sample_key = key_func(point) + samples.append(sample_key) return samples @staticmethod def broadcast(emitted_items: List[List[Any]], **kwargs: Any) -> List[Any]: - n_dim = kwargs["n_dim"] - zero_point = tuple(0 for _ in range(n_dim)) + zero_point = () # empty tuple is always smaller than any n-dimensional point for n >= 1 buckets = [zero_point] + [ point for samples in emitted_items for point in samples ] @@ -54,8 +58,10 @@ def broadcast(emitted_items: List[List[Any]], **kwargs: Any) -> List[Any]: def step( # type: ignore group_key: int, group_items: Iterable[Any], broadcast: Broadcast, **kwargs: Any ) -> Iterable[Tuple[int, Any]]: + key_func: Callable[[Tuple[Any]], Tuple[Any]] = kwargs.get("key_func", lambda x: x) for point in group_items: - point_bucket = bisect_left(broadcast.value, point) + point_key = key_func(point) + point_bucket = bisect_left(broadcast.value, point_key) yield point_bucket, point @@ -66,15 +72,16 @@ class SortByKeyAndValue(Step): """ @staticmethod - def group(rdd: RDD) -> RDD: # type: ignore + def group(rdd: RDD, **kwargs: Any) -> RDD: # type: ignore rdd = rdd.groupByKey().sortByKey() return rdd @staticmethod def step( # type: ignore - group_key: int, group_items: Iterable[Any], broadcast: Broadcast + group_key: int, group_items: Iterable[Any], broadcast: Broadcast, **kwargs: Any ) -> Iterable[Any]: - sorted_points = sorted(group_items) + key_func: Callable[[Tuple[Any]], Tuple[Any]] = kwargs.get("key_func", lambda x: x) + sorted_points = sorted(group_items, key=key_func) for point in sorted_points: yield point @@ -85,8 +92,7 @@ class TeraSort(Algorithm): Input: - - `rdd`: RDD[point], where each point is a tuple with `n_dim` elements - - `n_dim`: int - number of dimensions (coordinates) + - `rdd`: RDD[point], where each point is a tuple of integers >= 1 (non-zero) Output: @@ -98,14 +104,14 @@ class TeraSort(Algorithm): "sort": SortByKeyAndValue, } - def run(self, rdd: RDD, n_dim: int) -> RDD: # type: ignore + def run(self, rdd: RDD, key_func: Callable[[Tuple[Any]], Tuple[Any]] = lambda x: x) -> RDD: # type: ignore rdd = rdd.cache() n_points = rdd.count() m = n_points / self.n_partitions optimal_p = math.log(n_points * self.n_partitions) / m - rdd = self.assign_buckets(rdd, p=optimal_p, n_dim=n_dim) # type: ignore - rdd = self.sort(rdd) # type: ignore + rdd = self.assign_buckets(rdd, p=optimal_p, key_func=key_func) # type: ignore + rdd = self.sort(rdd, key_func=key_func) # type: ignore return rdd diff --git a/tests/examples/test_countifs.py b/tests/examples/test_countifs.py index 577c9ae..d74cadf 100644 --- a/tests/examples/test_countifs.py +++ b/tests/examples/test_countifs.py @@ -6,8 +6,7 @@ from spark_minimal_algorithms.algorithm import Step, Algorithm from spark_minimal_algorithms.examples.countifs import ( - SortAndAssignLabel, - AssignNestedLabel, + TeraSortWithLabels, GetResultsByLabel, AggregateResultsByQuery, Countifs, @@ -19,8 +18,7 @@ @pytest.mark.parametrize( "cls", [ - SortAndAssignLabel, - AssignNestedLabel, + TeraSortWithLabels, GetResultsByLabel, AggregateResultsByQuery, ], @@ -29,24 +27,20 @@ def test_step_creation(cls, spark_context, n_partitions): instance = cls(spark_context, n_partitions) - assert isinstance(instance, Step) + assert isinstance(instance, Step) or isinstance(instance, Algorithm) assert instance._n_partitions == n_partitions -@pytest.mark.parametrize("n_partitions", [1]) +@pytest.mark.parametrize("n_partitions", [1, 2]) def test_algorithm_creation(spark_context, n_partitions): instance = Countifs(spark_context, n_partitions) assert isinstance(instance, Algorithm) assert instance._n_partitions == n_partitions - assert hasattr(instance, "sort_and_assign_label") - assert type(instance.sort_and_assign_label) == SortAndAssignLabel - assert instance.sort_and_assign_label._n_partitions == n_partitions - - assert hasattr(instance, "assign_nested_label") - assert type(instance.assign_nested_label) == AssignNestedLabel - assert instance.assign_nested_label._n_partitions == n_partitions + assert hasattr(instance, "assign_next_label") + assert type(instance.assign_next_label) == TeraSortWithLabels + assert instance.assign_next_label._n_partitions == n_partitions assert hasattr(instance, "get_results_by_label") assert type(instance.get_results_by_label) == GetResultsByLabel @@ -57,6 +51,151 @@ def test_algorithm_creation(spark_context, n_partitions): assert instance.aggregate_results_by_query._n_partitions == n_partitions +@pytest.mark.parametrize("n_partitions", [1, 2, 3, 4]) +def test_tera_sort_label_assignment_1d(spark_context, n_partitions): + rdd = spark_context.parallelize([ + ((), (1,), (0, 0)), # D0: data at x = 1 with empty label + ((), (1,), (1, 0)), # Q0: query at x = 1 with empty label + ((), (2,), (0, 1)), # D1: data at x = 2 with empty label + ((), (2,), (1, 1)), # Q1: query at x = 2 with empty label + ]) + # Q0, D1 is the only pair of results that matches the COUNTIF criteria + expected_result = [ + (((), ''), (1, 0)), + (((), ''), (0, 1)), + ] + + algorithm = TeraSortWithLabels(spark_context, n_partitions) + result = algorithm(rdd=rdd).collect() + + assert result == expected_result + + +@pytest.mark.parametrize("n_partitions", [1, 2, 3, 4]) +def test_tera_sort_label_assignment_2d_round_1_case_1(spark_context, n_partitions): + rdd = spark_context.parallelize([ + ((), (3, 6), (0, 0)), # D0 + ((), (4, 2), (0, 1)), # D1 + ((), (0, 5), (1, 0)), # Q0 + ((), (7, 1), (1, 1)), # Q1 + ]) + # after 1st dimension, for Q0 both D1 and D2 are feasible + expected_result_1st_round = [ + (((), ''), (5,), (1, 0)), + (((), '0'), (5,), (1, 0)), + (((), '0'), (6,), (0, 0)), + (((), ''), (2,), (0, 1)), + ] + + algorithm = TeraSortWithLabels(spark_context, n_partitions) + result = algorithm(rdd=rdd).collect() + + assert result == expected_result_1st_round + + +@pytest.mark.parametrize("n_partitions", [1, 2, 3, 4]) +def test_tera_sort_label_assignment_2d_round_2_case_1(spark_context, n_partitions): + rdd_after_1st_round = spark_context.parallelize([ + (((), ''), (5,), (1, 0)), # Q0 + (((), '0'), (5,), (1, 0)), # Q0 + (((), '0'), (6,), (0, 0)), # D0 + (((), ''), (2,), (0, 1)), # D1 + ]) + # after 1st dimension, for Q0 both D1 and D2 are feasible + expected_result_2nd_round = [ + ((((), '0'), ''), (1, 0)), + ((((), '0'), ''), (0, 0)), + ] + + algorithm = TeraSortWithLabels(spark_context, n_partitions) + result = algorithm(rdd=rdd_after_1st_round).collect() + + assert result == expected_result_2nd_round + + +@pytest.mark.parametrize("n_partitions", [1, 2, 3, 4]) +def test_tera_sort_label_assignment_2d_round_1_case_2(spark_context, n_partitions): + # 'data_points': [(103, 480), (105, 1771), (1178, 101), (1243, 107)], + # 'query_points': [(100, 100), (102, 102), (104, 104), (106, 106)] + rdd = spark_context.parallelize([ + ((), (1178, 101), (0, 2)), + ((), (103, 480), (0, 0)), + ((), (105, 1771), (0, 1)), + ((), (1243, 107), (0, 3)), + ((), (104, 104), (1, 2)), + ((), (100, 100), (1, 0)), + ((), (102, 102), (1, 1)), + ((), (106, 106), (1, 3)) + ]) + # after 1st dimension, for Q0 both D1 and D2 are feasible + expected_result_1st_round = [ + (((), ''), (100,), (1, 0)), + (((), '0'), (100,), (1, 0)), + (((), '00'), (100,), (1, 0)), + (((), ''), (102,), (1, 1)), + (((), '0'), (102,), (1, 1)), + (((), '0'), (480,), (0, 0)), + (((), ''), (104,), (1, 2)), + (((), ''), (1771,), (0, 1)), + (((), '1'), (106,), (1, 3)), + (((), ''), (101,), (0, 2)), + (((), '1'), (101,), (0, 2)), + (((), ''), (107,), (0, 3)), + (((), '1'), (107,), (0, 3)), + (((), '11'), (107,), (0, 3)), + ] + + algorithm = TeraSortWithLabels(spark_context, n_partitions) + result = algorithm(rdd=rdd).collect() + + assert result == expected_result_1st_round + + +@pytest.mark.parametrize("n_partitions", [1, 2, 3, 4]) +def test_tera_sort_label_assignment_2d_round_2_case_2(spark_context, n_partitions): + rdd_after_1st_round = spark_context.parallelize([ + (((), ''), (100,), (1, 0)), + (((), '0'), (100,), (1, 0)), + (((), '00'), (100,), (1, 0)), + (((), ''), (102,), (1, 1)), + (((), '0'), (102,), (1, 1)), + (((), '0'), (480,), (0, 0)), + (((), ''), (104,), (1, 2)), + (((), ''), (1771,), (0, 1)), + (((), '1'), (106,), (1, 3)), + (((), ''), (101,), (0, 2)), + (((), '1'), (101,), (0, 2)), + (((), ''), (107,), (0, 3)), + (((), '1'), (107,), (0, 3)), + (((), '11'), (107,), (0, 3)), + ]) + # after 1st dimension, for Q0 both D1 and D2 are feasible + expected_result_2nd_round = [ + ((((), ''), ''), (1, 0)), + ((((), ''), '0'), (1, 0)), + ((((), ''), '00'), (1, 0)), + ((((), ''), '00'), (0, 2)), + ((((), ''), ''), (1, 1)), + ((((), ''), '01'), (1, 1)), + ((((), ''), ''), (1, 2)), + ((((), ''), ''), (0, 3)), + ((((), ''), ''), (0, 1)), + ((((), ''), '10'), (0, 1)), + ((((), '0'), ''), (1, 0)), + ((((), '0'), '0'), (1, 0)), + ((((), '0'), ''), (1, 1)), + ((((), '0'), ''), (0, 0)), + ((((), '00'), ''), (1, 0)), + ((((), '1'), ''), (1, 3)), + ((((), '1'), ''), (0, 3)), + ] + + algorithm = TeraSortWithLabels(spark_context, n_partitions) + result = algorithm(rdd=rdd_after_1st_round).collect() + + assert result == expected_result_2nd_round + + TESTS_1D = [ { "query_points": [1, 4, 5, 6, 7, 8, 11, 13, 14, 17], diff --git a/tests/examples/test_tera_sort.py b/tests/examples/test_tera_sort.py index b1d2bb5..01773c7 100644 --- a/tests/examples/test_tera_sort.py +++ b/tests/examples/test_tera_sort.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Any +from typing import List, Tuple, Any, Callable import random import pytest @@ -8,37 +8,38 @@ random.seed(42) -def create_test_case(n_points: int, n_dim: int) -> List[Tuple[Any]]: +def create_test_case( + n_points: int, n_dim: int, key_func: Callable[[Tuple[Any]], Tuple[Any]] = lambda x: x +) -> Tuple[List[Tuple[Any]], List[Tuple[Any]], Callable[[Tuple[Any]], Tuple[Any]]]: max_point = 100 * n_points points = [ tuple(random.randint(1, max_point) for _ in range(n_dim)) for _ in range(n_points) ] - return points, sorted(points) + return points, sorted(points, key=key_func), key_func TESTS = [ create_test_case(5, 1), create_test_case(10, 1), - create_test_case(100, 1), + create_test_case(10, 1, lambda x: (x[0],)), create_test_case(5, 2), create_test_case(10, 2), - create_test_case(100, 2), + create_test_case(10, 2, lambda x: (x[0],)), create_test_case(5, 3), create_test_case(10, 3), - create_test_case(100, 3), + create_test_case(10, 3, lambda x: (x[0],)), ] @pytest.mark.parametrize("test_case", TESTS) @pytest.mark.parametrize("n_partitions", [1, 2, 4]) def test_tera_sort(spark_context, n_partitions, test_case): - points, sorted_points = test_case - n_dim = len(points[0]) + points, sorted_points, key_func = test_case rdd = spark_context.parallelize(points) tera_sort = TeraSort(spark_context, n_partitions) - result = tera_sort(rdd=rdd, n_dim=n_dim).collect() + result = tera_sort(rdd=rdd, key_func=key_func).collect() assert len(result) == len(sorted_points) assert result == sorted_points @@ -46,10 +47,13 @@ def test_tera_sort(spark_context, n_partitions, test_case): LONG_TESTS = [ create_test_case(100, 1), + create_test_case(100, 1, lambda x: (x[0],)), create_test_case(1_000, 1), create_test_case(100, 2), + create_test_case(100, 2, lambda x: (x[0],)), create_test_case(1_000, 2), create_test_case(100, 3), + create_test_case(100, 3, lambda x: (x[0],)), create_test_case(1_000, 3), ] @@ -58,12 +62,11 @@ def test_tera_sort(spark_context, n_partitions, test_case): @pytest.mark.parametrize("test_case", LONG_TESTS) @pytest.mark.parametrize("n_partitions", [1, 2, 3, 4, 8, 16]) def test_tera_sort_performance(spark_context, n_partitions, test_case): - points, sorted_points = test_case - n_dim = len(points[0]) + points, sorted_points, key_func = test_case rdd = spark_context.parallelize(points) tera_sort = TeraSort(spark_context, n_partitions) - result = tera_sort(rdd=rdd, n_dim=n_dim).collect() + result = tera_sort(rdd=rdd, key_func=key_func).collect() assert len(result) == len(sorted_points) assert result == sorted_points