Skip to content

Commit

Permalink
Improve typing
Browse files Browse the repository at this point in the history
  • Loading branch information
dustalov committed Sep 6, 2024
1 parent a955fb6 commit 1c2fccd
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
18 changes: 16 additions & 2 deletions python/evalica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Collection, Hashable
from dataclasses import dataclass
from types import MappingProxyType
from typing import Generic, Literal, TypeVar
from typing import Generic, Literal, Protocol, TypeVar

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -158,6 +158,20 @@ def matrices(
)


class ResultProtocol(Protocol[T]):
"""
The result protocol.
Attributes:
scores: The element scores.
index: The index.
"""

scores: pd.Series[float]
index: dict[T, int]


@dataclass(frozen=True)
class CountingResult(Generic[T]):
"""
Expand Down Expand Up @@ -969,7 +983,7 @@ def pairwise_scores(
return pairwise_scores_pyo3(scores)


def pairwise_frame(scores: pd.Series[T]) -> pd.DataFrame: # type: ignore[type-var]
def pairwise_frame(scores: pd.Series[float]) -> pd.DataFrame:
"""
Create a data frame out of the estimated pairwise scores.
Expand Down
20 changes: 13 additions & 7 deletions python/evalica/gradio/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
__license__ = "Apache 2.0"

import argparse
from typing import TYPE_CHECKING, Any, Protocol, cast
from typing import TYPE_CHECKING, Protocol, cast

import evalica
import gradio as gr
Expand Down Expand Up @@ -56,15 +56,21 @@ def visualize(df_pairwise: pd.DataFrame) -> Figure:
}


class ResultScoresProtocol(Protocol):
scores: pd.Series[str]


class CallableAlgorithm(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> ResultScoresProtocol: ... # noqa: ANN401
def __call__(
self,
xs: pd.Series[str],
ys: pd.Series[str],
winners: pd.Series[Winner], # type: ignore[type-var]
) -> evalica.ResultProtocol[str]: ...


def invoke(algorithm: str, xs: pd.Series[str], ys: pd.Series[str], winners: pd.Series[str]) -> pd.Series[str]:
def invoke(
algorithm: str,
xs: pd.Series[str],
ys: pd.Series[str],
winners: pd.Series[Winner], # type: ignore[type-var]
) -> pd.Series[float]:
algorithm_impl = cast("CallableAlgorithm", ALGORITHMS[algorithm])

return algorithm_impl(xs=xs, ys=ys, winners=winners).scores
Expand Down

0 comments on commit 1c2fccd

Please sign in to comment.