Skip to content

Commit 1cea7c2

Browse files
authored
Updated progress bar for asynchronous queries (#3665)
1 parent 99239d3 commit 1cea7c2

6 files changed

Lines changed: 78 additions & 73 deletions

File tree

src_cpp/include/node_connection.h

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ class ConnectionInitAsyncWorker : public Napi::AsyncWorker {
6060
NodeConnection* nodeConnection;
6161
};
6262

63+
namespace kuzu {
64+
namespace main {
65+
6366
class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker {
6467
public:
6568
ConnectionExecuteAsyncWorker(Napi::Function& callback, std::shared_ptr<Connection>& connection,
@@ -78,27 +81,33 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker {
7881
~ConnectionExecuteAsyncWorker() override = default;
7982

8083
void Execute() override {
84+
uint64_t queryID = connection->getClientContext()->getDatabase()->getNextQueryID();
8185
auto progressBar = connection->getClientContext()->getProgressBar();
82-
bool trackProgress = progressBar->getProgressBarPrinting();
86+
auto trackProgress = progressBar->getProgressBarPrinting();
87+
auto display = progressBar->getDisplay().get();
88+
NodeProgressBarDisplay* nodeDisplay =
89+
ku_dynamic_cast<ProgressBarDisplay*, NodeProgressBarDisplay*>(display);
8390
if (progressCallback) {
91+
nodeDisplay->setCallbackFunction(queryID, *progressCallback);
8492
progressBar->toggleProgressBarPrinting(true);
85-
progressBar->setDisplay(
86-
std::make_shared<NodeProgressBarDisplay>(*progressCallback, Env()));
8793
}
8894
try {
8995
auto result =
90-
connection->executeWithParams(preparedStatement.get(), std::move(params)).release();
96+
connection
97+
->executeWithParamsWithID(preparedStatement.get(), std::move(params), queryID)
98+
.release();
9199
nodeQueryResult->SetQueryResult(result, true);
92100
if (!result->isSuccess()) {
93101
SetError(result->getErrorMessage());
102+
return;
94103
}
95104
} catch (const std::exception& exc) {
96105
SetError(std::string(exc.what()));
97106
}
98107
if (progressCallback) {
99-
progressBar->toggleProgressBarPrinting(trackProgress);
100-
progressBar->setDisplay(ProgressBar::DefaultProgressBarDisplay());
101-
progressCallback->Release();
108+
if (nodeDisplay->getNumCallbacks() == 0) {
109+
progressBar->toggleProgressBarPrinting(trackProgress);
110+
}
102111
}
103112
}
104113

@@ -129,15 +138,18 @@ class ConnectionQueryAsyncWorker : public Napi::AsyncWorker {
129138
~ConnectionQueryAsyncWorker() override = default;
130139

131140
void Execute() override {
141+
uint64_t queryID = connection->getClientContext()->getDatabase()->getNextQueryID();
132142
auto progressBar = connection->getClientContext()->getProgressBar();
133-
bool trackProgress = progressBar->getProgressBarPrinting();
143+
auto trackProgress = progressBar->getProgressBarPrinting();
144+
auto display = progressBar->getDisplay().get();
145+
NodeProgressBarDisplay* nodeDisplay =
146+
ku_dynamic_cast<ProgressBarDisplay*, NodeProgressBarDisplay*>(display);
134147
if (progressCallback) {
148+
nodeDisplay->setCallbackFunction(queryID, *progressCallback);
135149
progressBar->toggleProgressBarPrinting(true);
136-
progressBar->setDisplay(
137-
std::make_shared<NodeProgressBarDisplay>(*progressCallback, Env()));
138150
}
139151
try {
140-
auto result = connection->query(statement).release();
152+
auto result = connection->queryWithID(statement, queryID).release();
141153
nodeQueryResult->SetQueryResult(result, true);
142154
if (!result->isSuccess()) {
143155
SetError(result->getErrorMessage());
@@ -146,9 +158,9 @@ class ConnectionQueryAsyncWorker : public Napi::AsyncWorker {
146158
SetError(std::string(exc.what()));
147159
}
148160
if (progressCallback) {
149-
progressBar->toggleProgressBarPrinting(trackProgress);
150-
progressBar->setDisplay(ProgressBar::DefaultProgressBarDisplay());
151-
progressCallback->Release();
161+
if (nodeDisplay->getNumCallbacks() == 0) {
162+
progressBar->toggleProgressBarPrinting(trackProgress);
163+
}
152164
}
153165
}
154166

@@ -162,3 +174,6 @@ class ConnectionQueryAsyncWorker : public Napi::AsyncWorker {
162174
NodeQueryResult* nodeQueryResult;
163175
std::optional<Napi::ThreadSafeFunction> progressCallback;
164176
};
177+
178+
} // namespace main
179+
} // namespace kuzu
Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#pragma once
22

3+
#include <optional>
4+
#include <unordered_set>
5+
36
#include "common/task_system/progress_bar_display.h"
47
#include <napi.h>
58

@@ -11,14 +14,15 @@ using namespace common;
1114
*/
1215
class NodeProgressBarDisplay : public ProgressBarDisplay {
1316
public:
14-
NodeProgressBarDisplay(Napi::ThreadSafeFunction callback, Napi::Env env)
15-
: callback(callback), env(env) {}
17+
void updateProgress(uint64_t queryID, double newPipelineProgress,
18+
uint32_t newNumPipelinesFinished) override;
19+
20+
void finishProgress(uint64_t queryID) override;
1621

17-
void updateProgress(double newPipelineProgress, uint32_t newNumPipelinesFinished) override;
22+
void setCallbackFunction(uint64_t queryID, Napi::ThreadSafeFunction callback);
1823

19-
void finishProgress() override;
24+
uint32_t getNumCallbacks() { return queryCallbacks.size(); }
2025

2126
private:
22-
Napi::ThreadSafeFunction callback;
23-
Napi::Env env;
27+
std::unordered_map<uint64_t, Napi::ThreadSafeFunction> queryCallbacks;
2428
};

src_cpp/node_connection.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ Napi::Value NodeConnection::InitAsync(const Napi::CallbackInfo& info) {
4141

4242
void NodeConnection::InitCppConnection() {
4343
this->connection = std::make_shared<Connection>(database.get());
44+
connection->getClientContext()->getProgressBar()->setDisplay(
45+
std::make_shared<NodeProgressBarDisplay>());
4446
// After the connection is initialized, we do not need to hold a reference to the database.
4547
database.reset();
4648
}
Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,50 @@
11
#include "include/node_progress_bar_display.h"
22

3+
#include <tuple>
4+
35
using namespace kuzu;
46
using namespace common;
57

6-
void NodeProgressBarDisplay::updateProgress(double newPipelineProgress,
8+
void NodeProgressBarDisplay::updateProgress(uint64_t queryID, double newPipelineProgress,
79
uint32_t newNumPipelinesFinished) {
8-
uint32_t progress = (uint32_t)(newPipelineProgress * 100.0);
9-
uint32_t oldProgress = (uint32_t)(pipelineProgress * 100.0);
10-
if (progress > oldProgress || newNumPipelinesFinished > numPipelinesFinished) {
10+
if (numPipelines == 0) {
11+
return;
12+
}
13+
uint32_t curPipelineProgress = (uint32_t)(newPipelineProgress * 100.0);
14+
uint32_t oldPipelineProgress = (uint32_t)(pipelineProgress * 100.0);
15+
if (curPipelineProgress > oldPipelineProgress ||
16+
newNumPipelinesFinished > numPipelinesFinished) {
1117
pipelineProgress = newPipelineProgress;
1218
numPipelinesFinished = newNumPipelinesFinished;
13-
callback.BlockingCall([this](Napi::Env env, Napi::Function jsCallback) {
14-
jsCallback.Call({Napi::Number::New(env, pipelineProgress),
15-
Napi::Number::New(env, numPipelinesFinished),
16-
Napi::Number::New(env, numPipelines)});
17-
});
19+
auto callback = queryCallbacks.find(queryID);
20+
if (callback != queryCallbacks.end()) {
21+
double capturedPipelineProgress = pipelineProgress;
22+
uint32_t capturedNumPipelinesFinished = numPipelinesFinished;
23+
uint32_t capturedNumPipelines = numPipelines;
24+
callback->second.BlockingCall(
25+
[capturedPipelineProgress, capturedNumPipelinesFinished,
26+
capturedNumPipelines](Napi::Env env, Napi::Function jsCallback) {
27+
// Use the captured values directly inside the lambda
28+
jsCallback.Call({Napi::Number::New(env, capturedPipelineProgress),
29+
Napi::Number::New(env, capturedNumPipelinesFinished),
30+
Napi::Number::New(env, capturedNumPipelines)});
31+
});
32+
}
1833
}
1934
}
2035

21-
void NodeProgressBarDisplay::finishProgress() {
22-
pipelineProgress = 0;
36+
void NodeProgressBarDisplay::finishProgress(uint64_t queryID) {
2337
numPipelines = 0;
2438
numPipelinesFinished = 0;
39+
pipelineProgress = 0;
40+
auto callback = queryCallbacks.find(queryID);
41+
if (callback != queryCallbacks.end()) {
42+
callback->second.Release();
43+
}
44+
queryCallbacks.erase(queryID);
45+
}
46+
47+
void NodeProgressBarDisplay::setCallbackFunction(uint64_t queryID,
48+
Napi::ThreadSafeFunction callback) {
49+
queryCallbacks.emplace(queryID, callback);
2550
}

src_js/connection.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class Connection {
8484
* Execute a prepared statement with the given parameters.
8585
* @param {kuzu.PreparedStatement} preparedStatement the prepared statement to execute.
8686
* @param {Object} params a plain object mapping parameter names to values.
87-
* @param {Function} [progressCallback] optional callback function that is invoked with the progress of the query execution. The callback receives three arguments: pipelineProgress, numPipelinesFinished, and numPipelines.
87+
* @param {Function} [progressCallback] - Optional callback function that is invoked with the progress of the query execution. The callback receives three arguments: pipelineProgress, numPipelinesFinished, and numPipelines.
8888
* @returns {Promise<kuzu.QueryResult>} a promise that resolves to the query result. The promise is rejected if there is an error.
8989
*/
9090
execute(preparedStatement, params = {}, progressCallback) {

test/test_connection.js

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -274,47 +274,6 @@ describe("Close", function () {
274274
});
275275

276276
describe("Progress", function () {
277-
it("should execute a valid prepared statement with progress", async function () {
278-
await conn.query("CALL progress_bar_time = 0");
279-
let progressCalled = false;
280-
const progressCallback = (pipelineProgress, numPipelinesFinished, numPipelines) => {
281-
progressCalled = true;
282-
assert.isNumber(pipelineProgress);
283-
assert.isNumber(numPipelinesFinished);
284-
assert.isNumber(numPipelines);
285-
};
286-
const preparedStatement = await conn.prepare(
287-
"MATCH (a:person) WHERE a.ID = $1 RETURN COUNT(*)"
288-
);
289-
assert.exists(preparedStatement);
290-
assert.isTrue(preparedStatement.isSuccess());
291-
const queryResult = await conn.execute(preparedStatement, { 1: 0 }, progressCallback);
292-
assert.exists(queryResult);
293-
assert.equal(queryResult.constructor.name, "QueryResult");
294-
assert.isTrue(queryResult.hasNext());
295-
const tuple = await queryResult.getNext();
296-
assert.exists(tuple);
297-
assert.exists(tuple["COUNT_STAR()"]);
298-
assert.equal(tuple["COUNT_STAR()"], 1);
299-
assert.isTrue(progressCalled)
300-
});
301-
it("should throw error if the progress callback is not a function for execute", async function () {
302-
try {
303-
const preparedStatement = await conn.prepare(
304-
"MATCH (a:person) WHERE a.ID = $1 RETURN COUNT(*)"
305-
);
306-
assert.exists(preparedStatement);
307-
assert.isTrue(preparedStatement.isSuccess());
308-
await conn.execute(preparedStatement, { 1: 0 }, 10);
309-
assert.fail("No error thrown when progress callback is not a function.");
310-
} catch (e) {
311-
assert.equal(
312-
e.message,
313-
"progressCallback must be a function."
314-
);
315-
}
316-
});
317-
318277
it("should execute a valid query with progress", async function () {
319278
await conn.query("CALL progress_bar_time = 0");
320279
let progressCalled = false;

0 commit comments

Comments
 (0)