diff --git a/ansible_base/jwt_consumer/common/util.py b/ansible_base/jwt_consumer/common/util.py index 13b2002e3..ded237d66 100644 --- a/ansible_base/jwt_consumer/common/util.py +++ b/ansible_base/jwt_consumer/common/util.py @@ -1,11 +1,13 @@ import logging import time +from base64 import b64encode from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding from ansible_base.jwt_consumer.common.cert import JWTCert, JWTCertException +from ansible_base.lib.utils.settings import get_setting logger = logging.getLogger('ansible_base.jwt_consumer.common.util') @@ -42,6 +44,15 @@ def validate_x_trusted_proxy_header(header_value: str, ignore_cache=False) -> bo logger.warning("Failed to validate x-trusted-proxy-header, malformed, expected value to contain a -") return False + # Validate that the header has been cut within the last 300ms (by default) + try: + if time.time_ns() - int(timestamp) > get_setting('trusted_header_timeout_in_ns', 300000000): + logger.warning(f"Timestamp {timestamp} was too old to be valid alter trusted_header_timeout_in_ns if needed") + return False + except ValueError: + logger.warning(f"Unable to convert timestamp (base64) {b64encode(timestamp.encode('UTF-8'))} into an integer") + return False + try: public_key.verify( bytes.fromhex(signature), diff --git a/test_app/tests/jwt_consumer/common/test_util.py b/test_app/tests/jwt_consumer/common/test_util.py index fdf4cb5a1..19956a614 100644 --- a/test_app/tests/jwt_consumer/common/test_util.py +++ b/test_app/tests/jwt_consumer/common/test_util.py @@ -1,3 +1,4 @@ +import time from unittest import mock from django.test.utils import override_settings @@ -29,3 +30,24 @@ def test_validate_trusted_proxy_header_fail_load_public_key(self, mock_load_pem_ def test_validate_trusted_proxy_header_bad_public_key(self, random_public_key): with override_settings(ANSIBLE_BASE_JWT_KEY=random_public_key): assert not validate_x_trusted_proxy_header("0-12345123451234512345") + + def test_header_timeout(self, expected_log, rsa_keypair): + header = generate_x_trusted_proxy_header(rsa_keypair.private) + with override_settings(ANSIBLE_BASE_JWT_KEY=rsa_keypair.public): + # Assert this header is valid if used right away + assert validate_x_trusted_proxy_header(header) is True + + # By default the header is only valid for 300ms so a 1/2 second sleep will expire it + time.sleep(0.5) + with expected_log( + 'ansible_base.jwt_consumer.common.util.logger', 'warning', 'was too old to be valid alter trusted_header_timeout_in_ns if needed' + ): + assert validate_x_trusted_proxy_header(header) is False + + def test_invalid_header_timestamp(self, expected_log, rsa_keypair): + header = generate_x_trusted_proxy_header(rsa_keypair.private) + _, signed_part = header.split('-') + header = f'asdf-{signed_part}' + with override_settings(ANSIBLE_BASE_JWT_KEY=rsa_keypair.public): + with expected_log('ansible_base.jwt_consumer.common.util.logger', 'warning', 'Unable to convert timestamp (base64)'): + assert validate_x_trusted_proxy_header(header) is False