From c554f83685186be4cfa9387eb5d6d700d2bbd7c0 Mon Sep 17 00:00:00 2001 From: Wenda Chu <32250288+w1nda@users.noreply.github.com> Date: Thu, 18 Apr 2024 12:45:26 -0700 Subject: [PATCH] Add new initialization arguments for mask class to support set mask with less code (#206) * init mask with dont care all * Reorder functions to get better diff display * Reorder functions to get better diff display * Reorder functions to get better diff display --- src/ptf/mask.py | 88 ++++++++++++++++++++++------------- utests/tests/ptf/test_mask.py | 33 +++++++++++++ 2 files changed, 89 insertions(+), 32 deletions(-) diff --git a/src/ptf/mask.py b/src/ptf/mask.py index 832dac6..a1f45d7 100644 --- a/src/ptf/mask.py +++ b/src/ptf/mask.py @@ -12,13 +12,28 @@ class MaskException(Exception): class Mask: - def __init__(self, exp_pkt, ignore_extra_bytes=False): + def __init__(self, exp_pkt, ignore_extra_bytes=False, dont_care_all=False): self.exp_pkt = exp_pkt self.size = len(exp_pkt) self.valid = True - self.mask = [0xFF] * self.size + self.mask = [0] * self.size if dont_care_all else [0xFF] * self.size self.ignore_extra_bytes = ignore_extra_bytes + def set_care(self, offset, bitwidth): + for idx in range(offset, offset + bitwidth): + offsetB = idx // 8 + offsetb = idx % 8 + self.mask[offsetB] = self.mask[offsetB] | (1 << (7 - offsetb)) + + def set_care_all(self): + self.mask = [0xFF] * self.size + + def set_care_packet(self, hdr_type, field_name): + offset, bitwidth = self._calculate_fields_offset_and_bitwidth( + hdr_type, field_name + ) + self.set_care(offset, bitwidth) + def set_do_not_care(self, offset, bitwidth): # a very naive but simple method # we do it bit by bit :) @@ -27,7 +42,45 @@ def set_do_not_care(self, offset, bitwidth): offsetb = idx % 8 self.mask[offsetB] = self.mask[offsetB] & (~(1 << (7 - offsetb))) + def set_do_not_care_all(self): + self.mask = [0] * self.size + def set_do_not_care_packet(self, hdr_type, field_name): + offset, bitwidth = self._calculate_fields_offset_and_bitwidth( + hdr_type, field_name + ) + self.set_do_not_care(offset, bitwidth) + + def set_do_not_care_scapy(self, hdr_type, field_name): + warnings.warn( + '"set_do_not_care_scapy" is going to be deprecated, please ' + 'switch to the new one: "set_do_not_care_packet"', + DeprecationWarning, + ) + self.set_do_not_care_packet(hdr_type, field_name) + + def set_ignore_extra_bytes(self): + self.ignore_extra_bytes = True + + def is_valid(self): + return self.valid + + def pkt_match(self, pkt): + # just to be on the safe side + pkt = bytearray(bytes(pkt)) + # we fail if we don't match on sizes, or if ignore_extra_bytes is set, + # fail if we have not received at least size bytes + if (not self.ignore_extra_bytes and len(pkt) != self.size) or len( + pkt + ) < self.size: + return False + exp_pkt = bytearray(bytes(self.exp_pkt)) + for i in range(self.size): + if (exp_pkt[i] & self.mask[i]) != (pkt[i] & self.mask[i]): + return False + return True + + def _calculate_fields_offset_and_bitwidth(self, hdr_type, field_name): if hdr_type not in self.exp_pkt: self.valid = False raise MaskException("Unknown header type") @@ -64,36 +117,7 @@ def set_do_not_care_packet(self, hdr_type, field_name): break else: offset += bits - self.set_do_not_care(hdr_offset * 8 + offset, bitwidth) - - def set_do_not_care_scapy(self, hdr_type, field_name): - warnings.warn( - '"set_do_not_care_scapy" is going to be deprecated, please ' - 'switch to the new one: "set_do_not_care_packet"', - DeprecationWarning, - ) - self.set_do_not_care_packet(hdr_type, field_name) - - def set_ignore_extra_bytes(self): - self.ignore_extra_bytes = True - - def is_valid(self): - return self.valid - - def pkt_match(self, pkt): - # just to be on the safe side - pkt = bytearray(bytes(pkt)) - # we fail if we don't match on sizes, or if ignore_extra_bytes is set, - # fail if we have not received at least size bytes - if (not self.ignore_extra_bytes and len(pkt) != self.size) or len( - pkt - ) < self.size: - return False - exp_pkt = bytearray(bytes(self.exp_pkt)) - for i in range(self.size): - if (exp_pkt[i] & self.mask[i]) != (pkt[i] & self.mask[i]): - return False - return True + return hdr_offset * 8 + offset, bitwidth def __str__(self): old_stdout = sys.stdout diff --git a/utests/tests/ptf/test_mask.py b/utests/tests/ptf/test_mask.py index 83c2119..2b4a70d 100644 --- a/utests/tests/ptf/test_mask.py +++ b/utests/tests/ptf/test_mask.py @@ -23,6 +23,13 @@ def test_mask__mask_simple_packet(self, scapy_simple_tcp_packet): mask_packet.set_do_not_care_packet(TCP, "chksum") assert mask_packet.pkt_match(modified_packet) + def test_mask__set_do_not_care_all(self): + expected_packet = "\x01\x02\x03\x04\x05\x06" + packet = "\x08\x07\x06\x05\x04\x03\x02\x01" + mask = Mask(expected_packet.encode(), ignore_extra_bytes=True) + mask.set_do_not_care_all() + assert mask.pkt_match(packet.encode()) + def test_mask__set_do_not_care(self): expected_packet = "\x01\x02\x03\x04\x05\x06" packet = "\x01\x00\x00\x04\x05\x06\x07\x08" @@ -30,6 +37,32 @@ def test_mask__set_do_not_care(self): mask.set_do_not_care(8, 16) assert mask.pkt_match(packet.encode()) + def test_mask__set_care_all(self): + expected_packet = "\x01\x02\x03\x04\x05\x06" + packet = "\x00\x02\x03\x04\x05\x06" + mask = Mask(expected_packet.encode(), ignore_extra_bytes=True) + mask.set_care_all() + assert not mask.pkt_match(packet.encode()) + + def test_mask__set_care(self): + expected_packet = "\x01\x02\x03\x04\x05\x06" + packet = "\x01\x02\x00\x04\x05\x06\x07\x08" + mask = Mask( + expected_packet.encode(), ignore_extra_bytes=True, dont_care_all=True + ) + mask.set_care(0, 16) + assert mask.pkt_match(packet.encode()) + mask.set_care(16, 8) + assert not mask.pkt_match(packet.encode()) + + def test_mask__set_care_packet(self): + packet = IP(src="1.1.1.1") + mask = Mask(packet.copy(), ignore_extra_bytes=True, dont_care_all=True) + packet[IP].src = "2.2.2.2" + assert mask.pkt_match(packet) + mask.set_care_packet(IP, "src") + assert not mask.pkt_match(packet) + def test_mask__check_masking_conditional_field(self, scapy_simple_vxlan_packet): simple_vxlan = scapy_simple_vxlan_packet simple_vxlan[VXLAN].flags = "G"