Skip to content

Commit

Permalink
pearson correlation metric (#338)
Browse files Browse the repository at this point in the history
* pearson correlation metric

* formatting

* add optional mask to pearsonr metric

* formatting

* fixed random vector
  • Loading branch information
alex-golts authored Dec 4, 2023
1 parent 8f3f538 commit f6606be
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 1 deletion.
55 changes: 55 additions & 0 deletions fuse/eval/examples/examples_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
(C) Copyright 2021 IBM Corp.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Created on Nov 30, 2023
"""

from fuse.eval.metrics.stat.metrics_stat_common import MetricPearsonCorrelation
import numpy as np
import pandas as pd
from collections import OrderedDict
from fuse.eval.evaluator import EvaluatorDefault


def example_pearson_correlation() -> float:
"""
Pearson correlation coefficient
"""

# define data
sz = 1000
data = {
"id": range(sz),
}
np.random.seed(0)
rand_vec = np.random.randn((sz))
data["x1"] = 100 * np.ones((sz)) + 10 * rand_vec
data["x2"] = -10 * np.ones((sz)) + 3 * rand_vec

data_df = pd.DataFrame(data)

# list of metrics
metrics = OrderedDict(
[
("pearsonr", MetricPearsonCorrelation(pred="x1", target="x2")),
]
)

# read files
evaluator = EvaluatorDefault()
res = evaluator.eval(ids=None, data=data_df, metrics=metrics)

return res
46 changes: 46 additions & 0 deletions fuse/eval/metrics/libs/stat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy as np
from typing import Sequence, Union


class Stat:
"""
Statistical metrics
"""

@staticmethod
def pearson_correlation(
pred: Union[np.ndarray, Sequence],
target: Union[np.ndarray, Sequence],
mask: Union[np.ndarray, Sequence, None] = None,
) -> float:
"""
Pearson correlation coefficient measuring the linear relationship between two datasets/vectors.
:param pred: prediction values
:param target: target values
:param mask: optional boolean mask. if it is provided, the metric will be applied only to the masked samples
"""
if isinstance(pred, Sequence):
pred = np.array(pred)
if isinstance(target, Sequence):
target = np.array(target)
if isinstance(mask, Sequence):
mask = np.array(mask).astype("bool")
if mask is not None:
pred = pred[mask]
target = target[mask]

pred = pred.squeeze()
target = target.squeeze()
if len(pred.shape) > 1 or len(target.shape) > 1:
raise ValueError(
f"expected 1D vectors. got pred shape: {pred.shape}, target shape: {target.shape}"
)

mean_pred = np.mean(pred)
mean_target = np.mean(target)

r = np.sum((pred - mean_pred) * (target - mean_target)) / np.sqrt(
np.sum((pred - mean_pred) ** 2) * np.sum((target - mean_target) ** 2)
)

return r
16 changes: 15 additions & 1 deletion fuse/eval/metrics/stat/metrics_stat_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, Hashable, Optional, Sequence
from collections import Counter
from fuse.eval.metrics.metrics_common import MetricWithCollectorBase
from fuse.eval.metrics.metrics_common import MetricDefault, MetricWithCollectorBase
from fuse.eval.metrics.libs.stat import Stat


class MetricUniqueValues(MetricWithCollectorBase):
Expand All @@ -19,3 +20,16 @@ def eval(
counter = Counter(values)

return list(counter.items())


class MetricPearsonCorrelation(MetricDefault):
def __init__(
self, pred: str, target: str, mask: Optional[str] = None, **kwargs: dict
) -> None:
super().__init__(
pred=pred,
target=target,
mask=mask,
metric_func=Stat.pearson_correlation,
**kwargs
)
6 changes: 6 additions & 0 deletions fuse/eval/tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@

from fuse.eval.examples.examples_seq_gen import example_seq_gen_0, example_seq_gen_1

from fuse.eval.examples.examples_stats import example_pearson_correlation


class TestEval(unittest.TestCase):
def test_eval_example_0(self) -> None:
Expand Down Expand Up @@ -232,6 +234,10 @@ def test_eval_example_seq_gen_1(self) -> None:
results = example_seq_gen_1(seed=1234)
self.assertAlmostEqual(results["metrics.perplexity"], 162.87, places=2)

def test_pearson_correlation(self) -> None:
res = example_pearson_correlation()
self.assertAlmostEqual(res["metrics.pearsonr"], 1.0, places=2)


if __name__ == "__main__":
unittest.main()

0 comments on commit f6606be

Please sign in to comment.