Skip to content

Commit

Permalink
Merge pull request #1
Browse files Browse the repository at this point in the history
* adding boilerplate

* Add scope management and reorganize project structure

* Update .gitignore to exclude poetry.lock file

* remove

* Delete .idea/vcs.xml configuration file

* Refactor AuthRequest to handle `scope` construction internally

* Add .idea/vcs.xml to .gitignore

* remove .idea

* Add IntelliJ IDEA project configuration files
  • Loading branch information
cdot65 authored Oct 10, 2024
1 parent c9e4a80 commit 74a93a4
Show file tree
Hide file tree
Showing 17 changed files with 217 additions and 20 deletions.
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,13 @@ atlassian-ide-plugin.xml
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser

.idea/misc.xml
.idea/scm-sdk.iml

# Exclude secrets.yaml files
.secrets.yaml
secrets.yaml
secrets.yaml

# Exclude poetry lock
poetry.lock
/.idea/vcs.xml
4 changes: 0 additions & 4 deletions .idea/misc.xml

This file was deleted.

8 changes: 0 additions & 8 deletions .idea/scm-sdk.iml

This file was deleted.

6 changes: 0 additions & 6 deletions .idea/vcs.xml

This file was deleted.

1 change: 1 addition & 0 deletions pan_scm_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# pan_scm_sdk/__init__.py
1 change: 1 addition & 0 deletions pan_scm_sdk/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# pan_scm_sdk/auth/__init__.py
82 changes: 82 additions & 0 deletions pan_scm_sdk/auth/oauth2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# pan_scm_sdk/auth/oauth2.py

from requests_oauthlib import OAuth2Session
from oauthlib.oauth2 import BackendApplicationClient
from pan_scm_sdk.utils.logging import setup_logger
from pan_scm_sdk.models.auth import AuthRequest
import jwt
from jwt import PyJWKClient
from jwt.exceptions import ExpiredSignatureError

logger = setup_logger(__name__)

class OAuth2Client:
def __init__(self, auth_request: AuthRequest):
self.auth_request = auth_request
self.session = self._create_session()
self.signing_key = self._get_signing_key()

def _create_session(self):
client = BackendApplicationClient(client_id=self.auth_request.client_id)
oauth = OAuth2Session(client=client)
logger.debug("Fetching token...")

token = oauth.fetch_token(
token_url=self.auth_request.token_url,
client_id=self.auth_request.client_id,
client_secret=self.auth_request.client_secret,
scope=self.auth_request.scope,
include_client_id=True,
client_kwargs={'tsg_id': self.auth_request.tsg_id}
)
logger.debug(f"Token fetched successfully. {token}")
return oauth

def _get_signing_key(self):
jwks_uri = "/".join(
self.auth_request.token_url.split("/")[:-1]
) + "/connect/jwk_uri"
jwks_client = PyJWKClient(jwks_uri)
signing_key = jwks_client.get_signing_key_from_jwt(
self.session.token["access_token"]
)
return signing_key

def decode_token(self):
try:
payload = jwt.decode(
self.session.token["access_token"],
self.signing_key.key,
algorithms=["RS256"],
audience=self.auth_request.client_id,
)
return payload
except ExpiredSignatureError:
logger.error("Token has expired.")
raise

@property
def is_expired(self):
try:
jwt.decode(
self.session.token["access_token"],
self.signing_key.key,
algorithms=["RS256"],
audience=self.auth_request.client_id,
)
return False
except ExpiredSignatureError:
return True

def refresh_token(self):
logger.debug("Refreshing token...")
token = self.session.fetch_token(
token_url=self.auth_request.token_url,
client_id=self.auth_request.client_id,
client_secret=self.auth_request.client_secret,
scope=self.auth_request.scope,
include_client_id=True,
client_kwargs={'tsg_id': self.auth_request.tsg_id}
)
logger.debug(f"Token refreshed successfully. {token}")
self.signing_key = self._get_signing_key()
55 changes: 55 additions & 0 deletions pan_scm_sdk/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# pan_scm_sdk/client.py

from pan_scm_sdk.auth.oauth2 import OAuth2Client
from pan_scm_sdk.models.auth import AuthRequest
from pan_scm_sdk.utils.logging import setup_logger
from pan_scm_sdk.exceptions import APIError

