diff --git a/airbyte_cdk/sources/file_based/exceptions.py b/airbyte_cdk/sources/file_based/exceptions.py index cd727463a..5953850b2 100644 --- a/airbyte_cdk/sources/file_based/exceptions.py +++ b/airbyte_cdk/sources/file_based/exceptions.py @@ -24,8 +24,12 @@ class FileBasedSourceError(Enum): ) ERROR_VALIDATING_RECORD = "One or more records do not pass the schema validation policy. Please modify your input schema, or select a more lenient validation policy." ERROR_VALIDATION_STREAM_DISCOVERY_OPTIONS = "Only one of options 'Schemaless', 'Input Schema', 'Files To Read For Schema Discover' or 'Use First Found File For Schema Discover' can be provided at the same time." - ERROR_PARSING_RECORD_MISMATCHED_COLUMNS = "A header field has resolved to `None`. This indicates that the CSV has more rows than the number of header fields. If you input your schema or headers, please verify that the number of columns corresponds to the number of columns in your CSV's rows." - ERROR_PARSING_RECORD_MISMATCHED_ROWS = "A row's value has resolved to `None`. This indicates that the CSV has more columns in the header field than the number of columns in the row(s). If you input your schema or headers, please verify that the number of columns corresponds to the number of columns in your CSV's rows." + ERROR_PARSING_RECORD_MISMATCHED_COLUMNS = ( + "CSV data row contains more columns than the header row defines." + ) + ERROR_PARSING_RECORD_MISMATCHED_ROWS = ( + "CSV data row contains fewer columns than the header row defines." + ) STOP_SYNC_PER_SCHEMA_VALIDATION_POLICY = "Stopping sync in accordance with the configured validation policy. Records in file did not conform to the schema." NULL_VALUE_IN_SCHEMA = "Error during schema inference: no type was detected for key." UNRECOGNIZED_TYPE = "Error during schema inference: unrecognized type." diff --git a/airbyte_cdk/sources/file_based/file_types/csv_parser.py b/airbyte_cdk/sources/file_based/file_types/csv_parser.py index edab346fe..3b3dc4a0f 100644 --- a/airbyte_cdk/sources/file_based/file_types/csv_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/csv_parser.py @@ -65,24 +65,26 @@ def read_data( doublequote=config_format.double_quote, quoting=csv.QUOTE_MINIMAL, ) - with stream_reader.open_file(file, file_read_mode, config_format.encoding, logger) as fp: - try: - headers = self._get_headers(fp, config_format, dialect_name) - except UnicodeError: - raise AirbyteTracedException( - message=f"{FileBasedSourceError.ENCODING_ERROR.value} Expected encoding: {config_format.encoding}", + try: + with stream_reader.open_file( + file, file_read_mode, config_format.encoding, logger + ) as fp: + try: + headers = self._get_headers(fp, config_format, dialect_name) + except UnicodeError: + raise AirbyteTracedException( + message=f"{FileBasedSourceError.ENCODING_ERROR.value} Expected encoding: {config_format.encoding}", + ) + + rows_to_skip = ( + config_format.skip_rows_before_header + + (1 if config_format.header_definition.has_header_row() else 0) + + config_format.skip_rows_after_header ) + self._skip_rows(fp, rows_to_skip) + lineno += rows_to_skip - rows_to_skip = ( - config_format.skip_rows_before_header - + (1 if config_format.header_definition.has_header_row() else 0) - + config_format.skip_rows_after_header - ) - self._skip_rows(fp, rows_to_skip) - lineno += rows_to_skip - - reader = csv.DictReader(fp, dialect=dialect_name, fieldnames=headers) # type: ignore - try: + reader = csv.DictReader(fp, dialect=dialect_name, fieldnames=headers) # type: ignore for row in reader: lineno += 1 @@ -111,14 +113,11 @@ def read_data( lineno=lineno, ) yield row - finally: - # due to RecordParseError or GeneratorExit - csv.unregister_dialect(dialect_name) + finally: + csv.unregister_dialect(dialect_name) def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str) -> List[str]: - """ - Assumes the fp is pointing to the beginning of the files and will reset it as such - """ + """Assumes the fp is pointing to the beginning of the files and will reset it as such.""" # Note that this method assumes the dialect has already been registered if we're parsing the headers if isinstance(config_format.header_definition, CsvHeaderUserProvided): return config_format.header_definition.column_names @@ -134,6 +133,14 @@ def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str) reader = csv.reader(fp, dialect=dialect_name) # type: ignore headers = list(next(reader)) + empty_count = sum(1 for h in headers if not h or h.isspace()) + if empty_count: + raise AirbyteTracedException( + message="CSV header row contains empty column name(s). Remove trailing delimiters or empty columns from the header row.", + internal_message=f"Found {empty_count} empty/whitespace-only column name(s) in header: {headers}", + failure_type=FailureType.config_error, + ) + fp.seek(0) return headers @@ -227,7 +234,7 @@ def parse_records( logger: logging.Logger, discovered_schema: Optional[Mapping[str, SchemaType]], ) -> Iterable[Dict[str, Any]]: - line_no = 0 + data_generator = None try: config_format = _extract_format(config) if discovered_schema: @@ -244,19 +251,15 @@ def parse_records( config, file, stream_reader, logger, self.file_read_mode ) for row in data_generator: - line_no += 1 yield CsvParser._to_nullable( cast_fn(row), deduped_property_types, config_format.null_values, config_format.strings_can_be_null, ) - except RecordParseError as parse_err: - raise RecordParseError( - FileBasedSourceError.ERROR_PARSING_RECORD, filename=file.uri, lineno=line_no - ) from parse_err finally: - data_generator.close() + if data_generator is not None: + data_generator.close() @property def file_read_mode(self) -> FileReadMode: diff --git a/unit_tests/sources/file_based/file_types/test_csv_parser.py b/unit_tests/sources/file_based/file_types/test_csv_parser.py index 0b49dd66d..3238a96bb 100644 --- a/unit_tests/sources/file_based/file_types/test_csv_parser.py +++ b/unit_tests/sources/file_based/file_types/test_csv_parser.py @@ -11,6 +11,7 @@ from typing import Any, Dict, Generator, List, Set from unittest import TestCase, mock from unittest.mock import Mock +from uuid import uuid4 import pytest @@ -785,3 +786,124 @@ def test_encoding_is_passed_to_stream_reader() -> None: mock.call().__exit__(None, None, None), ] ) + + +@pytest.mark.parametrize( + "header_row, expected_empty_count", + [ + pytest.param("col1,col2,col3,,,", 3, id="trailing_empty_columns"), + pytest.param("col1,,col3", 1, id="middle_empty_column"), + pytest.param(",col2,col3", 1, id="leading_empty_column"), + pytest.param("col1,col2, ", 1, id="whitespace_only_column"), + ], +) +def test_get_headers_raises_on_empty_column_names( + header_row: str, expected_empty_count: int +) -> None: + csv_reader = _CsvReader() + config_format = CsvFormat() + fp = io.StringIO(header_row) + + dialect_name = f"test_{uuid4()}" + csv.register_dialect( + dialect_name, + delimiter=config_format.delimiter, + quotechar=config_format.quote_char, + escapechar=config_format.escape_char, + doublequote=config_format.double_quote, + quoting=csv.QUOTE_MINIMAL, + ) + + try: + with pytest.raises(AirbyteTracedException) as exc_info: + csv_reader._get_headers(fp, config_format, dialect_name) + + assert exc_info.value.failure_type == FailureType.config_error + assert "empty column name" in exc_info.value.message + assert f"{expected_empty_count} empty" in exc_info.value.internal_message + finally: + csv.unregister_dialect(dialect_name) + + +def test_get_headers_accepts_valid_headers() -> None: + csv_reader = _CsvReader() + config_format = CsvFormat() + fp = io.StringIO("col1,col2,col3") + + dialect_name = f"test_{uuid4()}" + csv.register_dialect( + dialect_name, + delimiter=config_format.delimiter, + quotechar=config_format.quote_char, + escapechar=config_format.escape_char, + doublequote=config_format.double_quote, + quoting=csv.QUOTE_MINIMAL, + ) + + try: + headers = csv_reader._get_headers(fp, config_format, dialect_name) + assert headers == ["col1", "col2", "col3"] + finally: + csv.unregister_dialect(dialect_name) + + +def test_read_data_raises_on_empty_column_names() -> None: + config_format = CsvFormat() + config = Mock() + config.name = "config_name" + config.format = config_format + + file = RemoteFile(uri="test.csv", last_modified=datetime.now()) + stream_reader = Mock(spec=AbstractFileBasedStreamReader) + logger = Mock(spec=logging.Logger) + csv_reader = _CsvReader() + + stream_reader.open_file.return_value = ( + CsvFileBuilder().with_data(["col1,col2,col3,,,", "v1,v2,v3,v4,v5,v6"]).build() + ) + + with pytest.raises(AirbyteTracedException) as exc_info: + list( + csv_reader.read_data( + config, + file, + stream_reader, + logger, + FileReadMode.READ, + ) + ) + + assert exc_info.value.failure_type == FailureType.config_error + assert "empty column name" in exc_info.value.message + + +def test_parse_records_preserves_mismatch_error_detail() -> None: + config_format = CsvFormat() + config = FileBasedStreamConfig( + name="test", + validation_policy="Emit Record", + file_type="csv", + format=config_format, + ) + + file = RemoteFile(uri="test.csv", last_modified=datetime.now()) + stream_reader = Mock() + mock_obj = stream_reader.open_file.return_value + mock_obj.__enter__ = Mock(return_value=io.StringIO("header\ntoo many values,value,value,value")) + mock_obj.__exit__ = Mock(return_value=None) + + parser = CsvParser() + + with pytest.raises(RecordParseError) as exc_info: + list( + parser.parse_records( + config, + file, + stream_reader, + logger, + {"properties": {"header": {"type": "string"}}}, + ) + ) + + error_msg = str(exc_info.value) + assert "more columns than the header" in error_msg