-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[generator][python] Added diff_complex.
- Loading branch information
1 parent
eaa7355
commit 70721a0
Showing
4 changed files
with
267 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import argparse | ||
import csv | ||
import logging | ||
import sys | ||
from itertools import islice | ||
from typing import Dict | ||
|
||
import zss | ||
|
||
from diff_complex.trees_builder import Node | ||
from diff_complex.trees_builder import read_complexes_from_csv | ||
|
||
logger = logging.getLogger("diff_complex") | ||
|
||
csv.field_size_limit(sys.maxsize) | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description="Compare comples files.") | ||
parser.add_argument( | ||
"--old", metavar="PATH", type=str, help="Path to old file", required=True | ||
) | ||
parser.add_argument( | ||
"--new", metavar="PATH", type=str, help="Path to new file", required=True | ||
) | ||
|
||
parser.add_argument( | ||
"--popularity", metavar="PATH", type=str, help="Path to popularity file" | ||
) | ||
parser.add_argument( | ||
"--num", | ||
type=int, | ||
help="Number of objects from popularity file that to be compared", | ||
default=50, | ||
) | ||
|
||
parser.add_argument( | ||
"--threshold", type=int, help="Threshold of tree distance", default=1, | ||
) | ||
|
||
parser.add_argument( | ||
"--from_root", | ||
default=False, | ||
action="store_true", | ||
help="Compare trees from roots.", | ||
) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def label_dist(a, b): | ||
if a == b: | ||
return 0 | ||
else: | ||
return 1 | ||
|
||
|
||
def diff( | ||
old_complexes_map: Dict[str, Node], | ||
new_complexes_map: Dict[str, Node], | ||
id_: str, | ||
threshold: int = 1, | ||
from_root: bool = False, | ||
): | ||
old_tree = old_complexes_map.get(id_) | ||
if old_tree is None: | ||
logger.warning(f"{id_} is not found in old complexes.") | ||
return | ||
|
||
new_tree = new_complexes_map.get(id_) | ||
if new_tree is None: | ||
logger.warning(f"{id_} is not found in new complexes.") | ||
return | ||
|
||
if from_root: | ||
p = old_tree.parent | ||
while p is not None: | ||
old_tree = p.parent | ||
|
||
p = new_tree.parent | ||
while p is not None: | ||
new_tree = p.parent | ||
|
||
operations = zss.simple_distance( | ||
old_tree, new_tree, label_dist=label_dist, return_operations=True | ||
) | ||
|
||
if operations[0] >= threshold: | ||
op_o = { | ||
o.arg1: o | ||
for o in operations[1] | ||
if o.type == zss.Operation.remove or o.type == zss.Operation.update | ||
} | ||
op_n = { | ||
o.arg2: o | ||
for o in operations[1] | ||
if o.type == zss.Operation.insert or o.type == zss.Operation.update | ||
} | ||
|
||
logger.warning( | ||
f"Differences found for id[{id_}]: distance is {operations[0]}\n" | ||
f"Old:\n" | ||
f"{old_tree.to_string_with_operations(op_o)}\n" | ||
f"New:\n" | ||
f"{new_tree.to_string_with_operations(op_n)}" | ||
) | ||
|
||
|
||
def main(): | ||
logging.basicConfig( | ||
level=logging.INFO, format="[%(asctime)s] %(levelname)s %(module)s %(message)s" | ||
) | ||
|
||
args = parse_args() | ||
old_complexes_map = read_complexes_from_csv(args.old) | ||
logger.info(f"{len(old_complexes_map)} old complexes was read from {args.old}") | ||
|
||
new_complexes_map = read_complexes_from_csv(args.new) | ||
logger.info(f"{len(new_complexes_map)} new complexes was read from {args.new}") | ||
|
||
if args.popularity: | ||
with open(args.popularity) as csvfile: | ||
rows = csv.reader(csvfile, delimiter=",") | ||
ids = [row[0] for row in islice(rows, args.num)] | ||
else: | ||
old_complexes_map = { | ||
k: v for k, v in old_complexes_map.items() if v.parent is None | ||
} | ||
new_complexes_map = { | ||
k: v for k, v in new_complexes_map.items() if v.parent is None | ||
} | ||
ids = list(old_complexes_map.keys()) | ||
|
||
for id_ in ids: | ||
diff(old_complexes_map, new_complexes_map, id_, args.threshold, args.from_root) | ||
|
||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
-r ../mwm/requirements.txt | ||
zss==1.2.0 | ||
numpy>=1.7 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import csv | ||
import logging | ||
from typing import Dict | ||
from typing import List | ||
from typing import Optional | ||
|
||
import zss | ||
|
||
from mwm.decode_id import decode_id | ||
|
||
logger = logging.getLogger("diff_complex") | ||
|
||
|
||
class Node(zss.Node): | ||
def __init__(self, label, children=None): | ||
self.parent = None | ||
super().__init__(label, children) | ||
|
||
def addpar(self, node): | ||
self.parent = node | ||
|
||
def __str__(self): | ||
return self.to_string_with_operations(operations={}) | ||
|
||
def to_string_with_operations(self, operations): | ||
lines = [] | ||
_print_tree(self, lines, is_root=True, operations=operations) | ||
return "".join(lines) | ||
|
||
def __hash__(self): | ||
return id(self) | ||
|
||
|
||
def _print_tree(node, lines, prefix="", is_tail=True, is_root=False, operations=None): | ||
lines.append(prefix) | ||
if not is_root: | ||
if is_tail: | ||
lines.append("└───") | ||
prefix += " " | ||
else: | ||
lines.append("├───") | ||
prefix += "│ " | ||
|
||
label = str(node.label).replace("\n", "") | ||
if operations is not None and node in operations: | ||
t = operations[node].type | ||
if t == zss.Operation.remove: | ||
lines.append("(-)") | ||
elif t == zss.Operation.insert: | ||
lines.append("(+)") | ||
elif t == zss.Operation.update: | ||
lines.append("(-+)") | ||
lines.append(label) | ||
lines.append("\n") | ||
|
||
l = len(node.children) - 1 | ||
for i, c in enumerate(node.children): | ||
_print_tree(c, lines, prefix, i == l, False, operations) | ||
|
||
|
||
def link_nodes(node: Node, parent: Node): | ||
node.addpar(parent) | ||
parent.addkid(node) | ||
|
||
|
||
class HierarchyEntry: | ||
__slots__ = ("id", "parent_id", "depth", "lat", "lon", "type", "name", "country") | ||
|
||
@staticmethod | ||
def make_from_csv_row(csv_row: List[str]) -> Optional["HierarchyEntry"]: | ||
e = HierarchyEntry() | ||
if len(csv_row) == 8: | ||
e.id = csv_row[0] | ||
e.parent_id = csv_row[1] | ||
e.depth = int(csv_row[2]) | ||
e.lat = float(csv_row[3]) | ||
e.lon = float(csv_row[4]) | ||
e.type = csv_row[5] | ||
e.name = csv_row[6] | ||
e.country = csv_row[7] | ||
return e | ||
# For old format: | ||
elif len(csv_row) == 6: | ||
e.id = csv_row[0] | ||
e.parent_id = csv_row[1] | ||
e.lat = float(csv_row[2]) | ||
e.lon = float(csv_row[3]) | ||
e.type = csv_row[4] | ||
e.name = csv_row[5] | ||
return e | ||
logger.error(f"Row [{csv_row}] - {len(csv_row)} cannot be parsed.") | ||
return None | ||
|
||
def __eq__(self, other): | ||
if isinstance(other, (HierarchyEntry, str)): | ||
self_id = self.id | ||
other_id = other if isinstance(other, str) else other.id | ||
return self_id == other_id | ||
|
||
raise TypeError(f"{other}:{type(other)} is not supported.") | ||
|
||
def __str__(self): | ||
return ( | ||
f"{self.id}[{self.type}]:{self.name} " | ||
f"({decode_id(self.id.split()[0] if ' ' in self.id else self.id)})" | ||
) | ||
|
||
|
||
def read_complexes_from_csv(path: str) -> Dict[str, Node]: | ||
m = {} | ||
with open(path) as csvfile: | ||
rows = csv.reader(csvfile, delimiter=";") | ||
|
||
for row in rows: | ||
e = HierarchyEntry.make_from_csv_row(row) | ||
m[e.id] = Node(e) | ||
|
||
for id_, node in m.items(): | ||
if node.label.parent_id: | ||
try: | ||
link_nodes(node, m[node.label.parent_id]) | ||
except KeyError: | ||
logger.error(f"Id {node.label.parent_id} was not found in dict.") | ||
pass | ||
|
||
return m |