logger = setup_logger(__name__)

class APIClient:
def __init__(
self,
client_id: str,
client_secret: str,
tsg_id: str,
api_base_url: str = "https://api.strata.paloaltonetworks.com",
):
self.api_base_url = api_base_url

# Create the AuthRequest object
try:
auth_request = AuthRequest(
client_id=client_id,
client_secret=client_secret,
tsg_id=tsg_id
)
except ValueError as e:
logger.error(f"Authentication initialization failed: {e}")
raise APIError(f"Authentication initialization failed: {e}")

self.oauth_client = OAuth2Client(auth_request)
self.session = self.oauth_client.session

def request(self, method: str, endpoint: str, **kwargs):
url = f"{self.api_base_url}{endpoint}"
logger.debug(f"Making {method} request to {url} with params {kwargs}")
try:
response = self.session.request(method, url, **kwargs)
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"API request failed: {str(e)}")
raise APIError(f"API request failed: {str(e)}") from e

def get(self, endpoint: str, **kwargs):
if self.oauth_client.is_expired:
self.oauth_client.refresh_token()
return self.request('GET', endpoint, **kwargs)

def post(self, endpoint: str, **kwargs):
if self.oauth_client.is_expired:
self.oauth_client.refresh_token()
return self.request('POST', endpoint, **kwargs)

# Implement other methods as needed
7 changes: 7 additions & 0 deletions pan_scm_sdk/endpoints/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# pan_scm_sdk/endpoints/__init__.py

API_ENDPOINTS = {
'get_example': '/example/endpoint',
'post_example': '/example/endpoint',
# Add other endpoints as needed
}
Empty file removed pan_scm_sdk/exceptions.py
Empty file.
3 changes: 3 additions & 0 deletions pan_scm_sdk/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# pan_scm_sdk/exceptions/__init__.py

from .authentication import APIError, AuthenticationError, ValidationError
10 changes: 10 additions & 0 deletions pan_scm_sdk/exceptions/authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# pan_scm_sdk/exceptions/authentication.py

class APIError(Exception):
"""Base class for API exceptions."""

class AuthenticationError(APIError):
"""Raised when authentication fails."""

class ValidationError(APIError):
"""Raised when data validation fails."""
1 change: 1 addition & 0 deletions pan_scm_sdk/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# pan_scm_sdk/models/__init__.py
22 changes: 22 additions & 0 deletions pan_scm_sdk/models/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# pan_scm_sdk/models/auth.py

from pydantic import BaseModel, Field, model_validator

class AuthRequest(BaseModel):
client_id: str
client_secret: str
tsg_id: str
scope: str = Field(default=None)
token_url: str = Field(
default="https://auth.apps.paloaltonetworks.com/am/oauth2/access_token"
)

@model_validator(mode='before')
@classmethod
def construct_scope(cls, values):
if values.get('scope') is None:
tsg_id = values.get('tsg_id')
if tsg_id is None:
raise ValueError('tsg_id is required to construct scope')
values['scope'] = f"tsg_id:{tsg_id}"
return values
1 change: 1 addition & 0 deletions pan_scm_sdk/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# pan_scm_sdk/utils/__init__.py
22 changes: 22 additions & 0 deletions pan_scm_sdk/utils/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# pan_scm_sdk/utils/logging.py

import logging
import sys

def setup_logger(name: str) -> logging.Logger:
"""Set up and return a logger with the given name."""
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)

# Console handler
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.DEBUG)

# Formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)

# Add handler to logger
logger.addHandler(ch)

return logger
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ readme = "README.md"

[tool.poetry.dependencies]
python = "^3.10"
jwt = "^1.3.1"
oauthlib = "^3.2.2"
requests-oauthlib = "^2.0.0"
setuptools = "^75.1.0"
pydantic = "^2.9.2"
pyjwt = "^2.9.0"
cryptography = "^43.0.1"


[tool.poetry.group.dev.dependencies]
Expand All @@ -24,6 +26,7 @@ pytest = "^8.3.3"
factory-boy = "^3.3.1"
termynal = "^0.12.1"
invoke = "^2.2.0"
ipython = "^8.28.0"

[build-system]
requires = ["poetry-core"]
Expand Down

0 comments on commit 74a93a4

Please sign in to comment.