2424from fsspec .core import get_fs_token_paths
2525
2626from nvtabular .io .dask import _ddf_to_dataset
27- from nvtabular .io .dataset import Dataset
27+ from nvtabular .io .dataset import Dataset , _set_dtypes
2828from nvtabular .io .shuffle import Shuffle , _check_shuffle_arg
2929from nvtabular .io .writer_factory import writer_factory
3030from nvtabular .ops import DFOperator , StatOperator , TransformOperator
@@ -483,7 +483,9 @@ def _run_trans_ops_for_phase(self, gdf, tasks):
483483 gdf = op .apply_op (gdf , self .columns_ctx , cols_grp , target_cols = target_cols )
484484 return gdf
485485
486- def apply_ops (self , gdf , start_phase = None , end_phase = None , writer = None , output_path = None ):
486+ def apply_ops (
487+ self , gdf , start_phase = None , end_phase = None , writer = None , output_path = None , dtypes = None
488+ ):
487489 """
488490 gdf: cudf dataframe
489491 Controls the application of registered preprocessing phase op
@@ -508,6 +510,8 @@ def apply_ops(self, gdf, start_phase=None, end_phase=None, writer=None, output_p
508510 writer .need_cal_col_names = False
509511
510512 start_write = time .time ()
513+ # Special dtype conversion
514+ gdf = _set_dtypes (gdf , dtypes )
511515 writer .add_data (gdf )
512516 self .timings ["write_df" ] += time .time () - start_write
513517
@@ -714,6 +718,7 @@ def apply(
714718 output_format = "parquet" ,
715719 out_files_per_proc = None ,
716720 num_io_threads = 0 ,
721+ dtypes = None ,
717722 ):
718723 """
719724 Runs all the preprocessing and feature engineering operators.
@@ -753,6 +758,9 @@ def apply(
753758 num_io_threads : integer
754759 Number of IO threads to use for writing the output dataset.
755760 For `0` (default), no dedicated IO threads will be used.
761+ dtypes : dict
762+ Dictionary containing desired datatypes for output columns.
763+ Keys are column names, values are datatypes.
756764 """
757765
758766 # Check shuffle argument
@@ -773,6 +781,7 @@ def apply(
773781 output_format = output_format ,
774782 out_files_per_proc = out_files_per_proc ,
775783 num_io_threads = num_io_threads ,
784+ dtypes = dtypes ,
776785 )
777786 else :
778787 self .iterate_online (
@@ -782,6 +791,7 @@ def apply(
782791 output_format = output_format ,
783792 out_files_per_proc = out_files_per_proc ,
784793 num_io_threads = num_io_threads ,
794+ dtypes = dtypes ,
785795 )
786796
787797 def iterate_online (
@@ -794,6 +804,7 @@ def iterate_online(
794804 out_files_per_proc = None ,
795805 apply_ops = True ,
796806 num_io_threads = 0 ,
807+ dtypes = None ,
797808 ):
798809 """Iterate through dataset and (optionally) apply/shuffle/write."""
799810 # Check shuffle argument
@@ -813,8 +824,9 @@ def iterate_online(
813824
814825 # Iterate through dataset, apply ops, and write out processed data
815826 if apply_ops :
816- for gdf in dataset .to_iter (shuffle = (shuffle is not None )):
817- self .apply_ops (gdf , output_path = output_path , writer = writer )
827+ columns = self .columns_ctx ["all" ]["base" ]
828+ for gdf in dataset .to_iter (shuffle = (shuffle is not None ), columns = columns ):
829+ self .apply_ops (gdf , output_path = output_path , writer = writer , dtypes = dtypes )
818830
819831 # Close writer and write general/specialized metadata
820832 if writer :
@@ -844,6 +856,7 @@ def build_and_process_graph(
844856 out_files_per_proc = None ,
845857 apply_ops = True ,
846858 num_io_threads = 0 ,
859+ dtypes = None ,
847860 ):
848861 """Build Dask-task graph for workflow.
849862
@@ -873,6 +886,12 @@ def build_and_process_graph(
873886 for idx , _ in enumerate (self .phases [:end ]):
874887 self .exec_phase (idx , record_stats = record_stats , update_ddf = (idx == (end - 1 )))
875888 self ._base_phase = 0 # Re-Set _base_phase
889+
890+ if dtypes :
891+ ddf = self .get_ddf ()
892+ _meta = _set_dtypes (ddf ._meta , dtypes )
893+ self .set_ddf (ddf .map_partitions (_set_dtypes , dtypes , meta = _meta ))
894+
876895 if output_format :
877896 output_path = output_path or "./"
878897 output_path = str (output_path )
@@ -895,6 +914,7 @@ def write_to_dataset(
895914 iterate = False ,
896915 nfiles = None ,
897916 num_io_threads = 0 ,
917+ dtypes = None ,
898918 ):
899919 """Write data to shuffled parquet dataset.
900920
@@ -919,6 +939,7 @@ def write_to_dataset(
919939 out_files_per_proc = out_files_per_proc ,
920940 apply_ops = apply_ops ,
921941 num_io_threads = num_io_threads ,
942+ dtypes = dtypes ,
922943 )
923944 else :
924945 self .build_and_process_graph (
@@ -930,6 +951,7 @@ def write_to_dataset(
930951 out_files_per_proc = out_files_per_proc ,
931952 apply_ops = apply_ops ,
932953 num_io_threads = num_io_threads ,
954+ dtypes = dtypes ,
933955 )
934956
935957 def ddf_to_dataset (
0 commit comments