Skip to content

Commit

Permalink
[generator][python] Added diff_complex.
Browse files Browse the repository at this point in the history
  • Loading branch information
maksimandrianov committed Sep 7, 2020
1 parent eaa7355 commit 70721a0
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 0 deletions.
Empty file.
138 changes: 138 additions & 0 deletions tools/python/diff_complex/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import argparse
import csv
import logging
import sys
from itertools import islice
from typing import Dict

import zss

from diff_complex.trees_builder import Node
from diff_complex.trees_builder import read_complexes_from_csv

logger = logging.getLogger("diff_complex")

csv.field_size_limit(sys.maxsize)


def parse_args():
parser = argparse.ArgumentParser(description="Compare comples files.")
parser.add_argument(
"--old", metavar="PATH", type=str, help="Path to old file", required=True
)
parser.add_argument(
"--new", metavar="PATH", type=str, help="Path to new file", required=True
)

parser.add_argument(
"--popularity", metavar="PATH", type=str, help="Path to popularity file"
)
parser.add_argument(
"--num",
type=int,
help="Number of objects from popularity file that to be compared",
default=50,
)

parser.add_argument(
"--threshold", type=int, help="Threshold of tree distance", default=1,
)

parser.add_argument(
"--from_root",
default=False,
action="store_true",
help="Compare trees from roots.",
)

return parser.parse_args()


def label_dist(a, b):
if a == b:
return 0
else:
return 1


def diff(
old_complexes_map: Dict[str, Node],
new_complexes_map: Dict[str, Node],
id_: str,
threshold: int = 1,
from_root: bool = False,
):
old_tree = old_complexes_map.get(id_)
if old_tree is None:
logger.warning(f"{id_} is not found in old complexes.")
return

new_tree = new_complexes_map.get(id_)
if new_tree is None:
logger.warning(f"{id_} is not found in new complexes.")
return

if from_root:
p = old_tree.parent
while p is not None:
old_tree = p.parent

p = new_tree.parent
while p is not None:
new_tree = p.parent

operations = zss.simple_distance(
old_tree, new_tree, label_dist=label_dist, return_operations=True
)

if operations[0] >= threshold:
op_o = {
o.arg1: o
for o in operations[1]
if o.type == zss.Operation.remove or o.type == zss.Operation.update
}
op_n = {
o.arg2: o
for o in operations[1]
if o.type == zss.Operation.insert or o.type == zss.Operation.update
}

logger.warning(
f"Differences found for id[{id_}]: distance is {operations[0]}\n"
f"Old:\n"
f"{old_tree.to_string_with_operations(op_o)}\n"
f"New:\n"
f"{new_tree.to_string_with_operations(op_n)}"
)


def main():
logging.basicConfig(
level=logging.INFO, format="[%(asctime)s] %(levelname)s %(module)s %(message)s"
)

args = parse_args()
old_complexes_map = read_complexes_from_csv(args.old)
logger.info(f"{len(old_complexes_map)} old complexes was read from {args.old}")

new_complexes_map = read_complexes_from_csv(args.new)
logger.info(f"{len(new_complexes_map)} new complexes was read from {args.new}")

if args.popularity:
with open(args.popularity) as csvfile:
rows = csv.reader(csvfile, delimiter=",")
ids = [row[0] for row in islice(rows, args.num)]
else:
old_complexes_map = {
k: v for k, v in old_complexes_map.items() if v.parent is None
}
new_complexes_map = {
k: v for k, v in new_complexes_map.items() if v.parent is None
}
ids = list(old_complexes_map.keys())

for id_ in ids:
diff(old_complexes_map, new_complexes_map, id_, args.threshold, args.from_root)


main()
3 changes: 3 additions & 0 deletions tools/python/diff_complex/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-r ../mwm/requirements.txt
zss==1.2.0
numpy>=1.7
126 changes: 126 additions & 0 deletions tools/python/diff_complex/trees_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import csv
import logging
from typing import Dict
from typing import List
from typing import Optional

import zss

from mwm.decode_id import decode_id

logger = logging.getLogger("diff_complex")


class Node(zss.Node):
def __init__(self, label, children=None):
self.parent = None
super().__init__(label, children)

def addpar(self, node):
self.parent = node

def __str__(self):
return self.to_string_with_operations(operations={})

def to_string_with_operations(self, operations):
lines = []
_print_tree(self, lines, is_root=True, operations=operations)
return "".join(lines)

def __hash__(self):
return id(self)


def _print_tree(node, lines, prefix="", is_tail=True, is_root=False, operations=None):
lines.append(prefix)
if not is_root:
if is_tail:
lines.append("└───")
prefix += " "
else:
lines.append("├───")
prefix += "│ "

label = str(node.label).replace("\n", "")
if operations is not None and node in operations:
t = operations[node].type
if t == zss.Operation.remove:
lines.append("(-)")
elif t == zss.Operation.insert:
lines.append("(+)")
elif t == zss.Operation.update:
lines.append("(-+)")
lines.append(label)
lines.append("\n")

l = len(node.children) - 1
for i, c in enumerate(node.children):
_print_tree(c, lines, prefix, i == l, False, operations)


def link_nodes(node: Node, parent: Node):
node.addpar(parent)
parent.addkid(node)


class HierarchyEntry:
__slots__ = ("id", "parent_id", "depth", "lat", "lon", "type", "name", "country")

@staticmethod
def make_from_csv_row(csv_row: List[str]) -> Optional["HierarchyEntry"]:
e = HierarchyEntry()
if len(csv_row) == 8:
e.id = csv_row[0]
e.parent_id = csv_row[1]
e.depth = int(csv_row[2])
e.lat = float(csv_row[3])
e.lon = float(csv_row[4])
e.type = csv_row[5]
e.name = csv_row[6]
e.country = csv_row[7]
return e
# For old format:
elif len(csv_row) == 6:
e.id = csv_row[0]
e.parent_id = csv_row[1]
e.lat = float(csv_row[2])
e.lon = float(csv_row[3])
e.type = csv_row[4]
e.name = csv_row[5]
return e
logger.error(f"Row [{csv_row}] - {len(csv_row)} cannot be parsed.")
return None

def __eq__(self, other):
if isinstance(other, (HierarchyEntry, str)):
self_id = self.id
other_id = other if isinstance(other, str) else other.id
return self_id == other_id

raise TypeError(f"{other}:{type(other)} is not supported.")

def __str__(self):
return (
f"{self.id}[{self.type}]:{self.name} "
f"({decode_id(self.id.split()[0] if ' ' in self.id else self.id)})"
)


def read_complexes_from_csv(path: str) -> Dict[str, Node]:
m = {}
with open(path) as csvfile:
rows = csv.reader(csvfile, delimiter=";")

for row in rows:
e = HierarchyEntry.make_from_csv_row(row)
m[e.id] = Node(e)

for id_, node in m.items():
if node.label.parent_id:
try:
link_nodes(node, m[node.label.parent_id])
except KeyError:
logger.error(f"Id {node.label.parent_id} was not found in dict.")
pass

return m

0 comments on commit 70721a0

Please sign in to comment.