From a4f93c78a247dc785068cdbb5f7dca5324c2d30e Mon Sep 17 00:00:00 2001 From: Alex North Date: Tue, 24 May 2016 18:10:59 +1000 Subject: [PATCH] Implement case-insensitive key comparison for csvjoin --- .gitignore | 1 + csvkit/join.py | 84 ++++++++++++++++++++++++++++--------- csvkit/utilities/csvjoin.py | 11 +++-- tests/test_join.py | 38 +++++++++++++++-- 4 files changed, 108 insertions(+), 26 deletions(-) diff --git a/.gitignore b/.gitignore index 16a4d30ae..763e62cc4 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ docs/_build .coverage .tox cover +env diff --git a/csvkit/join.py b/csvkit/join.py index 9f78bf6d1..7e9c638fd 100644 --- a/csvkit/join.py +++ b/csvkit/join.py @@ -1,15 +1,16 @@ #!/usr/bin/env python -def _get_ordered_keys(rows, column_index): +def _get_keys(rows, column_index, lowercase=False): """ - Get ordered keys from rows, given the key column index. + Get keys from rows as keys in a dictionary (i.e. unordered), given the key column index. """ - return [r[column_index] for r in rows] + pairs = ((r[column_index], True) for r in rows) + return CaseInsensitiveDict(pairs) if lowercase else dict(pairs) -def _get_mapped_keys(rows, column_index): - mapped_keys = {} +def _get_mapped_keys(rows, column_index, case_insensitive=False): + mapped_keys = CaseInsensitiveDict() if case_insensitive else {} for r in rows: key = r[column_index] @@ -21,6 +22,11 @@ def _get_mapped_keys(rows, column_index): return mapped_keys +def _lower(key): + """Transforms a string to lowercase, leaves other types alone.""" + keyfn = getattr(key, 'lower', None) + return keyfn() if keyfn else key + def sequential_join(left_rows, right_rows, header=True): """ @@ -49,7 +55,7 @@ def sequential_join(left_rows, right_rows, header=True): return output -def inner_join(left_rows, left_column_id, right_rows, right_column_id, header=True): +def inner_join(left_rows, left_column_id, right_rows, right_column_id, header=True, ignore_case=False): """ Execute an inner join on two tables and return the combined table. """ @@ -63,7 +69,7 @@ def inner_join(left_rows, left_column_id, right_rows, right_column_id, header=Tr output = [] # Map right rows to keys - right_mapped_keys = _get_mapped_keys(right_rows, right_column_id) + right_mapped_keys = _get_mapped_keys(right_rows, right_column_id, ignore_case) for left_row in left_rows: len_left_row = len(left_row) @@ -80,7 +86,7 @@ def inner_join(left_rows, left_column_id, right_rows, right_column_id, header=Tr return output -def full_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True): +def full_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True, ignore_case=False): """ Execute full outer join on two tables and return the combined table. """ @@ -94,11 +100,11 @@ def full_outer_join(left_rows, left_column_id, right_rows, right_column_id, head else: output = [] - # Get ordered keys - left_ordered_keys = _get_ordered_keys(left_rows, left_column_id) + # Get left keys + left_keys = _get_keys(left_rows, left_column_id, ignore_case) # Get mapped keys - right_mapped_keys = _get_mapped_keys(right_rows, right_column_id) + right_mapped_keys = _get_mapped_keys(right_rows, right_column_id, ignore_case) for left_row in left_rows: len_left_row = len(left_row) @@ -116,13 +122,13 @@ def full_outer_join(left_rows, left_column_id, right_rows, right_column_id, head for right_row in right_rows: right_key = right_row[right_column_id] - if right_key not in left_ordered_keys: + if right_key not in left_keys: output.append(([u''] * len_left_headers) + right_row) return output -def left_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True): +def left_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True, ignore_case=False): """ Execute left outer join on two tables and return the combined table. """ @@ -137,7 +143,7 @@ def left_outer_join(left_rows, left_column_id, right_rows, right_column_id, head output = [] # Get mapped keys - right_mapped_keys = _get_mapped_keys(right_rows, right_column_id) + right_mapped_keys = _get_mapped_keys(right_rows, right_column_id, ignore_case) for left_row in left_rows: len_left_row = len(left_row) @@ -155,7 +161,7 @@ def left_outer_join(left_rows, left_column_id, right_rows, right_column_id, head return output -def right_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True): +def right_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True, ignore_case=False): """ Execute right outer join on two tables and return the combined table. """ @@ -168,11 +174,11 @@ def right_outer_join(left_rows, left_column_id, right_rows, right_column_id, hea else: output = [] - # Get ordered keys - left_ordered_keys = _get_ordered_keys(left_rows, left_column_id) + # Get left keys + left_keys = _get_keys(left_rows, left_column_id, ignore_case) # Get mapped keys - right_mapped_keys = _get_mapped_keys(right_rows, right_column_id) + right_mapped_keys = _get_mapped_keys(right_rows, right_column_id, ignore_case) for left_row in left_rows: len_left_row = len(left_row) @@ -188,7 +194,47 @@ def right_outer_join(left_rows, left_column_id, right_rows, right_column_id, hea for right_row in right_rows: right_key = right_row[right_column_id] - if right_key not in left_ordered_keys: + if right_key not in left_keys: output.append(([u''] * len_left_headers) + right_row) return output + + + +class CaseInsensitiveDict(dict): + """ + Adapted from http://stackoverflow.com/a/32888599/1583437 + """ + def __init__(self, *args, **kwargs): + super(CaseInsensitiveDict, self).__init__(*args, **kwargs) + self._convert_keys() + + def __getitem__(self, key): + return super(CaseInsensitiveDict, self).__getitem__(_lower(key)) + + def __setitem__(self, key, value): + super(CaseInsensitiveDict, self).__setitem__(_lower(key), value) + + def __delitem__(self, key): + return super(CaseInsensitiveDict, self).__delitem__(_lower(key)) + + def __contains__(self, key): + return super(CaseInsensitiveDict, self).__contains__(_lower(key)) + + def pop(self, key, *args, **kwargs): + return super(CaseInsensitiveDict, self).pop(_lower(key), *args, **kwargs) + + def get(self, key, *args, **kwargs): + return super(CaseInsensitiveDict, self).get(_lower(key), *args, **kwargs) + + def setdefault(self, key, *args, **kwargs): + return super(CaseInsensitiveDict, self).setdefault(_lower(key), *args, **kwargs) + + def update(self, single_arg=None, **kwargs): + super(CaseInsensitiveDict, self).update(self.__class__(single_arg)) + super(CaseInsensitiveDict, self).update(self.__class__(**kwargs)) + + def _convert_keys(self): + for k in list(self.keys()): + v = super(CaseInsensitiveDict, self).pop(k) + self.__setitem__(k, v) diff --git a/csvkit/utilities/csvjoin.py b/csvkit/utilities/csvjoin.py index ab3010666..4276cf9ce 100644 --- a/csvkit/utilities/csvjoin.py +++ b/csvkit/utilities/csvjoin.py @@ -22,6 +22,8 @@ def add_arguments(self): help='Perform a left outer join, rather than the default inner join. If more than two files are provided this will be executed as a sequence of left outer joins, starting at the left.') self.argparser.add_argument('--right', dest='right_join', action='store_true', help='Perform a right outer join, rather than the default inner join. If more than two files are provided this will be executed as a sequence of right outer joins, starting at the right.') + self.argparser.add_argument('--ignorecase', dest='ignore_case', action='store_true', + help='Whether to ignore string case when comparing keys.') def main(self): self.input_files = [] @@ -62,10 +64,11 @@ def main(self): jointab = tables[0] + ignore_case = self.args.ignore_case if self.args.left_join: # Left outer join for i, t in enumerate(tables[1:]): - jointab = join.left_outer_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header) + jointab = join.left_outer_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header, ignore_case=ignore_case) elif self.args.right_join: # Right outer join jointab = tables[-1] @@ -74,15 +77,15 @@ def main(self): remaining_tables.reverse() for i, t in enumerate(remaining_tables): - jointab = join.right_outer_join(t, join_column_ids[-(i + 2)], jointab, join_column_ids[-1], header=header) + jointab = join.right_outer_join(t, join_column_ids[-(i + 2)], jointab, join_column_ids[-1], header=header, ignore_case=ignore_case) elif self.args.outer_join: # Full outer join for i, t in enumerate(tables[1:]): - jointab = join.full_outer_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header) + jointab = join.full_outer_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header, ignore_case=ignore_case) elif self.args.columns: # Inner join for i, t in enumerate(tables[1:]): - jointab = join.inner_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header) + jointab = join.inner_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header, ignore_case=ignore_case) else: # Sequential join for t in tables[1:]: diff --git a/tests/test_join.py b/tests/test_join.py index 41d67b0bd..0340cb460 100644 --- a/tests/test_join.py +++ b/tests/test_join.py @@ -25,9 +25,9 @@ def setUp(self): [u'1', u'second', u'0'], [u'2', u'only', u'0', u'0']] # Note extra value in this column - def test_get_ordered_keys(self): - self.assertEqual(join._get_ordered_keys(self.tab1[1:], 0), [u'1', u'2', u'3', u'1']) - self.assertEqual(join._get_ordered_keys(self.tab2[1:], 0), [u'1', u'4', u'1', u'2']) + def test_get_keys(self): + self.assertEqual(join._get_keys(self.tab1[1:], 0).keys(), set([u'1', u'2', u'3', u'1'])) + self.assertEqual(join._get_keys(self.tab2[1:], 0).keys(), set([u'1', u'4', u'1', u'2'])) def test_get_mapped_keys(self): self.assertEqual(join._get_mapped_keys(self.tab1[1:], 0), { @@ -35,6 +35,13 @@ def test_get_mapped_keys(self): u'2': [[u'2', u'Chicago Sun-Times', u'only']], u'3': [[u'3', u'Chicago Tribune', u'only']]}) + def test_get_mapped_keys_ignore_case(self): + mapped_keys = join._get_mapped_keys(self.tab1[1:], 1, case_insensitive=True) + assert u'Chicago Reader' in mapped_keys + assert u'chicago reader' in mapped_keys + assert u'CHICAGO SUN-TIMES' in mapped_keys + assert u'1' not in mapped_keys + def test_sequential_join(self): self.assertEqual(join.sequential_join(self.tab1, self.tab2), [ ['id', 'name', 'i_work_here', 'id', 'age', 'i_work_here'], @@ -82,3 +89,28 @@ def test_right_outer_join(self): [u'1', u'Chicago Reader', u'second', u'1', u'first', u'0'], [u'1', u'Chicago Reader', u'second', u'1', u'second', u'0'], [u'', u'', u'', u'4', u'only', u'0']]) + + def test_right_outer_join_ignore_case(self): + # Right outer join exercises all the case dependencies + tab1 = [ + ['id', 'name', 'i_work_here'], + [u'a', u'Chicago Reader', u'first'], + [u'b', u'Chicago Sun-Times', u'only'], + [u'c', u'Chicago Tribune', u'only'], + [u'a', u'Chicago Reader', u'second']] + + tab2 = [ + ['id', 'age', 'i_work_here'], + [u'A', u'first', u'0'], + [u'D', u'only', u'0'], + [u'A', u'second', u'0'], + [u'B', u'only', u'0', u'0']] # Note extra value in this column + + self.assertEqual(join.right_outer_join(tab1, 0, tab2, 0, ignore_case=True), [ + ['id', 'name', 'i_work_here', 'id', 'age', 'i_work_here'], + [u'a', u'Chicago Reader', u'first', u'A', u'first', u'0'], + [u'a', u'Chicago Reader', u'first', u'A', u'second', u'0'], + [u'b', u'Chicago Sun-Times', u'only', u'B', u'only', u'0', u'0'], + [u'a', u'Chicago Reader', u'second', u'A', u'first', u'0'], + [u'a', u'Chicago Reader', u'second', u'A', u'second', u'0'], + [u'', u'', u'', u'D', u'only', u'0']])