Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update with mypy typing #427

Merged
merged 1 commit into from
May 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions pyrainbird/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _device_busy_retry() -> JitterRetry:
return JitterRetry(
attempts=_retry_attempts(),
start_timeout=_retry_delay(),
statuses=[HTTPStatus.SERVICE_UNAVAILABLE.value],
statuses=set([HTTPStatus.SERVICE_UNAVAILABLE.value]),
retry_all_server_errors=False,
)

Expand All @@ -114,10 +114,10 @@ def __init__(
self._password = password
self._coder = encryption.PayloadCoder(password, _LOGGER)

def with_retry_options(self, retry_options: RetryOptions) -> "AsyncRainbirdClient":
def with_retry_options(self, retry_options: RetryOptions) -> "AsyncRainbirdClient": # type: ignore[valid-type]
"""Create a new AsyncRainbirdClient with retry options."""
return AsyncRainbirdClient(
RetryClient(client_session=self._websession, retry_options=retry_options),
RetryClient(client_session=self._websession, retry_options=retry_options), # type: ignore[arg-type]
self._host,
self._password,
)
Expand Down Expand Up @@ -147,7 +147,7 @@ async def request(
"Error communicating with Rain Bird device"
) from err
content = await resp.read()
return self._coder.decode_command(content)
return self._coder.decode_command(content) # type: ignore


def CreateController(
Expand All @@ -165,7 +165,7 @@ class AsyncRainbirdController:
def __init__(
self,
local_client: AsyncRainbirdClient,
cloud_client: AsyncRainbirdClient = None,
cloud_client: AsyncRainbirdClient | None = None,
) -> None:
"""Initialize AsyncRainbirdController."""
self._local_client = local_client
Expand Down Expand Up @@ -418,15 +418,15 @@ async def get_schedule(self) -> Schedule:
commands.append("%04x" % (0x80 | zone_page))
_LOGGER.debug("Sending schedule commands: %s", commands)
# Run command serially to avoid overwhelming the controller
schedule_data = {
schedule_data: dict[str, Any] = {
"controllerInfo": {},
"programInfo": [],
"programStartInfo": [],
"durations": [],
}
for command in commands:
result = await self._process_command(
None, "RetrieveScheduleRequest", int(command, 16) # Disable validation
None, "RetrieveScheduleRequest", int(command, 16) # type: ignore
)
if not isinstance(result, dict):
continue
Expand Down Expand Up @@ -509,4 +509,4 @@ async def _cacheable_command(
return result
result = await self._process_command(funct, command, *args)
self._cache[key] = result
return result
return result # type: ignore
152 changes: 106 additions & 46 deletions pyrainbird/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class States:
"""Rainbird controller response containing a bitmask string e.g. active zones."""

count: int
mask: str
mask: int
states: tuple

def __init__(self, mask: str) -> None:
Expand Down Expand Up @@ -195,20 +195,42 @@ class WaterBudget:
class WifiParams(DataClassDictMixin):
"""Wifi parameters for the device."""

mac_address: Optional[str] = field(metadata=field_options(alias="macAddress"), default=None)
mac_address: Optional[str] = field(
metadata=field_options(alias="macAddress"), default=None
)
"""The mac address for the device, also referred to as the stick id."""

local_ip_address: Optional[str] = field(metadata=field_options(alias="localIpAddress"), default=None)
local_netmask: Optional[str] = field(metadata=field_options(alias="localNetmask"), default=None)
local_gateway: Optional[str] = field(metadata=field_options(alias="localGateway"), default=None)
local_ip_address: Optional[str] = field(
metadata=field_options(alias="localIpAddress"), default=None
)
local_netmask: Optional[str] = field(
metadata=field_options(alias="localNetmask"), default=None
)
local_gateway: Optional[str] = field(
metadata=field_options(alias="localGateway"), default=None
)
rssi: Optional[int] = None
wifi_ssid: Optional[str] = field(metadata=field_options(alias="wifiSsid"), default=None)
wifi_password: Optional[str] = field(metadata=field_options(alias="wifiPassword"), default=None)
wifi_security: Optional[str] = field(metadata=field_options(alias="wifiSecurity"), default=None)
ap_timeout_no_lan: Optional[int] = field(metadata=field_options(alias="apTimeoutNoLan"), default=None)
ap_timeout_idle: Optional[int] = field(metadata=field_options(alias="apTimeoutIdle"), default=None)
ap_security: Optional[str] = field(metadata=field_options(alias="apSecurity"), default=None)
sick_version: Optional[str] = field(metadata=field_options(alias="stickVersion"), default=None)
wifi_ssid: Optional[str] = field(
metadata=field_options(alias="wifiSsid"), default=None
)
wifi_password: Optional[str] = field(
metadata=field_options(alias="wifiPassword"), default=None
)
wifi_security: Optional[str] = field(
metadata=field_options(alias="wifiSecurity"), default=None
)
ap_timeout_no_lan: Optional[int] = field(
metadata=field_options(alias="apTimeoutNoLan"), default=None
)
ap_timeout_idle: Optional[int] = field(
metadata=field_options(alias="apTimeoutIdle"), default=None
)
ap_security: Optional[str] = field(
metadata=field_options(alias="apSecurity"), default=None
)
sick_version: Optional[str] = field(
metadata=field_options(alias="stickVersion"), default=None
)


class SoilType(IntEnum):
Expand All @@ -227,9 +249,15 @@ class ProgramInfo(DataClassDictMixin):
The values are repeated once for each program.
"""

soil_types: list[SoilType] = field(default_factory=list, metadata=field_options(alias="SoilTypes"))
flow_rates: list[int] = field(default_factory=list, metadata=field_options(alias="FlowRates"))
flow_units: list[int] = field(default_factory=list, metadata=field_options(alias="FlowUnits"))
soil_types: list[SoilType] = field(
default_factory=list, metadata=field_options(alias="SoilTypes")
)
flow_rates: list[int] = field(
default_factory=list, metadata=field_options(alias="FlowRates")
)
flow_units: list[int] = field(
default_factory=list, metadata=field_options(alias="FlowUnits")
)

@classmethod
def __pre_deserialize__(cls, values: dict[Any, Any]) -> dict[Any, Any]:
Expand All @@ -253,9 +281,15 @@ class Settings(DataClassDictMixin):
"""Country location of the device."""

# Program information
soil_types: list[SoilType] = field(default_factory=list, metadata=field_options(alias="SoilTypes"))
flow_rates: list[int] = field(default_factory=list, metadata=field_options(alias="FlowRates"))
flow_units: list[int] = field(default_factory=list, metadata=field_options(alias="FlowUnits"))
soil_types: list[SoilType] = field(
default_factory=list, metadata=field_options(alias="SoilTypes")
)
flow_rates: list[int] = field(
default_factory=list, metadata=field_options(alias="FlowRates")
)
flow_units: list[int] = field(
default_factory=list, metadata=field_options(alias="FlowUnits")
)

@classmethod
def __pre_deserialize__(cls, values: dict[Any, Any]) -> dict[Any, Any]:
Expand Down Expand Up @@ -294,7 +328,7 @@ def __init__(self, status: Optional[str], settings: Optional[Settings]) -> None:
@property
def status(self) -> str:
"""Return device status."""
return self._status
return self._status or "unknown"

@property
def settings(self) -> Optional[Settings]:
Expand All @@ -316,7 +350,9 @@ class Controller(DataClassDictMixin):
available_stations: list[int] = field(
metadata=field_options(alias="availableStations"), default_factory=list
)
custom_name: Optional[str] = field(metadata=field_options(alias="customName"), default=None)
custom_name: Optional[str] = field(
metadata=field_options(alias="customName"), default=None
)
custom_program_names: dict[str, str] = field(
metadata=field_options(alias="customProgramNames"), default_factory=dict
)
Expand Down Expand Up @@ -345,18 +381,30 @@ class Weather(DataClassDictMixin):
city: Optional[str] = None
forecast: list[Forecast] = field(default_factory=list)
location: Optional[str] = None
time_zone_id: Optional[str] = field(metadata=field_options(alias="timeZoneId"), default=None)
time_zone_raw_offset: Optional[str] = field(metadata=field_options(alias="timeZoneRawOffset"), default=None)
time_zone_id: Optional[str] = field(
metadata=field_options(alias="timeZoneId"), default=None
)
time_zone_raw_offset: Optional[str] = field(
metadata=field_options(alias="timeZoneRawOffset"), default=None
)


@dataclass
class WeatherAndStatus(DataClassDictMixin):
"""Weather and status from the cloud API."""

stick_id: Optional[str] = field(metadata=field_options(alias="StickId"), default=None)
controller: Optional[Controller] = field(metadata=field_options(alias="Controller"), default=None)
forecasted_rain: Optional[dict[str, Any]] = field(metadata=field_options(alias="ForecastedRain"), default=None)
weather: Optional[Weather] = field(metadata=field_options(alias="Weather"), default=None)
stick_id: Optional[str] = field(
metadata=field_options(alias="StickId"), default=None
)
controller: Optional[Controller] = field(
metadata=field_options(alias="Controller"), default=None
)
forecasted_rain: Optional[dict[str, Any]] = field(
metadata=field_options(alias="ForecastedRain"), default=None
)
weather: Optional[Weather] = field(
metadata=field_options(alias="Weather"), default=None
)


@dataclass
Expand Down Expand Up @@ -398,6 +446,7 @@ def deserialize(self, values: dict[str, Any]) -> datetime.datetime:
int(values["second"]),
)


@dataclass
class ControllerState(DataClassDictMixin):
"""Details about the controller state."""
Expand All @@ -417,13 +466,14 @@ class ControllerState(DataClassDictMixin):
# TODO: Likely need to make this a mask w/ States
active_station: int = field(metadata=field_options(alias="activeStation"))

device_time: datetime.datetime = field(metadata=field_options(serialization_strategy=DeviceTime()))
device_time: datetime.datetime = field(
metadata=field_options(serialization_strategy=DeviceTime())
)

@classmethod
def __pre_deserialize__(cls, d: dict[Any, Any]) -> dict[Any, Any]:
d["device_time"] = {
k: d[k]
for k in ("year", "month", "day", "hour", "minute", "second")
k: d[k] for k in ("year", "month", "day", "hour", "minute", "second")
}
return d

Expand Down Expand Up @@ -459,7 +509,7 @@ def name(self) -> str:
@classmethod
def __pre_deserialize__(cls, values: dict[Any, Any]) -> dict[Any, Any]:
if duration := values.get("duration"):
values["duration"] = duration * 60 #datetime.timedelta(minutes=duration)
values["duration"] = duration * 60 # datetime.timedelta(minutes=duration)
return values


Expand All @@ -479,15 +529,13 @@ def deserialize(self, starts: list[int]) -> list[datetime.time]:
return result




class DayOfWeekSerializationStrategy(SerializationStrategy):
"""Validate different ways the device time parameter is handled."""

def serialize(self, value: Any) -> str:
raise ValueError("Serialization not implemented")

def deserialize(self, mask: int) -> list[DayOfWeek]:
def deserialize(self, mask: int) -> set[DayOfWeek]:
"""Deserialize the device time fields."""
_LOGGER.debug("DayOfWeekSerializationStrategy=%s", mask)
result: set[DayOfWeek] = set()
Expand All @@ -512,7 +560,13 @@ class Program(DataClassDictMixin):
frequency: ProgramFrequency
"""Determines how often the program runs."""

days_of_week: set[DayOfWeek] = field(metadata=field_options(alias="daysOfWeekMask", serialization_strategy=DayOfWeekSerializationStrategy()), default_factory=set)
days_of_week: set[DayOfWeek] = field(
metadata=field_options(
alias="daysOfWeekMask",
serialization_strategy=DayOfWeekSerializationStrategy(),
),
default_factory=set,
)
"""For a CUSTOM program determines the days of the week."""

period: Optional[int] = None
Expand All @@ -521,13 +575,18 @@ class Program(DataClassDictMixin):
synchro: Optional[int] = None
"""Days from today before starting the first day of the program."""

starts: list[datetime.time] = field(default_factory=list, metadata=field_options(serialization_strategy=TimeSerializationStrategy()))
starts: list[datetime.time] = field(
default_factory=list,
metadata=field_options(serialization_strategy=TimeSerializationStrategy()),
)
"""Time of day the program starts."""

durations: list[ZoneDuration] = field(default_factory=list)
"""Durations for run times for each zone."""

controller_info: Optional[ControllerInfo] = field(metadata=field_options(alias="controllerInfo"), default=None)
controller_info: Optional[ControllerInfo] = field(
metadata=field_options(alias="controllerInfo"), default=None
)
"""Information about the controller as input into the programs."""

@property
Expand All @@ -541,7 +600,7 @@ def timeline(self) -> ProgramTimeline:
"""Return a timeline of events for the program."""
return self.timeline_tz(datetime.datetime.now().tzinfo)

def timeline_tz(self, tzinfo: datetime.tzinfo) -> ProgramTimeline:
def timeline_tz(self, tzinfo: datetime.tzinfo | None) -> ProgramTimeline:
"""Return a timeline of events for the program."""
iters: list[Iterable[SortableItem[Timespan, ProgramEvent]]] = []
now = datetime.datetime.now(tzinfo)
Expand All @@ -553,9 +612,9 @@ def timeline_tz(self, tzinfo: datetime.tzinfo) -> ProgramTimeline:
self.frequency,
dtstart,
self.duration,
self.synchro,
self.synchro or 0,
self.days_of_week,
self.period,
self.period or 0,
delay_days=self.delay_days,
),
)
Expand All @@ -575,9 +634,9 @@ def zone_timeline(self) -> ProgramTimeline:
self.frequency,
dtstart,
zone_duration.duration,
self.synchro,
self.synchro or 0,
self.days_of_week,
self.period,
self.period or 0,
delay_days=self.delay_days,
)
)
Expand All @@ -604,12 +663,13 @@ def __post_init__(self):
self.period = None



@dataclass
class Schedule(DataClassDictMixin):
"""Details about program schedules."""

controller_info: Optional[ControllerInfo] = field(metadata=field_options(alias="controllerInfo"))
controller_info: Optional[ControllerInfo] = field(
metadata=field_options(alias="controllerInfo")
)
"""Information about the controller used in the schedule."""

programs: list[Program] = field(metadata=field_options(alias="programInfo"))
Expand All @@ -620,7 +680,7 @@ def timeline(self) -> ProgramTimeline:
"""Return a timeline of all programs."""
return self.timeline_tz(datetime.datetime.now().tzinfo)

def timeline_tz(self, tzinfo: datetime.tzinfo) -> ProgramTimeline:
def timeline_tz(self, tzinfo: datetime.tzinfo | None) -> ProgramTimeline:
"""Return a timeline of all programs."""
iters: list[Iterable[SortableItem[Timespan, ProgramEvent]]] = []
now = datetime.datetime.now(tzinfo)
Expand All @@ -633,9 +693,9 @@ def timeline_tz(self, tzinfo: datetime.tzinfo) -> ProgramTimeline:
program.frequency,
dtstart,
program.duration,
program.synchro,
program.synchro or 0,
program.days_of_week,
program.period,
program.period or 0,
delay_days=self.delay_days,
)
)
Expand Down
4 changes: 2 additions & 2 deletions pyrainbird/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def encode_command(self, method: str, params: dict[str, Any]) -> str:
return send_data
return encrypt(send_data, self._password)

def decode_command(self, content: bytes) -> str:
def decode_command(self, content: bytes) -> str | dict[str, Any]:
"""Decode a response payload."""
if self._password is not None:
decrypted_data = (
Expand All @@ -112,7 +112,7 @@ def decode_command(self, content: bytes) -> str:
.rstrip()
)
content = decrypted_data
self._logger.debug("Response: %s" % content)
self._logger.debug("Response: %r" % content)
response = json.loads(content)
if error := response.get("error"):
msg = ["Error from controller"]
Expand Down
Loading
Loading