Skip to content

Commit

Permalink
rename BaseModel to ExtrasBaseModel
Browse files Browse the repository at this point in the history
  • Loading branch information
ruthenian8 committed Dec 4, 2023
1 parent bd7355a commit 4614c93
Show file tree
Hide file tree
Showing 11 changed files with 32 additions and 32 deletions.
4 changes: 2 additions & 2 deletions dff/script/extras/conditions/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from .dataset import DatasetItem
from .utils import LABEL_KEY
from .models.base_model import BaseModel
from .models.base_model import ExtrasBaseModel


@singledispatch
Expand Down Expand Up @@ -79,7 +79,7 @@ def has_cls_label_innner(ctx: Context, pipeline: Pipeline) -> bool:


def has_match(
model: BaseModel,
model: ExtrasBaseModel,
positive_examples: Optional[List[str]],
negative_examples: Optional[List[str]] = None,
threshold: float = 0.9,
Expand Down
4 changes: 2 additions & 2 deletions dff/script/extras/conditions/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Base Model
-----------
This module defines an abstract interface for label-scoring models, :py:class:`~BaseModel`.
This module defines an abstract interface for label-scoring models, :py:class:`~ExtrasBaseModel`.
When defining custom label-scoring models, always inherit from this class.
"""
from copy import copy
Expand All @@ -13,7 +13,7 @@
from ..utils import LABEL_KEY


class BaseModel(ABC):
class ExtrasBaseModel(ABC):
"""
Base class for label-scoring models.
Namespace key should be declared, if you want the scores of your model
Expand Down
4 changes: 2 additions & 2 deletions dff/script/extras/conditions/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
except ImportError:
hf_available = False

from .base_model import BaseModel
from .base_model import ExtrasBaseModel
from ..dataset import Dataset


class BaseHFModel(BaseModel):
class BaseHFModel(ExtrasBaseModel):
"""
Base class for Hugging Face-based annotator models.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re
from typing import Optional, Union

from ...base_model import BaseModel
from ...base_model import ExtrasBaseModel
from ....dataset import Dataset


Expand All @@ -34,7 +34,7 @@ def __call__(self, request: str, **re_kwargs):
return result


class RegexClassifier(BaseModel):
class RegexClassifier(ExtrasBaseModel):
"""
RegexClassifier wraps a :py:class:`~RegexModel` for label annotation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
gensim_available = False
ALL_MODELS = []

from ...base_model import BaseModel
from ...base_model import ExtrasBaseModel
from ....dataset import Dataset
from ....utils import DefaultTokenizer
from .cosine_matcher_mixin import CosineMatcherMixin


class GensimMatcher(CosineMatcherMixin, BaseModel):
class GensimMatcher(CosineMatcherMixin, ExtrasBaseModel):
"""
GensimMatcher utilizes embeddings from Gensim models to measure
proximity between utterances and pre-defined labels.
Expand All @@ -47,7 +47,7 @@ def __init__(
if not gensim_available:
raise ImportError("Required packages missing. Try `pip install dff[ext,gensim]`")
CosineMatcherMixin.__init__(self, dataset=dataset)
BaseModel.__init__(self, namespace_key=namespace_key)
ExtrasBaseModel.__init__(self, namespace_key=namespace_key)
self.model = model
self.tokenizer = tokenizer or DefaultTokenizer()
# self.fit(self.dataset, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions dff/script/extras/conditions/models/remote_api/async_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
"""
from dff.script import Context

from ..base_model import BaseModel
from ..base_model import ExtrasBaseModel
from ...utils import LABEL_KEY


class AsyncMixin(BaseModel):
class AsyncMixin(ExtrasBaseModel):
"""
This class overrides the :py:meth:`~__call__` method
allowing for asynchronous calls to annotator models.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import json
from pathlib import Path

from ..base_model import BaseModel
from ..base_model import ExtrasBaseModel
from .async_mixin import AsyncMixin

try:
Expand All @@ -24,7 +24,7 @@
dialogflow_available = False


class AbstractGDFModel(BaseModel):
class AbstractGDFModel(ExtrasBaseModel):
"""
Abstract class for a Google Dialogflow model.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
except ImportError:
hf_api_available = False

from ..base_model import BaseModel
from ..base_model import ExtrasBaseModel
from .async_mixin import AsyncMixin


class AbstractHFAPIModel(BaseModel):
class AbstractHFAPIModel(ExtrasBaseModel):
"""
Abstract class for an HF API annotator.
"""
Expand Down
4 changes: 2 additions & 2 deletions dff/script/extras/conditions/models/remote_api/rasa_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@

from http import HTTPStatus
from ...utils import RasaResponse
from ..base_model import BaseModel
from ..base_model import ExtrasBaseModel
from .async_mixin import AsyncMixin


class AbstractRasaModel(BaseModel):
class AbstractRasaModel(ExtrasBaseModel):
"""
Abstract class for a RASA annotator.
"""
Expand Down
4 changes: 2 additions & 2 deletions dff/script/extras/conditions/models/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
joblib = object
sklearn_available = False

from .base_model import BaseModel
from .base_model import ExtrasBaseModel


class BaseSklearnModel(BaseModel):
class BaseSklearnModel(ExtrasBaseModel):
"""
Base class for Sklearn-based annotator models.
Expand Down
22 changes: 11 additions & 11 deletions docs/source/user_guides/extended_conditions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ The following code snippets demonstrate how you can write a connector for an ext
from dff.script.extras.conditions.models.remote_api.async_mixin import (
AsyncMixin,
)
from dff.script.extras.conditions.models.base_model import BaseModel
from dff.script.extras.conditions.models.base_model import ExtrasBaseModel
To create a synchronous connector to an API, we recommend you to inherit the class from ``BaseModel``.
To create a synchronous connector to an API, we recommend you to inherit the class from ``ExtrasBaseModel``.
The only method that you have to override is the ``predict`` method.
It takes a request string and returns a {label: probability} dictionary.
In case the request has not been successful, an empty dictionary can be returned.
Expand All @@ -118,7 +118,7 @@ We use `httpx` as an asynchronous http client.

.. code-block:: python
class CustomAPIConnector(BaseModel):
class CustomAPIConnector(ExtrasBaseModel):
def __init__(self, url: str, namespace_key: str = "default") -> None:
super().__init__(namespace_key)
self.url = url
Expand Down Expand Up @@ -156,12 +156,12 @@ In this section, we show the way you can adapt a classifier model to DFF's class
import pickle
from dff.script.extras.conditions.models.base_model import BaseModel
from dff.script.extras.conditions.models.base_model import ExtrasBaseModel
In order to create your own classifier, create a child class of the ``BaseModel`` abstract type.
In order to create your own classifier, create a child class of the ``ExtrasBaseModel`` abstract type.

``BaseModel`` only has one abstract method, ``predict``, that should necessarily be overridden.
``ExtrasBaseModel`` only has one abstract method, ``predict``, that should necessarily be overridden.
The method takes a request string and returns a dictionary of class labels
and their respective probabilities.

Expand All @@ -174,7 +174,7 @@ at your own convenience, e.g. lack of those will not raise an error.

.. code-block:: python
class MyCustomClassifier(BaseModel):
class MyCustomClassifier(ExtrasBaseModel):
def __init__(
self, swear_words: list, namespace_key: str = "default"
) -> None:
Expand Down Expand Up @@ -207,13 +207,13 @@ The following code snippets demonstrate the way in which a custom matcher can be

.. code-block:: python
from dff.script.extras.conditions.models.base_model import BaseModel
from dff.script.extras.conditions.models.base_model import ExtrasBaseModel
from dff.script.extras.conditions.models.local.cosine_matchers.cosine_matcher_mixin import (
CosineMatcherMixin,
)
To build your own cosine matcher, you should inherit
from the ``CosineMatcherMixin`` and from the ``BaseModel``,
from the ``CosineMatcherMixin`` and from the ``ExtrasBaseModel``,
with the former taking precedence.
This requires the ``__init__`` method to take ``dataset`` argument.

Expand All @@ -233,10 +233,10 @@ e.g. lack of those will not raise an error.

.. code-block:: python
class MyCustomMatcher(CosineMatcherMixin, BaseModel):
class MyCustomMatcher(CosineMatcherMixin, ExtrasBaseModel):
def __init__(self, model, dataset, namespace_key) -> None:
CosineMatcherMixin.__init__(self, dataset)
BaseModel.__init__(self, namespace_key)
ExtrasBaseModel.__init__(self, namespace_key)
self.model = model
def transform(self, request: str):
Expand Down

0 comments on commit 4614c93

Please sign in to comment.