@@ -23,16 +23,32 @@ db_copy_thread_t::db_copy_thread_t(std::string const &conninfo)
2323 });
2424}
2525
26+ db_copy_thread_t ::~db_copy_thread_t () { finish (); }
27+
2628void db_copy_thread_t::add_buffer (std::unique_ptr<db_cmd_t > &&buffer)
2729{
30+ assert (m_worker.joinable ()); // thread must not have been finished
2831 std::unique_lock<std::mutex> lock (m_queue_mutex);
2932 m_worker_queue.push_back (std::move (buffer));
33+ m_queue_cond.notify_one ();
34+ }
35+
36+ void db_copy_thread_t::sync_and_wait ()
37+ {
38+ std::promise<void > barrier;
39+ std::future<void > sync = barrier.get_future ();
40+ add_buffer (std::unique_ptr<db_cmd_t >(new db_cmd_sync_t (std::move (barrier))));
41+ sync.wait ();
3042}
3143
3244void db_copy_thread_t::finish ()
3345{
34- add_buffer (std::unique_ptr<db_cmd_t >(new db_cmd_finish_t ()));
35- m_worker.join ();
46+ if (m_worker.joinable ()) {
47+ finish_copy ();
48+
49+ add_buffer (std::unique_ptr<db_cmd_t >(new db_cmd_finish_t ()));
50+ m_worker.join ();
51+ }
3652}
3753
3854void db_copy_thread_t::worker_thread ()
@@ -61,6 +77,7 @@ void db_copy_thread_t::worker_thread()
6177 execute_sql (static_cast <db_cmd_sql_t *>(item.get ())->buffer );
6278 break ;
6379 case db_cmd_t ::Cmd_sync:
80+ finish_copy ();
6481 static_cast <db_cmd_sync_t *>(item.get ())->barrier .set_value ();
6582 break ;
6683 case db_cmd_t ::Cmd_finish:
@@ -69,8 +86,7 @@ void db_copy_thread_t::worker_thread()
6986 }
7087 }
7188
72- if (m_inflight)
73- finish_copy ();
89+ finish_copy ();
7490
7591 disconnect ();
7692}
@@ -93,8 +109,7 @@ void db_copy_thread_t::connect()
93109
94110void db_copy_thread_t::execute_sql (std::string const &sql_cmd)
95111{
96- if (m_inflight)
97- finish_copy ();
112+ finish_copy ();
98113
99114 pgsql_exec_simple (m_conn, PGRES_COMMAND_OK, sql_cmd.c_str ());
100115}
@@ -117,7 +132,8 @@ void db_copy_thread_t::write_to_db(db_cmd_copy_t *buffer)
117132 if (!buffer->deletables .empty ())
118133 delete_rows (buffer);
119134
120- start_copy (buffer->target );
135+ if (!m_inflight)
136+ start_copy (buffer->target );
121137
122138 pgsql_CopyData (buffer->target ->name .c_str (), m_conn, buffer->buffer );
123139}
@@ -130,24 +146,21 @@ void db_copy_thread_t::delete_rows(db_cmd_copy_t *buffer)
130146 sql.reserve (buffer->target ->name .size () + buffer->deletables .size () * 15 +
131147 30 );
132148 sql += buffer->target ->name ;
133- sql += " WHERE " ;
149+ sql += " WHERE " ;
134150 sql += buffer->target ->id ;
135151 sql += " IN (" ;
136152 for (auto id : buffer->deletables ) {
137153 sql += std::to_string (id);
138154 sql += ' ,' ;
139155 }
140- sql + = ' )' ;
156+ sql[sql. size () - 1 ] = ' )' ;
141157
142158 pgsql_exec_simple (m_conn, PGRES_COMMAND_OK, sql);
143159}
144160
145161void db_copy_thread_t::start_copy (std::shared_ptr<db_target_descr_t > const &target)
146162{
147- if (m_inflight)
148- return ;
149-
150- assert (m_inflight.get () == target.get ());
163+ m_inflight = target;
151164
152165 std::string copystr = " COPY " ;
153166 copystr.reserve (target->name .size () + target->rows .size () + 14 );
@@ -165,6 +178,9 @@ void db_copy_thread_t::start_copy(std::shared_ptr<db_target_descr_t> const &targ
165178
166179void db_copy_thread_t::finish_copy ()
167180{
181+ if (!m_inflight)
182+ return ;
183+
168184 if (PQputCopyEnd (m_conn, nullptr ) != 1 )
169185 throw std::runtime_error ((fmt (" stop COPY_END for %1% failed: %2%\n " ) %
170186 m_inflight->name %
@@ -215,10 +231,12 @@ void db_copy_mgr_t::exec_sql(std::string const &sql_cmd)
215231
216232void db_copy_mgr_t::sync ()
217233{
218- std::promise<void > barrier;
219- std::future<void > sync = barrier.get_future ();
220- m_processor->add_buffer (std::unique_ptr<db_cmd_t >(new db_cmd_sync_t (std::move (barrier))));
221- sync.wait ();
234+ // finish any ongoing copy operations
235+ if (m_current) {
236+ m_processor->add_buffer (std::move (m_current));
237+ }
238+
239+ m_processor->sync_and_wait ();
222240}
223241
224242void db_copy_mgr_t::finish_line ()
0 commit comments