Skip to content

Commit

Permalink
feat(models): make models Schema ordered
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeremy Desanlis authored and Jguer committed Jun 21, 2021
1 parent dcec5a2 commit f9a3bb7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
23 changes: 14 additions & 9 deletions pygitguardian/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
from .config import DOCUMENT_SIZE_THRESHOLD_BYTES


class BaseSchema(Schema):
class Meta:
ordered = True


class Base:
SCHEMA: ClassVar[Schema]
SCHEMA: ClassVar[BaseSchema]

def __init__(self):
self.status_code = None
Expand All @@ -41,7 +46,7 @@ def __bool__(self) -> bool:
return self.status_code == 200


class DocumentSchema(Schema):
class DocumentSchema(BaseSchema):
class Meta:
unknown = EXCLUDE

Expand Down Expand Up @@ -89,7 +94,7 @@ def __repr__(self):
return "filename:{0}, document:{1}".format(self.filename, self.document)


class DetailSchema(Schema):
class DetailSchema(BaseSchema):
detail = fields.String(required=True)

@pre_load
Expand Down Expand Up @@ -124,7 +129,7 @@ def __repr__(self):
return "{0}:{1}".format(self.status_code, self.detail)


class MatchSchema(Schema):
class MatchSchema(BaseSchema):
match = fields.String(required=True)
match_type = fields.String(data_key="type", required=True)
line_start = fields.Int(allow_none=True)
Expand Down Expand Up @@ -188,7 +193,7 @@ def __repr__(self):
)


class PolicyBreakSchema(Schema):
class PolicyBreakSchema(BaseSchema):
break_type = fields.String(data_key="type", required=True)
policy = fields.String(required=True)
matches = fields.List(fields.Nested(MatchSchema), required=True)
Expand Down Expand Up @@ -226,7 +231,7 @@ def __repr__(self):
)


class ScanResultSchema(Schema):
class ScanResultSchema(BaseSchema):
policy_break_count = fields.Integer(required=True)
policies = fields.List(fields.String(), required=True)
policy_breaks = fields.List(fields.Nested(PolicyBreakSchema), required=True)
Expand Down Expand Up @@ -308,7 +313,7 @@ def __str__(self):
)


class MultiScanResultSchema(Schema):
class MultiScanResultSchema(BaseSchema):
scan_results = fields.List(
fields.Nested(ScanResultSchema),
required=True,
Expand Down Expand Up @@ -379,7 +384,7 @@ def __str__(self):
)


class QuotaSchema(Schema):
class QuotaSchema(BaseSchema):
count = fields.Int()
limit = fields.Int()
remaining = fields.Int()
Expand Down Expand Up @@ -419,7 +424,7 @@ def __repr__(self):
)


class QuotaResponseSchema(Schema):
class QuotaResponseSchema(BaseSchema):
content = fields.Nested(QuotaSchema)

@post_load
Expand Down
11 changes: 6 additions & 5 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from collections import OrderedDict
from datetime import date
from typing import Any, Dict, List, Optional, Type
from unittest.mock import Mock, patch
Expand Down Expand Up @@ -272,7 +273,7 @@ def test_health_check(client: GGClient):
assert bool(health)
assert health.success

assert type(health.to_dict()) == dict
assert type(health.to_dict()) == OrderedDict
assert type(health.to_json()) == str


Expand All @@ -285,7 +286,7 @@ def test_health_check_error(client: GGClient):
assert bool(health) is False
assert health.success is False

assert type(health.to_dict()) == dict
assert type(health.to_dict()) == OrderedDict
assert type(health.to_json()) == str


Expand Down Expand Up @@ -331,7 +332,7 @@ def test_multi_content_scan(
pytest.fail("multiscan is not a MultiScanResult")
return

assert type(multiscan.to_dict()) == dict
assert type(multiscan.to_dict()) == OrderedDict
assert type(multiscan.to_json()) == str
assert type(repr(multiscan)) == str
assert type(str(multiscan)) == str
Expand Down Expand Up @@ -463,7 +464,7 @@ def test_content_scan(
else:
pytest.fail("returned should be a ScanResult")

assert type(scan_result.to_dict()) == dict
assert type(scan_result.to_dict()) == OrderedDict
scan_result_json = scan_result.to_json()
assert type(scan_result_json) == str
assert type(json.loads(scan_result_json)) == dict
Expand Down Expand Up @@ -564,7 +565,7 @@ def test_quota_overview(client: GGClient):
else:
pytest.fail("returned should be a QuotaResponse")

assert type(quota_response.to_dict()) == dict
assert type(quota_response.to_dict()) == OrderedDict
quota_response_json = quota_response.to_json()
assert type(quota_response_json) == str
assert type(json.loads(quota_response_json)) == dict

0 comments on commit f9a3bb7

Please sign in to comment.