Skip to content

Commit 7511012

Browse files
mrrytensorflower-gardener
authored andcommitted
Set a more descriptive error when a Run() call is cancelled due to session closure.
Also fixes a potential memory leak, where the worker-side state of a failed Run() call would not be cleaned up. Change: 147752067
1 parent 4903865 commit 7511012

6 files changed

Lines changed: 53 additions & 14 deletions

File tree

tensorflow/core/distributed_runtime/master.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,7 @@ void Master::GC() {
101101
<< "Note that if you are starting multiple replicas "
102102
<< "on a staggered delay, session_gc_seconds may need "
103103
<< "to be raised.";
104-
sess->Close().IgnoreError();
105-
sess->Unref();
104+
sess->GarbageCollect();
106105
});
107106
}
108107
}

tensorflow/core/distributed_runtime/master_session.cc

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,41 +1387,67 @@ Status MasterSession::DoRunWithLocalExecution(
13871387
pss.collect_rpcs = ph->should_collect_rpcs();
13881388
}
13891389

1390-
TF_RETURN_IF_ERROR(rcg->RunPartitions(env_, step_id, count,
1391-
execution_state_.get(), &pss, opts, req,
1392-
resp, cancellation_manager_, false));
1390+
Status s =
1391+
rcg->RunPartitions(env_, step_id, count, execution_state_.get(), &pss,
1392+
opts, req, resp, cancellation_manager_, false);
1393+
if (s.ok()) {
1394+
pss.end_micros = Env::Default()->NowMicros();
13931395

1394-
pss.end_micros = Env::Default()->NowMicros();
1395-
1396-
// Schedule post-processing and cleanup to be done asynchronously.
1396+
// Schedule post-processing and cleanup to be done asynchronously.
1397+
rcg->ProcessStats(step_id, &pss, execution_state_.get(), ph.get(),
1398+
req.options(), resp->mutable_metadata());
1399+
} else if (errors::IsCancelled(s)) {
1400+
mutex_lock l(mu_);
1401+
if (closed_) {
1402+
if (garbage_collected_) {
1403+
s = errors::Cancelled(
1404+
"Step was cancelled because the session was garbage collected due "
1405+
"to inactivity.");
1406+
} else {
1407+
s = errors::Cancelled(
1408+
"Step was cancelled by an explicit call to `Session::Close()`.");
1409+
}
1410+
}
1411+
}
13971412
rcg->Ref();
1398-
rcg->ProcessStats(step_id, &pss, execution_state_.get(), ph.get(),
1399-
req.options(), resp->mutable_metadata());
14001413
rcg->CleanupPartitionsAsync(step_id, [rcg](const Status& s) {
14011414
if (!s.ok()) {
14021415
LOG(ERROR) << "Cleanup partition error: " << s;
14031416
}
14041417
rcg->Unref();
14051418
});
1406-
return Status::OK();
1419+
return s;
14071420
}
14081421

14091422
Status MasterSession::Close() {
1423+
{
1424+
mutex_lock l(mu_);
1425+
closed_ = true; // All subsequent calls to Run() or Extend() will fail.
1426+
}
14101427
cancellation_manager_->StartCancel();
14111428
std::vector<ReffedClientGraph*> to_unref;
14121429
{
14131430
mutex_lock l(mu_);
14141431
while (num_running_ != 0) {
14151432
num_running_is_zero_.wait(l);
14161433
}
1417-
closed_ = true; // All subsequent calls to Run() or Extend() will fail.
14181434
ClearRunsTable(&to_unref, &run_graphs_);
14191435
ClearRunsTable(&to_unref, &partial_run_graphs_);
14201436
}
14211437
for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
14221438
return Status::OK();
14231439
}
14241440

1441+
void MasterSession::GarbageCollect() {
1442+
{
1443+
mutex_lock l(mu_);
1444+
closed_ = true;
1445+
garbage_collected_ = true;
1446+
}
1447+
cancellation_manager_->StartCancel();
1448+
Unref();
1449+
}
1450+
14251451
MasterSession::RunState::RunState(const std::vector<string>& input_names,
14261452
const std::vector<string>& output_names,
14271453
ReffedClientGraph* rcg, const uint64 step_id,

tensorflow/core/distributed_runtime/master_session.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ class MasterSession : public core::RefCounted {
8888
// Close() may block the caller thread for a long time.
8989
Status Close();
9090

91+
// Close this session and release a reference on "*this".
92+
//
93+
// Note that, unlike Close(), this method does not block on the
94+
// completion of all work.
95+
void GarbageCollect();
96+
9197
private:
9298
SessionOptions session_opts_;
9399

@@ -158,6 +164,7 @@ class MasterSession : public core::RefCounted {
158164
int32 num_running_ GUARDED_BY(mu_) = 0;
159165

160166
bool closed_ GUARDED_BY(mu_) = false;
167+
bool garbage_collected_ GUARDED_BY(mu_) = false;
161168

162169
std::unordered_map<uint64, int64> subgraph_execution_counts_ GUARDED_BY(mu_);
163170

tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ Status GrpcServer::Init() {
162162
builder.SetMaxMessageSize(std::numeric_limits<int32>::max());
163163
builder.SetOption(
164164
std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
165-
master_impl_.reset(new Master(&master_env_, 0.0));
165+
master_impl_ = CreateMaster(&master_env_);
166166
master_service_ = NewGrpcMasterService(master_impl_.get(), &builder);
167167
worker_impl_.reset(NewGrpcWorker(&worker_env_));
168168
worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder);
@@ -297,6 +297,10 @@ ChannelCreationFunction GrpcServer::GetChannelCreationFunction(
297297
return NewHostPortGrpcChannel;
298298
}
299299

300+
std::unique_ptr<Master> GrpcServer::CreateMaster(MasterEnv* master_env) {
301+
return std::unique_ptr<Master>(new Master(master_env, 0.0));
302+
}
303+
300304
/* static */
301305
Status GrpcServer::Create(const ServerDef& server_def, Env* env,
302306
std::unique_ptr<ServerInterface>* out_server) {

tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ class GrpcServer : public ServerInterface {
6363
virtual ChannelCreationFunction GetChannelCreationFunction(
6464
const ServerDef& server_def) const;
6565

66+
virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env);
67+
6668
// Returns the port to which this server is bound.
6769
// This method may only be called after `this->Init()` returns successfully.
6870
int bound_port() const { return bound_port_; }

tensorflow/python/training/server_lib_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ def testCloseCancelsBlockingOperation(self):
156156
sess.run(dequeue_t)
157157

158158
def blocking_dequeue():
159-
with self.assertRaises(errors_impl.CancelledError):
159+
with self.assertRaisesRegexp(errors_impl.CancelledError,
160+
"Session::Close"):
160161
sess.run(dequeue_t)
161162

162163
blocking_thread = self.checkedThread(blocking_dequeue)

0 commit comments

Comments
 (0)