Skip to content

Commit

Permalink
nsxt: Typing fixes for Python 3.6 compatibility.
Browse files Browse the repository at this point in the history
Ability to run checked mainly by running `docker build --tag=capirca .` and
`docker run capirca`, since the current `Dockerfile` conveniently uses Python
3.6.

To run the tests, a separate `Dockerfile.tests` was used (but not submitted)
with the following changes:

    WORKDIR /app

    ENTRYPOINT ["python3", "-m", "unittest", "discover", "-s", ".", "-p", "nsxt_test.py", "-v"]

and running the test as follows:

    docker build -f Dockerfile.tests --tag capirca:test_nsxt .
    docker run capirca:test_nsxt

Fixes #345.
  • Loading branch information
ivucica committed Jan 11, 2024
1 parent 102b025 commit 65f896f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 36 deletions.
57 changes: 29 additions & 28 deletions capirca/lib/nsxt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

import datetime
import json
from typing import Literal, TypedDict, Optional, Union, Tuple
from typing_extensions import Literal, TypedDict
from typing import Optional, Union, Tuple, List

from absl import logging
from capirca.lib import aclgenerator
Expand Down Expand Up @@ -90,9 +91,9 @@ class NsxtUnsupportedManyPoliciesError(Error):
class ServiceEntries:
"""Represents service entries for a rule."""

