Skip to content
This repository has been archived by the owner on Oct 14, 2024. It is now read-only.

Feature/support sync credentials #45

Merged
merged 5 commits into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kiota_authentication_azure/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION: str = '0.1.1'
VERSION: str = '0.2.0'
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from typing import TYPE_CHECKING, Dict, List, Optional
import inspect
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from urllib.parse import urlparse

from azure.core.credentials import TokenCredential
from azure.core.credentials_async import AsyncTokenCredential
from kiota_abstractions.authentication import AccessTokenProvider, AllowedHostsValidator

if TYPE_CHECKING:
from azure.core.credentials_async import AsyncTokenCredential


class AzureIdentityAccessTokenProvider(AccessTokenProvider):
"""Access token provider that leverages the Azure Identity library to retrieve an access token.
"""

def __init__(
self,
credentials: "AsyncTokenCredential",
credentials: Union[TokenCredential, AsyncTokenCredential],
options: Optional[Dict],
scopes: List[str] = ['https://graph.microsoft.com/.default'],
allowed_hosts: List[str] = [
Expand Down Expand Up @@ -45,11 +45,18 @@ async def get_authorization_token(self, uri: str) -> str:
parsed_url = urlparse(uri)
if not parsed_url.scheme == 'https':
raise Exception("Only https is supported")
#async credentials
if inspect.iscoroutinefunction(self._credentials.get_token):
if self._options:
result = await self._credentials.get_token(*self._scopes, **self._options)
result = await self._credentials.get_token(*self._scopes)
await self._credentials.close() #type: ignore
baywet marked this conversation as resolved.
Show resolved Hide resolved
# sync credentials
else:
if self._options:
result = self._credentials.get_token(*self._scopes, **self._options)
result = self._credentials.get_token(*self._scopes)

if self._options:
result = await self._credentials.get_token(*self._scopes, **self._options)
result = await self._credentials.get_token(*self._scopes)
await self._credentials.close()
if result and result.token:
return result.token
return ""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from kiota_abstractions.authentication import BaseBearerTokenAuthenticationProvider

from .azure_identity_access_token_provider import AzureIdentityAccessTokenProvider

if TYPE_CHECKING:
from azure.core.credentials import TokenCredential
from azure.core.credentials_async import AsyncTokenCredential


class AzureIdentityAuthenticationProvider(BaseBearerTokenAuthenticationProvider):

def __init__(
self,
credentials: "AsyncTokenCredential",
credentials: Union["TokenCredential", "AsyncTokenCredential"],
options: Optional[Dict] = None,
scopes: List[str] = ['https://graph.microsoft.com/.default'],
allowed_hosts: List[str] = [
Expand Down
8 changes: 7 additions & 1 deletion tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@ class DummyToken:
token: str


class DummyAzureTokenCredential():
class DummySyncAzureTokenCredential():

def get_token(self, *args):
return DummyToken(token="This is a dummy token")


class DummyAsyncAzureTokenCredential():

async def get_token(self, *args):
return DummyToken(token="This is a dummy token")
Expand Down
26 changes: 18 additions & 8 deletions tests/test_azure_identity_access_token_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
AzureIdentityAccessTokenProvider,
)

from .helpers import DummyAzureTokenCredential
from .helpers import DummyAsyncAzureTokenCredential, DummySyncAzureTokenCredential


def test_invalid_instantiation_without_credentials():
Expand All @@ -14,17 +14,19 @@ def test_invalid_instantiation_without_credentials():


def test_valid_instantiation_without_options():
token_provider = AzureIdentityAccessTokenProvider(DummyAzureTokenCredential(), None)
token_provider = AzureIdentityAccessTokenProvider(DummyAsyncAzureTokenCredential(), None)
assert not token_provider._options


def test_invalid_instatiation_without_scopes():
with pytest.raises(Exception):
token_provider = AzureIdentityAccessTokenProvider(DummyAzureTokenCredential(), None, None)
token_provider = AzureIdentityAccessTokenProvider(
DummyAsyncAzureTokenCredential(), None, None
)


def test_get_allowed_hosts_validator():
token_provider = AzureIdentityAccessTokenProvider(DummyAzureTokenCredential(), None)
token_provider = AzureIdentityAccessTokenProvider(DummySyncAzureTokenCredential(), None)
validator = token_provider.get_allowed_hosts_validator()
hosts = validator.get_allowed_hosts()
assert isinstance(validator, AllowedHostsValidator)
Expand All @@ -36,23 +38,31 @@ def test_get_allowed_hosts_validator():


@pytest.mark.asyncio
async def test_get_authorization_token():
async def test_get_authorization_token_async():

token_provider = AzureIdentityAccessTokenProvider(DummyAzureTokenCredential(), None)
token_provider = AzureIdentityAccessTokenProvider(DummyAsyncAzureTokenCredential(), None)
token = await token_provider.get_authorization_token('https://graph.microsoft.com')
assert token == "This is a dummy token"


@pytest.mark.asyncio
async def test_get_authorization_token_sync():

token_provider = AzureIdentityAccessTokenProvider(DummySyncAzureTokenCredential(), None)
token = await token_provider.get_authorization_token('https://graph.microsoft.com')
assert token == "This is a dummy token"


@pytest.mark.asyncio
async def test_get_authorization_token_invalid_url():

token_provider = AzureIdentityAccessTokenProvider(DummyAzureTokenCredential(), None)
token_provider = AzureIdentityAccessTokenProvider(DummyAsyncAzureTokenCredential(), None)
token = await token_provider.get_authorization_token('https://example.com')
assert token == ""


@pytest.mark.asyncio
async def test_get_authorization_token_invalid_scheme():
with pytest.raises(Exception):
token_provider = AzureIdentityAccessTokenProvider(DummyAzureTokenCredential(), None)
token_provider = AzureIdentityAccessTokenProvider(DummySyncAzureTokenCredential(), None)
token = await token_provider.get_authorization_token('http://graph.microsoft.com')
4 changes: 2 additions & 2 deletions tests/test_azure_identity_authentication_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
AzureIdentityAuthenticationProvider,
)

from .helpers import DummyAzureTokenCredential
from .helpers import DummyAsyncAzureTokenCredential, DummySyncAzureTokenCredential


def test_invalid_instantiation_without_credentials():
Expand All @@ -15,7 +15,7 @@ def test_invalid_instantiation_without_credentials():

@pytest.mark.asyncio
async def test_valid_instantiation_without_options():
auth_provider = AzureIdentityAuthenticationProvider(DummyAzureTokenCredential())
auth_provider = AzureIdentityAuthenticationProvider(DummyAsyncAzureTokenCredential())
request_info = RequestInformation()
request_info.url = "https://graph.microsoft.com"
await auth_provider.authenticate_request(request_info)
Expand Down