Skip to content

Commit 14e5632

Browse files
committed
Several changes to support cross-language transforms on DataflowRunner
1 parent d93e4e0 commit 14e5632

4 files changed

Lines changed: 39 additions & 9 deletions

File tree

sdks/python/apache_beam/runners/dataflow/dataflow_runner.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -858,13 +858,31 @@ def run_ParDo(self, transform_node, options):
858858
outputs = []
859859
step.encoding = self._get_encoded_output_coder(transform_node)
860860

861+
all_output_tags = transform_proto.outputs.keys()
862+
863+
from apache_beam.transforms.core import RunnerAPIPTransformHolder
864+
external_transform = isinstance(transform, RunnerAPIPTransformHolder)
865+
866+
# Some external transforms require output tags to not be modified.
867+
# So we randomly select one of the output tags as the main output and
868+
# leave others as side outputs. Transform execution should not change
869+
# dependending on which output tag we choose as the main output here.
870+
# Also, some SDKs do not work correctly if output tags are modified. So for
871+
# external transforms, we leave tags unmodified.
872+
main_output_tag = (
873+
all_output_tags[0] if external_transform else PropertyNames.OUT)
874+
875+
# Python SDK uses 'None' as the tag of the main output.
876+
tag_to_ignore = main_output_tag if external_transform else 'None'
877+
side_output_tags = set(all_output_tags).difference({tag_to_ignore})
878+
861879
# Add the main output to the description.
862880
outputs.append(
863881
{PropertyNames.USER_NAME: (
864882
'%s.%s' % (transform_node.full_label, PropertyNames.OUT)),
865883
PropertyNames.ENCODING: step.encoding,
866-
PropertyNames.OUTPUT_NAME: PropertyNames.OUT})
867-
for side_tag in transform.output_tags:
884+
PropertyNames.OUTPUT_NAME: main_output_tag})
885+
for side_tag in side_output_tags:
868886
# The assumption here is that all outputs will have the same typehint
869887
# and coder as the main output. This is certainly the case right now
870888
# but conceivably it could change in the future.
@@ -873,7 +891,8 @@ def run_ParDo(self, transform_node, options):
873891
'%s.%s' % (transform_node.full_label, side_tag)),
874892
PropertyNames.ENCODING: step.encoding,
875893
PropertyNames.OUTPUT_NAME: (
876-
'%s_%s' % (PropertyNames.OUT, side_tag))})
894+
side_tag if external_transform
895+
else '%s_%s' % (PropertyNames.OUT, side_tag))})
877896

878897
step.add_property(PropertyNames.OUTPUT_INFO, outputs)
879898

sdks/python/apache_beam/runners/pipeline_context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ class Environment(object):
4040
Provides consistency with how the other componentes are accessed.
4141
"""
4242
def __init__(self, proto):
43-
self._proto = proto
43+
self.proto = proto
4444

4545
def to_runner_api(self, context):
46-
return self._proto
46+
return self.proto
4747

4848
@staticmethod
4949
def from_runner_api(proto, context):

sdks/python/apache_beam/testing/test_pipeline.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def __init__(self,
6767
options=None,
6868
argv=None,
6969
is_integration_test=False,
70-
blocking=True):
70+
blocking=True,
71+
additional_pipeline_args=None):
7172
"""Initialize a pipeline object for test.
7273
7374
Args:
@@ -88,14 +89,18 @@ def __init__(self,
8889
test, :data:`False` otherwise.
8990
blocking (bool): Run method will wait until pipeline execution is
9091
completed.
92+
additional_pipeline_args (List[str]): additional pipeline arguments to be
93+
included when construction the pipeline options object.
9194
9295
Raises:
9396
~exceptions.ValueError: if either the runner or options argument is not
9497
of the expected type.
9598
"""
9699
self.is_integration_test = is_integration_test
97100
self.not_use_test_runner_api = False
98-
self.options_list = self._parse_test_option_args(argv)
101+
additional_pipeline_args = additional_pipeline_args or []
102+
self.options_list = (
103+
self._parse_test_option_args(argv) + additional_pipeline_args)
99104
self.blocking = blocking
100105
if options is None:
101106
options = PipelineOptions(self.options_list)

sdks/python/apache_beam/transforms/core.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,10 +362,16 @@ def to_runner_api(self, context, has_parts=False):
362362
else:
363363
env1 = id_to_proto_map[env_id]
364364
env2 = context.environments[env_id]
365-
assert env1.SerializeToString() == env2.SerializeToString(), (
365+
assert env1.urn == env2.proto.urn, (
366366
'Expected environments with the same ID to be equal but received '
367+
'environments with different URNs '
367368
'%r and %r',
368-
env1, env2)
369+
env1.urn, env2.proto.urn)
370+
assert env1.payload == env2.proto.payload, (
371+
'Expected environments with the same ID to be equal but received '
372+
'environments with different payloads '
373+
'%r and %r',
374+
env1.payload, env2.proto.payload)
369375
return self._proto
370376

371377
def get_restriction_coder(self):

0 commit comments

Comments
 (0)