Skip to content

Commit

Permalink
Merge pull request #72 from open-traffic-generator/defaults_required_…
Browse files Browse the repository at this point in the history
…and_type_assertion

Defaults, Required Params and Type assertion handling
  • Loading branch information
ashutshkumr authored Jun 15, 2021
2 parents add72e1 + f2f1680 commit f38eb67
Show file tree
Hide file tree
Showing 17 changed files with 645 additions and 224 deletions.
198 changes: 189 additions & 9 deletions snappi/snappicommon.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import importlib
import json
from typing import Pattern, ValuesView
import yaml
import requests
import io
import sys
import re

if sys.version_info[0] == 3:
unicode = str
Expand Down Expand Up @@ -120,7 +122,95 @@ def _decode(self, dict_object):
raise NotImplementedError()


class SnappiObject(SnappiBase):
class SnappiValidator(object):

_MAC_REGEX = re.compile(
r'^([\da-fA-F]{2}[:]){5}[\da-fA-F]{2}$')
_IPV6_REP1 = re.compile(r'^:[\da-fA-F].+')
_IPV6_REP2 = re.compile(r'.+[\da-fA-F]:$')
_IPV6_REP3 = re.compile(
r'^[\da-fA-F]{1,4}:' *7 + r'[\da-fA-F]{1,4}$')
_HEX_REGEX = re.compile(r'^0?x?[\da-fA-F]+$')

__slots__ = ()

def __init__(self) -> None:
pass

def validate_mac(self, mac):
if mac is None:
return False
if isinstance(mac, list):
return all([
True if self._MAC_REGEX.match(m) else False
for m in mac
])
if self._MAC_REGEX.match(mac):
return True
return False

def validate_ipv4(self, ip):
if ip is None:
return False
try:
if isinstance(ip, list):
return all([
all([0 <= int(oct) <= 255 for oct in i.split('.', 3)])
for i in ip
])
else:
return all([0 <= int(oct) <= 255 for oct in ip.split('.', 3)])
except Exception:
return False

def _validate_ipv6(self, ip):
if ip is None:
return False
if self._IPV6_REP1.match(ip) or self._IPV6_REP2.match(ip):
return False
if ip.count('::') == 0:
if self._IPV6_REP3.match(ip):
return True
else:
return False
if ip.count(':') > 7 or ip.count('::') > 1 or ip.count(':::') > 0:
return False
return True

def validate_ipv6(self, ip):
if isinstance(ip, list):
return all([
self._validate_ipv6(i) for i in ip
])
return self._validate_ipv6(ip)

def validate_hex(self, hex):
if isinstance(hex, list):
return all([
self._HEX_REGEX(h)
for h in hex
])
if self._HEX_REGEX.match(hex):
return True
return False

def validate_integer(self, value):
if not isinstance(value, list):
value = [value]
return all([isinstance(i, int) for i in value])

def validate_float(self, value):
if not isinstance(value, list):
value = [value]
return all([isinstance(i, float) for i in value])

def validate_double(self, value):
if not isinstance(value, list):
value = [value]
return all([isinstance(i, float) for i in value])


