Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support configurable token endpoint #137

Merged
merged 11 commits into from
Nov 4, 2024
43 changes: 30 additions & 13 deletions openfga_sdk/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
NOTE: This file was auto generated by OpenAPI Generator (https://openapi-generator.tech). DO NOT EDIT.
"""

from urllib.parse import urlparse
from urllib.parse import urlparse, urlunparse

from openfga_sdk.exceptions import ApiValueError

Expand Down Expand Up @@ -160,6 +160,33 @@ def configuration(self, value):
"""
self._configuration = value

def _parse_issuer(self, issuer: str):
default_endpoint_path = "/oauth/token"

parsed_url = urlparse(issuer.strip())

try:
parsed_url.port
except ValueError as e:
raise ApiValueError(e)

if parsed_url.netloc is None and parsed_url.path is None:
raise ApiValueError("Invalid issuer")

if parsed_url.scheme == "":
parsed_url = urlparse(f"https://{issuer}")
elif parsed_url.scheme not in ("http", "https"):
raise ApiValueError(
f"Invalid issuer scheme {parsed_url.scheme} must be HTTP or HTTPS"
)

if parsed_url.path in ("", "/"):
parsed_url = parsed_url._replace(path=default_endpoint_path)

valid_url = urlunparse(parsed_url)

return valid_url

def validate_credentials_config(self):
"""
Check whether credentials configuration is valid
Expand Down Expand Up @@ -190,15 +217,5 @@ def validate_credentials_config(self):
"configuration `{}` requires client_id, client_secret, api_audience and api_issuer defined for client_credentials method."
)
# validate token issuer
combined_url = "https://" + self.configuration.api_issuer
parsed_url = None
try:
parsed_url = urlparse(combined_url)
except ValueError:
raise ApiValueError(
f"api_issuer `{self.configuration.api_issuer}` is invalid"
)
if parsed_url.netloc == "":
raise ApiValueError(
f"api_issuer `{self.configuration.api_issuer}` is invalid"
)

self._parse_issuer(self.configuration.api_issuer)
14 changes: 9 additions & 5 deletions test/client/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2314,12 +2314,16 @@ async def test_list_relations(self, mock_request):
Check whether a user is authorized to access an object
"""

def mock_check_requests(*args, **kwargs):
body = kwargs.get("body")
tuple_key = body.get("tuple_key")
if tuple_key["relation"] == "owner":
return mock_response('{"allowed": false, "resolution": "1234"}', 200)
return mock_response('{"allowed": true, "resolution": "1234"}', 200)

# First, mock the response
mock_request.side_effect = [
mock_response('{"allowed": true, "resolution": "1234"}', 200),
mock_response('{"allowed": false, "resolution": "1234"}', 200),
mock_response('{"allowed": true, "resolution": "1234"}', 200),
]
mock_request.side_effect = mock_check_requests

configuration = self.configuration
configuration.store_id = store_id
async with OpenFgaClient(configuration) as api_client:
Expand Down
56 changes: 56 additions & 0 deletions test/credentials_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import openfga_sdk
from openfga_sdk.credentials import CredentialConfiguration, Credentials
from openfga_sdk.exceptions import ApiValueError


class TestCredentials(IsolatedAsyncioTestCase):
Expand Down Expand Up @@ -172,3 +173,58 @@ def test_configuration_client_credentials_missing_api_audience(self):
)
with self.assertRaises(openfga_sdk.ApiValueError):
credential.validate_credentials_config()


class TestCredentialsIssuer(IsolatedAsyncioTestCase):
def setUp(self):
# Setup a basic configuration that can be modified per test case
self.configuration = CredentialConfiguration(api_issuer="https://example.com")
self.credentials = Credentials(
method="client_credentials", configuration=self.configuration
)

def test_valid_issuer_https(self):
# Test a valid HTTPS URL
self.configuration.api_issuer = "issuer.fga.example "
result = self.credentials._parse_issuer(self.configuration.api_issuer)
self.assertEqual(result, "https://issuer.fga.example/oauth/token")

def test_valid_issuer_with_oauth_endpoint_https(self):
# Test a valid HTTPS URL
self.configuration.api_issuer = "https://example.com/oauth/token"
result = self.credentials._parse_issuer(self.configuration.api_issuer)
self.assertEqual(result, "https://example.com/oauth/token")

def test_valid_issuer_with_some_endpoint_https(self):
# Test a valid HTTPS URL
self.configuration.api_issuer = "https://example.com/oauth/some/endpoint"
result = self.credentials._parse_issuer(self.configuration.api_issuer)
self.assertEqual(result, "https://example.com/oauth/some/endpoint")

def test_valid_issuer_http(self):
# Test a valid HTTP URL
self.configuration.api_issuer = "fga.example/some_endpoint"
result = self.credentials._parse_issuer(self.configuration.api_issuer)
self.assertEqual(result, "https://fga.example/some_endpoint")

def test_invalid_issuer_no_scheme(self):
# Test an issuer URL without a scheme
self.configuration.api_issuer = "https://issuer.fga.example:8080/some_endpoint "
result = self.credentials._parse_issuer(self.configuration.api_issuer)
self.assertEqual(result, "https://issuer.fga.example:8080/some_endpoint")

def test_invalid_issuer_bad_scheme(self):
# Test an issuer with an unsupported scheme
self.configuration.api_issuer = "ftp://example.com"
with self.assertRaises(ApiValueError):
self.credentials._parse_issuer(self.configuration.api_issuer)

def test_invalid_issuer_with_port(self):
# Test an issuer with an unsupported scheme
self.configuration.api_issuer = "https://issuer.fga.example:8080 "
result = self.credentials._parse_issuer(self.configuration.api_issuer)
self.assertEqual(result, "https://issuer.fga.example:8080/oauth/token")


if __name__ == "__main__":
unittest.main()
14 changes: 9 additions & 5 deletions test/sync/client/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2317,12 +2317,16 @@ def test_list_relations(self, mock_request):
Check whether a user is authorized to access an object
"""

def mock_check_requests(*args, **kwargs):
body = kwargs.get("body")
tuple_key = body.get("tuple_key")
if tuple_key["relation"] == "owner":
return mock_response('{"allowed": false, "resolution": "1234"}', 200)
return mock_response('{"allowed": true, "resolution": "1234"}', 200)

# First, mock the response
mock_request.side_effect = [
mock_response('{"allowed": true, "resolution": "1234"}', 200),
mock_response('{"allowed": false, "resolution": "1234"}', 200),
mock_response('{"allowed": true, "resolution": "1234"}', 200),
]
mock_request.side_effect = mock_check_requests

configuration = self.configuration
configuration.store_id = store_id
with OpenFgaClient(configuration) as api_client:
Expand Down