From 8d4024406210dbcb0a99cc036606efcfa3671c3a Mon Sep 17 00:00:00 2001 From: Chen Ding Date: Wed, 28 Feb 2024 16:57:59 +0800 Subject: [PATCH 01/14] [Docs] Update deeprec2402 release images and notes in README.md & RELEASE.md. (#975) Signed-off-by: candy.dc --- README.md | 4 +- RELEASE.md | 44 +++++++++++++++++++ docs/docs_en/DeepRec-Compile-And-Install.md | 4 +- docs/docs_en/Estimator-Compile-And-Install.md | 2 +- docs/docs_en/TFServing-Compile-And-Install.md | 2 +- docs/docs_zh/DeepRec-Compile-And-Install.md | 4 +- docs/docs_zh/Estimator-Compile-And-Install.md | 2 +- docs/docs_zh/TFServing-Compile-And-Install.md | 2 +- 8 files changed, 54 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 8f491e14665..b7d7b578c24 100644 --- a/README.md +++ b/README.md @@ -95,13 +95,13 @@ $ pip3 install /tmp/tensorflow_pkg/tensorflow-1.15.5+${version}-cp38-cp38m-linux #### Image for CPU ``` -alideeprec/deeprec-release:deeprec2310-cpu-py38-ubuntu20.04 +alideeprec/deeprec-release:deeprec2402-cpu-py38-ubuntu20.04 ``` #### Image for GPU CUDA11.6 ``` -alideeprec/deeprec-release:deeprec2310-gpu-py38-cu116-ubuntu20.04 +alideeprec/deeprec-release:deeprec2402-gpu-py38-cu116-ubuntu20.04 ``` *** diff --git a/RELEASE.md b/RELEASE.md index 6b7e4a7fd79..b095351d2a0 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,4 +1,48 @@ +# Release r1.15.5-deeprec2402 + +## **Major Features and Improvements** + +### **Embedding** + +- Refine KVInterface::GetShardedSnapshot API. +- Undefine EV GPU interface in CPU compile. +- Make Embedding backward compatible with previous saved_model. +- Log error when EV has been initialized in EV Import OP. + +### **Op Implement** + +- Implement of SliceSend/SliceRecv Op. +- Implement FileSliceSend/FileSliceRecvOp. + +### **SDK** + +- Add build SDK package. + +### **BugFix** + +- Fix shared embedding frequency counting problem. +- Fix Graph contains EmbeddingVariable compiling issue. +- Fix a scheduling issue. +- Fix tensor shape meta-data bug for DataFrame Value. + +### **ModelZoo** + +- Set Saver parameter sharded=True in distributed training. + +More details of features: [https://deeprec.readthedocs.io/zh/latest/](url) + +## **Release Images** + +### **CPU Image** + +`alideeprec/deeprec-release:deeprec2402-cpu-py38-ubuntu20.04` + +### **GPU Image** + +`alideeprec/deeprec-release:deeprec2402-gpu-py38-cu116-ubuntu20.04` + # Release r1.15.5-deeprec2310 + ## **Major Features and Improvements** ### **Embedding** diff --git a/docs/docs_en/DeepRec-Compile-And-Install.md b/docs/docs_en/DeepRec-Compile-And-Install.md index fdf3e295fdd..379526e5b24 100644 --- a/docs/docs_en/DeepRec-Compile-And-Install.md +++ b/docs/docs_en/DeepRec-Compile-And-Install.md @@ -111,7 +111,7 @@ pip3 install /tmp/tensorflow_pkg/tensorflow-1.15.5+${version}-cp38-cp38m-linux_x x86_64: ``` -alideeprec/deeprec-release:deeprec2310-cpu-py38-ubuntu20.04 +alideeprec/deeprec-release:deeprec2402-cpu-py38-ubuntu20.04 ``` arm64: @@ -122,5 +122,5 @@ alideeprec/deeprec-release:deeprec2302-cpu-py38-ubuntu22.04-arm64 **GPU Image with CUDA 11.6** ``` -alideeprec/deeprec-release:deeprec2310-gpu-py38-cu116-ubuntu20.04 +alideeprec/deeprec-release:deeprec2402-gpu-py38-cu116-ubuntu20.04 ``` diff --git a/docs/docs_en/Estimator-Compile-And-Install.md b/docs/docs_en/Estimator-Compile-And-Install.md index 55f759a3c2a..6305d739571 100644 --- a/docs/docs_en/Estimator-Compile-And-Install.md +++ b/docs/docs_en/Estimator-Compile-And-Install.md @@ -40,7 +40,7 @@ DeepRec provide new distributed protocols such as grpc++ and star_server, which Source Code: [https://github.com/DeepRec-AI/estimator](https://github.com/DeepRec-AI/estimator) -Develop Branch:master, Latest Release Branch: deeprec2310 +Develop Branch:master, Latest Release Branch: deeprec2402 ## Estimator Build diff --git a/docs/docs_en/TFServing-Compile-And-Install.md b/docs/docs_en/TFServing-Compile-And-Install.md index 79a0944aa3e..ea70f397c98 100644 --- a/docs/docs_en/TFServing-Compile-And-Install.md +++ b/docs/docs_en/TFServing-Compile-And-Install.md @@ -39,7 +39,7 @@ We provide optimized TFServing which could highly improve performance in inferen Source Code: [https://github.com/DeepRec-AI/serving](https://github.com/DeepRec-AI/serving) -Develop Branch: master, Latest Release Branch: deeprec2310 +Develop Branch: master, Latest Release Branch: deeprec2402 ## TFServing Build diff --git a/docs/docs_zh/DeepRec-Compile-And-Install.md b/docs/docs_zh/DeepRec-Compile-And-Install.md index ad8fd36dbf7..0c11dca394f 100644 --- a/docs/docs_zh/DeepRec-Compile-And-Install.md +++ b/docs/docs_zh/DeepRec-Compile-And-Install.md @@ -108,7 +108,7 @@ pip3 install /tmp/tensorflow_pkg/tensorflow-1.15.5+${version}-cp38-cp38m-linux_x x86_64: ``` -alideeprec/deeprec-release:deeprec2310-cpu-py38-ubuntu20.04 +alideeprec/deeprec-release:deeprec2402-cpu-py38-ubuntu20.04 ``` arm64: @@ -119,7 +119,7 @@ alideeprec/deeprec-release:deeprec2302-cpu-py38-ubuntu22.04-arm64 **GPU CUDA11.6镜像** ``` -alideeprec/deeprec-release:deeprec2310-gpu-py38-cu116-ubuntu20.04 +alideeprec/deeprec-release:deeprec2402-gpu-py38-cu116-ubuntu20.04 ``` ## DeepRec Processor编译打包 diff --git a/docs/docs_zh/Estimator-Compile-And-Install.md b/docs/docs_zh/Estimator-Compile-And-Install.md index e54c8ddbd2f..eeb4f66dc99 100644 --- a/docs/docs_zh/Estimator-Compile-And-Install.md +++ b/docs/docs_zh/Estimator-Compile-And-Install.md @@ -40,7 +40,7 @@ 代码库:[https://github.com/DeepRec-AI/estimator](https://github.com/DeepRec-AI/estimator) -开发分支:master,最新Release分支:deeprec2310 +开发分支:master,最新Release分支:deeprec2402 ## Estimator编译 diff --git a/docs/docs_zh/TFServing-Compile-And-Install.md b/docs/docs_zh/TFServing-Compile-And-Install.md index a43d2d517a6..b0460934165 100644 --- a/docs/docs_zh/TFServing-Compile-And-Install.md +++ b/docs/docs_zh/TFServing-Compile-And-Install.md @@ -39,7 +39,7 @@ 代码库:[https://github.com/DeepRec-AI/serving](https://github.com/DeepRec-AI/serving) -开发分支:master,最新Release分支:deeprec2310 +开发分支:master,最新Release分支:deeprec2402 ## TFServing编译&打包 From 8b58f9b93e144fa2d6517d5d370dc0df4fd3644b Mon Sep 17 00:00:00 2001 From: Chen Ding Date: Wed, 28 Feb 2024 17:18:29 +0800 Subject: [PATCH 02/14] [Dockerfile] Add DeepRec release image dockerfile. (#976) Signed-off-by: candy.dc --- cibuild/dockerfiles/Dockerfile.release | 32 ++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 cibuild/dockerfiles/Dockerfile.release diff --git a/cibuild/dockerfiles/Dockerfile.release b/cibuild/dockerfiles/Dockerfile.release new file mode 100644 index 00000000000..77b013f840d --- /dev/null +++ b/cibuild/dockerfiles/Dockerfile.release @@ -0,0 +1,32 @@ +# build DeepRec & estimator wheel +FROM alideeprec/deeprec-base:deeprec-base-cpu-py38-ubuntu20.04 AS deeprec_build + +ARG TF_COMMIT=deeprec2402 + +RUN mkdir -p /src +RUN wget -nv -O /src/install_bazel.sh \ + http://pythonrun.oss-cn-zhangjiakou.aliyuncs.com/bazel-0.26.1-installer-linux-x86_64.sh && \ + bash /src/install_bazel.sh + +RUN git clone https://github.com/DeepRec-AI/DeepRec.git /src/DeepRec && \ + cd /src/DeepRec && \ + git checkout ${TF_COMMIT} +RUN cd /src/DeepRec && \ + yes "" | bash ./configure || true && \ + bazel build -c opt --config=opt //tensorflow/tools/pip_package:build_pip_package && \ + bazel-bin/tensorflow/tools/pip_package/build_pip_package /src/ + +RUN pip install /src/tensorflow-1.15.5+${TF_COMMIT}-cp38-cp38-linux_x86_64.whl + +RUN git clone https://github.com/DeepRec-AI/estimator.git /src/estimator && \ + cd /src/estimator && \ + git checkout ${TF_COMMIT} +RUN cd /src/estimator && \ + bazel build //tensorflow_estimator/tools/pip_package:build_pip_package && \ + bazel-bin/tensorflow_estimator/tools/pip_package/build_pip_package /src/ + +# build DeeepRec release image +FROM alideeprec/deeprec-base:deeprec-base-cpu-py38-ubuntu20.04 +COPY --from=deeprec_build /src/*.whl / +RUN pip install /tensorflow-1.15.5+${TF_COMMIT}-cp38-cp38-linux_x86_64.whl tensorflow_estimator-1.15.2+${TF_COMMIT}-py2.py3-none-any.whl +RUN rm -f /tensorflow-1.15.5+${TF_COMMIT}-cp38-cp38-linux_x86_64.whl /tensorflow_estimator-1.15.2+${TF_COMMIT}-py2.py3-none-any.whl From 186afd0479bb43c629cafa808be70b7f5ac33d83 Mon Sep 17 00:00:00 2001 From: Chen Ding Date: Thu, 29 Feb 2024 10:10:38 +0800 Subject: [PATCH 03/14] [Serving] Fix syntax error in generate timeline tool. (#977) Signed-off-by: candy.dc --- serving/tools/timeline/gen_timeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/serving/tools/timeline/gen_timeline.py b/serving/tools/timeline/gen_timeline.py index f055e473fa0..d56c1b39897 100644 --- a/serving/tools/timeline/gen_timeline.py +++ b/serving/tools/timeline/gen_timeline.py @@ -1,6 +1,6 @@ import sys -import config_pb2 -import timeline +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import timeline def gen_timeline(src_name, dest_name): run_metadata = config_pb2.RunMetadata() From 6dae552cb40e954cce59e125977f141c6a926ada Mon Sep 17 00:00:00 2001 From: Chen Bangduo Date: Thu, 7 Mar 2024 14:35:36 +0800 Subject: [PATCH 04/14] [Embedding] Refine header file of embedding variable. (#978) Signed-off-by: chenbangduo.cbd --- tensorflow/core/framework/embedding/embedding_var.h | 1 - tensorflow/core/kernels/kv_variable_ops.cc | 1 + tensorflow/core/kernels/kv_variable_restore_ops.cc | 1 + tensorflow/core/kernels/training_ali_ops.cc | 8 ++++---- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/framework/embedding/embedding_var.h b/tensorflow/core/framework/embedding/embedding_var.h index c0d26a2f4d8..81941bc9ff9 100644 --- a/tensorflow/core/framework/embedding/embedding_var.h +++ b/tensorflow/core/framework/embedding/embedding_var.h @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/core/framework/embedding/gpu_hash_map_kv.h" #include "tensorflow/core/framework/embedding/embedding_config.h" #include "tensorflow/core/framework/embedding/storage.h" -#include "tensorflow/core/framework/embedding/storage_factory.h" #include "tensorflow/core/framework/typed_allocator.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/kv_variable_ops.cc b/tensorflow/core/kernels/kv_variable_ops.cc index 5cd0ef140bd..b7567ffe924 100644 --- a/tensorflow/core/kernels/kv_variable_ops.cc +++ b/tensorflow/core/kernels/kv_variable_ops.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/embedding/cache.h" #include "tensorflow/core/framework/embedding/config.pb.h" #include "tensorflow/core/framework/embedding/embedding_var.h" +#include "tensorflow/core/framework/embedding/storage_factory.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/resource_mgr.h" diff --git a/tensorflow/core/kernels/kv_variable_restore_ops.cc b/tensorflow/core/kernels/kv_variable_restore_ops.cc index 2eccf485ef8..e16db9b4cd6 100644 --- a/tensorflow/core/kernels/kv_variable_restore_ops.cc +++ b/tensorflow/core/kernels/kv_variable_restore_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/embedding/cache.h" #include "tensorflow/core/framework/embedding/config.pb.h" #include "tensorflow/core/framework/embedding/embedding_var.h" +#include "tensorflow/core/framework/embedding/storage_factory.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/resource_mgr.h" diff --git a/tensorflow/core/kernels/training_ali_ops.cc b/tensorflow/core/kernels/training_ali_ops.cc index 546b30e29dd..fc21ab610cf 100644 --- a/tensorflow/core/kernels/training_ali_ops.cc +++ b/tensorflow/core/kernels/training_ali_ops.cc @@ -236,7 +236,7 @@ class KvSparseApplyAdagradGPUOp : public OpKernel { T** dev_a = dev_v + task_size; CHECK(dev_a); CHECK(dev_v); - DeviceMemoryBase dev_v_ptr(dev_v, sizeof(T*) * task_size * 2); + se::DeviceMemoryBase dev_v_ptr(dev_v, sizeof(T*) * task_size * 2); stream->ThenMemcpy(&dev_v_ptr, v, sizeof(T*) * task_size * 2); int block_size = 128; @@ -1606,7 +1606,7 @@ class KvSparseApplyAdamGPUOp : public OpKernel { CHECK(dev_m_ptr); CHECK(dev_v_ptr); - DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3); + se::DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3); stream->ThenMemcpy(&dst_ptr, var_ptr, sizeof(T*) * task_size * 3); int block_size = 128; @@ -2579,7 +2579,7 @@ class KvSparseApplyAdamAsyncGPUOp : public OpKernel { CHECK(dev_m_ptr); CHECK(dev_v_ptr); - DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3); + se::DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3); stream->ThenMemcpy(&dst_ptr, var_ptr, sizeof(T*) * task_size * 3); int block_size = 128; @@ -3236,7 +3236,7 @@ class KvSparseApplyAdamWGPUOp : public OpKernel { CHECK(dev_m_ptr); CHECK(dev_v_ptr); - DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3); + se::DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3); stream->ThenMemcpy(&dst_ptr, var_ptr, sizeof(T*) * task_size * 3); int block_size = 128; From cf16856d01551c9d1cb005722d7f62a448df7095 Mon Sep 17 00:00:00 2001 From: Chen Bangduo Date: Tue, 26 Mar 2024 17:15:18 +0800 Subject: [PATCH 05/14] [Incremental Checkpoint] Fix import incremental embedding variable. (#983) Signed-off-by: chenbangduo.cbd --- .../embedding/embedding_var_restore.cc | 50 +++++++++-------- tensorflow/python/training/incr_ckpt_test.py | 54 +++++++++++++++++++ 2 files changed, 82 insertions(+), 22 deletions(-) diff --git a/tensorflow/core/framework/embedding/embedding_var_restore.cc b/tensorflow/core/framework/embedding/embedding_var_restore.cc index 11c13008995..6ff07bf7e43 100644 --- a/tensorflow/core/framework/embedding/embedding_var_restore.cc +++ b/tensorflow/core/framework/embedding/embedding_var_restore.cc @@ -102,45 +102,48 @@ void CheckpointLoader::RestoreInternal( Tensor part_filter_offset_tensor; if (!restore_args_.m_is_oldform) { /****** InitPartOffsetTensor ******/ - TensorShape part_offset_shape, part_filter_offset_shape; - DataType part_offset_type, part_filter_offset_type; + TensorShape part_offset_shape; + DataType part_offset_type; string offset_tensor_name; if (!restore_args_.m_is_incr) { offset_tensor_name = name_string + kPartOffsetTensorSuffsix; } else { offset_tensor_name = name_string + kIncrPartOffsetTensorSuffsix; } - - string offset_filter_tensor_name = - name_string + kPartFilterOffsetTensorSuffsix; + Status s = reader_->LookupDtypeAndShape( offset_tensor_name, &part_offset_type, &part_offset_shape); if (!s.ok()) { LOG(ERROR) << "EV restoring fail:" << s.error_message(); } - s = reader_->LookupDtypeAndShape(offset_filter_tensor_name, - &part_filter_offset_type, - &part_filter_offset_shape); - if (!s.ok()) { - LOG(ERROR) << "EV restoring fail: " << s.error_message(); - } part_offset_tensor = Tensor(cpu_allocator(), part_offset_type, part_offset_shape); - part_filter_offset_tensor = Tensor( - cpu_allocator(), part_filter_offset_type, part_filter_offset_shape); s = reader_->Lookup(offset_tensor_name, &part_offset_tensor); if (!s.ok()) { LOG(ERROR) << "EV restoring fail:" << s.error_message(); } - s = reader_->Lookup(offset_filter_tensor_name, - &part_filter_offset_tensor); - if (!s.ok()) { - LOG(ERROR) << "EV restoring fail: " << s.error_message(); + if (restore_args_.m_has_filter) { + TensorShape part_filter_offset_shape; + DataType part_filter_offset_type; + string offset_filter_tensor_name = + name_string + kPartFilterOffsetTensorSuffsix; + s = reader_->LookupDtypeAndShape(offset_filter_tensor_name, + &part_filter_offset_type, + &part_filter_offset_shape); + if (!s.ok()) { + LOG(ERROR) << "EV restoring fail: " << s.error_message(); + } + part_filter_offset_tensor = \ + Tensor(cpu_allocator(), part_filter_offset_type, + part_filter_offset_shape); + s = reader_->Lookup(offset_filter_tensor_name, + &part_filter_offset_tensor); + if (!s.ok()) { + LOG(ERROR) << "EV restoring fail: " << s.error_message(); + } } } - auto part_offset_flat = part_offset_tensor.flat(); - auto part_filter_offset_flat = part_filter_offset_tensor.flat(); if (restore_args_.m_is_oldform) { VLOG(1) << "old form, EV name:" << name_string @@ -164,6 +167,7 @@ void CheckpointLoader::RestoreInternal( VLOG(1) << "new form checkpoint... :" << name_string << " , partition_id:" << restore_args_.m_partition_id << " , partition_num:" << restore_args_.m_partition_num; + auto part_offset_flat = part_offset_tensor.flat(); for (size_t i = 0; i < restore_args_.m_loaded_parts.size(); i++) { int subpart_id = restore_args_.m_loaded_parts[i]; size_t value_unit_bytes = sizeof(V) * restore_args_.m_old_dim; @@ -183,6 +187,7 @@ void CheckpointLoader::RestoreInternal( new_dim, emb_config, device); if (restore_args_.m_has_filter) { + auto part_filter_offset_flat = part_filter_offset_tensor.flat(); Status s = EVRestoreFilteredFeatures( subpart_id, new_dim, restore_buff, part_filter_offset_flat, emb_config, device); @@ -444,7 +449,7 @@ Status CheckpointLoader::EVInitTensorNameAndShape( } st = reader_->LookupHeader(restore_args_.m_tensor_version + "_filtered", sizeof(K) * version_filter_shape.dim_size(0)); - if (!st.ok()) { + if (!st.ok() && st.code() != error::NOT_FOUND) { return st; } st = reader_->LookupTensorShape(restore_args_.m_tensor_freq + "_filtered", @@ -463,7 +468,8 @@ Status CheckpointLoader::EVInitTensorNameAndShape( return st; } } - return st; + + return Status::OK(); } #define REGISTER_KERNELS(ktype, vtype) \ template Status CheckpointLoader::EVInitTensorNameAndShape(\ @@ -644,4 +650,4 @@ TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX) #undef REGISTER_KERNELS_ALL_INDEX #undef REGISTER_KERNELS -}// namespace tensorflow \ No newline at end of file +}// namespace tensorflow diff --git a/tensorflow/python/training/incr_ckpt_test.py b/tensorflow/python/training/incr_ckpt_test.py index b4f7ded3cea..55cf748a9d6 100644 --- a/tensorflow/python/training/incr_ckpt_test.py +++ b/tensorflow/python/training/incr_ckpt_test.py @@ -451,5 +451,59 @@ def testIncrementalSaverForResourceVariable(self): saver.build() incr_saver = incr_saver_module._get_incremental_saver(True, saver) + def testIncrementalSaverSaveAndRestore(self): + tmp_path = self.get_temp_dir() + full_ckpt_dir = os.path.join(tmp_path, "model.ckpt") + incr_ckpt_dir = os.path.join(tmp_path, "incr.ckpt") + full_ckpt_path = None + incr_ckpt_path = None + + # construct graph + emb_var = variable_scope.get_embedding_variable("emb", embedding_dim=3, + initializer = init_ops.ones_initializer(dtypes.float32)) + emb = embedding_ops.embedding_lookup(emb_var, + math_ops.cast([0, 1, 2, 3, 4], dtypes.int64)) + loss = math_ops.reduce_sum(emb, name = 'reduce_sum') + opt = adagrad.AdagradOptimizer(0.1) + g_v = opt.compute_gradients(loss) + train_op = opt.apply_gradients(g_v) + init = variables.global_variables_initializer() + saver = saver_module.Saver(sharded=True, incremental_save_restore=True) + incr_saver = \ + incr_saver_module.IncrementalSaver(sharded=True, + saver_def=saver.saver_def, defer_build=True) + incr_saver.build(saver._builder.filename_tensor) + + # generate full ckpt and incr ckpt. + full_ckpt_value=None + incr_ckpt_value=None + with self.test_session() as sess: + sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) + sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) + sess.run([init]) + sess.run([train_op]) + full_ckpt_path = saver.save(sess, full_ckpt_dir, global_step = 10) + full_ckpt_value = sess.run([emb]) + print("full_ckpt: {}".format(full_ckpt_value)) + sess.run([train_op]) + incr_ckpt_path = \ + incr_saver.incremental_save(sess, incr_ckpt_dir, global_step=20) + incr_ckpt_value = sess.run([emb]) + print("incr_ckpt: {}".format(incr_ckpt_value)) + + # check the value after restoring parameter. + with self.test_session() as sess: + sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) + sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) + sess.run([init]) + saver.restore(sess, full_ckpt_path) + restore_full_ckpt_value = sess.run([emb]) + print("restore_full_ckpt: {}".format(restore_full_ckpt_value)) + incr_saver.incremental_restore(sess, full_ckpt_path, incr_ckpt_path) + restore_incr_ckpt_value = sess.run([emb]) + print("restore_incr_ckpt: {}".format(restore_incr_ckpt_value)) + self.assertAllClose(full_ckpt_value, restore_full_ckpt_value) + self.assertAllClose(incr_ckpt_value, restore_incr_ckpt_value) + if __name__ == "__main__": googletest.main() From d5f7f6ad77a59b70679835009dbe31add175dba3 Mon Sep 17 00:00:00 2001 From: "Secret.Sun" Date: Wed, 10 Apr 2024 14:41:50 +0800 Subject: [PATCH 06/14] [Runtime] Remove read limit of ReadBinaryProto. (#981) Signed-off-by: Secret.Sun --- tensorflow/core/platform/env.cc | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc index ac91b79a07f..b835677627a 100644 --- a/tensorflow/core/platform/env.cc +++ b/tensorflow/core/platform/env.cc @@ -508,14 +508,7 @@ Status ReadBinaryProto(Env* env, const string& fname, TF_RETURN_IF_ERROR(env->NewRandomAccessFile(fname, &file)); std::unique_ptr stream(new FileStream(file.get())); - // TODO(jiayq): the following coded stream is for debugging purposes to allow - // one to parse arbitrarily large messages for MessageLite. One most likely - // doesn't want to put protobufs larger than 64MB on Android, so we should - // eventually remove this and quit loud when a large protobuf is passed in. ::tensorflow::protobuf::io::CodedInputStream coded_stream(stream.get()); - // Total bytes hard limit / warning limit are set to 1GB and 512MB - // respectively. - coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20); if (!proto->ParseFromCodedStream(&coded_stream) || !coded_stream.ConsumedEntireMessage()) { From a4489e31a4b9bc8371198537a0a15af6011ef8ae Mon Sep 17 00:00:00 2001 From: Chen Bangduo Date: Fri, 12 Apr 2024 14:22:32 +0800 Subject: [PATCH 07/14] [EVAllocator] Fix the bug in configuring ARENA_ARRAY_SIZE. (#986) Signed-off-by: chenbangduo.cbd --- tensorflow/core/framework/ev_allocator.h | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/framework/ev_allocator.h b/tensorflow/core/framework/ev_allocator.h index d3251b14782..5082ee04b72 100644 --- a/tensorflow/core/framework/ev_allocator.h +++ b/tensorflow/core/framework/ev_allocator.h @@ -546,15 +546,15 @@ class EVAllocatorImpl { page_map_ = new PageMap(); page_map_->Init(); - int64 arena_array_size = ARENA_ARRAY_SIZE; + arena_array_size_ = ARENA_ARRAY_SIZE; Status s = ReadInt64FromEnvVar("ARENA_ARRAY_SIZE", - ARENA_ARRAY_SIZE, &arena_array_size); + ARENA_ARRAY_SIZE, &arena_array_size_); if (!s.ok()) { LOG(ERROR) << "Read ARENA_ARRAY_SIZE env error: " << s.error_message(); } - LOG(INFO) << "EVAllocator set arena array size: " << arena_array_size; + LOG(INFO) << "EVAllocator set arena array size: " << arena_array_size_; - arenas_ = new std::vector>(arena_array_size, page_map_); + arenas_ = new std::vector>(arena_array_size_, page_map_); arena_cur_index = 0; } @@ -602,7 +602,7 @@ class EVAllocatorImpl { { mutex_lock l(mu_arena_index_); ret = &((*arenas_)[arena_cur_index]); - arena_cur_index = (arena_cur_index + 1) % ARENA_ARRAY_SIZE; + arena_cur_index = (arena_cur_index + 1) % arena_array_size_; } return ret; @@ -619,6 +619,7 @@ class EVAllocatorImpl { PageMap* page_map_ = nullptr; std::vector> *arenas_ = nullptr; int arena_cur_index GUARDED_BY(mu_arena_index_); + int64 arena_array_size_; }; template From 04413cf0ee6ca57f35446095c4e27bc1cfdf2b0d Mon Sep 17 00:00:00 2001 From: Chaofeng Guo Date: Thu, 18 Apr 2024 19:56:17 +0800 Subject: [PATCH 08/14] [Embedding] Fix the issue of default_value type mismatch in the EV Gather op. (#989) Signed-off-by: Lyaction --- tensorflow/python/ops/kv_variable_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/kv_variable_ops.py b/tensorflow/python/ops/kv_variable_ops.py index 840aadf2541..55e01537c0d 100644 --- a/tensorflow/python/ops/kv_variable_ops.py +++ b/tensorflow/python/ops/kv_variable_ops.py @@ -858,10 +858,10 @@ def sparse_read(self, indices, name=None, ev_init_value=None, counts=None): if self._trainable: tape.variable_accessed(self) if ev_init_value is not None: - default_value = ev_init_value + default_value = math_ops.cast(ev_init_value, self.dtype) is_use_default_value_tensor = True else: - default_value = ops.convert_to_tensor(1.0) + default_value = ops.convert_to_tensor(1.0, dtype=self.dtype) is_use_default_value_tensor = False if counts != None: value = gen_kv_variable_ops.kv_resource_gather_v1(self._handle, From fc08e1b605490e818cdf80bc2389b68028c19049 Mon Sep 17 00:00:00 2001 From: Chen Bangduo Date: Fri, 26 Apr 2024 11:33:59 +0800 Subject: [PATCH 09/14] [Hook] Add 'before_create_session' interface to SessionRunHook. (#991) Signed-off-by: chenbangduo.cbd --- tensorflow/python/training/monitored_session.py | 3 +++ tensorflow/python/training/session_run_hook.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index 6eb204785dd..9492028a200 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -957,6 +957,8 @@ def __init__(self, session_creator, hooks, stop_grace_period_secs): def create_session(self): """Creates a coordinated session.""" # Keep the tf_sess for unit testing. + for hook in self._hooks: + hook.before_create_session() self.tf_sess = self._session_creator.create_session() # We don't want coordinator to suppress any exception. self.coord = coordinator.Coordinator(clean_stop_exception_types=[]) @@ -1027,6 +1029,7 @@ class MonitoredSession(_MonitoredSession): in given order: * calls `hook.begin()` for each given hook + * calls `hook.before_create_session()` * finalizes the graph via `scaffold.finalize()` * create session * initializes the model via initialization ops provided by `Scaffold` diff --git a/tensorflow/python/training/session_run_hook.py b/tensorflow/python/training/session_run_hook.py index e598bc2d98c..9d05d04c139 100644 --- a/tensorflow/python/training/session_run_hook.py +++ b/tensorflow/python/training/session_run_hook.py @@ -109,6 +109,20 @@ def begin(self): """ pass + def before_create_session(self): + """Called before new TensorFlow session is created. + + This has two essential differences with the situation in which `begin` is + called: + + * Do not modify the graph in this method, ops should not be added to graph. + The modification of the graph should take place within the begin + interface. + * This method will also be called prior to the recovery of a wrapped + session, not just at the beginning of the overall session. + """ + pass + def after_create_session(self, session, coord): # pylint: disable=unused-argument """Called when new TensorFlow session is created. From e10d4411dfb93ca47f6e1908ac878d1417c7db58 Mon Sep 17 00:00:00 2001 From: Chen Ding Date: Mon, 29 Apr 2024 17:18:35 +0800 Subject: [PATCH 10/14] [Docs] Fix readthedoc build fail. (#993) - Add configure file: docs/docs_zh/.readthedocs.yaml docs/docs_en/.readthedocs.yaml Signed-off-by: Chen Ding --- docs/docs_en/.readthedocs.yaml | 35 ++++++++++++++++++++++++++++++++++ docs/docs_zh/.readthedocs.yaml | 35 ++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 docs/docs_en/.readthedocs.yaml create mode 100644 docs/docs_zh/.readthedocs.yaml diff --git a/docs/docs_en/.readthedocs.yaml b/docs/docs_en/.readthedocs.yaml new file mode 100644 index 00000000000..c69bbd13812 --- /dev/null +++ b/docs/docs_en/.readthedocs.yaml @@ -0,0 +1,35 @@ +# Read the Docs configuration file for Sphinx projects +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the OS, Python version and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.12" + # You can also specify other tool versions: + # nodejs: "20" + # rust: "1.70" + # golang: "1.20" + +# Build documentation in the "docs/" directory with Sphinx +sphinx: + configuration: docs/docs_en/conf.py + # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs + # builder: "dirhtml" + # Fail on all warnings to avoid broken references + # fail_on_warning: true + +# Optionally build your docs in additional formats such as PDF and ePub +# formats: +# - pdf +# - epub + +# Optional but recommended, declare the Python requirements required +# to build your documentation +# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +python: + install: + - requirements: docs/docs_en/requirements.txt diff --git a/docs/docs_zh/.readthedocs.yaml b/docs/docs_zh/.readthedocs.yaml new file mode 100644 index 00000000000..859db8adfa5 --- /dev/null +++ b/docs/docs_zh/.readthedocs.yaml @@ -0,0 +1,35 @@ +# Read the Docs configuration file for Sphinx projects +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the OS, Python version and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.12" + # You can also specify other tool versions: + # nodejs: "20" + # rust: "1.70" + # golang: "1.20" + +# Build documentation in the "docs/" directory with Sphinx +sphinx: + configuration: docs/docs_zh/conf.py + # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs + # builder: "dirhtml" + # Fail on all warnings to avoid broken references + # fail_on_warning: true + +# Optionally build your docs in additional formats such as PDF and ePub +# formats: +# - pdf +# - epub + +# Optional but recommended, declare the Python requirements required +# to build your documentation +# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +python: + install: + - requirements: docs/docs_zh/requirements.txt From b2aed9686182124fca72f8093e74136cc13dcd39 Mon Sep 17 00:00:00 2001 From: Chen Bangduo Date: Tue, 14 May 2024 10:43:13 +0800 Subject: [PATCH 11/14] [Embedding] Change the log level for EV restore. (#995) Signed-off-by: chenbangduo.cbd --- tensorflow/core/kernels/kv_variable_restore_ops.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/kernels/kv_variable_restore_ops.cc b/tensorflow/core/kernels/kv_variable_restore_ops.cc index e16db9b4cd6..0a0165595f0 100644 --- a/tensorflow/core/kernels/kv_variable_restore_ops.cc +++ b/tensorflow/core/kernels/kv_variable_restore_ops.cc @@ -376,8 +376,8 @@ class KvResourceImportV3Op: public AsyncOpKernel { // EV should not be initialized at this time. if (ev->IsInitialized()) { - LOG(ERROR) << "Import parameter for EV (" << name_string - << ") failed, this EV has already been initialized."; + LOG(WARNING) << "EV (" << name_string + << ") has already been initialized."; } auto do_compute = [this, context, file_name_string, ev, From 93c69ad9576d6ee0f7b9479bef9b091451e5b91a Mon Sep 17 00:00:00 2001 From: Chen Bangduo Date: Tue, 21 May 2024 19:26:07 +0800 Subject: [PATCH 12/14] [Rendezvous] RemoteRendezvous supports FlowControl. (#994) Signed-off-by: chenbangduo.cbd --- .../base_rendezvous_mgr.cc | 213 ++++++++++++++- .../distributed_runtime/base_rendezvous_mgr.h | 45 ++++ .../rendezvous_mgr_interface.h | 11 +- .../rpc/grpc_remote_worker.cc | 10 + .../rpc/grpc_worker_interface.h | 6 + .../rpc/grpc_worker_service.cc | 162 ++++++++++++ .../rpc/grpc_worker_service.h | 4 + .../rpc/grpc_worker_service_impl.cc | 2 + .../rpc/grpc_worker_service_impl.h | 1 + .../rpc/rpc_rendezvous_mgr.cc | 245 ++++++++++++++++++ .../rpc/rpc_rendezvous_mgr_test.cc | 26 ++ tensorflow/core/framework/rendezvous.cc | 41 +++ tensorflow/core/framework/rendezvous.h | 26 ++ .../core/kernels/file_slice_sendrecv_ops.cc | 20 +- .../core/kernels/file_slice_sendrecv_ops.h | 2 + .../kernels/file_slice_sendrecv_ops_test.cc | 13 + tensorflow/core/kernels/slice_sendrecv_ops.cc | 40 +-- tensorflow/core/kernels/slice_sendrecv_ops.h | 2 + .../core/kernels/slice_sendrecv_ops_test.cc | 13 + tensorflow/core/protobuf/worker.proto | 46 ++++ tensorflow/core/protobuf/worker_service.proto | 5 + 21 files changed, 903 insertions(+), 30 deletions(-) diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc index 17935eb8982..ead121b30c8 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc @@ -34,11 +34,13 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/env_var.h" namespace tensorflow { namespace { uint64 kGlobalStepId = 0x100000000000000uLL; + int64 kFlowControlMaxSize = 16; } // namespace anonymous static void StartAbortRendevous(Rendezvous* rendez, const Status& s) { @@ -127,6 +129,23 @@ void BaseRendezvousMgr::FuseRecvLocalAsync( rendez->FuseRecvLocalAsync(parsed_keys, std::move(done_cb)); } +void BaseRendezvousMgr::FlowControlRecvLocalAsync(int64 step_id, + const StringPiece& tag, const Rendezvous::ParsedKey& parsed, + Rendezvous::DoneCallback done) { + auto rendez = FindOrCreate(step_id); + using namespace std::placeholders; + Rendezvous::DoneCallback done_cb = std::bind( + [rendez](Rendezvous::DoneCallback done, + // Begin unbound arguments. + const Status& s, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& v, bool dead) { + rendez->Unref(); + done(s, send_args, recv_args, v, dead); + }, + std::move(done), _1, _2, _3, _4, _5); + rendez->FlowControlRecvLocalAsync(tag, parsed, std::move(done_cb)); +} + void BaseRendezvousMgr::Cleanup(int64 step_id) { Rendezvous* rendez = nullptr; { @@ -174,7 +193,17 @@ BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id) : env_(env), step_id_(step_id), local_(NewLocalRendezvous()), - session_(nullptr) {} + session_(nullptr), + flow_control_num_(0) { + Status s = ReadInt64FromEnvVar("REMOTE_RENDEZVOUS_FLOW_CONTROL_MAX_SIZE", + kFlowControlMaxSize, &flow_control_max_size_); + if (!s.ok()) { + LOG(ERROR) << "Read REMOTE_RENDEZVOUS_FLOW_CONTROL_MAX_SIZE env error: " + << s.error_message(); + } + VLOG(2) << "BaseRemoteRendezvous set flow control max size: " + << flow_control_max_size_; +} BaseRemoteRendezvous::~BaseRemoteRendezvous() { CHECK(active_.empty()); @@ -221,6 +250,16 @@ Status BaseRemoteRendezvous::Initialize(WorkerSession* session) { std::move(fuse_call.done)); } + std::vector deferred_flow_control_calls; + { + mutex_lock l(mu_); + std::swap(deferred_flow_control_calls, deferred_flow_control_calls_); + } + for (auto& fc_call : deferred_flow_control_calls) { + FlowControlRecvLocalAsyncInternal(fc_call.tag, fc_call.parsed, + std::move(fc_call.done)); + } + return Status::OK(); } @@ -271,6 +310,43 @@ Status BaseRemoteRendezvous::Send(const ParsedKey& parsed, return local_->Send(parsed, args, val, mu, is_dead); } +Status BaseRemoteRendezvous::FlowControlSend(const StringPiece& tag, + const ParsedKey& parsed, + const Args& args, + const Tensor& val, + const bool is_dead, + const int64 timeout_millis) { + VLOG(1) << "BaseRemoteRendezvous FlowControlSend " << this << " " + << parsed.FullKey(); + const std::string tag_string(tag.data(), tag.size()); + { + mutex_lock l(mu_); + while(status_.ok() && flow_control_num_ >= flow_control_max_size_) { + if (flow_control_cv_.wait_for( + l, std::chrono::milliseconds(timeout_millis)) == \ + std::cv_status::timeout) { + return errors::DeadlineExceeded("FlowControlSend has timed out."); + } + } + + if (!status_.ok()) return status_; + DCHECK(is_initialized_locked()); + if (!IsLocalDevice(session_->worker_name, parsed.src_device)) { + return errors::InvalidArgument( + "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ", + session_->worker_name); + } + + flow_control_num_++; + if (flow_control_counters_.count(tag_string) == 0) { + flow_control_counters_[tag_string] = 0; + } + flow_control_counters_[tag_string]++; + } + // Buffers "val" and "device_context" in local_. + return local_->Send(parsed, args, val, is_dead); +} + Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed, bool is_src) { // Cache session pointer to avoid repeatedly taking & releasing the lock @@ -413,6 +489,63 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, } } +void BaseRemoteRendezvous::FlowControlRecvAsync(const StringPiece& tag, + const ParsedKey& parsed, + const Args& recv_args, + DoneCallback done) { + VLOG(1) << "RemoteRendezvous FlowControlRecvAsync " << this + << " " << tag << " " << parsed.FullKey(); + + Status s = ValidateDevices(parsed, false /*!is_src*/); + if (s.ok() && !is_initialized()) { + s.Update(errors::Internal( + "FlowControlRecvAsync called when uninitialized (key:", + parsed.FullKey(), ").")); + } + if (!s.ok()) { + done(s, Args(), recv_args, Tensor(), false); + return; + } + + // Are src and dst in the same worker? + if (IsSameWorker(parsed.src, parsed.dst)) { + // Recv the tensor from local_. + local_->RecvAsync( + parsed, recv_args, + [this, tag, parsed, done]( + const Status& status, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) { + VLOG(2) << "RemoteRendezvous Finished Recv " << this << " " + << parsed.FullKey(); + Tensor* out = new Tensor; + StatusCallback final_callback = [done, send_args, recv_args, out, + is_dead](const Status& s) { + done(s, send_args, recv_args, *out, is_dead); + delete out; + }; + + if (status.ok()) { + SameWorkerRecvDone(parsed, send_args, recv_args, in, out, + std::move(final_callback)); + const std::string tag_string(tag.data(), tag.size()); + { + mutex_lock l(mu_); + flow_control_num_--; + DCHECK(flow_control_counters_.count(tag_string) != 0); + flow_control_counters_[tag_string]--; + } + flow_control_cv_.notify_one(); + } else { + final_callback(status); + } + }); + return; + } else { + FlowControlRecvFromRemoteAsync(tag, parsed, recv_args, std::move(done)); + } + +} + void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed, DoneCallback done) { { @@ -600,6 +733,58 @@ void BaseRemoteRendezvous::FuseRecvLocalAsyncInternal( } } +void BaseRemoteRendezvous::FlowControlRecvLocalAsync(const StringPiece& tag, + const ParsedKey& parsed, + DoneCallback done) { + { + mutex_lock l(mu_); + if (!is_initialized_locked()) { + // FlowControlRecvLocalAsync can be called (due to an incoming RecvTensor + // RPC from a remote worker) before the RunStep (or PartialRunStep) RPC + // from the master arrives. RecvLocalAsync thus buffers the arguments + // until after the RemoteRendezvous is Initialize()'d, when it completes + // the rendezvous logic. At some point after Initialize() is called, a + // Tensor is produced locally that will then be sent in response to the + // incoming RPC. + DeferredFlowControlCall call(tag, parsed, std::move(done)); + deferred_flow_control_calls_.push_back(call); + return; + } + } + FlowControlRecvLocalAsyncInternal(tag, parsed, std::move(done)); +} + +void BaseRemoteRendezvous::FlowControlRecvLocalAsyncInternal( + const StringPiece& tag, const ParsedKey& parsed, DoneCallback done) { + Status s = ValidateDevices(parsed, true /* is_src */); + if (!s.ok()) { + done(s, Args(), Args(), Tensor(), false); + return; + } + + using namespace std::placeholders; + Rendezvous::DoneCallback done_cb = std::bind( + [this, tag](Rendezvous::DoneCallback done, + // Begin unbound arguments. + const Status& s, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& v, bool dead) { + done(s, send_args, recv_args, v, dead); + if (s.ok()) { + const std::string tag_string(tag.data(), tag.size()); + { + mutex_lock l(mu_); + flow_control_num_--; + DCHECK(flow_control_counters_.count(tag_string) != 0); + flow_control_counters_[tag_string]--; + } + flow_control_cv_.notify_one(); + } + }, + std::move(done), _1, _2, _3, _4, _5); + + local_->RecvAsync(parsed, Args(), std::move(done_cb)); +} + void BaseRemoteRendezvous::FuseRecvFromRemoteAsync( const std::vector& parsed_keys, const Rendezvous::Args& args, @@ -607,6 +792,12 @@ void BaseRemoteRendezvous::FuseRecvFromRemoteAsync( CHECK(false) << "FuseRecvFromRemoteAsync Unimplemented"; } +void BaseRemoteRendezvous::FlowControlRecvFromRemoteAsync( + const StringPiece& tag, const Rendezvous::ParsedKey& parsed, + const Rendezvous::Args& args, DoneCallback done) { + CHECK(false) << "FlowControlRecvFromRemoteAsync Unimplemented."; +} + void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, const Rendezvous::Args& recv_args, RefDoneCallback done) { @@ -636,6 +827,19 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, } } +int64 BaseRemoteRendezvous::GetAllFlowControlItemNum() { + mutex_lock l(mu_); + return flow_control_num_; +} + +int64 BaseRemoteRendezvous::GetFlowControlItemNum(StringPiece tag) { + const std::string tag_string(tag.data(), tag.size()); + mutex_lock l(mu_); + if (flow_control_counters_.count(tag_string) == 0) + return 0; + return flow_control_counters_[tag_string]; +} + void BaseRemoteRendezvous::StartAbort(const Status& s) { CHECK(!s.ok()); // Use a "derived" status as the status for the rendezvous. Derived @@ -656,7 +860,10 @@ void BaseRemoteRendezvous::StartAbort(const Status& s) { } active_.clear(); } + flow_control_num_ = 0; + flow_control_counters_.clear(); } + flow_control_cv_.notify_all(); } void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call, @@ -707,4 +914,8 @@ BaseRemoteRendezvous::DeferredFuseCall::DeferredFuseCall( const std::vector& parsed_keys, FuseDoneCallback done) : parsed_keys(parsed_keys), done(std::move(done)) {} +BaseRemoteRendezvous::DeferredFlowControlCall::DeferredFlowControlCall( + const StringPiece& tag, const ParsedKey& parsed, DoneCallback done) + : tag(tag), parsed(parsed), done(std::move(done)) {} + } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h index b65e59436c0..fc72d9bedfc 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ #include +#include #include #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" @@ -86,6 +87,10 @@ class BaseRendezvousMgr : public RendezvousMgrInterface { const std::vector& parsed_keys, Rendezvous::FuseDoneCallback done) override; + void FlowControlRecvLocalAsync(int64 step_id, const StringPiece& tag, + const Rendezvous::ParsedKey& parsed, + Rendezvous::DoneCallback done) override; + // Removes rendezvous for "step_id". // // TODO(zhifengc): Have a background thread in worker that @@ -140,6 +145,11 @@ class BaseRemoteRendezvous : public RemoteRendezvous { Status Send(const ParsedKey& key, const Rendezvous::Args& args, Tensor* val, mutex* mu, const bool is_dead) override; + Status FlowControlSend(const StringPiece& tag, const ParsedKey& key, + const Args& args, const Tensor& val, + const bool is_dead, + const int64 timeout_millis) override; + // This method is called only by the RecvOp. It tests to see // whether the value will be produced by a local or remote device // and handles accordingly. In the local case it forwards to @@ -147,6 +157,10 @@ class BaseRemoteRendezvous : public RemoteRendezvous { void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, DoneCallback done) override; + void FlowControlRecvAsync(const StringPiece& tag, + const ParsedKey& parsed_key, + const Args& args, DoneCallback done) override; + void StartAbort(const Status& status) override; // This method is called only by the local Worker, forwarded through @@ -171,10 +185,18 @@ class BaseRemoteRendezvous : public RemoteRendezvous { void FuseRecvLocalSync(const std::vector& parsed_keys, FuseDoneCallback done); + void FlowControlRecvLocalAsync(const StringPiece& tag, + const ParsedKey& parsed, DoneCallback done); + // For ref send/recv void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, RefDoneCallback done) override; + // Obtain statistical information + int64 GetAllFlowControlItemNum() override; + + int64 GetFlowControlItemNum(StringPiece tag) override; + protected: virtual void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& args, @@ -185,6 +207,10 @@ class BaseRemoteRendezvous : public RemoteRendezvous { const Rendezvous::Args& args, FuseDoneCallback done); + virtual void FlowControlRecvFromRemoteAsync(const StringPiece& tag, + const Rendezvous::ParsedKey& parsed, + const Rendezvous::Args& args, DoneCallback done); + // Returns true if "src" and "dst" are located in the same worker, // and hence may use a local rendezvous. virtual bool IsSameWorker(DeviceNameUtils::ParsedName src, @@ -210,6 +236,12 @@ class BaseRemoteRendezvous : public RemoteRendezvous { mutable mutex mu_; + // For Flow Control. + int64 flow_control_max_size_; + int64 flow_control_num_ GUARDED_BY(mu_); + std::unordered_map flow_control_counters_ GUARDED_BY(mu_); + tensorflow::condition_variable flow_control_cv_; + // Status given by StartAbort() if any. Status status_ GUARDED_BY(mu_); WorkerSession* session_ GUARDED_BY(mu_); // Not owned. @@ -233,6 +265,16 @@ class BaseRemoteRendezvous : public RemoteRendezvous { }; std::vector deferred_fuse_calls_ GUARDED_BY(mu_); + struct DeferredFlowControlCall { + const StringPiece tag; + const ParsedKey parsed; + DoneCallback done; + + DeferredFlowControlCall(const StringPiece& tag, const ParsedKey& parsed, + DoneCallback done); + }; + std::vector deferred_flow_control_calls_ GUARDED_BY(mu_); + typedef std::function InactiveCallback; // Active outstanding RecvTensor calls. @@ -262,6 +304,9 @@ class BaseRemoteRendezvous : public RemoteRendezvous { void FuseRecvLocalAsyncInternal(const std::vector& parsed_keys, FuseDoneCallback done); + void FlowControlRecvLocalAsyncInternal(const StringPiece& tag, + const ParsedKey& parsed, + DoneCallback done); TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous); }; diff --git a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h index caf4af97ac2..abc971c4552 100644 --- a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h +++ b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h @@ -40,6 +40,11 @@ class RemoteRendezvous : public Rendezvous { public: // Fully construct the RemoteRendezvous. virtual Status Initialize(WorkerSession* session) = 0; + + // Obtain statistical information + virtual int64 GetAllFlowControlItemNum() = 0; + + virtual int64 GetFlowControlItemNum(StringPiece tag) = 0; }; // RendezvousMgr keeps track of a set of local rendezvous instances. @@ -87,7 +92,11 @@ class RendezvousMgrInterface { virtual void FuseRecvLocalAsync( int64 step_id, const std::vector& parsed_keys, - Rendezvous::FuseDoneCallback done) = 0; + Rendezvous::FuseDoneCallback done) = 0; + + virtual void FlowControlRecvLocalAsync(int64 step_id, const StringPiece& tag, + const Rendezvous::ParsedKey& parsed, + Rendezvous::DoneCallback done) = 0; // Removes rendezvous for "step_id". // diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc index ba95e80b496..c3fb6a8ee6c 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -63,6 +63,7 @@ class GrpcRemoteWorker : cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)), recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)), fuserecvtensor_(Method(GrpcWorkerMethod::kFuseRecvTensor)), + flowcontrolrecvtensor_(Method(GrpcWorkerMethod::kFlowControlRecvTensor)), recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)), logging_(Method(GrpcWorkerMethod::kLogging)), tracing_(Method(GrpcWorkerMethod::kTracing)), @@ -210,6 +211,14 @@ class GrpcRemoteWorker : IssueRequest(request, response, fuserecvtensor_, done, call_opts); } + void FlowControlRecvTensorAsync(CallOptions* call_opts, + const FlowControlRecvTensorRequest* request, + TensorResponse* response, + StatusCallback done) { + VLOG(1) << "FlowControlRecvTensorAsync req: " << request->DebugString(); + IssueRequest(request, response, flowcontrolrecvtensor_, done, call_opts); + } + void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request, TensorResponse* response, StatusCallback done) override { VLOG(1) << "RecvTensorAsync req: " << request->DebugString(); @@ -341,6 +350,7 @@ class GrpcRemoteWorker : const ::grpc::string cleanupall_; const ::grpc::string recvtensor_; const ::grpc::string fuserecvtensor_; + const ::grpc::string flowcontrolrecvtensor_; const ::grpc::string recvbuf_; const ::grpc::string logging_; const ::grpc::string tracing_; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_interface.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_interface.h index 20f1d2b5a62..2c885fec75d 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_interface.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_interface.h @@ -6,6 +6,8 @@ namespace tensorflow { class CallOptions; class FuseTensorResponse; class FuseRecvTensorRequest; +class FlowControlRecvTensorRequest; +class TensorResponse; class GrpcWorkerInterface { public: @@ -13,6 +15,10 @@ class GrpcWorkerInterface { const FuseRecvTensorRequest* request, FuseTensorResponse* response, StatusCallback done) = 0; + + virtual void FlowControlRecvTensorAsync(CallOptions* call_opts, + const FlowControlRecvTensorRequest* request, + TensorResponse* response, StatusCallback done) = 0; }; } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index ef4fbeab438..3bdacc29a12 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -170,6 +170,15 @@ class GrpcWorkerServiceThread { EnqueueFuseRecvTensorRequestRaw(); } + // Support FlowControlRecv + for (int i = 0; + i < gtl::FindWithDefault( + queue_depth_, static_cast(GrpcWorkerMethod::kFlowControlRecvTensor), + 1000); + ++i) { + EnqueueFlowControlRecvTensorRequestRaw(); + } + void* tag; bool ok; @@ -312,6 +321,24 @@ class GrpcWorkerServiceThread { EnqueueFuseRecvTensorRequestRaw(); } + void FlowControlRecvTensorHandlerRaw( + WorkerCall* call) { + Schedule([this, call]() { + CallOptions* call_opts = new CallOptions; + call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); + + worker_->GrpcFlowControlRecvTensorAsync(call_opts, &call->request, + &call->response, + [call, call_opts + ](const Status& s) { + call->ClearCancelCallback(); + delete call_opts; + call->SendResponse(ToGrpcStatus(s)); + }); + }); + EnqueueFlowControlRecvTensorRequestRaw(); + } + void RecvBufHandler(WorkerCall* call) { Schedule([this, call]() { CallOptions* call_opts = new CallOptions; @@ -394,6 +421,19 @@ class GrpcWorkerServiceThread { } } + void EnqueueFlowControlRecvTensorRequestRaw() { + mutex_lock l(shutdown_mu_); + if (!is_shutdown_) { + Call:: + EnqueueRequestForMethod( + worker_service_, cq_.get(), + static_cast(GrpcWorkerMethod::kFlowControlRecvTensor), + &GrpcWorkerServiceThread::FlowControlRecvTensorHandlerRaw, + true /* supports cancel*/); + } + } + GrpcWorker* const worker_ = nullptr; // Not owned. std::unique_ptr<::grpc::ServerCompletionQueue> cq_; std::unique_ptr thread_; @@ -746,6 +786,128 @@ void GrpcWorker::GrpcFuseRecvTensorAsync(CallOptions* opts, }); } +// GrpcFlowControlRecvTensorAsync: unlike the other Worker methods, which use +// protocol buffers for a response object, to avoid extra protocol buffer +// serialization overhead we generate our response directly into a +// ::grpc::ByteBuffer object +void GrpcWorker::GrpcFlowControlRecvTensorAsync(CallOptions* opts, + const FlowControlRecvTensorRequest* request, + ::grpc::ByteBuffer* response, StatusCallback done) { + VLOG(1) << "GrpcFlowControlRecvTensorAsync req: " << request->DebugString(); + const int64 request_id = request->request_id(); + const int64 step_id = request->step_id(); + + bool cache_enabled = (response_cache_ != nullptr && request_id != 0); + + auto do_response = [response, done, cache_enabled](const Tensor& tensor, + bool is_dead, + const Status& status) { + if (status.ok()) { + grpc::EncodeTensorToByteBuffer(is_dead, tensor, cache_enabled, response); + } + done(status); + }; + + // If response cache is enabled and the response cache already contains the + // request, we delegate this retry request to the response cache. Otherwise, + // we add the request to the response cache and start the computation to + // retrieve the requested data. + if (cache_enabled && + response_cache_->QueueRequest(request_id, step_id, do_response)) { + return; + } + + auto rendezvous_done = [this, request_id, do_response, cache_enabled]( + const Tensor& tensor, bool is_dead, + const Status& status) { + if (cache_enabled) { + // Data is ready. Process all pending requests in the response cache. + response_cache_->OnRequestFinished(request_id, tensor, is_dead, status); + } else { + do_response(tensor, is_dead, status); + } + }; + + auto fail = [&rendezvous_done](const Status& status) { + rendezvous_done(Tensor(), false, status); + }; + + Status s = recent_request_ids_.TrackUnique( + request_id, "RecvTensor (GrpcWorker)", *request); + if (!s.ok()) { + fail(s); + return; + } + + const string& key = request->rendezvous_key(); + TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str()); + Rendezvous::ParsedKey parsed; + s = Rendezvous::ParseKey(key, &parsed); + Device* src_dev = nullptr; + if (s.ok()) { + s = PrepareRecvTensor(parsed, &src_dev); + } + if (!s.ok()) { + fail(s); + return; + } + + // Request the tensor associated with the rendezvous key. + // Note that we log the cancellation here but do not abort the current step. + // gRPC can generate cancellations in response to transient network failures, + // and aborting the step eliminates the opportunity for client side retries. + // Repeated client failures will eventually cause the step to be aborted by + // the client. + opts->SetCancelCallback( + [step_id]() { LOG(WARNING) << "RecvTensor cancelled for " << step_id; }); + StringPiece tag = request->tag(); + env_->rendezvous_mgr->FlowControlRecvLocalAsync( + step_id, tag, parsed, + [opts, rendezvous_done, src_dev, request]( + const Status& status, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& val, + const bool is_dead) { + opts->ClearCancelCallback(); + if (status.ok()) { + // DMA can only be used for Tensors that do not fall into + // the following three odd edge cases: 1) a zero-size + // buffer, 2) a dead tensor which has an uninit value, and + // 3) the tensor has the on_host allocation attribute, + // i.e. it's in CPU RAM *independent of its assigned + // device type*. + const bool on_host = send_args.alloc_attrs.on_host(); + { + // Non-DMA cases. + if (src_dev->tensorflow_gpu_device_info() && (!on_host)) { + DeviceContext* send_dev_context = send_args.device_context; + AllocatorAttributes alloc_attrs; + alloc_attrs.set_gpu_compatible(true); + alloc_attrs.set_on_host(true); + Allocator* alloc = src_dev->GetAllocator(alloc_attrs); + Tensor* copy = new Tensor(alloc, val.dtype(), val.shape()); + CHECK(send_dev_context) + << "send dev name: " << src_dev->name() + << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); + // "val" is on an accelerator device. Uses the device_context to + // fill the copy on host. + StatusCallback copy_ready = [rendezvous_done, copy, + is_dead](const Status& s) { + // The value is now ready to be returned on the wire. + rendezvous_done(*copy, is_dead, s); + delete copy; + }; + + CopyDeviceToHost(&val, alloc, alloc, request->rendezvous_key(), + src_dev, copy, send_dev_context, copy_ready); + return; + } + } + } + + rendezvous_done(val, is_dead, status); + }); +} + namespace { // If RecvBufRespExtra.tensor_content is a single large string, then gRPC // can stall on the recv side when the string buffer needs to be enlarged, diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h index 69759c420cc..48941d438c9 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h @@ -51,6 +51,10 @@ class GrpcWorker : public Worker { ::grpc::ByteBuffer* response, StatusCallback done); + virtual void GrpcFlowControlRecvTensorAsync(CallOptions* opts, + const FlowControlRecvTensorRequest* request, + ::grpc::ByteBuffer* response, StatusCallback done); + void LoggingAsync(const LoggingRequest* request, LoggingResponse* response, StatusCallback done) override; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc index 515d6e90beb..2095540e36a 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc @@ -48,6 +48,8 @@ const char* GrpcWorkerMethodName(GrpcWorkerMethod id) { return "/tensorflow.WorkerService/RecvTensor"; case GrpcWorkerMethod::kFuseRecvTensor: return "/tensorflow.WorkerService/FuseRecvTensor"; + case GrpcWorkerMethod::kFlowControlRecvTensor: + return "/tensorflow.WorkerService/FlowControlRecvTensor"; case GrpcWorkerMethod::kRecvBuf: return "/tensorflow.WorkerService/RecvBuf"; case GrpcWorkerMethod::kLogging: diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h index ff8e1c07cb4..ad77ee0fd80 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h @@ -80,6 +80,7 @@ enum class GrpcWorkerMethod { kCleanupAll, kRecvTensor, kFuseRecvTensor, + kFlowControlRecvTensor, kRecvBuf, kLogging, kTracing, diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc index 69f1481f59e..267bf09e66f 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -53,6 +53,10 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous { const Rendezvous::Args& args, FuseDoneCallback done) override; + void FlowControlRecvFromRemoteAsync(const StringPiece& tag, + const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, + DoneCallback done) override; + private: ~RpcRemoteRendezvous() override {} @@ -529,6 +533,247 @@ void RpcRemoteRendezvous::FuseRecvFromRemoteAsync( }); } + + +class FlowControlRpcRecvTensorCall : public BaseRecvTensorCall { + public: + FlowControlRpcRecvTensorCall() + : wi_(nullptr), dst_device_(nullptr) {} + + void Init(WorkerInterface* wi, int64 step_id, const StringPiece& tag, + const StringPiece& key, AllocatorAttributes alloc_attrs, + Device* dst_device, const Rendezvous::Args& recv_args, + Rendezvous::DoneCallback done) { + wi_ = wi; + grpc_wi_ = dynamic_cast(wi_); + alloc_attrs_ = alloc_attrs; + dst_device_ = dst_device; + recv_args_ = recv_args; + done_ = std::move(done); + req_.set_step_id(step_id); + req_.set_tag(tag.data(), tag.size()); + req_.set_request_id(GetUniqueRequestId()); + req_.set_rendezvous_key(key.data(), key.size()); + } + + void Reset() { + // The FlowControlRpcRemoteRendezvous using this object is responsible for + // calling ReleaseWorker() before Reset(). + DCHECK_EQ(static_cast(nullptr), wi_) + << "Leaking WorkerInterface in RpcRecvTensorCall::Reset()."; + + alloc_attrs_ = AllocatorAttributes(); + dst_device_ = nullptr; + // We don't clear opts_ and assume that Init will set up the state for + // opts_ appropriately. + req_.Clear(); + resp_.Clear(); + { + mutex_lock l(mu_); + status_ = Status::OK(); + } + done_ = nullptr; + } + + ~FlowControlRpcRecvTensorCall() override { + // Since only the FlowControlRpcRecvTensorFreeList will delete an + // FlowControlRpcRecvTensorCall, and it always sets this->wi_ to null when + // a call object is released to it, we can assert that this->wi_ is + // always null at the point of deletion. + CHECK_EQ(static_cast(nullptr), wi_) + << "Leaking WorkerInterface in FlowControlRpcRecvTensorCall destructor."; + } + + void Start(std::function recv_done) override { + StartRTCall(std::move(recv_done)); + } + + void StartAbort(const Status& s) override { + { + mutex_lock l(mu_); + status_.Update(s); + } + opts_.StartCancel(); + } + + Status status() const override { + mutex_lock l(mu_); + return status_; + } + + void ReleaseWorker(WorkerCacheInterface* worker_cache) { + DCHECK_NE(static_cast(nullptr), wi_) + << "FlowControlRpcRecvTensorCall::ReleaseWorker() called twice."; + worker_cache->ReleaseWorker(src_worker_, wi_); + wi_ = nullptr; + grpc_wi_ = nullptr; + } + + const Tensor& tensor() const { return resp_.tensor(); } + + bool is_dead() const { return resp_.metadata().is_dead(); } + + Device* dst_device() const { return dst_device_; } + const Rendezvous::Args recv_args() const { return recv_args_; } + const Rendezvous::DoneCallback& done() const { return done_; } + + private: + friend class RpcRemoteRendezvous; + + // Start the main RecvTensor call, checking for an async abort. + void StartRTCall(std::function recv_done) { + resp_.InitAlloc(dst_device_, alloc_attrs_); + using namespace std::placeholders; + StatusCallback cb = std::bind( + [this](std::function recv_done, + // Begin unbound arguments. + const Status& s) { + if (!s.ok()) { + mutex_lock l(mu_); + status_.Update(s); + } + recv_done(); + }, + std::move(recv_done), _1); + grpc_wi_->FlowControlRecvTensorAsync(&opts_, &req_, &resp_, std::move(cb)); + } + + string src_worker_; + string src_rel_device_; + WorkerInterface* wi_; // Not owned. + GrpcWorkerInterface* grpc_wi_; + AllocatorAttributes alloc_attrs_; + Device* dst_device_; + CallOptions opts_; + FlowControlRecvTensorRequest req_; + TensorResponse resp_; + Rendezvous::Args recv_args_; + Rendezvous::DoneCallback done_; + + mutable mutex mu_; + Status status_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(FlowControlRpcRecvTensorCall); +}; + +class FlowControlRpcRecvTensorFreeList { + public: + FlowControlRpcRecvTensorFreeList() {} + ~FlowControlRpcRecvTensorFreeList() { + for (size_t i = 0; i < objects_.size(); i++) { + delete objects_[i]; + } + } + + FlowControlRpcRecvTensorCall* New() { + { + mutex_lock l(mu_); + if (!objects_.empty()) { + FlowControlRpcRecvTensorCall* result = objects_.back(); + objects_.pop_back(); + return result; + } + } + return new FlowControlRpcRecvTensorCall; + } + + void Release(FlowControlRpcRecvTensorCall* obj) { + obj->Reset(); + { + mutex_lock l(mu_); + if (objects_.size() < kMaxObjects) { + objects_.push_back(obj); + return; + } + } + delete obj; + } + + private: + static const int kMaxObjects = 1000; + + mutex mu_; + std::vector objects_ GUARDED_BY(mu_); +}; + +static FlowControlRpcRecvTensorFreeList* get_flow_control_call_freelist() { + static FlowControlRpcRecvTensorFreeList* call_freelist = \ + new FlowControlRpcRecvTensorFreeList(); + return call_freelist; +} + +void RpcRemoteRendezvous::FlowControlRecvFromRemoteAsync( + const StringPiece& tag, const Rendezvous::ParsedKey& parsed, + const Rendezvous::Args& recv_args, DoneCallback done) { + CHECK(is_initialized()); + Status s; + + // Prepare a FlowControlRecvTensor call that can handle being aborted. + FlowControlRpcRecvTensorCall* call = get_flow_control_call_freelist()->New(); + + // key.src_device identifies a remote device. + if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &call->src_worker_, + &call->src_rel_device_)) { + s = errors::Internal(parsed.src_device, + " is invalid remote source device."); + } + + WorkerSession* sess = session(); + WorkerInterface* rwi = + sess->worker_cache->GetOrCreateWorker(call->src_worker_); + if (s.ok() && rwi == nullptr) { + s = errors::Internal("No worker known as ", call->src_worker_); + } + + Device* dst_device; + if (s.ok()) { + s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device); + } + if (!s.ok()) { + if (rwi != nullptr) { + sess->worker_cache->ReleaseWorker(call->src_worker_, rwi); + } + get_flow_control_call_freelist()->Release(call); + done(s, Args(), recv_args, Tensor{}, false); + return; + } + + call->Init(rwi, step_id_, tag, parsed.FullKey(), recv_args.alloc_attrs, + dst_device, recv_args, std::move(done)); + + // Record "call" in active_ so that it can be aborted cleanly. + RegisterCall(call, recv_args); + + // RendezvousMgr already aborted, shouldn't send RPC call any more + if (!call->status().ok()) { + // NOTE: `*sess` can potentially be deleted before we return from + // `call->done()(...)`, so we must release the worker before calling the + // callback. + call->ReleaseWorker(sess->worker_cache.get()); + call->done()(call->status(), Args(), Args(), Tensor(), false); + get_flow_control_call_freelist()->Release(call); + return; + } + + // Start "call". + Ref(); + call->Start([this, call]() { + // Removes "call" from active_. Prevent StartAbort(). + DeregisterCall(call); + // If StartAbort was called prior to DeregisterCall, then the + // current status should be bad. + Status s = call->status(); + // NOTE: `*session()` can potentially be deleted before we return from + // `call->done()(...)`, so we must release the worker before calling the + // callback. + call->ReleaseWorker(session()->worker_cache.get()); + call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead()); + get_flow_control_call_freelist()->Release(call); + Unref(); + }); + +} + } // namespace RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env) diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc index 5021853ce23..75f41ab3057 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc @@ -211,6 +211,32 @@ TEST_F(RpcRendezvousMgrTest, CleanupAll) { } } +TEST_F(RpcRendezvousMgrTest, FlowControlSend) { + setenv("REMOTE_RENDEZVOUS_FLOW_CONTROL_MAX_SIZE", "2", 1); + const int64 step_id = 123; + const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( + "/job:mnist/replica:1/task:2/cpu:0", 7890, + "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); + { + RemoteRendezvous* rendez = rmgr_.Find(step_id); + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); + core::ScopedUnref unref(rendez); + Rendezvous::Args args; + TF_ASSERT_OK( + rendez->FlowControlSend("TEST", key, args, V("peach_0"), false)); + TF_ASSERT_OK( + rendez->FlowControlSend("TEST", key, args, V("peach_1"), false)); + + EXPECT_NE( + rendez->FlowControlSend("TEST", key, args, V("peach_2"), false, 100), + Status::OK()); + EXPECT_EQ(rendez->GetAllFlowControlItemNum(), 2); + EXPECT_EQ(rendez->GetFlowControlItemNum("TEST"), 2); + } + + unsetenv("REMOTE_RENDEZVOUS_FLOW_CONTROL_MAX_SIZE"); +} + class DummyDeviceContext : public DeviceContext { public: explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {} diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc index e4db066a562..4d1adf1a070 100644 --- a/tensorflow/core/framework/rendezvous.cc +++ b/tensorflow/core/framework/rendezvous.cc @@ -146,6 +146,47 @@ Status Rendezvous::Recv(const ParsedKey& key, const Args& args, Tensor* val, return Recv(key, args, val, is_dead, no_timeout); } +Status Rendezvous::FlowControlSend(const StringPiece& tag, const ParsedKey& key, + const Args& args, const Tensor& val, + const bool is_dead) { + int64 no_timeout = 300000; + return FlowControlSend(tag, key, args, val, is_dead, no_timeout); +} + +Status Rendezvous::FlowControlRecv(const StringPiece& tag, const ParsedKey& key, + const Args& args, Tensor* val, bool* is_dead, + int64 timeout_ms) { + Status ret; + Notification n; + FlowControlRecvAsync(tag, key, args, [&ret, &n, val, is_dead]( + const Status& s, const Args& send_args, + const Args& recv_args, const Tensor& v, + const bool dead) { + ret = s; + *val = v; + *is_dead = dead; + n.Notify(); + }); + if (timeout_ms > 0) { + int64 timeout_us = timeout_ms * 1000; + bool notified = WaitForNotificationWithTimeout(&n, timeout_us); + if (!notified) { + return Status(error::DEADLINE_EXCEEDED, + "Timed out waiting for notification"); + } + } else { + n.WaitForNotification(); + } + return ret; +} + +Status Rendezvous::FlowControlRecv(const StringPiece& tag, const ParsedKey& key, + const Args& args, Tensor* val, + bool* is_dead) { + const int64 no_timeout = 0; + return FlowControlRecv(tag, key, args, val, is_dead, no_timeout); +} + class LocalRendezvousImpl : public Rendezvous { public: explicit LocalRendezvousImpl() {} diff --git a/tensorflow/core/framework/rendezvous.h b/tensorflow/core/framework/rendezvous.h index 3aa65534272..106c0f26b32 100644 --- a/tensorflow/core/framework/rendezvous.h +++ b/tensorflow/core/framework/rendezvous.h @@ -108,6 +108,17 @@ class Rendezvous : public core::RefCounted { virtual Status Send(const ParsedKey& key, const Args& args, Tensor* ref_val, mutex* ref_mu, const bool is_dead) { return Status::OK(); } + virtual Status FlowControlSend(const StringPiece& tag, const ParsedKey& key, + const Args& args, const Tensor& val, + const bool is_dead, + const int64 timeout_millis) { + return errors::Unimplemented("[Rendezvous] unimplement FlowControlSend."); + } + + virtual Status FlowControlSend(const StringPiece& tag, const ParsedKey& key, + const Args& args, const Tensor& val, + const bool is_dead); + // Callback provided by a tensor consumer waiting on the rendezvous. // It will be invoked when the tensor is available, or when a non-OK // status arises in the production of that tensor. It also gets @@ -139,12 +150,27 @@ class Rendezvous : public core::RefCounted { virtual void FuseRecvAsync(const std::vector& parsed_keys, const Args& args, FuseDoneCallback done) {} + // Local rendezvous does not need this. + virtual void FlowControlRecvAsync(const StringPiece& tag, + const ParsedKey& parsed_key, const Args& args, + DoneCallback done) { + CHECK(false) << "[Rendezvous] unimplement FlowControlRecvAsync."; + } + // Synchronous wrapper for RecvAsync. Status Recv(const ParsedKey& key, const Args& args, Tensor* val, bool* is_dead, int64 timeout_ms); Status Recv(const ParsedKey& key, const Args& args, Tensor* val, bool* is_dead); + // Synchronous wrapper for FlowControlRecvAsync. + Status FlowControlRecv(const StringPiece& tag, const ParsedKey& key, + const Args& args, Tensor* val, bool* is_dead, + int64 timeout_ms); + + Status FlowControlRecv(const StringPiece& tag, const ParsedKey& key, + const Args& args, Tensor* val, bool* is_dead); + // Aborts all pending and future Send/Recv with the given "status". // // StartAbort() does not wait for ongoing calls to finish. diff --git a/tensorflow/core/kernels/file_slice_sendrecv_ops.cc b/tensorflow/core/kernels/file_slice_sendrecv_ops.cc index 6bfe54363f9..a919238a5ee 100644 --- a/tensorflow/core/kernels/file_slice_sendrecv_ops.cc +++ b/tensorflow/core/kernels/file_slice_sendrecv_ops.cc @@ -33,11 +33,10 @@ FileSliceSendOp::FileSliceSendOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK( ctx, ctx->GetAttr("send_device_incarnation", reinterpret_cast(&send_device_incarnation))); - string tensor_name; - OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name_)); key_prefix_ = \ slice_sendrecv::GetSliceRendezvousKeyPrefix(send_device, - recv_device, send_device_incarnation, tensor_name); + recv_device, send_device_incarnation, tensor_name_); if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) { hostmem_sendrecv_ = false; @@ -212,8 +211,9 @@ Status FileSliceSendOp::SendFileSlice(OpKernelContext* ctx, frame_iter, &parsed_key.buf_); VLOG(2) << "FileSliceSend " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, data_t, - ctx->is_input_dead())); + TF_RETURN_IF_ERROR( + ctx->rendezvous()->FlowControlSend(tensor_name_, parsed_key, args, data_t, + ctx->is_input_dead())); } @@ -253,11 +253,10 @@ FileSliceRecvOp::FileSliceRecvOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK( ctx, ctx->GetAttr("send_device_incarnation", reinterpret_cast(&send_device_incarnation))); - string tensor_name; - OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name_)); key_prefix_ = \ slice_sendrecv::GetSliceRendezvousKeyPrefix(send_device, - recv_device, send_device_incarnation, tensor_name); + recv_device, send_device_incarnation, tensor_name_); if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) { hostmem_sendrecv_ = false; } @@ -464,8 +463,9 @@ Status FileSliceRecvOp::RecvFileSlice(OpKernelContext* ctx, frame_iter, &parsed_key.buf_); VLOG(2) << "FileSliceRecv " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &data_t, - &is_dead, timeout_ms_)); + TF_RETURN_IF_ERROR( + ctx->rendezvous()->FlowControlRecv(tensor_name_, parsed_key, args, + &data_t, &is_dead, timeout_ms_)); // This shouldn't be a dead tensor. CHECK_EQ(is_dead, false); file_ptr->Append(data_t.scalar()()); diff --git a/tensorflow/core/kernels/file_slice_sendrecv_ops.h b/tensorflow/core/kernels/file_slice_sendrecv_ops.h index 6701196d481..df7e6c646f8 100644 --- a/tensorflow/core/kernels/file_slice_sendrecv_ops.h +++ b/tensorflow/core/kernels/file_slice_sendrecv_ops.h @@ -28,6 +28,7 @@ class FileSliceSendOp : public OpKernel { private: // Variables. + string tensor_name_; string key_prefix_; bool hostmem_sendrecv_; int32 slice_size_; @@ -63,6 +64,7 @@ class FileSliceRecvOp: public OpKernel { private: // Variables. + string tensor_name_; string key_prefix_; bool hostmem_sendrecv_; string recv_dir_; diff --git a/tensorflow/core/kernels/file_slice_sendrecv_ops_test.cc b/tensorflow/core/kernels/file_slice_sendrecv_ops_test.cc index 931cd152253..62f5596bb62 100644 --- a/tensorflow/core/kernels/file_slice_sendrecv_ops_test.cc +++ b/tensorflow/core/kernels/file_slice_sendrecv_ops_test.cc @@ -50,6 +50,13 @@ class DummyRendezvous : public Rendezvous { kv_.erase(key_str); return Status::OK(); } + + Status FlowControlSend(const StringPiece& tag, const ParsedKey& key, + const Args& args, const Tensor& val, + const bool is_dead) { + return Send(key, args, val, is_dead); + } + void RecvAsync(const ParsedKey& key, const Args& args, DoneCallback done) override { std::string key_str = { key.FullKey().data(), key.FullKey().size() }; @@ -72,6 +79,12 @@ class DummyRendezvous : public Rendezvous { done(Status::OK(), var.args, args, var.data, var.is_dead); kv_.erase(key_str); } + + void FlowControlRecvAsync(const StringPiece& tag, const ParsedKey& parsed_key, + const Args& args, DoneCallback done) { + RecvAsync(parsed_key, args, done); + } + void StartAbort(const Status& status) override {} private: diff --git a/tensorflow/core/kernels/slice_sendrecv_ops.cc b/tensorflow/core/kernels/slice_sendrecv_ops.cc index 25f1a4e8738..ee0e5426cbc 100644 --- a/tensorflow/core/kernels/slice_sendrecv_ops.cc +++ b/tensorflow/core/kernels/slice_sendrecv_ops.cc @@ -30,11 +30,10 @@ SliceSendOp::SliceSendOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK( ctx, ctx->GetAttr("send_device_incarnation", reinterpret_cast(&send_device_incarnation))); - string tensor_name; - OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name_)); key_prefix_ = \ slice_sendrecv::GetSliceRendezvousKeyPrefix(send_device, - recv_device, send_device_incarnation, tensor_name); + recv_device, send_device_incarnation, tensor_name_); if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) { hostmem_sendrecv_ = false; @@ -171,8 +170,9 @@ Status SliceSendOp::SendString(OpKernelContext* ctx, frame_iter, &parsed_key.buf_); VLOG(2) << "SliceSend " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, data_t, - ctx->is_input_dead())); + TF_RETURN_IF_ERROR( + ctx->rendezvous()->FlowControlSend(tensor_name_, parsed_key, args, + data_t, ctx->is_input_dead())); } else { TF_RETURN_IF_ERROR(SendStringSlice(ctx, frame_iter, elem, i)); } @@ -209,8 +209,9 @@ Status SliceSendOp::SendStringSlice(OpKernelContext* ctx, frame_iter, &parsed_key.buf_); VLOG(2) << "SliceSend " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, data_t, - ctx->is_input_dead())); + TF_RETURN_IF_ERROR( + ctx->rendezvous()->FlowControlSend(tensor_name_, parsed_key, args, data_t, + ctx->is_input_dead())); } return Status::OK(); @@ -248,8 +249,9 @@ Status SliceSendOp::SendBasicType(OpKernelContext* ctx, frame_iter, &parsed_key.buf_); VLOG(2) << "SliceSend " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - TF_RETURN_IF_ERROR(ctx->rendezvous()->Send(parsed_key, args, data_t, - ctx->is_input_dead())); + TF_RETURN_IF_ERROR( + ctx->rendezvous()->FlowControlSend(tensor_name_, parsed_key, args, data_t, + ctx->is_input_dead())); } return Status::OK(); @@ -270,11 +272,10 @@ SliceRecvOp::SliceRecvOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK( ctx, ctx->GetAttr("send_device_incarnation", reinterpret_cast(&send_device_incarnation))); - string tensor_name; - OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name_)); key_prefix_ = \ slice_sendrecv::GetSliceRendezvousKeyPrefix(send_device, - recv_device, send_device_incarnation, tensor_name); + recv_device, send_device_incarnation, tensor_name_); if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) { hostmem_sendrecv_ = false; } @@ -440,8 +441,9 @@ Status SliceRecvOp::RecvString(OpKernelContext* ctx, frame_iter, &parsed_key.buf_); VLOG(2) << "SliceRecv " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &data_t, - &is_dead, timeout_ms_)); + TF_RETURN_IF_ERROR( + ctx->rendezvous()->FlowControlRecv(tensor_name_, parsed_key, args, + &data_t, &is_dead, timeout_ms_)); // This shouldn't be a dead tensor. CHECK_EQ(is_dead, false); output_flat(i) = data_t.scalar()(); @@ -484,8 +486,9 @@ Status SliceRecvOp::RecvStringSlice(OpKernelContext* ctx, frame_iter, &parsed_key.buf_); VLOG(2) << "SliceRecv " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &data_t, - &is_dead, timeout_ms_)); + TF_RETURN_IF_ERROR( + ctx->rendezvous()->FlowControlRecv(tensor_name_, parsed_key, args, + &data_t, &is_dead, timeout_ms_)); // This shouldn't be a dead tensor. CHECK_EQ(is_dead, false); output_flat(index) += data_t.scalar()(); @@ -529,8 +532,9 @@ Status SliceRecvOp::RecvBasicType(OpKernelContext* ctx, frame_iter, &parsed_key.buf_); VLOG(2) << "SliceSend " << parsed_key.buf_; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(parsed_key.buf_, &parsed_key)); - TF_RETURN_IF_ERROR(ctx->rendezvous()->Recv(parsed_key, args, &data_t, - &is_dead, timeout_ms_)); + TF_RETURN_IF_ERROR( + ctx->rendezvous()->FlowControlRecv(tensor_name_, parsed_key, args, + &data_t, &is_dead, timeout_ms_)); // This shouldn't be a dead tensor. CHECK_EQ(is_dead, false); auto data_base = data_t.data(); diff --git a/tensorflow/core/kernels/slice_sendrecv_ops.h b/tensorflow/core/kernels/slice_sendrecv_ops.h index 43429bff32f..12e583e5551 100644 --- a/tensorflow/core/kernels/slice_sendrecv_ops.h +++ b/tensorflow/core/kernels/slice_sendrecv_ops.h @@ -28,6 +28,7 @@ class SliceSendOp : public OpKernel { private: // Variables. + string tensor_name_; string key_prefix_; bool hostmem_sendrecv_; int32 slice_size_; @@ -58,6 +59,7 @@ class SliceRecvOp : public OpKernel { private: // Variable. + string tensor_name_; string key_prefix_; bool hostmem_sendrecv_; int32 slice_size_; diff --git a/tensorflow/core/kernels/slice_sendrecv_ops_test.cc b/tensorflow/core/kernels/slice_sendrecv_ops_test.cc index 5693ed57918..0eeb6d98c36 100644 --- a/tensorflow/core/kernels/slice_sendrecv_ops_test.cc +++ b/tensorflow/core/kernels/slice_sendrecv_ops_test.cc @@ -50,6 +50,13 @@ class DummyRendezvous : public Rendezvous { kv_.erase(key_str); return Status::OK(); } + + Status FlowControlSend(const StringPiece& tag, const ParsedKey& key, + const Args& args, const Tensor& val, + const bool is_dead) { + return Send(key, args, val, is_dead); + } + void RecvAsync(const ParsedKey& key, const Args& args, DoneCallback done) override { std::string key_str = { key.FullKey().data(), key.FullKey().size() }; @@ -72,6 +79,12 @@ class DummyRendezvous : public Rendezvous { done(Status::OK(), var.args, args, var.data, var.is_dead); kv_.erase(key_str); } + + void FlowControlRecvAsync(const StringPiece& tag, const ParsedKey& parsed_key, + const Args& args, DoneCallback done) { + RecvAsync(parsed_key, args, done); + } + void StartAbort(const Status& status) override {} private: diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 65ec7ffe4bc..fa18fec180c 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -441,6 +441,52 @@ message MarkRecvFinishedRequest { message MarkRecvFinishedResponse {} +//////////////////////////////////////////////////////////////////////////////// +// +// FlowControlRecvTensor method request messages +// +//////////////////////////////////////////////////////////////////////////////// + +message FlowControlRecvTensorRequest { + // The step in which the tensor will be produced. + // + // REQUIRED: This must eventually correspond to the `step_id` passed + // into a RunGraph call on the same WorkerService. + int64 step_id = 1; + + string tag = 2; + + // A key identifying the channel to receive tensors from. A RecvTensor request + // retrieves one tensor from the channel, but multiple tensors can be sent and + // received over the same channel with multiple RecvTensor requests. See + // rendezvous.h for details. + string rendezvous_key = 3; + + // If true, use an out-of-band DMA mechanism to transfer the + // received tensor. + bool dma_ok = 4; + + // Optional information on client-side device locality. + DeviceLocality client_locality = 5; + + // Optional information on server-side device locality. + DeviceLocality server_locality = 6; + + // Optional information needed by the RPC subsystem. + google.protobuf.Any transport_options = 7; + + // Unique identifier for this request. Every RecvTensorRequest must have a + // unique request_id, and retried RecvTensorRequests must have the same + // request_id. If request_id is zero, retry detection and response cache + // are disabled. + // + // Retried RecvTensorRequests are problematic because a RecvTensor with no + // corresponding sender will wait forever, and the tensor may have been + // delivered to a previous retry. Workers use request_ids to reject retried + // RecvTensor requests instead of waiting forever. + int64 request_id = 8; +} + //////////////////////////////////////////////////////////////////////////////// // // Logging method request/response messages diff --git a/tensorflow/core/protobuf/worker_service.proto b/tensorflow/core/protobuf/worker_service.proto index 07a64c55ad8..8591f2fe6ab 100644 --- a/tensorflow/core/protobuf/worker_service.proto +++ b/tensorflow/core/protobuf/worker_service.proto @@ -72,6 +72,11 @@ service WorkerService { // FuseRecvTensor Method } + // See worker.proto for details. + rpc FlowControlRecvTensor(FlowControlRecvTensorRequest) returns (RecvTensorResponse) { + // FlowControlRecvTensor Method + } + // See worker.proto for details. rpc Logging(LoggingRequest) returns (LoggingResponse); From 9e30ab604aa316359f249bc061b5fe87a5773604 Mon Sep 17 00:00:00 2001 From: Chen Bangduo Date: Thu, 23 May 2024 12:00:02 +0800 Subject: [PATCH 13/14] [Embedding] Check the sharded property of tf.train.Saver. (#996) Signed-off-by: chenbangduo.cbd --- modelzoo/bst/train.py | 3 +- modelzoo/dbmtl/train.py | 3 +- modelzoo/dcn/train.py | 3 +- modelzoo/dcnv2/train.py | 3 +- modelzoo/deepfm/train.py | 3 +- modelzoo/dien/train.py | 3 +- modelzoo/din/train.py | 3 +- modelzoo/dlrm/train.py | 3 +- modelzoo/dssm/train.py | 3 +- modelzoo/esmm/train.py | 3 +- modelzoo/masknet/train.py | 3 +- modelzoo/mlperf/train.py | 3 +- modelzoo/mmoe/train.py | 3 +- modelzoo/ple/train.py | 3 +- modelzoo/simple_multitask/train.py | 3 +- modelzoo/wide_and_deep/train.py | 3 +- .../feature_column/feature_column_v2_test.py | 6 +- .../ops/embedding_variable_ops_gpu_test.py | 7 +- .../python/ops/embedding_variable_ops_test.py | 64 ++++++++++--------- tensorflow/python/training/incr_ckpt_test.py | 5 +- tensorflow/python/training/saver.py | 11 ++++ tensorflow/python/training/saver_test.py | 6 ++ 22 files changed, 76 insertions(+), 71 deletions(-) diff --git a/modelzoo/bst/train.py b/modelzoo/bst/train.py index eeeb136678b..536ddbc6905 100644 --- a/modelzoo/bst/train.py +++ b/modelzoo/bst/train.py @@ -612,10 +612,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dbmtl/train.py b/modelzoo/dbmtl/train.py index c848cbc76b2..36f2685a175 100644 --- a/modelzoo/dbmtl/train.py +++ b/modelzoo/dbmtl/train.py @@ -527,10 +527,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dcn/train.py b/modelzoo/dcn/train.py index 44701e22d9f..5094a18bd85 100644 --- a/modelzoo/dcn/train.py +++ b/modelzoo/dcn/train.py @@ -594,10 +594,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dcnv2/train.py b/modelzoo/dcnv2/train.py index 5b572af0425..c1346ad6d7d 100644 --- a/modelzoo/dcnv2/train.py +++ b/modelzoo/dcnv2/train.py @@ -610,10 +610,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/deepfm/train.py b/modelzoo/deepfm/train.py index 166bedec0d0..89b2b823a46 100644 --- a/modelzoo/deepfm/train.py +++ b/modelzoo/deepfm/train.py @@ -472,10 +472,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dien/train.py b/modelzoo/dien/train.py index 190695f6ce0..f43fd2f1e73 100644 --- a/modelzoo/dien/train.py +++ b/modelzoo/dien/train.py @@ -776,10 +776,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/din/train.py b/modelzoo/din/train.py index 058583ce6fd..34621dee45e 100644 --- a/modelzoo/din/train.py +++ b/modelzoo/din/train.py @@ -594,10 +594,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dlrm/train.py b/modelzoo/dlrm/train.py index cc4c045c349..9dff32aca52 100644 --- a/modelzoo/dlrm/train.py +++ b/modelzoo/dlrm/train.py @@ -507,10 +507,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/dssm/train.py b/modelzoo/dssm/train.py index db949aac5e8..9d2264d9ce9 100644 --- a/modelzoo/dssm/train.py +++ b/modelzoo/dssm/train.py @@ -478,10 +478,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/esmm/train.py b/modelzoo/esmm/train.py index 073b08814d4..1916ed76c27 100755 --- a/modelzoo/esmm/train.py +++ b/modelzoo/esmm/train.py @@ -534,10 +534,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=train_steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/masknet/train.py b/modelzoo/masknet/train.py index bb96a467701..bb9eee0ec3f 100644 --- a/modelzoo/masknet/train.py +++ b/modelzoo/masknet/train.py @@ -529,10 +529,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/mlperf/train.py b/modelzoo/mlperf/train.py index ce34fe5e55c..559e4fb6efc 100644 --- a/modelzoo/mlperf/train.py +++ b/modelzoo/mlperf/train.py @@ -522,10 +522,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/mmoe/train.py b/modelzoo/mmoe/train.py index 694eb45da80..a3a6c9146d8 100644 --- a/modelzoo/mmoe/train.py +++ b/modelzoo/mmoe/train.py @@ -523,10 +523,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/ple/train.py b/modelzoo/ple/train.py index b2d2f2057ec..33aa9a15e8e 100644 --- a/modelzoo/ple/train.py +++ b/modelzoo/ple/train.py @@ -592,10 +592,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/simple_multitask/train.py b/modelzoo/simple_multitask/train.py index 4ef1874a521..6eb51f7d4e9 100644 --- a/modelzoo/simple_multitask/train.py +++ b/modelzoo/simple_multitask/train.py @@ -427,10 +427,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=train_steps) log_hook = tf.train.LoggingTensorHook( diff --git a/modelzoo/wide_and_deep/train.py b/modelzoo/wide_and_deep/train.py index 3024f58024e..2d1c964e593 100644 --- a/modelzoo/wide_and_deep/train.py +++ b/modelzoo/wide_and_deep/train.py @@ -543,10 +543,9 @@ def train(sess_config, hooks = [] hooks.extend(input_hooks) - sharded_saver = tf_config != None scaffold = tf.train.Scaffold( local_init_op=tf.group(tf.local_variables_initializer(), data_init_op), - saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver)) + saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True)) stop_hook = tf.train.StopAtStepHook(last_step=steps) log_hook = tf.train.LoggingTensorHook( diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py index 7946aee1e1a..24f8a36daa4 100644 --- a/tensorflow/python/feature_column/feature_column_v2_test.py +++ b/tensorflow/python/feature_column/feature_column_v2_test.py @@ -7527,7 +7527,7 @@ def testEmbeddingVariableForL2FeatureEviction(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables_lib.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) @@ -7758,7 +7758,7 @@ def testEmbeddingVariableForSharedEmbeddingColumnsWithPartitionNum(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) init = variables_lib.global_variables_initializer() - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) @test_util.run_deprecated_v1 def testEmbeddingVariableForInt32ID(self): @@ -7783,7 +7783,7 @@ def testEmbeddingVariableForInt32ID(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables_lib.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) diff --git a/tensorflow/python/ops/embedding_variable_ops_gpu_test.py b/tensorflow/python/ops/embedding_variable_ops_gpu_test.py index d47d94d0d99..3c69153ab1b 100644 --- a/tensorflow/python/ops/embedding_variable_ops_gpu_test.py +++ b/tensorflow/python/ops/embedding_variable_ops_gpu_test.py @@ -63,7 +63,8 @@ def testEmbeddingVariableForInitFromProto(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) graph = ops.get_default_graph() - meta_graph_def = saver_module.export_meta_graph() + saver = saver_module.Saver(sharded=True) + meta_graph_def = saver_module.export_meta_graph(saver_def=saver.as_saver_def()) ops.reset_default_graph() with self.test_session() as sess: res = saver_module.import_meta_graph(meta_graph_def) @@ -748,7 +749,7 @@ def testSaveV3(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) init = variables.global_variables_initializer() - saver = saver = saver_module.Saver() + saver = saver = saver_module.Saver(sharded=True) checkpoint_directory = self.get_temp_dir() model_path = os.path.join(checkpoint_directory, "model.ckpt") with self.test_session() as sess: @@ -816,7 +817,7 @@ def testEmbeddingVariableSaveAndRestoreOptimzierStatesForMultiTierWithHbm(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) graph = ops.get_default_graph() with self.test_session(graph = graph) as sess: saver.restore(sess, os.path.join(checkpoint_directory, "model.ckpt-12345")) diff --git a/tensorflow/python/ops/embedding_variable_ops_test.py b/tensorflow/python/ops/embedding_variable_ops_test.py index dbf254d5f14..1119fd1c194 100644 --- a/tensorflow/python/ops/embedding_variable_ops_test.py +++ b/tensorflow/python/ops/embedding_variable_ops_test.py @@ -162,7 +162,7 @@ def _RecordFreqTestTemplate(self, optimizer): opt = self._CreateOptimizer(optimizer) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -194,7 +194,7 @@ def _RecordVersionTemplate(self, optimizer): opt = self._CreateOptimizer(optimizer) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -232,7 +232,7 @@ def testSaveVersionWithGlobalStepEviction(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) init = variables.global_variables_initializer() - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) model_path = os.path.join(checkpoint_directory, "model.ckpt") with self.test_session() as sess: sess.run([init]) @@ -269,7 +269,7 @@ def testFeatureColumnRecordFreqWithPartition(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -313,7 +313,7 @@ def testFeatureColumnRecordFreqSGDWithPartition(self): opt = gradient_descent.GradientDescentOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -387,7 +387,8 @@ def testDynamicEmbeddingVariableForInitFromProto(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) graph = ops.get_default_graph() - meta_graph_def = saver_module.export_meta_graph() + saver = saver_module.Saver(sharded=True) + meta_graph_def = saver_module.export_meta_graph(saver_def=saver.as_saver_def()) ops.reset_default_graph() with self.test_session() as sess: res = saver_module.import_meta_graph(meta_graph_def) @@ -406,7 +407,8 @@ def testEmbeddingVariableForInitFromProto(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) graph = ops.get_default_graph() - meta_graph_def = saver_module.export_meta_graph() + saver = saver_module.Saver(sharded=True) + meta_graph_def = saver_module.export_meta_graph(saver_def=saver.as_saver_def()) ops.reset_default_graph() with self.test_session() as sess: res = saver_module.import_meta_graph(meta_graph_def) @@ -450,7 +452,7 @@ def testEmbeddingVariableForLookupInt32(self): opt = adam.AdamOptimizer(0.01) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) @@ -643,7 +645,7 @@ def testEmbeddingVariableForL2FeatureEvictionFromContribFeatureColumn(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) @@ -682,7 +684,7 @@ def testEmbeddingVariableForGlobalStepEviction(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run([init]) @@ -720,7 +722,7 @@ def testEmbeddingVariableForL2FeatureEviction(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) @@ -1534,7 +1536,7 @@ def testEmbeddingVariableForSaveFreq(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) init = variables.global_variables_initializer() - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) model_path = os.path.join(checkpoint_directory, "model.ckpt") with self.test_session() as sess: sess.run([init]) @@ -1567,7 +1569,7 @@ def testEmbeddingVariableForL2FeatureEvictionDRAM(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) @@ -1724,7 +1726,7 @@ def runTestAdagrad(self, var, g): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -1778,7 +1780,7 @@ def runTestAdagrad(self, var, g): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -1849,7 +1851,7 @@ def runTestAdagrad(self, var, g): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -1923,7 +1925,7 @@ def testEmbeddingVariableForRecordFreq(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -1963,7 +1965,7 @@ def testEmbeddingVariableForRecordFreqWithCounterFilter(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -2278,7 +2280,7 @@ def testEmbeddingVariableForContirbFeatureColumnWithPartitionNum(self): opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) def testSaveV3(self): print("testSaveV3") @@ -2295,7 +2297,7 @@ def testSaveV3(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, global_step=gs) init = variables.global_variables_initializer() - saver = saver = saver_module.Saver() + saver = saver = saver_module.Saver(sharded=True) checkpoint_directory = self.get_temp_dir() model_path = os.path.join(checkpoint_directory, "model.ckpt") with self.test_session() as sess: @@ -2326,7 +2328,7 @@ def testEmbeddingVariableForNotSaveUnfilterFeature(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -2359,7 +2361,7 @@ def testEmbeddingVariableForSaveUnfilterFeature(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() model_path = os.path.join(checkpoint_directory, "model1.ckpt") @@ -2390,7 +2392,7 @@ def testEmbeddingVariableForMultiTierInference(self): opt = adagrad.AdagradOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v, gs) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session() as sess: sess.run([init]) @@ -2412,7 +2414,7 @@ def testEmbeddingVariableForMultiTierInference(self): emb = embedding_ops.embedding_lookup(emb_var, ids) tires = kv_variable_ops.lookup_tier(emb_var, math_ops.cast([1,2,3,4], dtypes.int64)) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) graph = ops.get_default_graph() with self.test_session(graph = graph) as sess: saver.restore(sess, os.path.join(checkpoint_directory, "model.ckpt")) @@ -2784,7 +2786,7 @@ def testSetInitializedWithoutRestore(self): g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) init = variables.global_variables_initializer() - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) with self.test_session() as sess: result = sess.run(var._is_initialized_op) self.assertEqual(False, result) @@ -2806,7 +2808,7 @@ def testSetInitializedWithRestore(self): opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: sess.run([init]) @@ -2823,7 +2825,7 @@ def testSetInitializedWithRestore(self): opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: result = sess.run(var._is_initialized_op) @@ -2860,7 +2862,7 @@ def testCountsTensor(self): opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: sess.run([init]) @@ -2893,7 +2895,7 @@ def testCountsWithSparseAndDenseTensor(self): opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: sess.run([init]) @@ -2929,7 +2931,7 @@ def testCountsTensorWithGradientDescent(self): opt = gradient_descent.GradientDescentOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: sess.run([init]) @@ -2964,7 +2966,7 @@ def testCountsDenseAndSparseTensorWithGradientDescent(self): opt = gradient_descent.GradientDescentOptimizer(0.1) g_v = opt.compute_gradients(loss) train_op = opt.apply_gradients(g_v) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() with self.test_session(graph=g) as sess: sess.run([init]) diff --git a/tensorflow/python/training/incr_ckpt_test.py b/tensorflow/python/training/incr_ckpt_test.py index 55cf748a9d6..849c73a44dc 100644 --- a/tensorflow/python/training/incr_ckpt_test.py +++ b/tensorflow/python/training/incr_ckpt_test.py @@ -75,7 +75,7 @@ def testSparseEvIncrSaveRestore(self): emb = embedding_ops.embedding_lookup(var, math_ops.cast([0,1,2,5,6,7], dtypes.int64)) with ops.device("/device:CPU:0"): apply_incr = gen_io_ops.record_sparse_indices(math_ops.cast([0,1,2,5,6,7], dtypes.int64), "var_ev1") - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() ev_var_name = "var_ev1" incr_save_op = gen_io_ops.incr_save(incr_ckpt_path, [ev_var_name], [], [True],[var.handle]) @@ -178,7 +178,7 @@ def testMixIncrSaveRestore(self): activate_op = gen_io_ops. activate_sparse_recorder(["var_ev1","var_norm1"]) - saver = saver_module.Saver() + saver = saver_module.Saver(sharded=True) init = variables.global_variables_initializer() incr_save_op = gen_io_ops.incr_save(incr_ckpt_path, ["var_norm1", "var_ev1"], [], [True, True], [var_norm, var_ev.handle]) @@ -445,6 +445,7 @@ def testIncrementalSaverForResourceVariable(self): variable_scope.get_variable('var', shape=[100], use_resource=False) variable_scope.get_embedding_variable('ev', embedding_dim=100) saver = saver_module.Saver( + sharded=True, save_relative_paths=True, incremental_save_restore=True, ) diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index acc9723c183..e70226f2968 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -1071,10 +1071,14 @@ def _build(self, checkpoint_path, build_save, build_restore): # pylint: disable=protected-access self._var_list = variables._all_saveable_objects() from tensorflow.python.ops import hash_table + from tensorflow.python.ops import kv_variable_ops if isinstance(self._var_list, dict): + ev = {} ht = {} lst = {} for name, x in self._var_list.items(): + if isinstance(x, kv_variable_ops.EmbeddingVariable): + ev[name] = x if isinstance(x, hash_table.HashTable): if x.hash_table not in ht: ht[x.hash_table] = [x] @@ -1084,15 +1088,20 @@ def _build(self, checkpoint_path, build_save, build_restore): lst[name] = BloomFilterSaveable(x) else: lst[name] = x + if len(ev) != 0 and not self._sharded: + raise ValueError("EmbeddingVariable can only use sharded saver") if len(ht) != 0 and not self._sharded: raise ValueError("HashTable can only use sharded saver") for x, y in ht.items(): lst[x.name] = HashTableSaveable(y) self._var_list = lst else: + ev = [] ht = {} lst = [] for x in self._var_list: + if isinstance(x, kv_variable_ops.EmbeddingVariable): + ev.append(x) if isinstance(x, hash_table.HashTable): if x.hash_table not in ht: ht[x.hash_table] = [x] @@ -1102,6 +1111,8 @@ def _build(self, checkpoint_path, build_save, build_restore): lst.append(BloomFilterSaveable(x)) else: lst.append(x) + if len(ev) != 0 and not self._sharded: + raise ValueError("EmbeddingVariable can only use sharded saver") if len(ht) != 0 and not self._sharded: raise ValueError("HashTable can only use sharded saver") for x, y in ht.items(): diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index b48f00d0c14..365ef85af1d 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -852,6 +852,12 @@ def _model(): for orig, restored in zip(orig_vals, restored_vals): self.assertAllEqual(orig, restored) + def testEnableSaverShardedWhenUseEmbeddingVariable(self): + with ops_lib.Graph().as_default(): + emb_var = \ + variable_scope.get_embedding_variable(name="emb_var", embedding_dim=64) + with self.assertRaisesRegexp(ValueError, "EmbeddingVariable"): + saver_module.Saver([emb_var], sharded=False) class SaveRestoreShardedTest(test.TestCase): From d1c5a6e9aa2ec62da93f6719c6755293cf6406a5 Mon Sep 17 00:00:00 2001 From: LightWang4 <303176469@qq.com> Date: Tue, 21 Jan 2025 17:54:28 +0800 Subject: [PATCH 14/14] [Embedding] Fix op dependency in init_from_checkpoint API. (#1012) Signed-off-by: lightwang --- tensorflow/python/training/checkpoint_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index db887fa12f1..d87a9f1b39b 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -443,7 +443,8 @@ def _set_checkpoint_initializer(variable, is_partitioned_ev = variable._save_slice_info is not None partition_id = variable._save_slice_info.var_offset[0] if is_partitioned_ev else 0 partition_num = variable._save_slice_info.full_shape[0] if is_partitioned_ev else 1 - with ops.control_dependencies([variable._initializer_op]): + restore_dependency = ops.get_collection(ops.GraphKeys.EMBEDDING_VARIABLE_RESTORE_DEPENDENCY)[0] + with ops.control_dependencies(restore_dependency[variable._primary_handle]): rank = variable.initial_value.get_shape().rank - 1 restore_op = gen_kv_variable_ops.kv_resource_import_v3( ckpt_file,