diff --git a/graphframes-connect/src/main/protobuf/graphframes.proto b/graphframes-connect/src/main/protobuf/graphframes.proto index 3e81f6c4f..8f5e6759d 100644 --- a/graphframes-connect/src/main/protobuf/graphframes.proto +++ b/graphframes-connect/src/main/protobuf/graphframes.proto @@ -111,6 +111,7 @@ message Pregel { string additional_col_name = 6; ColumnOrExpression additional_col_initial = 7; ColumnOrExpression additional_col_upd = 8; + optional bool early_stopping = 9; } message ShortestPaths { diff --git a/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala index e4990fc37..0e20c29ed 100644 --- a/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala +++ b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala @@ -174,6 +174,10 @@ object GraphFramesConnectUtils { .map(parseColumnOrExpression(_, planner)) .foldLeft(pregel)((p, col) => p.sendMsgToDst(col)) + if (pregelProto.hasEarlyStopping) { + pregel = pregel.setEarlyStopping(pregelProto.getEarlyStopping) + } + pregel.run() } case MethodCase.SHORTEST_PATHS => { diff --git a/python/graphframes/connect/graphframe_client.py b/python/graphframes/connect/graphframe_client.py index bb74e19a6..5de5ba723 100644 --- a/python/graphframes/connect/graphframe_client.py +++ b/python/graphframes/connect/graphframe_client.py @@ -24,6 +24,7 @@ def __init__(self, graph: "GraphFrameConnect") -> None: self._send_msg_to_src = [] self._send_msg_to_dst = [] self._agg_msg = None + self._early_stopping = False def setMaxIter(self, value: int) -> Self: self._max_iter = value @@ -33,6 +34,10 @@ def setCheckpointInterval(self, value: int) -> Self: self._checkpoint_interval = value return self + def setEarlyStopping(self, value: bool) -> Self: + self._early_stopping = value + return self + def withVertexColumn( self, colName: str, @@ -62,6 +67,7 @@ def __init__( self, max_iter: int, checkpoint_interval: int, + early_stopping: bool, vertex_col_name: str, agg_msg: Column | str, send2dst: list[Column | str], @@ -74,6 +80,7 @@ def __init__( super().__init__(None) self.max_iter = max_iter self.checkpoint_interval = checkpoint_interval + self.early_stopping = early_stopping self.vertex_col_name = vertex_col_name self.agg_msg = agg_msg self.send2dst = send2dst @@ -97,6 +104,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: additional_col_name=self.vertex_col_name, additional_col_initial=make_column_or_expr(self.vertex_col_init, session), additional_col_upd=make_column_or_expr(self.vertex_col_upd, session), + early_stopping=self.early_stopping, ) pb_message = pb.GraphFramesAPI( vertices=dataframe_to_proto(self.vertices, session), @@ -129,6 +137,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: send2src=self._send_msg_to_src, vertices=self.graph._vertices, edges=self.graph._edges, + early_stopping=self._early_stopping, ), session=self.graph._spark, ) diff --git a/python/graphframes/connect/proto/graphframes_pb2.py b/python/graphframes/connect/proto/graphframes_pb2.py index 39e99776e..156f3afa7 100644 --- a/python/graphframes/connect/proto/graphframes_pb2.py +++ b/python/graphframes/connect/proto/graphframes_pb2.py @@ -19,7 +19,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x11graphframes.proto\x12\x1dorg.graphframes.connect.proto"\xd6\x0c\n\x0eGraphFramesAPI\x12\x1a\n\x08vertices\x18\x01 \x01(\x0cR\x08vertices\x12\x14\n\x05\x65\x64ges\x18\x02 \x01(\x0cR\x05\x65\x64ges\x12\x61\n\x12\x61ggregate_messages\x18\x03 \x01(\x0b\x32\x30.org.graphframes.connect.proto.AggregateMessagesH\x00R\x11\x61ggregateMessages\x12\x36\n\x03\x62\x66s\x18\x04 \x01(\x0b\x32".org.graphframes.connect.proto.BFSH\x00R\x03\x62\x66s\x12g\n\x14\x63onnected_components\x18\x05 \x01(\x0b\x32\x32.org.graphframes.connect.proto.ConnectedComponentsH\x00R\x13\x63onnectedComponents\x12k\n\x16\x64rop_isolated_vertices\x18\x06 \x01(\x0b\x32\x33.org.graphframes.connect.proto.DropIsolatedVerticesH\x00R\x14\x64ropIsolatedVertices\x12O\n\x0c\x66ilter_edges\x18\x07 \x01(\x0b\x32*.org.graphframes.connect.proto.FilterEdgesH\x00R\x0b\x66ilterEdges\x12X\n\x0f\x66ilter_vertices\x18\x08 \x01(\x0b\x32-.org.graphframes.connect.proto.FilterVerticesH\x00R\x0e\x66ilterVertices\x12\x39\n\x04\x66ind\x18\t \x01(\x0b\x32#.org.graphframes.connect.proto.FindH\x00R\x04\x66ind\x12^\n\x11label_propagation\x18\n \x01(\x0b\x32/.org.graphframes.connect.proto.LabelPropagationH\x00R\x10labelPropagation\x12\x46\n\tpage_rank\x18\x0b \x01(\x0b\x32\'.org.graphframes.connect.proto.PageRankH\x00R\x08pageRank\x12\x84\x01\n\x1fparallel_personalized_page_rank\x18\x0c \x01(\x0b\x32;.org.graphframes.connect.proto.ParallelPersonalizedPageRankH\x00R\x1cparallelPersonalizedPageRank\x12w\n\x1apower_iteration_clustering\x18\r \x01(\x0b\x32\x37.org.graphframes.connect.proto.PowerIterationClusteringH\x00R\x18powerIterationClustering\x12?\n\x06pregel\x18\x0e \x01(\x0b\x32%.org.graphframes.connect.proto.PregelH\x00R\x06pregel\x12U\n\x0eshortest_paths\x18\x0f \x01(\x0b\x32,.org.graphframes.connect.proto.ShortestPathsH\x00R\rshortestPaths\x12\x80\x01\n\x1dstrongly_connected_components\x18\x10 \x01(\x0b\x32:.org.graphframes.connect.proto.StronglyConnectedComponentsH\x00R\x1bstronglyConnectedComponents\x12P\n\rsvd_plus_plus\x18\x11 \x01(\x0b\x32*.org.graphframes.connect.proto.SVDPlusPlusH\x00R\x0bsvdPlusPlus\x12U\n\x0etriangle_count\x18\x12 \x01(\x0b\x32,.org.graphframes.connect.proto.TriangleCountH\x00R\rtriangleCount\x12\x45\n\x08triplets\x18\x13 \x01(\x0b\x32\'.org.graphframes.connect.proto.TripletsH\x00R\x08tripletsB\x08\n\x06method"M\n\x12\x43olumnOrExpression\x12\x12\n\x03\x63ol\x18\x01 \x01(\x0cH\x00R\x03\x63ol\x12\x14\n\x04\x65xpr\x18\x02 \x01(\tH\x00R\x04\x65xprB\r\n\x0b\x63ol_or_expr"P\n\x0eStringOrLongID\x12\x19\n\x07long_id\x18\x01 \x01(\x03H\x00R\x06longId\x12\x1d\n\tstring_id\x18\x02 \x01(\tH\x00R\x08stringIdB\x04\n\x02id"\xaf\x02\n\x11\x41ggregateMessages\x12J\n\x07\x61gg_col\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06\x61ggCol\x12V\n\x0bsend_to_src\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x00R\tsendToSrc\x88\x01\x01\x12V\n\x0bsend_to_dst\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x01R\tsendToDst\x88\x01\x01\x42\x0e\n\x0c_send_to_srcB\x0e\n\x0c_send_to_dst"\x9d\x02\n\x03\x42\x46S\x12N\n\tfrom_expr\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x08\x66romExpr\x12J\n\x07to_expr\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06toExpr\x12R\n\x0b\x65\x64ge_filter\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\nedgeFilter\x12&\n\x0fmax_path_length\x18\x04 \x01(\x05R\rmaxPathLength"\x95\x01\n\x13\x43onnectedComponents\x12\x1c\n\talgorithm\x18\x01 \x01(\tR\talgorithm\x12/\n\x13\x63heckpoint_interval\x18\x02 \x01(\x05R\x12\x63heckpointInterval\x12/\n\x13\x62roadcast_threshold\x18\x03 \x01(\x05R\x12\x62roadcastThreshold"\x16\n\x14\x44ropIsolatedVertices"^\n\x0b\x46ilterEdges\x12O\n\tcondition\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition"a\n\x0e\x46ilterVertices\x12O\n\tcondition\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition" \n\x04\x46ind\x12\x18\n\x07pattern\x18\x01 \x01(\tR\x07pattern"-\n\x10LabelPropagation\x12\x19\n\x08max_iter\x18\x01 \x01(\x05R\x07maxIter"\xe2\x01\n\x08PageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12O\n\tsource_id\x18\x02 \x01(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDH\x00R\x08sourceId\x88\x01\x01\x12\x1e\n\x08max_iter\x18\x03 \x01(\x05H\x01R\x07maxIter\x88\x01\x01\x12\x15\n\x03tol\x18\x04 \x01(\x01H\x02R\x03tol\x88\x01\x01\x42\x0c\n\n_source_idB\x0b\n\t_max_iterB\x06\n\x04_tol"\xb4\x01\n\x1cParallelPersonalizedPageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12L\n\nsource_ids\x18\x02 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tsourceIds\x12\x19\n\x08max_iter\x18\x03 \x01(\x05R\x07maxIter"v\n\x18PowerIterationClustering\x12\x0c\n\x01k\x18\x01 \x01(\x05R\x01k\x12\x19\n\x08max_iter\x18\x02 \x01(\x05R\x07maxIter\x12"\n\nweight_col\x18\x03 \x01(\tH\x00R\tweightCol\x88\x01\x01\x42\r\n\x0b_weight_col"\xd0\x04\n\x06Pregel\x12L\n\x08\x61gg_msgs\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x07\x61ggMsgs\x12X\n\x0fsend_msg_to_dst\x18\x02 \x03(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToDst\x12X\n\x0fsend_msg_to_src\x18\x03 \x03(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToSrc\x12/\n\x13\x63heckpoint_interval\x18\x04 \x01(\x05R\x12\x63heckpointInterval\x12\x19\n\x08max_iter\x18\x05 \x01(\x05R\x07maxIter\x12.\n\x13\x61\x64\x64itional_col_name\x18\x06 \x01(\tR\x11\x61\x64\x64itionalColName\x12g\n\x16\x61\x64\x64itional_col_initial\x18\x07 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x14\x61\x64\x64itionalColInitial\x12_\n\x12\x61\x64\x64itional_col_upd\x18\x08 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x10\x61\x64\x64itionalColUpd"\\\n\rShortestPaths\x12K\n\tlandmarks\x18\x01 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tlandmarks"8\n\x1bStronglyConnectedComponents\x12\x19\n\x08max_iter\x18\x01 \x01(\x05R\x07maxIter"\xd6\x01\n\x0bSVDPlusPlus\x12\x12\n\x04rank\x18\x01 \x01(\x05R\x04rank\x12\x19\n\x08max_iter\x18\x02 \x01(\x05R\x07maxIter\x12\x1b\n\tmin_value\x18\x03 \x01(\x01R\x08minValue\x12\x1b\n\tmax_value\x18\x04 \x01(\x01R\x08maxValue\x12\x16\n\x06gamma1\x18\x05 \x01(\x01R\x06gamma1\x12\x16\n\x06gamma2\x18\x06 \x01(\x01R\x06gamma2\x12\x16\n\x06gamma6\x18\x07 \x01(\x01R\x06gamma6\x12\x16\n\x06gamma7\x18\x08 \x01(\x01R\x06gamma7"\x0f\n\rTriangleCount"\n\n\x08TripletsB\xd2\x01\n!com.org.graphframes.connect.protoB\x10GraphframesProtoH\x01P\x01\xa0\x01\x01\xa2\x02\x04OGCP\xaa\x02\x1dOrg.Graphframes.Connect.Proto\xca\x02\x1dOrg\\Graphframes\\Connect\\Proto\xe2\x02)Org\\Graphframes\\Connect\\Proto\\GPBMetadata\xea\x02 Org::Graphframes::Connect::Protob\x06proto3' + b'\n\x11graphframes.proto\x12\x1dorg.graphframes.connect.proto"\xd6\x0c\n\x0eGraphFramesAPI\x12\x1a\n\x08vertices\x18\x01 \x01(\x0cR\x08vertices\x12\x14\n\x05\x65\x64ges\x18\x02 \x01(\x0cR\x05\x65\x64ges\x12\x61\n\x12\x61ggregate_messages\x18\x03 \x01(\x0b\x32\x30.org.graphframes.connect.proto.AggregateMessagesH\x00R\x11\x61ggregateMessages\x12\x36\n\x03\x62\x66s\x18\x04 \x01(\x0b\x32".org.graphframes.connect.proto.BFSH\x00R\x03\x62\x66s\x12g\n\x14\x63onnected_components\x18\x05 \x01(\x0b\x32\x32.org.graphframes.connect.proto.ConnectedComponentsH\x00R\x13\x63onnectedComponents\x12k\n\x16\x64rop_isolated_vertices\x18\x06 \x01(\x0b\x32\x33.org.graphframes.connect.proto.DropIsolatedVerticesH\x00R\x14\x64ropIsolatedVertices\x12O\n\x0c\x66ilter_edges\x18\x07 \x01(\x0b\x32*.org.graphframes.connect.proto.FilterEdgesH\x00R\x0b\x66ilterEdges\x12X\n\x0f\x66ilter_vertices\x18\x08 \x01(\x0b\x32-.org.graphframes.connect.proto.FilterVerticesH\x00R\x0e\x66ilterVertices\x12\x39\n\x04\x66ind\x18\t \x01(\x0b\x32#.org.graphframes.connect.proto.FindH\x00R\x04\x66ind\x12^\n\x11label_propagation\x18\n \x01(\x0b\x32/.org.graphframes.connect.proto.LabelPropagationH\x00R\x10labelPropagation\x12\x46\n\tpage_rank\x18\x0b \x01(\x0b\x32\'.org.graphframes.connect.proto.PageRankH\x00R\x08pageRank\x12\x84\x01\n\x1fparallel_personalized_page_rank\x18\x0c \x01(\x0b\x32;.org.graphframes.connect.proto.ParallelPersonalizedPageRankH\x00R\x1cparallelPersonalizedPageRank\x12w\n\x1apower_iteration_clustering\x18\r \x01(\x0b\x32\x37.org.graphframes.connect.proto.PowerIterationClusteringH\x00R\x18powerIterationClustering\x12?\n\x06pregel\x18\x0e \x01(\x0b\x32%.org.graphframes.connect.proto.PregelH\x00R\x06pregel\x12U\n\x0eshortest_paths\x18\x0f \x01(\x0b\x32,.org.graphframes.connect.proto.ShortestPathsH\x00R\rshortestPaths\x12\x80\x01\n\x1dstrongly_connected_components\x18\x10 \x01(\x0b\x32:.org.graphframes.connect.proto.StronglyConnectedComponentsH\x00R\x1bstronglyConnectedComponents\x12P\n\rsvd_plus_plus\x18\x11 \x01(\x0b\x32*.org.graphframes.connect.proto.SVDPlusPlusH\x00R\x0bsvdPlusPlus\x12U\n\x0etriangle_count\x18\x12 \x01(\x0b\x32,.org.graphframes.connect.proto.TriangleCountH\x00R\rtriangleCount\x12\x45\n\x08triplets\x18\x13 \x01(\x0b\x32\'.org.graphframes.connect.proto.TripletsH\x00R\x08tripletsB\x08\n\x06method"M\n\x12\x43olumnOrExpression\x12\x12\n\x03\x63ol\x18\x01 \x01(\x0cH\x00R\x03\x63ol\x12\x14\n\x04\x65xpr\x18\x02 \x01(\tH\x00R\x04\x65xprB\r\n\x0b\x63ol_or_expr"P\n\x0eStringOrLongID\x12\x19\n\x07long_id\x18\x01 \x01(\x03H\x00R\x06longId\x12\x1d\n\tstring_id\x18\x02 \x01(\tH\x00R\x08stringIdB\x04\n\x02id"\xaf\x02\n\x11\x41ggregateMessages\x12J\n\x07\x61gg_col\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06\x61ggCol\x12V\n\x0bsend_to_src\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x00R\tsendToSrc\x88\x01\x01\x12V\n\x0bsend_to_dst\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x01R\tsendToDst\x88\x01\x01\x42\x0e\n\x0c_send_to_srcB\x0e\n\x0c_send_to_dst"\x9d\x02\n\x03\x42\x46S\x12N\n\tfrom_expr\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x08\x66romExpr\x12J\n\x07to_expr\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06toExpr\x12R\n\x0b\x65\x64ge_filter\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\nedgeFilter\x12&\n\x0fmax_path_length\x18\x04 \x01(\x05R\rmaxPathLength"\x95\x01\n\x13\x43onnectedComponents\x12\x1c\n\talgorithm\x18\x01 \x01(\tR\talgorithm\x12/\n\x13\x63heckpoint_interval\x18\x02 \x01(\x05R\x12\x63heckpointInterval\x12/\n\x13\x62roadcast_threshold\x18\x03 \x01(\x05R\x12\x62roadcastThreshold"\x16\n\x14\x44ropIsolatedVertices"^\n\x0b\x46ilterEdges\x12O\n\tcondition\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition"a\n\x0e\x46ilterVertices\x12O\n\tcondition\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition" \n\x04\x46ind\x12\x18\n\x07pattern\x18\x01 \x01(\tR\x07pattern"-\n\x10LabelPropagation\x12\x19\n\x08max_iter\x18\x01 \x01(\x05R\x07maxIter"\xe2\x01\n\x08PageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12O\n\tsource_id\x18\x02 \x01(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDH\x00R\x08sourceId\x88\x01\x01\x12\x1e\n\x08max_iter\x18\x03 \x01(\x05H\x01R\x07maxIter\x88\x01\x01\x12\x15\n\x03tol\x18\x04 \x01(\x01H\x02R\x03tol\x88\x01\x01\x42\x0c\n\n_source_idB\x0b\n\t_max_iterB\x06\n\x04_tol"\xb4\x01\n\x1cParallelPersonalizedPageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12L\n\nsource_ids\x18\x02 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tsourceIds\x12\x19\n\x08max_iter\x18\x03 \x01(\x05R\x07maxIter"v\n\x18PowerIterationClustering\x12\x0c\n\x01k\x18\x01 \x01(\x05R\x01k\x12\x19\n\x08max_iter\x18\x02 \x01(\x05R\x07maxIter\x12"\n\nweight_col\x18\x03 \x01(\tH\x00R\tweightCol\x88\x01\x01\x42\r\n\x0b_weight_col"\x8f\x05\n\x06Pregel\x12L\n\x08\x61gg_msgs\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x07\x61ggMsgs\x12X\n\x0fsend_msg_to_dst\x18\x02 \x03(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToDst\x12X\n\x0fsend_msg_to_src\x18\x03 \x03(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToSrc\x12/\n\x13\x63heckpoint_interval\x18\x04 \x01(\x05R\x12\x63heckpointInterval\x12\x19\n\x08max_iter\x18\x05 \x01(\x05R\x07maxIter\x12.\n\x13\x61\x64\x64itional_col_name\x18\x06 \x01(\tR\x11\x61\x64\x64itionalColName\x12g\n\x16\x61\x64\x64itional_col_initial\x18\x07 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x14\x61\x64\x64itionalColInitial\x12_\n\x12\x61\x64\x64itional_col_upd\x18\x08 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x10\x61\x64\x64itionalColUpd\x12*\n\x0e\x65\x61rly_stopping\x18\t \x01(\x08H\x00R\rearlyStopping\x88\x01\x01\x42\x11\n\x0f_early_stopping"\\\n\rShortestPaths\x12K\n\tlandmarks\x18\x01 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tlandmarks"8\n\x1bStronglyConnectedComponents\x12\x19\n\x08max_iter\x18\x01 \x01(\x05R\x07maxIter"\xd6\x01\n\x0bSVDPlusPlus\x12\x12\n\x04rank\x18\x01 \x01(\x05R\x04rank\x12\x19\n\x08max_iter\x18\x02 \x01(\x05R\x07maxIter\x12\x1b\n\tmin_value\x18\x03 \x01(\x01R\x08minValue\x12\x1b\n\tmax_value\x18\x04 \x01(\x01R\x08maxValue\x12\x16\n\x06gamma1\x18\x05 \x01(\x01R\x06gamma1\x12\x16\n\x06gamma2\x18\x06 \x01(\x01R\x06gamma2\x12\x16\n\x06gamma6\x18\x07 \x01(\x01R\x06gamma6\x12\x16\n\x06gamma7\x18\x08 \x01(\x01R\x06gamma7"\x0f\n\rTriangleCount"\n\n\x08TripletsB\xd2\x01\n!com.org.graphframes.connect.protoB\x10GraphframesProtoH\x01P\x01\xa0\x01\x01\xa2\x02\x04OGCP\xaa\x02\x1dOrg.Graphframes.Connect.Proto\xca\x02\x1dOrg\\Graphframes\\Connect\\Proto\xe2\x02)Org\\Graphframes\\Connect\\Proto\\GPBMetadata\xea\x02 Org::Graphframes::Connect::Protob\x06proto3' ) _globals = globals() @@ -59,15 +59,15 @@ _globals["_POWERITERATIONCLUSTERING"]._serialized_start = 3296 _globals["_POWERITERATIONCLUSTERING"]._serialized_end = 3414 _globals["_PREGEL"]._serialized_start = 3417 - _globals["_PREGEL"]._serialized_end = 4009 - _globals["_SHORTESTPATHS"]._serialized_start = 4011 - _globals["_SHORTESTPATHS"]._serialized_end = 4103 - _globals["_STRONGLYCONNECTEDCOMPONENTS"]._serialized_start = 4105 - _globals["_STRONGLYCONNECTEDCOMPONENTS"]._serialized_end = 4161 - _globals["_SVDPLUSPLUS"]._serialized_start = 4164 - _globals["_SVDPLUSPLUS"]._serialized_end = 4378 - _globals["_TRIANGLECOUNT"]._serialized_start = 4380 - _globals["_TRIANGLECOUNT"]._serialized_end = 4395 - _globals["_TRIPLETS"]._serialized_start = 4397 - _globals["_TRIPLETS"]._serialized_end = 4407 + _globals["_PREGEL"]._serialized_end = 4072 + _globals["_SHORTESTPATHS"]._serialized_start = 4074 + _globals["_SHORTESTPATHS"]._serialized_end = 4166 + _globals["_STRONGLYCONNECTEDCOMPONENTS"]._serialized_start = 4168 + _globals["_STRONGLYCONNECTEDCOMPONENTS"]._serialized_end = 4224 + _globals["_SVDPLUSPLUS"]._serialized_start = 4227 + _globals["_SVDPLUSPLUS"]._serialized_end = 4441 + _globals["_TRIANGLECOUNT"]._serialized_start = 4443 + _globals["_TRIANGLECOUNT"]._serialized_end = 4458 + _globals["_TRIPLETS"]._serialized_start = 4460 + _globals["_TRIPLETS"]._serialized_end = 4470 # @@protoc_insertion_point(module_scope) diff --git a/python/graphframes/connect/proto/graphframes_pb2.pyi b/python/graphframes/connect/proto/graphframes_pb2.pyi index 00649305c..0306ae927 100644 --- a/python/graphframes/connect/proto/graphframes_pb2.pyi +++ b/python/graphframes/connect/proto/graphframes_pb2.pyi @@ -251,6 +251,7 @@ class Pregel(_message.Message): "additional_col_name", "additional_col_initial", "additional_col_upd", + "early_stopping", ) AGG_MSGS_FIELD_NUMBER: _ClassVar[int] SEND_MSG_TO_DST_FIELD_NUMBER: _ClassVar[int] @@ -260,6 +261,7 @@ class Pregel(_message.Message): ADDITIONAL_COL_NAME_FIELD_NUMBER: _ClassVar[int] ADDITIONAL_COL_INITIAL_FIELD_NUMBER: _ClassVar[int] ADDITIONAL_COL_UPD_FIELD_NUMBER: _ClassVar[int] + EARLY_STOPPING_FIELD_NUMBER: _ClassVar[int] agg_msgs: ColumnOrExpression send_msg_to_dst: _containers.RepeatedCompositeFieldContainer[ColumnOrExpression] send_msg_to_src: _containers.RepeatedCompositeFieldContainer[ColumnOrExpression] @@ -268,6 +270,7 @@ class Pregel(_message.Message): additional_col_name: str additional_col_initial: ColumnOrExpression additional_col_upd: ColumnOrExpression + early_stopping: bool def __init__( self, agg_msgs: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., @@ -278,6 +281,7 @@ class Pregel(_message.Message): additional_col_name: _Optional[str] = ..., additional_col_initial: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., additional_col_upd: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., + early_stopping: bool = ..., ) -> None: ... class ShortestPaths(_message.Message): diff --git a/python/graphframes/lib/pregel.py b/python/graphframes/lib/pregel.py index 0d9c5c25f..642b45fe6 100644 --- a/python/graphframes/lib/pregel.py +++ b/python/graphframes/lib/pregel.py @@ -106,6 +106,23 @@ def setCheckpointInterval(self, value: int) -> "Pregel": self._java_obj.setCheckpointInterval(int(value)) return self + def setEarlyStopping(self, value: bool) -> "Pregel": + """ + Set should Pregel stop earlier in case of no new messages to send or not. + + Early stopping allows to terminate Pregel before reaching maxIter by checking if there are any non-null messages. + While in some cases it may gain significant performance boost, in other cases it can lead to performance degradation, + because checking if the messages DataFrame is empty or not is an action and requires materialization of the Spark Plan + with some additional computations. + + In the case when the user can assume a good value of maxIter, it is recommended to leave this value to the default "false". + In the case when it is hard to estimate the number of iterations required for convergence, + it is recommended to set this value to "false" to avoid iterating over convergence until reaching maxIter. + When this value is "true", maxIter can be set to a bigger value without risks. + """ # noqa: E501 + self._java_obj.setEarlyStopping(bool(value)) + return self + def withVertexColumn( self, colName: str, initialExpr: Any, updateAfterAggMsgsExpr: Any ) -> "Pregel": diff --git a/python/tests/test_graphframes.py b/python/tests/test_graphframes.py index e46f2f40f..1a889ee37 100644 --- a/python/tests/test_graphframes.py +++ b/python/tests/test_graphframes.py @@ -207,6 +207,51 @@ def test_page_rank(spark): for a, b in zip(result, expected): assert a == pytest.approx(b, abs=1e-3) +def test_pregel_early_stopping(spark): + edges = spark.createDataFrame( + [ + [0, 1], + [1, 2], + [2, 4], + [2, 0], + [3, 4], # 3 has no in-links + [4, 0], + [4, 2], + ], + ["src", "dst"], + ) + edges.cache() + vertices = spark.createDataFrame([[0], [1], [2], [3], [4]], ["id"]) + numVertices = vertices.count() + + vertices = GraphFrame(vertices, edges).outDegrees + vertices.toPandas().head() + vertices.cache() + + # Construct a new GraphFrame with the updated vertices DataFrame. + graph = GraphFrame(vertices, edges) + alpha = 0.15 + pregel = graph.pregel + ranks = ( + graph.pregel.setMaxIter(5).setEarlyStopping(True) + .withVertexColumn( + "rank", + sqlfunctions.lit(1.0 / numVertices), + sqlfunctions.coalesce(pregel.msg(), sqlfunctions.lit(0.0)) + * sqlfunctions.lit(1.0 - alpha) + + sqlfunctions.lit(alpha / numVertices), + ) + .sendMsgToDst(pregel.src("rank") / pregel.src("outDegree")) + .aggMsgs(sqlfunctions.sum(pregel.msg())) + .run() + ) + resultRows = ranks.sort("id").collect() + result = map(lambda x: x.rank, resultRows) + expected = [0.245, 0.224, 0.303, 0.03, 0.197] + + # Compare each result with its expected value using a tolerance of 1e-3. + for a, b in zip(result, expected): + assert a == pytest.approx(b, abs=1e-3) def _hasCols(graph, vcols=[], ecols=[]): for c in vcols: diff --git a/src/main/scala/org/graphframes/lib/Pregel.scala b/src/main/scala/org/graphframes/lib/Pregel.scala index fd4ec96b7..8f8abd238 100644 --- a/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/src/main/scala/org/graphframes/lib/Pregel.scala @@ -25,6 +25,8 @@ import org.graphframes.GraphFrame._ import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.functions.{array, col, explode, struct} +import scala.util.control.Breaks.{breakable, break} + /** * Implements a Pregel-like bulk-synchronous message-passing API based on DataFrame operations. * @@ -79,6 +81,7 @@ class Pregel(val graph: GraphFrame) { private var maxIter: Int = 10 private var checkpointInterval = 2 + private var earlyStopping = false private var sendMsgs = collection.mutable.ListBuffer.empty[(Column, Column)] private var aggMsgsCol: Column = null @@ -104,6 +107,30 @@ class Pregel(val graph: GraphFrame) { this } + /** + * Should Pregel stop earlier in case of no new messages to send? + * + * Early stopping allows to terminate Pregel before reaching maxIter by checking is there any + * non-null message or not. While in some cases it may gain significant performance boost, it + * other cases it can tend to performance degradation, because checking is messages DataFrame is + * empty or not is an action and requires materialization of the Spark Plan with some additional + * computations. + * + * In the case when user can assume a good value of maxIter it is recommended to leave this + * value to the default "false". In the case when it is hard to estimate an amount of iterations + * required for convergence, it is recommended to set this value to "false" to avoid iterating + * over convergence until reaching maxIter. When this value is "true", maxIter can be set to a + * bigger value without risks. + * + * @param value + * should Pregel checks for the termination condition on each step + * @return + */ + def setEarlyStopping(value: Boolean): this.type = { + earlyStopping = value + this + } + /** * Defines an additional vertex column at the start of run and how to update it in each * iteration. @@ -231,54 +258,63 @@ class Pregel(val graph: GraphFrame) { val shouldCheckpoint = checkpointInterval > 0 - while (iteration <= maxIter) { - val tripletsDF = currentVertices - .select(struct(col("*")).as(SRC)) - .join(edges.select(struct(col("*")).as(EDGE)), Pregel.src(ID) === Pregel.edge(SRC)) - .join( - currentVertices.select(struct(col("*")).as(DST)), - Pregel.edge(DST) === Pregel.dst(ID)) - - var msgDF: DataFrame = tripletsDF - .select(explode(array(sendMsgsColList: _*)).as("msg")) - .select(col("msg.id"), col("msg.msg").as(Pregel.MSG_COL_NAME)) - - val newAggMsgDF = msgDF - .filter(Pregel.msg.isNotNull) - .groupBy(ID) - .agg(aggMsgsCol.as(Pregel.MSG_COL_NAME)) - - val verticesWithMsg = currentVertices.join(newAggMsgDF, Seq(ID), "left_outer") - - var newVertexUpdateColDF = verticesWithMsg.select((col(ID) :: updateVertexCols): _*) - - if (shouldCheckpoint && graph.spark.sparkContext.getCheckpointDir.isEmpty) { - // Spark Connect workaround - graph.spark.conf.getOption("spark.checkpoint.dir") match { - case Some(d) => graph.spark.sparkContext.setCheckpointDir(d) - case None => - throw new IOException( - "Checkpoint directory is not set. Please set it first using sc.setCheckpointDir()" + - "or by specifying the conf 'spark.checkpoint.dir'.") - } + if (shouldCheckpoint && graph.spark.sparkContext.getCheckpointDir.isEmpty) { + // Spark Connect workaround + graph.spark.conf.getOption("spark.checkpoint.dir") match { + case Some(d) => graph.spark.sparkContext.setCheckpointDir(d) + case None => + throw new IOException( + "Checkpoint directory is not set. Please set it first using sc.setCheckpointDir()" + + "or by specifying the conf 'spark.checkpoint.dir'.") } + } - if (shouldCheckpoint && iteration % checkpointInterval == 0) { - // do checkpoint, use lazy checkpoint because later we will materialize this DF. - newVertexUpdateColDF = newVertexUpdateColDF.checkpoint(eager = false) - // TODO: remove last checkpoint file. - } - newVertexUpdateColDF.cache() - newVertexUpdateColDF.count() // materialize it + breakable { + while (iteration <= maxIter) { + val tripletsDF = currentVertices + .select(struct(col("*")).as(SRC)) + .join(edges.select(struct(col("*")).as(EDGE)), Pregel.src(ID) === Pregel.edge(SRC)) + .join( + currentVertices.select(struct(col("*")).as(DST)), + Pregel.edge(DST) === Pregel.dst(ID)) + + val msgDF: DataFrame = tripletsDF + .select(explode(array(sendMsgsColList: _*)).as("msg")) + .select(col("msg.id"), col("msg.msg").as(Pregel.MSG_COL_NAME)) + .filter(Pregel.msg.isNotNull) + + if (earlyStopping && msgDF.isEmpty) { + if (vertexUpdateColDF != null) { + vertexUpdateColDF.unpersist() + } + break + } - if (vertexUpdateColDF != null) { - vertexUpdateColDF.unpersist() - } - vertexUpdateColDF = newVertexUpdateColDF + val newAggMsgDF = msgDF + .groupBy(ID) + .agg(aggMsgsCol.as(Pregel.MSG_COL_NAME)) + + val verticesWithMsg = currentVertices.join(newAggMsgDF, Seq(ID), "left_outer") - currentVertices = graph.vertices.join(vertexUpdateColDF, ID) + var newVertexUpdateColDF = verticesWithMsg.select((col(ID) :: updateVertexCols): _*) - iteration += 1 + if (shouldCheckpoint && iteration % checkpointInterval == 0) { + // do checkpoint, use lazy checkpoint because later we will materialize this DF. + newVertexUpdateColDF = newVertexUpdateColDF.checkpoint(eager = false) + // TODO: remove last checkpoint file. + } + newVertexUpdateColDF.cache() + newVertexUpdateColDF.count() // materialize it + + if (vertexUpdateColDF != null) { + vertexUpdateColDF.unpersist() + } + vertexUpdateColDF = newVertexUpdateColDF + + currentVertices = graph.vertices.join(vertexUpdateColDF, ID) + + iteration += 1 + } } currentVertices diff --git a/src/test/scala/org/graphframes/lib/PregelSuite.scala b/src/test/scala/org/graphframes/lib/PregelSuite.scala index dc6cc35c8..669172a56 100644 --- a/src/test/scala/org/graphframes/lib/PregelSuite.scala +++ b/src/test/scala/org/graphframes/lib/PregelSuite.scala @@ -110,4 +110,27 @@ class PregelSuite extends SparkFunSuite with GraphFrameTestSparkContext { assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1)) } + test("chain propagation with termination") { + val n = 5 + val verDF = (1 to n).toDF("id").repartition(3) + val edgeDF = (1 until n) + .map(x => (x, x + 1)) + .toDF("src", "dst") + .repartition(3) + + val graph = GraphFrame(verDF, edgeDF) + + val resultDF = graph.pregel + .setMaxIter(1000) + .setEarlyStopping(true) + .withVertexColumn( + "value", + when(col("id") === lit(1), lit(1)).otherwise(lit(0)), + when(Pregel.msg > col("value"), Pregel.msg).otherwise(col("value"))) + .sendMsgToDst(when(Pregel.dst("value") =!= Pregel.src("value"), Pregel.src("value"))) + .aggMsgs(max(Pregel.msg)) + .run() + + assert(resultDF.sort("id").select("value").as[Int].collect() === Array.fill(n)(1)) + } }