Skip to content

Commit

Permalink
improve typing
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderVNikitin committed Nov 26, 2023
1 parent cd231bd commit 1831785
Show file tree
Hide file tree
Showing 18 changed files with 186 additions and 173 deletions.
8 changes: 4 additions & 4 deletions tsgm/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import typing
import typing as T
import logging
import numpy as np

Expand All @@ -13,7 +13,7 @@ class DatasetProperties:
"""
Stores the properties of a dataset. Along with dimensions it can store properties of the covariates.
"""
def __init__(self, N: int, D: int, T: int, variables=None):
def __init__(self, N: int, D: int, T: int, variables: T.Optional[T.List] = None) -> None:
"""
:param N: The number of samples.
:type N: int
Expand All @@ -35,7 +35,7 @@ class Dataset(DatasetProperties):
"""
Wrapper for time-series datasets. Additional information is stored in `metadata` field.
"""
def __init__(self, x: tsgm.types.Tensor, y: tsgm.types.Tensor, metadata: typing.Optional[typing.Dict] = None):
def __init__(self, x: tsgm.types.Tensor, y: tsgm.types.Tensor, metadata: T.Optional[T.Dict] = None):
"""
:param x: The matrix of time series with dimensions NxDxT
:type x: tsgm.types.Tensor
Expand Down Expand Up @@ -155,4 +155,4 @@ def output_dim(self) -> int:
return len(set(self.y))


DatasetOrTensor = typing.Union[Dataset, tsgm.types.Tensor]
DatasetOrTensor = T.Union[Dataset, tsgm.types.Tensor]
20 changes: 10 additions & 10 deletions tsgm/metrics/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
import typing
import typing as T
import logging
import numpy as np
import itertools
Expand All @@ -16,7 +16,7 @@
n_splits=3, random_state=42, shuffle=True)


def _dataset_or_tensor_to_tensor(D1):
def _dataset_or_tensor_to_tensor(D1: tsgm.dataset.DatasetOrTensor) -> tsgm.types.Tensor:
if isinstance(D1, tsgm.dataset.Dataset):
return D1.X
else:
Expand All @@ -33,7 +33,7 @@ class DistanceMetric(Metric):
"""
Metric that measures similarity between synthetic and real time series
"""
def __init__(self, statistics: list, discrepancy: typing.Callable):
def __init__(self, statistics: list, discrepancy: T.Callable) -> None:
"""
:param statistics: A list of summary statistics (callable)
:type statistics: list
Expand Down Expand Up @@ -87,14 +87,14 @@ class ConsistencyMetric(Metric):
"""
Predictive consistency metric measures whether a set of evaluators yield consistent results on real and synthetic data.
"""
def __init__(self, evaluators: list):
def __init__(self, evaluators: T.List) -> None:
"""
:param evaluators: A list of evaluators (each item should implement method `.evaluate(D)`)
:type evaluators: list
"""
self._evaluators = evaluators

def _apply_models(self, D: tsgm.dataset.DatasetOrTensor, D_test: tsgm.dataset.DatasetOrTensor) -> list:
def _apply_models(self, D: tsgm.dataset.DatasetOrTensor, D_test: tsgm.dataset.DatasetOrTensor) -> T.List:
return [e.evaluate(D, D_test) for e in self._evaluators]

def __call__(self, D1: tsgm.dataset.DatasetOrTensor, D2: tsgm.dataset.DatasetOrTensor, D_test: tsgm.dataset.DatasetOrTensor) -> float:
Expand Down Expand Up @@ -131,14 +131,14 @@ class DownstreamPerformanceMetric(Metric):
The downstream performance metric evaluates the performance of a model on a downstream task.
It returns performance gains achieved with the addition of synthetic data.
"""
def __init__(self, evaluator: BaseDownstreamEvaluator):
def __init__(self, evaluator: BaseDownstreamEvaluator) -> None:
"""
:param evaluator: An evaluator, should implement method `.evaluate(D)`
:type evaluator: BaseDownstreamEvaluator
"""
self._evaluator = evaluator

def __call__(self, D1: tsgm.dataset.DatasetOrTensor, D2: tsgm.dataset.DatasetOrTensor, D_test: typing.Optional[tsgm.dataset.DatasetOrTensor], return_std=False) -> float:
def __call__(self, D1: tsgm.dataset.DatasetOrTensor, D2: tsgm.dataset.DatasetOrTensor, D_test: T.Optional[tsgm.dataset.DatasetOrTensor], return_std: bool = False) -> float:
"""
:param D1: A time series dataset.
:type D1: tsgm.dataset.DatasetOrTensor.
Expand Down Expand Up @@ -169,7 +169,7 @@ class PrivacyMembershipInferenceMetric(Metric):
"""
The metric that measures the possibility of membership inference attacks.
"""
def __init__(self, attacker: typing.Any, metric: typing.Callable = None):
def __init__(self, attacker: T.Any, metric: T.Optional[T.Callable] = None) -> None:
"""
:param attacker: An attacker, one class classififier (OCC) that implements methods `.fit` and `.predict`
:type attacker: typing.Any
Expand Down Expand Up @@ -201,7 +201,7 @@ class MMDMetric(Metric):
This metric calculated MMD between real and synthetic samples
"""

def __init__(self, kernel: typing.Callable = tsgm.utils.mmd.exp_quad_kernel) -> None:
def __init__(self, kernel: T.Callable = tsgm.utils.mmd.exp_quad_kernel) -> None:
self.kernel = kernel

def __call__(self, D1: tsgm.dataset.DatasetOrTensor, D2: tsgm.dataset.DatasetOrTensor) -> float:
Expand All @@ -215,7 +215,7 @@ class DiscriminativeMetric(Metric):
"""
The discriminative metric measures how accurately a discriminative model can separate synthetic and real data.
"""
def __call__(self, d_hist: tsgm.dataset.DatasetOrTensor, d_syn: tsgm.dataset.DatasetOrTensor, model, test_size, n_epochs, metric=None, random_seed=None) -> float:
def __call__(self, d_hist: tsgm.dataset.DatasetOrTensor, d_syn: tsgm.dataset.DatasetOrTensor, model: T.Callable, test_size: T.Union[float, int], n_epochs: int, metric: T.Optional[T.Callable] = None, random_seed: T.Optional[int] = None) -> float:
X_hist, X_syn = _dataset_or_tensor_to_tensor(d_hist), _dataset_or_tensor_to_tensor(d_syn)
X_all, y_all = np.concatenate([X_hist, X_syn]), np.concatenate([[1] * len(d_hist), [0] * len(d_syn)])
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(X_all, y_all, test_size=test_size, random_state=random_seed)
Expand Down
6 changes: 3 additions & 3 deletions tsgm/metrics/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
'''


def _validate_axis(axis: typing.Optional[int]):
def _validate_axis(axis: typing.Optional[int]) -> int:
assert axis == 1 or axis == 2 or axis is None


def _apply_percacf(x):
def _apply_percacf(x: tsgm.types.Tensor) -> tsgm.types.Tensor:
return np.percentile(acf(x), .75)


def _apply_power(x):
def _apply_power(x: tsgm.types.Tensor) -> tsgm.types.Tensor:
return np.power(x, 2).sum() / len(x)


Expand Down
Loading

0 comments on commit 1831785

Please sign in to comment.