Skip to content

Commit 7585042

Browse files
Jiho Choitensorflower-gardener
authored andcommitted
Use the event name as part of the step name for the explicit root events.
PiperOrigin-RevId: 318634056 Change-Id: I2860534f4ebe62e732306a39a6a8fd57f6366b16
1 parent 4aa879a commit 7585042

3 files changed

Lines changed: 23 additions & 4 deletions

File tree

tensorflow/core/profiler/convert/xplane_to_trace_events.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ void ConvertXPlaneToTraceEvents(uint32 device_id, const XPlaneVisitor& xplane,
9191
xevent.ForEachStat([&](const XStatVisitor& stat) {
9292
if (stat.ValueCase() == XStat::VALUE_NOT_SET) return;
9393
if (IsInternalStat(stat.Type())) return;
94+
if (stat.Type() == StatType::kStepName) {
95+
event->set_name(stat.ToString());
96+
}
9497
args[std::string(stat.Name())] = stat.ToString();
9598
});
9699
});

tensorflow/core/profiler/utils/group_events.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,25 @@ bool HasFunctionRun(EventNode* event_node) {
139139
return false;
140140
}
141141

142+
bool IsImplicitRootEvent(const XEventVisitor& event) {
143+
static const auto* const kImplicitRootEvents = new absl::flat_hash_set<int64>{
144+
HostEventType::kFunctionRun, HostEventType::kSessionRun,
145+
HostEventType::kRunGraph, HostEventType::kExecutorStateProcess};
146+
return event.Type().has_value() &&
147+
kImplicitRootEvents->contains(*event.Type());
148+
}
149+
142150
void ProcessRootEvent(int64 group_id, EventNode* root_event,
143151
EventGroupNameMap* event_group_name_map) {
144152
root_event->PropagateGroupId(group_id);
145153
std::string group_name = root_event->GetGroupName();
146154
// TODO(jihochoi): change event name instead.
147-
root_event->AddStepName(group_name);
155+
if (!IsImplicitRootEvent(root_event->GetEventVisitor())) {
156+
// Add the `step_name` stat for the user-defined root events only. When an
157+
// XEvent is converted to a trace event, the trace event name is set to the
158+
// `step_name` stat's value if present.
159+
root_event->AddStepName(group_name);
160+
}
148161
event_group_name_map->emplace(group_id, std::move(group_name));
149162
}
150163

@@ -336,6 +349,8 @@ std::string EventNode::GetGroupName() const {
336349
if (absl::optional<XStatVisitor> stat =
337350
GetContextStat(StatType::kGraphType)) {
338351
absl::StrAppend(&name, stat->StrOrRefValue(), " ");
352+
} else if (!(IsImplicitRootEvent(visitor_))) {
353+
absl::StrAppend(&name, GetEventVisitor().Name(), " ");
339354
}
340355
int64 step_num = group_id_.value_or(0);
341356
if (absl::optional<XStatVisitor> stat = GetContextStat(StatType::kIterNum)) {

tensorflow/core/profiler/utils/group_events_test.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ TEST(GroupEventsTest, GroupGpuTraceTest) {
4040
host_plane_builder.ReserveLines(2);
4141

4242
auto main_thread = host_plane_builder.GetOrCreateLine(0);
43-
CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kTraceContext,
44-
0, 100, {{StatType::kStepNum, kStepNum}});
43+
CreateXEvent(
44+
&host_plane_builder, &main_thread, HostEventType::kTraceContext, 0, 100,
45+
{{StatType::kGraphType, "train"}, {StatType::kStepNum, kStepNum}});
4546
CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun,
4647
10, 90, {{StatType::kStepId, kStepId}});
4748

@@ -68,7 +69,7 @@ TEST(GroupEventsTest, GroupGpuTraceTest) {
6869
device_plane->lines(0).events(0).stats(1)),
6970
StatType::kGroupId);
7071
EXPECT_EQ(event_group_name_map.size(), 1);
71-
EXPECT_EQ(event_group_name_map[0], "123");
72+
EXPECT_EQ(event_group_name_map[0], "train 123");
7273
}
7374

7475
TEST(GroupEventsTest, GroupTensorFlowLoopTest) {

0 commit comments

Comments
 (0)