Skip to content

Commit e0e3cca

Browse files
jperez999benfred
andauthored
adding cvg temp fix (NVIDIA-Merlin#427)
* adding cvg temp fix * add tensor check Co-authored-by: Ben Frederickson <github@benfrederickson.com>
1 parent 9aa70ca commit e0e3cca

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

nvtabular/loader/torch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ def _LONG_DTYPE(self):
107107
def _FLOAT32_DTYPE(self):
108108
return torch.float32
109109

110+
def _handle_tensors(self, cats, conts, labels):
111+
if isinstance(conts, torch.Tensor):
112+
conts = conts.clone()
113+
return cats, conts, labels
114+
110115

111116
class DLDataLoader(torch.utils.data.DataLoader):
112117
"""

0 commit comments

Comments
 (0)