Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
ft: add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
urimandujano committed Jan 17, 2024
1 parent 3839119 commit c2a4ca4
Showing 1 changed file with 118 additions and 1 deletion.
119 changes: 118 additions & 1 deletion tests/test_utilities.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
import logging
import uuid
from typing import Type
from unittest import mock

import kubernetes
import pytest
import urllib3
from kubernetes.client import models as k8s_models
from prefect.infrastructure.kubernetes import KubernetesJob

from prefect_kubernetes.utilities import convert_manifest_to_model
from prefect_kubernetes.utilities import (
ResilientStreamWatcher,
convert_manifest_to_model,
)

base_path = "tests/sample_k8s_resources"

Expand Down Expand Up @@ -200,3 +210,110 @@ def test_bad_model_type_raises(v1_model_name):
match="`v1_model` must be the name of a valid Kubernetes client model.",
):
convert_manifest_to_model(sample_deployment_manifest, v1_model_name)


def test_resilient_streaming_retries_on_configured_errors(caplog):
watcher = ResilientStreamWatcher(logger=logging.getLogger("test"))

with mock.patch.object(
watcher.watch,
"stream",
side_effect=[
watcher.reconnect_exceptions[0],
watcher.reconnect_exceptions[0],
["random_success"],
],
) as mocked_stream:
for log in watcher.api_object_stream(str):
assert log == "random_success"

assert mocked_stream.call_count == 3
assert "Unable to connect, retrying..." in caplog.text


@pytest.mark.parametrize(
"exc", [Exception, TypeError, ValueError, urllib3.exceptions.ProtocolError]
)
def test_resilient_streaming_raises_on_unconfigured_errors(
exc: Type[Exception], caplog
):
watcher = ResilientStreamWatcher(
logger=logging.getLogger("test"), reconnect_exceptions=[]
)

with mock.patch.object(watcher.watch, "stream", side_effect=[exc]) as mocked_stream:
with pytest.raises(exc):
for _ in watcher.api_object_stream(str):
pass

assert mocked_stream.call_count == 1
assert "Unexpected error" in caplog.text
assert exc.__name__ in caplog.text


def _create_api_objects_mocks(n: int = 3):
objects = []
for _ in range(n):
o = mock.MagicMock(spec=kubernetes.client.V1Pod)
o.metadata = mock.PropertyMock()
o.metadata.uid = uuid.uuid4()
objects.append(o)
return objects


def test_resilient_streaming_deduplicates_api_objects_on_reconnects():
watcher = ResilientStreamWatcher(logger=logging.getLogger("test"))

object_pool = _create_api_objects_mocks()
thrown_exceptions = 0

def my_stream(*args, **kwargs):
"""
Simulate a stream that throws exceptions after yielding the first
object before yielding the rest of the objects.
"""
for o in object_pool:
yield {"object": o}

nonlocal thrown_exceptions
if thrown_exceptions < 3:
thrown_exceptions += 1
raise watcher.reconnect_exceptions[0]

watcher.watch.stream = my_stream
results = [obj for obj in watcher.api_object_stream(str)]

assert len(object_pool) == len(results)


def test_resilient_streaming_pulls_all_logs_on_reconnects():
watcher = ResilientStreamWatcher(logger=logging.getLogger("test"))

logs = ["log1", "log2", "log3", "log4"]
thrown_exceptions = 0

def my_stream(*args, **kwargs):
"""
Simulate a stream that throws exceptions after yielding the first
object before yielding the rest of the objects.
"""
for log in logs:
yield log

nonlocal thrown_exceptions
if thrown_exceptions < 3:
thrown_exceptions += 1
raise watcher.reconnect_exceptions[0]

watcher.watch.stream = my_stream
results = [obj for obj in watcher.log_stream(str)]

assert results == [
"log1", # Before first exception
"log1", # Before second exception
"log1", # Before third exception
"log1", # No more exceptions from here onward
"log2",
"log3",
"log4",
]

0 comments on commit c2a4ca4

Please sign in to comment.