Skip to content

Commit

Permalink
Add new initialization arguments for mask class to support set mask w…
Browse files Browse the repository at this point in the history
…ith less code (#206)

* init mask with dont care all

* Reorder functions to get better diff display

* Reorder functions to get better diff display

* Reorder functions to get better diff display
  • Loading branch information
w1nda authored Apr 18, 2024
1 parent 23ebe72 commit c554f83
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 32 deletions.
88 changes: 56 additions & 32 deletions src/ptf/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,28 @@ class MaskException(Exception):


class Mask:
def __init__(self, exp_pkt, ignore_extra_bytes=False):
def __init__(self, exp_pkt, ignore_extra_bytes=False, dont_care_all=False):
self.exp_pkt = exp_pkt
self.size = len(exp_pkt)
self.valid = True
self.mask = [0xFF] * self.size
self.mask = [0] * self.size if dont_care_all else [0xFF] * self.size
self.ignore_extra_bytes = ignore_extra_bytes

def set_care(self, offset, bitwidth):
for idx in range(offset, offset + bitwidth):
offsetB = idx // 8
offsetb = idx % 8
self.mask[offsetB] = self.mask[offsetB] | (1 << (7 - offsetb))

def set_care_all(self):
self.mask = [0xFF] * self.size

def set_care_packet(self, hdr_type, field_name):
offset, bitwidth = self._calculate_fields_offset_and_bitwidth(
hdr_type, field_name
)
self.set_care(offset, bitwidth)

def set_do_not_care(self, offset, bitwidth):
# a very naive but simple method
# we do it bit by bit :)
Expand All @@ -27,7 +42,45 @@ def set_do_not_care(self, offset, bitwidth):
offsetb = idx % 8
self.mask[offsetB] = self.mask[offsetB] & (~(1 << (7 - offsetb)))

def set_do_not_care_all(self):
self.mask = [0] * self.size

def set_do_not_care_packet(self, hdr_type, field_name):
offset, bitwidth = self._calculate_fields_offset_and_bitwidth(
hdr_type, field_name
)
self.set_do_not_care(offset, bitwidth)

def set_do_not_care_scapy(self, hdr_type, field_name):
warnings.warn(
'"set_do_not_care_scapy" is going to be deprecated, please '
'switch to the new one: "set_do_not_care_packet"',
DeprecationWarning,
)
self.set_do_not_care_packet(hdr_type, field_name)

def set_ignore_extra_bytes(self):
self.ignore_extra_bytes = True

def is_valid(self):
return self.valid

def pkt_match(self, pkt):
# just to be on the safe side
pkt = bytearray(bytes(pkt))
# we fail if we don't match on sizes, or if ignore_extra_bytes is set,
# fail if we have not received at least size bytes
if (not self.ignore_extra_bytes and len(pkt) != self.size) or len(
pkt
) < self.size:
return False
exp_pkt = bytearray(bytes(self.exp_pkt))
for i in range(self.size):
if (exp_pkt[i] & self.mask[i]) != (pkt[i] & self.mask[i]):
return False
return True

def _calculate_fields_offset_and_bitwidth(self, hdr_type, field_name):
if hdr_type not in self.exp_pkt:
self.valid = False
raise MaskException("Unknown header type")
Expand Down Expand Up @@ -64,36 +117,7 @@ def set_do_not_care_packet(self, hdr_type, field_name):
break
else:
offset += bits
self.set_do_not_care(hdr_offset * 8 + offset, bitwidth)

def set_do_not_care_scapy(self, hdr_type, field_name):
warnings.warn(
'"set_do_not_care_scapy" is going to be deprecated, please '
'switch to the new one: "set_do_not_care_packet"',
DeprecationWarning,
)
self.set_do_not_care_packet(hdr_type, field_name)

def set_ignore_extra_bytes(self):
self.ignore_extra_bytes = True

def is_valid(self):
return self.valid

def pkt_match(self, pkt):
# just to be on the safe side
pkt = bytearray(bytes(pkt))
# we fail if we don't match on sizes, or if ignore_extra_bytes is set,
# fail if we have not received at least size bytes
if (not self.ignore_extra_bytes and len(pkt) != self.size) or len(
pkt
) < self.size:
return False
exp_pkt = bytearray(bytes(self.exp_pkt))
for i in range(self.size):
if (exp_pkt[i] & self.mask[i]) != (pkt[i] & self.mask[i]):
return False
return True
return hdr_offset * 8 + offset, bitwidth

def __str__(self):
old_stdout = sys.stdout
Expand Down
33 changes: 33 additions & 0 deletions utests/tests/ptf/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,46 @@ def test_mask__mask_simple_packet(self, scapy_simple_tcp_packet):
mask_packet.set_do_not_care_packet(TCP, "chksum")
assert mask_packet.pkt_match(modified_packet)

def test_mask__set_do_not_care_all(self):
expected_packet = "\x01\x02\x03\x04\x05\x06"
packet = "\x08\x07\x06\x05\x04\x03\x02\x01"
mask = Mask(expected_packet.encode(), ignore_extra_bytes=True)
mask.set_do_not_care_all()
assert mask.pkt_match(packet.encode())

def test_mask__set_do_not_care(self):
expected_packet = "\x01\x02\x03\x04\x05\x06"
packet = "\x01\x00\x00\x04\x05\x06\x07\x08"
mask = Mask(expected_packet.encode(), ignore_extra_bytes=True)
mask.set_do_not_care(8, 16)
assert mask.pkt_match(packet.encode())

def test_mask__set_care_all(self):
expected_packet = "\x01\x02\x03\x04\x05\x06"
packet = "\x00\x02\x03\x04\x05\x06"
mask = Mask(expected_packet.encode(), ignore_extra_bytes=True)
mask.set_care_all()
assert not mask.pkt_match(packet.encode())

def test_mask__set_care(self):
expected_packet = "\x01\x02\x03\x04\x05\x06"
packet = "\x01\x02\x00\x04\x05\x06\x07\x08"
mask = Mask(
expected_packet.encode(), ignore_extra_bytes=True, dont_care_all=True
)
mask.set_care(0, 16)
assert mask.pkt_match(packet.encode())
mask.set_care(16, 8)
assert not mask.pkt_match(packet.encode())

def test_mask__set_care_packet(self):
packet = IP(src="1.1.1.1")
mask = Mask(packet.copy(), ignore_extra_bytes=True, dont_care_all=True)
packet[IP].src = "2.2.2.2"
assert mask.pkt_match(packet)
mask.set_care_packet(IP, "src")
assert not mask.pkt_match(packet)

def test_mask__check_masking_conditional_field(self, scapy_simple_vxlan_packet):
simple_vxlan = scapy_simple_vxlan_packet
simple_vxlan[VXLAN].flags = "G"
Expand Down

0 comments on commit c554f83

Please sign in to comment.