Skip to content

Commit

Permalink
[uss_qualifier] async session queries also return headers
Browse files Browse the repository at this point in the history
  • Loading branch information
Shastick committed Nov 30, 2023
1 parent d00bc7c commit 35afabc
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 17 deletions.
30 changes: 25 additions & 5 deletions monitoring/monitorlib/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools
from typing import Dict, List, Optional
import urllib.parse
from aiohttp import ClientSession
from aiohttp import ClientSession, ClientResponse

import jwt
import requests
Expand Down Expand Up @@ -190,32 +190,52 @@ def adjust_request_kwargs(self, url, method, kwargs):
return kwargs

async def put(self, url, **kwargs):
"""Returns (status, headers, json)"""
url = self._prefix_url + url
if "auth" not in kwargs:
kwargs = self.adjust_request_kwargs(url, "PUT", kwargs)
async with self._client.put(url, **kwargs) as response:
return response.status, await response.json()
return (
response.status,
{k: v for k, v in response.headers.items()},
await response.json(),
)

async def get(self, url, **kwargs):
"""Returns (status, headers, json)"""
url = self._prefix_url + url
if "auth" not in kwargs:
kwargs = self.adjust_request_kwargs(url, "GET", kwargs)
async with self._client.get(url, **kwargs) as response:
return response.status, await response.json()
return (
response.status,
{k: v for k, v in response.headers.items()},
await response.json(),
)

async def post(self, url, **kwargs):
"""Returns (status, headers, json)"""
url = self._prefix_url + url
if "auth" not in kwargs:
kwargs = self.adjust_request_kwargs(url, "POST", kwargs)
async with self._client.post(url, **kwargs) as response:
return response.status, await response.json()
return (
response.status,
{k: v for k, v in response.headers.items()},
await response.json(),
)

async def delete(self, url, **kwargs):
"""Returns (status, headers, json)"""
url = self._prefix_url + url
if "auth" not in kwargs:
kwargs = self.adjust_request_kwargs(url, "DELETE", kwargs)
async with self._client.delete(url, **kwargs) as response:
return response.status, await response.json()
return (
response.status,
{k: v for k, v in response.headers.items()},
await response.json(),
)


def default_scopes(scopes: List[str]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def test_create_isa_concurrent(ids, session_ridv1_async):
)
)
for isa_id, resp in results:
assert resp[0] == 200, resp[1]
data = resp[1]
assert resp[0] == 200, resp[2]
data = resp[2]
assert data["service_area"]["id"] == isa_id
assert data["service_area"]["flights_url"] == "https://example.com/dss"
assert_datetimes_are_equal(
Expand All @@ -133,9 +133,9 @@ def test_get_isa_by_ids_concurrent(ids, session_ridv1_async):
)
)
for isa_id, resp in results:
assert resp[0] == 200, resp[1]
assert resp[0] == 200, resp[2]

data = resp[1]
data = resp[2]
assert data["service_area"]["id"] == isa_id
assert data["service_area"]["flights_url"] == FLIGHTS_URL

Expand All @@ -162,8 +162,8 @@ def test_delete_isa_concurrent(ids, session_ridv1_async):
)

for isa_id, resp in results:
assert resp[0] == 200, resp[1]
version = resp[1]["service_area"]["version"]
assert resp[0] == 200, resp[2]
version = resp[2]["service_area"]["version"]
version_map[isa_id] = version

# Delete ISAs concurrently
Expand All @@ -178,4 +178,4 @@ def test_delete_isa_concurrent(ids, session_ridv1_async):
)

for isa_id, resp in results:
assert resp[0], resp[1]
assert resp[0], resp[2]
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def test_create_ops_concurrent(ids, scd_api, scd_session_async):
op_id = req_map[0]
op_resp_map[op_id] = {}
op_resp_map[op_id]["status_code"] = resp[0][0]
op_resp_map[op_id]["content"] = resp[0][1]
op_resp_map[op_id]["content"] = resp[0][2]
for op_id, resp in op_resp_map.items():
if resp["status_code"] != 201:
try:
Expand Down Expand Up @@ -342,7 +342,7 @@ def test_get_ops_by_ids_concurrent(ids, scd_api, scd_session_async):
for op_id, resp in zip(map(ids, OP_TYPES), results):
op_resp_map[op_id] = {}
op_resp_map[op_id]["status_code"] = resp[0]
op_resp_map[op_id]["content"] = resp[1]
op_resp_map[op_id]["content"] = resp[2]

for op_id, resp in op_resp_map.items():
assert resp["status_code"] == 200, resp["content"]
Expand Down Expand Up @@ -381,7 +381,7 @@ def test_get_ops_by_search_concurrent(ids, scd_api, scd_session_async):
for idx, resp in zip(range(len(OP_TYPES)), results):
op_resp_map[idx] = {}
op_resp_map[idx]["status_code"] = resp[0]
op_resp_map[idx]["content"] = resp[1]
op_resp_map[idx]["content"] = resp[2]

for idx, resp in op_resp_map.items():
assert resp["status_code"] == 200, resp["content"]
Expand Down Expand Up @@ -431,7 +431,7 @@ def test_mutate_ops_concurrent(ids, scd_api, scd_session, scd_session_async):
op_id = req_map[0]
op_resp_map[op_id] = {}
op_resp_map[op_id]["status_code"] = resp[0][0]
op_resp_map[op_id]["content"] = resp[0][1]
op_resp_map[op_id]["content"] = resp[0][2]

ovn_map.clear()

Expand Down Expand Up @@ -486,7 +486,7 @@ def test_delete_op_concurrent(ids, scd_api, scd_session_async):
for op_id, resp in zip(map(ids, OP_TYPES), results):
op_resp_map[op_id] = {}
op_resp_map[op_id]["status_code"] = resp[0]
op_resp_map[op_id]["content"] = resp[1]
op_resp_map[op_id]["content"] = resp[2]

assert len(op_resp_map) == len(OP_TYPES)

Expand Down

0 comments on commit 35afabc

Please sign in to comment.