Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for custom domains #12

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 36 additions & 17 deletions aiosfstream/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,40 @@
from aiosfstream.exceptions import AuthenticationError


TOKEN_URL = "https://login.salesforce.com/services/oauth2/token"
SANDBOX_TOKEN_URL = "https://test.salesforce.com/services/oauth2/token"
# formatted later with a domain
BASE_URL = "https://{}.salesforce.com/services/oauth2/token"
LOGIN_DOMAIN = "login"
SANDBOX_DOMAIN = "test"


# pylint: disable=too-many-instance-attributes

class AuthenticatorBase(AuthExtension):
"""Abstract base class to serve as a base for implementing concrete
authenticators"""
def __init__(self, sandbox: bool = False,
def __init__(self, sandbox: bool = None, domain: str = None,
json_dumps: JsonDumper = json.dumps,
json_loads: JsonLoader = json.loads) -> None:
"""
:param sandbox: Marks whether the authentication has to be done \
for a sandbox org or for a production org
:param sandbox: Marks whether the connection has to be made with \
a sandbox org or with a production org. Cannot be used concurrently with \
a value for domain.
:param domain: A custom salesforce domain instead of 'login' or 'test'. \
Cannot be used concurrently with a value for sandbox
:param json_dumps: Function for JSON serialization, the default is \
:func:`json.dumps`
:param json_loads: Function for JSON deserialization, the default is \
:func:`json.loads`
"""
#: Marks whether the authentication has to be done for a sandbox org \
#: or for a production org
self._sandbox = sandbox
if sandbox is not None and domain is not None:
raise ValueError('You cannot specify a value for sandbox AND domain. Please use just one.')
elif domain is not None:
self._domain = domain
elif sandbox is True:
self._domain = SANDBOX_DOMAIN
else:
self._domain = LOGIN_DOMAIN

#: Salesforce session ID that can be used with the web services API
self.access_token: Optional[str] = None
#: Value is Bearer for all responses that include an access token
Expand All @@ -60,9 +71,7 @@ def __init__(self, sandbox: bool = False,
@property
def _token_url(self) -> str:
"""The URL that should be used for token requests"""
if self._sandbox:
return SANDBOX_TOKEN_URL
return TOKEN_URL
return BASE_URL.format(self._domain)

async def outgoing(self, payload: Payload, headers: Headers) -> None:
"""Process outgoing *payload* and *headers*
Expand Down Expand Up @@ -124,7 +133,8 @@ async def _authenticate(self) -> Tuple[int, JsonObject]:
class PasswordAuthenticator(AuthenticatorBase):
"""Authenticator for using the OAuth 2.0 Username-Password Flow"""
def __init__(self, consumer_key: str, consumer_secret: str,
username: str, password: str, sandbox: bool = False,
username: str, password: str,
sandbox: bool = None, domain: str = None,
json_dumps: JsonDumper = json.dumps,
json_loads: JsonLoader = json.loads) -> None:
"""
Expand All @@ -134,14 +144,18 @@ def __init__(self, consumer_key: str, consumer_secret: str,
connected app definition
:param username: Salesforce username
:param password: Salesforce password
:param sandbox: Marks whether the authentication has to be done \
for a sandbox org or for a production org
:param sandbox: Marks whether the connection has to be made with \
a sandbox org or with a production org. Cannot be used concurrently with \
a value for domain.
:param domain: A custom salesforce domain instead of 'login' or 'test'. \
Cannot be used concurrently with a value for sandbox
:param json_dumps: Function for JSON serialization, the default is \
:func:`json.dumps`
:param json_loads: Function for JSON deserialization, the default is \
:func:`json.loads`
"""
super().__init__(sandbox=sandbox,
domain=domain,
json_dumps=json_dumps,
json_loads=json_loads)
#: OAuth2 client id
Expand Down Expand Up @@ -178,7 +192,8 @@ async def _authenticate(self) -> Tuple[int, JsonObject]:
class RefreshTokenAuthenticator(AuthenticatorBase):
"""Authenticator for using the OAuth 2.0 Refresh Token Flow"""
def __init__(self, consumer_key: str, consumer_secret: str,
refresh_token: str, sandbox: bool = False,
refresh_token: str,
sandbox: bool = None, domain: str = None,
json_dumps: JsonDumper = json.dumps,
json_loads: JsonLoader = json.loads) -> None:
"""
Expand All @@ -189,14 +204,18 @@ def __init__(self, consumer_key: str, consumer_secret: str,
:param refresh_token: A refresh token obtained from Salesforce \
by using one of its authentication methods (for example with the \
OAuth 2.0 Web Server Authentication Flow)
:param sandbox: Marks whether the authentication has to be done \
for a sandbox org or for a production org
:param sandbox: Marks whether the connection has to be made with \
a sandbox org or with a production org. Cannot be used concurrently with \
a value for domain.
:param domain: A custom salesforce domain instead of 'login' or 'test'. \
Cannot be used concurrently with a value for sandbox
:param json_dumps: Function for JSON serialization, the default is \
:func:`json.dumps`
:param json_loads: Function for JSON deserialization, the default is \
:func:`json.loads`
"""
super().__init__(sandbox=sandbox,
domain=domain,
json_dumps=json_dumps,
json_loads=json_loads)
#: OAuth2 client id
Expand Down
9 changes: 7 additions & 2 deletions aiosfstream/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ def __init__(self, *, # pylint: disable=too-many-locals
replay_storage_policy: ReplayMarkerStoragePolicy
= ReplayMarkerStoragePolicy.AUTOMATIC,
connection_timeout: Union[int, float] = 10.0,
max_pending_count: int = 100, sandbox: bool = False,
max_pending_count: int = 100,
sandbox: bool = None, domain: str = None,
json_dumps: JsonDumper = json.dumps,
json_loads: JsonLoader = json.loads,
loop: Optional[asyncio.AbstractEventLoop] = None):
Expand Down Expand Up @@ -326,7 +327,10 @@ def __init__(self, *, # pylint: disable=too-many-locals
consumed. \
If it is less than or equal to zero, the count is infinite.
:param sandbox: Marks whether the connection has to be made with \
a sandbox org or with a production org
a sandbox org or with a production org. Cannot be used concurrently with \
a value for domain.
:param domain: A custom salesforce domain instead of 'login' or 'test'. \
Cannot be used concurrently with a value for sandbox
:param json_dumps: Function for JSON serialization, the default is \
:func:`json.dumps`
:param json_loads: Function for JSON deserialization, the default is \
Expand All @@ -342,6 +346,7 @@ def __init__(self, *, # pylint: disable=too-many-locals
username=username,
password=password,
sandbox=sandbox,
domain=domain,
json_dumps=json_dumps,
json_loads=json_loads,
)
Expand Down
25 changes: 22 additions & 3 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aiohttp.client_exceptions import ClientError

from aiosfstream.auth import AuthenticatorBase, PasswordAuthenticator, \
TOKEN_URL, SANDBOX_TOKEN_URL, RefreshTokenAuthenticator
LOGIN_DOMAIN, SANDBOX_DOMAIN, BASE_URL, RefreshTokenAuthenticator
from aiosfstream.exceptions import AuthenticationError


Expand Down Expand Up @@ -131,12 +131,31 @@ async def test_incoming(self):
def test_token_url_non_sandbox(self):
auth = Authenticator()

self.assertEqual(auth._token_url, TOKEN_URL)
self.assertEqual(auth._token_url, BASE_URL.format(LOGIN_DOMAIN))

def test_token_url_sandbox(self):
auth = Authenticator(sandbox=True)

self.assertEqual(auth._token_url, SANDBOX_TOKEN_URL)
self.assertEqual(auth._token_url, BASE_URL.format(SANDBOX_DOMAIN))

def test_custom_url_sandbox(self):
domain = 'sparkles'
auth = Authenticator(domain=domain)
self.assertEqual(auth._token_url, BASE_URL.format(domain))

def test_sandbox_true_with_custom_domain(self):
domain = 'sparkles'
sandbox = True
with self.assertRaisesRegex(ValueError,
"You cannot specify a value for sandbox AND domain"):
auth = Authenticator(sandbox=sandbox, domain=domain)

def test_sandbox_false_with_custom_domain(self):
domain = 'sparkles'
sandbox = False
with self.assertRaisesRegex(ValueError,
"You cannot specify a value for sandbox AND domain"):
auth = Authenticator(sandbox=sandbox, domain=domain)


class TestPasswordAuthenticator(TestCase):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def test_init(self, super_init, authenticator_cls):
json_loads = object()
loop = object()
sandbox_enabled = True
domain = None

SalesforceStreamingClient(
consumer_key=consumer_key,
Expand All @@ -430,6 +431,7 @@ def test_init(self, super_init, authenticator_cls):
connection_timeout=connection_timeout,
max_pending_count=max_pending_count,
sandbox=sandbox_enabled,
domain=domain,
json_dumps=json_dumps,
json_loads=json_loads,
loop=loop
Expand All @@ -441,6 +443,7 @@ def test_init(self, super_init, authenticator_cls):
username=username,
password=password,
sandbox=sandbox_enabled,
domain=domain,
json_dumps=json_dumps,
json_loads=json_loads,
)
Expand Down