def __init__(self, protocol: int, source_ports: list[Tuple[str, str]],
destination_ports: list[Tuple[str, str]],
icmp_types: list[int]):
def __init__(self, protocol: int, source_ports: List[Tuple[str, str]],
destination_ports: List[Tuple[str, str]],
icmp_types: List[int]):
"""Setting things up.
Args:
Expand Down Expand Up @@ -239,8 +240,8 @@ def __str__(self):
af_list = [self.af]

# There can be many source and destination addresses.
source_address: list[nacaddr.IPType] = []
destination_address: list[nacaddr.IPType] = []
source_address: List[nacaddr.IPType] = []
destination_address: List[nacaddr.IPType] = []
source_addr = []
destination_addr = []

Expand All @@ -257,32 +258,32 @@ def __str__(self):
# cannot be a part of a netblock passed into NSX-T API. Currently only
# addressing IPv4 as that's where the issue has been identified.
# https://github.com/google/capirca/issues/348
zero_ip_address: list[nacaddr.IPType] = []
zero_ip_address: List[nacaddr.IPType] = []
if af == 4:
zero_ip_address: list[nacaddr.IPType] = [nacaddr.IP('0.0.0.0/32')]
zero_ip_address: List[nacaddr.IPType] = [nacaddr.IP('0.0.0.0/32')]

# source address
if self.term.source_address:
source_address: list[nacaddr.IPType] = self.term.GetAddressOfVersion(
source_address: List[nacaddr.IPType] = self.term.GetAddressOfVersion(
'source_address', af)
source_address_exclude: list[nacaddr.IPType] = (
source_address_exclude: List[nacaddr.IPType] = (
self.term.GetAddressOfVersion('source_address_exclude', af))

if source_address_exclude:
source_address: list[nacaddr.IPType] = nacaddr.ExcludeAddrs(
source_address: List[nacaddr.IPType] = nacaddr.ExcludeAddrs(
source_address,
source_address_exclude + zero_ip_address)
else:
if (af == 4 and source_address and
'0.0.0.0/0' not in [str(a) for a in source_address]):
# Exclude 0.0.0.0/32, removing 0.0.0.0/anything netblocks. However,
# do so only if we would not already have 'ANY' in the list.
source_address: list[nacaddr.IPType] = nacaddr.ExcludeAddrs(
source_address: List[nacaddr.IPType] = nacaddr.ExcludeAddrs(
source_address, zero_ip_address)
if source_address:
if af == 4:
source_address: list[nacaddr.IPv4]
source_v4_addr: list[nacaddr.IPv4] = source_address
source_address: List[nacaddr.IPv4]
source_v4_addr: List[nacaddr.IPv4] = source_address
if (source_v4_addr and
'0.0.0.0/0' in [str(a) for a in source_address]):
# Once we make the address list empty, it'll be set to ANY later
Expand All @@ -292,35 +293,35 @@ def __str__(self):
# later, we'll correctly not use ANY.)
#
# See https://github.com/google/capirca/issues/348
source_v4_addr: list[nacaddr.IPv4] = []
source_v4_addr: List[nacaddr.IPv4] = []
else:
source_address: list[nacaddr.IPv6]
source_v6_addr: list[nacaddr.IPv6] = source_address
source_address: List[nacaddr.IPv6]
source_v6_addr: List[nacaddr.IPv6] = source_address
source_addr = source_v4_addr + source_v6_addr

# destination address
if self.term.destination_address:
destination_address: list[
destination_address: List[
nacaddr.IPType] = self.term.GetAddressOfVersion(
'destination_address', af)
destination_address_exclude: list[nacaddr.IPType] = (
destination_address_exclude: List[nacaddr.IPType] = (
self.term.GetAddressOfVersion('destination_address_exclude', af))

if destination_address_exclude:
destination_address: list[nacaddr.IPType] = nacaddr.ExcludeAddrs(
destination_address: List[nacaddr.IPType] = nacaddr.ExcludeAddrs(
destination_address,
destination_address_exclude + zero_ip_address)
else:
if (af == 4 and source_address and
'0.0.0.0/0' not in [str(a) for a in source_address]):
# Exclude 0.0.0.0/32, removing 0.0.0.0/anything netblocks. However,
# do so only if we would not already have 'ANY' in the list.
destination_address: list[nacaddr.IPType] = nacaddr.ExcludeAddrs(
destination_address: List[nacaddr.IPType] = nacaddr.ExcludeAddrs(
destination_address, zero_ip_address)
if destination_address:
if af == 4:
destination_address: list[nacaddr.IPv4]
dest_v4_addr: list[nacaddr.IPv4] = destination_address
destination_address: List[nacaddr.IPv4]
dest_v4_addr: List[nacaddr.IPv4] = destination_address
if (dest_v4_addr and
'0.0.0.0/0' in [str(a) for a in destination_address]):
# Once we make the address list empty, it'll be set to ANY later
Expand All @@ -330,10 +331,10 @@ def __str__(self):
# later, we'll correctly not use ANY.)
#
# See https://github.com/google/capirca/issues/348
dest_v4_addr: list[nacaddr.IPv4] = []
dest_v4_addr: List[nacaddr.IPv4] = []
else:
destination_address: list[nacaddr.IPv6]
dest_v6_addr: list[nacaddr.IPv6] = destination_address
destination_address: List[nacaddr.IPv6]
dest_v6_addr: List[nacaddr.IPv6] = destination_address
destination_addr = dest_v4_addr + dest_v6_addr

# Check for mismatch IP for source and destination address for mixed filter
Expand Down Expand Up @@ -420,13 +421,13 @@ class Nsxt(aclgenerator.ACLGenerator):
_FILTER_OPTIONS_DICT = {}

def _TranslatePolicy(self, pol: policy.Policy, exp_info: int):
self.nsxt_policies: list[Tuple[policy.Header, str, list[Term]]] = []
self.nsxt_policies: List[Tuple[policy.Header, str, List[Term]]] = []
current_date = datetime.datetime.utcnow().date()

# Warn about policies that will expire in less than exp_info weeks.
exp_info_date = current_date + datetime.timedelta(weeks=exp_info)

filters: list[Tuple[policy.Header, list[policy.Term]]] = pol.filters
filters: List[Tuple[policy.Header, List[policy.Term]]] = pol.filters
for header, terms in filters:
if self._PLATFORM not in header.platforms:
continue
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ absl-py
ply
PyYAML
six>=1.12.0
typing_extensions
17 changes: 9 additions & 8 deletions tests/lib/nsxt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

import copy
import json
from typing import Any, Literal, Tuple, Union
from typing import Any, Tuple, Union, Dict, List
from typing_extensions import Literal
from unittest import mock

from absl.testing import absltest
Expand Down Expand Up @@ -864,14 +865,14 @@ class TestTrafficKindGrid(parameterized.TestCase):

# Which address set should be put into the policy, based on the type of policy
# we're testing?
KIND_TO_ADDRESS: dict[_TRAFFIC_KIND, _ADDRESSES] = {
KIND_TO_ADDRESS: Dict[_TRAFFIC_KIND, _ADDRESSES] = {
'mixed': 'GOOGLE_DNS',
'v4': 'INTERNAL_V4',
'v6': 'INTERNAL_V6'}

# Which expanded address group (e.g. netblocks) is expected, based on the type
# of policy we're testing?
KIND_TO_ADDRESS_GROUPS: dict[
KIND_TO_ADDRESS_GROUPS: Dict[
_TRAFFIC_KIND, Union[nacaddr.IPv4, nacaddr.IPv6, Literal['ANY']]] = {
# 'GOOGLE_DNS'
'mixed': [nacaddr.IP('8.8.4.4/32'), nacaddr.IP('8.8.8.8/32'),
Expand Down Expand Up @@ -961,11 +962,11 @@ def test_generator_works(self):
' destination-address:: INTERNAL_V6',
'}']))

def get_source_dest_addresses(self, nsxt_json: dict[str, Any]) -> (
Tuple[list[str], list[str]]):
rules: list[dict[str, Any]] = nsxt_json['rules']
src: list[str] = []
dst: list[str] = []
def get_source_dest_addresses(self, nsxt_json: Dict[str, Any]) -> (
Tuple[List[str], List[str]]):
rules: List[Dict[str, Any]] = nsxt_json['rules']
src: List[str] = []
dst: List[str] = []

for rule in rules:
src.extend(i for i in rule['source_groups'])
Expand Down

0 comments on commit 65f896f

Please sign in to comment.