Skip to content

Commit 1edf2e0

Browse files
author
Alberto Alvarez
authored
Adding dtypes option to Workflow (NVIDIA-Merlin#392)
1 parent 3424aef commit 1edf2e0

2 files changed

Lines changed: 94 additions & 4 deletions

File tree

nvtabular/workflow.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from fsspec.core import get_fs_token_paths
2525

2626
from nvtabular.io.dask import _ddf_to_dataset
27-
from nvtabular.io.dataset import Dataset
27+
from nvtabular.io.dataset import Dataset, _set_dtypes
2828
from nvtabular.io.shuffle import Shuffle, _check_shuffle_arg
2929
from nvtabular.io.writer_factory import writer_factory
3030
from nvtabular.ops import DFOperator, StatOperator, TransformOperator
@@ -483,7 +483,9 @@ def _run_trans_ops_for_phase(self, gdf, tasks):
483483
gdf = op.apply_op(gdf, self.columns_ctx, cols_grp, target_cols=target_cols)
484484
return gdf
485485

486-
def apply_ops(self, gdf, start_phase=None, end_phase=None, writer=None, output_path=None):
486+
def apply_ops(
487+
self, gdf, start_phase=None, end_phase=None, writer=None, output_path=None, dtypes=None
488+
):
487489
"""
488490
gdf: cudf dataframe
489491
Controls the application of registered preprocessing phase op
@@ -508,6 +510,8 @@ def apply_ops(self, gdf, start_phase=None, end_phase=None, writer=None, output_p
508510
writer.need_cal_col_names = False
509511

510512
start_write = time.time()
513+
# Special dtype conversion
514+
gdf = _set_dtypes(gdf, dtypes)
511515
writer.add_data(gdf)
512516
self.timings["write_df"] += time.time() - start_write
513517

@@ -714,6 +718,7 @@ def apply(
714718
output_format="parquet",
715719
out_files_per_proc=None,
716720
num_io_threads=0,
721+
dtypes=None,
717722
):
718723
"""
719724
Runs all the preprocessing and feature engineering operators.
@@ -753,6 +758,9 @@ def apply(
753758
num_io_threads : integer
754759
Number of IO threads to use for writing the output dataset.
755760
For `0` (default), no dedicated IO threads will be used.
761+
dtypes : dict
762+
Dictionary containing desired datatypes for output columns.
763+
Keys are column names, values are datatypes.
756764
"""
757765

758766
# Check shuffle argument
@@ -773,6 +781,7 @@ def apply(
773781
output_format=output_format,
774782
out_files_per_proc=out_files_per_proc,
775783
num_io_threads=num_io_threads,
784+
dtypes=dtypes,
776785
)
777786
else:
778787
self.iterate_online(
@@ -782,6 +791,7 @@ def apply(
782791
output_format=output_format,
783792
out_files_per_proc=out_files_per_proc,
784793
num_io_threads=num_io_threads,
794+
dtypes=dtypes,
785795
)
786796

787797
def iterate_online(
@@ -794,6 +804,7 @@ def iterate_online(
794804
out_files_per_proc=None,
795805
apply_ops=True,
796806
num_io_threads=0,
807+
dtypes=None,
797808
):
798809
"""Iterate through dataset and (optionally) apply/shuffle/write."""
799810
# Check shuffle argument
@@ -813,8 +824,9 @@ def iterate_online(
813824

814825
# Iterate through dataset, apply ops, and write out processed data
815826
if apply_ops:
816-
for gdf in dataset.to_iter(shuffle=(shuffle is not None)):
817-
self.apply_ops(gdf, output_path=output_path, writer=writer)
827+
columns = self.columns_ctx["all"]["base"]
828+
for gdf in dataset.to_iter(shuffle=(shuffle is not None), columns=columns):
829+
self.apply_ops(gdf, output_path=output_path, writer=writer, dtypes=dtypes)
818830

819831
# Close writer and write general/specialized metadata
820832
if writer:
@@ -844,6 +856,7 @@ def build_and_process_graph(
844856
out_files_per_proc=None,
845857
apply_ops=True,
846858
num_io_threads=0,
859+
dtypes=None,
847860
):
848861
"""Build Dask-task graph for workflow.
849862
@@ -873,6 +886,12 @@ def build_and_process_graph(
873886
for idx, _ in enumerate(self.phases[:end]):
874887
self.exec_phase(idx, record_stats=record_stats, update_ddf=(idx == (end - 1)))
875888
self._base_phase = 0 # Re-Set _base_phase
889+
890+
if dtypes:
891+
ddf = self.get_ddf()
892+
_meta = _set_dtypes(ddf._meta, dtypes)
893+
self.set_ddf(ddf.map_partitions(_set_dtypes, dtypes, meta=_meta))
894+
876895
if output_format:
877896
output_path = output_path or "./"
878897
output_path = str(output_path)
@@ -895,6 +914,7 @@ def write_to_dataset(
895914
iterate=False,
896915
nfiles=None,
897916
num_io_threads=0,
917+
dtypes=None,
898918
):
899919
"""Write data to shuffled parquet dataset.
900920
@@ -919,6 +939,7 @@ def write_to_dataset(
919939
out_files_per_proc=out_files_per_proc,
920940
apply_ops=apply_ops,
921941
num_io_threads=num_io_threads,
942+
dtypes=dtypes,
922943
)
923944
else:
924945
self.build_and_process_graph(
@@ -930,6 +951,7 @@ def write_to_dataset(
930951
out_files_per_proc=out_files_per_proc,
931952
apply_ops=apply_ops,
932953
num_io_threads=num_io_threads,
954+
dtypes=dtypes,
933955
)
934956

935957
def ddf_to_dataset(

tests/unit/test_workflow.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,71 @@ def test_chaining_3():
442442
assert all(
443443
x in result.columns for x in ["ad_id_count", "ad_id_clicked_sum_ctr", "ad_id_clicked_sum"]
444444
)
445+
446+
447+
@pytest.mark.parametrize("shuffle", [nvt.io.Shuffle.PER_WORKER, nvt.io.Shuffle.PER_PARTITION, None])
448+
@pytest.mark.parametrize("use_client", [True, False])
449+
@pytest.mark.parametrize("apply_offline", [True, False])
450+
def test_workflow_apply(client, use_client, tmpdir, shuffle, apply_offline):
451+
out_files_per_proc = 2
452+
out_path = str(tmpdir.mkdir("processed"))
453+
path = str(tmpdir.join("simple.parquet"))
454+
455+
size = 25
456+
row_group_size = 5
457+
458+
cont_columns = ["cont1", "cont2"]
459+
cat_columns = ["cat1", "cat2"]
460+
label_column = ["label"]
461+
462+
df = pd.DataFrame(
463+
{
464+
"cont1": np.arange(size, dtype=np.float64),
465+
"cont2": np.arange(size, dtype=np.float64),
466+
"cat1": np.arange(size, dtype=np.int32),
467+
"cat2": np.arange(size, dtype=np.int32),
468+
"label": np.arange(size, dtype=np.float64),
469+
}
470+
)
471+
df.to_parquet(path, row_group_size=row_group_size, engine="pyarrow")
472+
473+
dataset = nvt.Dataset(path, engine="parquet", row_groups_per_part=1)
474+
processor = nvt.Workflow(
475+
cat_names=cat_columns,
476+
cont_names=cont_columns,
477+
label_name=label_column,
478+
client=client if use_client else None,
479+
)
480+
processor.add_cont_feature([ops.FillMissing(), ops.Clip(min_value=0), ops.LogOp()])
481+
processor.add_cat_preprocess(ops.Categorify())
482+
483+
processor.finalize()
484+
# Force dtypes
485+
dict_dtypes = {}
486+
for col in cont_columns:
487+
dict_dtypes[col] = np.float32
488+
for col in cat_columns:
489+
dict_dtypes[col] = np.float32
490+
for col in label_column:
491+
dict_dtypes[col] = np.int64
492+
493+
if not apply_offline:
494+
processor.apply(
495+
dataset,
496+
output_format=None,
497+
record_stats=True,
498+
)
499+
processor.apply(
500+
dataset,
501+
apply_offline=apply_offline,
502+
record_stats=apply_offline,
503+
output_path=out_path,
504+
shuffle=shuffle,
505+
out_files_per_proc=out_files_per_proc,
506+
dtypes=dict_dtypes,
507+
)
508+
509+
# Check dtypes
510+
for filename in glob.glob(os.path.join(out_path, "*.parquet")):
511+
gdf = cudf.io.read_parquet(filename)
512+
assert dict(gdf.dtypes) == dict_dtypes

0 commit comments

Comments
 (0)