|
1 | 1 | #include "py_conversion.h" |
2 | 2 |
|
3 | 3 | #include "cached_import/py_cached_import.h" |
| 4 | +#include "common/case_insensitive_map.h" |
4 | 5 | #include "common/exception/not_implemented.h" |
5 | 6 | #include "common/type_utils.h" |
6 | 7 | #include "common/types/uuid.h" |
| 8 | +#include "py_objects.h" |
7 | 9 |
|
8 | 10 | namespace kuzu { |
9 | 11 |
|
@@ -34,6 +36,8 @@ PythonObjectType getPythonObjectType(py::handle& ele) { |
34 | 36 | return PythonObjectType::List; |
35 | 37 | } else if (py::isinstance(ele, uuid)) { |
36 | 38 | return PythonObjectType::UUID; |
| 39 | + } else if (py::isinstance<py::dict>(ele)) { |
| 40 | + return PythonObjectType::Dict; |
37 | 41 | } else { |
38 | 42 | throw NotImplementedException(stringFormat("Scanning of type {} has not been implemented", |
39 | 43 | py::str(py::type::of(ele)).cast<std::string>())); |
@@ -63,7 +67,68 @@ void transformListValue(common::ValueVector* outputVector, uint64_t pos, py::han |
63 | 67 | } |
64 | 68 | } |
65 | 69 |
|
| 70 | +static std::vector<std::string> transformStructKeys(py::handle keys, idx_t size) { |
| 71 | + std::vector<std::string> res; |
| 72 | + res.reserve(size); |
| 73 | + for (auto i = 0u; i < size; i++) { |
| 74 | + res.emplace_back(py::str(keys.attr("__getitem__")(i))); |
| 75 | + } |
| 76 | + return res; |
| 77 | +} |
| 78 | + |
| 79 | +void transformDictionaryToStruct(common::ValueVector* outputVector, uint64_t pos, |
| 80 | + const PyDictionary& dict) { |
| 81 | + KU_ASSERT(outputVector->dataType.getLogicalTypeID() == LogicalTypeID::STRUCT); |
| 82 | + auto structKeys = transformStructKeys(dict.keys, dict.len); |
| 83 | + if (StructType::getNumFields(outputVector->dataType) != dict.len) { |
| 84 | + throw common::ConversionException( |
| 85 | + common::stringFormat("Failed to convert python dictionary: {} to target type {}", |
| 86 | + dict.toString(), outputVector->dataType.toString())); |
| 87 | + } |
| 88 | + |
| 89 | + common::case_insensitive_map_t<idx_t> keyMap; |
| 90 | + for (idx_t i = 0; i < structKeys.size(); i++) { |
| 91 | + keyMap[structKeys[i]] = i; |
| 92 | + } |
| 93 | + |
| 94 | + for (auto i = 0u; i < StructType::getNumFields(outputVector->dataType); i++) { |
| 95 | + auto& field = StructType::getField(outputVector->dataType, i); |
| 96 | + auto idx = keyMap[field.getName()]; |
| 97 | + transformPythonValue(StructVector::getFieldVector(outputVector, i).get(), pos, |
| 98 | + dict.values.attr("__getitem__")(idx)); |
| 99 | + } |
| 100 | +} |
| 101 | + |
| 102 | +void transformDictionaryToMap(common::ValueVector* outputVector, uint64_t pos, |
| 103 | + const PyDictionary& dict) { |
| 104 | + KU_ASSERT(outputVector->dataType.getLogicalTypeID() == LogicalTypeID::MAP); |
| 105 | + auto keys = dict.values.attr("__getitem__")(0); |
| 106 | + auto values = dict.values.attr("__getitem__")(1); |
| 107 | + |
| 108 | + if (py::none().is(keys) || py::none().is(values)) { |
| 109 | + // Null map |
| 110 | + outputVector->setNull(pos, true /* isNull */); |
| 111 | + } |
| 112 | + |
| 113 | + auto numKeys = py::len(keys); |
| 114 | + KU_ASSERT(numKeys == py::len(values)); |
| 115 | + auto listEntry = ListVector::addList(outputVector, numKeys); |
| 116 | + outputVector->setValue(pos, listEntry); |
| 117 | + auto structVector = ListVector::getDataVector(outputVector); |
| 118 | + auto keyVector = StructVector::getFieldVector(structVector, 0); |
| 119 | + auto valVector = StructVector::getFieldVector(structVector, 1); |
| 120 | + for (auto i = 0u; i < numKeys; i++) { |
| 121 | + transformPythonValue(keyVector.get(), listEntry.offset + i, keys.attr("__getitem__")(i)); |
| 122 | + transformPythonValue(valVector.get(), listEntry.offset + i, values.attr("__getitem__")(i)); |
| 123 | + } |
| 124 | +} |
| 125 | + |
66 | 126 | void transformPythonValue(common::ValueVector* outputVector, uint64_t pos, py::handle ele) { |
| 127 | + if (ele.is_none()) { |
| 128 | + outputVector->setNull(pos, true /* isNull */); |
| 129 | + return; |
| 130 | + } |
| 131 | + outputVector->setNull(pos, false /* isNull */); |
67 | 132 | auto objType = getPythonObjectType(ele); |
68 | 133 | switch (objType) { |
69 | 134 | case PythonObjectType::None: { |
@@ -117,6 +182,19 @@ void transformPythonValue(common::ValueVector* outputVector, uint64_t pos, py::h |
117 | 182 | UUID::fromString(ele.attr("hex").cast<std::string>(), result); |
118 | 183 | outputVector->setValue(pos, result); |
119 | 184 | } break; |
| 185 | + case PythonObjectType::Dict: { |
| 186 | + PyDictionary dict = PyDictionary(py::reinterpret_borrow<py::object>(ele)); |
| 187 | + switch (outputVector->dataType.getLogicalTypeID()) { |
| 188 | + case LogicalTypeID::STRUCT: { |
| 189 | + transformDictionaryToStruct(outputVector, pos, dict); |
| 190 | + } break; |
| 191 | + case LogicalTypeID::MAP: { |
| 192 | + transformDictionaryToMap(outputVector, pos, dict); |
| 193 | + } break; |
| 194 | + default: |
| 195 | + KU_UNREACHABLE; |
| 196 | + } |
| 197 | + } break; |
120 | 198 | default: |
121 | 199 | KU_UNREACHABLE; |
122 | 200 | } |
|
0 commit comments