Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Update save_restore_v2_ops.cc
修改写作规范
  • Loading branch information
MrRobotsAA authored Oct 24, 2022
commit 5da8c3cb253c545fa270f59bbccf1bd356bfeb50
58 changes: 26 additions & 32 deletions tensorflow/core/kernels/save_restore_v2_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ class SaveV2 : public OpKernel {
}

template <typename TKey, typename TValue>
void DumpEvWithGlobalStep(OpKernelContext* context, int variable_index, const string& tensor_name, BundleWriter& writer, DataType global_step_type) {
void DumpEvWithGlobalStep(OpKernelContext* context, int variable_index,
const string& tensor_name, BundleWriter& writer,
DataType global_step_type) {
if (global_step_type == DT_INT32) {
DumpEv<TKey, TValue, int32>(context, variable_index, tensor_name, writer);
} else {
Expand All @@ -117,7 +119,8 @@ class SaveV2 : public OpKernel {
}

template <typename TKey, typename TValue, typename TGlobalStep>
void DumpEv(OpKernelContext* context, int variable_index, const string& tensor_name, BundleWriter& writer) {
void DumpEv(OpKernelContext* context, int variable_index,
const string& tensor_name, BundleWriter& writer) {
EmbeddingVar<TKey, TValue>* variable = nullptr;
OP_REQUIRES_OK(context,
LookupResource(context, HandleFromInput(context, variable_index), &variable));
Expand All @@ -143,12 +146,12 @@ class SaveV2 : public OpKernel {
shape_and_slices);

const int kFixedInputs = 3; // Prefix, tensor names, shape_and_slices.
const int num_tensors = static_cast<int>(tensor_names.NumElements()); //获取tensor个数
const int num_tensors = static_cast<int>(tensor_names.NumElements());
const string& prefix_string = prefix.scalar<tstring>()();
const auto& tensor_names_flat = tensor_names.flat<tstring>();
const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>();

const int Nosql_Marker = 0;
const int nosql_marker = 0;
auto tempstate = random::New64();
string db_prefix_tmp = strings::StrCat(prefix_string,"--temp",tempstate);
DBWriter dbwriter(Env::Default(), prefix_string,db_prefix_tmp);
Expand All @@ -171,8 +174,7 @@ class SaveV2 : public OpKernel {
const string& tensor_name = tensor_names_flat(i);


if (tensor_types_[i] == DT_RESOURCE)
{
if (tensor_types_[i] == DT_RESOURCE) {
auto& handle = HandleFromInput(context, i + kFixedInputs);
if (IsHandle<EmbeddingVar<int64, float>>(handle)) {
EmbeddingVar<int64, float>* variable = nullptr;
Expand All @@ -190,8 +192,7 @@ class SaveV2 : public OpKernel {
} else if (ev_key_types_[start_ev_key_index] == DT_INT64) {
DumpEvWithGlobalStep<int64, float>(context, i + kFixedInputs, tensor_name, writer, tensor_types_[0]);
}
}
else if (IsHandle<HashTableResource>(handle)) {
} else if (IsHandle<HashTableResource>(handle)) {
auto handles = context->input(i + kFixedInputs).flat<ResourceHandle>();
int tensible_size = handles.size() - 1;
std::vector<core::ScopedUnref> unrefs;
Expand Down Expand Up @@ -232,10 +233,8 @@ class SaveV2 : public OpKernel {
&writer, hashtable, tensibles, table_name, tensible_name,
slice.start(0), slice.length(0), slice_shape.dim_size(0)));


}
}
else if (IsHandle<HashTableAdmitStrategyResource>(handle)) {
} else if (IsHandle<HashTableAdmitStrategyResource>(handle)) {
HashTableAdmitStrategyResource* resource;
OP_REQUIRES_OK(context,
LookupResource(context, HandleFromInput(context, i + kFixedInputs), &resource));
Expand All @@ -254,12 +253,9 @@ class SaveV2 : public OpKernel {
&writer, bf, tensor_name, slice.start(0),
slice.length(0), slice_shape.dim_size(0)));
}



start_ev_key_index++;
}
else
{
} else {
const Tensor& tensor = context->input(i + kFixedInputs);
if (!shape_and_slices_flat(i).empty()) {
const string& shape_spec = shape_and_slices_flat(i);
Expand All @@ -277,32 +273,32 @@ class SaveV2 : public OpKernel {
shape_spec, ", tensor: ",
tensor.shape().DebugString()));

if(Nosql_Marker==1){
if(nosql_marker==1){

OP_REQUIRES_OK(context,
dbwriter.AddSlice(tensor_name, shape, slice, tensor,"slice_tensor"));
}
else{
} else{

OP_REQUIRES_OK(context,
writer.AddSlice(tensor_name, shape, slice, tensor));
}
}
else {
if(Nosql_Marker==1){
OP_REQUIRES_OK(context, dbwriter.Add(tensor_name, tensor,"normal_tensor"));
}
else{
string tmp_dbfile_prefix_string = strings::StrCat(prefix_string,"--temp",tempstate,"--data--0--1","--tensor--",tensor_name);
} else {
if(nosql_marker==1){
OP_REQUIRES_OK(context,
dbwriter.Add(tensor_name, tensor,"normal_tensor"));
} else{
string tmp_dbfile_prefix_string =
strings::StrCat(prefix_string,"--temp",tempstate,"--data--0--1","--tensor--",tensor_name);
OP_REQUIRES_OK(context, writer.Add(tensor_name, tensor,tmp_dbfile_prefix_string));
}
}
}
}
if(Nosql_Marker==1){
if(nosql_marker==1){

OP_REQUIRES_OK(context, dbwriter.Finish());
}
else{
} else{
OP_REQUIRES_OK(context, writer.Finish());
}
}
Expand Down Expand Up @@ -533,9 +529,7 @@ class MergeV2Checkpoints : public OpKernel {
const string& merged_prefix = destination_prefix.scalar<tstring>()();
OP_REQUIRES_OK(
context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix));
//合并不同的checkpoint源文件

//删除旧的目录

if (delete_old_dirs_) {
const string merged_dir(io::Dirname(merged_prefix));
for (const string& input_prefix : input_prefixes) {
Expand Down