Skip to content
This repository was archived by the owner on Sep 17, 2022. It is now read-only.
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
save
  • Loading branch information
nkreeger committed Mar 29, 2019
commit afdfc5576ca4edace361aef50dbbdbce0b3eec44
155 changes: 112 additions & 43 deletions binding/tfjs_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ namespace tfnodejs {
// Used to hold strings beyond the lifetime of a JS call.
static std::set<std::string> ATTR_NAME_SET;

// Binds a ...
struct TFE_TensorHandleNum {
TFE_TensorHandle *handle;
int32_t handle_num;
};

// Cleans up extra reference count for shared V8/TF tensor memory:
static void DeallocTensor(void *data, size_t len, void *arg) {
NapiAutoRef *auto_ref = static_cast<NapiAutoRef *>(arg);
Expand Down Expand Up @@ -321,7 +327,6 @@ void CopyTFE_TensorHandleDataToTypedArray(napi_env env,
// current value to the newly allocated NAPI buffer.
memcpy(array_buffer_data, TF_TensorData(tensor.tensor), byte_length);

fprintf(stderr, "---> TensorHandle data to TypedArray\n");
nstatus = napi_create_typedarray(env, array_type, num_elements,
array_buffer_value, 0, result);
ENSURE_NAPI_OK(env, nstatus);
Expand Down Expand Up @@ -689,7 +694,9 @@ void AssignOpAttr(napi_env env, TFE_Op *tfe_op, napi_value attr_value) {
}
}

TFJSBackend::TFJSBackend(napi_env env) : next_tensor_id_(0) {
TFJSBackend::TFJSBackend(napi_env env)
: tfe_handle_map_(new std::map<int32_t, TFE_TensorHandle *>()),
next_tensor_id_(0) {
TF_AutoStatus tf_status;
TFE_ContextOptions *tfe_options = TFE_NewContextOptions();
tfe_context_ = TFE_NewContext(tfe_options, tf_status.status);
Expand Down Expand Up @@ -734,7 +741,7 @@ TFJSBackend::TFJSBackend(napi_env env) : next_tensor_id_(0) {
}

TFJSBackend::~TFJSBackend() {
for (auto &kv : tfe_handle_map_) {
for (auto &kv : *tfe_handle_map_) {
TFE_DeleteTensorHandle(kv.second);
}
if (tfe_context_ != nullptr) {
Expand All @@ -744,9 +751,86 @@ TFJSBackend::~TFJSBackend() {

TFJSBackend *TFJSBackend::Create(napi_env env) { return new TFJSBackend(env); }

int32_t TFJSBackend::InsertHandle(TFE_TensorHandle *tfe_handle) {
return tfe_handle_map_.insert(std::make_pair(next_tensor_id_++, tfe_handle))
.first->first;
static int32_t GC_COUNT = 0;

// TODO - move to top of method...
static void TFEHandlePairFinalize(napi_env env, void *data, void *hint) {
std::map<int32_t, TFE_TensorHandle *> *tfe_handle_map =
static_cast<std::map<int32_t, TFE_TensorHandle *> *>(data);
if (!tfe_handle_map) {
fprintf(stderr, "----> EXCEPTION HANDLE MAP IS NOT VALID!!!\n");
return;
}

napi_value tensor_id_value = static_cast<napi_value>(hint);
if (tensor_id_value == nullptr) {
fprintf(stderr, "----> EXCEPTION TENSOR ID IS NOT VALID!!!\n");
return;
}

// TODO - move cleanup to static method?
int32_t tensor_id;
napi_get_value_int32(env, tensor_id_value, &tensor_id);

// TODO - cleanup/refactor this... Use heap ints??? fragmentation?
auto tensor_entry = tfe_handle_map->find(tensor_id);
if (tensor_entry == tfe_handle_map->end()) {
// NAPI_THROW_ERROR(env,
// "Delete called on a Tensor not referenced (tensor_id:
// %d)", tensor_id);
return;
}

GC_COUNT++;
// if (GC_COUNT % 100 == 0) {
fprintf(stderr, "GC_COUNT: %d (TENSOR_ID: %d)\n", GC_COUNT, tensor_id);
// }
TFE_DeleteTensorHandle(tensor_entry->second);
tfe_handle_map->erase(tensor_entry);
}

napi_status TFJSBackend::CreateTensorMetadataValue(
napi_env env, TFE_TensorHandle *tfe_handle, napi_value shape_value,
napi_value dtype_value, napi_value *tensor_metadata_value) {
napi_status nstatus;

// First bump tensor index and insert into the handle map:
int32_t next_idx = next_tensor_id_++; // XXX heap?
tfe_handle_map_->insert(std::make_pair(next_idx, tfe_handle));

if (next_idx % 1000 == 0) {
fprintf(stderr, ":: next_id: %d\n", next_idx);
}

// Next, create an object to represent the TensorMetadata class.
nstatus = napi_create_object(env, tensor_metadata_value);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nstatus);

// Assign all values of the TensorMetadata class:
napi_value id_value;
nstatus = napi_create_int32(env, next_idx, &id_value);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nstatus);

nstatus =
napi_set_named_property(env, *tensor_metadata_value, "id", id_value);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nstatus);

nstatus = napi_set_named_property(env, *tensor_metadata_value, "shape",
shape_value);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nstatus);

nstatus = napi_set_named_property(env, *tensor_metadata_value, "dtype",
dtype_value);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nstatus);

// Next create an external JS object that can be tracked for GC. This object
// must be tracked to ensure the underlying TFE_TensorHandle data is cleanedup
// when Tensor reference is GC'd.
nstatus = napi_wrap(env, *tensor_metadata_value, tfe_handle_map_,
TFEHandlePairFinalize, id_value, nullptr);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nstatus);

return napi_ok;
}

napi_value TFJSBackend::CreateTensor(napi_env env, napi_value shape_value,
Expand Down Expand Up @@ -786,26 +870,28 @@ napi_value TFJSBackend::CreateTensor(napi_env env, napi_value shape_value,
tfe_handle = new_handle;
}

napi_value output_tensor_id;
nstatus = napi_create_int32(env, InsertHandle(tfe_handle), &output_tensor_id);
napi_value tensor_metadata_value;
nstatus = CreateTensorMetadataValue(env, tfe_handle, shape_value, dtype_value,
&tensor_metadata_value);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
return output_tensor_id;

return tensor_metadata_value;
}

void TFJSBackend::DeleteTensor(napi_env env, napi_value tensor_id_value) {
int32_t tensor_id;
ENSURE_NAPI_OK(env, napi_get_value_int32(env, tensor_id_value, &tensor_id));

auto tensor_entry = tfe_handle_map_.find(tensor_id);
if (tensor_entry == tfe_handle_map_.end()) {
auto tensor_entry = tfe_handle_map_->find(tensor_id);
if (tensor_entry == tfe_handle_map_->end()) {
NAPI_THROW_ERROR(env,
"Delete called on a Tensor not referenced (tensor_id: %d)",
tensor_id);
return;
}

TFE_DeleteTensorHandle(tensor_entry->second);
tfe_handle_map_.erase(tensor_entry);
tfe_handle_map_->erase(tensor_entry);
}

napi_value TFJSBackend::GetTensorData(napi_env env,
Expand All @@ -814,8 +900,8 @@ napi_value TFJSBackend::GetTensorData(napi_env env,
ENSURE_NAPI_OK_RETVAL(
env, napi_get_value_int32(env, tensor_id_value, &tensor_id), nullptr);

auto tensor_entry = tfe_handle_map_.find(tensor_id);
if (tensor_entry == tfe_handle_map_.end()) {
auto tensor_entry = tfe_handle_map_->find(tensor_id);
if (tensor_entry == tfe_handle_map_->end()) {
NAPI_THROW_ERROR(
env, "Get data called on a Tensor not referenced (tensor_id: %d)",
tensor_id);
Expand Down Expand Up @@ -855,8 +941,8 @@ napi_value TFJSBackend::ExecuteOp(napi_env env, napi_value op_name_value,
nstatus = napi_get_value_int32(env, cur_input_id, &cur_input_tensor_id);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);

auto input_tensor_entry = tfe_handle_map_.find(cur_input_tensor_id);
if (input_tensor_entry == tfe_handle_map_.end()) {
auto input_tensor_entry = tfe_handle_map_->find(cur_input_tensor_id);
if (input_tensor_entry == tfe_handle_map_->end()) {
NAPI_THROW_ERROR(env, "Input Tensor ID not referenced (tensor_id: %d)",
cur_input_tensor_id);
return nullptr;
Expand Down Expand Up @@ -899,42 +985,25 @@ napi_value TFJSBackend::ExecuteOp(napi_env env, napi_value op_name_value,
nstatus = napi_create_array_with_length(env, size, &output_tensor_infos);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);

// TODO(kreeger): look at napi_adjust_external_memory for GC/heap usage in
// this block
for (int32_t i = 0; i < num_outputs; i++) {
// Output tensor info object:
napi_value tensor_info_value;
nstatus = napi_create_object(env, &tensor_info_value);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);

TFE_TensorHandle *handle = result_handles[i];

// Output tensor ID:
napi_value output_tensor_id_value;
nstatus =
napi_create_int32(env, InsertHandle(handle), &output_tensor_id_value);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);

nstatus = napi_set_named_property(env, tensor_info_value, "id",
output_tensor_id_value);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);

// Output tensor shape:
napi_value shape_value;
GetTFE_TensorHandleShape(env, handle, &shape_value);
GetTFE_TensorHandleShape(env, handle, &shape_value); // nstatus??

nstatus =
napi_set_named_property(env, tensor_info_value, "shape", shape_value);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);

// Output tensor dtype:
napi_value type_value;
GetTFE_TensorHandleType(env, handle, &type_value);
napi_value dtype_value;
GetTFE_TensorHandleType(env, handle, &dtype_value);

nstatus =
napi_set_named_property(env, tensor_info_value, "dtype", type_value);
napi_value tensor_metadata_value;
nstatus = CreateTensorMetadataValue(env, handle, shape_value, dtype_value,
&tensor_metadata_value);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);

// Push into output array
nstatus = napi_set_element(env, output_tensor_infos, i, tensor_info_value);
nstatus =
napi_set_element(env, output_tensor_infos, i, tensor_metadata_value);
ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
}

Expand Down
10 changes: 8 additions & 2 deletions binding/tfjs_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,16 @@ class TFJSBackend {
TFJSBackend(napi_env env);
~TFJSBackend();

int32_t InsertHandle(TFE_TensorHandle* tfe_handle);
// TODO - doc me.
napi_status CreateTensorMetadataValue(napi_env env,
TFE_TensorHandle* tfe_handle,
napi_value shape_value,
napi_value dtype_value,
napi_value* tensor_metadata_value);

TFE_Context* tfe_context_;
std::map<int32_t, TFE_TensorHandle*> tfe_handle_map_;
// TODO (type-def this thing)
std::map<int32_t, TFE_TensorHandle*>* tfe_handle_map_;
int32_t next_tensor_id_;
std::string device_name;
};
Expand Down
4 changes: 4 additions & 0 deletions src/debug_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import * as tf from './index';

const c = tf.add([1, 2], [3, 4]);
console.log(c.dataSync());
29 changes: 12 additions & 17 deletions src/nodejs_kernel_backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@ import {createTensorsTypeOpAttr, createTypeOpAttr, getTFDType} from './ops/op_ut
import {TensorMetadata, TFEOpAttr, TFJSBinding} from './tfjs_binding';

type TensorInfo = {
shape: number[],
dtype: number,
metadata: TensorMetadata,
values: Float32Array|Int32Array|Uint8Array,
id: number
};

interface DataId {}
Expand Down Expand Up @@ -77,12 +75,7 @@ export class NodeJSKernelBackend extends KernelBackend {
private createOutputTensor(metadata: TensorMetadata): Tensor {
const newId = {};

this.tensorMap.set(newId, {
shape: metadata.shape,
dtype: metadata.dtype,
id: metadata.id,
values: null
});
this.tensorMap.set(newId, {metadata, values: null});

let dtype: DataType;
switch (metadata.dtype) {
Expand Down Expand Up @@ -122,18 +115,19 @@ export class NodeJSKernelBackend extends KernelBackend {
if (info.values != null) {
// Values were delayed to write into the TensorHandle. Do that before
// Op execution and clear stored values.
info.id =
this.binding.createTensor(info.shape, info.dtype, info.values);
info.metadata = this.binding.createTensor(
info.metadata.shape, info.metadata.dtype, info.values);
info.values = null;
this.tensorMap.set((tensors[i] as Tensor).dataId, info);
}
ids.push(info.id);
ids.push(info.metadata.id);
} else if (tensors[i] instanceof Int64Scalar) {
// Then `tensors[i]` is a Int64Scalar, which we currently represent
// using an `Int32Array`.
const value = (tensors[i] as Int64Scalar).valueArray;
const id = this.binding.createTensor([], this.binding.TF_INT64, value);
ids.push(id);
const metadata =
this.binding.createTensor([], this.binding.TF_INT64, value);
ids.push(metadata.id);
} else {
throw new Error(`Invalid Tensor type: ${typeof tensors[i]}`);
}
Expand Down Expand Up @@ -201,12 +195,12 @@ export class NodeJSKernelBackend extends KernelBackend {
if (info.values != null) {
return info.values;
} else {
return this.binding.tensorDataSync(info.id);
return this.binding.tensorDataSync(info.metadata.id);
}
}

disposeData(dataId: object): void {
const id = this.tensorMap.get(dataId).id;
const id = this.tensorMap.get(dataId).metadata.id;
if (id != null && id >= 0) {
this.binding.deleteTensor(id);
}
Expand All @@ -226,7 +220,8 @@ export class NodeJSKernelBackend extends KernelBackend {
register(dataId: object, shape: number[], dtype: DataType): void {
if (!this.tensorMap.has(dataId)) {
this.tensorMap.set(
dataId, {shape, dtype: getTFDType(dtype), values: null, id: -1});
dataId,
{metadata: {id: -1, shape, dtype: getTFDType(dtype)}, values: null});
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/tfjs_binding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ export interface TFJSBinding {
// Creates a tensor with the backend:
createTensor(
shape: number[], dtype: number,
buffer: Float32Array|Int32Array|Uint8Array): number;
buffer: Float32Array|Int32Array|Uint8Array): TensorMetadata;

// Deletes a tensor with the backend:
deleteTensor(tensorId: number): void;
Expand Down
Loading