Skip to content

Commit 235d266

Browse files
authored
Added query progress callbacks for nodejs api (#3591)
Co-authored-by: CI Bot <MSebanc@users.noreply.github.com>
1 parent d6c04eb commit 235d266

6 files changed

Lines changed: 183 additions & 12 deletions

File tree

src_cpp/include/node_connection.h

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "main/kuzu.h"
66
#include "node_database.h"
77
#include "node_prepared_statement.h"
8+
#include "node_progress_bar_display.h"
89
#include "node_query_result.h"
910
#include <napi.h>
1011

@@ -63,24 +64,42 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker {
6364
public:
6465
ConnectionExecuteAsyncWorker(Napi::Function& callback, std::shared_ptr<Connection>& connection,
6566
std::shared_ptr<PreparedStatement> preparedStatement, NodeQueryResult* nodeQueryResult,
66-
std::unordered_map<std::string, std::unique_ptr<Value>> params)
67+
std::unordered_map<std::string, std::unique_ptr<Value>> params,
68+
Napi::Value progressCallback)
6769
: Napi::AsyncWorker(callback), connection(connection),
6870
preparedStatement(std::move(preparedStatement)), nodeQueryResult(nodeQueryResult),
69-
params(std::move(params)) {}
71+
params(std::move(params)) {
72+
if (progressCallback.IsFunction()) {
73+
this->progressCallback = Napi::ThreadSafeFunction::New(Env(),
74+
progressCallback.As<Napi::Function>(), "ProgressCallback", 0, 1);
75+
}
76+
}
77+
7078
~ConnectionExecuteAsyncWorker() override = default;
7179

7280
void Execute() override {
81+
auto progressBar = connection->getClientContext()->getProgressBar();
82+
bool trackProgress = progressBar->getProgressBarPrinting();
83+
if (progressCallback) {
84+
progressBar->toggleProgressBarPrinting(true);
85+
progressBar->setDisplay(
86+
std::make_shared<NodeProgressBarDisplay>(*progressCallback, Env()));
87+
}
7388
try {
7489
auto result =
7590
connection->executeWithParams(preparedStatement.get(), std::move(params)).release();
7691
nodeQueryResult->SetQueryResult(result, true);
7792
if (!result->isSuccess()) {
7893
SetError(result->getErrorMessage());
79-
return;
8094
}
8195
} catch (const std::exception& exc) {
8296
SetError(std::string(exc.what()));
8397
}
98+
if (progressCallback) {
99+
progressBar->toggleProgressBarPrinting(trackProgress);
100+
progressBar->setDisplay(ProgressBar::DefaultProgressBarDisplay());
101+
progressCallback->Release();
102+
}
84103
}
85104

86105
void OnOK() override { Callback().Call({Env().Null()}); }
@@ -92,28 +111,45 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker {
92111
std::shared_ptr<PreparedStatement> preparedStatement;
93112
NodeQueryResult* nodeQueryResult;
94113
std::unordered_map<std::string, std::unique_ptr<Value>> params;
114+
std::optional<Napi::ThreadSafeFunction> progressCallback;
95115
};
96116

97117
class ConnectionQueryAsyncWorker : public Napi::AsyncWorker {
98118
public:
99119
ConnectionQueryAsyncWorker(Napi::Function& callback, std::shared_ptr<Connection>& connection,
100-
std::string statement, NodeQueryResult* nodeQueryResult)
120+
std::string statement, NodeQueryResult* nodeQueryResult, Napi::Value progressCallback)
101121
: Napi::AsyncWorker(callback), connection(connection), statement(std::move(statement)),
102-
nodeQueryResult(nodeQueryResult) {}
122+
nodeQueryResult(nodeQueryResult) {
123+
if (progressCallback.IsFunction()) {
124+
this->progressCallback = Napi::ThreadSafeFunction::New(Env(),
125+
progressCallback.As<Napi::Function>(), "ProgressCallback", 0, 1);
126+
}
127+
}
103128

104129
~ConnectionQueryAsyncWorker() override = default;
105130

106131
void Execute() override {
132+
auto progressBar = connection->getClientContext()->getProgressBar();
133+
bool trackProgress = progressBar->getProgressBarPrinting();
134+
if (progressCallback) {
135+
progressBar->toggleProgressBarPrinting(true);
136+
progressBar->setDisplay(
137+
std::make_shared<NodeProgressBarDisplay>(*progressCallback, Env()));
138+
}
107139
try {
108140
auto result = connection->query(statement).release();
109141
nodeQueryResult->SetQueryResult(result, true);
110142
if (!result->isSuccess()) {
111143
SetError(result->getErrorMessage());
112-
return;
113144
}
114145
} catch (const std::exception& exc) {
115146
SetError(std::string(exc.what()));
116147
}
148+
if (progressCallback) {
149+
progressBar->toggleProgressBarPrinting(trackProgress);
150+
progressBar->setDisplay(ProgressBar::DefaultProgressBarDisplay());
151+
progressCallback->Release();
152+
}
117153
}
118154

119155
void OnOK() override { Callback().Call({Env().Null()}); }
@@ -124,4 +160,5 @@ class ConnectionQueryAsyncWorker : public Napi::AsyncWorker {
124160
std::shared_ptr<Connection> connection;
125161
std::string statement;
126162
NodeQueryResult* nodeQueryResult;
163+
std::optional<Napi::ThreadSafeFunction> progressCallback;
127164
};
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#pragma once
2+
3+
#include "common/task_system/progress_bar_display.h"
4+
#include <napi.h>
5+
6+
using namespace kuzu;
7+
using namespace common;
8+
9+
/**
10+
* @brief A class that displays a progress bar in the terminal.
11+
*/
12+
class NodeProgressBarDisplay : public ProgressBarDisplay {
13+
public:
14+
NodeProgressBarDisplay(Napi::ThreadSafeFunction callback, Napi::Env env)
15+
: callback(callback), env(env) {}
16+
17+
void updateProgress(double newPipelineProgress, uint32_t newNumPipelinesFinished) override;
18+
19+
void finishProgress() override;
20+
21+
private:
22+
Napi::ThreadSafeFunction callback;
23+
Napi::Env env;
24+
};

src_cpp/node_connection.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ Napi::Value NodeConnection::ExecuteAsync(const Napi::CallbackInfo& info) {
8383
try {
8484
auto params = Util::TransformParametersForExec(info[2].As<Napi::Array>());
8585
auto asyncWorker = new ConnectionExecuteAsyncWorker(callback, connection,
86-
nodePreparedStatement->preparedStatement, nodeQueryResult, std::move(params));
86+
nodePreparedStatement->preparedStatement, nodeQueryResult, std::move(params), info[4]);
8787
asyncWorker->Queue();
8888
} catch (const std::exception& exc) {
8989
Napi::Error::New(env, std::string(exc.what())).ThrowAsJavaScriptException();
@@ -98,7 +98,7 @@ Napi::Value NodeConnection::QueryAsync(const Napi::CallbackInfo& info) {
9898
auto nodeQueryResult = Napi::ObjectWrap<NodeQueryResult>::Unwrap(info[1].As<Napi::Object>());
9999
auto callback = info[2].As<Napi::Function>();
100100
auto asyncWorker =
101-
new ConnectionQueryAsyncWorker(callback, connection, statement, nodeQueryResult);
101+
new ConnectionQueryAsyncWorker(callback, connection, statement, nodeQueryResult, info[3]);
102102
asyncWorker->Queue();
103103
return info.Env().Undefined();
104104
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#include "include/node_progress_bar_display.h"
2+
3+
using namespace kuzu;
4+
using namespace common;
5+
6+
void NodeProgressBarDisplay::updateProgress(double newPipelineProgress,
7+
uint32_t newNumPipelinesFinished) {
8+
uint32_t progress = (uint32_t)(newPipelineProgress * 100.0);
9+
uint32_t oldProgress = (uint32_t)(pipelineProgress * 100.0);
10+
if (progress > oldProgress || newNumPipelinesFinished > numPipelinesFinished) {
11+
pipelineProgress = newPipelineProgress;
12+
numPipelinesFinished = newNumPipelinesFinished;
13+
callback.BlockingCall([this](Napi::Env env, Napi::Function jsCallback) {
14+
jsCallback.Call({Napi::Number::New(env, pipelineProgress),
15+
Napi::Number::New(env, numPipelinesFinished),
16+
Napi::Number::New(env, numPipelines)});
17+
});
18+
}
19+
}
20+
21+
void NodeProgressBarDisplay::finishProgress() {
22+
pipelineProgress = 0;
23+
numPipelines = 0;
24+
numPipelinesFinished = 0;
25+
}

src_js/connection.js

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,10 @@ class Connection {
8484
* Execute a prepared statement with the given parameters.
8585
* @param {kuzu.PreparedStatement} preparedStatement the prepared statement to execute.
8686
* @param {Object} params a plain object mapping parameter names to values.
87+
* @param {Function} [progressCallback] optional callback function that is invoked with the progress of the query execution. The callback receives three arguments: pipelineProgress, numPipelinesFinished, and numPipelines.
8788
* @returns {Promise<kuzu.QueryResult>} a promise that resolves to the query result. The promise is rejected if there is an error.
8889
*/
89-
execute(preparedStatement, params = {}) {
90+
execute(preparedStatement, params = {}, progressCallback) {
9091
return new Promise((resolve, reject) => {
9192
if (
9293
!typeof preparedStatement === "object" ||
@@ -129,6 +130,9 @@ class Connection {
129130
)
130131
);
131132
}
133+
}
134+
if (progressCallback && typeof progressCallback !== "function") {
135+
return reject(new Error("progressCallback must be a function."));
132136
}
133137
this._getConnection()
134138
.then((connection) => {
@@ -149,7 +153,8 @@ class Connection {
149153
.catch((err) => {
150154
return reject(err);
151155
});
152-
}
156+
},
157+
progressCallback
153158
);
154159
} catch (e) {
155160
return reject(e);
@@ -193,13 +198,17 @@ class Connection {
193198
/**
194199
* Execute a query.
195200
* @param {String} statement the statement to execute.
201+
* @param {Function} [progressCallback] - Optional callback function that is invoked with the progress of the query execution. The callback receives three arguments: pipelineProgress, numPipelinesFinished, and numPipelines.
196202
* @returns {Promise<kuzu.QueryResult>} a promise that resolves to the query result. The promise is rejected if there is an error.
197203
*/
198-
query(statement) {
204+
query(statement, progressCallback) {
199205
return new Promise((resolve, reject) => {
200206
if (typeof statement !== "string") {
201207
return reject(new Error("statement must be a string."));
202208
}
209+
if (progressCallback && typeof progressCallback !== "function") {
210+
return reject(new Error("progressCallback must be a function."));
211+
}
203212
this._getConnection()
204213
.then((connection) => {
205214
const nodeQueryResult = new KuzuNative.NodeQueryResult();
@@ -215,7 +224,8 @@ class Connection {
215224
.catch((err) => {
216225
return reject(err);
217226
});
218-
});
227+
},
228+
progressCallback);
219229
} catch (e) {
220230
return reject(e);
221231
}

test/test_connection.js

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,78 @@ describe("Close", function () {
272272
}
273273
});
274274
});
275+
276+
describe("Progress", function () {
277+
it("should execute a valid prepared statement with progress", async function () {
278+
await conn.query("CALL progress_bar_time = 0");
279+
let progressCalled = false;
280+
const progressCallback = (pipelineProgress, numPipelinesFinished, numPipelines) => {
281+
progressCalled = true;
282+
assert.isNumber(pipelineProgress);
283+
assert.isNumber(numPipelinesFinished);
284+
assert.isNumber(numPipelines);
285+
};
286+
const preparedStatement = await conn.prepare(
287+
"MATCH (a:person) WHERE a.ID = $1 RETURN COUNT(*)"
288+
);
289+
assert.exists(preparedStatement);
290+
assert.isTrue(preparedStatement.isSuccess());
291+
const queryResult = await conn.execute(preparedStatement, { 1: 0 }, progressCallback);
292+
assert.exists(queryResult);
293+
assert.equal(queryResult.constructor.name, "QueryResult");
294+
assert.isTrue(queryResult.hasNext());
295+
const tuple = await queryResult.getNext();
296+
assert.exists(tuple);
297+
assert.exists(tuple["COUNT_STAR()"]);
298+
assert.equal(tuple["COUNT_STAR()"], 1);
299+
assert.isTrue(progressCalled)
300+
});
301+
it("should throw error if the progress callback is not a function for execute", async function () {
302+
try {
303+
const preparedStatement = await conn.prepare(
304+
"MATCH (a:person) WHERE a.ID = $1 RETURN COUNT(*)"
305+
);
306+
assert.exists(preparedStatement);
307+
assert.isTrue(preparedStatement.isSuccess());
308+
await conn.execute(preparedStatement, { 1: 0 }, 10);
309+
assert.fail("No error thrown when progress callback is not a function.");
310+
} catch (e) {
311+
assert.equal(
312+
e.message,
313+
"progressCallback must be a function."
314+
);
315+
}
316+
});
317+
318+
it("should execute a valid query with progress", async function () {
319+
await conn.query("CALL progress_bar_time = 0");
320+
let progressCalled = false;
321+
const progressCallback = (pipelineProgress, numPipelinesFinished, numPipelines) => {
322+
progressCalled = true;
323+
assert.isNumber(pipelineProgress);
324+
assert.isNumber(numPipelinesFinished);
325+
assert.isNumber(numPipelines);
326+
};
327+
const queryResult = await conn.query("MATCH (a:person) RETURN COUNT(*)", progressCallback);
328+
assert.exists(queryResult);
329+
assert.equal(queryResult.constructor.name, "QueryResult");
330+
assert.isTrue(queryResult.hasNext());
331+
const tuple = await queryResult.getNext();
332+
assert.exists(tuple);
333+
assert.exists(tuple["COUNT_STAR()"]);
334+
assert.equal(tuple["COUNT_STAR()"], 8);
335+
assert.isTrue(progressCalled);
336+
});
337+
338+
it("should throw error if the progress callback is not a function for query", async function () {
339+
try {
340+
await conn.query("MATCH (a:person) RETURN COUNT(*)", 10);
341+
assert.fail("No error thrown when progress callback is not a function.");
342+
} catch (e) {
343+
assert.equal(
344+
e.message,
345+
"progressCallback must be a function."
346+
);
347+
}
348+
});
349+
});

0 commit comments

Comments
 (0)