Skip to content

Commit

Permalink
Fix: Resolve issue where case-sensitive column pruning removes real c…
Browse files Browse the repository at this point in the history
…olumns (#161)
  • Loading branch information
aaronsteers authored Apr 2, 2024
1 parent 0720822 commit 4672849
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 169 deletions.
14 changes: 6 additions & 8 deletions airbyte/_processors/file/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from airbyte import exceptions as exc
from airbyte._batch_handles import BatchHandle
from airbyte._util.protocol_util import airbyte_record_message_to_dict
from airbyte._util.name_normalizers import LowerCaseNormalizer, StreamRecord
from airbyte.progress import progress


Expand Down Expand Up @@ -146,9 +146,6 @@ def process_record_message(
"""Write a record to the cache.
This method is called for each record message, before the batch is written.
Returns:
A tuple of the stream name and the batch handle.
"""
stream_name = record_msg.stream

Expand All @@ -167,9 +164,10 @@ def process_record_message(
raise exc.AirbyteLibInternalError(message="Expected open file writer.")

self._write_record_dict(
record_dict=airbyte_record_message_to_dict(
record_message=record_msg,
stream_schema=stream_schema,
record_dict=StreamRecord(
from_dict=record_msg.data,
expected_keys=stream_schema["properties"].keys(),
normalizer=LowerCaseNormalizer,
prune_extra_fields=self.prune_extra_fields,
),
open_file_writer=batch_handle.open_file_writer,
Expand Down Expand Up @@ -216,7 +214,7 @@ def __del__(self) -> None:
@abc.abstractmethod
def _write_record_dict(
self,
record_dict: dict,
record_dict: StreamRecord,
open_file_writer: IO[bytes],
) -> None:
"""Write one record to a file."""
Expand Down
4 changes: 2 additions & 2 deletions airbyte/_processors/file/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
if TYPE_CHECKING:
from pathlib import Path

pass
from airbyte._util.name_normalizers import StreamRecord


class JsonlWriter(FileWriterBase):
Expand All @@ -34,7 +34,7 @@ def _open_new_file(

def _write_record_dict(
self,
record_dict: dict,
record_dict: StreamRecord,
open_file_writer: gzip.GzipFile | IO[bytes],
) -> None:
open_file_writer.write(orjson.dumps(record_dict) + b"\n")
4 changes: 2 additions & 2 deletions airbyte/_processors/sql/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ def _write_files_to_new_table(
properties_list = list(self.get_stream_properties(stream_name).keys())
columns_list = list(self._get_sql_column_definitions(stream_name=stream_name).keys())
columns_list_str = indent(
"\n, ".join([self._quote_identifier(c) for c in columns_list]),
"\n, ".join([self._quote_identifier(col) for col in columns_list]),
" ",
)
files_list = ", ".join([f"'{f!s}'" for f in files])
columns_type_map = indent(
"\n, ".join(
[
self._quote_identifier(prop_name)
self._quote_identifier(self.normalizer.normalize(prop_name))
+ ": "
+ str(
self._get_sql_column_definitions(stream_name)[
Expand Down
2 changes: 1 addition & 1 deletion airbyte/_processors/sql/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def path_str(path: Path) -> str:
files_list = ", ".join([f"'{f.name}'" for f in files])
columns_list_str: str = indent("\n, ".join(columns_list), " " * 12)
variant_cols_str: str = ("\n" + " " * 21 + ", ").join(
[f"$1:{col}" for col in properties_list]
[f"$1:{self.normalizer.normalize(col)}" for col in properties_list]
)
copy_statement = dedent(
f"""
Expand Down
53 changes: 22 additions & 31 deletions airbyte/_util/name_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
from collections.abc import Iterable


class NameNormalizerBase(abc.ABC):
Expand Down Expand Up @@ -50,16 +50,19 @@ def normalize(name: str) -> str:
return name.lower().replace(" ", "_").replace("-", "_")


class CaseInsensitiveDict(dict[str, Any]):
"""A case-aware, case-insensitive dictionary implementation.
class StreamRecord(dict[str, Any]):
"""The StreamRecord class is a case-aware, case-insensitive dictionary implementation.
It has these behaviors:
- When a key is retrieved, deleted, or checked for existence, it is always checked in a
case-insensitive manner.
- The original case is stored in a separate dictionary, so that the original case can be
retrieved when needed.
There are two ways to store keys internally:
This behavior mirrors how a case-aware, case-insensitive SQL database would handle column
references.
There are two ways this class can store keys internally:
- If normalize_keys is True, the keys are normalized using the given normalizer.
- If normalize_keys is False, the original case of the keys is stored.
Expand Down Expand Up @@ -88,14 +91,18 @@ def __init__(
self,
from_dict: dict,
*,
prune_extra_fields: bool,
normalize_keys: bool = True,
normalizer: type[NameNormalizerBase] | None = None,
expected_keys: list[str] | None = None,
) -> None:
"""Initialize the dictionary with the given data.
If normalize_keys is True, the keys will be normalized using the given normalizer.
If expected_keys is provided, the dictionary will be initialized with the given keys.
Args:
- normalize_keys: If `True`, the keys will be normalized using the given normalizer.
- expected_keys: If provided, the dictionary will be initialized with these given keys.
- expected_keys: If provided and `prune_extra_fields` is True, then unexpected fields
will be removed. This option is ignored if `expected_keys` is not provided.
"""
# If no normalizer is provided, use LowerCaseNormalizer.
self._normalize_keys = normalize_keys
Expand All @@ -104,6 +111,7 @@ def __init__(
# If no expected keys are provided, use all keys from the input dictionary.
if not expected_keys:
expected_keys = list(from_dict.keys())
prune_extra_fields = False # No expected keys provided.

# Store a lookup from normalized keys to pretty cased (originally cased) keys.
self._pretty_case_keys: dict[str, str] = {
Expand All @@ -118,7 +126,12 @@ def __init__(

self.update({k: None for k in index_keys}) # Start by initializing all values to None
for k, v in from_dict.items():
self[self._index_case(k)] = v
index_cased_key = self._index_case(k)
if prune_extra_fields and index_cased_key not in index_keys:
# Dropping undeclared field
continue

self[index_cased_key] = v

def __getitem__(self, key: str) -> Any: # noqa: ANN401
if super().__contains__(key):
Expand Down Expand Up @@ -166,7 +179,7 @@ def __len__(self) -> int:
return super().__len__()

def __eq__(self, other: object) -> bool:
if isinstance(other, CaseInsensitiveDict):
if isinstance(other, StreamRecord):
return dict(self) == dict(other)

if isinstance(other, dict):
Expand All @@ -176,30 +189,8 @@ def __eq__(self, other: object) -> bool:
return False


def normalize_records(
records: Iterable[dict[str, Any]],
expected_keys: list[str],
) -> Iterator[CaseInsensitiveDict]:
"""Add missing columns to the record with null values.
Also conform the column names to the case in the catalog.
This is a generator that yields CaseInsensitiveDicts, which allows for case-insensitive
lookups of columns. This is useful because the case of the columns in the records may
not match the case of the columns in the catalog.
"""
yield from (
CaseInsensitiveDict(
from_dict=record,
expected_keys=expected_keys,
)
for record in records
)


__all__ = [
"NameNormalizerBase",
"LowerCaseNormalizer",
"CaseInsensitiveDict",
"normalize_records",
"StreamRecord",
]
108 changes: 0 additions & 108 deletions airbyte/_util/protocol_util.py

This file was deleted.

18 changes: 9 additions & 9 deletions airbyte/sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
)

from airbyte import exceptions as exc
from airbyte._util import protocol_util
from airbyte._util.name_normalizers import normalize_records
from airbyte._util.name_normalizers import StreamRecord
from airbyte._util.telemetry import (
EventState,
EventType,
Expand Down Expand Up @@ -73,7 +72,7 @@ def as_temp_files(files_contents: list[dict | str]) -> Generator[list[str], Any,
finally:
for temp_file in temp_files:
with suppress(Exception):
temp_file.unlink()
Path(temp_file.name).unlink()


class Source:
Expand Down Expand Up @@ -437,13 +436,14 @@ def _with_logging(records: Iterable[dict[str, Any]]) -> Iterator[dict[str, Any]]
self._log_sync_success(cache=None)

iterator: Iterator[dict[str, Any]] = _with_logging(
normalize_records(
records=protocol_util.airbyte_messages_to_record_dicts(
self._read_with_catalog(configured_catalog),
stream_schema=self.get_stream_json_schema(stream),
records=( # Generator comprehension yields StreamRecord objects for each record
StreamRecord(
from_dict=record.record.data,
expected_keys=all_properties,
prune_extra_fields=True,
),
expected_keys=all_properties,
)
for record in self._read_with_catalog(configured_catalog)
if record.record
)
)
return LazyDataset(
Expand Down
24 changes: 16 additions & 8 deletions tests/unit_tests/test_text_normalization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from math import exp
import pytest
from airbyte._util.name_normalizers import CaseInsensitiveDict, LowerCaseNormalizer
from airbyte._util.name_normalizers import StreamRecord, LowerCaseNormalizer

def test_case_insensitive_dict():
# Initialize a CaseInsensitiveDict
cid = CaseInsensitiveDict({"Upper": 1, "lower": 2})
def test_case_insensitive_dict() -> None:
# Initialize a StreamRecord
cid = StreamRecord(
{"Upper": 1, "lower": 2},
prune_extra_fields=True,
)

# Test __getitem__
assert cid["Upper"] == 1
Expand Down Expand Up @@ -59,8 +62,12 @@ def test_case_insensitive_dict():


def test_case_insensitive_dict_w() -> None:
# Initialize a CaseInsensitiveDict
cid = CaseInsensitiveDict({"Upper": 1, "lower": 2}, expected_keys=["Upper", "lower", "other"])
# Initialize a StreamRecord
cid = StreamRecord(
{"Upper": 1, "lower": 2},
expected_keys=["Upper", "lower", "other"],
prune_extra_fields=True,
)

# Test __len__
assert len(cid) == 3
Expand All @@ -80,11 +87,12 @@ def test_case_insensitive_dict_w() -> None:


def test_case_insensitive_w_pretty_keys() -> None:
# Initialize a CaseInsensitiveDict
cid = CaseInsensitiveDict(
# Initialize a StreamRecord
cid = StreamRecord(
{"Upper": 1, "lower": 2},
expected_keys=["Upper", "lower", "other"],
normalize_keys=False,
prune_extra_fields=True,
)

# Test __len__
Expand Down

0 comments on commit 4672849

Please sign in to comment.