diff --git a/src/idpyoidc/client/oauth2/__init__.py b/src/idpyoidc/client/oauth2/__init__.py index 620608b0..20c5c13b 100755 --- a/src/idpyoidc/client/oauth2/__init__.py +++ b/src/idpyoidc/client/oauth2/__init__.py @@ -14,6 +14,7 @@ from idpyoidc.client.service import SUCCESSFUL from idpyoidc.client.service import Service from idpyoidc.client.util import do_add_ons +from idpyoidc.client.util import get_content_type from idpyoidc.client.util import get_deserialization_method from idpyoidc.configure import Configuration from idpyoidc.context import OidcContext @@ -254,12 +255,13 @@ def parse_request_response(self, service, reqresp, response_body_type="", state= if reqresp.status_code in SUCCESSFUL: logger.debug('response_body_type: "{}"'.format(response_body_type)) - _deser_method = get_deserialization_method(reqresp) + _content_type = get_content_type(reqresp) + _deser_method = get_deserialization_method(_content_type) - if _deser_method != response_body_type: + if _content_type != response_body_type: logger.warning( "Not the body type I expected: {} != {}".format( - _deser_method, response_body_type + _content_type, response_body_type ) ) if _deser_method in ["json", "jwt", "urlencoded"]: @@ -282,7 +284,9 @@ def parse_request_response(self, service, reqresp, response_body_type="", state= elif 400 <= reqresp.status_code < 500: logger.error("Error response ({}): {}".format(reqresp.status_code, reqresp.text)) # expecting an error response - _deser_method = get_deserialization_method(reqresp) + _content_type = get_content_type(reqresp) + _deser_method = get_deserialization_method(_content_type) + if not _deser_method: _deser_method = "json" diff --git a/src/idpyoidc/client/util.py b/src/idpyoidc/client/util.py index 6441af49..939f5e18 100755 --- a/src/idpyoidc/client/util.py +++ b/src/idpyoidc/client/util.py @@ -1,8 +1,8 @@ """Utilities""" -import logging -import secrets from http.cookiejar import Cookie from http.cookiejar import http2time +import logging +import secrets from urllib.parse import parse_qs from urllib.parse import urlsplit from urllib.parse import urlunsplit @@ -16,7 +16,6 @@ from idpyoidc.defaults import BASECHR from idpyoidc.exception import UnSupported from idpyoidc.util import importer - from .exception import TimeFormatError from .exception import WrongContentType @@ -202,7 +201,7 @@ def verify_header(reqresp, body_type): logger.debug("resp.txt: %s" % (sanitize(reqresp.text),)) try: - _ctype = reqresp.headers["content-type"] + _ctype = get_content_type(reqresp) except KeyError: if body_type: return body_type @@ -249,45 +248,54 @@ def verify_header(reqresp, body_type): return body_type -def get_deserialization_method(reqresp): - """ - - :param reqresp: Class instance with attributes: ['status', 'text', - 'headers', 'url'] - :return: Verified body content type - """ +def get_content_type(reqresp) -> str: logger.debug("resp.headers: %s" % (sanitize(reqresp.headers),)) logger.debug("resp.txt: %s" % (sanitize(reqresp.text),)) + ctype = reqresp.headers.get("content-type") - _ctype = reqresp.headers.get("content-type") - if not _ctype: + if not ctype: # let's try to detect the format try: reqresp.json() - return "json" + return "application/json" except Exception: try: _jwt = factory(reqresp.txt) - return "jwt" + return "application/jwt" except Exception: - return "urlencoded" # reasonable default ?? - elif ';' in _ctype: - for _typ in _ctype.split(";"): + try: + _info = parse_qs(reqresp.txt) + return "application/x-www-form-urlencoded" + except Exception: + return "text/html" # reasonable default ?? + elif ';' in ctype: + for _typ in ctype.split(";"): if _typ.startswith("application") or _typ.startswith("text"): - _ctype = _typ + ctype = _typ break - if match_to_("application/json", _ctype) or match_to_("application/jrd+json", _ctype): + return ctype + + +def get_deserialization_method(ctype): + """ + + :param reqresp: Class instance with attributes: ['status', 'text', + 'headers', 'url'] + :return: Verified body content type + """ + + if match_to_("application/json", ctype) or match_to_("application/jrd+json", ctype): deser_method = "json" - elif match_to_("application/jwt", _ctype): + elif match_to_("application/jwt", ctype): deser_method = "jwt" - elif match_to_("application/jose", _ctype): + elif match_to_("application/jose", ctype): deser_method = "jose" - elif match_to_(URL_ENCODED, _ctype): + elif match_to_(URL_ENCODED, ctype): deser_method = "urlencoded" - elif match_to_("text/plain", _ctype) or match_to_("test/html", _ctype): + elif match_to_("text/plain", ctype) or match_to_("test/html", ctype): deser_method = "" - elif _ctype.startswith("application/") and _ctype.endswith("+jwt"): + elif ctype.startswith("application/") and ctype.endswith("+jwt"): deser_method = "jwt" else: deser_method = "" # reasonable default ?? diff --git a/tests/test_client_05_util.py b/tests/test_client_05_util.py index 3a22416a..057c9545 100644 --- a/tests/test_client_05_util.py +++ b/tests/test_client_05_util.py @@ -7,6 +7,7 @@ import pytest from idpyoidc.client.exception import WrongContentType +from idpyoidc.client.util import get_content_type from idpyoidc.client.util import get_deserialization_method from idpyoidc.client.util import get_http_body from idpyoidc.client.util import get_http_url @@ -139,28 +140,35 @@ def test_verify_header(): def test_get_deserialization_method_json(): resp = FakeResponse("application/json") - assert get_deserialization_method(resp) == "json" + ctype = get_content_type(resp) + assert get_deserialization_method(ctype) == "json" resp = FakeResponse("application/json; charset=utf-8") - assert get_deserialization_method(resp) == "json" + ctype = get_content_type(resp) + assert get_deserialization_method(ctype) == "json" resp.headers["content-type"] = "application/jrd+json" - assert get_deserialization_method(resp) == "json" + ctype = get_content_type(resp) + assert get_deserialization_method(ctype) == "json" def test_get_deserialization_method_jwt(): resp = FakeResponse("application/jwt") - assert get_deserialization_method(resp) == "jwt" + ctype = get_content_type(resp) + assert get_deserialization_method(ctype) == "jwt" def test_get_deserialization_method_urlencoded(): resp = FakeResponse(URL_ENCODED) - assert get_deserialization_method(resp) == "urlencoded" + ctype = get_content_type(resp) + assert get_deserialization_method(ctype) == "urlencoded" def test_get_deserialization_method_text(): resp = FakeResponse("text/html") - assert get_deserialization_method(resp) == "" + ctype = get_content_type(resp) + assert get_deserialization_method(ctype) == "" resp = FakeResponse("text/plain") - assert get_deserialization_method(resp) == "" + ctype = get_content_type(resp) + assert get_deserialization_method(ctype) == "" diff --git a/tests/test_client_16_util.py b/tests/test_client_16_util.py index a09d65a5..57c4bf64 100644 --- a/tests/test_client_16_util.py +++ b/tests/test_client_16_util.py @@ -12,6 +12,7 @@ from idpyoidc.client.exception import WrongContentType from idpyoidc.client.util import JSON_ENCODED from idpyoidc.client.util import URL_ENCODED +from idpyoidc.client.util import get_content_type from idpyoidc.client.util import get_deserialization_method from idpyoidc.message.oauth2 import AccessTokenRequest from idpyoidc.message.oauth2 import AuthorizationRequest @@ -145,31 +146,38 @@ def test_verify_header(): def test_get_deserialization_method_json(): resp = FakeResponse("application/json") - assert get_deserialization_method(resp) == "json" + ctype = get_content_type(resp) + assert get_deserialization_method(ctype) == "json" resp = FakeResponse("application/json; charset=utf-8") - assert get_deserialization_method(resp) == "json" + ctype = get_content_type(resp) + assert get_deserialization_method(ctype) == "json" resp.headers["content-type"] = "application/jrd+json" - assert get_deserialization_method(resp) == "json" + ctype = get_content_type(resp) + assert get_deserialization_method(ctype) == "json" def test_get_deserialization_method_jwt(): resp = FakeResponse("application/jwt") - assert get_deserialization_method(resp) == "jwt" + ctype = get_content_type(resp) + assert get_deserialization_method(ctype) == "jwt" def test_get_deserialization_method_urlencoded(): resp = FakeResponse(URL_ENCODED) - assert get_deserialization_method(resp) == "urlencoded" + ctype = get_content_type(resp) + assert get_deserialization_method(ctype) == "urlencoded" def test_get_deserialization_method_text(): resp = FakeResponse("text/html") - assert get_deserialization_method(resp) == "" + ctype = get_content_type(resp) + assert get_deserialization_method(ctype) == "" resp = FakeResponse("text/plain") - assert get_deserialization_method(resp) == "" + ctype = get_content_type(resp) + assert get_deserialization_method(ctype) == "" def test_verify_no_content_type():