Skip to content

Commit 7c36fdd

Browse files
committed
Add unified matrix table add interface.
1 parent 00e67ac commit 7c36fdd

2 files changed

Lines changed: 27 additions & 10 deletions

File tree

include/multiverso/table/matrix_table.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@ class MatrixWorkerTable : public WorkerTable {
2929
void Get(const std::vector<integer_t>& row_ids,
3030
const std::vector<T*>& data_vec, size_t size);
3131

32-
// Add whole table
33-
void Add(T* data, size_t size, const AddOption* option = nullptr);
34-
3532
void Add(integer_t row_id, T* data, size_t size,
3633
const AddOption* option = nullptr);
3734

3835
void Add(const std::vector<integer_t>& row_ids,
3936
const std::vector<T*>& data_vec, size_t size,
4037
const AddOption* option = nullptr);
4138

39+
void Add(T* data, size_t size, integer_t* row_ids = nullptr,
40+
integer_t row_ids_size = 0, const AddOption* option = nullptr);
41+
4242
int Partition(const std::vector<Blob>& kv,
4343
std::unordered_map<int, std::vector<Blob>>* out) override;
4444

src/table/matrix_table.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,6 @@ void MatrixWorkerTable<T>::Get(const std::vector<integer_t>& row_ids,
8989
Log::Debug("[Get] worker = %d, #rows_set = %d\n", MV_Rank(), row_ids.size());
9090
}
9191

92-
template <typename T>
93-
void MatrixWorkerTable<T>::Add(T* data, size_t size, const AddOption* option) {
94-
CHECK(size == num_col_ * num_row_);
95-
integer_t whole_table = -1;
96-
Add(whole_table, data, size, option);
97-
}
98-
9992
template <typename T>
10093
void MatrixWorkerTable<T>::Add(integer_t row_id, T* data, size_t size,
10194
const AddOption* option) {
@@ -122,6 +115,30 @@ void MatrixWorkerTable<T>::Add(const std::vector<integer_t>& row_ids,
122115
Log::Debug("[Add] worker = %d, #rows_set = %d\n", MV_Rank(), row_ids.size());
123116
}
124117

118+
template <typename T>
119+
void MatrixWorkerTable<T>::Add(T* data, size_t size, integer_t* row_ids,
120+
integer_t row_ids_size,
121+
const AddOption* option) {
122+
if (row_ids_size == 0) {
123+
CHECK(size == num_col_ * num_row_);
124+
integer_t row_id = -1;
125+
Blob ids_blob(&row_id, sizeof(integer_t));
126+
Blob data_blob(data, size * sizeof(T));
127+
WorkerTable::Add(ids_blob, data_blob, option);
128+
Log::Debug("[Add] worker = %d, #row = %d\n", MV_Rank(), row_id);
129+
} else {
130+
CHECK(size == num_col_);
131+
Blob ids_blob(row_ids, sizeof(integer_t) * row_ids_size);
132+
Blob data_blob(row_ids_size * row_size_);
133+
//copy each row
134+
for (auto i = 0; i < row_ids_size; ++i){
135+
memcpy(data_blob.data() + i * row_size_, &data[i * num_col_], row_size_);
136+
}
137+
WorkerTable::Add(ids_blob, data_blob, option);
138+
Log::Debug("[Add] worker = %d, #rows_set = %d\n", MV_Rank(), row_ids_size);
139+
}
140+
}
141+
125142
template <typename T>
126143
int MatrixWorkerTable<T>::Partition(const std::vector<Blob>& kv,
127144
std::unordered_map<int, std::vector<Blob>>* out) {

0 commit comments

Comments
 (0)