Skip to content

Commit 0a4c5ce

Browse files
authored
Support scan pandas dict (#5370)
1 parent c06e634 commit 0a4c5ce

6 files changed

Lines changed: 260 additions & 0 deletions

File tree

src_cpp/include/pandas/pandas_analyzer.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "common/types/types.h"
4+
#include "py_objects.h"
45
#include "pybind_include.h"
56

67
namespace kuzu {
@@ -21,6 +22,8 @@ class PandasAnalyzer {
2122

2223
private:
2324
common::LogicalType innerAnalyze(py::object column, bool& canConvert);
25+
common::LogicalType dictToMap(const PyDictionary& dict, bool& canConvert);
26+
common::LogicalType dictToStruct(const PyDictionary& dict, bool& canConvert);
2427

2528
private:
2629
PythonGILWrapper gil;

src_cpp/include/py_conversion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ enum class PythonObjectType : uint8_t {
2020
String,
2121
List,
2222
UUID,
23+
Dict,
2324
};
2425

2526
PythonObjectType getPythonObjectType(py::handle& ele);

src_cpp/include/py_objects.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#pragma once
2+
3+
#include "common/types/types.h"
4+
#include "pybind_include.h"
5+
6+
namespace kuzu {
7+
8+
struct PyDictionary {
9+
public:
10+
explicit PyDictionary(py::object dict)
11+
: keys{py::list(dict.attr("keys")())}, values{py::list(dict.attr("values")())},
12+
len{static_cast<common::idx_t>(py::len(keys))}, dict{std::move(dict)} {}
13+
14+
// These are cached so we don't have to create new objects all the time
15+
// The CPython API offers PyDict_Keys but that creates a new reference every time, same for
16+
// values
17+
py::object keys;
18+
py::object values;
19+
common::idx_t len;
20+
21+
public:
22+
py::handle operator[](const py::object& obj) const {
23+
return PyDict_GetItem(dict.ptr(), obj.ptr());
24+
}
25+
26+
std::string toString() const { return std::string(py::str(dict)); }
27+
28+
private:
29+
py::object dict;
30+
};
31+
32+
} // namespace kuzu

src_cpp/pandas/pandas_analyzer.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,81 @@ common::LogicalType PandasAnalyzer::getListType(py::object& ele, bool& canConver
5454
return listType;
5555
}
5656

57+
static bool isValidMapComponent(const py::handle& component) {
58+
if (py::none().is(component)) {
59+
return true;
60+
}
61+
if (!py::hasattr(component, "__getitem__")) {
62+
return false;
63+
}
64+
if (!py::hasattr(component, "__len__")) {
65+
return false;
66+
}
67+
return true;
68+
}
69+
70+
bool dictionaryHasMapFormat(const PyDictionary& dict) {
71+
if (dict.len != 2) {
72+
return false;
73+
}
74+
75+
// { 'key': [ .. keys .. ], 'value': [ .. values .. ]}
76+
auto keysKey = py::str("key");
77+
auto valuesKey = py::str("value");
78+
auto keys = dict[keysKey];
79+
auto values = dict[valuesKey];
80+
if (!keys || !values) {
81+
return false;
82+
}
83+
if (!isValidMapComponent(keys)) {
84+
return false;
85+
}
86+
if (!isValidMapComponent(values)) {
87+
return false;
88+
}
89+
if (py::none().is(keys) || py::none().is(values)) {
90+
return true;
91+
}
92+
auto size = py::len(keys);
93+
if (size != py::len(values)) {
94+
return false;
95+
}
96+
return true;
97+
}
98+
99+
common::LogicalType PandasAnalyzer::dictToMap(const PyDictionary& dict, bool& canConvert) {
100+
auto keys = dict.values.attr("__getitem__")(0);
101+
auto values = dict.values.attr("__getitem__")(1);
102+
103+
if (py::none().is(keys) || py::none().is(values)) {
104+
return common::LogicalType::MAP(common::LogicalType::ANY(), common::LogicalType::ANY());
105+
}
106+
107+
auto keyType = PandasAnalyzer::getListType(keys, canConvert);
108+
if (!canConvert) {
109+
return common::LogicalType::MAP(common::LogicalType::ANY(), common::LogicalType::ANY());
110+
}
111+
auto valueType = getListType(values, canConvert);
112+
if (!canConvert) {
113+
return common::LogicalType::MAP(common::LogicalType::ANY(), common::LogicalType::ANY());
114+
}
115+
116+
return common::LogicalType::MAP(std::move(keyType), std::move(valueType));
117+
}
118+
119+
common::LogicalType PandasAnalyzer::dictToStruct(const PyDictionary& dict, bool& canConvert) {
120+
std::vector<common::StructField> fields;
121+
122+
for (auto i = 0u; i < dict.len; i++) {
123+
auto dictKey = dict.keys.attr("__getitem__")(i);
124+
auto key = std::string(py::str(dictKey));
125+
auto dictVal = dict.values.attr("__getitem__")(i);
126+
auto val = getItemType(dictVal, canConvert);
127+
fields.emplace_back(std::move(key), std::move(val));
128+
}
129+
return common::LogicalType::STRUCT(std::move(fields));
130+
}
131+
57132
common::LogicalType PandasAnalyzer::getItemType(py::object ele, bool& canConvert) {
58133
auto objectType = getPythonObjectType(ele);
59134
switch (objectType) {
@@ -75,6 +150,16 @@ common::LogicalType PandasAnalyzer::getItemType(py::object ele, bool& canConvert
75150
return common::LogicalType::LIST(getListType(ele, canConvert));
76151
case PythonObjectType::UUID:
77152
return common::LogicalType::UUID();
153+
case PythonObjectType::Dict: {
154+
PyDictionary dict = PyDictionary(py::reinterpret_borrow<py::object>(ele));
155+
if (dict.len == 0) {
156+
return common::LogicalType::MAP(common::LogicalType::ANY(), common::LogicalType::ANY());
157+
}
158+
if (dictionaryHasMapFormat(dict)) {
159+
return dictToMap(dict, canConvert);
160+
}
161+
return dictToStruct(dict, canConvert);
162+
}
78163
default:
79164
KU_UNREACHABLE;
80165
}

src_cpp/py_conversion.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#include "py_conversion.h"
22

33
#include "cached_import/py_cached_import.h"
4+
#include "common/case_insensitive_map.h"
45
#include "common/exception/not_implemented.h"
56
#include "common/type_utils.h"
67
#include "common/types/uuid.h"
8+
#include "py_objects.h"
79

810
namespace kuzu {
911

@@ -34,6 +36,8 @@ PythonObjectType getPythonObjectType(py::handle& ele) {
3436
return PythonObjectType::List;
3537
} else if (py::isinstance(ele, uuid)) {
3638
return PythonObjectType::UUID;
39+
} else if (py::isinstance<py::dict>(ele)) {
40+
return PythonObjectType::Dict;
3741
} else {
3842
throw NotImplementedException(stringFormat("Scanning of type {} has not been implemented",
3943
py::str(py::type::of(ele)).cast<std::string>()));
@@ -63,7 +67,68 @@ void transformListValue(common::ValueVector* outputVector, uint64_t pos, py::han
6367
}
6468
}
6569

70+
static std::vector<std::string> transformStructKeys(py::handle keys, idx_t size) {
71+
std::vector<std::string> res;
72+
res.reserve(size);
73+
for (auto i = 0u; i < size; i++) {
74+
res.emplace_back(py::str(keys.attr("__getitem__")(i)));
75+
}
76+
return res;
77+
}
78+
79+
void transformDictionaryToStruct(common::ValueVector* outputVector, uint64_t pos,
80+
const PyDictionary& dict) {
81+
KU_ASSERT(outputVector->dataType.getLogicalTypeID() == LogicalTypeID::STRUCT);
82+
auto structKeys = transformStructKeys(dict.keys, dict.len);
83+
if (StructType::getNumFields(outputVector->dataType) != dict.len) {
84+
throw common::ConversionException(
85+
common::stringFormat("Failed to convert python dictionary: {} to target type {}",
86+
dict.toString(), outputVector->dataType.toString()));
87+
}
88+
89+
common::case_insensitive_map_t<idx_t> keyMap;
90+
for (idx_t i = 0; i < structKeys.size(); i++) {
91+
keyMap[structKeys[i]] = i;
92+
}
93+
94+
for (auto i = 0u; i < StructType::getNumFields(outputVector->dataType); i++) {
95+
auto& field = StructType::getField(outputVector->dataType, i);
96+
auto idx = keyMap[field.getName()];
97+
transformPythonValue(StructVector::getFieldVector(outputVector, i).get(), pos,
98+
dict.values.attr("__getitem__")(idx));
99+
}
100+
}
101+
102+
void transformDictionaryToMap(common::ValueVector* outputVector, uint64_t pos,
103+
const PyDictionary& dict) {
104+
KU_ASSERT(outputVector->dataType.getLogicalTypeID() == LogicalTypeID::MAP);
105+
auto keys = dict.values.attr("__getitem__")(0);
106+
auto values = dict.values.attr("__getitem__")(1);
107+
108+
if (py::none().is(keys) || py::none().is(values)) {
109+
// Null map
110+
outputVector->setNull(pos, true /* isNull */);
111+
}
112+
113+
auto numKeys = py::len(keys);
114+
KU_ASSERT(numKeys == py::len(values));
115+
auto listEntry = ListVector::addList(outputVector, numKeys);
116+
outputVector->setValue(pos, listEntry);
117+
auto structVector = ListVector::getDataVector(outputVector);
118+
auto keyVector = StructVector::getFieldVector(structVector, 0);
119+
auto valVector = StructVector::getFieldVector(structVector, 1);
120+
for (auto i = 0u; i < numKeys; i++) {
121+
transformPythonValue(keyVector.get(), listEntry.offset + i, keys.attr("__getitem__")(i));
122+
transformPythonValue(valVector.get(), listEntry.offset + i, values.attr("__getitem__")(i));
123+
}
124+
}
125+
66126
void transformPythonValue(common::ValueVector* outputVector, uint64_t pos, py::handle ele) {
127+
if (ele.is_none()) {
128+
outputVector->setNull(pos, true /* isNull */);
129+
return;
130+
}
131+
outputVector->setNull(pos, false /* isNull */);
67132
auto objType = getPythonObjectType(ele);
68133
switch (objType) {
69134
case PythonObjectType::None: {
@@ -117,6 +182,19 @@ void transformPythonValue(common::ValueVector* outputVector, uint64_t pos, py::h
117182
UUID::fromString(ele.attr("hex").cast<std::string>(), result);
118183
outputVector->setValue(pos, result);
119184
} break;
185+
case PythonObjectType::Dict: {
186+
PyDictionary dict = PyDictionary(py::reinterpret_borrow<py::object>(ele));
187+
switch (outputVector->dataType.getLogicalTypeID()) {
188+
case LogicalTypeID::STRUCT: {
189+
transformDictionaryToStruct(outputVector, pos, dict);
190+
} break;
191+
case LogicalTypeID::MAP: {
192+
transformDictionaryToMap(outputVector, pos, dict);
193+
} break;
194+
default:
195+
KU_UNREACHABLE;
196+
}
197+
} break;
120198
default:
121199
KU_UNREACHABLE;
122200
}

test/test_scan_pandas.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ def test_scan_pandas_with_exists(tmp_path: Path) -> None:
587587
assert tp[0] == 3
588588
assert tp[1] == 1
589589

590+
590591
def test_scan_empty_list(tmp_path: Path) -> None:
591592
db = kuzu.Database(tmp_path)
592593
conn = kuzu.Connection(db)
@@ -599,3 +600,63 @@ def test_scan_empty_list(tmp_path: Path) -> None:
599600
tp = res.get_next()
600601
assert tp[0] == "1"
601602
assert tp[1] == []
603+
604+
605+
def test_scan_py_dict_struct_format(tmp_path: Path) -> None:
606+
db = kuzu.Database(tmp_path)
607+
conn = kuzu.Connection(db)
608+
df = pd.DataFrame({
609+
"id": [1, 3, 4],
610+
"dt": [{'key1': 5, 'key3': 4}, {'key1': 10, 'key3': 25}, None]
611+
})
612+
res = conn.execute("LOAD FROM df RETURN *")
613+
tp = res.get_next()
614+
assert tp[0] == 1
615+
assert tp[1] == {'key1': 5, 'key3': 4}
616+
tp = res.get_next()
617+
assert tp[0] == 3
618+
assert tp[1] == {'key1': 10, 'key3': 25}
619+
tp = res.get_next()
620+
assert tp[0] == 4
621+
assert tp[1] is None
622+
623+
624+
def test_scan_py_dict_map_format(tmp_path: Path) -> None:
625+
db = kuzu.Database(tmp_path)
626+
conn = kuzu.Connection(db)
627+
df = pd.DataFrame({
628+
"id": [1, 3, 4],
629+
"dt": [{'key': ['Alice', 'Bob'], 'value': [32, 41]}, {'key': ['Carol'], 'value': [2]},
630+
{'key': ['zoo', 'ela', 'dan'], 'value': [44, 52, 88]}]
631+
})
632+
res = conn.execute("LOAD FROM df RETURN *")
633+
tp = res.get_next()
634+
assert tp[0] == 1
635+
assert tp[1] == {'Alice': 32, 'Bob': 41}
636+
tp = res.get_next()
637+
assert tp[0] == 3
638+
assert tp[1] == {'Carol': 2}
639+
tp = res.get_next()
640+
assert tp[0] == 4
641+
assert tp[1] == {'zoo': 44, 'ela': 52, 'dan': 88}
642+
643+
# If key and value size don't match, kuzu sniffs it as struct.
644+
df = pd.DataFrame({
645+
"id": [4],
646+
"dt": [{'key': ['Alice', 'Bob'], 'value': []}]
647+
})
648+
res = conn.execute("LOAD FROM df RETURN *")
649+
tup = res.get_next()
650+
assert tup[0] == 4
651+
assert tup[1] == {'key': ['Alice', 'Bob'], 'value': []}
652+
653+
654+
def test_scan_py_dict_empty(tmp_path: Path) -> None:
655+
db = kuzu.Database(tmp_path)
656+
conn = kuzu.Connection(db)
657+
df = pd.DataFrame({
658+
"id": [],
659+
"dt": []
660+
})
661+
res = conn.execute("LOAD FROM df RETURN *")
662+
assert not res.has_next()

0 commit comments

Comments
 (0)