class SnappiObject(SnappiBase, SnappiValidator):
"""Base class for any /components/schemas object
Every SnappiObject is reuseable within the schema so it can
Expand All @@ -139,17 +229,38 @@ def __init__(self, parent=None, choice=None):
@property
def parent(self):
return self._parent

def _set_choice(self, name):
if 'choice' in dir(self) and '_TYPES' in dir(self) \
and 'choice' in self._TYPES and name in self._TYPES['choice']['enum']:
for enum in self._TYPES['choice']['enum']:
if enum in self._properties and name != enum:
self._properties.pop(enum)
self._properties['choice'] = name

def _get_property(self, name, default_value=None, parent=None, choice=None):
if name not in self._properties or self._properties[name] is None:
if isinstance(default_value, type) is True:
self._properties[name] = default_value(parent=parent, choice=choice)
if name in self._properties and self._properties[name] is not None:
return self._properties[name]
if isinstance(default_value, type) is True:
self._set_choice(name)
self._properties[name] = default_value(parent=parent, choice=choice)

if '_DEFAULTS' in dir(self._properties[name]) and\
'choice' in self._properties[name]._DEFAULTS:
getattr(self._properties[name], self._properties[name]._DEFAULTS['choice'])
else:
if default_value is None and name in self._DEFAULTS:
self._set_choice(name)
self._properties[name] = self._DEFAULTS[name]
else:
self._properties[name] = default_value
return self._properties[name]

def _set_property(self, name, value, choice=None):
self._properties[name] = value
if name in self._DEFAULTS and value is None:
self._properties[name] = self._DEFAULTS[name]
else:
self._properties[name] = value
if choice is not None:
self._properties['choice'] = choice
elif self._parent is not None and self._choice is not None and value is not None:
Expand All @@ -159,15 +270,18 @@ def _encode(self):
"""Helper method for serialization
"""
output = {}
self._validate_required()
for key, value in self._properties.items():
self._validate_types(key, value)
if isinstance(value, (SnappiObject, SnappiIter)):
output[key] = value._encode()
else:
elif value is not None:
output[key] = value
return output

def _decode(self, obj):
snappi_names = dir(self)
dtypes = [list, str, int, float, bool]
for property_name, property_value in obj.items():
if property_name in snappi_names:
if isinstance(property_value, dict):
Expand All @@ -176,20 +290,26 @@ def _decode(self, obj):
property_value = child[1](self, property_name)._decode(property_value)
else:
property_value = child[1]()._decode(property_value)
elif isinstance(property_value,
list) and property_name in self._TYPES:
elif isinstance(property_value, list) and \
property_name in self._TYPES and \
self._TYPES[property_name]['type'] not in dtypes:
child = self._get_child_class(property_name, True)
snappi_list = child[0]()
for item in property_value:
item = child[1]()._decode(item)
snappi_list._items.append(item)
property_value = snappi_list
elif property_name in self._DEFAULTS and property_value is None:
if isinstance(self._DEFAULTS[property_name], tuple(dtypes)):
property_value = self._DEFAULTS[property_name]
self._properties[property_name] = property_value
self._validate_types(property_name, property_value)
self._validate_required()
return self

def _get_child_class(self, property_name, is_property_list=False):
list_class = None
class_name = self._TYPES[property_name]
class_name = self._TYPES[property_name]['type']
module = importlib.import_module(self.__module__)
object_class = getattr(module, class_name)
if is_property_list is True:
Expand Down Expand Up @@ -217,6 +337,66 @@ def clone(self):
"""Creates a deep copy of the current object
"""
return self.__deepcopy__(None)

def _validate_required(self):
"""Validates the required properties are set
"""
if getattr(self, "_REQUIRED", None) is None:
return
for p in self._REQUIRED:
if p in self._properties and self._properties[p] is not None:
continue
msg = "{} is a mandatory property of {}"\
" and should not be set to None".format(
p, self.__class__
)
raise ValueError(msg)

def _validate_types(self, property_name, property_value):
common_data_types = [list, str, int, float, bool]
if property_name not in self._TYPES:
raise ValueError("Invalid Property {}".format(property_name))
details = self._TYPES[property_name]
if property_value is None and property_name not in self._DEFAULTS and \
property_name not in self._REQUIRED:
return
if 'enum' in details and property_value not in details['enum']:
msg = 'property {} shall be one of these' \
' {} enum, but got {} at {}'
raise TypeError(msg.format(
property_name, details['enum'], property_value, self.__class__
))
if details['type'] in common_data_types and \
'format' not in details:
if not isinstance(property_value, details['type']):
msg = 'property {} shall be of type {},' \
' but got {} at {}'
raise TypeError(msg.format(
property_name, details['type'], type(property_value), self.__class__
))
if details['type'] not in common_data_types:
class_name = details['type']
# TODO Need to revisit importlib
module = importlib.import_module(self.__module__)
object_class = getattr(module, class_name)
if not isinstance(property_value, object_class):
msg = 'property {} shall be of type {},' \
' but got {} at {}'
raise TypeError(msg.format(
property_name, class_name, type(property_value),
self.__class__
))
if 'format' in details:
validate_obj = getattr(self, 'validate_%s' % details['format'], None)
if validate_obj is None:
raise TypeError('{} is not a valid or unsupported format'.format(
details['format']
))
if validate_obj(property_value) is False:
msg = 'Invalid {} format, expected {} at {}'.format(
property_value, details['format'], self.__class__
)
raise TypeError(msg)


class SnappiIter(SnappiBase):
Expand Down
84 changes: 81 additions & 3 deletions snappi/snappigenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,21 @@ def _write_snappi_object(self, ref, choice_method_name=None):
if len(snappi_types) > 0:
self._write(1, '_TYPES = {')
for name, value in snappi_types:
self._write(2, "'%s': '%s'," % (name, value))
if len(value) == 1:
self._write(2, "'%s': {'%s': %s}," % (
name, list(value.keys())[0], list(value.values())[0]
))
continue
self._write(2, "'%s': %s" % (name, '{'))
for n, v in value.items():
if isinstance(v, list):
self._write(3, "'%s': [" % n)
for i in v:
self._write(4, "'%s'," % i)
self._write(3, "],")
continue
self._write(3, "'%s': %s," % (n, v))
self._write(2, "},")
self._write(1, '} # type: Dict[str, str]')
self._write()
else:
Expand All @@ -414,15 +428,43 @@ def _write_snappi_object(self, ref, choice_method_name=None):
self._write(1, '_TYPES = {} # type: Dict[str, str]')
self._write()

required, defaults = self._get_required_and_defaults(schema_object)

if len(required) > 0:
self._write(1, '_REQUIRED = {} # type: tuple(str)'.format(required))
self._write()
else:
self._write(1, '_REQUIRED= () # type: tuple(str)')
self._write()

if len(defaults) > 0:
self._write(1, '_DEFAULTS = {')
for name, value in defaults:
if isinstance(value, (list, bool, int, float, tuple)):
self._write(2, "'%s': %s," % (name, value))
else:
self._write(2, "'%s': '%s'," % (name, value))
self._write(1, '} # type: Dict[str, Union(type)]')
self._write()
else:
self._write(1, '_DEFAULTS= {} # type: Dict[str, Union(type)]')
self._write()
# write constants
# search for all simple properties with enum or
# x-constant and add them here

for enum in parse('$..enum | x-constants').find(schema_object):
for name in enum.value:
value = name
value_type = 'string'
if isinstance(enum.value, dict):
value = enum.value[name]
self._write(1, '%s = \'%s\' # type: str' % (name.upper(), value))
value_type = enum.context.value['type'] \
if 'type' in enum.context.value else 'string'
if value_type == 'string':
self._write(1, '%s = \'%s\' # type: str' % (name.upper(), value))
else:
self._write(1, '%s = %s #' % (name.upper(), value))
if len(enum.value) > 0:
self._write()

Expand Down Expand Up @@ -724,20 +766,56 @@ def _get_description(self, yobject):
# if len(line) > 0:
# doc_string.append('%s ' % line)
# return doc_string
def _get_data_types(self, yproperty):
data_type_map = {
'integer': 'int', 'string': 'str',
'boolean': 'bool', 'array': 'list',
'number': 'float', 'float': 'float',
'double': 'float'
}
if yproperty['type'] in data_type_map:
return data_type_map[yproperty['type']]
else:
return yproperty['type']

def _get_snappi_types(self, yobject):
types = []
if 'properties' in yobject:
for name in yobject['properties']:
yproperty = yobject['properties'][name]
ref = parse("$..'$ref'").find(yproperty)
pt = {}
if 'type' in yproperty:
pt.update({'type': self._get_data_types(yproperty)})
pt.update({'enum': yproperty['enum']}) if 'enum' in yproperty else None
pt.update({
'format': "\'%s\'" % yproperty['format']
}) if 'format' in yproperty else None
if len(ref) > 0:
object_name = ref[0].value.split('/')[-1]
class_name = object_name.replace('.', '')
if 'type' in yproperty and yproperty['type'] == 'array':
class_name += 'Iter'
types.append((name, class_name))
pt.update({'type': "\'%s\'" % class_name})
if len(pt) > 0:
types.append((name, pt))

return types

def _get_required_and_defaults(self, yobject):
required = []
defaults = []
if 'required' in yobject:
required = yobject['required']
if 'properties' in yobject:
for name in yobject['properties']:
yproperty = yobject['properties'][name]
if 'default' in yproperty:
default = yproperty['default']
if 'type' in yproperty and yproperty['type'] == 'number':
default = float(default)
defaults.append((name, default))
return (tuple(required), defaults)

def _get_default_value(self, property):
if 'default' in property:
Expand Down
Loading

0 comments on commit f38eb67

Please sign in to comment.