Skip to content

Commit

Permalink
Merge pull request #14 from AutoResearch/feat-make-compatible-with-state
Browse files Browse the repository at this point in the history
feat: make compatible with state and pd.DataFrames
  • Loading branch information
younesStrittmatter authored Sep 3, 2023
2 parents 8cb9bfe + ce87313 commit 9b2b9b8
Showing 1 changed file with 45 additions and 35 deletions.
80 changes: 45 additions & 35 deletions src/autora/experimentalist/inequality/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Iterable, Literal
from typing import Literal, Union

import numpy as np
import pandas as pd
from sklearn.metrics import DistanceMetric

from autora.utils.deprecation import deprecated_alias
Expand Down Expand Up @@ -30,8 +31,8 @@


def sample(
condition_pool: np.ndarray,
reference_conditions: np.ndarray,
conditions: Union[pd.DataFrame, np.ndarray],
reference_conditions: Union[pd.DataFrame, np.ndarray],
num_samples: int = 1,
equality_distance: float = 0,
metric: str = "euclidean",
Expand All @@ -43,7 +44,7 @@ def sample(
into reference_conditions and are included in the summed equality calculation.
Args:
condition_pool: pool of IV conditions to evaluate inequality
conditions: pool of IV conditions to evaluate inequality
reference_conditions: reference pool of IV conditions
num_samples: number of samples to select
equality_distance: the distance to decide if two data points are equal.
Expand All @@ -58,89 +59,98 @@ def sample(
Examples:
The value 1 is not in the reference. Therefore it is choosen.
>>> summed_inequality_sampler([1, 2, 3], [2, 3, 4])
>>> summed_inequality_sample([1, 2, 3], [2, 3, 4])
array([[1]])
The equality distance is set to 0.4. 1 and 1.3 are considered equal, so are 3 and 3.1.
Therefore 2 is choosen.
>>> summed_inequality_sampler([1, 2, 3], [1.3, 2.7, 3.1], 1, .4)
>>> summed_inequality_sample([1, 2, 3], [1.3, 2.7, 3.1], 1, .4)
array([[2]])
The value 3 appears least often in the reference.
>>> summed_inequality_sampler([1, 2, 3], [1, 1, 1, 2, 2, 2, 3, 3])
>>> summed_inequality_sample([1, 2, 3], [1, 1, 1, 2, 2, 2, 3, 3])
array([[3]])
The experimentalist "fills up" the reference array so the values are contributed evenly
>>> summed_inequality_sampler([1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 1, 2, 2, 2, 2, 3, 3, 3], 3)
>>> summed_inequality_sample([1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 1, 2, 2, 2, 2, 3, 3, 3], 3)
array([[1],
[3],
[1]])
The experimentalist samples without replacemnt!
>>> summed_inequality_sampler([1, 2, 3], [1, 1, 1], 3)
>>> summed_inequality_sample([1, 2, 3], [1, 1, 1], 3)
array([[3],
[2],
[1]])
"""

if isinstance(condition_pool, Iterable):
condition_pool = np.array(list(condition_pool))
X = np.array(conditions)

if isinstance(reference_conditions, Iterable):
reference_conditions = np.array(list(reference_conditions))
_reference_conditions = reference_conditions.copy()
if isinstance(reference_conditions, pd.DataFrame):
if set(conditions.columns) != set(reference_conditions.columns):
raise Exception(
f"Variable names {set(conditions.columns)} in conditions"
f"and {set(reference_conditions.columns)} in allowed values don't match. "
)

if condition_pool.ndim == 1:
condition_pool = condition_pool.reshape(-1, 1)
_reference_conditions = _reference_conditions[conditions.columns]

if reference_conditions.ndim == 1:
reference_conditions = reference_conditions.reshape(-1, 1)
X_reference_conditions = np.array(_reference_conditions)

if condition_pool.shape[1] != reference_conditions.shape[1]:
if X.ndim == 1:
X = X.reshape(-1, 1)

if X_reference_conditions.ndim == 1:
X_reference_conditions = X_reference_conditions.reshape(-1, 1)

if X.shape[1] != X_reference_conditions.shape[1]:
raise ValueError(
f"condition_pool and reference_conditions must have the same number of columns.\n"
f"condition_pool has {condition_pool.shape[1]} columns, "
f"while reference_conditions has {reference_conditions.shape[1]} columns."
f"conditions and reference_conditions must have the same number of columns.\n"
f"conditions has {X.shape[1]} columns, "
f"while reference_conditions has {X_reference_conditions.shape[1]} columns."
)

if condition_pool.shape[0] < num_samples:
if X.shape[0] < num_samples:
raise ValueError(
f"condition_pool must have at least {num_samples} rows matching the number "
f"conditions must have at least {num_samples} rows matching the number "
f"of requested samples."
)

dist = DistanceMetric.get_metric(metric)

# create a list to store the n condition_pool values with the highest inequality scores
# create a list to store the n conditions values with the highest inequality scores
condition_pool_res = []
# choose the canditate with the highest inequality score n-times
for _ in range(num_samples):
summed_equalities = []
# loop over all IV values
for row in condition_pool:
for row in X:

# calculate the distances between the current row in matrix1
# and all other rows in matrix2
summed_equality = 0
for reference_conditions_row in reference_conditions:
for reference_conditions_row in X_reference_conditions:
distance = dist.pairwise([row, reference_conditions_row])[0, 1]
summed_equality += distance > equality_distance

# store the summed distance for the current row
summed_equalities.append(summed_equality)

# sort the rows in matrix1 by their summed distances
condition_pool = condition_pool[np.argsort(summed_equalities)[::-1]]
X = X[np.argsort(summed_equalities)[::-1]]
# append the first value of the sorted list to the result
condition_pool_res.append(condition_pool[0])
condition_pool_res.append(X[0])
# add the chosen value to reference_conditions
reference_conditions = np.append(
reference_conditions, [condition_pool[0]], axis=0
)
# remove the chosen value from condition_pool
condition_pool = condition_pool[1:]

return np.array(condition_pool_res[:num_samples])
X_reference_conditions = np.append(X_reference_conditions, [X[0]], axis=0)
# remove the chosen value from X
X = X[1:]

new_conditions = np.array(condition_pool_res[:num_samples])
if isinstance(conditions, pd.DataFrame):
new_conditions = pd.DataFrame(new_conditions, columns=conditions.columns)
return new_conditions


summed_inequality_sample = sample
Expand Down

0 comments on commit 9b2b9b8

Please sign in to comment.