Skip to content

Commit

Permalink
Create a new Filter class to support richer query filtering functiona…
Browse files Browse the repository at this point in the history
…lity.
  • Loading branch information
BirchKwok committed Apr 23, 2024
1 parent 21b3541 commit 28df03e
Showing 1 changed file with 156 additions and 0 deletions.
156 changes: 156 additions & 0 deletions min_vec/structures/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import operator
from typing import Union

import numpy as np


class MatchField:
def __init__(self, value, comparator=operator.eq):
"""
Initialize a MatchField instance with a specific value and comparator.
.. versionadded:: 0.3.0
Parameters:
value: The value to compare the data attribute with.
comparator: The comparator function to apply to the data attribute.
"""
self.value = value
self.comparator = comparator

def match(self, data_value):
"""
Apply the comparator to compare the data attribute with the specified value.
Parameters:
data_value: The value of the data attribute to compare.
Returns:
bool: True if the data attribute matches the specified value, otherwise False.
"""
return self.comparator(data_value, self.value)

def __hash__(self):
return hash(self.value) + hash(self.comparator)


class FieldCondition:
def __init__(self, key, matcher):
"""
Initialize a FieldCondition instance with a specific key and matcher.
.. versionadded:: 0.3.0
Parameters:
key: The key of the data attribute to compare.
matcher: The MatchField instance to compare the data attribute with.
"""
self.key = key
self.matcher = matcher

def evaluate(self, data):
"""
Evaluate the condition against a given dictionary.
Parameters:
data: A dictionary containing the data attributes to compare.
Returns:
bool: True if the data attribute matches the specified value, otherwise False.
"""
attribute_value = data.get(self.key)
if attribute_value is not None:
return self.matcher.match(attribute_value)
return False

def __hash__(self):
return hash(self.key) + hash(self.matcher)


class MatchID:
def __init__(self, indices: Union[list, np.ndarray]):
"""
Initialize an MatchID instance.
.. versionadded:: 0.3.0
Parameters:
indices (list or np.ndarray): The indices to filter the numpy array.
"""
self.indices = indices

def match(self, array):
"""
Filter the numpy array according to the specified indices.
Parameters:
array (np.ndarray): The numpy array to filter.
Returns:
np.ndarray: A numpy array filtered according to the specified indices.
"""
return array[np.isin(array, self.indices, assume_unique=True)]

def __hash__(self):
if isinstance(self.indices, list):
return hash(tuple(self.indices))
else:
return hash(self.indices.tobytes())


class IDCondition:
def __init__(self, matcher):
"""
Initialize an IDCondition instance.
.. versionadded:: 0.3.0
Parameters:
matcher: The MatchID instance to filter the numpy array.
"""
self.matcher = matcher

def evaluate(self, array):
"""
Evaluate the condition against a given numpy array.
:param array: A numpy array that needs to be filtered.
:return: A numpy array filtered according to the specified indices.
"""
return self.matcher.match(array)

def __hash__(self):
return hash(self.matcher)


class Filter:
__slots__ = ['must', 'any']

def __init__(self, must=None, any=None):
"""
Initialize a Filter instance with must and any conditions.
.. versionadded:: 0.3.0
Parameters:
must: A list of conditions that must be satisfied.
any: A list of conditions where at least one must be satisfied.
"""
self.must = must if must is not None else []
self.any = any if any is not None else []

def apply(self, data):
"""
Apply the filter to the given data.
Parameters:
data: The data to apply the filter to.
Returns:
bool: True if the data satisfies the filter conditions, otherwise False.
"""
if not self.must and not self.any:
must_pass = True
any_pass = False
else:
must_pass = all(condition.evaluate(data) for condition in self.must) if self.must else True
any_pass = any(condition.evaluate(data) for condition in self.any) if self.any else False

return must_pass and (any_pass or not self.any)

def __hash__(self):
return hash((tuple(self.must), tuple(self.any)))

0 comments on commit 28df03e

Please sign in to comment.