Skip to content

Commit

Permalink
Merge pull request #1070 from SEKOIA-IO/fix/TrellixRefreshToken
Browse files Browse the repository at this point in the history
Trellix: fix refresh of the access token
  • Loading branch information
squioc authored Aug 8, 2024
2 parents 4465747 + 47751c8 commit 0837bcb
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 10 deletions.
6 changes: 6 additions & 0 deletions Trellix/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## 2024-08-08 - 1.9.2

### Fixed

- Ensure we have a fresh access token when the previous one expired

## 2024-08-07 - 1.9.1

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion Trellix/client/schemas/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,4 +370,4 @@ def is_expired(self) -> bool:
Returns:
bool:
"""
return self.created_at + self.token.expires_in > (time() - 1)
return self.created_at + self.token.expires_in > time() - 300
12 changes: 10 additions & 2 deletions Trellix/client/token_refresher.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,16 @@ async def refresh_token(self) -> None:
if response.status >= 400:
raise AuthenticationFailed.from_http_response(response_data)

access_token = HttpToken(**response_data)
logger.info(
"Got new access token",
expires_in=access_token.expires_in,
token_type=access_token.token_type,
token_id=access_token.tid,
)

self._token = TrellixToken(
token=HttpToken(**response_data),
token=access_token,
scopes=self.scopes,
created_at=time.time(),
)
Expand Down Expand Up @@ -199,7 +207,7 @@ async def with_access_token(self) -> AsyncGenerator[TrellixToken, None]:
Yields:
TrellixToken:
"""
if self._token is None:
if self._token is None or self._token.is_expired():
await self.refresh_token()

if not self._token:
Expand Down
2 changes: 1 addition & 1 deletion Trellix/connectors/trellix_edr_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def run(self) -> None: # pragma: no cover
delta_sleep = self.configuration.frequency - processing_time
if len(message_ids) == 0 and delta_sleep > 0:
self.log(message=f"Next batch in the future. Waiting {delta_sleep} seconds", level="info")
time.sleep(delta_sleep)
loop.run_until_complete(asyncio.sleep(delta_sleep))

except Exception as e:
self.log_exception(e, message="Error while running Trellix EDR")
Expand Down
2 changes: 1 addition & 1 deletion Trellix/connectors/trellix_epo_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def run(self) -> None: # pragma: no cover
delta_sleep = self.configuration.frequency - processing_time
if len(message_ids) == 0 and delta_sleep > 0:
self.log(message=f"Next batch in the future. Waiting {delta_sleep} seconds", level="info")
time.sleep(delta_sleep)
loop.run_until_complete(asyncio.sleep(delta_sleep))

except Exception as e:
self.log_exception(e, message="Error while running Trellix EPO")
Expand Down
2 changes: 1 addition & 1 deletion Trellix/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,5 @@
"name": "Trellix",
"uuid": "888071f8-1456-11ee-be56-0242ac120002",
"slug": "trellix",
"version": "1.9.1"
"version": "1.9.2"
}
3 changes: 3 additions & 0 deletions Trellix/tests/client/schemas/test_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,7 @@ def test_trellix_token_expired(sample_trellix_token):
assert sample_trellix_token.is_expired() is True

sample_trellix_token.created_at = int(time()) - sample_trellix_token.token.expires_in - 1
assert sample_trellix_token.is_expired() is True

sample_trellix_token.created_at = int(time()) - sample_trellix_token.token.expires_in - 300
assert sample_trellix_token.is_expired() is False
33 changes: 33 additions & 0 deletions Trellix/tests/client/test_token_refresher.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,36 @@ async def test_trellix_refresher_auth_url(
)

assert token_refresher.auth_url.path == expected_auth_path


@pytest.mark.asyncio
async def test_trellix_refresher_always_provide_fresh_token(http_token, session_faker, token_refresher_session):
"""
Test TrellixTokenRefresher.with_access_token method.
Args:
http_token: HttpToken
session_faker: Faker
token_refresher_session: MagicMock
"""
token_refresher_session.post = MagicMock()
token_refresher_session.post.return_value.__aenter__.return_value.status = 200
token_refresher_session.post.return_value.__aenter__.return_value.json.side_effect = [
{"tid": 233264798, "token_type": "Bearer", "expires_in": 1, "access_token": "token_expired_quickly"},
{"tid": 233264799, "token_type": "Bearer", "expires_in": 600, "access_token": "fresh_token"},
]

token_refresher = TrellixTokenRefresher(
session_faker.word(),
session_faker.word(),
session_faker.word(),
session_faker.uri(),
Scope.complete_set_of_scopes(),
)

await token_refresher.refresh_token()

async with token_refresher.with_access_token() as token:
assert token.token.access_token == "fresh_token"

await token_refresher.close()
2 changes: 1 addition & 1 deletion Trellix/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def http_token(session_faker) -> HttpToken:
return HttpToken(
tid=session_faker.pyint(),
token_type=session_faker.word(),
expires_in=session_faker.pyint(min_value=100, max_value=1000),
expires_in=session_faker.pyint(min_value=500, max_value=1000),
access_token=session_faker.word(),
)

Expand Down
8 changes: 5 additions & 3 deletions Trellix/tests/connectors/test_trellix_edr_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ async def test_trellix_connector_get_detection_events(

token_refresher = await http_client._get_token_refresher(Scope.threats_set_of_scopes())

mocked_responses.post(token_refresher.auth_url, status=200, payload=http_token.dict())
# Each mock for each request for fresh token. We should have 3 requests
mocked_responses.post(token_refresher.auth_url, status=200, payload=http_token.dict(), repeat=3)

first_request_expected_detections_result = [
edr_detection_event_response.dict(exclude_none=True) for _ in range(0, session_faker.pyint(max_value=100))
Expand Down Expand Up @@ -234,7 +235,8 @@ async def test_trellix_connector_get_affectedhosts_events(

token_refresher = await http_client._get_token_refresher(Scope.threats_set_of_scopes())

mocked_responses.post(token_refresher.auth_url, status=200, payload=http_token.dict())
# Each mock for each request for fresh token. We should have 3 requests
mocked_responses.post(token_refresher.auth_url, status=200, payload=http_token.dict(), repeat=3)

first_request_expected_detections_result = [
edr_affectedhost_event_response.dict(exclude_none=True)
Expand Down Expand Up @@ -328,7 +330,7 @@ async def test_trellix_connector_get_threats_events(

# Mocks #1
# We mock request to get token
mocked_responses.post(token_refresher.auth_url, status=200, payload=http_token.dict())
mocked_responses.post(token_refresher.auth_url, status=200, payload=http_token.dict(), repeat=100)

# Mocks #2
# We mock request to get threats. Let`s say we should send 2 requests. First request
Expand Down

0 comments on commit 0837bcb

Please sign in to comment.