1515import json
1616import logging
1717import os
18+ import shutil
19+ import tempfile
1820import time
1921from collections import OrderedDict
20- from typing import Dict , Union
21- from typing import List
22+ from math import ceil
23+ from typing import Dict , List , Tuple , Union
2224from urllib .parse import urlparse
2325
2426import fastavro
@@ -642,11 +644,11 @@ def ingest(
642644 raise Exception (f"Feature set name must be provided" )
643645
644646 # Read table and get row count
645- tmp_table_name = _read_table_from_source (
647+ dir_path , dest_path = _read_table_from_source (
646648 source , chunk_size , max_workers
647649 )
648650
649- pq_file = pq .ParquetFile (tmp_table_name )
651+ pq_file = pq .ParquetFile (dest_path )
650652
651653 row_count = pq_file .metadata .num_rows
652654
@@ -688,7 +690,7 @@ def ingest(
688690 # Transform and push data to Kafka
689691 if feature_set .source .source_type == "Kafka" :
690692 for chunk in get_feature_row_chunks (
691- file = tmp_table_name ,
693+ file = dest_path ,
692694 row_groups = list (range (pq_file .num_row_groups )),
693695 fs = feature_set ,
694696 max_workers = max_workers ):
@@ -715,7 +717,7 @@ def ingest(
715717 finally :
716718 # Remove parquet file(s) that were created earlier
717719 print ("Removing temporary file(s)..." )
718- os . remove ( tmp_table_name )
720+ shutil . rmtree ( dir_path )
719721
720722 return None
721723
@@ -753,7 +755,7 @@ def _read_table_from_source(
753755 source : Union [pd .DataFrame , str ],
754756 chunk_size : int ,
755757 max_workers : int
756- ) -> str :
758+ ) -> Tuple [ str , str ] :
757759 """
758760 Infers a data source type (path or Pandas DataFrame) and reads it in as
759761 a PyArrow Table.
@@ -777,7 +779,9 @@ def _read_table_from_source(
777779 Amount of rows to load and ingest at a time.
778780
779781 Returns:
780- str: Path to parquet file that was created.
782+ Tuple[str, str]:
783+ Tuple containing parent directory path and destination path to
784+ parquet file.
781785 """
782786
783787 # Pandas DataFrame detected
@@ -807,12 +811,13 @@ def _read_table_from_source(
807811 assert isinstance (table , pa .lib .Table )
808812
809813 # Write table as parquet file with a specified row_group_size
814+ dir_path = tempfile .mkdtemp ()
810815 tmp_table_name = f"{ int (time .time ())} .parquet"
811- row_group_size = min ( int ( table . num_rows / max_workers ), chunk_size )
812- pq . write_table ( table = table , where = tmp_table_name ,
813- row_group_size = row_group_size )
816+ dest_path = f" { dir_path } / { tmp_table_name } "
817+ row_group_size = min ( ceil ( table . num_rows / max_workers ), chunk_size )
818+ pq . write_table ( table = table , where = dest_path , row_group_size = row_group_size )
814819
815820 # Remove table from memory
816821 del table
817822
818- return tmp_table_name
823+ return dir_path , dest_path
0 commit comments