From c15881571acf0dd9fbedead719363ea2dfa626a3 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Fri, 7 Feb 2025 13:36:06 +0100 Subject: [PATCH 01/27] wip --- buf.gen.yaml | 12 + buf.yaml | 3 + build.sbt | 88 ++- .../src/main/protobuf/graphframes.proto | 139 +++++ .../sql/graphframes/GraphFramesConnect.scala | 23 + .../graphframes/GraphFramesConnectUtils.scala | 205 +++++++ project/plugins.sbt | 4 + .../graphframes/connect/graphframe_client.py | 502 ++++++++++++++++++ .../connect/proto/graphframes_pb2.py | 79 +++ .../connect/proto/graphframes_pb2.pyi | 222 ++++++++ .../connect/proto/graphframes_pb2_grpc.py | 4 + python/graphframes/connect/utils.py | 19 + 12 files changed, 1255 insertions(+), 45 deletions(-) create mode 100644 buf.gen.yaml create mode 100644 buf.yaml create mode 100644 graphframes-connect/src/main/protobuf/graphframes.proto create mode 100644 graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnect.scala create mode 100644 graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala create mode 100644 python/graphframes/connect/graphframe_client.py create mode 100644 python/graphframes/connect/proto/graphframes_pb2.py create mode 100644 python/graphframes/connect/proto/graphframes_pb2.pyi create mode 100644 python/graphframes/connect/proto/graphframes_pb2_grpc.py create mode 100644 python/graphframes/connect/utils.py diff --git a/buf.gen.yaml b/buf.gen.yaml new file mode 100644 index 000000000..33c55306b --- /dev/null +++ b/buf.gen.yaml @@ -0,0 +1,12 @@ +version: v2 +managed: + enabled: true + +plugins: + # Python API + - remote: buf.build/grpc/python:v1.64.2 + out: python/graphframes/connect/proto + - remote: buf.build/protocolbuffers/python:v27.1 + out: python/graphframes/connect/proto + - remote: buf.build/protocolbuffers/pyi + out: python/graphframes/connect/proto \ No newline at end of file diff --git a/buf.yaml b/buf.yaml new file mode 100644 index 000000000..e0cdbd729 --- /dev/null +++ b/buf.yaml @@ -0,0 +1,3 @@ +version: v2 +modules: + - path: graphframes-connect/src/main/protobuf \ No newline at end of file diff --git a/build.sbt b/build.sbt index 061901717..cf4bb3ea4 100644 --- a/build.sbt +++ b/build.sbt @@ -20,53 +20,44 @@ ThisBuild / scalaVersion := scalaVer ThisBuild / organization := "org.graphframes" ThisBuild / crossScalaVersions := Seq("2.12.18", "2.13.8") +lazy val commonSetting = Seq( + libraryDependencies ++= Seq( + "org.apache.spark" %% "spark-graphx" % sparkVer % "provided" cross CrossVersion.for3Use2_13, + "org.apache.spark" %% "spark-sql" % sparkVer % "provided" cross CrossVersion.for3Use2_13, + "org.apache.spark" %% "spark-mllib" % sparkVer % "provided" cross CrossVersion.for3Use2_13, + "org.slf4j" % "slf4j-api" % "2.0.16", + "org.scalatest" %% "scalatest" % defaultScalaTestVer % Test, + "com.github.zafarkhaja" % "java-semver" % "0.10.2" % Test), + credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials"), + licenses := Seq("Apache-2.0" -> url("https://opensource.org/licenses/Apache-2.0")), + Compile / scalacOptions ++= Seq("-deprecation", "-feature"), + Compile / doc / scalacOptions ++= Seq( + "-groups", + "-implicits", + "-skip-packages", + Seq("org.apache.spark").mkString(":")), + Test / doc / scalacOptions ++= Seq("-groups", "-implicits"), + + // Test settings + Test / fork := true, + Test / parallelExecution := false, + Test / javaOptions ++= Seq( + "-XX:+IgnoreUnrecognizedVMOptions", + "-Xmx2048m", + "-XX:ReservedCodeCacheSize=384m", + "-XX:MaxMetaspaceSize=384m", + "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED", + "--add-opens=java.base/java.lang=ALL-UNNAMED")) + lazy val root = (project in file(".")) .settings( + commonSetting, name := "graphframes", - - // Replace spark-packages plugin functionality with explicit dependencies - libraryDependencies ++= Seq( - "org.apache.spark" %% "spark-graphx" % sparkVer % "provided" cross CrossVersion.for3Use2_13, - "org.apache.spark" %% "spark-sql" % sparkVer % "provided" cross CrossVersion.for3Use2_13, - "org.apache.spark" %% "spark-mllib" % sparkVer % "provided" cross CrossVersion.for3Use2_13, - "org.slf4j" % "slf4j-api" % "1.7.16", - "org.scalatest" %% "scalatest" % defaultScalaTestVer % Test, - "com.github.zafarkhaja" % "java-semver" % "0.9.0" % Test - ), - - licenses := Seq("Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0")), - - // Modern way to set Scala options Compile / scalacOptions ++= Seq("-deprecation", "-feature"), - Compile / doc / scalacOptions ++= Seq( - "-groups", - "-implicits", - "-skip-packages", Seq("org.apache.spark").mkString(":") - ), - - Test / doc / scalacOptions ++= Seq("-groups", "-implicits"), - - // Test settings - Test / fork := true, - Test / parallelExecution := false, - - Test / javaOptions ++= Seq( - "-XX:+IgnoreUnrecognizedVMOptions", - "-Xmx2048m", - "-XX:ReservedCodeCacheSize=384m", - "-XX:MaxMetaspaceSize=384m", - "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED", - "--add-opens=java.base/java.lang=ALL-UNNAMED" - ), - // Global settings - Global / concurrentRestrictions := Seq( - Tags.limitAll(1) - ), - + Global / concurrentRestrictions := Seq(Tags.limitAll(1)), autoAPIMappings := true, - coverageHighlighting := false, // Release settings @@ -76,8 +67,7 @@ lazy val root = (project in file(".")) commitReleaseVersion, tagRelease, setNextVersion, - commitNextVersion - ), + commitNextVersion), // Assembly settings assembly / test := {}, // No tests in assembly @@ -87,7 +77,15 @@ lazy val root = (project in file(".")) case x => val oldStrategy = (assembly / assemblyMergeStrategy).value oldStrategy(x) - }, + }) - credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials") - ) \ No newline at end of file +lazy val connect = (project in file("graphframes-connect")) + .dependsOn(root) + .settings( + commonSetting, + name := "graphframes-connect", + Compile / PB.targets := Seq(PB.gens.java -> (Compile / sourceManaged).value), + Compile / PB.includePaths ++= Seq(file("src/main/protobuf")), + PB.protocVersion := "3.23.4", // Spark 3.5 branch + libraryDependencies ++= Seq( + "org.apache.spark" %% "spark-connect" % sparkVer % "provided" cross CrossVersion.for3Use2_13)) diff --git a/graphframes-connect/src/main/protobuf/graphframes.proto b/graphframes-connect/src/main/protobuf/graphframes.proto new file mode 100644 index 000000000..2cdd0febf --- /dev/null +++ b/graphframes-connect/src/main/protobuf/graphframes.proto @@ -0,0 +1,139 @@ +syntax = 'proto3'; + +package org.graphframes.connect.proto; + +option java_multiple_files = true; +option java_package = "org.graphframes.connect.proto"; +option java_generate_equals_and_hash = true; +option optimize_for=SPEED; + + +message GraphFramesAPI { + bytes vertices = 1; + bytes edges = 2; + oneof method { + AggregateMessages aggregate_messages = 3; + BFS bfs = 4; + ConnectedComponents connected_components = 5; + Degrees degrees = 6; + DropIsolatedVertices drop_isolated_vertices = 7; + FilterEdges filter_edges = 8; + FilterVertices filter_vertices = 9; + Find find = 10; + InDegrees in_degrees = 11; + LabelPropagation label_propagation = 12; + OutDegrees out_degrees = 13; + PageRank page_rank = 14; + ParallelPersonalizedPageRank parallel_personalized_page_rank = 15; + Pregel pregel = 16; + ShortestPaths shortest_paths = 17; + StronglyConnectedComponents strongly_connected_components = 18; + SVDPlusPlus svd_plus_plus = 19; + TriangleCount triangle_count = 20; + Triplets triplets = 21; + } +} + +message ColumnOrExpression { + oneof col_or_expr { + bytes col = 1; + string expr = 2; + } +} + +message StringOrLongID { + oneof id { + int64 long_id = 1; + string string_id = 2; + } +} + +message AggregateMessages { + ColumnOrExpression agg_col = 1; + optional ColumnOrExpression send_to_src = 2; + optional ColumnOrExpression send_to_dst = 3; +} + +message BFS { + ColumnOrExpression from_expr = 1; + ColumnOrExpression to_expr = 2; + ColumnOrExpression edge_filter = 3; + int32 max_path_length = 4; +} + +message ConnectedComponents { + string algorithm = 1; + int32 checkpoint_interval = 2; + int32 broadcast_threshold = 3; +} + +message Degrees {} + +message DropIsolatedVertices {} + +message FilterEdges { + ColumnOrExpression condition = 1; +} + +message FilterVertices { + ColumnOrExpression condition = 2; +} + +message Find { + string pattern = 1; +} + +message InDegrees {} + +message LabelPropagation { + int32 max_iter = 1; +} + +message OutDegrees {} + +message PageRank { + double reset_probability = 1; + optional StringOrLongID source_id = 2; + int32 max_iter = 3; + double tol = 4; +} + +message ParallelPersonalizedPageRank { + double reset_probability = 1; + repeated StringOrLongID source_ids = 2; + int32 max_iter = 3; +} + +message Pregel { + ColumnOrExpression agg_msgs = 1; + ColumnOrExpression send_msg_to_dst = 2; + ColumnOrExpression send_msg_to_src = 3; + int32 checkpoint_interval = 4; + int32 max_iter = 5; + string additional_col_name = 6; + ColumnOrExpression additional_col_initial = 7; + ColumnOrExpression additional_col_upd = 8; +} + +message ShortestPaths { + repeated StringOrLongID landmarks = 1; +} + +message StronglyConnectedComponents { + int32 max_iter = 1; +} + +message SVDPlusPlus { + int32 rank = 1; + int32 max_iter = 2; + double min_value = 3; + double max_value = 4; + double gamma1 = 5; + double gamma2 = 6; + double gamma6 = 7; + double gamma7 = 8; +} + +message TriangleCount {} + +message Triplets {} diff --git a/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnect.scala b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnect.scala new file mode 100644 index 000000000..d079612c2 --- /dev/null +++ b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnect.scala @@ -0,0 +1,23 @@ +package org.apache.spark.sql.graphframes + +import org.graphframes.connect.proto.GraphFramesAPI + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connect.planner.SparkConnectPlanner +import org.apache.spark.sql.connect.plugin.RelationPlugin + +import com.google.protobuf + +class GraphFramesConnect extends RelationPlugin { + override def transform( + relation: protobuf.Any, + planner: SparkConnectPlanner): Option[LogicalPlan] = { + if (relation.is(classOf[GraphFramesAPI])) { + val protoCall = relation.unpack(classOf[GraphFramesAPI]) + // Because the plugins API is changed in spark 4.0 it makes sense to separate plugin impl from the parsing logic + Option(GraphFramesConnectUtils.parseAPICall(protoCall, planner).logicalPlan) + } else { + Option.empty + } + } +} diff --git a/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala new file mode 100644 index 000000000..062241d9e --- /dev/null +++ b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala @@ -0,0 +1,205 @@ +// Because Dataset.ofRows is private[sql] we are forced to use spark package +// Same about Column helper object. +package org.apache.spark.sql.graphframes + +import scala.jdk.CollectionConverters._ + +import org.graphframes.GraphFrame +import org.graphframes.connect.proto.{ColumnOrExpression, GraphFramesAPI, StringOrLongID} +import org.graphframes.connect.proto.ColumnOrExpression.ColOrExprCase +import org.graphframes.connect.proto.GraphFramesAPI.MethodCase +import org.graphframes.connect.proto.StringOrLongID.IdCase + +import org.apache.spark.sql.{Column, DataFrame, Dataset} +import org.apache.spark.sql.connect.planner.SparkConnectPlanner +import org.apache.spark.sql.functions.{expr, lit} + +import com.google.protobuf.ByteString + +object GraphFramesConnectUtils { + private[graphframes] def parseColumnOrExpression( + colOrExpr: ColumnOrExpression, + planner: SparkConnectPlanner): Column = { + colOrExpr.getColOrExprCase match { + case ColOrExprCase.COL => + Column( + planner.transformExpression( + org.apache.spark.connect.proto.Expression.parseFrom(colOrExpr.getCol.toByteArray))) + case ColOrExprCase.EXPR => expr(colOrExpr.getExpr) + case _ => + throw new RuntimeException( + "INTERNAL ERROR: unreachable case in function parseColumnOrExpression") + } + } + + private[graphframes] def parseLongOrStringID(id: StringOrLongID): Any = { + id.getIdCase match { + case IdCase.LONG_ID => id.getLongId + case IdCase.STRING_ID => id.getStringId + case _ => + throw new RuntimeException( + "INTERNAL ERROR: unreachable case in function parseLongOrStringID") + } + } + + private[graphframes] def parseDataFrame( + data: ByteString, + planner: SparkConnectPlanner): DataFrame = { + Dataset.ofRows( + planner.sessionHolder.session, + planner.transformRelation( + org.apache.spark.connect.proto.Relation.parseFrom(data.toByteArray))) + } + + private[graphframes] def extractGraphFrame( + apiMessage: GraphFramesAPI, + planner: SparkConnectPlanner): GraphFrame = { + val vertices = parseDataFrame(apiMessage.getVertices, planner) + val edges = parseDataFrame(apiMessage.getEdges, planner) + + GraphFrame(vertices, edges) + } + + private[graphframes] def parseAPICall( + apiMessage: GraphFramesAPI, + planner: SparkConnectPlanner): DataFrame = { + val graphFrame = extractGraphFrame(apiMessage, planner) + + apiMessage.getMethodCase match { + case MethodCase.AGGREGATE_MESSAGES => { + val aggregateMessagesProto = apiMessage.getAggregateMessages + var aggregateMessages = graphFrame.aggregateMessages + if (aggregateMessagesProto.hasSendToDst) { + aggregateMessages = aggregateMessages.sendToDst( + parseColumnOrExpression(aggregateMessagesProto.getSendToDst, planner)) + } + if (aggregateMessagesProto.hasSendToSrc) { + aggregateMessages = aggregateMessages.sendToSrc( + parseColumnOrExpression(aggregateMessagesProto.getSendToSrc, planner)) + } + + aggregateMessages.agg(parseColumnOrExpression(aggregateMessagesProto.getAggCol, planner)) + } + case MethodCase.BFS => { + val bfsProto = apiMessage.getBfs + graphFrame.bfs + .toExpr(parseColumnOrExpression(bfsProto.getToExpr, planner)) + .fromExpr(parseColumnOrExpression(bfsProto.getFromExpr, planner)) + .edgeFilter(parseColumnOrExpression(bfsProto.getEdgeFilter, planner)) + .maxPathLength(bfsProto.getMaxPathLength) + .run() + } + case MethodCase.CONNECTED_COMPONENTS => { + val cc = apiMessage.getConnectedComponents + graphFrame.connectedComponents + .setAlgorithm(cc.getAlgorithm) + .setCheckpointInterval(cc.getCheckpointInterval) + .setBroadcastThreshold(cc.getBroadcastThreshold) + .run() + } + case MethodCase.DEGREES => { + graphFrame.degrees + } + case MethodCase.DROP_ISOLATED_VERTICES => { + graphFrame.dropIsolatedVertices().vertices + } + case MethodCase.FILTER_EDGES => { + val condition = parseColumnOrExpression(apiMessage.getFilterEdges.getCondition, planner) + graphFrame.filterEdges(condition).edges + } + case MethodCase.FILTER_VERTICES => { + val condition = + parseColumnOrExpression(apiMessage.getFilterVertices.getCondition, planner) + graphFrame.filterVertices(condition).vertices + } + case MethodCase.FIND => { + graphFrame.find(apiMessage.getFind.getPattern) + } + case MethodCase.IN_DEGREES => { + graphFrame.inDegrees + } + case MethodCase.LABEL_PROPAGATION => { + graphFrame.labelPropagation.maxIter(apiMessage.getLabelPropagation.getMaxIter).run() + } + case MethodCase.OUT_DEGREES => { + graphFrame.outDegrees + } + case MethodCase.PAGE_RANK => { + val pageRankProto = apiMessage.getPageRank + val pageRank = graphFrame.pageRank + + pageRank + .maxIter(pageRankProto.getMaxIter) + .tol(pageRankProto.getTol) + .resetProbability(pageRankProto.getResetProbability) + + if (pageRankProto.hasSourceId) { + pageRank.sourceId(parseLongOrStringID(pageRankProto.getSourceId)) + } + + // Edges should be updated on the client side + // TODO: do we really need an edge weights in that case? + pageRank.run().vertices + } + case MethodCase.PARALLEL_PERSONALIZED_PAGE_RANK => { + val pPageRankProto = apiMessage.getParallelPersonalizedPageRank + val sourceIds = pPageRankProto.getSourceIdsList.asScala + .map(parseLongOrStringID) + .toArray + val pPageRank = graphFrame.parallelPersonalizedPageRank + pPageRank + .resetProbability(pPageRankProto.getResetProbability) + .maxIter(pPageRankProto.getMaxIter) + .sourceIds(sourceIds) + .run() + .vertices // See comment in the PageRank + } + case MethodCase.PREGEL => { + val pregelProto = apiMessage.getPregel + val pregel = graphFrame.pregel + pregel + .aggMsgs(parseColumnOrExpression(pregelProto.getAggMsgs, planner)) + .sendMsgToDst(parseColumnOrExpression(pregelProto.getSendMsgToDst, planner)) + .sendMsgToDst(parseColumnOrExpression(pregelProto.getSendMsgToDst, planner)) + .setCheckpointInterval(pregelProto.getCheckpointInterval) + .setMaxIter(pregelProto.getMaxIter) + .withVertexColumn( + pregelProto.getAdditionalColName, + parseColumnOrExpression(pregelProto.getAdditionalColInitial, planner), + parseColumnOrExpression(pregelProto.getAdditionalColUpd, planner)) + .run() + } + case MethodCase.SHORTEST_PATHS => { + graphFrame.shortestPaths + .landmarks( + apiMessage.getShortestPaths.getLandmarksList.asScala.map(parseLongOrStringID)) + .run() + } + case MethodCase.STRONGLY_CONNECTED_COMPONENTS => { + graphFrame.stronglyConnectedComponents + .maxIter(apiMessage.getStronglyConnectedComponents.getMaxIter) + .run() + } + case MethodCase.SVD_PLUS_PLUS => { + val svdPPProto = apiMessage.getSvdPlusPlus + val svd = graphFrame.svdPlusPlus + .maxIter(svdPPProto.getMaxIter) + .gamma1(svdPPProto.getGamma1) + .gamma2(svdPPProto.getGamma2) + .gamma6(svdPPProto.getGamma6) + .gamma7(svdPPProto.getGamma7) + .rank(svdPPProto.getRank) + .minValue(svdPPProto.getMinValue) + .maxValue(svdPPProto.getMaxValue) + val svdResult = svd.run() + svdResult.withColumn("loss", lit(svd.loss)) + } + case MethodCase.TRIANGLE_COUNT => { + graphFrame.triangleCount.run() + } + case MethodCase.TRIPLETS => { + graphFrame.triplets + } + } + } +} diff --git a/project/plugins.sbt b/project/plugins.sbt index 46028c336..0a3e292ef 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -8,3 +8,7 @@ ThisBuild / libraryDependencySchemes ++= Seq( addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.0.10") addSbtPlugin("com.github.sbt" % "sbt-release" % "1.4.0") addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.3.1") + +// Protobuf things needed for the Spark Connect +addSbtPlugin("com.thesamet" % "sbt-protoc" % "1.0.7") +libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.10.10" diff --git a/python/graphframes/connect/graphframe_client.py b/python/graphframes/connect/graphframe_client.py new file mode 100644 index 000000000..f228eac6e --- /dev/null +++ b/python/graphframes/connect/graphframe_client.py @@ -0,0 +1,502 @@ +from typing import Self + +from pyspark.sql.connect import proto +from pyspark.sql.connect.client import SparkConnectClient +from pyspark.sql.connect.dataframe import DataFrame +from pyspark.sql.connect.plan import LogicalPlan +from pyspark.storagelevel import StorageLevel + +from .proto import graphframes_pb2 as pb +from .utils import column_to_proto, dataframe_to_proto + + +class GraphFrameConnect: + ID = "id" + SRC = "src" + DST = "dst" + EDGE = "edge" + + def __init__(self, v: DataFrame, e: DataFrame) -> None: + self._vertices = v + self._edges = e + self._spark = v.sparkSession + + if self.ID not in v.columns: + raise ValueError( + "Vertex ID column {} missing from vertex DataFrame, which has columns: {}".format( + self.ID, ",".join(v.columns) + ) + ) + if self.SRC not in e.columns: + raise ValueError( + "Source vertex ID column {} missing from edge DataFrame, which has columns: {}".format( + self.SRC, ",".join(e.columns) + ) + ) + if self.DST not in e.columns: + raise ValueError( + "Destination vertex ID column {} missing from edge DataFrame, which has columns: {}".format( + self.DST, ",".join(e.columns) + ) + ) + + @staticmethod + def _get_pb_api_message( + vertices: DataFrame, edges: DataFrame, client: SparkConnectClient + ) -> pb.GraphFramesAPI: + return pb.GraphFramesAPI( + vertices=dataframe_to_proto(vertices, client), + edges=dataframe_to_proto(edges, client), + ) + + @property + def vertices(self) -> DataFrame: + """ + :class:`DataFrame` holding vertex information, with unique column "id" + for vertex IDs. + """ + return self._vertices + + @property + def edges(self) -> DataFrame: + """ + :class:`DataFrame` holding edge information, with unique columns "src" and + "dst" storing source vertex IDs and destination vertex IDs of edges, + respectively. + """ + return self._edges + + def __repr__(self) -> str: + # Exactly like in the scala core + v_cols = [self.ID] + [col for col in self.vertices.columns if col != self.ID] + e_cols = [self.SRC, self.DST] + [ + col for col in self.edges.columns if col not in {self.SRC, self.DST} + ] + v = self.vertices.select(*v_cols).__repr__() + e = self.edges.select(*e_cols).__repr__() + + return f"GraphFrame(v:{v}, e:{e})" + + def cache(self) -> Self: + """Persist the dataframe representation of vertices and edges of the graph with the default + storage level. + """ + self._vertices = self._vertices.cache() + self._edges = self._edges.cache() + return self + + def persist(self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) -> Self: + """Persist the dataframe representation of vertices and edges of the graph with the given + storage level. + """ + self._vertices = self._vertices.persist(storageLevel=storageLevel) + self._edges = self._edges.persist(storageLevel=storageLevel) + return self + + def unpersist(self, blocking: bool = False) -> Self: + """Mark the dataframe representation of vertices and edges of the graph as non-persistent, + and remove all blocks for it from memory and disk. + """ + self._vertices = self._vertices.unpersist(blocking=blocking) + self._edges = self._edges.unpersist(blocking=blocking) + return self + + @property + def outDegrees(self) -> DataFrame: + """ + The out-degree of each vertex in the graph, returned as a DataFrame with two columns: + - "id": the ID of the vertex + - "outDegree" (integer) storing the out-degree of the vertex + + Note that vertices with 0 out-edges are not returned in the result. + + :return: DataFrame with new vertices column "outDegree" + """ + + class OutDegrees(LogicalPlan): + def __init__(self, v: DataFrame, e: DataFrame) -> None: + self.v = v + self.e = e + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.out_degrees = pb.OutDegrees() + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan + + return DataFrame.withPlan(OutDegrees(self._vertices, self._edges), self._spark) + + @property + def inDegrees(self) -> DataFrame: + """ + The in-degree of each vertex in the graph, returned as a DataFame with two columns: + - "id": the ID of the vertex + - "inDegree" (int) storing the in-degree of the vertex + + Note that vertices with 0 in-edges are not returned in the result. + + :return: DataFrame with new vertices column "inDegree" + """ + + class InDegrees(LogicalPlan): + def __init__(self, v: DataFrame, e: DataFrame) -> None: + self.v = v + self.e = e + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.in_degrees = pb.InDegrees() + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan + + return DataFrame.withPlan(InDegrees(self._vertices, self._edges), self._spark) + + @property + def degrees(self) -> DataFrame: + """ + The degree of each vertex in the graph, returned as a DataFrame with two columns: + - "id": the ID of the vertex + - 'degree' (integer) the degree of the vertex + + Note that vertices with 0 edges are not returned in the result. + + :return: DataFrame with new vertices column "degree" + """ + + class Degrees(LogicalPlan): + def __init__(self, v: DataFrame, e: DataFrame) -> None: + self.v = v + self.e = e + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.degrees = pb.Degrees() + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan + + return DataFrame.withPlan(Degrees(self._vertices, self._edges), self._spark) + + @property + def triplets(self) -> DataFrame: + """ + The triplets (source vertex)-[edge]->(destination vertex) for all edges in the graph. + + Returned as a :class:`DataFrame` with three columns: + - "src": source vertex with schema matching 'vertices' + - "edge": edge with schema matching 'edges' + - 'dst': destination vertex with schema matching 'vertices' + + :return: DataFrame with columns 'src', 'edge', and 'dst' + """ + jdf = self._jvm_graph.triplets() + return DataFrame(jdf, self._spark) + + @property + def pregel(self): + """ + Get the :class:`graphframes.lib.Pregel` object for running pregel. + + See :class:`graphframes.lib.Pregel` for more details. + """ + return Pregel(self) + + def find(self, pattern: str) -> DataFrame: + """ + Motif finding. + + See Scala documentation for more details. + + :param pattern: String describing the motif to search for. + :return: DataFrame with one Row for each instance of the motif found + """ + jdf = self._jvm_graph.find(pattern) + return DataFrame(jdf, self._spark) + + def filterVertices(self, condition: Union[str, Column]) -> "GraphFrame": + """ + Filters the vertices based on expression, remove edges containing any dropped vertices. + + :param condition: String or Column describing the condition expression for filtering. + :return: GraphFrame with filtered vertices and edges. + """ + + if isinstance(condition, basestring): + jdf = self._jvm_graph.filterVertices(condition) + elif isinstance(condition, Column): + jdf = self._jvm_graph.filterVertices(condition._jc) + else: + raise TypeError("condition should be string or Column") + return _from_java_gf(jdf, self._spark) + + def filterEdges(self, condition: Union[str, Column]) -> "GraphFrame": + """ + Filters the edges based on expression, keep all vertices. + + :param condition: String or Column describing the condition expression for filtering. + :return: GraphFrame with filtered edges. + """ + if isinstance(condition, basestring): + jdf = self._jvm_graph.filterEdges(condition) + elif isinstance(condition, Column): + jdf = self._jvm_graph.filterEdges(condition._jc) + else: + raise TypeError("condition should be string or Column") + return _from_java_gf(jdf, self._spark) + + def dropIsolatedVertices(self) -> "GraphFrame": + """ + Drops isolated vertices, vertices are not contained in any edges. + + :return: GraphFrame with filtered vertices. + """ + jdf = self._jvm_graph.dropIsolatedVertices() + return _from_java_gf(jdf, self._spark) + + def bfs( + self, + fromExpr: str, + toExpr: str, + edgeFilter: Optional[str] = None, + maxPathLength: int = 10, + ) -> DataFrame: + """ + Breadth-first search (BFS). + + See Scala documentation for more details. + + :return: DataFrame with one Row for each shortest path between matching vertices. + """ + builder = ( + self._jvm_graph.bfs() + .fromExpr(fromExpr) + .toExpr(toExpr) + .maxPathLength(maxPathLength) + ) + if edgeFilter is not None: + builder.edgeFilter(edgeFilter) + jdf = builder.run() + return DataFrame(jdf, self._spark) + + def aggregateMessages( + self, + aggCol: Union[Column, str], + sendToSrc: Union[Column, str, None] = None, + sendToDst: Union[Column, str, None] = None, + ) -> DataFrame: + """ + Aggregates messages from the neighbours. + + When specifying the messages and aggregation function, the user may reference columns using + the static methods in :class:`graphframes.lib.AggregateMessages`. + + See Scala documentation for more details. + + :param aggCol: the requested aggregation output either as + :class:`pyspark.sql.Column` or SQL expression string + :param sendToSrc: message sent to the source vertex of each triplet either as + :class:`pyspark.sql.Column` or SQL expression string (default: None) + :param sendToDst: message sent to the destination vertex of each triplet either as + :class:`pyspark.sql.Column` or SQL expression string (default: None) + + :return: DataFrame with columns for the vertex ID and the resulting aggregated message + """ + # Check that either sendToSrc, sendToDst, or both are provided + if sendToSrc is None and sendToDst is None: + raise ValueError( + "Either `sendToSrc`, `sendToDst`, or both have to be provided" + ) + builder = self._jvm_graph.aggregateMessages() + if sendToSrc is not None: + if isinstance(sendToSrc, Column): + builder.sendToSrc(sendToSrc._jc) + elif isinstance(sendToSrc, basestring): + builder.sendToSrc(sendToSrc) + else: + raise TypeError("Provide message either as `Column` or `str`") + if sendToDst is not None: + if isinstance(sendToDst, Column): + builder.sendToDst(sendToDst._jc) + elif isinstance(sendToDst, basestring): + builder.sendToDst(sendToDst) + else: + raise TypeError("Provide message either as `Column` or `str`") + if isinstance(aggCol, Column): + jdf = builder.agg(aggCol._jc) + else: + jdf = builder.agg(aggCol) + return DataFrame(jdf, self._spark) + + # Standard algorithms + + def connectedComponents( + self, + algorithm: str = "graphframes", + checkpointInterval: int = 2, + broadcastThreshold: int = 1000000, + ) -> DataFrame: + """ + Computes the connected components of the graph. + + See Scala documentation for more details. + + :param algorithm: connected components algorithm to use (default: "graphframes") + Supported algorithms are "graphframes" and "graphx". + :param checkpointInterval: checkpoint interval in terms of number of iterations (default: 2) + :param broadcastThreshold: broadcast threshold in propagating component assignments + (default: 1000000) + + :return: DataFrame with new vertices column "component" + """ + jdf = ( + self._jvm_graph.connectedComponents() + .setAlgorithm(algorithm) + .setCheckpointInterval(checkpointInterval) + .setBroadcastThreshold(broadcastThreshold) + .run() + ) + return DataFrame(jdf, self._spark) + + def labelPropagation(self, maxIter: int) -> DataFrame: + """ + Runs static label propagation for detecting communities in networks. + + See Scala documentation for more details. + + :param maxIter: the number of iterations to be performed + :return: DataFrame with new vertices column "label" + """ + jdf = self._jvm_graph.labelPropagation().maxIter(maxIter).run() + return DataFrame(jdf, self._spark) + + def pageRank( + self, + resetProbability: float = 0.15, + sourceId: Optional[Any] = None, + maxIter: Optional[int] = None, + tol: Optional[float] = None, + ) -> "GraphFrame": + """ + Runs the PageRank algorithm on the graph. + Note: Exactly one of fixed_num_iter or tolerance must be set. + + See Scala documentation for more details. + + :param resetProbability: Probability of resetting to a random vertex. + :param sourceId: (optional) the source vertex for a personalized PageRank. + :param maxIter: If set, the algorithm is run for a fixed number + of iterations. This may not be set if the `tol` parameter is set. + :param tol: If set, the algorithm is run until the given tolerance. + This may not be set if the `numIter` parameter is set. + :return: GraphFrame with new vertices column "pagerank" and new edges column "weight" + """ + builder = self._jvm_graph.pageRank().resetProbability(resetProbability) + if sourceId is not None: + builder = builder.sourceId(sourceId) + if maxIter is not None: + builder = builder.maxIter(maxIter) + assert tol is None, "Exactly one of maxIter or tol should be set." + else: + assert tol is not None, "Exactly one of maxIter or tol should be set." + builder = builder.tol(tol) + jgf = builder.run() + return _from_java_gf(jgf, self._spark) + + def parallelPersonalizedPageRank( + self, + resetProbability: float = 0.15, + sourceIds: Optional[list[Any]] = None, + maxIter: Optional[int] = None, + ) -> "GraphFrame": + """ + Run the personalized PageRank algorithm on the graph, + from the provided list of sources in parallel for a fixed number of iterations. + + See Scala documentation for more details. + + :param resetProbability: Probability of resetting to a random vertex + :param sourceIds: the source vertices for a personalized PageRank + :param maxIter: the fixed number of iterations this algorithm runs + :return: GraphFrame with new vertices column "pageranks" and new edges column "weight" + """ + assert ( + sourceIds is not None and len(sourceIds) > 0 + ), "Source vertices Ids sourceIds must be provided" + assert maxIter is not None, "Max number of iterations maxIter must be provided" + sourceIds = self._sc._jvm.PythonUtils.toArray(sourceIds) + builder = self._jvm_graph.parallelPersonalizedPageRank() + builder = builder.resetProbability(resetProbability) + builder = builder.sourceIds(sourceIds) + builder = builder.maxIter(maxIter) + jgf = builder.run() + return _from_java_gf(jgf, self._spark) + + def shortestPaths(self, landmarks: list[Any]) -> DataFrame: + """ + Runs the shortest path algorithm from a set of landmark vertices in the graph. + + See Scala documentation for more details. + + :param landmarks: a set of one or more landmarks + :return: DataFrame with new vertices column "distances" + """ + jdf = self._jvm_graph.shortestPaths().landmarks(landmarks).run() + return DataFrame(jdf, self._spark) + + def stronglyConnectedComponents(self, maxIter: int) -> DataFrame: + """ + Runs the strongly connected components algorithm on this graph. + + See Scala documentation for more details. + + :param maxIter: the number of iterations to run + :return: DataFrame with new vertex column "component" + """ + jdf = self._jvm_graph.stronglyConnectedComponents().maxIter(maxIter).run() + return DataFrame(jdf, self._spark) + + def svdPlusPlus( + self, + rank: int = 10, + maxIter: int = 2, + minValue: float = 0.0, + maxValue: float = 5.0, + gamma1: float = 0.007, + gamma2: float = 0.007, + gamma6: float = 0.005, + gamma7: float = 0.015, + ) -> tuple[DataFrame, float]: + """ + Runs the SVD++ algorithm. + + See Scala documentation for more details. + + :return: Tuple of DataFrame with new vertex columns storing learned model, and loss value + """ + # This call is actually useless, because one needs to build the configuration first... + builder = self._jvm_graph.svdPlusPlus() + builder.rank(rank).maxIter(maxIter).minValue(minValue).maxValue(maxValue) + builder.gamma1(gamma1).gamma2(gamma2).gamma6(gamma6).gamma7(gamma7) + jdf = builder.run() + loss = builder.loss() + v = DataFrame(jdf, self._spark) + return (v, loss) + + def triangleCount(self) -> DataFrame: + """ + Counts the number of triangles passing through each vertex in this graph. + + See Scala documentation for more details. + + :return: DataFrame with new vertex column "count" + """ + jdf = self._jvm_graph.triangleCount().run() + return DataFrame(jdf, self._spark) diff --git a/python/graphframes/connect/proto/graphframes_pb2.py b/python/graphframes/connect/proto/graphframes_pb2.py new file mode 100644 index 000000000..429dfb59c --- /dev/null +++ b/python/graphframes/connect/proto/graphframes_pb2.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: graphframes.proto +# Protobuf Python Version: 5.27.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 27, + 1, + '', + 'graphframes.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11graphframes.proto\x12\x1dorg.graphframes.connect.proto\"\xba\r\n\x0eGraphFramesAPI\x12\x1a\n\x08vertices\x18\x01 \x01(\x0cR\x08vertices\x12\x14\n\x05\x65\x64ges\x18\x02 \x01(\x0cR\x05\x65\x64ges\x12\x61\n\x12\x61ggregate_messages\x18\x03 \x01(\x0b\x32\x30.org.graphframes.connect.proto.AggregateMessagesH\x00R\x11\x61ggregateMessages\x12\x36\n\x03\x62\x66s\x18\x04 \x01(\x0b\x32\".org.graphframes.connect.proto.BFSH\x00R\x03\x62\x66s\x12g\n\x14\x63onnected_components\x18\x05 \x01(\x0b\x32\x32.org.graphframes.connect.proto.ConnectedComponentsH\x00R\x13\x63onnectedComponents\x12\x42\n\x07\x64\x65grees\x18\x06 \x01(\x0b\x32&.org.graphframes.connect.proto.DegreesH\x00R\x07\x64\x65grees\x12k\n\x16\x64rop_isolated_vertices\x18\x07 \x01(\x0b\x32\x33.org.graphframes.connect.proto.DropIsolatedVerticesH\x00R\x14\x64ropIsolatedVertices\x12O\n\x0c\x66ilter_edges\x18\x08 \x01(\x0b\x32*.org.graphframes.connect.proto.FilterEdgesH\x00R\x0b\x66ilterEdges\x12X\n\x0f\x66ilter_vertices\x18\t \x01(\x0b\x32-.org.graphframes.connect.proto.FilterVerticesH\x00R\x0e\x66ilterVertices\x12\x39\n\x04\x66ind\x18\n \x01(\x0b\x32#.org.graphframes.connect.proto.FindH\x00R\x04\x66ind\x12I\n\nin_degrees\x18\x0b \x01(\x0b\x32(.org.graphframes.connect.proto.InDegreesH\x00R\tinDegrees\x12^\n\x11label_propagation\x18\x0c \x01(\x0b\x32/.org.graphframes.connect.proto.LabelPropagationH\x00R\x10labelPropagation\x12L\n\x0bout_degrees\x18\r \x01(\x0b\x32).org.graphframes.connect.proto.OutDegreesH\x00R\noutDegrees\x12\x46\n\tpage_rank\x18\x0e \x01(\x0b\x32\'.org.graphframes.connect.proto.PageRankH\x00R\x08pageRank\x12\x84\x01\n\x1fparallel_personalized_page_rank\x18\x0f \x01(\x0b\x32;.org.graphframes.connect.proto.ParallelPersonalizedPageRankH\x00R\x1cparallelPersonalizedPageRank\x12?\n\x06pregel\x18\x10 \x01(\x0b\x32%.org.graphframes.connect.proto.PregelH\x00R\x06pregel\x12U\n\x0eshortest_paths\x18\x11 \x01(\x0b\x32,.org.graphframes.connect.proto.ShortestPathsH\x00R\rshortestPaths\x12\x80\x01\n\x1dstrongly_connected_components\x18\x12 \x01(\x0b\x32:.org.graphframes.connect.proto.StronglyConnectedComponentsH\x00R\x1bstronglyConnectedComponents\x12P\n\rsvd_plus_plus\x18\x13 \x01(\x0b\x32*.org.graphframes.connect.proto.SVDPlusPlusH\x00R\x0bsvdPlusPlus\x12U\n\x0etriangle_count\x18\x14 \x01(\x0b\x32,.org.graphframes.connect.proto.TriangleCountH\x00R\rtriangleCount\x12\x45\n\x08triplets\x18\x15 \x01(\x0b\x32\'.org.graphframes.connect.proto.TripletsH\x00R\x08tripletsB\x08\n\x06method\"M\n\x12\x43olumnOrExpression\x12\x12\n\x03\x63ol\x18\x01 \x01(\x0cH\x00R\x03\x63ol\x12\x14\n\x04\x65xpr\x18\x02 \x01(\tH\x00R\x04\x65xprB\r\n\x0b\x63ol_or_expr\"P\n\x0eStringOrLongID\x12\x19\n\x07long_id\x18\x01 \x01(\x03H\x00R\x06longId\x12\x1d\n\tstring_id\x18\x02 \x01(\tH\x00R\x08stringIdB\x04\n\x02id\"\xaf\x02\n\x11\x41ggregateMessages\x12J\n\x07\x61gg_col\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06\x61ggCol\x12V\n\x0bsend_to_src\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x00R\tsendToSrc\x88\x01\x01\x12V\n\x0bsend_to_dst\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x01R\tsendToDst\x88\x01\x01\x42\x0e\n\x0c_send_to_srcB\x0e\n\x0c_send_to_dst\"\x9d\x02\n\x03\x42\x46S\x12N\n\tfrom_expr\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x08\x66romExpr\x12J\n\x07to_expr\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06toExpr\x12R\n\x0b\x65\x64ge_filter\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\nedgeFilter\x12&\n\x0fmax_path_length\x18\x04 \x01(\x05R\rmaxPathLength\"\x95\x01\n\x13\x43onnectedComponents\x12\x1c\n\talgorithm\x18\x01 \x01(\tR\talgorithm\x12/\n\x13\x63heckpoint_interval\x18\x02 \x01(\x05R\x12\x63heckpointInterval\x12/\n\x13\x62roadcast_threshold\x18\x03 \x01(\x05R\x12\x62roadcastThreshold\"\t\n\x07\x44\x65grees\"\x16\n\x14\x44ropIsolatedVertices\"^\n\x0b\x46ilterEdges\x12O\n\tcondition\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition\"a\n\x0e\x46ilterVertices\x12O\n\tcondition\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition\" \n\x04\x46ind\x12\x18\n\x07pattern\x18\x01 \x01(\tR\x07pattern\"\x0b\n\tInDegrees\"-\n\x10LabelPropagation\x12\x19\n\x08max_iter\x18\x01 \x01(\x05R\x07maxIter\"\x0c\n\nOutDegrees\"\xc3\x01\n\x08PageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12O\n\tsource_id\x18\x02 \x01(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDH\x00R\x08sourceId\x88\x01\x01\x12\x19\n\x08max_iter\x18\x03 \x01(\x05R\x07maxIter\x12\x10\n\x03tol\x18\x04 \x01(\x01R\x03tolB\x0c\n\n_source_id\"\xb4\x01\n\x1cParallelPersonalizedPageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12L\n\nsource_ids\x18\x02 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tsourceIds\x12\x19\n\x08max_iter\x18\x03 \x01(\x05R\x07maxIter\"\xd0\x04\n\x06Pregel\x12L\n\x08\x61gg_msgs\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x07\x61ggMsgs\x12X\n\x0fsend_msg_to_dst\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToDst\x12X\n\x0fsend_msg_to_src\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToSrc\x12/\n\x13\x63heckpoint_interval\x18\x04 \x01(\x05R\x12\x63heckpointInterval\x12\x19\n\x08max_iter\x18\x05 \x01(\x05R\x07maxIter\x12.\n\x13\x61\x64\x64itional_col_name\x18\x06 \x01(\tR\x11\x61\x64\x64itionalColName\x12g\n\x16\x61\x64\x64itional_col_initial\x18\x07 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x14\x61\x64\x64itionalColInitial\x12_\n\x12\x61\x64\x64itional_col_upd\x18\x08 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x10\x61\x64\x64itionalColUpd\"\\\n\rShortestPaths\x12K\n\tlandmarks\x18\x01 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tlandmarks\"8\n\x1bStronglyConnectedComponents\x12\x19\n\x08max_iter\x18\x01 \x01(\x05R\x07maxIter\"\xd6\x01\n\x0bSVDPlusPlus\x12\x12\n\x04rank\x18\x01 \x01(\x05R\x04rank\x12\x19\n\x08max_iter\x18\x02 \x01(\x05R\x07maxIter\x12\x1b\n\tmin_value\x18\x03 \x01(\x01R\x08minValue\x12\x1b\n\tmax_value\x18\x04 \x01(\x01R\x08maxValue\x12\x16\n\x06gamma1\x18\x05 \x01(\x01R\x06gamma1\x12\x16\n\x06gamma2\x18\x06 \x01(\x01R\x06gamma2\x12\x16\n\x06gamma6\x18\x07 \x01(\x01R\x06gamma6\x12\x16\n\x06gamma7\x18\x08 \x01(\x01R\x06gamma7\"\x0f\n\rTriangleCount\"\n\n\x08TripletsB\xd2\x01\n!com.org.graphframes.connect.protoB\x10GraphframesProtoH\x01P\x01\xa0\x01\x01\xa2\x02\x04OGCP\xaa\x02\x1dOrg.Graphframes.Connect.Proto\xca\x02\x1dOrg\\Graphframes\\Connect\\Proto\xe2\x02)Org\\Graphframes\\Connect\\Proto\\GPBMetadata\xea\x02 Org::Graphframes::Connect::Protob\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'graphframes_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\n!com.org.graphframes.connect.protoB\020GraphframesProtoH\001P\001\240\001\001\242\002\004OGCP\252\002\035Org.Graphframes.Connect.Proto\312\002\035Org\\Graphframes\\Connect\\Proto\342\002)Org\\Graphframes\\Connect\\Proto\\GPBMetadata\352\002 Org::Graphframes::Connect::Proto' + _globals['_GRAPHFRAMESAPI']._serialized_start=53 + _globals['_GRAPHFRAMESAPI']._serialized_end=1775 + _globals['_COLUMNOREXPRESSION']._serialized_start=1777 + _globals['_COLUMNOREXPRESSION']._serialized_end=1854 + _globals['_STRINGORLONGID']._serialized_start=1856 + _globals['_STRINGORLONGID']._serialized_end=1936 + _globals['_AGGREGATEMESSAGES']._serialized_start=1939 + _globals['_AGGREGATEMESSAGES']._serialized_end=2242 + _globals['_BFS']._serialized_start=2245 + _globals['_BFS']._serialized_end=2530 + _globals['_CONNECTEDCOMPONENTS']._serialized_start=2533 + _globals['_CONNECTEDCOMPONENTS']._serialized_end=2682 + _globals['_DEGREES']._serialized_start=2684 + _globals['_DEGREES']._serialized_end=2693 + _globals['_DROPISOLATEDVERTICES']._serialized_start=2695 + _globals['_DROPISOLATEDVERTICES']._serialized_end=2717 + _globals['_FILTEREDGES']._serialized_start=2719 + _globals['_FILTEREDGES']._serialized_end=2813 + _globals['_FILTERVERTICES']._serialized_start=2815 + _globals['_FILTERVERTICES']._serialized_end=2912 + _globals['_FIND']._serialized_start=2914 + _globals['_FIND']._serialized_end=2946 + _globals['_INDEGREES']._serialized_start=2948 + _globals['_INDEGREES']._serialized_end=2959 + _globals['_LABELPROPAGATION']._serialized_start=2961 + _globals['_LABELPROPAGATION']._serialized_end=3006 + _globals['_OUTDEGREES']._serialized_start=3008 + _globals['_OUTDEGREES']._serialized_end=3020 + _globals['_PAGERANK']._serialized_start=3023 + _globals['_PAGERANK']._serialized_end=3218 + _globals['_PARALLELPERSONALIZEDPAGERANK']._serialized_start=3221 + _globals['_PARALLELPERSONALIZEDPAGERANK']._serialized_end=3401 + _globals['_PREGEL']._serialized_start=3404 + _globals['_PREGEL']._serialized_end=3996 + _globals['_SHORTESTPATHS']._serialized_start=3998 + _globals['_SHORTESTPATHS']._serialized_end=4090 + _globals['_STRONGLYCONNECTEDCOMPONENTS']._serialized_start=4092 + _globals['_STRONGLYCONNECTEDCOMPONENTS']._serialized_end=4148 + _globals['_SVDPLUSPLUS']._serialized_start=4151 + _globals['_SVDPLUSPLUS']._serialized_end=4365 + _globals['_TRIANGLECOUNT']._serialized_start=4367 + _globals['_TRIANGLECOUNT']._serialized_end=4382 + _globals['_TRIPLETS']._serialized_start=4384 + _globals['_TRIPLETS']._serialized_end=4394 +# @@protoc_insertion_point(module_scope) diff --git a/python/graphframes/connect/proto/graphframes_pb2.pyi b/python/graphframes/connect/proto/graphframes_pb2.pyi new file mode 100644 index 000000000..df0f20f89 --- /dev/null +++ b/python/graphframes/connect/proto/graphframes_pb2.pyi @@ -0,0 +1,222 @@ +from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class GraphFramesAPI(_message.Message): + __slots__ = ("vertices", "edges", "aggregate_messages", "bfs", "connected_components", "degrees", "drop_isolated_vertices", "filter_edges", "filter_vertices", "find", "in_degrees", "label_propagation", "out_degrees", "page_rank", "parallel_personalized_page_rank", "pregel", "shortest_paths", "strongly_connected_components", "svd_plus_plus", "triangle_count", "triplets") + VERTICES_FIELD_NUMBER: _ClassVar[int] + EDGES_FIELD_NUMBER: _ClassVar[int] + AGGREGATE_MESSAGES_FIELD_NUMBER: _ClassVar[int] + BFS_FIELD_NUMBER: _ClassVar[int] + CONNECTED_COMPONENTS_FIELD_NUMBER: _ClassVar[int] + DEGREES_FIELD_NUMBER: _ClassVar[int] + DROP_ISOLATED_VERTICES_FIELD_NUMBER: _ClassVar[int] + FILTER_EDGES_FIELD_NUMBER: _ClassVar[int] + FILTER_VERTICES_FIELD_NUMBER: _ClassVar[int] + FIND_FIELD_NUMBER: _ClassVar[int] + IN_DEGREES_FIELD_NUMBER: _ClassVar[int] + LABEL_PROPAGATION_FIELD_NUMBER: _ClassVar[int] + OUT_DEGREES_FIELD_NUMBER: _ClassVar[int] + PAGE_RANK_FIELD_NUMBER: _ClassVar[int] + PARALLEL_PERSONALIZED_PAGE_RANK_FIELD_NUMBER: _ClassVar[int] + PREGEL_FIELD_NUMBER: _ClassVar[int] + SHORTEST_PATHS_FIELD_NUMBER: _ClassVar[int] + STRONGLY_CONNECTED_COMPONENTS_FIELD_NUMBER: _ClassVar[int] + SVD_PLUS_PLUS_FIELD_NUMBER: _ClassVar[int] + TRIANGLE_COUNT_FIELD_NUMBER: _ClassVar[int] + TRIPLETS_FIELD_NUMBER: _ClassVar[int] + vertices: bytes + edges: bytes + aggregate_messages: AggregateMessages + bfs: BFS + connected_components: ConnectedComponents + degrees: Degrees + drop_isolated_vertices: DropIsolatedVertices + filter_edges: FilterEdges + filter_vertices: FilterVertices + find: Find + in_degrees: InDegrees + label_propagation: LabelPropagation + out_degrees: OutDegrees + page_rank: PageRank + parallel_personalized_page_rank: ParallelPersonalizedPageRank + pregel: Pregel + shortest_paths: ShortestPaths + strongly_connected_components: StronglyConnectedComponents + svd_plus_plus: SVDPlusPlus + triangle_count: TriangleCount + triplets: Triplets + def __init__(self, vertices: _Optional[bytes] = ..., edges: _Optional[bytes] = ..., aggregate_messages: _Optional[_Union[AggregateMessages, _Mapping]] = ..., bfs: _Optional[_Union[BFS, _Mapping]] = ..., connected_components: _Optional[_Union[ConnectedComponents, _Mapping]] = ..., degrees: _Optional[_Union[Degrees, _Mapping]] = ..., drop_isolated_vertices: _Optional[_Union[DropIsolatedVertices, _Mapping]] = ..., filter_edges: _Optional[_Union[FilterEdges, _Mapping]] = ..., filter_vertices: _Optional[_Union[FilterVertices, _Mapping]] = ..., find: _Optional[_Union[Find, _Mapping]] = ..., in_degrees: _Optional[_Union[InDegrees, _Mapping]] = ..., label_propagation: _Optional[_Union[LabelPropagation, _Mapping]] = ..., out_degrees: _Optional[_Union[OutDegrees, _Mapping]] = ..., page_rank: _Optional[_Union[PageRank, _Mapping]] = ..., parallel_personalized_page_rank: _Optional[_Union[ParallelPersonalizedPageRank, _Mapping]] = ..., pregel: _Optional[_Union[Pregel, _Mapping]] = ..., shortest_paths: _Optional[_Union[ShortestPaths, _Mapping]] = ..., strongly_connected_components: _Optional[_Union[StronglyConnectedComponents, _Mapping]] = ..., svd_plus_plus: _Optional[_Union[SVDPlusPlus, _Mapping]] = ..., triangle_count: _Optional[_Union[TriangleCount, _Mapping]] = ..., triplets: _Optional[_Union[Triplets, _Mapping]] = ...) -> None: ... + +class ColumnOrExpression(_message.Message): + __slots__ = ("col", "expr") + COL_FIELD_NUMBER: _ClassVar[int] + EXPR_FIELD_NUMBER: _ClassVar[int] + col: bytes + expr: str + def __init__(self, col: _Optional[bytes] = ..., expr: _Optional[str] = ...) -> None: ... + +class StringOrLongID(_message.Message): + __slots__ = ("long_id", "string_id") + LONG_ID_FIELD_NUMBER: _ClassVar[int] + STRING_ID_FIELD_NUMBER: _ClassVar[int] + long_id: int + string_id: str + def __init__(self, long_id: _Optional[int] = ..., string_id: _Optional[str] = ...) -> None: ... + +class AggregateMessages(_message.Message): + __slots__ = ("agg_col", "send_to_src", "send_to_dst") + AGG_COL_FIELD_NUMBER: _ClassVar[int] + SEND_TO_SRC_FIELD_NUMBER: _ClassVar[int] + SEND_TO_DST_FIELD_NUMBER: _ClassVar[int] + agg_col: ColumnOrExpression + send_to_src: ColumnOrExpression + send_to_dst: ColumnOrExpression + def __init__(self, agg_col: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., send_to_src: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., send_to_dst: _Optional[_Union[ColumnOrExpression, _Mapping]] = ...) -> None: ... + +class BFS(_message.Message): + __slots__ = ("from_expr", "to_expr", "edge_filter", "max_path_length") + FROM_EXPR_FIELD_NUMBER: _ClassVar[int] + TO_EXPR_FIELD_NUMBER: _ClassVar[int] + EDGE_FILTER_FIELD_NUMBER: _ClassVar[int] + MAX_PATH_LENGTH_FIELD_NUMBER: _ClassVar[int] + from_expr: ColumnOrExpression + to_expr: ColumnOrExpression + edge_filter: ColumnOrExpression + max_path_length: int + def __init__(self, from_expr: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., to_expr: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., edge_filter: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., max_path_length: _Optional[int] = ...) -> None: ... + +class ConnectedComponents(_message.Message): + __slots__ = ("algorithm", "checkpoint_interval", "broadcast_threshold") + ALGORITHM_FIELD_NUMBER: _ClassVar[int] + CHECKPOINT_INTERVAL_FIELD_NUMBER: _ClassVar[int] + BROADCAST_THRESHOLD_FIELD_NUMBER: _ClassVar[int] + algorithm: str + checkpoint_interval: int + broadcast_threshold: int + def __init__(self, algorithm: _Optional[str] = ..., checkpoint_interval: _Optional[int] = ..., broadcast_threshold: _Optional[int] = ...) -> None: ... + +class Degrees(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class DropIsolatedVertices(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class FilterEdges(_message.Message): + __slots__ = ("condition",) + CONDITION_FIELD_NUMBER: _ClassVar[int] + condition: ColumnOrExpression + def __init__(self, condition: _Optional[_Union[ColumnOrExpression, _Mapping]] = ...) -> None: ... + +class FilterVertices(_message.Message): + __slots__ = ("condition",) + CONDITION_FIELD_NUMBER: _ClassVar[int] + condition: ColumnOrExpression + def __init__(self, condition: _Optional[_Union[ColumnOrExpression, _Mapping]] = ...) -> None: ... + +class Find(_message.Message): + __slots__ = ("pattern",) + PATTERN_FIELD_NUMBER: _ClassVar[int] + pattern: str + def __init__(self, pattern: _Optional[str] = ...) -> None: ... + +class InDegrees(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class LabelPropagation(_message.Message): + __slots__ = ("max_iter",) + MAX_ITER_FIELD_NUMBER: _ClassVar[int] + max_iter: int + def __init__(self, max_iter: _Optional[int] = ...) -> None: ... + +class OutDegrees(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class PageRank(_message.Message): + __slots__ = ("reset_probability", "source_id", "max_iter", "tol") + RESET_PROBABILITY_FIELD_NUMBER: _ClassVar[int] + SOURCE_ID_FIELD_NUMBER: _ClassVar[int] + MAX_ITER_FIELD_NUMBER: _ClassVar[int] + TOL_FIELD_NUMBER: _ClassVar[int] + reset_probability: float + source_id: StringOrLongID + max_iter: int + tol: float + def __init__(self, reset_probability: _Optional[float] = ..., source_id: _Optional[_Union[StringOrLongID, _Mapping]] = ..., max_iter: _Optional[int] = ..., tol: _Optional[float] = ...) -> None: ... + +class ParallelPersonalizedPageRank(_message.Message): + __slots__ = ("reset_probability", "source_ids", "max_iter") + RESET_PROBABILITY_FIELD_NUMBER: _ClassVar[int] + SOURCE_IDS_FIELD_NUMBER: _ClassVar[int] + MAX_ITER_FIELD_NUMBER: _ClassVar[int] + reset_probability: float + source_ids: _containers.RepeatedCompositeFieldContainer[StringOrLongID] + max_iter: int + def __init__(self, reset_probability: _Optional[float] = ..., source_ids: _Optional[_Iterable[_Union[StringOrLongID, _Mapping]]] = ..., max_iter: _Optional[int] = ...) -> None: ... + +class Pregel(_message.Message): + __slots__ = ("agg_msgs", "send_msg_to_dst", "send_msg_to_src", "checkpoint_interval", "max_iter", "additional_col_name", "additional_col_initial", "additional_col_upd") + AGG_MSGS_FIELD_NUMBER: _ClassVar[int] + SEND_MSG_TO_DST_FIELD_NUMBER: _ClassVar[int] + SEND_MSG_TO_SRC_FIELD_NUMBER: _ClassVar[int] + CHECKPOINT_INTERVAL_FIELD_NUMBER: _ClassVar[int] + MAX_ITER_FIELD_NUMBER: _ClassVar[int] + ADDITIONAL_COL_NAME_FIELD_NUMBER: _ClassVar[int] + ADDITIONAL_COL_INITIAL_FIELD_NUMBER: _ClassVar[int] + ADDITIONAL_COL_UPD_FIELD_NUMBER: _ClassVar[int] + agg_msgs: ColumnOrExpression + send_msg_to_dst: ColumnOrExpression + send_msg_to_src: ColumnOrExpression + checkpoint_interval: int + max_iter: int + additional_col_name: str + additional_col_initial: ColumnOrExpression + additional_col_upd: ColumnOrExpression + def __init__(self, agg_msgs: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., send_msg_to_dst: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., send_msg_to_src: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., checkpoint_interval: _Optional[int] = ..., max_iter: _Optional[int] = ..., additional_col_name: _Optional[str] = ..., additional_col_initial: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., additional_col_upd: _Optional[_Union[ColumnOrExpression, _Mapping]] = ...) -> None: ... + +class ShortestPaths(_message.Message): + __slots__ = ("landmarks",) + LANDMARKS_FIELD_NUMBER: _ClassVar[int] + landmarks: _containers.RepeatedCompositeFieldContainer[StringOrLongID] + def __init__(self, landmarks: _Optional[_Iterable[_Union[StringOrLongID, _Mapping]]] = ...) -> None: ... + +class StronglyConnectedComponents(_message.Message): + __slots__ = ("max_iter",) + MAX_ITER_FIELD_NUMBER: _ClassVar[int] + max_iter: int + def __init__(self, max_iter: _Optional[int] = ...) -> None: ... + +class SVDPlusPlus(_message.Message): + __slots__ = ("rank", "max_iter", "min_value", "max_value", "gamma1", "gamma2", "gamma6", "gamma7") + RANK_FIELD_NUMBER: _ClassVar[int] + MAX_ITER_FIELD_NUMBER: _ClassVar[int] + MIN_VALUE_FIELD_NUMBER: _ClassVar[int] + MAX_VALUE_FIELD_NUMBER: _ClassVar[int] + GAMMA1_FIELD_NUMBER: _ClassVar[int] + GAMMA2_FIELD_NUMBER: _ClassVar[int] + GAMMA6_FIELD_NUMBER: _ClassVar[int] + GAMMA7_FIELD_NUMBER: _ClassVar[int] + rank: int + max_iter: int + min_value: float + max_value: float + gamma1: float + gamma2: float + gamma6: float + gamma7: float + def __init__(self, rank: _Optional[int] = ..., max_iter: _Optional[int] = ..., min_value: _Optional[float] = ..., max_value: _Optional[float] = ..., gamma1: _Optional[float] = ..., gamma2: _Optional[float] = ..., gamma6: _Optional[float] = ..., gamma7: _Optional[float] = ...) -> None: ... + +class TriangleCount(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class Triplets(_message.Message): + __slots__ = () + def __init__(self) -> None: ... diff --git a/python/graphframes/connect/proto/graphframes_pb2_grpc.py b/python/graphframes/connect/proto/graphframes_pb2_grpc.py new file mode 100644 index 000000000..2daafffeb --- /dev/null +++ b/python/graphframes/connect/proto/graphframes_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/python/graphframes/connect/utils.py b/python/graphframes/connect/utils.py new file mode 100644 index 000000000..9ac92040e --- /dev/null +++ b/python/graphframes/connect/utils.py @@ -0,0 +1,19 @@ +from pyspark.sql.connect.client import SparkConnectClient +from pyspark.sql.connect.column import Column +from pyspark.sql.connect.dataframe import DataFrame +from pyspark.sql.connect.expressions import Expression +from pyspark.sql.connect.plan import LogicalPlan + + +def dataframe_to_proto(df: DataFrame, client: SparkConnectClient) -> bytes: + plan = df._plan + assert plan is not None + assert isinstance(plan, LogicalPlan) + return plan.to_proto(client).SerializeToString() + + +def column_to_proto(col: Column, client: SparkConnectClient) -> bytes: + expr = col._expr + assert expr is not None + assert isinstance(expr, Expression) + return expr.to_plan(client).SerializeToString() From ea11df6f02e0d9a2f274122449243b2dcb9a6b0c Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sat, 8 Feb 2025 15:39:08 +0100 Subject: [PATCH 02/27] wip --- .../graphframes/connect/graphframe_client.py | 425 ++++++++++++------ python/graphframes/connect/utils.py | 8 + 2 files changed, 307 insertions(+), 126 deletions(-) diff --git a/python/graphframes/connect/graphframe_client.py b/python/graphframes/connect/graphframe_client.py index f228eac6e..fd925e283 100644 --- a/python/graphframes/connect/graphframe_client.py +++ b/python/graphframes/connect/graphframe_client.py @@ -1,13 +1,159 @@ +from __future__ import annotations + from typing import Self +from pyspark.sql.connect import functions as F from pyspark.sql.connect import proto from pyspark.sql.connect.client import SparkConnectClient +from pyspark.sql.connect.column import Column from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.connect.plan import LogicalPlan from pyspark.storagelevel import StorageLevel from .proto import graphframes_pb2 as pb -from .utils import column_to_proto, dataframe_to_proto +from .utils import dataframe_to_proto, make_column_or_expr + + +class PregelConnect: + def __init__(self, graph: "GraphFrameConnect") -> None: + self.graph = graph + self._max_iter = 10 + self._checkpoint_interval = 2 + self._col_name = None + self._initial_expr = None + self._update_after_agg_msgs_expr = None + self._send_msg_to_src = None + self._send_msg_to_dst = None + self._agg_msg = None + + def setMaxIter(self, value: int) -> Self: + self._max_iter = value + return self + + def setCheckpointInterval(self, value: int) -> Self: + self._checkpoint_interval = value + return self + + def withVertexColumn( + self, + colName: str, + initialExpr: Column | str, + updateAfterAggMsgsExpr: Column | str, + ) -> Self: + self._col_name = colName + self._initial_expr = initialExpr + self._update_after_agg_msgs_expr = updateAfterAggMsgsExpr + return self + + def sendMsgToSrc(self, msgExpr: Column | str) -> Self: + self._send_msg_to_src = msgExpr + return self + + def sendMsgToDst(self, msgExpr: Column | str) -> Self: + self._send_msg_to_dst = msgExpr + return self + + def aggMsgs(self, aggExpr: Column) -> Self: + self._agg_msg = aggExpr + return self + + def run(self) -> DataFrame: + class Pregel(LogicalPlan): + def __init__( + self, + max_iter: int, + checkpoint_interval: int, + vertex_col_name: str, + agg_msg: Column | str, + send2dst: Column | str, + send2src: Column | str, + vertex_col_init: Column | str, + vertex_col_upd: Column | str, + vertices: DataFrame, + edges: DataFrame, + ) -> None: + self.max_iter = max_iter + self.checkpoint_interval = checkpoint_interval + self.vertex_col_name = vertex_col_name + self.agg_msg = agg_msg + self.send2dst = send2dst + self.send2src = send2src + self.vertex_col_init = vertex_col_init + self.vertex_col_upd = vertex_col_upd + self.vertices = vertices + self.edges = edges + + def plan(self, session: SparkConnectClient) -> proto.Relation: + plan = self._create_proto_relation() + pregel = pb.Pregel( + agg_msgs=make_column_or_expr(self.agg_msg, session), + send_msg_to_dst=make_column_or_expr(self.send2dst, session), + send_msg_to_src=make_column_or_expr(self.send2src, session), + checkpoint_interval=self.checkpoint_interval, + max_iter=self.max_iter, + additional_col_name=self.vertex_col_name, + additional_col_initial=make_column_or_expr( + self.vertex_col_init, session + ), + additional_col_upd=make_column_or_expr( + self.vertex_col_upd, session + ), + ) + pb_message = pb.GraphFramesAPI( + vertices=dataframe_to_proto(self.vertices, session), + edges=dataframe_to_proto(self.edges, session), + ) + pb_message.pregel = pregel + plan.extension.Pack(pb_message) + return plan + + if ( + (self._col_name is None) + or (self._initial_expr is None) + or (self._update_after_agg_msgs_expr is None) + ): + raise ValueError("Initial vertex column is not initialized!") + + if self._agg_msg is None: + raise ValueError("AggMsg is not initialized!") + + if self._send_msg_to_src is None: + raise ValueError("Send-to-src column is not initialized!") + + if self._send_msg_to_dst is None: + raise ValueError("Send-to-dst column is not initialized!") + + return DataFrame.withPlan( + Pregel( + max_iter=self._max_iter, + checkpoint_interval=self._checkpoint_interval, + vertex_col_name=self._col_name, + vertex_col_init=self._initial_expr, + vertex_col_upd=self._update_after_agg_msgs_expr, + agg_msg=self._agg_msg, + send2dst=self._send_msg_to_dst, + send2src=self._send_msg_to_src, + vertices=self.graph._vertices, + edges=self.graph._edges, + ), + session=self.graph._spark, + ) + + @staticmethod + def msg() -> Column: + return F.col("_pregel_msg_") + + @staticmethod + def src(colName: str) -> Column: + return F.col("src." + colName) + + @staticmethod + def dst(colName: str) -> Column: + return F.col("dst." + colName) + + @staticmethod + def edge(colName: str) -> Column: + return F.col("edge." + colName) class GraphFrameConnect: @@ -51,19 +197,10 @@ def _get_pb_api_message( @property def vertices(self) -> DataFrame: - """ - :class:`DataFrame` holding vertex information, with unique column "id" - for vertex IDs. - """ return self._vertices @property def edges(self) -> DataFrame: - """ - :class:`DataFrame` holding edge information, with unique columns "src" and - "dst" storing source vertex IDs and destination vertex IDs of edges, - respectively. - """ return self._edges def __repr__(self) -> str: @@ -78,47 +215,28 @@ def __repr__(self) -> str: return f"GraphFrame(v:{v}, e:{e})" def cache(self) -> Self: - """Persist the dataframe representation of vertices and edges of the graph with the default - storage level. - """ self._vertices = self._vertices.cache() self._edges = self._edges.cache() return self def persist(self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) -> Self: - """Persist the dataframe representation of vertices and edges of the graph with the given - storage level. - """ self._vertices = self._vertices.persist(storageLevel=storageLevel) self._edges = self._edges.persist(storageLevel=storageLevel) return self def unpersist(self, blocking: bool = False) -> Self: - """Mark the dataframe representation of vertices and edges of the graph as non-persistent, - and remove all blocks for it from memory and disk. - """ self._vertices = self._vertices.unpersist(blocking=blocking) self._edges = self._edges.unpersist(blocking=blocking) return self @property def outDegrees(self) -> DataFrame: - """ - The out-degree of each vertex in the graph, returned as a DataFrame with two columns: - - "id": the ID of the vertex - - "outDegree" (integer) storing the out-degree of the vertex - - Note that vertices with 0 out-edges are not returned in the result. - - :return: DataFrame with new vertices column "outDegree" - """ - class OutDegrees(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame) -> None: self.v = v self.e = e - def plan(self, session: "SparkConnectClient") -> proto.Relation: + def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) @@ -131,22 +249,12 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation: @property def inDegrees(self) -> DataFrame: - """ - The in-degree of each vertex in the graph, returned as a DataFame with two columns: - - "id": the ID of the vertex - - "inDegree" (int) storing the in-degree of the vertex - - Note that vertices with 0 in-edges are not returned in the result. - - :return: DataFrame with new vertices column "inDegree" - """ - class InDegrees(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame) -> None: self.v = v self.e = e - def plan(self, session: "SparkConnectClient") -> proto.Relation: + def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) @@ -159,22 +267,12 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation: @property def degrees(self) -> DataFrame: - """ - The degree of each vertex in the graph, returned as a DataFrame with two columns: - - "id": the ID of the vertex - - 'degree' (integer) the degree of the vertex - - Note that vertices with 0 edges are not returned in the result. - - :return: DataFrame with new vertices column "degree" - """ - class Degrees(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame) -> None: self.v = v self.e = e - def plan(self, session: "SparkConnectClient") -> proto.Relation: + def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) @@ -187,104 +285,179 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation: @property def triplets(self) -> DataFrame: - """ - The triplets (source vertex)-[edge]->(destination vertex) for all edges in the graph. + class Triplets(LogicalPlan): + def __init__(self, v: DataFrame, e: DataFrame) -> None: + self.v = v + self.e = e - Returned as a :class:`DataFrame` with three columns: - - "src": source vertex with schema matching 'vertices' - - "edge": edge with schema matching 'edges' - - 'dst': destination vertex with schema matching 'vertices' + def plan(self, session: SparkConnectClient) -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.triplets = pb.Triplets() + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan - :return: DataFrame with columns 'src', 'edge', and 'dst' - """ - jdf = self._jvm_graph.triplets() - return DataFrame(jdf, self._spark) + return DataFrame.withPlan(Triplets(self._vertices, self._edges), self._spark) @property def pregel(self): - """ - Get the :class:`graphframes.lib.Pregel` object for running pregel. - - See :class:`graphframes.lib.Pregel` for more details. - """ - return Pregel(self) + return PregelConnect(self) def find(self, pattern: str) -> DataFrame: - """ - Motif finding. + class Find(LogicalPlan): + def __init__(self, v: DataFrame, e: DataFrame, pattern: str) -> None: + self.v = v + self.e = e + self.p = pattern - See Scala documentation for more details. + def plan(self, session: SparkConnectClient) -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.find = pb.Find(pattern=self.p) + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan - :param pattern: String describing the motif to search for. - :return: DataFrame with one Row for each instance of the motif found - """ - jdf = self._jvm_graph.find(pattern) - return DataFrame(jdf, self._spark) + return DataFrame.withPlan( + Find(self._vertices, self._edges, pattern), self._spark + ) - def filterVertices(self, condition: Union[str, Column]) -> "GraphFrame": - """ - Filters the vertices based on expression, remove edges containing any dropped vertices. + def filterVertices(self, condition: str | Column) -> Self: + class FilterVertices(LogicalPlan): + def __init__( + self, v: DataFrame, e: DataFrame, condition: str | Column + ) -> None: + self.v = v + self.e = e + self.c = condition - :param condition: String or Column describing the condition expression for filtering. - :return: GraphFrame with filtered vertices and edges. - """ + def plan(self, session: SparkConnectClient) -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + col_or_expr = make_column_or_expr(self.c, session) + graphframes_api_call.filter_vertices = pb.FilterVertices( + condition=col_or_expr + ) + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan - if isinstance(condition, basestring): - jdf = self._jvm_graph.filterVertices(condition) - elif isinstance(condition, Column): - jdf = self._jvm_graph.filterVertices(condition._jc) - else: - raise TypeError("condition should be string or Column") - return _from_java_gf(jdf, self._spark) + self._vertices = DataFrame.withPlan( + FilterVertices(self._vertices, self._edges, condition), self._spark + ) + self._edges = self._edges.join( + self._vertices.withColumn(self.SRC, F.col(self.ID)), + on=[self.SRC], + how="left_semi", + ).join( + self._vertices.withColumn(self.DST, F.col(self.ID)), + on=[self.DST], + how="left_semi", + ) + return self - def filterEdges(self, condition: Union[str, Column]) -> "GraphFrame": - """ - Filters the edges based on expression, keep all vertices. + def filterEdges(self, condition: str | Column) -> Self: + class FilterEdges(LogicalPlan): + def __init__( + self, v: DataFrame, e: DataFrame, condition: str | Column + ) -> None: + self.v = v + self.e = e + self.c = condition - :param condition: String or Column describing the condition expression for filtering. - :return: GraphFrame with filtered edges. - """ - if isinstance(condition, basestring): - jdf = self._jvm_graph.filterEdges(condition) - elif isinstance(condition, Column): - jdf = self._jvm_graph.filterEdges(condition._jc) - else: - raise TypeError("condition should be string or Column") - return _from_java_gf(jdf, self._spark) + def plan(self, session: SparkConnectClient) -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + col_or_expr = make_column_or_expr(self.c, session) + graphframes_api_call.filter_edges = pb.FilterEdges( + condition=col_or_expr + ) + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan - def dropIsolatedVertices(self) -> "GraphFrame": - """ - Drops isolated vertices, vertices are not contained in any edges. + self._edges = DataFrame.withPlan( + FilterEdges(self._vertices, self._edges, condition), self._spark + ) + return self - :return: GraphFrame with filtered vertices. - """ - jdf = self._jvm_graph.dropIsolatedVertices() - return _from_java_gf(jdf, self._spark) + def dropIsolatedVertices(self) -> Self: + class DropIsolatedVertices(LogicalPlan): + def __init__(self, v: DataFrame, e: DataFrame) -> None: + self.v = v + self.e = e + + def plan(self, session: SparkConnectClient) -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.drop_isolated_vertices = pb.DropIsolatedVertices() + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan + + self._vertices = DataFrame.withPlan( + DropIsolatedVertices(self._vertices, self._edges), self._spark + ) + return self def bfs( self, - fromExpr: str, - toExpr: str, - edgeFilter: Optional[str] = None, + fromExpr: Column | str, + toExpr: Column | str, + edgeFilter: Column | str | None = None, maxPathLength: int = 10, ) -> DataFrame: - """ - Breadth-first search (BFS). + class BFS(LogicalPlan): + def __init__( + self, + v: DataFrame, + e: DataFrame, + from_expr: Column | str, + to_expr: Column | str, + edge_filter: Column | str, + max_path_len: int, + ) -> None: + self.v = v + self.e = e + self.from_expr = from_expr + self.to_expr = to_expr + self.edge_filter = edge_filter + self.max_path_len = max_path_len - See Scala documentation for more details. + def plan(self, session: SparkConnectClient) -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.bfs = pb.BFS( + from_expr=make_column_or_expr(self.from_expr, session), + to_expr=make_column_or_expr(self.to_expr, session), + edge_filter=make_column_or_expr(self.edge_filter, session), + max_path_length=self.max_path_len, + ) + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan - :return: DataFrame with one Row for each shortest path between matching vertices. - """ - builder = ( - self._jvm_graph.bfs() - .fromExpr(fromExpr) - .toExpr(toExpr) - .maxPathLength(maxPathLength) + if edgeFilter is None: + edgeFilter = F.lit(True) + + return DataFrame.withPlan( + BFS( + v=self._vertices, + e=self._edges, + from_expr=fromExpr, + to_expr=toExpr, + edge_filter=edgeFilter, + max_path_len=maxPathLength, + ), + self._spark, ) - if edgeFilter is not None: - builder.edgeFilter(edgeFilter) - jdf = builder.run() - return DataFrame(jdf, self._spark) def aggregateMessages( self, diff --git a/python/graphframes/connect/utils.py b/python/graphframes/connect/utils.py index 9ac92040e..3e5d55105 100644 --- a/python/graphframes/connect/utils.py +++ b/python/graphframes/connect/utils.py @@ -4,6 +4,8 @@ from pyspark.sql.connect.expressions import Expression from pyspark.sql.connect.plan import LogicalPlan +from .proto.graphframes_pb2 import ColumnOrExpression + def dataframe_to_proto(df: DataFrame, client: SparkConnectClient) -> bytes: plan = df._plan @@ -17,3 +19,9 @@ def column_to_proto(col: Column, client: SparkConnectClient) -> bytes: assert expr is not None assert isinstance(expr, Expression) return expr.to_plan(client).SerializeToString() + +def make_column_or_expr(col: Column | str, client: SparkConnectClient) -> ColumnOrExpression: + if isinstance(col, Column): + return ColumnOrExpression(col=column_to_proto(col, client)) + else: + return ColumnOrExpression(expr=col) From 0ddd5bdcc7a3c1bfc7c48ae67a024e09e0e461eb Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sat, 8 Feb 2025 16:50:35 +0100 Subject: [PATCH 03/27] wip --- .../graphframes/connect/graphframe_client.py | 348 +++++++++++------- python/graphframes/connect/utils.py | 8 +- 2 files changed, 222 insertions(+), 134 deletions(-) diff --git a/python/graphframes/connect/graphframe_client.py b/python/graphframes/connect/graphframe_client.py index fd925e283..b0711b2fe 100644 --- a/python/graphframes/connect/graphframe_client.py +++ b/python/graphframes/connect/graphframe_client.py @@ -11,7 +11,7 @@ from pyspark.storagelevel import StorageLevel from .proto import graphframes_pb2 as pb -from .utils import dataframe_to_proto, make_column_or_expr +from .utils import dataframe_to_proto, make_column_or_expr, make_str_or_long_id class PregelConnect: @@ -461,54 +461,53 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: def aggregateMessages( self, - aggCol: Union[Column, str], - sendToSrc: Union[Column, str, None] = None, - sendToDst: Union[Column, str, None] = None, + aggCol: Column | str, + sendToSrc: Column | str | None = None, + sendToDst: Column | str | None = None, ) -> DataFrame: - """ - Aggregates messages from the neighbours. - - When specifying the messages and aggregation function, the user may reference columns using - the static methods in :class:`graphframes.lib.AggregateMessages`. - - See Scala documentation for more details. + class AggregateMessages(LogicalPlan): + def __init__( + self, + v: DataFrame, + e: DataFrame, + agg_col: Column | str, + send2src: Column | str | None, + send2dst: Column | str | None, + ) -> None: + self.v = v + self.e = e + self.agg_col = agg_col + self.send2src = send2src + self.send2dst = send2dst - :param aggCol: the requested aggregation output either as - :class:`pyspark.sql.Column` or SQL expression string - :param sendToSrc: message sent to the source vertex of each triplet either as - :class:`pyspark.sql.Column` or SQL expression string (default: None) - :param sendToDst: message sent to the destination vertex of each triplet either as - :class:`pyspark.sql.Column` or SQL expression string (default: None) + def plan(self, session: SparkConnectClient) -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.aggregate_messages = pb.AggregateMessages( + agg_col=make_column_or_expr(self.agg_col, session), + send_to_src=None + if self.send2src is None + else make_column_or_expr(self.send2src, session), + send_to_dst=None + if self.send2dst is None + else make_column_or_expr(self.send2dst, session), + ) + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan - :return: DataFrame with columns for the vertex ID and the resulting aggregated message - """ - # Check that either sendToSrc, sendToDst, or both are provided if sendToSrc is None and sendToDst is None: raise ValueError( "Either `sendToSrc`, `sendToDst`, or both have to be provided" ) - builder = self._jvm_graph.aggregateMessages() - if sendToSrc is not None: - if isinstance(sendToSrc, Column): - builder.sendToSrc(sendToSrc._jc) - elif isinstance(sendToSrc, basestring): - builder.sendToSrc(sendToSrc) - else: - raise TypeError("Provide message either as `Column` or `str`") - if sendToDst is not None: - if isinstance(sendToDst, Column): - builder.sendToDst(sendToDst._jc) - elif isinstance(sendToDst, basestring): - builder.sendToDst(sendToDst) - else: - raise TypeError("Provide message either as `Column` or `str`") - if isinstance(aggCol, Column): - jdf = builder.agg(aggCol._jc) - else: - jdf = builder.agg(aggCol) - return DataFrame(jdf, self._spark) - # Standard algorithms + return DataFrame.withPlan( + AggregateMessages( + self._vertices, self._edges, aggCol, sendToSrc, sendToDst + ), + self._spark, + ) def connectedComponents( self, @@ -516,39 +515,66 @@ def connectedComponents( checkpointInterval: int = 2, broadcastThreshold: int = 1000000, ) -> DataFrame: - """ - Computes the connected components of the graph. - - See Scala documentation for more details. - - :param algorithm: connected components algorithm to use (default: "graphframes") - Supported algorithms are "graphframes" and "graphx". - :param checkpointInterval: checkpoint interval in terms of number of iterations (default: 2) - :param broadcastThreshold: broadcast threshold in propagating component assignments - (default: 1000000) - - :return: DataFrame with new vertices column "component" - """ - jdf = ( - self._jvm_graph.connectedComponents() - .setAlgorithm(algorithm) - .setCheckpointInterval(checkpointInterval) - .setBroadcastThreshold(broadcastThreshold) - .run() + class ConnectedComponents(LogicalPlan): + def __init__( + self, + v: DataFrame, + e: DataFrame, + algorithm: str, + checkpoint_interval: int, + broadcast_threshold: int, + ) -> None: + self.v = v + self.e = e + self.algorithm = algorithm + self.checkpoint_interval = checkpoint_interval + self.broadcast_threshold = broadcast_threshold + + def plan(self, session: SparkConnectClient) -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.connected_components = pb.ConnectedComponents( + algorithm=self.algorithm, + checkpoint_interval=self.checkpoint_interval, + broadcast_threshold=self.broadcast_threshold, + ) + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan + + return DataFrame.withPlan( + ConnectedComponents( + self._vertices, + self._edges, + algorithm, + checkpointInterval, + broadcastThreshold, + ), + self._spark, ) - return DataFrame(jdf, self._spark) def labelPropagation(self, maxIter: int) -> DataFrame: - """ - Runs static label propagation for detecting communities in networks. + class LabelPropagation(LogicalPlan): + def __init__(self, v: DataFrame, e: DataFrame, max_iter: int) -> None: + self.v = v + self.e = e + self.max_iter = max_iter - See Scala documentation for more details. + def plan(self, session: SparkConnectClient) -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.label_propagation = pb.LabelPropagation( + max_iter=self.max_iter + ) + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan - :param maxIter: the number of iterations to be performed - :return: DataFrame with new vertices column "label" - """ - jdf = self._jvm_graph.labelPropagation().maxIter(maxIter).run() - return DataFrame(jdf, self._spark) + return DataFrame.withPlan( + LabelPropagation(self._vertices, self._edges, maxIter), self._spark + ) def pageRank( self, @@ -557,20 +583,6 @@ def pageRank( maxIter: Optional[int] = None, tol: Optional[float] = None, ) -> "GraphFrame": - """ - Runs the PageRank algorithm on the graph. - Note: Exactly one of fixed_num_iter or tolerance must be set. - - See Scala documentation for more details. - - :param resetProbability: Probability of resetting to a random vertex. - :param sourceId: (optional) the source vertex for a personalized PageRank. - :param maxIter: If set, the algorithm is run for a fixed number - of iterations. This may not be set if the `tol` parameter is set. - :param tol: If set, the algorithm is run until the given tolerance. - This may not be set if the `numIter` parameter is set. - :return: GraphFrame with new vertices column "pagerank" and new edges column "weight" - """ builder = self._jvm_graph.pageRank().resetProbability(resetProbability) if sourceId is not None: builder = builder.sourceId(sourceId) @@ -589,17 +601,6 @@ def parallelPersonalizedPageRank( sourceIds: Optional[list[Any]] = None, maxIter: Optional[int] = None, ) -> "GraphFrame": - """ - Run the personalized PageRank algorithm on the graph, - from the provided list of sources in parallel for a fixed number of iterations. - - See Scala documentation for more details. - - :param resetProbability: Probability of resetting to a random vertex - :param sourceIds: the source vertices for a personalized PageRank - :param maxIter: the fixed number of iterations this algorithm runs - :return: GraphFrame with new vertices column "pageranks" and new edges column "weight" - """ assert ( sourceIds is not None and len(sourceIds) > 0 ), "Source vertices Ids sourceIds must be provided" @@ -612,29 +613,52 @@ def parallelPersonalizedPageRank( jgf = builder.run() return _from_java_gf(jgf, self._spark) - def shortestPaths(self, landmarks: list[Any]) -> DataFrame: - """ - Runs the shortest path algorithm from a set of landmark vertices in the graph. + def shortestPaths(self, landmarks: list[str | int]) -> DataFrame: + class ShortestPaths(LogicalPlan): + def __init__( + self, v: DataFrame, e: DataFrame, landmarks: list[str | int] + ) -> None: + self.v = v + self.e = e + self.landmarks = landmarks - See Scala documentation for more details. + def plan(self, session: SparkConnectClient) -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.shortest_paths = pb.ShortestPaths( + landmarks=[make_str_or_long_id(raw_id) for raw_id in self.landmarks] + ) + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan - :param landmarks: a set of one or more landmarks - :return: DataFrame with new vertices column "distances" - """ - jdf = self._jvm_graph.shortestPaths().landmarks(landmarks).run() - return DataFrame(jdf, self._spark) + return DataFrame.withPlan( + ShortestPaths(self._vertices, self._edges, landmarks), self._spark + ) def stronglyConnectedComponents(self, maxIter: int) -> DataFrame: - """ - Runs the strongly connected components algorithm on this graph. + class StronglyConnectedComponents(LogicalPlan): + def __init__(self, v: DataFrame, e: DataFrame, max_iter: int) -> None: + self.v = v + self.e = e + self.max_iter = max_iter - See Scala documentation for more details. + def plan(self, session: SparkConnectClient) -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.strongly_connected_components = ( + pb.StronglyConnectedComponents(max_iter=self.max_iter) + ) + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan - :param maxIter: the number of iterations to run - :return: DataFrame with new vertex column "component" - """ - jdf = self._jvm_graph.stronglyConnectedComponents().maxIter(maxIter).run() - return DataFrame(jdf, self._spark) + return DataFrame.withPlan( + StronglyConnectedComponents(self._vertices, self._edges, maxIter), + self._spark, + ) def svdPlusPlus( self, @@ -646,30 +670,88 @@ def svdPlusPlus( gamma2: float = 0.007, gamma6: float = 0.005, gamma7: float = 0.015, + return_loss: bool = False, # TODO: should it be True to mimic the classic API? ) -> tuple[DataFrame, float]: - """ - Runs the SVD++ algorithm. - - See Scala documentation for more details. - - :return: Tuple of DataFrame with new vertex columns storing learned model, and loss value - """ - # This call is actually useless, because one needs to build the configuration first... - builder = self._jvm_graph.svdPlusPlus() - builder.rank(rank).maxIter(maxIter).minValue(minValue).maxValue(maxValue) - builder.gamma1(gamma1).gamma2(gamma2).gamma6(gamma6).gamma7(gamma7) - jdf = builder.run() - loss = builder.loss() - v = DataFrame(jdf, self._spark) - return (v, loss) + class SVDPlusPlus(LogicalPlan): + def __init__( + self, + v: DataFrame, + e: DataFrame, + rank: int, + max_iter: int, + min_value: float, + max_value: float, + gamma1: float, + gamma2: float, + gamma6: float, + gamma7: float, + ) -> None: + self.v = v + self.e = e + self.rank = rank + self.max_iter = max_iter + self.min_value = min_value + self.max_value = max_value + self.gamma1 = gamma1 + self.gamma2 = gamma2 + self.gamma6 = gamma6 + self.gamma7 = gamma7 + + def plan(self, session: SparkConnectClient) -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.svd_plus_plus = pb.SVDPlusPlus( + rank=self.rank, + max_iter=self.max_iter, + min_value=self.min_value, + max_value=self.max_value, + gamma1=self.gamma1, + gamma2=self.gamma2, + gamma6=self.gamma6, + gamma7=self.gamma7, + ) + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan + + output = DataFrame.withPlan( + SVDPlusPlus( + self._vertices, + self._edges, + rank=rank, + max_iter=maxIter, + min_value=minValue, + max_value=maxValue, + gamma1=gamma1, + gamma2=gamma2, + gamma6=gamma6, + gamma7=gamma7, + ), + self._spark, + ) + + if return_loss: + # This branch may be computationaly expensive and it is not lazy! + return (output.drop("loss"), output.select("loss").take(1)[0]["loss"]) + else: + return (output.drop("loss"), -1.0) def triangleCount(self) -> DataFrame: - """ - Counts the number of triangles passing through each vertex in this graph. + class TriangleCount(LogicalPlan): + def __init__(self, v: DataFrame, e: DataFrame) -> None: + self.v = v + self.e = e - See Scala documentation for more details. + def plan(self, session: SparkConnectClient) -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.triangle_count = pb.TriangleCount() + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan - :return: DataFrame with new vertex column "count" - """ - jdf = self._jvm_graph.triangleCount().run() - return DataFrame(jdf, self._spark) + return DataFrame.withPlan( + TriangleCount(self._vertices, self._edges), self._spark + ) diff --git a/python/graphframes/connect/utils.py b/python/graphframes/connect/utils.py index 3e5d55105..a4c270130 100644 --- a/python/graphframes/connect/utils.py +++ b/python/graphframes/connect/utils.py @@ -4,7 +4,7 @@ from pyspark.sql.connect.expressions import Expression from pyspark.sql.connect.plan import LogicalPlan -from .proto.graphframes_pb2 import ColumnOrExpression +from .proto.graphframes_pb2 import ColumnOrExpression, StringOrLongID def dataframe_to_proto(df: DataFrame, client: SparkConnectClient) -> bytes: @@ -25,3 +25,9 @@ def make_column_or_expr(col: Column | str, client: SparkConnectClient) -> Column return ColumnOrExpression(col=column_to_proto(col, client)) else: return ColumnOrExpression(expr=col) + +def make_str_or_long_id(str_or_long: str | int) -> StringOrLongID: + if isinstance(str_or_long, str): + return StringOrLongID(string_id=str_or_long) + else: + return StringOrLongID(long_id=str_or_long) From da7eccc2aef9617b743a23bffeaf9b5d98637a21 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sat, 8 Feb 2025 17:45:55 +0100 Subject: [PATCH 04/27] wip --- .../src/main/protobuf/graphframes.proto | 4 +- .../graphframes/GraphFramesConnectUtils.scala | 12 +- .../graphframes/connect/graphframe_client.py | 202 +++++++++++++----- 3 files changed, 160 insertions(+), 58 deletions(-) diff --git a/graphframes-connect/src/main/protobuf/graphframes.proto b/graphframes-connect/src/main/protobuf/graphframes.proto index 2cdd0febf..ca5594d64 100644 --- a/graphframes-connect/src/main/protobuf/graphframes.proto +++ b/graphframes-connect/src/main/protobuf/graphframes.proto @@ -94,8 +94,8 @@ message OutDegrees {} message PageRank { double reset_probability = 1; optional StringOrLongID source_id = 2; - int32 max_iter = 3; - double tol = 4; + optional int32 max_iter = 3; + optional double tol = 4; } message ParallelPersonalizedPageRank { diff --git a/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala index 062241d9e..f80d557d9 100644 --- a/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala +++ b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala @@ -126,12 +126,13 @@ object GraphFramesConnectUtils { } case MethodCase.PAGE_RANK => { val pageRankProto = apiMessage.getPageRank - val pageRank = graphFrame.pageRank + val pageRank = graphFrame.pageRank.resetProbability(pageRankProto.getResetProbability) - pageRank - .maxIter(pageRankProto.getMaxIter) - .tol(pageRankProto.getTol) - .resetProbability(pageRankProto.getResetProbability) + if (pageRankProto.hasMaxIter) { + pageRank.maxIter(pageRankProto.getMaxIter) + } else { + pageRank.tol(pageRankProto.getTol) + } if (pageRankProto.hasSourceId) { pageRank.sourceId(parseLongOrStringID(pageRankProto.getSourceId)) @@ -139,6 +140,7 @@ object GraphFramesConnectUtils { // Edges should be updated on the client side // TODO: do we really need an edge weights in that case? + // see comments in the Python API pageRank.run().vertices } case MethodCase.PARALLEL_PERSONALIZED_PAGE_RANK => { diff --git a/python/graphframes/connect/graphframe_client.py b/python/graphframes/connect/graphframe_client.py index b0711b2fe..68d9a74b7 100644 --- a/python/graphframes/connect/graphframe_client.py +++ b/python/graphframes/connect/graphframe_client.py @@ -214,20 +214,22 @@ def __repr__(self) -> str: return f"GraphFrame(v:{v}, e:{e})" - def cache(self) -> Self: - self._vertices = self._vertices.cache() - self._edges = self._edges.cache() - return self - - def persist(self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) -> Self: - self._vertices = self._vertices.persist(storageLevel=storageLevel) - self._edges = self._edges.persist(storageLevel=storageLevel) - return self - - def unpersist(self, blocking: bool = False) -> Self: - self._vertices = self._vertices.unpersist(blocking=blocking) - self._edges = self._edges.unpersist(blocking=blocking) - return self + def cache(self) -> "GraphFrameConnect": + new_vertices = self._vertices.cache() + new_edges = self._edges.cache() + return GraphFrameConnect(new_vertices, new_edges) + + def persist( + self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + ) -> "GraphFrameConnect": + new_vertices = self._vertices.persist(storageLevel=storageLevel) + new_edges = self._edges.persist(storageLevel=storageLevel) + return GraphFrameConnect(new_vertices, new_edges) + + def unpersist(self, blocking: bool = False) -> "GraphFrameConnect": + new_vertices = self._vertices.unpersist(blocking=blocking) + new_edges = self._edges.unpersist(blocking=blocking) + return GraphFrameConnect(new_vertices, new_edges) @property def outDegrees(self) -> DataFrame: @@ -325,7 +327,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: Find(self._vertices, self._edges, pattern), self._spark ) - def filterVertices(self, condition: str | Column) -> Self: + def filterVertices(self, condition: str | Column) -> "GraphFrameConnect": class FilterVertices(LogicalPlan): def __init__( self, v: DataFrame, e: DataFrame, condition: str | Column @@ -346,21 +348,22 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: plan.extension.Pack(graphframes_api_call) return plan - self._vertices = DataFrame.withPlan( + new_vertices = DataFrame.withPlan( FilterVertices(self._vertices, self._edges, condition), self._spark ) - self._edges = self._edges.join( - self._vertices.withColumn(self.SRC, F.col(self.ID)), + # Exactly like in the scala-core + new_edges = self._edges.join( + new_vertices.withColumn(self.SRC, F.col(self.ID)), on=[self.SRC], how="left_semi", ).join( - self._vertices.withColumn(self.DST, F.col(self.ID)), + new_vertices.withColumn(self.DST, F.col(self.ID)), on=[self.DST], how="left_semi", ) - return self + return GraphFrameConnect(new_vertices, new_edges) - def filterEdges(self, condition: str | Column) -> Self: + def filterEdges(self, condition: str | Column) -> "GraphFrameConnect": class FilterEdges(LogicalPlan): def __init__( self, v: DataFrame, e: DataFrame, condition: str | Column @@ -381,12 +384,12 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: plan.extension.Pack(graphframes_api_call) return plan - self._edges = DataFrame.withPlan( + new_edges = DataFrame.withPlan( FilterEdges(self._vertices, self._edges, condition), self._spark ) - return self + return GraphFrameConnect(self._vertices, new_edges) - def dropIsolatedVertices(self) -> Self: + def dropIsolatedVertices(self) -> "GraphFrameConnect": class DropIsolatedVertices(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame) -> None: self.v = v @@ -401,10 +404,10 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: plan.extension.Pack(graphframes_api_call) return plan - self._vertices = DataFrame.withPlan( + new_vertices = DataFrame.withPlan( DropIsolatedVertices(self._vertices, self._edges), self._spark ) - return self + return GraphFrameConnect(new_vertices, self._edges) def bfs( self, @@ -576,42 +579,139 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: LabelPropagation(self._vertices, self._edges, maxIter), self._spark ) + def _update_page_rank_edge_weights( + self, new_vertices: DataFrame + ) -> "GraphFrameConnect": + cols2select = self.edges.columns + ["weight"] + new_edges = ( + self._edges.join( + new_vertices.withColumn(self.SRC, F.col(self.ID)), + on=[self.SRC], + how="inner", + ) + .join( + self.outDegrees.withColumn(self.SRC, F.col(self.ID)), + on=[self.SRC], + how="inner", + ) + .withColumn("weight", F.col("pagerank") / F.col("outDegree")) + .select(*cols2select) + ) + return GraphFrameConnect(new_vertices, new_edges) + def pageRank( self, resetProbability: float = 0.15, - sourceId: Optional[Any] = None, - maxIter: Optional[int] = None, - tol: Optional[float] = None, - ) -> "GraphFrame": - builder = self._jvm_graph.pageRank().resetProbability(resetProbability) - if sourceId is not None: - builder = builder.sourceId(sourceId) - if maxIter is not None: - builder = builder.maxIter(maxIter) - assert tol is None, "Exactly one of maxIter or tol should be set." - else: - assert tol is not None, "Exactly one of maxIter or tol should be set." - builder = builder.tol(tol) - jgf = builder.run() - return _from_java_gf(jgf, self._spark) + sourceId: str | int | None = None, + maxIter: int | None = None, + tol: float | None = None, + ) -> "GraphFrameConnect": + class PageRank(LogicalPlan): + def __init__( + self, + v: DataFrame, + e: DataFrame, + reset_prob: float, + source_id: str | int | None, + max_iter: int | None, + tol: float | None, + ) -> None: + self.v = v + self.e = e + self.reset_prob = reset_prob + self.source_id = source_id + self.max_iter = max_iter + self.tol = tol + + def plan(self, session: SparkConnectClient) -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.page_rank = pb.PageRank( + reset_probability=self.reset_prob, + source_id=None + if self.source_id is None + else make_str_or_long_id(self.source_id), + max_iter=self.max_iter, + tol=self.tol, + ) + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan + + if (maxIter is None) == (tol is None): + # TODO: in classic it is not an axception but assert; + # at the same time I think it should be an exception. + raise ValueError("Exactly one of maxIter or tol should be set.") + + new_vertices = DataFrame.withPlan( + PageRank( + self._vertices, + self._edges, + reset_prob=resetProbability, + source_id=sourceId, + max_iter=maxIter, + tol=tol, + ), + self._spark, + ) + # TODO: should this part to be optional? Like 'compute_edge_weights'? + return self._update_page_rank_edge_weights(new_vertices) def parallelPersonalizedPageRank( self, resetProbability: float = 0.15, - sourceIds: Optional[list[Any]] = None, - maxIter: Optional[int] = None, - ) -> "GraphFrame": + sourceIds: list[str | int] | None = None, + maxIter: int | None = None, + ) -> "GraphFrameConnect": + class ParallelPersonalizedPageRank(LogicalPlan): + def __init__( + self, + v: DataFrame, + e: DataFrame, + reset_prob: float, + source_ids: list[str | int], + max_iter: int, + ) -> None: + self.v = v + self.e = e + self.reset_prob = reset_prob + self.source_ids = source_ids + self.max_iter = max_iter + + def plan(self, session: SparkConnectClient) -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.parallel_personalized_page_rank = ( + pb.ParallelPersonalizedPageRank( + reset_probability=self.reset_prob, + source_ids=[ + make_str_or_long_id(raw_id) for raw_id in self.source_ids + ], + max_iter=self.max_iter, + ) + ) + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan + assert ( sourceIds is not None and len(sourceIds) > 0 ), "Source vertices Ids sourceIds must be provided" assert maxIter is not None, "Max number of iterations maxIter must be provided" - sourceIds = self._sc._jvm.PythonUtils.toArray(sourceIds) - builder = self._jvm_graph.parallelPersonalizedPageRank() - builder = builder.resetProbability(resetProbability) - builder = builder.sourceIds(sourceIds) - builder = builder.maxIter(maxIter) - jgf = builder.run() - return _from_java_gf(jgf, self._spark) + + new_vertices = DataFrame.withPlan( + ParallelPersonalizedPageRank( + self._vertices, + self._edges, + reset_prob=resetProbability, + source_ids=sourceIds, + max_iter=maxIter, + ), + self._spark, + ) + return self._update_page_rank_edge_weights(new_vertices) def shortestPaths(self, landmarks: list[str | int]) -> DataFrame: class ShortestPaths(LogicalPlan): From fb784a35b30b8311d540caf68bb1e1f8f4a5c6c2 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sat, 8 Feb 2025 19:53:26 +0100 Subject: [PATCH 05/27] The first working version --- build.sbt | 17 +- dev/run_connect.py | 112 ++++ .../sql/graphframes/GraphFramesConnect.scala | 5 +- .../graphframes/GraphFramesConnectUtils.scala | 6 +- python/graphframes/connect/__init__.py | 0 .../graphframes/connect/graphframe_client.py | 138 +++-- python/tests/tests.py | 577 ++++++++++++++++++ 7 files changed, 798 insertions(+), 57 deletions(-) create mode 100644 dev/run_connect.py create mode 100644 python/graphframes/connect/__init__.py create mode 100644 python/tests/tests.py diff --git a/build.sbt b/build.sbt index cf4bb3ea4..298e4c477 100644 --- a/build.sbt +++ b/build.sbt @@ -1,4 +1,5 @@ -import ReleaseTransformations._ +import ReleaseTransformations.* +import sbtassembly.AssemblyPlugin.autoImport.assembly lazy val sparkVer = sys.props.getOrElse("spark.version", "3.5.4") lazy val sparkBranch = sparkVer.substring(0, 3) @@ -88,4 +89,16 @@ lazy val connect = (project in file("graphframes-connect")) Compile / PB.includePaths ++= Seq(file("src/main/protobuf")), PB.protocVersion := "3.23.4", // Spark 3.5 branch libraryDependencies ++= Seq( - "org.apache.spark" %% "spark-connect" % sparkVer % "provided" cross CrossVersion.for3Use2_13)) + "org.apache.spark" %% "spark-connect" % sparkVer % "provided" cross CrossVersion.for3Use2_13), + + // Assembly and shading + assembly / test := {}, + assembly / assemblyShadeRules := Seq( + ShadeRule.rename("com.google.protobuf.**" -> "org.sparkproject.connect.protobuf.@1").inAll), + assembly / assemblyMergeStrategy := { + case PathList("META-INF", xs @ _*) => MergeStrategy.discard + case x if x.endsWith("module-info.class") => MergeStrategy.discard + case x => + val oldStrategy = (assembly / assemblyMergeStrategy).value + oldStrategy(x) + }) diff --git a/dev/run_connect.py b/dev/run_connect.py new file mode 100644 index 000000000..14a7bf6e8 --- /dev/null +++ b/dev/run_connect.py @@ -0,0 +1,112 @@ +#!/usr/bin/python + +# Inspired by https://github.com/mrpowers-io/tsumugi-spark/blob/main/dev/run-connect.py + +import os +import shutil +import subprocess +import sys +from pathlib import Path + +SBT_BUILD_COMMAND = ["./build/sbt", "connect/assembly"] +SPARK_VERSION = "3.5.4" +SCALA_VERSION = "2.12" +GRAPHFRAMES_VERSION = "0.8.4" + + +if __name__ == "__main__": + prj_root = Path(__file__).parent.parent + scala_root = prj_root.joinpath("graphframes-connect") + + print("Build Graphframes...") + os.chdir(prj_root) + build_sbt = subprocess.run( + SBT_BUILD_COMMAND, + stdout=subprocess.PIPE, + universal_newlines=True, + ) + + if build_sbt.returncode == 0: + print("Done.") + else: + print(f"SBT build return an error: {build_sbt.returncode}") + print("stdout: ", build_sbt.stdout) + print("stderr: ", build_sbt.stderr) + sys.exit(1) + + tmp_dir = prj_root.joinpath("tmp") + tmp_dir.mkdir(exist_ok=True) + os.chdir(tmp_dir) + + unpackaed_spark_binary = f"spark-{SPARK_VERSION}-bin-hadoop3" + if not tmp_dir.joinpath(unpackaed_spark_binary).exists(): + print(f"Download spark {SPARK_VERSION}...") + if tmp_dir.joinpath(f"spark-{SPARK_VERSION}-bin-hadoop3.tgz").exists(): + shutil.rmtree( + tmp_dir.joinpath(f"spark-{SPARK_VERSION}-bin-hadoop3.tgz"), + ignore_errors=True, + ) + + get_spark = subprocess.run( + [ + "wget", + f"https://archive.apache.org/dist/spark/spark-{SPARK_VERSION}/spark-{SPARK_VERSION}-bin-hadoop3.tgz", + ], + stdout=subprocess.PIPE, + universal_newlines=True, + ) + if get_spark.returncode == 0: + print("Done.") + else: + print("Downlad failed.") + print("stdout: ", get_spark.stdout) + print("stdeerr: ", get_spark.stderr) + sys.exit(1) + + print("Unpack Spark...") + unpack_spark = subprocess.run( + [ + "tar", + "-xzf", + f"spark-{SPARK_VERSION}-bin-hadoop3.tgz", + ], + stdout=subprocess.PIPE, + universal_newlines=True, + ) + if unpack_spark.returncode == 0: + print("Done.") + else: + print("Unpacking failed.") + print("stdout: ", unpack_spark.stdout) + print("stdeerr: ", unpack_spark.stderr) + sys.exit(1) + + spark_home = tmp_dir.joinpath(unpackaed_spark_binary) + os.chdir(spark_home) + + gf_jar = ( + scala_root.joinpath("target") + .joinpath(f"scala-{SCALA_VERSION}") + .joinpath(f"graphframes-connect-assembly-{GRAPHFRAMES_VERSION}.jar") + ) + shutil.copyfile(gf_jar, spark_home.joinpath(gf_jar.name)) + + run_connect_command = [ + "./sbin/start-connect-server.sh", + "--wait", + "--jars", + f"{gf_jar}", + "--conf", + "spark.connect.extensions.relation.classes=org.apache.spark.sql.graphframes.GraphFramesConnect", + "--packages", + f"org.apache.spark:spark-connect_{SCALA_VERSION}:{SPARK_VERSION}", + ] + print("Starting SparkConnect Server...") + spark_connect = subprocess.run( + run_connect_command, + stdout=subprocess.PIPE, + universal_newlines=True, + ) + + if spark_connect.returncode == 0: + print("Done.") diff --git a/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnect.scala b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnect.scala index d079612c2..a8088a84b 100644 --- a/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnect.scala +++ b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnect.scala @@ -15,9 +15,10 @@ class GraphFramesConnect extends RelationPlugin { if (relation.is(classOf[GraphFramesAPI])) { val protoCall = relation.unpack(classOf[GraphFramesAPI]) // Because the plugins API is changed in spark 4.0 it makes sense to separate plugin impl from the parsing logic - Option(GraphFramesConnectUtils.parseAPICall(protoCall, planner).logicalPlan) + val result = GraphFramesConnectUtils.parseAPICall(protoCall, planner) + Some(result.logicalPlan) } else { - Option.empty + None } } } diff --git a/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala index f80d557d9..edc0e47d4 100644 --- a/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala +++ b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala @@ -45,10 +45,14 @@ object GraphFramesConnectUtils { private[graphframes] def parseDataFrame( data: ByteString, planner: SparkConnectPlanner): DataFrame = { + if (data.isEmpty) { + throw new IllegalArgumentException( + "Expected a serialized DataFrame but got an empty ByteString.") + } Dataset.ofRows( planner.sessionHolder.session, planner.transformRelation( - org.apache.spark.connect.proto.Relation.parseFrom(data.toByteArray))) + org.apache.spark.connect.proto.Plan.parseFrom(data.toByteArray).getRoot)) } private[graphframes] def extractGraphFrame( diff --git a/python/graphframes/connect/__init__.py b/python/graphframes/connect/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/graphframes/connect/graphframe_client.py b/python/graphframes/connect/graphframe_client.py index 68d9a74b7..f07e27c7a 100644 --- a/python/graphframes/connect/graphframe_client.py +++ b/python/graphframes/connect/graphframe_client.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Self +from typing_extensions import Self from pyspark.sql.connect import functions as F from pyspark.sql.connect import proto @@ -103,7 +103,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: vertices=dataframe_to_proto(self.vertices, session), edges=dataframe_to_proto(self.edges, session), ) - pb_message.pregel = pregel + pb_message.pregel.CopyFrom(pregel) plan.extension.Pack(pb_message) return plan @@ -235,6 +235,7 @@ def unpersist(self, blocking: bool = False) -> "GraphFrameConnect": def outDegrees(self) -> DataFrame: class OutDegrees(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame) -> None: + super().__init__(None) self.v = v self.e = e @@ -242,7 +243,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.out_degrees = pb.OutDegrees() + graphframes_api_call.out_degrees.CopyFrom(pb.OutDegrees()) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) return plan @@ -253,6 +254,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: def inDegrees(self) -> DataFrame: class InDegrees(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame) -> None: + super().__init__(None) self.v = v self.e = e @@ -260,7 +262,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.in_degrees = pb.InDegrees() + graphframes_api_call.in_degrees.CopyFrom(pb.InDegrees()) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) return plan @@ -271,6 +273,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: def degrees(self) -> DataFrame: class Degrees(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame) -> None: + super().__init__(None) self.v = v self.e = e @@ -278,7 +281,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.degrees = pb.Degrees() + graphframes_api_call.degrees.CopyFrom(pb.Degrees()) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) return plan @@ -289,6 +292,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: def triplets(self) -> DataFrame: class Triplets(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame) -> None: + super().__init__(None) self.v = v self.e = e @@ -296,7 +300,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.triplets = pb.Triplets() + graphframes_api_call.triplets.CopyFrom(pb.Triplets()) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) return plan @@ -310,6 +314,7 @@ def pregel(self): def find(self, pattern: str) -> DataFrame: class Find(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame, pattern: str) -> None: + super().__init__(None) self.v = v self.e = e self.p = pattern @@ -318,7 +323,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.find = pb.Find(pattern=self.p) + graphframes_api_call.find.CopyFrom(pb.Find(pattern=self.p)) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) return plan @@ -332,6 +337,7 @@ class FilterVertices(LogicalPlan): def __init__( self, v: DataFrame, e: DataFrame, condition: str | Column ) -> None: + super().__init__(None) self.v = v self.e = e self.c = condition @@ -341,8 +347,8 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: self.v, self.e, session ) col_or_expr = make_column_or_expr(self.c, session) - graphframes_api_call.filter_vertices = pb.FilterVertices( - condition=col_or_expr + graphframes_api_call.filter_vertices.CopyFrom( + pb.FilterVertices(condition=col_or_expr) ) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) @@ -368,6 +374,7 @@ class FilterEdges(LogicalPlan): def __init__( self, v: DataFrame, e: DataFrame, condition: str | Column ) -> None: + super().__init__(None) self.v = v self.e = e self.c = condition @@ -377,8 +384,8 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: self.v, self.e, session ) col_or_expr = make_column_or_expr(self.c, session) - graphframes_api_call.filter_edges = pb.FilterEdges( - condition=col_or_expr + graphframes_api_call.filter_edges.CopyFrom( + pb.FilterEdges(condition=col_or_expr) ) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) @@ -392,6 +399,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: def dropIsolatedVertices(self) -> "GraphFrameConnect": class DropIsolatedVertices(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame) -> None: + super().__init__(None) self.v = v self.e = e @@ -399,7 +407,9 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.drop_isolated_vertices = pb.DropIsolatedVertices() + graphframes_api_call.drop_isolated_vertices.CopyFrom( + pb.DropIsolatedVertices() + ) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) return plan @@ -426,6 +436,7 @@ def __init__( edge_filter: Column | str, max_path_len: int, ) -> None: + super().__init__(None) self.v = v self.e = e self.from_expr = from_expr @@ -437,11 +448,13 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.bfs = pb.BFS( - from_expr=make_column_or_expr(self.from_expr, session), - to_expr=make_column_or_expr(self.to_expr, session), - edge_filter=make_column_or_expr(self.edge_filter, session), - max_path_length=self.max_path_len, + graphframes_api_call.bfs.CopyFrom( + pb.BFS( + from_expr=make_column_or_expr(self.from_expr, session), + to_expr=make_column_or_expr(self.to_expr, session), + edge_filter=make_column_or_expr(self.edge_filter, session), + max_path_length=self.max_path_len, + ) ) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) @@ -477,6 +490,7 @@ def __init__( send2src: Column | str | None, send2dst: Column | str | None, ) -> None: + super().__init__(None) self.v = v self.e = e self.agg_col = agg_col @@ -487,14 +501,16 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.aggregate_messages = pb.AggregateMessages( - agg_col=make_column_or_expr(self.agg_col, session), - send_to_src=None - if self.send2src is None - else make_column_or_expr(self.send2src, session), - send_to_dst=None - if self.send2dst is None - else make_column_or_expr(self.send2dst, session), + graphframes_api_call.aggregate_messages.CopyFrom( + pb.AggregateMessages( + agg_col=make_column_or_expr(self.agg_col, session), + send_to_src=None + if self.send2src is None + else make_column_or_expr(self.send2src, session), + send_to_dst=None + if self.send2dst is None + else make_column_or_expr(self.send2dst, session), + ) ) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) @@ -527,6 +543,7 @@ def __init__( checkpoint_interval: int, broadcast_threshold: int, ) -> None: + super().__init__(None) self.v = v self.e = e self.algorithm = algorithm @@ -537,10 +554,12 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.connected_components = pb.ConnectedComponents( - algorithm=self.algorithm, - checkpoint_interval=self.checkpoint_interval, - broadcast_threshold=self.broadcast_threshold, + graphframes_api_call.connected_components.CopyFrom( + pb.ConnectedComponents( + algorithm=self.algorithm, + checkpoint_interval=self.checkpoint_interval, + broadcast_threshold=self.broadcast_threshold, + ) ) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) @@ -560,6 +579,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: def labelPropagation(self, maxIter: int) -> DataFrame: class LabelPropagation(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame, max_iter: int) -> None: + super().__init__(None) self.v = v self.e = e self.max_iter = max_iter @@ -568,8 +588,8 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.label_propagation = pb.LabelPropagation( - max_iter=self.max_iter + graphframes_api_call.label_propagation.CopyFrom( + pb.LabelPropagation(max_iter=self.max_iter) ) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) @@ -616,6 +636,7 @@ def __init__( max_iter: int | None, tol: float | None, ) -> None: + super().__init__(None) self.v = v self.e = e self.reset_prob = reset_prob @@ -627,13 +648,15 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.page_rank = pb.PageRank( - reset_probability=self.reset_prob, - source_id=None - if self.source_id is None - else make_str_or_long_id(self.source_id), - max_iter=self.max_iter, - tol=self.tol, + graphframes_api_call.page_rank.CopyFrom( + pb.PageRank( + reset_probability=self.reset_prob, + source_id=None + if self.source_id is None + else make_str_or_long_id(self.source_id), + max_iter=self.max_iter, + tol=self.tol, + ) ) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) @@ -673,6 +696,7 @@ def __init__( source_ids: list[str | int], max_iter: int, ) -> None: + super().__init__(None) self.v = v self.e = e self.reset_prob = reset_prob @@ -683,7 +707,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.parallel_personalized_page_rank = ( + graphframes_api_call.parallel_personalized_page_rank.CopyFrom( pb.ParallelPersonalizedPageRank( reset_probability=self.reset_prob, source_ids=[ @@ -718,6 +742,7 @@ class ShortestPaths(LogicalPlan): def __init__( self, v: DataFrame, e: DataFrame, landmarks: list[str | int] ) -> None: + super().__init__(None) self.v = v self.e = e self.landmarks = landmarks @@ -726,8 +751,12 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.shortest_paths = pb.ShortestPaths( - landmarks=[make_str_or_long_id(raw_id) for raw_id in self.landmarks] + graphframes_api_call.shortest_paths.CopyFrom( + pb.ShortestPaths( + landmarks=[ + make_str_or_long_id(raw_id) for raw_id in self.landmarks + ] + ) ) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) @@ -740,6 +769,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: def stronglyConnectedComponents(self, maxIter: int) -> DataFrame: class StronglyConnectedComponents(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame, max_iter: int) -> None: + super().__init__(None) self.v = v self.e = e self.max_iter = max_iter @@ -748,7 +778,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.strongly_connected_components = ( + graphframes_api_call.strongly_connected_components.CopyFrom( pb.StronglyConnectedComponents(max_iter=self.max_iter) ) plan = self._create_proto_relation() @@ -786,6 +816,7 @@ def __init__( gamma6: float, gamma7: float, ) -> None: + super().__init__(None) self.v = v self.e = e self.rank = rank @@ -801,15 +832,17 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.svd_plus_plus = pb.SVDPlusPlus( - rank=self.rank, - max_iter=self.max_iter, - min_value=self.min_value, - max_value=self.max_value, - gamma1=self.gamma1, - gamma2=self.gamma2, - gamma6=self.gamma6, - gamma7=self.gamma7, + graphframes_api_call.svd_plus_plus.CopyFrom( + pb.SVDPlusPlus( + rank=self.rank, + max_iter=self.max_iter, + min_value=self.min_value, + max_value=self.max_value, + gamma1=self.gamma1, + gamma2=self.gamma2, + gamma6=self.gamma6, + gamma7=self.gamma7, + ) ) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) @@ -840,6 +873,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: def triangleCount(self) -> DataFrame: class TriangleCount(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame) -> None: + super().__init__(None) self.v = v self.e = e @@ -847,7 +881,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.triangle_count = pb.TriangleCount() + graphframes_api_call.triangle_count.CopyFrom(pb.TriangleCount()) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) return plan diff --git a/python/tests/tests.py b/python/tests/tests.py new file mode 100644 index 000000000..f8330cc16 --- /dev/null +++ b/python/tests/tests.py @@ -0,0 +1,577 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import re +import shutil +import sys +import tempfile +import unittest +from pathlib import Path + +from pyspark.sql import SparkSession +from pyspark.sql import functions as sqlfunctions + +# Workaround that won't be needed after setting up setup.py +prj_root = Path(__file__).parent.parent +sys.path.insert(0, prj_root.absolute().__str__()) + +from graphframes.examples import BeliefPropagation, Graphs +from graphframes.graphframe import GraphFrame, Pregel, _from_java_gf, _java_api +from graphframes.connect.graphframe_client import GraphFrameConnect, PregelConnect +from graphframes.lib import AggregateMessages as AM + + +class GraphFrameTestUtils(object): + @classmethod + def parse_spark_version(cls, version_str): + """take an input version string + return version items in a dictionary + """ + _sc_ver_patt = r"(\d+)\.(\d+)(\.(\d+)(-(.+))?)?" + m = re.match(_sc_ver_patt, version_str) + if not m: + raise TypeError( + "version {} shoud be in ..".format( + version_str + ) + ) + version_info = {} + try: + version_info["major"] = int(m.group(1)) + except: + raise TypeError("invalid minor version") + try: + version_info["minor"] = int(m.group(2)) + except: + raise TypeError("invalid major version") + try: + version_info["maintenance"] = int(m.group(4)) + except: + version_info["maintenance"] = 0 + try: + version_info["special"] = m.group(6) + except: + pass + return version_info + + @classmethod + def createSparkSession(cls): + cls.checkpointDir = tempfile.mkdtemp() + if "SPARK_CONNECT_MODE_ENABLED" in os.environ: + cls.sc = ( + SparkSession.builder.remote("sc://localhost:15002") + .appName("GraphFramesTest") + .config("spark.sql.shuffle.partitions", 4) + .config("spark.checkpoint.dir", cls.checkpointDir) + .getOrCreate() + ) + else: + cls.sc = ( + SparkSession.builder.master("local[4]") + .appName("GraphFramesTest") + .config("spark.sql.shuffle.partitions", 4) + .config("spark.checkpoint.dir", cls.checkpointDir) + .getOrCreate() + ) + + assert cls.sc is not None + cls.spark_version = cls.parse_spark_version(cls.sc.version) + + @classmethod + def stopSparkSession(cls): + assert cls.sc is not None + cls.sc.stop() + cls.sc = None + shutil.rmtree(cls.checkpointDir) + + @classmethod + def spark_at_least_of_version(cls, version_str): + assert hasattr(cls, "spark_version") + required_version = cls.parse_spark_version(version_str) + spark_version = cls.spark_version + for _name in ["major", "minor", "maintenance"]: + sc_ver = spark_version[_name] + req_ver = required_version[_name] + if sc_ver != req_ver: + return sc_ver > req_ver + # All major.minor.maintenance equal + return True + + +def setUpModule(): + GraphFrameTestUtils.createSparkSession() + + +def tearDownModule(): + GraphFrameTestUtils.stopSparkSession() + + +class GraphFrameTestCase(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Small tests run much faster with spark.sql.shuffle.partitions = 4 + if "SPARK_CONNECT_MODE_ENABLED" in os.environ: + cls.spark = SparkSession.builder.remote( + "sc://localhost:15002" + ).getOrCreate() + cls.connect_mode = 1 + else: + cls.spark = ( + SparkSession(GraphFrameTestUtils.sc) + .builder.config("spark.sql.shuffle.partitions", 4) + .getOrCreate() + ) + cls.connect_mode = 0 + + @classmethod + def tearDownClass(cls): + cls.spark = None + cls.connect_mode = None + + +class GraphFrameTest(GraphFrameTestCase): + def setUp(self): + super(GraphFrameTest, self).setUp() + localVertices = [(1, "A"), (2, "B"), (3, "C")] + localEdges = [(1, 2, "love"), (2, 1, "hate"), (2, 3, "follow")] + v = self.spark.createDataFrame(localVertices, ["id", "name"]) + e = self.spark.createDataFrame(localEdges, ["src", "dst", "action"]) + if self.connect_mode: + self.g = GraphFrameConnect(v, e) + else: + self.g = GraphFrame(v, e) + + def test_spark_version_check(self): + # SparkContext is not available in Spark Connect + if not self.connect_mode: + gtu = GraphFrameTestUtils + gtu.spark_version = gtu.parse_spark_version("2.0.2") + self.assertTrue(gtu.spark_at_least_of_version("1.7")) + self.assertTrue(gtu.spark_at_least_of_version("2.0")) + self.assertTrue(gtu.spark_at_least_of_version("2.0.1")) + self.assertTrue(gtu.spark_at_least_of_version("2.0.2")) + self.assertFalse(gtu.spark_at_least_of_version("2.0.3")) + self.assertFalse(gtu.spark_at_least_of_version("2.1")) + + def test_construction(self): + g = self.g + vertexIDs = map(lambda x: x[0], g.vertices.select("id").collect()) + assert sorted(vertexIDs) == [1, 2, 3] + edgeActions = map(lambda x: x[0], g.edges.select("action").collect()) + assert sorted(edgeActions) == ["follow", "hate", "love"] + tripletsFirst = list( + map( + lambda x: (x[0][1], x[1][1], x[2][2]), + g.triplets.sort("src.id").select("src", "dst", "edge").take(1), + ) + ) + assert tripletsFirst == [("A", "B", "love")], tripletsFirst + # Try with invalid vertices and edges DataFrames + v_invalid = self.spark.createDataFrame( + [(1, "A"), (2, "B"), (3, "C")], ["invalid_colname_1", "invalid_colname_2"] + ) + e_invalid = self.spark.createDataFrame( + [(1, 2), (2, 3), (3, 1)], ["invalid_colname_3", "invalid_colname_4"] + ) + with self.assertRaises(ValueError): + if self.connect_mode: + GraphFrameConnect(v_invalid, e_invalid) + else: + GraphFrame(v_invalid, e_invalid) + + def test_cache(self): + g = self.g + g.cache() + g.unpersist() + + def test_degrees(self): + g = self.g + outDeg = g.outDegrees + self.assertSetEqual(set(outDeg.columns), {"id", "outDegree"}) + inDeg = g.inDegrees + self.assertSetEqual(set(inDeg.columns), {"id", "inDegree"}) + deg = g.degrees + self.assertSetEqual(set(deg.columns), {"id", "degree"}) + + def test_motif_finding(self): + g = self.g + motifs = g.find("(a)-[e]->(b)") + assert motifs.count() == 3 + self.assertSetEqual(set(motifs.columns), {"a", "e", "b"}) + + def test_filterVertices(self): + g = self.g + conditions = ["id < 3", g.vertices.id < 3] + expected_v = [(1, "A"), (2, "B")] + expected_e = [(1, 2, "love"), (2, 1, "hate")] + for cond in conditions: + g2 = g.filterVertices(cond) + v2 = g2.vertices.select("id", "name").collect() + e2 = g2.edges.select("src", "dst", "action").collect() + assert len(v2) == len(expected_v) + assert len(e2) == len(expected_e) + self.assertSetEqual(set(v2), set(expected_v)) + self.assertSetEqual(set(e2), set(expected_e)) + + def test_filterEdges(self): + g = self.g + conditions = ["dst > 2", g.edges.dst > 2] + expected_v = [(1, "A"), (2, "B"), (3, "C")] + expected_e = [(2, 3, "follow")] + for cond in conditions: + g2 = g.filterEdges(cond) + v2 = g2.vertices.select("id", "name").collect() + e2 = g2.edges.select("src", "dst", "action").collect() + assert len(v2) == len(expected_v) + assert len(e2) == len(expected_e) + self.assertSetEqual(set(v2), set(expected_v)) + self.assertSetEqual(set(e2), set(expected_e)) + + def test_dropIsolatedVertices(self): + g = self.g + g2 = g.filterEdges("dst > 2").dropIsolatedVertices() + v2 = g2.vertices.select("id", "name").collect() + e2 = g2.edges.select("src", "dst", "action").collect() + expected_v = [(2, "B"), (3, "C")] + expected_e = [(2, 3, "follow")] + assert len(v2) == len(expected_v) + assert len(e2) == len(expected_e) + self.assertSetEqual(set(v2), set(expected_v)) + self.assertSetEqual(set(e2), set(expected_e)) + + def test_bfs(self): + g = self.g + paths = g.bfs("name='A'", "name='C'") + self.assertEqual(paths.count(), 1) + self.assertEqual(paths.select("v1.name").head()[0], "B") + paths2 = g.bfs("name='A'", "name='C'", edgeFilter="action!='follow'") + self.assertEqual(paths2.count(), 0) + paths3 = g.bfs("name='A'", "name='C'", maxPathLength=1) + self.assertEqual(paths3.count(), 0) + + +class PregelTest(GraphFrameTestCase): + def setUp(self): + super(PregelTest, self).setUp() + + def test_page_rank(self): + from pyspark.sql.functions import coalesce, col, lit, sum, when + + edges = self.spark.createDataFrame( + [ + [0, 1], + [1, 2], + [2, 4], + [2, 0], + [3, 4], # 3 has no in-links + [4, 0], + [4, 2], + ], + ["src", "dst"], + ) + edges.cache() + assert self.spark is not None + vertices = self.spark.createDataFrame([[0], [1], [2], [3], [4]], ["id"]) + numVertices = vertices.count() + + if self.connect_mode: + vertices = GraphFrameConnect(vertices, edges).outDegrees + else: + vertices = GraphFrame(vertices, edges).outDegrees + vertices.cache() + + if self.connect_mode: + graph = GraphFrameConnect(vertices, edges) + else: + graph = GraphFrame(vertices, edges) + alpha = 0.15 + ranks = ( + graph.pregel.setMaxIter(5) + .withVertexColumn( + "rank", + lit(1.0 / numVertices), + coalesce(Pregel.msg(), lit(0.0)) * lit(1.0 - alpha) + + lit(alpha / numVertices), + ) + .sendMsgToDst(Pregel.src("rank") / Pregel.src("outDegree")) + .aggMsgs(sum(Pregel.msg())) + .run() + ) + resultRows = ranks.sort(ranks.id).collect() + result = map(lambda x: x.rank, resultRows) + expected = [0.245, 0.224, 0.303, 0.03, 0.197] + for a, b in zip(result, expected): + self.assertAlmostEqual(a, b, delta=1e-3) + + +class GraphFrameLibTest(GraphFrameTestCase): + def setUp(self): + super(GraphFrameLibTest, self).setUp() + if self.connect_mode: + self.japi = None + else: + self.japi = _java_api(self.spark._sc) + + def _hasCols(self, graph, vcols=[], ecols=[]): + map(lambda c: self.assertIn(c, graph.vertices.columns), vcols) + map(lambda c: self.assertIn(c, graph.edges.columns), ecols) + + def _df_hasCols(self, vertices, vcols=[]): + map(lambda c: self.assertIn(c, vertices.columns), vcols) + + def _graph(self, name, *args): + """ + Convenience to call one of the example graphs, passing the arguments and wrapping the result back + as a python object. + :param name: the name of the example graph + :param args: all the required arguments, without the initial spark session + :return: + """ + examples = self.japi.examples() + jgraph = getattr(examples, name)(*args) + return _from_java_gf(jgraph, self.spark) + + @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + def test_aggregate_messages(self): + g = self._graph("friends") + # For each user, sum the ages of the adjacent users, + # plus 1 for the src's sum if the edge is "friend". + sendToSrc = AM.dst["age"] + sqlfunctions.when( + AM.edge["relationship"] == "friend", sqlfunctions.lit(1) + ).otherwise(0) + sendToDst = AM.src["age"] + agg = g.aggregateMessages( + sqlfunctions.sum(AM.msg).alias("summedAges"), + sendToSrc=sendToSrc, + sendToDst=sendToDst, + ) + # Run the aggregation again providing SQL expressions as String instead. + agg2 = g.aggregateMessages( + "sum(MSG) AS `summedAges`", + sendToSrc="(dst['age'] + CASE WHEN (edge['relationship'] = 'friend') THEN 1 ELSE 0 END)", + sendToDst="src['age']", + ) + # Convert agg and agg2 to a mapping from id to the aggregated message. + aggMap = {id_: s for id_, s in agg.select("id", "summedAges").collect()} + agg2Map = {id_: s for id_, s in agg2.select("id", "summedAges").collect()} + # Compute the truth via brute force. + user2age = {id_: age for id_, age in g.vertices.select("id", "age").collect()} + trueAgg = {} + for src, dst, rel in g.edges.select("src", "dst", "relationship").collect(): + trueAgg[src] = ( + trueAgg.get(src, 0) + user2age[dst] + (1 if rel == "friend" else 0) + ) + trueAgg[dst] = trueAgg.get(dst, 0) + user2age[src] + # Compare if the agg mappings match the brute force mapping + self.assertEqual(aggMap, trueAgg) + self.assertEqual(agg2Map, trueAgg) + # Check that TypeError is raises with messages of wrong type + with self.assertRaises(TypeError): + g.aggregateMessages( + "sum(MSG) AS `summedAges`", sendToSrc=object(), sendToDst="src['age']" + ) + with self.assertRaises(TypeError): + g.aggregateMessages( + "sum(MSG) AS `summedAges`", sendToSrc=dst["age"], sendToDst=object() + ) + + def test_connected_components(self): + v = self.spark.createDataFrame([(0, "a", "b")], ["id", "vattr", "gender"]) + e = self.spark.createDataFrame([(0, 0, 1)], ["src", "dst", "test"]).filter( + "src > 10" + ) + if self.connect_mode: + g = GraphFrameConnect(v, e) + else: + g = GraphFrame(v, e) + comps = g.connectedComponents() + self._df_hasCols(comps, vcols=["id", "component", "vattr", "gender"]) + self.assertEqual(comps.count(), 1) + + def test_connected_components2(self): + v = self.spark.createDataFrame( + [(0, "a0", "b0"), (1, "a1", "b1")], ["id", "A", "B"] + ) + e = self.spark.createDataFrame([(0, 1, "a01", "b01")], ["src", "dst", "A", "B"]) + if self.connect_mode: + g = GraphFrameConnect(v, e) + else: + g = GraphFrame(v, e) + comps = g.connectedComponents() + self._df_hasCols(comps, vcols=["id", "component", "A", "B"]) + self.assertEqual(comps.count(), 2) + + @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + def test_connected_components_friends(self): + g = self._graph("friends") + comps_tests = [] + comps_tests += [g.connectedComponents()] + comps_tests += [g.connectedComponents(broadcastThreshold=1)] + comps_tests += [g.connectedComponents(checkpointInterval=0)] + comps_tests += [g.connectedComponents(checkpointInterval=10)] + comps_tests += [g.connectedComponents(algorithm="graphx")] + for c in comps_tests: + self.assertEqual(c.groupBy("component").count().count(), 2) + + @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + def test_label_progagation(self): + n = 5 + g = self._graph("twoBlobs", n) + labels = g.labelPropagation(maxIter=4 * n) + labels1 = labels.filter("id < 5").select("label").collect() + all1 = set([x.label for x in labels1]) + assert len(all1) == 1 + labels2 = labels.filter("id >= 5").select("label").collect() + all2 = set([x.label for x in labels2]) + assert len(all2) == 1 + assert all1 != all2 + + @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + def test_page_rank(self): + n = 100 + g = self._graph("star", n) + resetProb = 0.15 + errorTol = 1.0e-5 + pr = g.pageRank(resetProb, tol=errorTol) + self._hasCols(pr, vcols=["id", "pagerank"], ecols=["src", "dst", "weight"]) + + @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + def test_parallel_personalized_page_rank(self): + n = 100 + g = self._graph("star", n) + resetProb = 0.15 + maxIter = 15 + sourceIds = [1, 2, 3, 4] + pr = g.parallelPersonalizedPageRank( + resetProb, sourceIds=sourceIds, maxIter=maxIter + ) + self._hasCols(pr, vcols=["id", "pageranks"], ecols=["src", "dst", "weight"]) + + def test_shortest_paths(self): + edges = [(1, 2), (1, 5), (2, 3), (2, 5), (3, 4), (4, 5), (4, 6)] + all_edges = [z for (a, b) in edges for z in [(a, b), (b, a)]] + assert self.spark is not None + edges = self.spark.createDataFrame(all_edges, ["src", "dst"]) + vertices = self.spark.createDataFrame([(i,) for i in range(1, 7)], ["id"]) + if self.connect_mode: + g = GraphFrameConnect(vertices, edges) + else: + g = GraphFrame(vertices, edges) + landmarks = [1, 4] + v2 = g.shortestPaths(landmarks) + self._df_hasCols(v2, vcols=["id", "distances"]) + + @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + def test_svd_plus_plus(self): + g = self._graph("ALSSyntheticData") + (v2, cost) = g.svdPlusPlus() + self._df_hasCols(v2, vcols=["id", "column1", "column2", "column3", "column4"]) + + def test_strongly_connected_components(self): + # Simple island test + assert self.spark is not None + vertices = self.spark.createDataFrame([(i,) for i in range(1, 6)], ["id"]) + edges = self.spark.createDataFrame([(7, 8)], ["src", "dst"]) + if self.connect_mode: + g = GraphFrameConnect(vertices, edges) + else: + g = GraphFrame(vertices, edges) + c = g.stronglyConnectedComponents(5) + for row in c.collect(): + self.assertEqual(row.id, row.component) + + def test_triangle_counts(self): + assert self.spark is not None + edges = self.spark.createDataFrame([(0, 1), (1, 2), (2, 0)], ["src", "dst"]) + vertices = self.spark.createDataFrame([(0,), (1,), (2,)], ["id"]) + if self.connect_mode: + g = GraphFrameConnect(vertices, edges) + else: + g = GraphFrame(vertices, edges) + c = g.triangleCount() + for row in c.select("id", "count").collect(): + self.assertEqual(row.asDict()["count"], 1) + + @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + def test_mutithreaded_sparksession_usage(self): + # Test that we can use the GraphFrame API from multiple threads + localVertices = [(1, "A"), (2, "B"), (3, "C")] + localEdges = [(1, 2, "love"), (2, 1, "hate"), (2, 3, "follow")] + v = self.spark.createDataFrame(localVertices, ["id", "name"]) + e = self.spark.createDataFrame(localEdges, ["src", "dst", "action"]) + + exc = None + + def run_graphframe() -> None: + try: + GraphFrame(v, e) + except Exception as _e: + nonlocal exc + exc = _e + + import threading + + thread = threading.Thread(target=run_graphframe) + thread.start() + thread.join() + self.assertIsNone(exc, f"Exception was raised in thread: {exc}") + + +class GraphFrameExamplesTest(GraphFrameTestCase): + def setUp(self): + super(GraphFrameExamplesTest, self).setUp() + self.japi = _java_api(self.spark._sc) + + @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + def test_belief_propagation(self): + # create graphical model g of size 3 x 3 + g = Graphs(self.spark).gridIsingModel(3) + # run BP for 5 iterations + numIter = 5 + results = BeliefPropagation.runBPwithGraphFrames(g, numIter) + # check beliefs are valid + for row in results.vertices.select("belief").collect(): + belief = row["belief"] + self.assertTrue( + 0 <= belief <= 1, + msg="Expected belief to be probability in [0,1], but found {}".format( + belief + ), + ) + + @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + def test_graph_friends(self): + # construct graph + g = Graphs(self.spark).friends() + # check that a GraphFrame instance was returned + self.assertIsInstance(g, GraphFrame) + + @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + def test_graph_grid_ising_model(self): + # construct graph + n = 3 + g = Graphs(self.spark).gridIsingModel(n) + # check that all the vertices exist + ids = [v["id"] for v in g.vertices.collect()] + for i in range(n): + for j in range(n): + self.assertIn("{},{}".format(i, j), ids) + + +if __name__ == "__main__": + unittest.main() From 20f7575c38902f65f6497f43267aa6cfa99dacf1 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 23 Feb 2025 17:18:36 +0100 Subject: [PATCH 06/27] WIP --- .github/workflows/python-ci.yml | 6 +- .gitignore | 6 + build.sbt | 2 + python/dev/build_jar.py | 54 +++ python/graphframes/classic/graphframe.py | 476 +++++++++++++++++++++++ python/graphframes/graphframe.py | 347 ++++++++--------- python/graphframes/tests.py | 468 ---------------------- python/pyproject.toml | 11 +- python/tests/__init__.py | 0 python/tests/tests.py | 222 +++-------- 10 files changed, 761 insertions(+), 831 deletions(-) create mode 100644 python/dev/build_jar.py create mode 100644 python/graphframes/classic/graphframe.py delete mode 100644 python/graphframes/tests.py create mode 100644 python/tests/__init__.py diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index 1095ce49e..e943f19d4 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -26,8 +26,6 @@ jobs: path: | ~/.ivy2/cache key: sbt-ivy-cache-spark-${{ matrix.spark-version}}-scala-${{ matrix.scala-version }} - - name: Assembly - run: build/sbt -v ++${{ matrix.scala-version }} -Dspark.version=${{ matrix.spark-version }} "set test in assembly := {}" assembly - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} @@ -41,10 +39,8 @@ jobs: - name: Build Python package and its dependencies working-directory: ./python run: | - poetry build poetry install - name: Test working-directory: ./python run: | - export SPARK_HOME=$(poetry run python -c "import os; from importlib.util import find_spec; spec = find_spec('pyspark'); print(os.path.join(os.path.dirname(spec.origin)))") - ./run-tests.sh + poetry run python -m unittest discover -s tests/ diff --git a/.gitignore b/.gitignore index 93246acbe..4d0a174e7 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,9 @@ python/graphframes.egg-info python/graphframes/tutorials/data python/docs/_build python/docs/_site + +# JAR that is build during the installation +python/graphframes/resources/* + +# tmp data for spark connect +tmp/* diff --git a/build.sbt b/build.sbt index 2776ca730..e8af6e344 100644 --- a/build.sbt +++ b/build.sbt @@ -4,6 +4,8 @@ lazy val sparkVer = sys.props.getOrElse("spark.version", "3.5.4") lazy val sparkBranch = sparkVer.substring(0, 3) lazy val defaultScalaVer = sparkBranch match { case "3.5" => "2.12.18" + case "3.4" => "2.12.17" + case "3.3" => "2.12.15" case _ => throw new IllegalArgumentException(s"Unsupported Spark version: $sparkVer.") } lazy val scalaVer = sys.props.getOrElse("scala.version", defaultScalaVer) diff --git a/python/dev/build_jar.py b/python/dev/build_jar.py new file mode 100644 index 000000000..0c40d0e37 --- /dev/null +++ b/python/dev/build_jar.py @@ -0,0 +1,54 @@ +import shutil +import subprocess +import sys +from pathlib import Path + + +def build(spark_version: str = "3.5.4"): + print("Building GraphFrames JAR...") + print(f"SPARK_VERSION: {spark_version[:3]}") + assert spark_version[:3] in {"3.3", "3.4", "3.5"}, "Unsopported spark version!" + project_root = Path(__file__).parent.parent.parent + sbt_executable = project_root.joinpath("build").joinpath("sbt").absolute().__str__() + sbt_build_command = [sbt_executable, f"-Dspark.version={spark_version}", "assembly"] + sbt_build = subprocess.Popen( + sbt_build_command, + stdout=subprocess.PIPE, + universal_newlines=True, + cwd=project_root, + ) + while sbt_build.poll() is None: + assert sbt_build.stdout is not None # typing stuff + line = sbt_build.stdout.readline() + print(line.rstrip(), flush=True) + + if sbt_build.returncode != 0: + print("Error during the build of GraphFrames JAR!") + print("stdout: ", sbt_build.stdout) + print("stdeerr: ", sbt_build.stderr) + sys.exit(1) + else: + print("Building DONE successfully!") + + python_resources = ( + project_root.joinpath("python").joinpath("graphframes").joinpath("resources") + ) + target_dir = project_root.joinpath("target").joinpath("scala-2.12") + gf_jar = None + + for pp in target_dir.glob("*.jar"): + if "graphframes-assembly" in pp.name: + gf_jar = pp + break + + assert gf_jar is not None, "Missing JAR!" + python_resources.mkdir(parents=True, exist_ok=True) + shutil.copy(gf_jar, python_resources.joinpath(gf_jar.name)) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + spark_version = sys.argv[1] + build(spark_version) + else: + build() diff --git a/python/graphframes/classic/graphframe.py b/python/graphframes/classic/graphframe.py new file mode 100644 index 000000000..5381ec8b5 --- /dev/null +++ b/python/graphframes/classic/graphframe.py @@ -0,0 +1,476 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +from typing import Any, Union, Optional + +if sys.version > '3': + basestring = str + +from graphframes.lib import Pregel +from pyspark import SparkContext +from pyspark.sql import Column, DataFrame, SparkSession +from pyspark.storagelevel import StorageLevel + + +def _from_java_gf(jgf: Any, spark: SparkSession) -> 'GraphFrame': + """ + (internal) creates a python GraphFrame wrapper from a java GraphFrame. + + :param jgf: + """ + pv = DataFrame(jgf.vertices(), spark) + pe = DataFrame(jgf.edges(), spark) + return GraphFrame(pv, pe) + +def _java_api(jsc: SparkContext) -> Any: + javaClassName = "org.graphframes.GraphFramePythonAPI" + return jsc._jvm.Thread.currentThread().getContextClassLoader().loadClass(javaClassName) \ + .newInstance() + + +class GraphFrame: + """ + Represents a graph with vertices and edges stored as DataFrames. + + :param v: :class:`DataFrame` holding vertex information. + Must contain a column named "id" that stores unique + vertex IDs. + :param e: :class:`DataFrame` holding edge information. + Must contain two columns "src" and "dst" storing source + vertex IDs and destination vertex IDs of edges, respectively. + + >>> localVertices = [(1,"A"), (2,"B"), (3, "C")] + >>> localEdges = [(1,2,"love"), (2,1,"hate"), (2,3,"follow")] + >>> v = spark.createDataFrame(localVertices, ["id", "name"]) + >>> e = spark.createDataFrame(localEdges, ["src", "dst", "action"]) + >>> g = GraphFrame(v, e) + """ + + def __init__(self, v: DataFrame, e: DataFrame) -> None: + self._vertices = v + self._edges = e + self._spark = v.sparkSession + self._sc = self._spark._sc + self._jvm_gf_api = _java_api(self._sc) + + self.ID = self._jvm_gf_api.ID() + self.SRC = self._jvm_gf_api.SRC() + self.DST = self._jvm_gf_api.DST() + self._ATTR = self._jvm_gf_api.ATTR() + + # Check that provided DataFrames contain required columns + if self.ID not in v.columns: + raise ValueError( + "Vertex ID column {} missing from vertex DataFrame, which has columns: {}" + .format(self.ID, ",".join(v.columns))) + if self.SRC not in e.columns: + raise ValueError( + "Source vertex ID column {} missing from edge DataFrame, which has columns: {}" + .format(self.SRC, ",".join(e.columns))) + if self.DST not in e.columns: + raise ValueError( + "Destination vertex ID column {} missing from edge DataFrame, which has columns: {}" + .format(self.DST, ",".join(e.columns))) + + self._jvm_graph = self._jvm_gf_api.createGraph(v._jdf, e._jdf) + + @property + def vertices(self) -> DataFrame: + """ + :class:`DataFrame` holding vertex information, with unique column "id" + for vertex IDs. + """ + return self._vertices + + @property + def edges(self) -> DataFrame: + """ + :class:`DataFrame` holding edge information, with unique columns "src" and + "dst" storing source vertex IDs and destination vertex IDs of edges, + respectively. + """ + return self._edges + + def __repr__(self): + return self._jvm_graph.toString() + + def cache(self) -> 'GraphFrame': + """ Persist the dataframe representation of vertices and edges of the graph with the default + storage level. + """ + self._jvm_graph.cache() + return self + + def persist(self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) -> "GraphFrame": + """Persist the dataframe representation of vertices and edges of the graph with the given + storage level. + """ + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jvm_graph.persist(javaStorageLevel) + return self + + def unpersist(self, blocking: bool = False) -> 'GraphFrame': + """Mark the dataframe representation of vertices and edges of the graph as non-persistent, + and remove all blocks for it from memory and disk. + """ + self._jvm_graph.unpersist(blocking) + return self + + @property + def outDegrees(self) -> DataFrame: + """ + The out-degree of each vertex in the graph, returned as a DataFrame with two columns: + - "id": the ID of the vertex + - "outDegree" (integer) storing the out-degree of the vertex + + Note that vertices with 0 out-edges are not returned in the result. + + :return: DataFrame with new vertices column "outDegree" + """ + jdf = self._jvm_graph.outDegrees() + return DataFrame(jdf, self._spark) + + @property + def inDegrees(self) -> DataFrame: + """ + The in-degree of each vertex in the graph, returned as a DataFame with two columns: + - "id": the ID of the vertex + - "inDegree" (int) storing the in-degree of the vertex + + Note that vertices with 0 in-edges are not returned in the result. + + :return: DataFrame with new vertices column "inDegree" + """ + jdf = self._jvm_graph.inDegrees() + return DataFrame(jdf, self._spark) + + @property + def degrees(self) -> DataFrame: + """ + The degree of each vertex in the graph, returned as a DataFrame with two columns: + - "id": the ID of the vertex + - 'degree' (integer) the degree of the vertex + + Note that vertices with 0 edges are not returned in the result. + + :return: DataFrame with new vertices column "degree" + """ + jdf = self._jvm_graph.degrees() + return DataFrame(jdf, self._spark) + + @property + def triplets(self) -> DataFrame: + """ + The triplets (source vertex)-[edge]->(destination vertex) for all edges in the graph. + + Returned as a :class:`DataFrame` with three columns: + - "src": source vertex with schema matching 'vertices' + - "edge": edge with schema matching 'edges' + - 'dst': destination vertex with schema matching 'vertices' + + :return: DataFrame with columns 'src', 'edge', and 'dst' + """ + jdf = self._jvm_graph.triplets() + return DataFrame(jdf, self._spark) + + @property + def pregel(self): + """ + Get the :class:`graphframes.lib.Pregel` object for running pregel. + + See :class:`graphframes.lib.Pregel` for more details. + """ + return Pregel(self) + + def find(self, pattern: str) -> DataFrame: + """ + Motif finding. + + See Scala documentation for more details. + + :param pattern: String describing the motif to search for. + :return: DataFrame with one Row for each instance of the motif found + """ + jdf = self._jvm_graph.find(pattern) + return DataFrame(jdf, self._spark) + + def filterVertices(self, condition: Union[str, Column]) -> 'GraphFrame': + """ + Filters the vertices based on expression, remove edges containing any dropped vertices. + + :param condition: String or Column describing the condition expression for filtering. + :return: GraphFrame with filtered vertices and edges. + """ + + if isinstance(condition, basestring): + jdf = self._jvm_graph.filterVertices(condition) + elif isinstance(condition, Column): + jdf = self._jvm_graph.filterVertices(condition._jc) + else: + raise TypeError("condition should be string or Column") + return _from_java_gf(jdf, self._spark) + + def filterEdges(self, condition: Union[str, Column]) -> 'GraphFrame': + """ + Filters the edges based on expression, keep all vertices. + + :param condition: String or Column describing the condition expression for filtering. + :return: GraphFrame with filtered edges. + """ + if isinstance(condition, basestring): + jdf = self._jvm_graph.filterEdges(condition) + elif isinstance(condition, Column): + jdf = self._jvm_graph.filterEdges(condition._jc) + else: + raise TypeError("condition should be string or Column") + return _from_java_gf(jdf, self._spark) + + def dropIsolatedVertices(self) -> 'GraphFrame': + """ + Drops isolated vertices, vertices are not contained in any edges. + + :return: GraphFrame with filtered vertices. + """ + jdf = self._jvm_graph.dropIsolatedVertices() + return _from_java_gf(jdf, self._spark) + + def bfs(self, fromExpr: str, toExpr: str, + edgeFilter: Optional[str] = None, + maxPathLength: int = 10) -> DataFrame: + """ + Breadth-first search (BFS). + + See Scala documentation for more details. + + :return: DataFrame with one Row for each shortest path between matching vertices. + """ + builder = self._jvm_graph.bfs()\ + .fromExpr(fromExpr)\ + .toExpr(toExpr)\ + .maxPathLength(maxPathLength) + if edgeFilter is not None: + builder.edgeFilter(edgeFilter) + jdf = builder.run() + return DataFrame(jdf, self._spark) + + def aggregateMessages(self, aggCol: Union[Column, str], + sendToSrc: Union[Column, str, None] = None, + sendToDst: Union[Column, str, None] = None) -> DataFrame: + """ + Aggregates messages from the neighbours. + + When specifying the messages and aggregation function, the user may reference columns using + the static methods in :class:`graphframes.lib.AggregateMessages`. + + See Scala documentation for more details. + + :param aggCol: the requested aggregation output either as + :class:`pyspark.sql.Column` or SQL expression string + :param sendToSrc: message sent to the source vertex of each triplet either as + :class:`pyspark.sql.Column` or SQL expression string (default: None) + :param sendToDst: message sent to the destination vertex of each triplet either as + :class:`pyspark.sql.Column` or SQL expression string (default: None) + + :return: DataFrame with columns for the vertex ID and the resulting aggregated message + """ + # Check that either sendToSrc, sendToDst, or both are provided + if sendToSrc is None and sendToDst is None: + raise ValueError("Either `sendToSrc`, `sendToDst`, or both have to be provided") + builder = self._jvm_graph.aggregateMessages() + if sendToSrc is not None: + if isinstance(sendToSrc, Column): + builder.sendToSrc(sendToSrc._jc) + elif isinstance(sendToSrc, basestring): + builder.sendToSrc(sendToSrc) + else: + raise TypeError("Provide message either as `Column` or `str`") + if sendToDst is not None: + if isinstance(sendToDst, Column): + builder.sendToDst(sendToDst._jc) + elif isinstance(sendToDst, basestring): + builder.sendToDst(sendToDst) + else: + raise TypeError("Provide message either as `Column` or `str`") + if isinstance(aggCol, Column): + jdf = builder.agg(aggCol._jc) + else: + jdf = builder.agg(aggCol) + return DataFrame(jdf, self._spark) + + # Standard algorithms + + def connectedComponents(self, algorithm: str = 'graphframes', + checkpointInterval: int = 2, + broadcastThreshold: int = 1000000) -> DataFrame: + """ + Computes the connected components of the graph. + + See Scala documentation for more details. + + :param algorithm: connected components algorithm to use (default: "graphframes") + Supported algorithms are "graphframes" and "graphx". + :param checkpointInterval: checkpoint interval in terms of number of iterations (default: 2) + :param broadcastThreshold: broadcast threshold in propagating component assignments + (default: 1000000) + + :return: DataFrame with new vertices column "component" + """ + jdf = self._jvm_graph.connectedComponents() \ + .setAlgorithm(algorithm) \ + .setCheckpointInterval(checkpointInterval) \ + .setBroadcastThreshold(broadcastThreshold) \ + .run() + return DataFrame(jdf, self._spark) + + def labelPropagation(self, maxIter: int) -> DataFrame: + """ + Runs static label propagation for detecting communities in networks. + + See Scala documentation for more details. + + :param maxIter: the number of iterations to be performed + :return: DataFrame with new vertices column "label" + """ + jdf = self._jvm_graph.labelPropagation().maxIter(maxIter).run() + return DataFrame(jdf, self._spark) + + def pageRank(self, resetProbability: float = 0.15, + sourceId: Optional[Any] = None, + maxIter: Optional[int] = None, + tol: Optional[float] = None) -> 'GraphFrame': + """ + Runs the PageRank algorithm on the graph. + Note: Exactly one of fixed_num_iter or tolerance must be set. + + See Scala documentation for more details. + + :param resetProbability: Probability of resetting to a random vertex. + :param sourceId: (optional) the source vertex for a personalized PageRank. + :param maxIter: If set, the algorithm is run for a fixed number + of iterations. This may not be set if the `tol` parameter is set. + :param tol: If set, the algorithm is run until the given tolerance. + This may not be set if the `numIter` parameter is set. + :return: GraphFrame with new vertices column "pagerank" and new edges column "weight" + """ + builder = self._jvm_graph.pageRank().resetProbability(resetProbability) + if sourceId is not None: + builder = builder.sourceId(sourceId) + if maxIter is not None: + builder = builder.maxIter(maxIter) + assert tol is None, "Exactly one of maxIter or tol should be set." + else: + assert tol is not None, "Exactly one of maxIter or tol should be set." + builder = builder.tol(tol) + jgf = builder.run() + return _from_java_gf(jgf, self._spark) + + def parallelPersonalizedPageRank(self, resetProbability: float = 0.15, + sourceIds: Optional[list[Any]] = None, + maxIter: Optional[int] = None) -> 'GraphFrame': + """ + Run the personalized PageRank algorithm on the graph, + from the provided list of sources in parallel for a fixed number of iterations. + + See Scala documentation for more details. + + :param resetProbability: Probability of resetting to a random vertex + :param sourceIds: the source vertices for a personalized PageRank + :param maxIter: the fixed number of iterations this algorithm runs + :return: GraphFrame with new vertices column "pageranks" and new edges column "weight" + """ + assert sourceIds is not None and len(sourceIds) > 0, "Source vertices Ids sourceIds must be provided" + assert maxIter is not None, "Max number of iterations maxIter must be provided" + sourceIds = self._sc._jvm.PythonUtils.toArray(sourceIds) + builder = self._jvm_graph.parallelPersonalizedPageRank() + builder = builder.resetProbability(resetProbability) + builder = builder.sourceIds(sourceIds) + builder = builder.maxIter(maxIter) + jgf = builder.run() + return _from_java_gf(jgf, self._spark) + + def shortestPaths(self, landmarks: list[Any]) -> DataFrame: + """ + Runs the shortest path algorithm from a set of landmark vertices in the graph. + + See Scala documentation for more details. + + :param landmarks: a set of one or more landmarks + :return: DataFrame with new vertices column "distances" + """ + jdf = self._jvm_graph.shortestPaths().landmarks(landmarks).run() + return DataFrame(jdf, self._spark) + + def stronglyConnectedComponents(self, maxIter: int) -> DataFrame: + """ + Runs the strongly connected components algorithm on this graph. + + See Scala documentation for more details. + + :param maxIter: the number of iterations to run + :return: DataFrame with new vertex column "component" + """ + jdf = self._jvm_graph.stronglyConnectedComponents().maxIter(maxIter).run() + return DataFrame(jdf, self._spark) + + def svdPlusPlus(self, rank: int = 10, maxIter: int = 2, + minValue: float = 0.0, maxValue: float = 5.0, + gamma1: float = 0.007, gamma2: float = 0.007, + gamma6: float = 0.005, gamma7: float = 0.015) -> tuple[DataFrame, float]: + """ + Runs the SVD++ algorithm. + + See Scala documentation for more details. + + :return: Tuple of DataFrame with new vertex columns storing learned model, and loss value + """ + # This call is actually useless, because one needs to build the configuration first... + builder = self._jvm_graph.svdPlusPlus() + builder.rank(rank).maxIter(maxIter).minValue(minValue).maxValue(maxValue) + builder.gamma1(gamma1).gamma2(gamma2).gamma6(gamma6).gamma7(gamma7) + jdf = builder.run() + loss = builder.loss() + v = DataFrame(jdf, self._spark) + return (v, loss) + + def triangleCount(self) -> DataFrame: + """ + Counts the number of triangles passing through each vertex in this graph. + + See Scala documentation for more details. + + :return: DataFrame with new vertex column "count" + """ + jdf = self._jvm_graph.triangleCount().run() + return DataFrame(jdf, self._spark) + + +def _test(): + import doctest + import graphframe + globs = graphframe.__dict__.copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['spark'] = SparkSession(globs['sc']).builder.getOrCreate() + (failure_count, test_count) = doctest.testmod( + globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index 5381ec8b5..f0dbd1f8a 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -15,32 +15,32 @@ # limitations under the License. # -import sys -from typing import Any, Union, Optional +from __future__ import annotations -if sys.version > '3': - basestring = str +from typing import Any, Optional -from graphframes.lib import Pregel -from pyspark import SparkContext -from pyspark.sql import Column, DataFrame, SparkSession +from pyspark.sql import Column, DataFrame from pyspark.storagelevel import StorageLevel +from pyspark.version import __version__ +if __version__[:2] >= "3.4": + from pyspark.sql.utils import is_remote +else: + # All the Connect-related utilities are accessible starting from 3.4.x + def is_remote() -> bool: + return False -def _from_java_gf(jgf: Any, spark: SparkSession) -> 'GraphFrame': - """ - (internal) creates a python GraphFrame wrapper from a java GraphFrame. - :param jgf: - """ - pv = DataFrame(jgf.vertices(), spark) - pe = DataFrame(jgf.edges(), spark) - return GraphFrame(pv, pe) +from graphframes.lib import Pregel + +from .classic.graphframe import GraphFrame as GraphFrameClassic -def _java_api(jsc: SparkContext) -> Any: - javaClassName = "org.graphframes.GraphFramePythonAPI" - return jsc._jvm.Thread.currentThread().getContextClassLoader().loadClass(javaClassName) \ - .newInstance() +if __version__[:2] >= "3.4": + from graphframes.connect.graphframe_client import GraphFrameConnect +else: + class GraphFrameConnect: + def __init__(self, *args, **kwargs) -> None: + raise ValueError("Unreachable error happened!") class GraphFrame: @@ -61,33 +61,15 @@ class GraphFrame: >>> g = GraphFrame(v, e) """ + @staticmethod + def _from_impl(impl: GraphFrameClassic | GraphFrameConnect) -> "GraphFrame": + return GraphFrame(impl.vertices, impl.edges) + def __init__(self, v: DataFrame, e: DataFrame) -> None: - self._vertices = v - self._edges = e - self._spark = v.sparkSession - self._sc = self._spark._sc - self._jvm_gf_api = _java_api(self._sc) - - self.ID = self._jvm_gf_api.ID() - self.SRC = self._jvm_gf_api.SRC() - self.DST = self._jvm_gf_api.DST() - self._ATTR = self._jvm_gf_api.ATTR() - - # Check that provided DataFrames contain required columns - if self.ID not in v.columns: - raise ValueError( - "Vertex ID column {} missing from vertex DataFrame, which has columns: {}" - .format(self.ID, ",".join(v.columns))) - if self.SRC not in e.columns: - raise ValueError( - "Source vertex ID column {} missing from edge DataFrame, which has columns: {}" - .format(self.SRC, ",".join(e.columns))) - if self.DST not in e.columns: - raise ValueError( - "Destination vertex ID column {} missing from edge DataFrame, which has columns: {}" - .format(self.DST, ",".join(e.columns))) - - self._jvm_graph = self._jvm_gf_api.createGraph(v._jdf, e._jdf) + if is_remote(): + self._impl = GraphFrameConnect(v, e) + else: + self._impl = GraphFrameClassic(v, e) @property def vertices(self) -> DataFrame: @@ -95,7 +77,7 @@ def vertices(self) -> DataFrame: :class:`DataFrame` holding vertex information, with unique column "id" for vertex IDs. """ - return self._vertices + return self._impl.vertices @property def edges(self) -> DataFrame: @@ -104,32 +86,30 @@ def edges(self) -> DataFrame: "dst" storing source vertex IDs and destination vertex IDs of edges, respectively. """ - return self._edges + return self._impl.edges - def __repr__(self): - return self._jvm_graph.toString() + def __repr__(self) -> str: + return self._impl.__repr__ - def cache(self) -> 'GraphFrame': - """ Persist the dataframe representation of vertices and edges of the graph with the default + def cache(self) -> "GraphFrame": + """Persist the dataframe representation of vertices and edges of the graph with the default storage level. """ - self._jvm_graph.cache() - return self + return GraphFrame._from_impl(self._impl.cache()) - def persist(self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) -> "GraphFrame": + def persist( + self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY + ) -> "GraphFrame": """Persist the dataframe representation of vertices and edges of the graph with the given storage level. """ - javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) - self._jvm_graph.persist(javaStorageLevel) - return self + return GraphFrame._from_impl(self._impl.persist(storageLevel=storageLevel)) - def unpersist(self, blocking: bool = False) -> 'GraphFrame': + def unpersist(self, blocking: bool = False) -> "GraphFrame": """Mark the dataframe representation of vertices and edges of the graph as non-persistent, and remove all blocks for it from memory and disk. """ - self._jvm_graph.unpersist(blocking) - return self + return GraphFrame._from_impl(self._impl.unpersist(blocking=blocking)) @property def outDegrees(self) -> DataFrame: @@ -142,8 +122,7 @@ def outDegrees(self) -> DataFrame: :return: DataFrame with new vertices column "outDegree" """ - jdf = self._jvm_graph.outDegrees() - return DataFrame(jdf, self._spark) + return self._impl.outDegrees @property def inDegrees(self) -> DataFrame: @@ -156,8 +135,7 @@ def inDegrees(self) -> DataFrame: :return: DataFrame with new vertices column "inDegree" """ - jdf = self._jvm_graph.inDegrees() - return DataFrame(jdf, self._spark) + return self._impl.inDegrees @property def degrees(self) -> DataFrame: @@ -170,8 +148,7 @@ def degrees(self) -> DataFrame: :return: DataFrame with new vertices column "degree" """ - jdf = self._jvm_graph.degrees() - return DataFrame(jdf, self._spark) + return self._impl.degrees @property def triplets(self) -> DataFrame: @@ -185,17 +162,16 @@ def triplets(self) -> DataFrame: :return: DataFrame with columns 'src', 'edge', and 'dst' """ - jdf = self._jvm_graph.triplets() - return DataFrame(jdf, self._spark) + return self._impl.triplets @property - def pregel(self): + def pregel(self) -> Pregel: """ Get the :class:`graphframes.lib.Pregel` object for running pregel. See :class:`graphframes.lib.Pregel` for more details. """ - return Pregel(self) + return self._impl.pregel def find(self, pattern: str) -> DataFrame: """ @@ -206,52 +182,41 @@ def find(self, pattern: str) -> DataFrame: :param pattern: String describing the motif to search for. :return: DataFrame with one Row for each instance of the motif found """ - jdf = self._jvm_graph.find(pattern) - return DataFrame(jdf, self._spark) + return self._impl.find(pattern=pattern) - def filterVertices(self, condition: Union[str, Column]) -> 'GraphFrame': + def filterVertices(self, condition: str | Column) -> "GraphFrame": """ Filters the vertices based on expression, remove edges containing any dropped vertices. - + :param condition: String or Column describing the condition expression for filtering. - :return: GraphFrame with filtered vertices and edges. + :return: GraphFrame with filtered vertices and edges. """ + return GraphFrame._from_impl(self._impl.filterVertices(condition=condition)) - if isinstance(condition, basestring): - jdf = self._jvm_graph.filterVertices(condition) - elif isinstance(condition, Column): - jdf = self._jvm_graph.filterVertices(condition._jc) - else: - raise TypeError("condition should be string or Column") - return _from_java_gf(jdf, self._spark) - - def filterEdges(self, condition: Union[str, Column]) -> 'GraphFrame': + def filterEdges(self, condition: str | Column) -> "GraphFrame": """ Filters the edges based on expression, keep all vertices. - + :param condition: String or Column describing the condition expression for filtering. - :return: GraphFrame with filtered edges. + :return: GraphFrame with filtered edges. """ - if isinstance(condition, basestring): - jdf = self._jvm_graph.filterEdges(condition) - elif isinstance(condition, Column): - jdf = self._jvm_graph.filterEdges(condition._jc) - else: - raise TypeError("condition should be string or Column") - return _from_java_gf(jdf, self._spark) + return GraphFrame._from_impl(self._impl.filterEdges(condition=condition)) - def dropIsolatedVertices(self) -> 'GraphFrame': + def dropIsolatedVertices(self) -> "GraphFrame": """ Drops isolated vertices, vertices are not contained in any edges. - :return: GraphFrame with filtered vertices. + :return: GraphFrame with filtered vertices. """ - jdf = self._jvm_graph.dropIsolatedVertices() - return _from_java_gf(jdf, self._spark) + return GraphFrame._from_impl(self._impl.dropIsolatedVertices()) - def bfs(self, fromExpr: str, toExpr: str, - edgeFilter: Optional[str] = None, - maxPathLength: int = 10) -> DataFrame: + def bfs( + self, + fromExpr: str, + toExpr: str, + edgeFilter: str | None = None, + maxPathLength: int = 10, + ) -> DataFrame: """ Breadth-first search (BFS). @@ -259,18 +224,19 @@ def bfs(self, fromExpr: str, toExpr: str, :return: DataFrame with one Row for each shortest path between matching vertices. """ - builder = self._jvm_graph.bfs()\ - .fromExpr(fromExpr)\ - .toExpr(toExpr)\ - .maxPathLength(maxPathLength) - if edgeFilter is not None: - builder.edgeFilter(edgeFilter) - jdf = builder.run() - return DataFrame(jdf, self._spark) + return self._impl.bfs( + fromExpr=fromExpr, + toExpr=toExpr, + edgeFilter=edgeFilter, + maxPathLength=maxPathLength, + ) - def aggregateMessages(self, aggCol: Union[Column, str], - sendToSrc: Union[Column, str, None] = None, - sendToDst: Union[Column, str, None] = None) -> DataFrame: + def aggregateMessages( + self, + aggCol: Column | str, + sendToSrc: Column | str | None = None, + sendToDst: Column | str | None = None, + ) -> DataFrame: """ Aggregates messages from the neighbours. @@ -288,35 +254,18 @@ def aggregateMessages(self, aggCol: Union[Column, str], :return: DataFrame with columns for the vertex ID and the resulting aggregated message """ - # Check that either sendToSrc, sendToDst, or both are provided - if sendToSrc is None and sendToDst is None: - raise ValueError("Either `sendToSrc`, `sendToDst`, or both have to be provided") - builder = self._jvm_graph.aggregateMessages() - if sendToSrc is not None: - if isinstance(sendToSrc, Column): - builder.sendToSrc(sendToSrc._jc) - elif isinstance(sendToSrc, basestring): - builder.sendToSrc(sendToSrc) - else: - raise TypeError("Provide message either as `Column` or `str`") - if sendToDst is not None: - if isinstance(sendToDst, Column): - builder.sendToDst(sendToDst._jc) - elif isinstance(sendToDst, basestring): - builder.sendToDst(sendToDst) - else: - raise TypeError("Provide message either as `Column` or `str`") - if isinstance(aggCol, Column): - jdf = builder.agg(aggCol._jc) - else: - jdf = builder.agg(aggCol) - return DataFrame(jdf, self._spark) + return self._impl.aggregateMessages( + aggCol=aggCol, sendToSrc=sendToSrc, sendToDst=sendToDst + ) # Standard algorithms - def connectedComponents(self, algorithm: str = 'graphframes', - checkpointInterval: int = 2, - broadcastThreshold: int = 1000000) -> DataFrame: + def connectedComponents( + self, + algorithm: str = "graphframes", + checkpointInterval: int = 2, + broadcastThreshold: int = 1000000, + ) -> DataFrame: """ Computes the connected components of the graph. @@ -330,12 +279,11 @@ def connectedComponents(self, algorithm: str = 'graphframes', :return: DataFrame with new vertices column "component" """ - jdf = self._jvm_graph.connectedComponents() \ - .setAlgorithm(algorithm) \ - .setCheckpointInterval(checkpointInterval) \ - .setBroadcastThreshold(broadcastThreshold) \ - .run() - return DataFrame(jdf, self._spark) + return self._impl.connectedComponents( + algorithm=algorithm, + checkpointInterval=checkpointInterval, + broadcastThreshold=broadcastThreshold, + ) def labelPropagation(self, maxIter: int) -> DataFrame: """ @@ -346,13 +294,15 @@ def labelPropagation(self, maxIter: int) -> DataFrame: :param maxIter: the number of iterations to be performed :return: DataFrame with new vertices column "label" """ - jdf = self._jvm_graph.labelPropagation().maxIter(maxIter).run() - return DataFrame(jdf, self._spark) + return self._impl.labelPropagation(maxIter=maxIter) - def pageRank(self, resetProbability: float = 0.15, - sourceId: Optional[Any] = None, - maxIter: Optional[int] = None, - tol: Optional[float] = None) -> 'GraphFrame': + def pageRank( + self, + resetProbability: float = 0.15, + sourceId: Optional[Any] = None, + maxIter: Optional[int] = None, + tol: Optional[float] = None, + ) -> "GraphFrame": """ Runs the PageRank algorithm on the graph. Note: Exactly one of fixed_num_iter or tolerance must be set. @@ -367,21 +317,21 @@ def pageRank(self, resetProbability: float = 0.15, This may not be set if the `numIter` parameter is set. :return: GraphFrame with new vertices column "pagerank" and new edges column "weight" """ - builder = self._jvm_graph.pageRank().resetProbability(resetProbability) - if sourceId is not None: - builder = builder.sourceId(sourceId) - if maxIter is not None: - builder = builder.maxIter(maxIter) - assert tol is None, "Exactly one of maxIter or tol should be set." - else: - assert tol is not None, "Exactly one of maxIter or tol should be set." - builder = builder.tol(tol) - jgf = builder.run() - return _from_java_gf(jgf, self._spark) - - def parallelPersonalizedPageRank(self, resetProbability: float = 0.15, - sourceIds: Optional[list[Any]] = None, - maxIter: Optional[int] = None) -> 'GraphFrame': + return GraphFrame._from_impl( + self._impl.pageRank( + resetProbability=resetProbability, + sourceId=sourceId, + maxIter=maxIter, + tol=tol, + ) + ) + + def parallelPersonalizedPageRank( + self, + resetProbability: float = 0.15, + sourceIds: Optional[list[Any]] = None, + maxIter: Optional[int] = None, + ) -> "GraphFrame": """ Run the personalized PageRank algorithm on the graph, from the provided list of sources in parallel for a fixed number of iterations. @@ -393,15 +343,11 @@ def parallelPersonalizedPageRank(self, resetProbability: float = 0.15, :param maxIter: the fixed number of iterations this algorithm runs :return: GraphFrame with new vertices column "pageranks" and new edges column "weight" """ - assert sourceIds is not None and len(sourceIds) > 0, "Source vertices Ids sourceIds must be provided" - assert maxIter is not None, "Max number of iterations maxIter must be provided" - sourceIds = self._sc._jvm.PythonUtils.toArray(sourceIds) - builder = self._jvm_graph.parallelPersonalizedPageRank() - builder = builder.resetProbability(resetProbability) - builder = builder.sourceIds(sourceIds) - builder = builder.maxIter(maxIter) - jgf = builder.run() - return _from_java_gf(jgf, self._spark) + return GraphFrame._from_impl( + self._impl.parallelPersonalizedPageRank( + resetProbability=resetProbability, sourceIds=sourceIds, maxIter=maxIter + ) + ) def shortestPaths(self, landmarks: list[Any]) -> DataFrame: """ @@ -412,8 +358,7 @@ def shortestPaths(self, landmarks: list[Any]) -> DataFrame: :param landmarks: a set of one or more landmarks :return: DataFrame with new vertices column "distances" """ - jdf = self._jvm_graph.shortestPaths().landmarks(landmarks).run() - return DataFrame(jdf, self._spark) + return self._impl.shortestPaths(landmarks=landmarks) def stronglyConnectedComponents(self, maxIter: int) -> DataFrame: """ @@ -424,13 +369,19 @@ def stronglyConnectedComponents(self, maxIter: int) -> DataFrame: :param maxIter: the number of iterations to run :return: DataFrame with new vertex column "component" """ - jdf = self._jvm_graph.stronglyConnectedComponents().maxIter(maxIter).run() - return DataFrame(jdf, self._spark) + return self._impl.stronglyConnectedComponents(maxIter=maxIter) - def svdPlusPlus(self, rank: int = 10, maxIter: int = 2, - minValue: float = 0.0, maxValue: float = 5.0, - gamma1: float = 0.007, gamma2: float = 0.007, - gamma6: float = 0.005, gamma7: float = 0.015) -> tuple[DataFrame, float]: + def svdPlusPlus( + self, + rank: int = 10, + maxIter: int = 2, + minValue: float = 0.0, + maxValue: float = 5.0, + gamma1: float = 0.007, + gamma2: float = 0.007, + gamma6: float = 0.005, + gamma7: float = 0.015, + ) -> tuple[DataFrame, float]: """ Runs the SVD++ algorithm. @@ -438,14 +389,16 @@ def svdPlusPlus(self, rank: int = 10, maxIter: int = 2, :return: Tuple of DataFrame with new vertex columns storing learned model, and loss value """ - # This call is actually useless, because one needs to build the configuration first... - builder = self._jvm_graph.svdPlusPlus() - builder.rank(rank).maxIter(maxIter).minValue(minValue).maxValue(maxValue) - builder.gamma1(gamma1).gamma2(gamma2).gamma6(gamma6).gamma7(gamma7) - jdf = builder.run() - loss = builder.loss() - v = DataFrame(jdf, self._spark) - return (v, loss) + return self._impl.svdPlusPlus( + rank=rank, + maxIter=maxIter, + minValue=minValue, + maxValue=maxValue, + gamma1=gamma1, + gamma2=gamma2, + gamma6=gamma6, + gamma7=gamma7, + ) def triangleCount(self) -> DataFrame: """ @@ -455,19 +408,23 @@ def triangleCount(self) -> DataFrame: :return: DataFrame with new vertex column "count" """ - jdf = self._jvm_graph.triangleCount().run() - return DataFrame(jdf, self._spark) + return self._impl.triangleCount() def _test(): import doctest + import graphframe + from pyspark.sql import SparkSession + globs = graphframe.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - globs['spark'] = SparkSession(globs['sc']).builder.getOrCreate() + globs["spark"] = ( + SparkSession.builder.master("local[4]").appName("PythonTest").getOrCreate() + ) (failure_count, test_count) = doctest.testmod( - globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) - globs['sc'].stop() + globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE + ) + globs["spark"].stop() if failure_count: exit(-1) diff --git a/python/graphframes/tests.py b/python/graphframes/tests.py deleted file mode 100644 index 9a7ad1371..000000000 --- a/python/graphframes/tests.py +++ /dev/null @@ -1,468 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import sys -import tempfile -import shutil -import re - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest - -from pyspark import SparkContext -from pyspark.sql import functions as sqlfunctions, SparkSession - -from .graphframe import GraphFrame, Pregel, _java_api, _from_java_gf -from .lib import AggregateMessages as AM -from .examples import Graphs, BeliefPropagation - -class GraphFrameTestUtils(object): - - @classmethod - def parse_spark_version(cls, version_str): - """ take an input version string - return version items in a dictionary - """ - _sc_ver_patt = r'(\d+)\.(\d+)(\.(\d+)(-(.+))?)?' - m = re.match(_sc_ver_patt, version_str) - if not m: - raise TypeError("version {} shoud be in ..".format(version_str)) - version_info = {} - try: - version_info['major'] = int(m.group(1)) - except: - raise TypeError("invalid minor version") - try: - version_info['minor'] = int(m.group(2)) - except: - raise TypeError("invalid major version") - try: - version_info['maintenance'] = int(m.group(4)) - except: - version_info['maintenance'] = 0 - try: - version_info['special'] = m.group(6) - except: - pass - return version_info - - @classmethod - def createSparkContext(cls): - cls.sc = sc = SparkContext('local[4]', "GraphFramesTests") - cls.checkpointDir = tempfile.mkdtemp() - cls.sc.setCheckpointDir(cls.checkpointDir) - cls.spark_version = cls.parse_spark_version(sc.version) - - @classmethod - def stopSparkContext(cls): - cls.sc.stop() - cls.sc = None - shutil.rmtree(cls.checkpointDir) - - @classmethod - def spark_at_least_of_version(cls, version_str): - assert hasattr(cls, 'spark_version') - required_version = cls.parse_spark_version(version_str) - spark_version = cls.spark_version - for _name in ['major', 'minor', 'maintenance']: - sc_ver = spark_version[_name] - req_ver = required_version[_name] - if sc_ver != req_ver: - return sc_ver > req_ver - # All major.minor.maintenance equal - return True - -def setUpModule(): - GraphFrameTestUtils.createSparkContext() - -def tearDownModule(): - GraphFrameTestUtils.stopSparkContext() - - -class GraphFrameTestCase(unittest.TestCase): - - @classmethod - def setUpClass(cls): - # Small tests run much faster with spark.sql.shuffle.partitions = 4 - cls.spark = SparkSession(GraphFrameTestUtils.sc).builder.config('spark.sql.shuffle.partitions', 4).getOrCreate() - - @classmethod - def tearDownClass(cls): - cls.spark = None - - -class GraphFrameTest(GraphFrameTestCase): - def setUp(self): - super(GraphFrameTest, self).setUp() - localVertices = [(1, "A"), (2, "B"), (3, "C")] - localEdges = [(1, 2, "love"), (2, 1, "hate"), (2, 3, "follow")] - v = self.spark.createDataFrame(localVertices, ["id", "name"]) - e = self.spark.createDataFrame(localEdges, ["src", "dst", "action"]) - self.g = GraphFrame(v, e) - - def test_spark_version_check(self): - gtu = GraphFrameTestUtils - gtu.spark_version = gtu.parse_spark_version("2.0.2") - self.assertTrue(gtu.spark_at_least_of_version("1.7")) - self.assertTrue(gtu.spark_at_least_of_version("2.0")) - self.assertTrue(gtu.spark_at_least_of_version("2.0.1")) - self.assertTrue(gtu.spark_at_least_of_version("2.0.2")) - self.assertFalse(gtu.spark_at_least_of_version("2.0.3")) - self.assertFalse(gtu.spark_at_least_of_version("2.1")) - - def test_construction(self): - g = self.g - vertexIDs = map(lambda x: x[0], g.vertices.select("id").collect()) - assert sorted(vertexIDs) == [1, 2, 3] - edgeActions = map(lambda x: x[0], g.edges.select("action").collect()) - assert sorted(edgeActions) == ["follow", "hate", "love"] - tripletsFirst = list(map(lambda x: (x[0][1], x[1][1], x[2][2]), - g.triplets.sort("src.id").select("src", "dst", "edge").take(1))) - assert tripletsFirst == [("A", "B", "love")], tripletsFirst - # Try with invalid vertices and edges DataFrames - v_invalid = self.spark.createDataFrame( - [(1, "A"), (2, "B"), (3, "C")], ["invalid_colname_1", "invalid_colname_2"]) - e_invalid = self.spark.createDataFrame( - [(1, 2), (2, 3), (3, 1)], ["invalid_colname_3", "invalid_colname_4"]) - with self.assertRaises(ValueError): - GraphFrame(v_invalid, e_invalid) - - def test_cache(self): - g = self.g - g.cache() - g.unpersist() - - def test_degrees(self): - g = self.g - outDeg = g.outDegrees - self.assertSetEqual(set(outDeg.columns), {"id", "outDegree"}) - inDeg = g.inDegrees - self.assertSetEqual(set(inDeg.columns), {"id", "inDegree"}) - deg = g.degrees - self.assertSetEqual(set(deg.columns), {"id", "degree"}) - - def test_motif_finding(self): - g = self.g - motifs = g.find("(a)-[e]->(b)") - assert motifs.count() == 3 - self.assertSetEqual(set(motifs.columns), {"a", "e", "b"}) - - def test_filterVertices(self): - g = self.g - conditions = ["id < 3", g.vertices.id < 3] - expected_v = [(1, "A"), (2, "B")] - expected_e = [(1, 2, "love"), (2, 1, "hate")] - for cond in conditions: - g2 = g.filterVertices(cond) - v2 = g2.vertices.select("id", "name").collect() - e2 = g2.edges.select("src", "dst", "action").collect() - assert len(v2) == len(expected_v) - assert len(e2) == len(expected_e) - self.assertSetEqual(set(v2), set(expected_v)) - self.assertSetEqual(set(e2), set(expected_e)) - - def test_filterEdges(self): - g = self.g - conditions = ["dst > 2", g.edges.dst > 2] - expected_v = [(1, "A"), (2, "B"), (3, "C")] - expected_e = [(2, 3, "follow")] - for cond in conditions: - g2 = g.filterEdges(cond) - v2 = g2.vertices.select("id", "name").collect() - e2 = g2.edges.select("src", "dst", "action").collect() - assert len(v2) == len(expected_v) - assert len(e2) == len(expected_e) - self.assertSetEqual(set(v2), set(expected_v)) - self.assertSetEqual(set(e2), set(expected_e)) - - def test_dropIsolatedVertices(self): - g = self.g - g2 = g.filterEdges("dst > 2").dropIsolatedVertices() - v2 = g2.vertices.select("id", "name").collect() - e2 = g2.edges.select("src", "dst", "action").collect() - expected_v = [(2, "B"), (3, "C")] - expected_e = [(2, 3, "follow")] - assert len(v2) == len(expected_v) - assert len(e2) == len(expected_e) - self.assertSetEqual(set(v2), set(expected_v)) - self.assertSetEqual(set(e2), set(expected_e)) - - def test_bfs(self): - g = self.g - paths = g.bfs("name='A'", "name='C'") - self.assertEqual(paths.count(), 1) - self.assertEqual(paths.select("v1.name").head()[0], "B") - paths2 = g.bfs("name='A'", "name='C'", edgeFilter="action!='follow'") - self.assertEqual(paths2.count(), 0) - paths3 = g.bfs("name='A'", "name='C'", maxPathLength=1) - self.assertEqual(paths3.count(), 0) - - -class PregelTest(GraphFrameTestCase): - def setUp(self): - super(PregelTest, self).setUp() - - def test_page_rank(self): - from pyspark.sql.functions import coalesce, col, lit, sum, when - edges = self.spark.createDataFrame([[0, 1], - [1, 2], - [2, 4], - [2, 0], - [3, 4], # 3 has no in-links - [4, 0], - [4, 2]], ["src", "dst"]) - edges.cache() - vertices = self.spark.createDataFrame([[0], [1], [2], [3], [4]], ["id"]) - numVertices = vertices.count() - vertices = GraphFrame(vertices, edges).outDegrees - vertices.cache() - graph = GraphFrame(vertices, edges) - alpha = 0.15 - ranks = graph.pregel \ - .setMaxIter(5) \ - .withVertexColumn("rank", lit(1.0 / numVertices), - coalesce(Pregel.msg(), - lit(0.0)) * lit(1.0 - alpha) + lit(alpha / numVertices)) \ - .sendMsgToDst(Pregel.src("rank") / Pregel.src("outDegree")) \ - .aggMsgs(sum(Pregel.msg())) \ - .run() - resultRows = ranks.sort(ranks.id).collect() - result = map(lambda x: x.rank, resultRows) - expected = [0.245, 0.224, 0.303, 0.03, 0.197] - for a, b in zip(result, expected): - self.assertAlmostEqual(a, b, delta = 1e-3) - - -class GraphFrameLibTest(GraphFrameTestCase): - def setUp(self): - super(GraphFrameLibTest, self).setUp() - self.japi = _java_api(self.spark._sc) - - def _hasCols(self, graph, vcols = [], ecols = []): - map(lambda c: self.assertIn(c, graph.vertices.columns), vcols) - map(lambda c: self.assertIn(c, graph.edges.columns), ecols) - - def _df_hasCols(self, vertices, vcols = []): - map(lambda c: self.assertIn(c, vertices.columns), vcols) - - def _graph(self, name, *args): - """ - Convenience to call one of the example graphs, passing the arguments and wrapping the result back - as a python object. - :param name: the name of the example graph - :param args: all the required arguments, without the initial spark session - :return: - """ - examples = self.japi.examples() - jgraph = getattr(examples, name)(*args) - return _from_java_gf(jgraph, self.spark) - - def test_aggregate_messages(self): - g = self._graph("friends") - # For each user, sum the ages of the adjacent users, - # plus 1 for the src's sum if the edge is "friend". - sendToSrc = ( - AM.dst['age'] + - sqlfunctions.when( - AM.edge['relationship'] == 'friend', - sqlfunctions.lit(1) - ).otherwise(0)) - sendToDst = AM.src['age'] - agg = g.aggregateMessages( - sqlfunctions.sum(AM.msg).alias('summedAges'), - sendToSrc=sendToSrc, - sendToDst=sendToDst) - # Run the aggregation again providing SQL expressions as String instead. - agg2 = g.aggregateMessages( - "sum(MSG) AS `summedAges`", - sendToSrc="(dst['age'] + CASE WHEN (edge['relationship'] = 'friend') THEN 1 ELSE 0 END)", - sendToDst="src['age']") - # Convert agg and agg2 to a mapping from id to the aggregated message. - aggMap = {id_: s for id_, s in agg.select('id', 'summedAges').collect()} - agg2Map = {id_: s for id_, s in agg2.select('id', 'summedAges').collect()} - # Compute the truth via brute force. - user2age = {id_: age for id_, age in g.vertices.select('id', 'age').collect()} - trueAgg = {} - for src, dst, rel in g.edges.select("src", "dst", "relationship").collect(): - trueAgg[src] = trueAgg.get(src, 0) + user2age[dst] + (1 if rel == 'friend' else 0) - trueAgg[dst] = trueAgg.get(dst, 0) + user2age[src] - # Compare if the agg mappings match the brute force mapping - self.assertEqual(aggMap, trueAgg) - self.assertEqual(agg2Map, trueAgg) - # Check that TypeError is raises with messages of wrong type - with self.assertRaises(TypeError): - g.aggregateMessages( - "sum(MSG) AS `summedAges`", - sendToSrc=object(), - sendToDst="src['age']") - with self.assertRaises(TypeError): - g.aggregateMessages( - "sum(MSG) AS `summedAges`", - sendToSrc=dst['age'], - sendToDst=object()) - - def test_connected_components(self): - v = self.spark.createDataFrame([ - (0, "a", "b")], ["id", "vattr", "gender"]) - e = self.spark.createDataFrame([(0, 0, 1)], ["src", "dst", "test"]).filter("src > 10") - g = GraphFrame(v, e) - comps = g.connectedComponents() - self._df_hasCols(comps, vcols=['id', 'component', 'vattr', 'gender']) - self.assertEqual(comps.count(), 1) - - def test_connected_components2(self): - v = self.spark.createDataFrame([(0, "a0", "b0"), (1, "a1", "b1")], ["id", "A", "B"]) - e = self.spark.createDataFrame([(0, 1, "a01", "b01")], ["src", "dst", "A", "B"]) - g = GraphFrame(v, e) - comps = g.connectedComponents() - self._df_hasCols(comps, vcols=['id', 'component', 'A', 'B']) - self.assertEqual(comps.count(), 2) - - def test_connected_components_friends(self): - g = self._graph("friends") - comps_tests = [] - comps_tests += [g.connectedComponents()] - comps_tests += [g.connectedComponents(broadcastThreshold=1)] - comps_tests += [g.connectedComponents(checkpointInterval=0)] - comps_tests += [g.connectedComponents(checkpointInterval=10)] - comps_tests += [g.connectedComponents(algorithm="graphx")] - for c in comps_tests: - self.assertEqual(c.groupBy("component").count().count(), 2) - - def test_label_progagation(self): - n = 5 - g = self._graph("twoBlobs", n) - labels = g.labelPropagation(maxIter=4 * n) - labels1 = labels.filter("id < 5").select("label").collect() - all1 = set([x.label for x in labels1]) - assert len(all1) == 1 - labels2 = labels.filter("id >= 5").select("label").collect() - all2 = set([x.label for x in labels2]) - assert len(all2) == 1 - assert all1 != all2 - - def test_page_rank(self): - n = 100 - g = self._graph("star", n) - resetProb = 0.15 - errorTol = 1.0e-5 - pr = g.pageRank(resetProb, tol=errorTol) - self._hasCols(pr, vcols=['id', 'pagerank'], ecols=['src', 'dst', 'weight']) - - def test_parallel_personalized_page_rank(self): - n = 100 - g = self._graph("star", n) - resetProb = 0.15 - maxIter = 15 - sourceIds = [1, 2, 3, 4] - pr = g.parallelPersonalizedPageRank(resetProb, sourceIds=sourceIds, maxIter=maxIter) - self._hasCols(pr, vcols=['id', 'pageranks'], ecols=['src', 'dst', 'weight']) - - def test_shortest_paths(self): - edges = [(1, 2), (1, 5), (2, 3), (2, 5), (3, 4), (4, 5), (4, 6)] - all_edges = [z for (a, b) in edges for z in [(a, b), (b, a)]] - edges = self.spark.createDataFrame(all_edges, ["src", "dst"]) - vertices = self.spark.createDataFrame([(i,) for i in range(1, 7)], ["id"]) - g = GraphFrame(vertices, edges) - landmarks = [1, 4] - v2 = g.shortestPaths(landmarks) - self._df_hasCols(v2, vcols=["id", "distances"]) - - def test_svd_plus_plus(self): - g = self._graph("ALSSyntheticData") - (v2, cost) = g.svdPlusPlus() - self._df_hasCols(v2, vcols=['id', 'column1', 'column2', 'column3', 'column4']) - - def test_strongly_connected_components(self): - # Simple island test - vertices = self.spark.createDataFrame([(i,) for i in range(1, 6)], ["id"]) - edges = self.spark.createDataFrame([(7, 8)], ["src", "dst"]) - g = GraphFrame(vertices, edges) - c = g.stronglyConnectedComponents(5) - for row in c.collect(): - self.assertEqual(row.id, row.component) - - def test_triangle_counts(self): - edges = self.spark.createDataFrame([(0, 1), (1, 2), (2, 0)], ["src", "dst"]) - vertices = self.spark.createDataFrame([(0,), (1,), (2,)], ["id"]) - g = GraphFrame(vertices, edges) - c = g.triangleCount() - for row in c.select("id", "count").collect(): - self.assertEqual(row.asDict()['count'], 1) - - def test_mutithreaded_sparksession_usage(self): - # Test that we can use the GraphFrame API from multiple threads - localVertices = [(1, "A"), (2, "B"), (3, "C")] - localEdges = [(1, 2, "love"), (2, 1, "hate"), (2, 3, "follow")] - v = self.spark.createDataFrame(localVertices, ["id", "name"]) - e = self.spark.createDataFrame(localEdges, ["src", "dst", "action"]) - - - exc = None - def run_graphframe() -> None: - try: - GraphFrame(v, e) - except Exception as _e: - nonlocal exc - exc = _e - - import threading - thread = threading.Thread(target=run_graphframe) - thread.start() - thread.join() - self.assertIsNone(exc, f"Exception was raised in thread: {exc}") - - -class GraphFrameExamplesTest(GraphFrameTestCase): - def setUp(self): - super(GraphFrameExamplesTest, self).setUp() - self.japi = _java_api(self.spark._sc) - - def test_belief_propagation(self): - # create graphical model g of size 3 x 3 - g = Graphs(self.spark).gridIsingModel(3) - # run BP for 5 iterations - numIter = 5 - results = BeliefPropagation.runBPwithGraphFrames(g, numIter) - # check beliefs are valid - for row in results.vertices.select('belief').collect(): - belief = row['belief'] - self.assertTrue( - 0 <= belief <= 1, - msg="Expected belief to be probability in [0,1], but found {}".format(belief)) - - def test_graph_friends(self): - # construct graph - g = Graphs(self.spark).friends() - # check that a GraphFrame instance was returned - self.assertIsInstance(g, GraphFrame) - - def test_graph_grid_ising_model(self): - # construct graph - n = 3 - g = Graphs(self.spark).gridIsingModel(n) - # check that all the vertices exist - ids = [v['id'] for v in g.vertices.collect()] - for i in range(n): - for j in range(n): - self.assertIn('{},{}'.format(i, j), ids) diff --git a/python/pyproject.toml b/python/pyproject.toml index 8c0c1ba05..a239412ce 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -5,7 +5,9 @@ description = "GraphFrames: Graph Processing Framework for Apache Spark" authors = ["GraphFrames Contributors "] license = "Apache 2.0" readme = "README.md" -packages = [{include = "graphframes"}] +packages = [ + { include = "graphframes" }, +] classifiers = [ "Development Status :: 4 - Beta", "License :: OSI Approved :: Apache Software License", @@ -16,6 +18,13 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12" ] +include = [ + { path = "graphframes/resources/*.jar", format = "wheel" } +] + +[tool.poetry.build] +script = "dev/build_jar.py" + [tool.poetry.urls] "Project Homepage" = "https://graphframes.github.io/graphframes" diff --git a/python/tests/__init__.py b/python/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/tests/tests.py b/python/tests/tests.py index f8330cc16..c9a93fec6 100644 --- a/python/tests/tests.py +++ b/python/tests/tests.py @@ -15,65 +15,37 @@ # limitations under the License. # -import os -import re +import pathlib import shutil -import sys -import tempfile import unittest -from pathlib import Path +import warnings +from importlib import resources +from graphframes.classic.graphframe import _from_java_gf, _java_api +from graphframes.examples import BeliefPropagation, Graphs +from graphframes.graphframe import GraphFrame, Pregel +from graphframes.lib import AggregateMessages as AM from pyspark.sql import SparkSession from pyspark.sql import functions as sqlfunctions +from pyspark.version import __version__ -# Workaround that won't be needed after setting up setup.py -prj_root = Path(__file__).parent.parent -sys.path.insert(0, prj_root.absolute().__str__()) +if __version__[:2] >= "3.4": + from pyspark.sql.utils import is_remote +else: -from graphframes.examples import BeliefPropagation, Graphs -from graphframes.graphframe import GraphFrame, Pregel, _from_java_gf, _java_api -from graphframes.connect.graphframe_client import GraphFrameConnect, PregelConnect -from graphframes.lib import AggregateMessages as AM + def is_remote() -> bool: + return False -class GraphFrameTestUtils(object): +class GraphFrameTestCase(unittest.TestCase): @classmethod - def parse_spark_version(cls, version_str): - """take an input version string - return version items in a dictionary - """ - _sc_ver_patt = r"(\d+)\.(\d+)(\.(\d+)(-(.+))?)?" - m = re.match(_sc_ver_patt, version_str) - if not m: - raise TypeError( - "version {} shoud be in ..".format( - version_str - ) - ) - version_info = {} - try: - version_info["major"] = int(m.group(1)) - except: - raise TypeError("invalid minor version") - try: - version_info["minor"] = int(m.group(2)) - except: - raise TypeError("invalid major version") - try: - version_info["maintenance"] = int(m.group(4)) - except: - version_info["maintenance"] = 0 - try: - version_info["special"] = m.group(6) - except: - pass - return version_info + def setUpClass(cls): + warnings.filterwarnings("ignore", category=ResourceWarning) + warnings.filterwarnings("ignore", category=DeprecationWarning) + cls.checkpointDir = "/tmp/GFTestsCheckpointDir" - @classmethod - def createSparkSession(cls): - cls.checkpointDir = tempfile.mkdtemp() - if "SPARK_CONNECT_MODE_ENABLED" in os.environ: - cls.sc = ( + if is_remote(): + cls.spark = ( SparkSession.builder.remote("sc://localhost:15002") .appName("GraphFramesTest") .config("spark.sql.shuffle.partitions", 4) @@ -81,67 +53,33 @@ def createSparkSession(cls): .getOrCreate() ) else: - cls.sc = ( + spark = ( SparkSession.builder.master("local[4]") .appName("GraphFramesTest") .config("spark.sql.shuffle.partitions", 4) - .config("spark.checkpoint.dir", cls.checkpointDir) - .getOrCreate() ) - - assert cls.sc is not None - cls.spark_version = cls.parse_spark_version(cls.sc.version) - - @classmethod - def stopSparkSession(cls): - assert cls.sc is not None - cls.sc.stop() - cls.sc = None - shutil.rmtree(cls.checkpointDir) - - @classmethod - def spark_at_least_of_version(cls, version_str): - assert hasattr(cls, "spark_version") - required_version = cls.parse_spark_version(version_str) - spark_version = cls.spark_version - for _name in ["major", "minor", "maintenance"]: - sc_ver = spark_version[_name] - req_ver = required_version[_name] - if sc_ver != req_ver: - return sc_ver > req_ver - # All major.minor.maintenance equal - return True - - -def setUpModule(): - GraphFrameTestUtils.createSparkSession() - - -def tearDownModule(): - GraphFrameTestUtils.stopSparkSession() - - -class GraphFrameTestCase(unittest.TestCase): - @classmethod - def setUpClass(cls): - # Small tests run much faster with spark.sql.shuffle.partitions = 4 - if "SPARK_CONNECT_MODE_ENABLED" in os.environ: - cls.spark = SparkSession.builder.remote( - "sc://localhost:15002" - ).getOrCreate() - cls.connect_mode = 1 - else: - cls.spark = ( - SparkSession(GraphFrameTestUtils.sc) - .builder.config("spark.sql.shuffle.partitions", 4) - .getOrCreate() - ) - cls.connect_mode = 0 + resources_root = resources.files("graphframes").joinpath("resources") + spark_jars = [] + for pp in resources_root.iterdir(): + assert isinstance(pp, pathlib.PosixPath) # type checking + if pp.is_file() and pp.name.endswith(".jar"): + spark_jars.append(pp.absolute().__str__()) + if spark_jars: + jars_str = ",".join(spark_jars) + spark = spark.config("spark.jars", jars_str) + + cls.spark = spark.getOrCreate() + assert cls.spark is not None + cls.spark.sparkContext.setCheckpointDir(cls.checkpointDir) + + assert cls.spark is not None @classmethod def tearDownClass(cls): + assert cls.spark is not None + cls.spark.stop() cls.spark = None - cls.connect_mode = None + shutil.rmtree(cls.checkpointDir) class GraphFrameTest(GraphFrameTestCase): @@ -151,22 +89,7 @@ def setUp(self): localEdges = [(1, 2, "love"), (2, 1, "hate"), (2, 3, "follow")] v = self.spark.createDataFrame(localVertices, ["id", "name"]) e = self.spark.createDataFrame(localEdges, ["src", "dst", "action"]) - if self.connect_mode: - self.g = GraphFrameConnect(v, e) - else: - self.g = GraphFrame(v, e) - - def test_spark_version_check(self): - # SparkContext is not available in Spark Connect - if not self.connect_mode: - gtu = GraphFrameTestUtils - gtu.spark_version = gtu.parse_spark_version("2.0.2") - self.assertTrue(gtu.spark_at_least_of_version("1.7")) - self.assertTrue(gtu.spark_at_least_of_version("2.0")) - self.assertTrue(gtu.spark_at_least_of_version("2.0.1")) - self.assertTrue(gtu.spark_at_least_of_version("2.0.2")) - self.assertFalse(gtu.spark_at_least_of_version("2.0.3")) - self.assertFalse(gtu.spark_at_least_of_version("2.1")) + self.g = GraphFrame(v, e) def test_construction(self): g = self.g @@ -189,10 +112,7 @@ def test_construction(self): [(1, 2), (2, 3), (3, 1)], ["invalid_colname_3", "invalid_colname_4"] ) with self.assertRaises(ValueError): - if self.connect_mode: - GraphFrameConnect(v_invalid, e_invalid) - else: - GraphFrame(v_invalid, e_invalid) + GraphFrame(v_invalid, e_invalid) def test_cache(self): g = self.g @@ -270,7 +190,7 @@ def setUp(self): super(PregelTest, self).setUp() def test_page_rank(self): - from pyspark.sql.functions import coalesce, col, lit, sum, when + from pyspark.sql.functions import coalesce, lit, sum edges = self.spark.createDataFrame( [ @@ -289,16 +209,9 @@ def test_page_rank(self): vertices = self.spark.createDataFrame([[0], [1], [2], [3], [4]], ["id"]) numVertices = vertices.count() - if self.connect_mode: - vertices = GraphFrameConnect(vertices, edges).outDegrees - else: - vertices = GraphFrame(vertices, edges).outDegrees + vertices = GraphFrame(vertices, edges).outDegrees vertices.cache() - - if self.connect_mode: - graph = GraphFrameConnect(vertices, edges) - else: - graph = GraphFrame(vertices, edges) + graph = GraphFrame(vertices, edges) alpha = 0.15 ranks = ( graph.pregel.setMaxIter(5) @@ -322,7 +235,7 @@ def test_page_rank(self): class GraphFrameLibTest(GraphFrameTestCase): def setUp(self): super(GraphFrameLibTest, self).setUp() - if self.connect_mode: + if is_remote(): self.japi = None else: self.japi = _java_api(self.spark._sc) @@ -346,7 +259,7 @@ def _graph(self, name, *args): jgraph = getattr(examples, name)(*args) return _from_java_gf(jgraph, self.spark) - @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + @unittest.skipIf(is_remote(), "SKIP FOR CONNECT") def test_aggregate_messages(self): g = self._graph("friends") # For each user, sum the ages of the adjacent users, @@ -395,10 +308,7 @@ def test_connected_components(self): e = self.spark.createDataFrame([(0, 0, 1)], ["src", "dst", "test"]).filter( "src > 10" ) - if self.connect_mode: - g = GraphFrameConnect(v, e) - else: - g = GraphFrame(v, e) + g = GraphFrame(v, e) comps = g.connectedComponents() self._df_hasCols(comps, vcols=["id", "component", "vattr", "gender"]) self.assertEqual(comps.count(), 1) @@ -408,15 +318,12 @@ def test_connected_components2(self): [(0, "a0", "b0"), (1, "a1", "b1")], ["id", "A", "B"] ) e = self.spark.createDataFrame([(0, 1, "a01", "b01")], ["src", "dst", "A", "B"]) - if self.connect_mode: - g = GraphFrameConnect(v, e) - else: - g = GraphFrame(v, e) + g = GraphFrame(v, e) comps = g.connectedComponents() self._df_hasCols(comps, vcols=["id", "component", "A", "B"]) self.assertEqual(comps.count(), 2) - @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + @unittest.skipIf(is_remote(), "SKIP FOR CONNECT") def test_connected_components_friends(self): g = self._graph("friends") comps_tests = [] @@ -428,7 +335,7 @@ def test_connected_components_friends(self): for c in comps_tests: self.assertEqual(c.groupBy("component").count().count(), 2) - @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + @unittest.skipIf(is_remote(), "SKIP FOR CONNECT") def test_label_progagation(self): n = 5 g = self._graph("twoBlobs", n) @@ -441,7 +348,7 @@ def test_label_progagation(self): assert len(all2) == 1 assert all1 != all2 - @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + @unittest.skipIf(is_remote(), "SKIP FOR CONNECT") def test_page_rank(self): n = 100 g = self._graph("star", n) @@ -450,7 +357,7 @@ def test_page_rank(self): pr = g.pageRank(resetProb, tol=errorTol) self._hasCols(pr, vcols=["id", "pagerank"], ecols=["src", "dst", "weight"]) - @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + @unittest.skipIf(is_remote(), "SKIP FOR CONNECT") def test_parallel_personalized_page_rank(self): n = 100 g = self._graph("star", n) @@ -468,15 +375,12 @@ def test_shortest_paths(self): assert self.spark is not None edges = self.spark.createDataFrame(all_edges, ["src", "dst"]) vertices = self.spark.createDataFrame([(i,) for i in range(1, 7)], ["id"]) - if self.connect_mode: - g = GraphFrameConnect(vertices, edges) - else: - g = GraphFrame(vertices, edges) + g = GraphFrame(vertices, edges) landmarks = [1, 4] v2 = g.shortestPaths(landmarks) self._df_hasCols(v2, vcols=["id", "distances"]) - @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + @unittest.skipIf(is_remote(), "SKIP FOR CONNECT") def test_svd_plus_plus(self): g = self._graph("ALSSyntheticData") (v2, cost) = g.svdPlusPlus() @@ -487,10 +391,7 @@ def test_strongly_connected_components(self): assert self.spark is not None vertices = self.spark.createDataFrame([(i,) for i in range(1, 6)], ["id"]) edges = self.spark.createDataFrame([(7, 8)], ["src", "dst"]) - if self.connect_mode: - g = GraphFrameConnect(vertices, edges) - else: - g = GraphFrame(vertices, edges) + g = GraphFrame(vertices, edges) c = g.stronglyConnectedComponents(5) for row in c.collect(): self.assertEqual(row.id, row.component) @@ -499,15 +400,12 @@ def test_triangle_counts(self): assert self.spark is not None edges = self.spark.createDataFrame([(0, 1), (1, 2), (2, 0)], ["src", "dst"]) vertices = self.spark.createDataFrame([(0,), (1,), (2,)], ["id"]) - if self.connect_mode: - g = GraphFrameConnect(vertices, edges) - else: - g = GraphFrame(vertices, edges) + g = GraphFrame(vertices, edges) c = g.triangleCount() for row in c.select("id", "count").collect(): self.assertEqual(row.asDict()["count"], 1) - @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + @unittest.skipIf(is_remote(), "SKIP FOR CONNECT") def test_mutithreaded_sparksession_usage(self): # Test that we can use the GraphFrame API from multiple threads localVertices = [(1, "A"), (2, "B"), (3, "C")] @@ -537,7 +435,7 @@ def setUp(self): super(GraphFrameExamplesTest, self).setUp() self.japi = _java_api(self.spark._sc) - @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + @unittest.skipIf(is_remote(), "SKIP FOR CONNECT") def test_belief_propagation(self): # create graphical model g of size 3 x 3 g = Graphs(self.spark).gridIsingModel(3) @@ -554,14 +452,14 @@ def test_belief_propagation(self): ), ) - @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + @unittest.skipIf(is_remote(), "SKIP FOR CONNECT") def test_graph_friends(self): # construct graph g = Graphs(self.spark).friends() # check that a GraphFrame instance was returned self.assertIsInstance(g, GraphFrame) - @unittest.skipIf("SPARK_CONNECT_MODE_ENABLED" in os.environ, "SKIP FOR CONNECT") + @unittest.skipIf(is_remote(), "SKIP FOR CONNECT") def test_graph_grid_ising_model(self): # construct graph n = 3 From eee7b7b536ca26107eee92beeeaa0273dc46f319 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 23 Feb 2025 22:53:10 +0100 Subject: [PATCH 07/27] Working version? --- .github/workflows/python-ci.yml | 10 + dev/run_connect.py | 7 + .../src/main/protobuf/graphframes.proto | 4 +- .../graphframes/GraphFramesConnectUtils.scala | 17 +- python/graphframes/__init__.py | 3 +- python/graphframes/classic/graphframe.py | 143 +++--- .../graphframes/connect/graphframe_client.py | 170 +++---- .../connect/proto/graphframes_pb2.py | 110 +++-- .../connect/proto/graphframes_pb2.pyi | 160 ++++++- .../connect/proto/graphframes_pb2_grpc.py | 1 - python/graphframes/connect/utils.py | 2 + python/graphframes/examples/__init__.py | 3 +- .../examples/belief_propagation.py | 49 +- python/graphframes/examples/graphs.py | 75 ++-- python/graphframes/graphframe.py | 20 +- python/graphframes/lib/__init__.py | 3 +- python/graphframes/lib/aggregate_messages.py | 8 +- python/graphframes/lib/pregel.py | 8 +- python/poetry.lock | 423 +++++++++++++++--- python/pyproject.toml | 3 +- python/tests/tests.py | 13 +- .../graphframes/lib/ConnectedComponents.scala | 27 +- .../scala/org/graphframes/lib/Pregel.scala | 14 + 23 files changed, 855 insertions(+), 418 deletions(-) diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index e943f19d4..54cccdcb4 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -44,3 +44,13 @@ jobs: working-directory: ./python run: | poetry run python -m unittest discover -s tests/ + + - name: Test SparkConnect + run: | + cd python + export VENV_ROOT=$(poetry env info --path) + cd .. + $VENV_ROOT/bin/python dev/run-connect.sh + cd python + export SPARK_CONNECT_MODE_ENABLED=1 + poetry run python -m unittest discover -s tests/ diff --git a/dev/run_connect.py b/dev/run_connect.py index 14a7bf6e8..95fb0c6fb 100644 --- a/dev/run_connect.py +++ b/dev/run_connect.py @@ -90,6 +90,11 @@ .joinpath(f"graphframes-connect-assembly-{GRAPHFRAMES_VERSION}.jar") ) shutil.copyfile(gf_jar, spark_home.joinpath(gf_jar.name)) + checkpoint_dir = Path("/tmp/GFTestsCheckpointDir") + if checkpoint_dir.exists(): + shutil.rmtree(checkpoint_dir.absolute().__str__(), ignore_errors=True) + + checkpoint_dir.mkdir(exist_ok=True, parents=True) run_connect_command = [ "./sbin/start-connect-server.sh", @@ -100,6 +105,8 @@ "spark.connect.extensions.relation.classes=org.apache.spark.sql.graphframes.GraphFramesConnect", "--packages", f"org.apache.spark:spark-connect_{SCALA_VERSION}:{SPARK_VERSION}", + "--conf", + "spark.checkpoint.dir=/tmp/GFTestsCheckpointDir", ] print("Starting SparkConnect Server...") spark_connect = subprocess.run( diff --git a/graphframes-connect/src/main/protobuf/graphframes.proto b/graphframes-connect/src/main/protobuf/graphframes.proto index ca5594d64..41185b510 100644 --- a/graphframes-connect/src/main/protobuf/graphframes.proto +++ b/graphframes-connect/src/main/protobuf/graphframes.proto @@ -106,8 +106,8 @@ message ParallelPersonalizedPageRank { message Pregel { ColumnOrExpression agg_msgs = 1; - ColumnOrExpression send_msg_to_dst = 2; - ColumnOrExpression send_msg_to_src = 3; + repeated ColumnOrExpression send_msg_to_dst = 2; + repeated ColumnOrExpression send_msg_to_src = 3; int32 checkpoint_interval = 4; int32 max_iter = 5; string additional_col_name = 6; diff --git a/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala index edc0e47d4..1a8d80b0e 100644 --- a/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala +++ b/graphframes-connect/src/main/scala/org/apache/spark/sql/graphframes/GraphFramesConnectUtils.scala @@ -162,18 +162,23 @@ object GraphFramesConnectUtils { } case MethodCase.PREGEL => { val pregelProto = apiMessage.getPregel - val pregel = graphFrame.pregel - pregel + var pregel = graphFrame.pregel .aggMsgs(parseColumnOrExpression(pregelProto.getAggMsgs, planner)) - .sendMsgToDst(parseColumnOrExpression(pregelProto.getSendMsgToDst, planner)) - .sendMsgToDst(parseColumnOrExpression(pregelProto.getSendMsgToDst, planner)) .setCheckpointInterval(pregelProto.getCheckpointInterval) - .setMaxIter(pregelProto.getMaxIter) .withVertexColumn( pregelProto.getAdditionalColName, parseColumnOrExpression(pregelProto.getAdditionalColInitial, planner), parseColumnOrExpression(pregelProto.getAdditionalColUpd, planner)) - .run() + .setMaxIter(pregelProto.getMaxIter) + + pregel = pregelProto.getSendMsgToSrcList.asScala + .map(parseColumnOrExpression(_, planner)) + .foldLeft(pregel)((p, col) => p.sendMsgToSrc(col)) + pregel = pregelProto.getSendMsgToDstList.asScala + .map(parseColumnOrExpression(_, planner)) + .foldLeft(pregel)((p, col) => p.sendMsgToDst(col)) + + pregel.run() } case MethodCase.SHORTEST_PATHS => { graphFrame.shortestPaths diff --git a/python/graphframes/__init__.py b/python/graphframes/__init__.py index 03f1e4943..bded262bc 100644 --- a/python/graphframes/__init__.py +++ b/python/graphframes/__init__.py @@ -1,4 +1,3 @@ - from .graphframe import GraphFrame -__all__ = ['GraphFrame'] +__all__ = ["GraphFrame"] diff --git a/python/graphframes/classic/graphframe.py b/python/graphframes/classic/graphframe.py index 5381ec8b5..23deef6f7 100644 --- a/python/graphframes/classic/graphframe.py +++ b/python/graphframes/classic/graphframe.py @@ -18,7 +18,7 @@ import sys from typing import Any, Union, Optional -if sys.version > '3': +if sys.version > "3": basestring = str from graphframes.lib import Pregel @@ -27,7 +27,7 @@ from pyspark.storagelevel import StorageLevel -def _from_java_gf(jgf: Any, spark: SparkSession) -> 'GraphFrame': +def _from_java_gf(jgf: Any, spark: SparkSession) -> "GraphFrame": """ (internal) creates a python GraphFrame wrapper from a java GraphFrame. @@ -37,10 +37,15 @@ def _from_java_gf(jgf: Any, spark: SparkSession) -> 'GraphFrame': pe = DataFrame(jgf.edges(), spark) return GraphFrame(pv, pe) + def _java_api(jsc: SparkContext) -> Any: javaClassName = "org.graphframes.GraphFramePythonAPI" - return jsc._jvm.Thread.currentThread().getContextClassLoader().loadClass(javaClassName) \ - .newInstance() + return ( + jsc._jvm.Thread.currentThread() + .getContextClassLoader() + .loadClass(javaClassName) + .newInstance() + ) class GraphFrame: @@ -76,16 +81,22 @@ def __init__(self, v: DataFrame, e: DataFrame) -> None: # Check that provided DataFrames contain required columns if self.ID not in v.columns: raise ValueError( - "Vertex ID column {} missing from vertex DataFrame, which has columns: {}" - .format(self.ID, ",".join(v.columns))) + "Vertex ID column {} missing from vertex DataFrame, which has columns: {}".format( + self.ID, ",".join(v.columns) + ) + ) if self.SRC not in e.columns: raise ValueError( - "Source vertex ID column {} missing from edge DataFrame, which has columns: {}" - .format(self.SRC, ",".join(e.columns))) + "Source vertex ID column {} missing from edge DataFrame, which has columns: {}".format( + self.SRC, ",".join(e.columns) + ) + ) if self.DST not in e.columns: raise ValueError( - "Destination vertex ID column {} missing from edge DataFrame, which has columns: {}" - .format(self.DST, ",".join(e.columns))) + "Destination vertex ID column {} missing from edge DataFrame, which has columns: {}".format( + self.DST, ",".join(e.columns) + ) + ) self._jvm_graph = self._jvm_gf_api.createGraph(v._jdf, e._jdf) @@ -109,8 +120,8 @@ def edges(self) -> DataFrame: def __repr__(self): return self._jvm_graph.toString() - def cache(self) -> 'GraphFrame': - """ Persist the dataframe representation of vertices and edges of the graph with the default + def cache(self) -> "GraphFrame": + """Persist the dataframe representation of vertices and edges of the graph with the default storage level. """ self._jvm_graph.cache() @@ -124,7 +135,7 @@ def persist(self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) -> "Gra self._jvm_graph.persist(javaStorageLevel) return self - def unpersist(self, blocking: bool = False) -> 'GraphFrame': + def unpersist(self, blocking: bool = False) -> "GraphFrame": """Mark the dataframe representation of vertices and edges of the graph as non-persistent, and remove all blocks for it from memory and disk. """ @@ -209,12 +220,12 @@ def find(self, pattern: str) -> DataFrame: jdf = self._jvm_graph.find(pattern) return DataFrame(jdf, self._spark) - def filterVertices(self, condition: Union[str, Column]) -> 'GraphFrame': + def filterVertices(self, condition: Union[str, Column]) -> "GraphFrame": """ Filters the vertices based on expression, remove edges containing any dropped vertices. - + :param condition: String or Column describing the condition expression for filtering. - :return: GraphFrame with filtered vertices and edges. + :return: GraphFrame with filtered vertices and edges. """ if isinstance(condition, basestring): @@ -225,12 +236,12 @@ def filterVertices(self, condition: Union[str, Column]) -> 'GraphFrame': raise TypeError("condition should be string or Column") return _from_java_gf(jdf, self._spark) - def filterEdges(self, condition: Union[str, Column]) -> 'GraphFrame': + def filterEdges(self, condition: Union[str, Column]) -> "GraphFrame": """ Filters the edges based on expression, keep all vertices. - + :param condition: String or Column describing the condition expression for filtering. - :return: GraphFrame with filtered edges. + :return: GraphFrame with filtered edges. """ if isinstance(condition, basestring): jdf = self._jvm_graph.filterEdges(condition) @@ -240,18 +251,18 @@ def filterEdges(self, condition: Union[str, Column]) -> 'GraphFrame': raise TypeError("condition should be string or Column") return _from_java_gf(jdf, self._spark) - def dropIsolatedVertices(self) -> 'GraphFrame': + def dropIsolatedVertices(self) -> "GraphFrame": """ Drops isolated vertices, vertices are not contained in any edges. - :return: GraphFrame with filtered vertices. + :return: GraphFrame with filtered vertices. """ jdf = self._jvm_graph.dropIsolatedVertices() return _from_java_gf(jdf, self._spark) - def bfs(self, fromExpr: str, toExpr: str, - edgeFilter: Optional[str] = None, - maxPathLength: int = 10) -> DataFrame: + def bfs( + self, fromExpr: str, toExpr: str, edgeFilter: Optional[str] = None, maxPathLength: int = 10 + ) -> DataFrame: """ Breadth-first search (BFS). @@ -259,18 +270,20 @@ def bfs(self, fromExpr: str, toExpr: str, :return: DataFrame with one Row for each shortest path between matching vertices. """ - builder = self._jvm_graph.bfs()\ - .fromExpr(fromExpr)\ - .toExpr(toExpr)\ - .maxPathLength(maxPathLength) + builder = ( + self._jvm_graph.bfs().fromExpr(fromExpr).toExpr(toExpr).maxPathLength(maxPathLength) + ) if edgeFilter is not None: builder.edgeFilter(edgeFilter) jdf = builder.run() return DataFrame(jdf, self._spark) - def aggregateMessages(self, aggCol: Union[Column, str], - sendToSrc: Union[Column, str, None] = None, - sendToDst: Union[Column, str, None] = None) -> DataFrame: + def aggregateMessages( + self, + aggCol: Union[Column, str], + sendToSrc: Union[Column, str, None] = None, + sendToDst: Union[Column, str, None] = None, + ) -> DataFrame: """ Aggregates messages from the neighbours. @@ -314,9 +327,12 @@ def aggregateMessages(self, aggCol: Union[Column, str], # Standard algorithms - def connectedComponents(self, algorithm: str = 'graphframes', - checkpointInterval: int = 2, - broadcastThreshold: int = 1000000) -> DataFrame: + def connectedComponents( + self, + algorithm: str = "graphframes", + checkpointInterval: int = 2, + broadcastThreshold: int = 1000000, + ) -> DataFrame: """ Computes the connected components of the graph. @@ -330,11 +346,13 @@ def connectedComponents(self, algorithm: str = 'graphframes', :return: DataFrame with new vertices column "component" """ - jdf = self._jvm_graph.connectedComponents() \ - .setAlgorithm(algorithm) \ - .setCheckpointInterval(checkpointInterval) \ - .setBroadcastThreshold(broadcastThreshold) \ + jdf = ( + self._jvm_graph.connectedComponents() + .setAlgorithm(algorithm) + .setCheckpointInterval(checkpointInterval) + .setBroadcastThreshold(broadcastThreshold) .run() + ) return DataFrame(jdf, self._spark) def labelPropagation(self, maxIter: int) -> DataFrame: @@ -349,10 +367,13 @@ def labelPropagation(self, maxIter: int) -> DataFrame: jdf = self._jvm_graph.labelPropagation().maxIter(maxIter).run() return DataFrame(jdf, self._spark) - def pageRank(self, resetProbability: float = 0.15, - sourceId: Optional[Any] = None, - maxIter: Optional[int] = None, - tol: Optional[float] = None) -> 'GraphFrame': + def pageRank( + self, + resetProbability: float = 0.15, + sourceId: Optional[Any] = None, + maxIter: Optional[int] = None, + tol: Optional[float] = None, + ) -> "GraphFrame": """ Runs the PageRank algorithm on the graph. Note: Exactly one of fixed_num_iter or tolerance must be set. @@ -379,9 +400,12 @@ def pageRank(self, resetProbability: float = 0.15, jgf = builder.run() return _from_java_gf(jgf, self._spark) - def parallelPersonalizedPageRank(self, resetProbability: float = 0.15, - sourceIds: Optional[list[Any]] = None, - maxIter: Optional[int] = None) -> 'GraphFrame': + def parallelPersonalizedPageRank( + self, + resetProbability: float = 0.15, + sourceIds: Optional[list[Any]] = None, + maxIter: Optional[int] = None, + ) -> "GraphFrame": """ Run the personalized PageRank algorithm on the graph, from the provided list of sources in parallel for a fixed number of iterations. @@ -393,7 +417,9 @@ def parallelPersonalizedPageRank(self, resetProbability: float = 0.15, :param maxIter: the fixed number of iterations this algorithm runs :return: GraphFrame with new vertices column "pageranks" and new edges column "weight" """ - assert sourceIds is not None and len(sourceIds) > 0, "Source vertices Ids sourceIds must be provided" + assert ( + sourceIds is not None and len(sourceIds) > 0 + ), "Source vertices Ids sourceIds must be provided" assert maxIter is not None, "Max number of iterations maxIter must be provided" sourceIds = self._sc._jvm.PythonUtils.toArray(sourceIds) builder = self._jvm_graph.parallelPersonalizedPageRank() @@ -427,10 +453,17 @@ def stronglyConnectedComponents(self, maxIter: int) -> DataFrame: jdf = self._jvm_graph.stronglyConnectedComponents().maxIter(maxIter).run() return DataFrame(jdf, self._spark) - def svdPlusPlus(self, rank: int = 10, maxIter: int = 2, - minValue: float = 0.0, maxValue: float = 5.0, - gamma1: float = 0.007, gamma2: float = 0.007, - gamma6: float = 0.005, gamma7: float = 0.015) -> tuple[DataFrame, float]: + def svdPlusPlus( + self, + rank: int = 10, + maxIter: int = 2, + minValue: float = 0.0, + maxValue: float = 5.0, + gamma1: float = 0.007, + gamma2: float = 0.007, + gamma6: float = 0.005, + gamma7: float = 0.015, + ) -> tuple[DataFrame, float]: """ Runs the SVD++ algorithm. @@ -462,12 +495,14 @@ def triangleCount(self) -> DataFrame: def _test(): import doctest import graphframe + globs = graphframe.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - globs['spark'] = SparkSession(globs['sc']).builder.getOrCreate() + globs["sc"] = SparkContext("local[4]", "PythonTest", batchSize=2) + globs["spark"] = SparkSession(globs["sc"]).builder.getOrCreate() (failure_count, test_count) = doctest.testmod( - globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) - globs['sc'].stop() + globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE + ) + globs["sc"].stop() if failure_count: exit(-1) diff --git a/python/graphframes/connect/graphframe_client.py b/python/graphframes/connect/graphframe_client.py index f07e27c7a..b2578dcbd 100644 --- a/python/graphframes/connect/graphframe_client.py +++ b/python/graphframes/connect/graphframe_client.py @@ -22,8 +22,8 @@ def __init__(self, graph: "GraphFrameConnect") -> None: self._col_name = None self._initial_expr = None self._update_after_agg_msgs_expr = None - self._send_msg_to_src = None - self._send_msg_to_dst = None + self._send_msg_to_src = [] + self._send_msg_to_dst = [] self._agg_msg = None def setMaxIter(self, value: int) -> Self: @@ -46,11 +46,11 @@ def withVertexColumn( return self def sendMsgToSrc(self, msgExpr: Column | str) -> Self: - self._send_msg_to_src = msgExpr + self._send_msg_to_src.append(msgExpr) return self def sendMsgToDst(self, msgExpr: Column | str) -> Self: - self._send_msg_to_dst = msgExpr + self._send_msg_to_dst.append(msgExpr) return self def aggMsgs(self, aggExpr: Column) -> Self: @@ -65,13 +65,14 @@ def __init__( checkpoint_interval: int, vertex_col_name: str, agg_msg: Column | str, - send2dst: Column | str, - send2src: Column | str, + send2dst: list[Column | str], + send2src: list[Column | str], vertex_col_init: Column | str, vertex_col_upd: Column | str, vertices: DataFrame, edges: DataFrame, ) -> None: + super().__init__(None) self.max_iter = max_iter self.checkpoint_interval = checkpoint_interval self.vertex_col_name = vertex_col_name @@ -84,26 +85,26 @@ def __init__( self.edges = edges def plan(self, session: SparkConnectClient) -> proto.Relation: - plan = self._create_proto_relation() pregel = pb.Pregel( agg_msgs=make_column_or_expr(self.agg_msg, session), - send_msg_to_dst=make_column_or_expr(self.send2dst, session), - send_msg_to_src=make_column_or_expr(self.send2src, session), + send_msg_to_dst=[ + make_column_or_expr(c_or_e, session) for c_or_e in self.send2dst + ], + send_msg_to_src=[ + make_column_or_expr(c_or_e, session) for c_or_e in self.send2src + ], checkpoint_interval=self.checkpoint_interval, max_iter=self.max_iter, additional_col_name=self.vertex_col_name, - additional_col_initial=make_column_or_expr( - self.vertex_col_init, session - ), - additional_col_upd=make_column_or_expr( - self.vertex_col_upd, session - ), + additional_col_initial=make_column_or_expr(self.vertex_col_init, session), + additional_col_upd=make_column_or_expr(self.vertex_col_upd, session), ) pb_message = pb.GraphFramesAPI( vertices=dataframe_to_proto(self.vertices, session), edges=dataframe_to_proto(self.edges, session), ) pb_message.pregel.CopyFrom(pregel) + plan = self._create_proto_relation() plan.extension.Pack(pb_message) return plan @@ -117,12 +118,6 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: if self._agg_msg is None: raise ValueError("AggMsg is not initialized!") - if self._send_msg_to_src is None: - raise ValueError("Send-to-src column is not initialized!") - - if self._send_msg_to_dst is None: - raise ValueError("Send-to-dst column is not initialized!") - return DataFrame.withPlan( Pregel( max_iter=self._max_iter, @@ -219,9 +214,7 @@ def cache(self) -> "GraphFrameConnect": new_edges = self._edges.cache() return GraphFrameConnect(new_vertices, new_edges) - def persist( - self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY - ) -> "GraphFrameConnect": + def persist(self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) -> "GraphFrameConnect": new_vertices = self._vertices.persist(storageLevel=storageLevel) new_edges = self._edges.persist(storageLevel=storageLevel) return GraphFrameConnect(new_vertices, new_edges) @@ -233,60 +226,23 @@ def unpersist(self, blocking: bool = False) -> "GraphFrameConnect": @property def outDegrees(self) -> DataFrame: - class OutDegrees(LogicalPlan): - def __init__(self, v: DataFrame, e: DataFrame) -> None: - super().__init__(None) - self.v = v - self.e = e - - def plan(self, session: SparkConnectClient) -> proto.Relation: - graphframes_api_call = GraphFrameConnect._get_pb_api_message( - self.v, self.e, session - ) - graphframes_api_call.out_degrees.CopyFrom(pb.OutDegrees()) - plan = self._create_proto_relation() - plan.extension.Pack(graphframes_api_call) - return plan - - return DataFrame.withPlan(OutDegrees(self._vertices, self._edges), self._spark) + return self._edges.groupBy(F.col(self.SRC).alias(self.ID)).agg( + F.count("*").alias("outDegree") + ) @property def inDegrees(self) -> DataFrame: - class InDegrees(LogicalPlan): - def __init__(self, v: DataFrame, e: DataFrame) -> None: - super().__init__(None) - self.v = v - self.e = e - - def plan(self, session: SparkConnectClient) -> proto.Relation: - graphframes_api_call = GraphFrameConnect._get_pb_api_message( - self.v, self.e, session - ) - graphframes_api_call.in_degrees.CopyFrom(pb.InDegrees()) - plan = self._create_proto_relation() - plan.extension.Pack(graphframes_api_call) - return plan - - return DataFrame.withPlan(InDegrees(self._vertices, self._edges), self._spark) + return self._edges.groupBy(F.col(self.DST).alias(self.ID)).agg( + F.count("*").alias("inDegree") + ) @property def degrees(self) -> DataFrame: - class Degrees(LogicalPlan): - def __init__(self, v: DataFrame, e: DataFrame) -> None: - super().__init__(None) - self.v = v - self.e = e - - def plan(self, session: SparkConnectClient) -> proto.Relation: - graphframes_api_call = GraphFrameConnect._get_pb_api_message( - self.v, self.e, session - ) - graphframes_api_call.degrees.CopyFrom(pb.Degrees()) - plan = self._create_proto_relation() - plan.extension.Pack(graphframes_api_call) - return plan - - return DataFrame.withPlan(Degrees(self._vertices, self._edges), self._spark) + return ( + self._edges.select(F.explode(F.array(F.col(self.SRC), F.col(self.DST))).alias(self.ID)) + .groupBy(self.ID) + .agg(F.count("*").alias("degree")) + ) @property def triplets(self) -> DataFrame: @@ -328,15 +284,11 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: plan.extension.Pack(graphframes_api_call) return plan - return DataFrame.withPlan( - Find(self._vertices, self._edges, pattern), self._spark - ) + return DataFrame.withPlan(Find(self._vertices, self._edges, pattern), self._spark) def filterVertices(self, condition: str | Column) -> "GraphFrameConnect": class FilterVertices(LogicalPlan): - def __init__( - self, v: DataFrame, e: DataFrame, condition: str | Column - ) -> None: + def __init__(self, v: DataFrame, e: DataFrame, condition: str | Column) -> None: super().__init__(None) self.v = v self.e = e @@ -371,9 +323,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: def filterEdges(self, condition: str | Column) -> "GraphFrameConnect": class FilterEdges(LogicalPlan): - def __init__( - self, v: DataFrame, e: DataFrame, condition: str | Column - ) -> None: + def __init__(self, v: DataFrame, e: DataFrame, condition: str | Column) -> None: super().__init__(None) self.v = v self.e = e @@ -384,9 +334,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: self.v, self.e, session ) col_or_expr = make_column_or_expr(self.c, session) - graphframes_api_call.filter_edges.CopyFrom( - pb.FilterEdges(condition=col_or_expr) - ) + graphframes_api_call.filter_edges.CopyFrom(pb.FilterEdges(condition=col_or_expr)) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) return plan @@ -407,9 +355,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call = GraphFrameConnect._get_pb_api_message( self.v, self.e, session ) - graphframes_api_call.drop_isolated_vertices.CopyFrom( - pb.DropIsolatedVertices() - ) + graphframes_api_call.drop_isolated_vertices.CopyFrom(pb.DropIsolatedVertices()) plan = self._create_proto_relation() plan.extension.Pack(graphframes_api_call) return plan @@ -504,12 +450,16 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call.aggregate_messages.CopyFrom( pb.AggregateMessages( agg_col=make_column_or_expr(self.agg_col, session), - send_to_src=None - if self.send2src is None - else make_column_or_expr(self.send2src, session), - send_to_dst=None - if self.send2dst is None - else make_column_or_expr(self.send2dst, session), + send_to_src=( + None + if self.send2src is None + else make_column_or_expr(self.send2src, session) + ), + send_to_dst=( + None + if self.send2dst is None + else make_column_or_expr(self.send2dst, session) + ), ) ) plan = self._create_proto_relation() @@ -517,14 +467,10 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: return plan if sendToSrc is None and sendToDst is None: - raise ValueError( - "Either `sendToSrc`, `sendToDst`, or both have to be provided" - ) + raise ValueError("Either `sendToSrc`, `sendToDst`, or both have to be provided") return DataFrame.withPlan( - AggregateMessages( - self._vertices, self._edges, aggCol, sendToSrc, sendToDst - ), + AggregateMessages(self._vertices, self._edges, aggCol, sendToSrc, sendToDst), self._spark, ) @@ -599,9 +545,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: LabelPropagation(self._vertices, self._edges, maxIter), self._spark ) - def _update_page_rank_edge_weights( - self, new_vertices: DataFrame - ) -> "GraphFrameConnect": + def _update_page_rank_edge_weights(self, new_vertices: DataFrame) -> "GraphFrameConnect": cols2select = self.edges.columns + ["weight"] new_edges = ( self._edges.join( @@ -651,9 +595,9 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call.page_rank.CopyFrom( pb.PageRank( reset_probability=self.reset_prob, - source_id=None - if self.source_id is None - else make_str_or_long_id(self.source_id), + source_id=( + None if self.source_id is None else make_str_or_long_id(self.source_id) + ), max_iter=self.max_iter, tol=self.tol, ) @@ -710,9 +654,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: graphframes_api_call.parallel_personalized_page_rank.CopyFrom( pb.ParallelPersonalizedPageRank( reset_probability=self.reset_prob, - source_ids=[ - make_str_or_long_id(raw_id) for raw_id in self.source_ids - ], + source_ids=[make_str_or_long_id(raw_id) for raw_id in self.source_ids], max_iter=self.max_iter, ) ) @@ -739,9 +681,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: def shortestPaths(self, landmarks: list[str | int]) -> DataFrame: class ShortestPaths(LogicalPlan): - def __init__( - self, v: DataFrame, e: DataFrame, landmarks: list[str | int] - ) -> None: + def __init__(self, v: DataFrame, e: DataFrame, landmarks: list[str | int]) -> None: super().__init__(None) self.v = v self.e = e @@ -753,9 +693,7 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: ) graphframes_api_call.shortest_paths.CopyFrom( pb.ShortestPaths( - landmarks=[ - make_str_or_long_id(raw_id) for raw_id in self.landmarks - ] + landmarks=[make_str_or_long_id(raw_id) for raw_id in self.landmarks] ) ) plan = self._create_proto_relation() @@ -886,6 +824,4 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: plan.extension.Pack(graphframes_api_call) return plan - return DataFrame.withPlan( - TriangleCount(self._vertices, self._edges), self._spark - ) + return DataFrame.withPlan(TriangleCount(self._vertices, self._edges), self._spark) diff --git a/python/graphframes/connect/proto/graphframes_pb2.py b/python/graphframes/connect/proto/graphframes_pb2.py index 429dfb59c..2b9d5b944 100644 --- a/python/graphframes/connect/proto/graphframes_pb2.py +++ b/python/graphframes/connect/proto/graphframes_pb2.py @@ -9,71 +9,69 @@ from google.protobuf import runtime_version as _runtime_version from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder + _runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 5, - 27, - 1, - '', - 'graphframes.proto' + _runtime_version.Domain.PUBLIC, 5, 27, 1, "", "graphframes.proto" ) # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11graphframes.proto\x12\x1dorg.graphframes.connect.proto\"\xba\r\n\x0eGraphFramesAPI\x12\x1a\n\x08vertices\x18\x01 \x01(\x0cR\x08vertices\x12\x14\n\x05\x65\x64ges\x18\x02 \x01(\x0cR\x05\x65\x64ges\x12\x61\n\x12\x61ggregate_messages\x18\x03 \x01(\x0b\x32\x30.org.graphframes.connect.proto.AggregateMessagesH\x00R\x11\x61ggregateMessages\x12\x36\n\x03\x62\x66s\x18\x04 \x01(\x0b\x32\".org.graphframes.connect.proto.BFSH\x00R\x03\x62\x66s\x12g\n\x14\x63onnected_components\x18\x05 \x01(\x0b\x32\x32.org.graphframes.connect.proto.ConnectedComponentsH\x00R\x13\x63onnectedComponents\x12\x42\n\x07\x64\x65grees\x18\x06 \x01(\x0b\x32&.org.graphframes.connect.proto.DegreesH\x00R\x07\x64\x65grees\x12k\n\x16\x64rop_isolated_vertices\x18\x07 \x01(\x0b\x32\x33.org.graphframes.connect.proto.DropIsolatedVerticesH\x00R\x14\x64ropIsolatedVertices\x12O\n\x0c\x66ilter_edges\x18\x08 \x01(\x0b\x32*.org.graphframes.connect.proto.FilterEdgesH\x00R\x0b\x66ilterEdges\x12X\n\x0f\x66ilter_vertices\x18\t \x01(\x0b\x32-.org.graphframes.connect.proto.FilterVerticesH\x00R\x0e\x66ilterVertices\x12\x39\n\x04\x66ind\x18\n \x01(\x0b\x32#.org.graphframes.connect.proto.FindH\x00R\x04\x66ind\x12I\n\nin_degrees\x18\x0b \x01(\x0b\x32(.org.graphframes.connect.proto.InDegreesH\x00R\tinDegrees\x12^\n\x11label_propagation\x18\x0c \x01(\x0b\x32/.org.graphframes.connect.proto.LabelPropagationH\x00R\x10labelPropagation\x12L\n\x0bout_degrees\x18\r \x01(\x0b\x32).org.graphframes.connect.proto.OutDegreesH\x00R\noutDegrees\x12\x46\n\tpage_rank\x18\x0e \x01(\x0b\x32\'.org.graphframes.connect.proto.PageRankH\x00R\x08pageRank\x12\x84\x01\n\x1fparallel_personalized_page_rank\x18\x0f \x01(\x0b\x32;.org.graphframes.connect.proto.ParallelPersonalizedPageRankH\x00R\x1cparallelPersonalizedPageRank\x12?\n\x06pregel\x18\x10 \x01(\x0b\x32%.org.graphframes.connect.proto.PregelH\x00R\x06pregel\x12U\n\x0eshortest_paths\x18\x11 \x01(\x0b\x32,.org.graphframes.connect.proto.ShortestPathsH\x00R\rshortestPaths\x12\x80\x01\n\x1dstrongly_connected_components\x18\x12 \x01(\x0b\x32:.org.graphframes.connect.proto.StronglyConnectedComponentsH\x00R\x1bstronglyConnectedComponents\x12P\n\rsvd_plus_plus\x18\x13 \x01(\x0b\x32*.org.graphframes.connect.proto.SVDPlusPlusH\x00R\x0bsvdPlusPlus\x12U\n\x0etriangle_count\x18\x14 \x01(\x0b\x32,.org.graphframes.connect.proto.TriangleCountH\x00R\rtriangleCount\x12\x45\n\x08triplets\x18\x15 \x01(\x0b\x32\'.org.graphframes.connect.proto.TripletsH\x00R\x08tripletsB\x08\n\x06method\"M\n\x12\x43olumnOrExpression\x12\x12\n\x03\x63ol\x18\x01 \x01(\x0cH\x00R\x03\x63ol\x12\x14\n\x04\x65xpr\x18\x02 \x01(\tH\x00R\x04\x65xprB\r\n\x0b\x63ol_or_expr\"P\n\x0eStringOrLongID\x12\x19\n\x07long_id\x18\x01 \x01(\x03H\x00R\x06longId\x12\x1d\n\tstring_id\x18\x02 \x01(\tH\x00R\x08stringIdB\x04\n\x02id\"\xaf\x02\n\x11\x41ggregateMessages\x12J\n\x07\x61gg_col\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06\x61ggCol\x12V\n\x0bsend_to_src\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x00R\tsendToSrc\x88\x01\x01\x12V\n\x0bsend_to_dst\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x01R\tsendToDst\x88\x01\x01\x42\x0e\n\x0c_send_to_srcB\x0e\n\x0c_send_to_dst\"\x9d\x02\n\x03\x42\x46S\x12N\n\tfrom_expr\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x08\x66romExpr\x12J\n\x07to_expr\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06toExpr\x12R\n\x0b\x65\x64ge_filter\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\nedgeFilter\x12&\n\x0fmax_path_length\x18\x04 \x01(\x05R\rmaxPathLength\"\x95\x01\n\x13\x43onnectedComponents\x12\x1c\n\talgorithm\x18\x01 \x01(\tR\talgorithm\x12/\n\x13\x63heckpoint_interval\x18\x02 \x01(\x05R\x12\x63heckpointInterval\x12/\n\x13\x62roadcast_threshold\x18\x03 \x01(\x05R\x12\x62roadcastThreshold\"\t\n\x07\x44\x65grees\"\x16\n\x14\x44ropIsolatedVertices\"^\n\x0b\x46ilterEdges\x12O\n\tcondition\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition\"a\n\x0e\x46ilterVertices\x12O\n\tcondition\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition\" \n\x04\x46ind\x12\x18\n\x07pattern\x18\x01 \x01(\tR\x07pattern\"\x0b\n\tInDegrees\"-\n\x10LabelPropagation\x12\x19\n\x08max_iter\x18\x01 \x01(\x05R\x07maxIter\"\x0c\n\nOutDegrees\"\xc3\x01\n\x08PageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12O\n\tsource_id\x18\x02 \x01(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDH\x00R\x08sourceId\x88\x01\x01\x12\x19\n\x08max_iter\x18\x03 \x01(\x05R\x07maxIter\x12\x10\n\x03tol\x18\x04 \x01(\x01R\x03tolB\x0c\n\n_source_id\"\xb4\x01\n\x1cParallelPersonalizedPageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12L\n\nsource_ids\x18\x02 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tsourceIds\x12\x19\n\x08max_iter\x18\x03 \x01(\x05R\x07maxIter\"\xd0\x04\n\x06Pregel\x12L\n\x08\x61gg_msgs\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x07\x61ggMsgs\x12X\n\x0fsend_msg_to_dst\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToDst\x12X\n\x0fsend_msg_to_src\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToSrc\x12/\n\x13\x63heckpoint_interval\x18\x04 \x01(\x05R\x12\x63heckpointInterval\x12\x19\n\x08max_iter\x18\x05 \x01(\x05R\x07maxIter\x12.\n\x13\x61\x64\x64itional_col_name\x18\x06 \x01(\tR\x11\x61\x64\x64itionalColName\x12g\n\x16\x61\x64\x64itional_col_initial\x18\x07 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x14\x61\x64\x64itionalColInitial\x12_\n\x12\x61\x64\x64itional_col_upd\x18\x08 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x10\x61\x64\x64itionalColUpd\"\\\n\rShortestPaths\x12K\n\tlandmarks\x18\x01 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tlandmarks\"8\n\x1bStronglyConnectedComponents\x12\x19\n\x08max_iter\x18\x01 \x01(\x05R\x07maxIter\"\xd6\x01\n\x0bSVDPlusPlus\x12\x12\n\x04rank\x18\x01 \x01(\x05R\x04rank\x12\x19\n\x08max_iter\x18\x02 \x01(\x05R\x07maxIter\x12\x1b\n\tmin_value\x18\x03 \x01(\x01R\x08minValue\x12\x1b\n\tmax_value\x18\x04 \x01(\x01R\x08maxValue\x12\x16\n\x06gamma1\x18\x05 \x01(\x01R\x06gamma1\x12\x16\n\x06gamma2\x18\x06 \x01(\x01R\x06gamma2\x12\x16\n\x06gamma6\x18\x07 \x01(\x01R\x06gamma6\x12\x16\n\x06gamma7\x18\x08 \x01(\x01R\x06gamma7\"\x0f\n\rTriangleCount\"\n\n\x08TripletsB\xd2\x01\n!com.org.graphframes.connect.protoB\x10GraphframesProtoH\x01P\x01\xa0\x01\x01\xa2\x02\x04OGCP\xaa\x02\x1dOrg.Graphframes.Connect.Proto\xca\x02\x1dOrg\\Graphframes\\Connect\\Proto\xe2\x02)Org\\Graphframes\\Connect\\Proto\\GPBMetadata\xea\x02 Org::Graphframes::Connect::Protob\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x11graphframes.proto\x12\x1dorg.graphframes.connect.proto"\xba\r\n\x0eGraphFramesAPI\x12\x1a\n\x08vertices\x18\x01 \x01(\x0cR\x08vertices\x12\x14\n\x05\x65\x64ges\x18\x02 \x01(\x0cR\x05\x65\x64ges\x12\x61\n\x12\x61ggregate_messages\x18\x03 \x01(\x0b\x32\x30.org.graphframes.connect.proto.AggregateMessagesH\x00R\x11\x61ggregateMessages\x12\x36\n\x03\x62\x66s\x18\x04 \x01(\x0b\x32".org.graphframes.connect.proto.BFSH\x00R\x03\x62\x66s\x12g\n\x14\x63onnected_components\x18\x05 \x01(\x0b\x32\x32.org.graphframes.connect.proto.ConnectedComponentsH\x00R\x13\x63onnectedComponents\x12\x42\n\x07\x64\x65grees\x18\x06 \x01(\x0b\x32&.org.graphframes.connect.proto.DegreesH\x00R\x07\x64\x65grees\x12k\n\x16\x64rop_isolated_vertices\x18\x07 \x01(\x0b\x32\x33.org.graphframes.connect.proto.DropIsolatedVerticesH\x00R\x14\x64ropIsolatedVertices\x12O\n\x0c\x66ilter_edges\x18\x08 \x01(\x0b\x32*.org.graphframes.connect.proto.FilterEdgesH\x00R\x0b\x66ilterEdges\x12X\n\x0f\x66ilter_vertices\x18\t \x01(\x0b\x32-.org.graphframes.connect.proto.FilterVerticesH\x00R\x0e\x66ilterVertices\x12\x39\n\x04\x66ind\x18\n \x01(\x0b\x32#.org.graphframes.connect.proto.FindH\x00R\x04\x66ind\x12I\n\nin_degrees\x18\x0b \x01(\x0b\x32(.org.graphframes.connect.proto.InDegreesH\x00R\tinDegrees\x12^\n\x11label_propagation\x18\x0c \x01(\x0b\x32/.org.graphframes.connect.proto.LabelPropagationH\x00R\x10labelPropagation\x12L\n\x0bout_degrees\x18\r \x01(\x0b\x32).org.graphframes.connect.proto.OutDegreesH\x00R\noutDegrees\x12\x46\n\tpage_rank\x18\x0e \x01(\x0b\x32\'.org.graphframes.connect.proto.PageRankH\x00R\x08pageRank\x12\x84\x01\n\x1fparallel_personalized_page_rank\x18\x0f \x01(\x0b\x32;.org.graphframes.connect.proto.ParallelPersonalizedPageRankH\x00R\x1cparallelPersonalizedPageRank\x12?\n\x06pregel\x18\x10 \x01(\x0b\x32%.org.graphframes.connect.proto.PregelH\x00R\x06pregel\x12U\n\x0eshortest_paths\x18\x11 \x01(\x0b\x32,.org.graphframes.connect.proto.ShortestPathsH\x00R\rshortestPaths\x12\x80\x01\n\x1dstrongly_connected_components\x18\x12 \x01(\x0b\x32:.org.graphframes.connect.proto.StronglyConnectedComponentsH\x00R\x1bstronglyConnectedComponents\x12P\n\rsvd_plus_plus\x18\x13 \x01(\x0b\x32*.org.graphframes.connect.proto.SVDPlusPlusH\x00R\x0bsvdPlusPlus\x12U\n\x0etriangle_count\x18\x14 \x01(\x0b\x32,.org.graphframes.connect.proto.TriangleCountH\x00R\rtriangleCount\x12\x45\n\x08triplets\x18\x15 \x01(\x0b\x32\'.org.graphframes.connect.proto.TripletsH\x00R\x08tripletsB\x08\n\x06method"M\n\x12\x43olumnOrExpression\x12\x12\n\x03\x63ol\x18\x01 \x01(\x0cH\x00R\x03\x63ol\x12\x14\n\x04\x65xpr\x18\x02 \x01(\tH\x00R\x04\x65xprB\r\n\x0b\x63ol_or_expr"P\n\x0eStringOrLongID\x12\x19\n\x07long_id\x18\x01 \x01(\x03H\x00R\x06longId\x12\x1d\n\tstring_id\x18\x02 \x01(\tH\x00R\x08stringIdB\x04\n\x02id"\xaf\x02\n\x11\x41ggregateMessages\x12J\n\x07\x61gg_col\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06\x61ggCol\x12V\n\x0bsend_to_src\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x00R\tsendToSrc\x88\x01\x01\x12V\n\x0bsend_to_dst\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x01R\tsendToDst\x88\x01\x01\x42\x0e\n\x0c_send_to_srcB\x0e\n\x0c_send_to_dst"\x9d\x02\n\x03\x42\x46S\x12N\n\tfrom_expr\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x08\x66romExpr\x12J\n\x07to_expr\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06toExpr\x12R\n\x0b\x65\x64ge_filter\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\nedgeFilter\x12&\n\x0fmax_path_length\x18\x04 \x01(\x05R\rmaxPathLength"\x95\x01\n\x13\x43onnectedComponents\x12\x1c\n\talgorithm\x18\x01 \x01(\tR\talgorithm\x12/\n\x13\x63heckpoint_interval\x18\x02 \x01(\x05R\x12\x63heckpointInterval\x12/\n\x13\x62roadcast_threshold\x18\x03 \x01(\x05R\x12\x62roadcastThreshold"\t\n\x07\x44\x65grees"\x16\n\x14\x44ropIsolatedVertices"^\n\x0b\x46ilterEdges\x12O\n\tcondition\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition"a\n\x0e\x46ilterVertices\x12O\n\tcondition\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition" \n\x04\x46ind\x12\x18\n\x07pattern\x18\x01 \x01(\tR\x07pattern"\x0b\n\tInDegrees"-\n\x10LabelPropagation\x12\x19\n\x08max_iter\x18\x01 \x01(\x05R\x07maxIter"\x0c\n\nOutDegrees"\xe2\x01\n\x08PageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12O\n\tsource_id\x18\x02 \x01(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDH\x00R\x08sourceId\x88\x01\x01\x12\x1e\n\x08max_iter\x18\x03 \x01(\x05H\x01R\x07maxIter\x88\x01\x01\x12\x15\n\x03tol\x18\x04 \x01(\x01H\x02R\x03tol\x88\x01\x01\x42\x0c\n\n_source_idB\x0b\n\t_max_iterB\x06\n\x04_tol"\xb4\x01\n\x1cParallelPersonalizedPageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12L\n\nsource_ids\x18\x02 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tsourceIds\x12\x19\n\x08max_iter\x18\x03 \x01(\x05R\x07maxIter"\xd0\x04\n\x06Pregel\x12L\n\x08\x61gg_msgs\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x07\x61ggMsgs\x12X\n\x0fsend_msg_to_dst\x18\x02 \x03(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToDst\x12X\n\x0fsend_msg_to_src\x18\x03 \x03(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToSrc\x12/\n\x13\x63heckpoint_interval\x18\x04 \x01(\x05R\x12\x63heckpointInterval\x12\x19\n\x08max_iter\x18\x05 \x01(\x05R\x07maxIter\x12.\n\x13\x61\x64\x64itional_col_name\x18\x06 \x01(\tR\x11\x61\x64\x64itionalColName\x12g\n\x16\x61\x64\x64itional_col_initial\x18\x07 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x14\x61\x64\x64itionalColInitial\x12_\n\x12\x61\x64\x64itional_col_upd\x18\x08 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x10\x61\x64\x64itionalColUpd"\\\n\rShortestPaths\x12K\n\tlandmarks\x18\x01 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tlandmarks"8\n\x1bStronglyConnectedComponents\x12\x19\n\x08max_iter\x18\x01 \x01(\x05R\x07maxIter"\xd6\x01\n\x0bSVDPlusPlus\x12\x12\n\x04rank\x18\x01 \x01(\x05R\x04rank\x12\x19\n\x08max_iter\x18\x02 \x01(\x05R\x07maxIter\x12\x1b\n\tmin_value\x18\x03 \x01(\x01R\x08minValue\x12\x1b\n\tmax_value\x18\x04 \x01(\x01R\x08maxValue\x12\x16\n\x06gamma1\x18\x05 \x01(\x01R\x06gamma1\x12\x16\n\x06gamma2\x18\x06 \x01(\x01R\x06gamma2\x12\x16\n\x06gamma6\x18\x07 \x01(\x01R\x06gamma6\x12\x16\n\x06gamma7\x18\x08 \x01(\x01R\x06gamma7"\x0f\n\rTriangleCount"\n\n\x08TripletsB\xd2\x01\n!com.org.graphframes.connect.protoB\x10GraphframesProtoH\x01P\x01\xa0\x01\x01\xa2\x02\x04OGCP\xaa\x02\x1dOrg.Graphframes.Connect.Proto\xca\x02\x1dOrg\\Graphframes\\Connect\\Proto\xe2\x02)Org\\Graphframes\\Connect\\Proto\\GPBMetadata\xea\x02 Org::Graphframes::Connect::Protob\x06proto3' +) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'graphframes_pb2', _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "graphframes_pb2", _globals) if not _descriptor._USE_C_DESCRIPTORS: - _globals['DESCRIPTOR']._loaded_options = None - _globals['DESCRIPTOR']._serialized_options = b'\n!com.org.graphframes.connect.protoB\020GraphframesProtoH\001P\001\240\001\001\242\002\004OGCP\252\002\035Org.Graphframes.Connect.Proto\312\002\035Org\\Graphframes\\Connect\\Proto\342\002)Org\\Graphframes\\Connect\\Proto\\GPBMetadata\352\002 Org::Graphframes::Connect::Proto' - _globals['_GRAPHFRAMESAPI']._serialized_start=53 - _globals['_GRAPHFRAMESAPI']._serialized_end=1775 - _globals['_COLUMNOREXPRESSION']._serialized_start=1777 - _globals['_COLUMNOREXPRESSION']._serialized_end=1854 - _globals['_STRINGORLONGID']._serialized_start=1856 - _globals['_STRINGORLONGID']._serialized_end=1936 - _globals['_AGGREGATEMESSAGES']._serialized_start=1939 - _globals['_AGGREGATEMESSAGES']._serialized_end=2242 - _globals['_BFS']._serialized_start=2245 - _globals['_BFS']._serialized_end=2530 - _globals['_CONNECTEDCOMPONENTS']._serialized_start=2533 - _globals['_CONNECTEDCOMPONENTS']._serialized_end=2682 - _globals['_DEGREES']._serialized_start=2684 - _globals['_DEGREES']._serialized_end=2693 - _globals['_DROPISOLATEDVERTICES']._serialized_start=2695 - _globals['_DROPISOLATEDVERTICES']._serialized_end=2717 - _globals['_FILTEREDGES']._serialized_start=2719 - _globals['_FILTEREDGES']._serialized_end=2813 - _globals['_FILTERVERTICES']._serialized_start=2815 - _globals['_FILTERVERTICES']._serialized_end=2912 - _globals['_FIND']._serialized_start=2914 - _globals['_FIND']._serialized_end=2946 - _globals['_INDEGREES']._serialized_start=2948 - _globals['_INDEGREES']._serialized_end=2959 - _globals['_LABELPROPAGATION']._serialized_start=2961 - _globals['_LABELPROPAGATION']._serialized_end=3006 - _globals['_OUTDEGREES']._serialized_start=3008 - _globals['_OUTDEGREES']._serialized_end=3020 - _globals['_PAGERANK']._serialized_start=3023 - _globals['_PAGERANK']._serialized_end=3218 - _globals['_PARALLELPERSONALIZEDPAGERANK']._serialized_start=3221 - _globals['_PARALLELPERSONALIZEDPAGERANK']._serialized_end=3401 - _globals['_PREGEL']._serialized_start=3404 - _globals['_PREGEL']._serialized_end=3996 - _globals['_SHORTESTPATHS']._serialized_start=3998 - _globals['_SHORTESTPATHS']._serialized_end=4090 - _globals['_STRONGLYCONNECTEDCOMPONENTS']._serialized_start=4092 - _globals['_STRONGLYCONNECTEDCOMPONENTS']._serialized_end=4148 - _globals['_SVDPLUSPLUS']._serialized_start=4151 - _globals['_SVDPLUSPLUS']._serialized_end=4365 - _globals['_TRIANGLECOUNT']._serialized_start=4367 - _globals['_TRIANGLECOUNT']._serialized_end=4382 - _globals['_TRIPLETS']._serialized_start=4384 - _globals['_TRIPLETS']._serialized_end=4394 + _globals["DESCRIPTOR"]._loaded_options = None + _globals["DESCRIPTOR"]._serialized_options = ( + b"\n!com.org.graphframes.connect.protoB\020GraphframesProtoH\001P\001\240\001\001\242\002\004OGCP\252\002\035Org.Graphframes.Connect.Proto\312\002\035Org\\Graphframes\\Connect\\Proto\342\002)Org\\Graphframes\\Connect\\Proto\\GPBMetadata\352\002 Org::Graphframes::Connect::Proto" + ) + _globals["_GRAPHFRAMESAPI"]._serialized_start = 53 + _globals["_GRAPHFRAMESAPI"]._serialized_end = 1775 + _globals["_COLUMNOREXPRESSION"]._serialized_start = 1777 + _globals["_COLUMNOREXPRESSION"]._serialized_end = 1854 + _globals["_STRINGORLONGID"]._serialized_start = 1856 + _globals["_STRINGORLONGID"]._serialized_end = 1936 + _globals["_AGGREGATEMESSAGES"]._serialized_start = 1939 + _globals["_AGGREGATEMESSAGES"]._serialized_end = 2242 + _globals["_BFS"]._serialized_start = 2245 + _globals["_BFS"]._serialized_end = 2530 + _globals["_CONNECTEDCOMPONENTS"]._serialized_start = 2533 + _globals["_CONNECTEDCOMPONENTS"]._serialized_end = 2682 + _globals["_DEGREES"]._serialized_start = 2684 + _globals["_DEGREES"]._serialized_end = 2693 + _globals["_DROPISOLATEDVERTICES"]._serialized_start = 2695 + _globals["_DROPISOLATEDVERTICES"]._serialized_end = 2717 + _globals["_FILTEREDGES"]._serialized_start = 2719 + _globals["_FILTEREDGES"]._serialized_end = 2813 + _globals["_FILTERVERTICES"]._serialized_start = 2815 + _globals["_FILTERVERTICES"]._serialized_end = 2912 + _globals["_FIND"]._serialized_start = 2914 + _globals["_FIND"]._serialized_end = 2946 + _globals["_INDEGREES"]._serialized_start = 2948 + _globals["_INDEGREES"]._serialized_end = 2959 + _globals["_LABELPROPAGATION"]._serialized_start = 2961 + _globals["_LABELPROPAGATION"]._serialized_end = 3006 + _globals["_OUTDEGREES"]._serialized_start = 3008 + _globals["_OUTDEGREES"]._serialized_end = 3020 + _globals["_PAGERANK"]._serialized_start = 3023 + _globals["_PAGERANK"]._serialized_end = 3249 + _globals["_PARALLELPERSONALIZEDPAGERANK"]._serialized_start = 3252 + _globals["_PARALLELPERSONALIZEDPAGERANK"]._serialized_end = 3432 + _globals["_PREGEL"]._serialized_start = 3435 + _globals["_PREGEL"]._serialized_end = 4027 + _globals["_SHORTESTPATHS"]._serialized_start = 4029 + _globals["_SHORTESTPATHS"]._serialized_end = 4121 + _globals["_STRONGLYCONNECTEDCOMPONENTS"]._serialized_start = 4123 + _globals["_STRONGLYCONNECTEDCOMPONENTS"]._serialized_end = 4179 + _globals["_SVDPLUSPLUS"]._serialized_start = 4182 + _globals["_SVDPLUSPLUS"]._serialized_end = 4396 + _globals["_TRIANGLECOUNT"]._serialized_start = 4398 + _globals["_TRIANGLECOUNT"]._serialized_end = 4413 + _globals["_TRIPLETS"]._serialized_start = 4415 + _globals["_TRIPLETS"]._serialized_end = 4425 # @@protoc_insertion_point(module_scope) diff --git a/python/graphframes/connect/proto/graphframes_pb2.pyi b/python/graphframes/connect/proto/graphframes_pb2.pyi index df0f20f89..aa43d9cff 100644 --- a/python/graphframes/connect/proto/graphframes_pb2.pyi +++ b/python/graphframes/connect/proto/graphframes_pb2.pyi @@ -1,12 +1,40 @@ from google.protobuf.internal import containers as _containers from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message -from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union +from typing import ( + ClassVar as _ClassVar, + Iterable as _Iterable, + Mapping as _Mapping, + Optional as _Optional, + Union as _Union, +) DESCRIPTOR: _descriptor.FileDescriptor class GraphFramesAPI(_message.Message): - __slots__ = ("vertices", "edges", "aggregate_messages", "bfs", "connected_components", "degrees", "drop_isolated_vertices", "filter_edges", "filter_vertices", "find", "in_degrees", "label_propagation", "out_degrees", "page_rank", "parallel_personalized_page_rank", "pregel", "shortest_paths", "strongly_connected_components", "svd_plus_plus", "triangle_count", "triplets") + __slots__ = ( + "vertices", + "edges", + "aggregate_messages", + "bfs", + "connected_components", + "degrees", + "drop_isolated_vertices", + "filter_edges", + "filter_vertices", + "find", + "in_degrees", + "label_propagation", + "out_degrees", + "page_rank", + "parallel_personalized_page_rank", + "pregel", + "shortest_paths", + "strongly_connected_components", + "svd_plus_plus", + "triangle_count", + "triplets", + ) VERTICES_FIELD_NUMBER: _ClassVar[int] EDGES_FIELD_NUMBER: _ClassVar[int] AGGREGATE_MESSAGES_FIELD_NUMBER: _ClassVar[int] @@ -49,7 +77,34 @@ class GraphFramesAPI(_message.Message): svd_plus_plus: SVDPlusPlus triangle_count: TriangleCount triplets: Triplets - def __init__(self, vertices: _Optional[bytes] = ..., edges: _Optional[bytes] = ..., aggregate_messages: _Optional[_Union[AggregateMessages, _Mapping]] = ..., bfs: _Optional[_Union[BFS, _Mapping]] = ..., connected_components: _Optional[_Union[ConnectedComponents, _Mapping]] = ..., degrees: _Optional[_Union[Degrees, _Mapping]] = ..., drop_isolated_vertices: _Optional[_Union[DropIsolatedVertices, _Mapping]] = ..., filter_edges: _Optional[_Union[FilterEdges, _Mapping]] = ..., filter_vertices: _Optional[_Union[FilterVertices, _Mapping]] = ..., find: _Optional[_Union[Find, _Mapping]] = ..., in_degrees: _Optional[_Union[InDegrees, _Mapping]] = ..., label_propagation: _Optional[_Union[LabelPropagation, _Mapping]] = ..., out_degrees: _Optional[_Union[OutDegrees, _Mapping]] = ..., page_rank: _Optional[_Union[PageRank, _Mapping]] = ..., parallel_personalized_page_rank: _Optional[_Union[ParallelPersonalizedPageRank, _Mapping]] = ..., pregel: _Optional[_Union[Pregel, _Mapping]] = ..., shortest_paths: _Optional[_Union[ShortestPaths, _Mapping]] = ..., strongly_connected_components: _Optional[_Union[StronglyConnectedComponents, _Mapping]] = ..., svd_plus_plus: _Optional[_Union[SVDPlusPlus, _Mapping]] = ..., triangle_count: _Optional[_Union[TriangleCount, _Mapping]] = ..., triplets: _Optional[_Union[Triplets, _Mapping]] = ...) -> None: ... + def __init__( + self, + vertices: _Optional[bytes] = ..., + edges: _Optional[bytes] = ..., + aggregate_messages: _Optional[_Union[AggregateMessages, _Mapping]] = ..., + bfs: _Optional[_Union[BFS, _Mapping]] = ..., + connected_components: _Optional[_Union[ConnectedComponents, _Mapping]] = ..., + degrees: _Optional[_Union[Degrees, _Mapping]] = ..., + drop_isolated_vertices: _Optional[_Union[DropIsolatedVertices, _Mapping]] = ..., + filter_edges: _Optional[_Union[FilterEdges, _Mapping]] = ..., + filter_vertices: _Optional[_Union[FilterVertices, _Mapping]] = ..., + find: _Optional[_Union[Find, _Mapping]] = ..., + in_degrees: _Optional[_Union[InDegrees, _Mapping]] = ..., + label_propagation: _Optional[_Union[LabelPropagation, _Mapping]] = ..., + out_degrees: _Optional[_Union[OutDegrees, _Mapping]] = ..., + page_rank: _Optional[_Union[PageRank, _Mapping]] = ..., + parallel_personalized_page_rank: _Optional[ + _Union[ParallelPersonalizedPageRank, _Mapping] + ] = ..., + pregel: _Optional[_Union[Pregel, _Mapping]] = ..., + shortest_paths: _Optional[_Union[ShortestPaths, _Mapping]] = ..., + strongly_connected_components: _Optional[ + _Union[StronglyConnectedComponents, _Mapping] + ] = ..., + svd_plus_plus: _Optional[_Union[SVDPlusPlus, _Mapping]] = ..., + triangle_count: _Optional[_Union[TriangleCount, _Mapping]] = ..., + triplets: _Optional[_Union[Triplets, _Mapping]] = ..., + ) -> None: ... class ColumnOrExpression(_message.Message): __slots__ = ("col", "expr") @@ -75,7 +130,12 @@ class AggregateMessages(_message.Message): agg_col: ColumnOrExpression send_to_src: ColumnOrExpression send_to_dst: ColumnOrExpression - def __init__(self, agg_col: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., send_to_src: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., send_to_dst: _Optional[_Union[ColumnOrExpression, _Mapping]] = ...) -> None: ... + def __init__( + self, + agg_col: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., + send_to_src: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., + send_to_dst: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., + ) -> None: ... class BFS(_message.Message): __slots__ = ("from_expr", "to_expr", "edge_filter", "max_path_length") @@ -87,7 +147,13 @@ class BFS(_message.Message): to_expr: ColumnOrExpression edge_filter: ColumnOrExpression max_path_length: int - def __init__(self, from_expr: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., to_expr: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., edge_filter: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., max_path_length: _Optional[int] = ...) -> None: ... + def __init__( + self, + from_expr: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., + to_expr: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., + edge_filter: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., + max_path_length: _Optional[int] = ..., + ) -> None: ... class ConnectedComponents(_message.Message): __slots__ = ("algorithm", "checkpoint_interval", "broadcast_threshold") @@ -97,7 +163,12 @@ class ConnectedComponents(_message.Message): algorithm: str checkpoint_interval: int broadcast_threshold: int - def __init__(self, algorithm: _Optional[str] = ..., checkpoint_interval: _Optional[int] = ..., broadcast_threshold: _Optional[int] = ...) -> None: ... + def __init__( + self, + algorithm: _Optional[str] = ..., + checkpoint_interval: _Optional[int] = ..., + broadcast_threshold: _Optional[int] = ..., + ) -> None: ... class Degrees(_message.Message): __slots__ = () @@ -111,13 +182,17 @@ class FilterEdges(_message.Message): __slots__ = ("condition",) CONDITION_FIELD_NUMBER: _ClassVar[int] condition: ColumnOrExpression - def __init__(self, condition: _Optional[_Union[ColumnOrExpression, _Mapping]] = ...) -> None: ... + def __init__( + self, condition: _Optional[_Union[ColumnOrExpression, _Mapping]] = ... + ) -> None: ... class FilterVertices(_message.Message): __slots__ = ("condition",) CONDITION_FIELD_NUMBER: _ClassVar[int] condition: ColumnOrExpression - def __init__(self, condition: _Optional[_Union[ColumnOrExpression, _Mapping]] = ...) -> None: ... + def __init__( + self, condition: _Optional[_Union[ColumnOrExpression, _Mapping]] = ... + ) -> None: ... class Find(_message.Message): __slots__ = ("pattern",) @@ -149,7 +224,13 @@ class PageRank(_message.Message): source_id: StringOrLongID max_iter: int tol: float - def __init__(self, reset_probability: _Optional[float] = ..., source_id: _Optional[_Union[StringOrLongID, _Mapping]] = ..., max_iter: _Optional[int] = ..., tol: _Optional[float] = ...) -> None: ... + def __init__( + self, + reset_probability: _Optional[float] = ..., + source_id: _Optional[_Union[StringOrLongID, _Mapping]] = ..., + max_iter: _Optional[int] = ..., + tol: _Optional[float] = ..., + ) -> None: ... class ParallelPersonalizedPageRank(_message.Message): __slots__ = ("reset_probability", "source_ids", "max_iter") @@ -159,10 +240,24 @@ class ParallelPersonalizedPageRank(_message.Message): reset_probability: float source_ids: _containers.RepeatedCompositeFieldContainer[StringOrLongID] max_iter: int - def __init__(self, reset_probability: _Optional[float] = ..., source_ids: _Optional[_Iterable[_Union[StringOrLongID, _Mapping]]] = ..., max_iter: _Optional[int] = ...) -> None: ... + def __init__( + self, + reset_probability: _Optional[float] = ..., + source_ids: _Optional[_Iterable[_Union[StringOrLongID, _Mapping]]] = ..., + max_iter: _Optional[int] = ..., + ) -> None: ... class Pregel(_message.Message): - __slots__ = ("agg_msgs", "send_msg_to_dst", "send_msg_to_src", "checkpoint_interval", "max_iter", "additional_col_name", "additional_col_initial", "additional_col_upd") + __slots__ = ( + "agg_msgs", + "send_msg_to_dst", + "send_msg_to_src", + "checkpoint_interval", + "max_iter", + "additional_col_name", + "additional_col_initial", + "additional_col_upd", + ) AGG_MSGS_FIELD_NUMBER: _ClassVar[int] SEND_MSG_TO_DST_FIELD_NUMBER: _ClassVar[int] SEND_MSG_TO_SRC_FIELD_NUMBER: _ClassVar[int] @@ -172,20 +267,32 @@ class Pregel(_message.Message): ADDITIONAL_COL_INITIAL_FIELD_NUMBER: _ClassVar[int] ADDITIONAL_COL_UPD_FIELD_NUMBER: _ClassVar[int] agg_msgs: ColumnOrExpression - send_msg_to_dst: ColumnOrExpression - send_msg_to_src: ColumnOrExpression + send_msg_to_dst: _containers.RepeatedCompositeFieldContainer[ColumnOrExpression] + send_msg_to_src: _containers.RepeatedCompositeFieldContainer[ColumnOrExpression] checkpoint_interval: int max_iter: int additional_col_name: str additional_col_initial: ColumnOrExpression additional_col_upd: ColumnOrExpression - def __init__(self, agg_msgs: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., send_msg_to_dst: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., send_msg_to_src: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., checkpoint_interval: _Optional[int] = ..., max_iter: _Optional[int] = ..., additional_col_name: _Optional[str] = ..., additional_col_initial: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., additional_col_upd: _Optional[_Union[ColumnOrExpression, _Mapping]] = ...) -> None: ... + def __init__( + self, + agg_msgs: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., + send_msg_to_dst: _Optional[_Iterable[_Union[ColumnOrExpression, _Mapping]]] = ..., + send_msg_to_src: _Optional[_Iterable[_Union[ColumnOrExpression, _Mapping]]] = ..., + checkpoint_interval: _Optional[int] = ..., + max_iter: _Optional[int] = ..., + additional_col_name: _Optional[str] = ..., + additional_col_initial: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., + additional_col_upd: _Optional[_Union[ColumnOrExpression, _Mapping]] = ..., + ) -> None: ... class ShortestPaths(_message.Message): __slots__ = ("landmarks",) LANDMARKS_FIELD_NUMBER: _ClassVar[int] landmarks: _containers.RepeatedCompositeFieldContainer[StringOrLongID] - def __init__(self, landmarks: _Optional[_Iterable[_Union[StringOrLongID, _Mapping]]] = ...) -> None: ... + def __init__( + self, landmarks: _Optional[_Iterable[_Union[StringOrLongID, _Mapping]]] = ... + ) -> None: ... class StronglyConnectedComponents(_message.Message): __slots__ = ("max_iter",) @@ -194,7 +301,16 @@ class StronglyConnectedComponents(_message.Message): def __init__(self, max_iter: _Optional[int] = ...) -> None: ... class SVDPlusPlus(_message.Message): - __slots__ = ("rank", "max_iter", "min_value", "max_value", "gamma1", "gamma2", "gamma6", "gamma7") + __slots__ = ( + "rank", + "max_iter", + "min_value", + "max_value", + "gamma1", + "gamma2", + "gamma6", + "gamma7", + ) RANK_FIELD_NUMBER: _ClassVar[int] MAX_ITER_FIELD_NUMBER: _ClassVar[int] MIN_VALUE_FIELD_NUMBER: _ClassVar[int] @@ -211,7 +327,17 @@ class SVDPlusPlus(_message.Message): gamma2: float gamma6: float gamma7: float - def __init__(self, rank: _Optional[int] = ..., max_iter: _Optional[int] = ..., min_value: _Optional[float] = ..., max_value: _Optional[float] = ..., gamma1: _Optional[float] = ..., gamma2: _Optional[float] = ..., gamma6: _Optional[float] = ..., gamma7: _Optional[float] = ...) -> None: ... + def __init__( + self, + rank: _Optional[int] = ..., + max_iter: _Optional[int] = ..., + min_value: _Optional[float] = ..., + max_value: _Optional[float] = ..., + gamma1: _Optional[float] = ..., + gamma2: _Optional[float] = ..., + gamma6: _Optional[float] = ..., + gamma7: _Optional[float] = ..., + ) -> None: ... class TriangleCount(_message.Message): __slots__ = () diff --git a/python/graphframes/connect/proto/graphframes_pb2_grpc.py b/python/graphframes/connect/proto/graphframes_pb2_grpc.py index 2daafffeb..8a9393943 100644 --- a/python/graphframes/connect/proto/graphframes_pb2_grpc.py +++ b/python/graphframes/connect/proto/graphframes_pb2_grpc.py @@ -1,4 +1,3 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc - diff --git a/python/graphframes/connect/utils.py b/python/graphframes/connect/utils.py index a4c270130..aacd5ed06 100644 --- a/python/graphframes/connect/utils.py +++ b/python/graphframes/connect/utils.py @@ -20,12 +20,14 @@ def column_to_proto(col: Column, client: SparkConnectClient) -> bytes: assert isinstance(expr, Expression) return expr.to_plan(client).SerializeToString() + def make_column_or_expr(col: Column | str, client: SparkConnectClient) -> ColumnOrExpression: if isinstance(col, Column): return ColumnOrExpression(col=column_to_proto(col, client)) else: return ColumnOrExpression(expr=col) + def make_str_or_long_id(str_or_long: str | int) -> StringOrLongID: if isinstance(str_or_long, str): return StringOrLongID(string_id=str_or_long) diff --git a/python/graphframes/examples/__init__.py b/python/graphframes/examples/__init__.py index 8b92ef01f..2003b0191 100644 --- a/python/graphframes/examples/__init__.py +++ b/python/graphframes/examples/__init__.py @@ -1,5 +1,4 @@ - from .belief_propagation import BeliefPropagation from .graphs import Graphs -__all__ = ['BeliefPropagation', 'Graphs'] +__all__ = ["BeliefPropagation", "Graphs"] diff --git a/python/graphframes/examples/belief_propagation.py b/python/graphframes/examples/belief_propagation.py index c013450d7..d0ff8dbfd 100644 --- a/python/graphframes/examples/belief_propagation.py +++ b/python/graphframes/examples/belief_propagation.py @@ -25,7 +25,7 @@ from graphframes.lib import AggregateMessages as AM from pyspark.sql import SparkSession, functions as sqlfunctions, types -__all__ = ['BeliefPropagation'] +__all__ = ["BeliefPropagation"] class BeliefPropagation: @@ -61,7 +61,7 @@ class BeliefPropagation: * Coloring the graph by assigning a color to each vertex such that no neighboring vertices share the same color. * In each step of BP, update all vertices of a single color. Alternate colors. - """ + """ @classmethod def runBPwithGraphFrames(cls, g: GraphFrame, numIter: int) -> GraphFrame: @@ -71,12 +71,12 @@ def runBPwithGraphFrames(cls, g: GraphFrame, numIter: int) -> GraphFrame: """ # choose colors for vertices for BP scheduling colorG = cls._colorGraph(g) - numColors = colorG.vertices.select('color').distinct().count() + numColors = colorG.vertices.select("color").distinct().count() # TODO: handle vertices without any edges # initialize vertex beliefs at 0.0 - gx = GraphFrame(colorG.vertices.withColumn('belief', sqlfunctions.lit(0.0)), colorG.edges) + gx = GraphFrame(colorG.vertices.withColumn("belief", sqlfunctions.lit(0.0)), colorG.edges) # run BP for numIter iterations for iter_ in range(numIter): @@ -85,37 +85,40 @@ def runBPwithGraphFrames(cls, g: GraphFrame, numIter: int) -> GraphFrame: # Send messages to vertices of the current color. # We may send to source or destination since edges are treated as undirected. msgForSrc = sqlfunctions.when( - AM.src['color'] == color, - AM.edge['b'] * AM.dst['belief']) + AM.src["color"] == color, AM.edge["b"] * AM.dst["belief"] + ) msgForDst = sqlfunctions.when( - AM.dst['color'] == color, - AM.edge['b'] * AM.src['belief']) + AM.dst["color"] == color, AM.edge["b"] * AM.src["belief"] + ) # numerically stable sigmoid logistic = sqlfunctions.udf(cls._sigmoid, returnType=types.DoubleType()) aggregates = gx.aggregateMessages( sqlfunctions.sum(AM.msg).alias("aggMess"), sendToSrc=msgForSrc, - sendToDst=msgForDst) + sendToDst=msgForDst, + ) v = gx.vertices # receive messages and update beliefs for vertices of the current color newBeliefCol = sqlfunctions.when( - (v['color'] == color) & (aggregates['aggMess'].isNotNull()), - logistic(aggregates['aggMess'] + v['a']) - ).otherwise(v['belief']) # keep old beliefs for other colors - newVertices = (v - .join(aggregates, on=(v['id'] == aggregates['id']), how='left_outer') - .drop(aggregates['id']) # drop duplicate ID column (from outer join) - .withColumn('newBelief', newBeliefCol) # compute new beliefs - .drop('aggMess') # drop messages - .drop('belief') # drop old beliefs - .withColumnRenamed('newBelief', 'belief') + (v["color"] == color) & (aggregates["aggMess"].isNotNull()), + logistic(aggregates["aggMess"] + v["a"]), + ).otherwise( + v["belief"] + ) # keep old beliefs for other colors + newVertices = ( + v.join(aggregates, on=(v["id"] == aggregates["id"]), how="left_outer") + .drop(aggregates["id"]) # drop duplicate ID column (from outer join) + .withColumn("newBelief", newBeliefCol) # compute new beliefs + .drop("aggMess") # drop messages + .drop("belief") # drop old beliefs + .withColumnRenamed("newBelief", "belief") ) # cache new vertices using workaround for SPARK-1334 cachedNewVertices = AM.getCachedDataFrame(newVertices) gx = GraphFrame(cachedNewVertices, gx.edges) # Drop the "color" column from vertices - return GraphFrame(gx.vertices.drop('color'), gx.edges) + return GraphFrame(gx.vertices.drop("color"), gx.edges) @staticmethod def _colorGraph(g: GraphFrame) -> GraphFrame: @@ -132,7 +135,7 @@ def _colorGraph(g: GraphFrame) -> GraphFrame: """ colorUDF = sqlfunctions.udf(lambda i, j: (i + j) % 2, returnType=types.IntegerType()) - v = g.vertices.withColumn('color', colorUDF(sqlfunctions.col('i'), sqlfunctions.col('j'))) + v = g.vertices.withColumn("color", colorUDF(sqlfunctions.col("i"), sqlfunctions.col("j"))) return GraphFrame(v, g.edges) @staticmethod @@ -164,12 +167,12 @@ def main() -> None: results = BeliefPropagation.runBPwithGraphFrames(g, numIter) # display beliefs - beliefs = results.vertices.select('id', 'belief') + beliefs = results.vertices.select("id", "belief") print("Done with BP. Final beliefs after {} iterations:".format(numIter)) beliefs.show() spark.stop() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/python/graphframes/examples/graphs.py b/python/graphframes/examples/graphs.py index 8db04aecc..53535327d 100644 --- a/python/graphframes/examples/graphs.py +++ b/python/graphframes/examples/graphs.py @@ -21,7 +21,7 @@ from graphframes import GraphFrame -__all__ = ['Graphs'] +__all__ = ["Graphs"] class Graphs: @@ -37,24 +37,30 @@ def __init__(self, spark: SparkSession) -> None: def friends(self) -> GraphFrame: """A GraphFrame of friends in a (fake) social network.""" # Vertex DataFrame - v = self._spark.createDataFrame([ - ("a", "Alice", 34), - ("b", "Bob", 36), - ("c", "Charlie", 30), - ("d", "David", 29), - ("e", "Esther", 32), - ("f", "Fanny", 36) - ], ["id", "name", "age"]) + v = self._spark.createDataFrame( + [ + ("a", "Alice", 34), + ("b", "Bob", 36), + ("c", "Charlie", 30), + ("d", "David", 29), + ("e", "Esther", 32), + ("f", "Fanny", 36), + ], + ["id", "name", "age"], + ) # Edge DataFrame - e = self._spark.createDataFrame([ - ("a", "b", "friend"), - ("b", "c", "follow"), - ("c", "b", "follow"), - ("f", "c", "follow"), - ("e", "f", "follow"), - ("e", "d", "friend"), - ("d", "a", "friend") - ], ["src", "dst", "relationship"]) + e = self._spark.createDataFrame( + [ + ("a", "b", "friend"), + ("b", "c", "follow"), + ("c", "b", "follow"), + ("f", "c", "follow"), + ("e", "f", "follow"), + ("e", "d", "friend"), + ("d", "a", "friend"), + ], + ["src", "dst", "relationship"], + ) # Create a GraphFrame return GraphFrame(v, e) @@ -87,37 +93,40 @@ def gridIsingModel(self, n: int, vStd: float = 1.0, eStd: float = 1.0) -> GraphF # check param n if n < 1: raise ValueError( - "Grid graph must have size >= 1, but was given invalid value n = {}" - .format(n)) + "Grid graph must have size >= 1, but was given invalid value n = {}".format(n) + ) # create coodinates grid coordinates = self._spark.createDataFrame( - itertools.product(range(n), range(n)), - schema=('i', 'j')) + itertools.product(range(n), range(n)), schema=("i", "j") + ) # create SQL expression for converting coordinates (i,j) to a string ID "i,j" # avoid Cartesian join due to SPARK-15425: use generator since n should be small - toIDudf = sqlfunctions.udf(lambda i, j: '{},{}'.format(i,j)) + toIDudf = sqlfunctions.udf(lambda i, j: "{},{}".format(i, j)) # create the vertex DataFrame # create SQL expression for converting coordinates (i,j) to a string ID "i,j" - vIDcol = toIDudf(sqlfunctions.col('i'), sqlfunctions.col('j')) + vIDcol = toIDudf(sqlfunctions.col("i"), sqlfunctions.col("j")) # add random parameters generated from a normal distribution seed = 12345 - vertices = (coordinates.withColumn('id', vIDcol) - .withColumn('a', sqlfunctions.randn(seed) * vStd)) + vertices = coordinates.withColumn("id", vIDcol).withColumn( + "a", sqlfunctions.randn(seed) * vStd + ) # create the edge DataFrame # create SQL expression for converting coordinates (i,j+1) and (i+1,j) to string IDs - rightIDcol = toIDudf(sqlfunctions.col('i'), sqlfunctions.col('j') + 1) - downIDcol = toIDudf(sqlfunctions.col('i') + 1, sqlfunctions.col('j')) - horizontalEdges = (coordinates.filter(sqlfunctions.col('j') != n - 1) - .select(vIDcol.alias('src'), rightIDcol.alias('dst'))) - verticalEdges = (coordinates.filter(sqlfunctions.col('i') != n - 1) - .select(vIDcol.alias('src'), downIDcol.alias('dst'))) + rightIDcol = toIDudf(sqlfunctions.col("i"), sqlfunctions.col("j") + 1) + downIDcol = toIDudf(sqlfunctions.col("i") + 1, sqlfunctions.col("j")) + horizontalEdges = coordinates.filter(sqlfunctions.col("j") != n - 1).select( + vIDcol.alias("src"), rightIDcol.alias("dst") + ) + verticalEdges = coordinates.filter(sqlfunctions.col("i") != n - 1).select( + vIDcol.alias("src"), downIDcol.alias("dst") + ) allEdges = horizontalEdges.unionAll(verticalEdges) # add random parameters from a normal distribution - edges = allEdges.withColumn('b', sqlfunctions.randn(seed + 1) * eStd) + edges = allEdges.withColumn("b", sqlfunctions.randn(seed + 1) * eStd) # create the GraphFrame g = GraphFrame(vertices, edges) diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index f0dbd1f8a..bb6255de5 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -18,12 +18,13 @@ from __future__ import annotations from typing import Any, Optional +from typing_extensions import TYPE_CHECKING from pyspark.sql import Column, DataFrame from pyspark.storagelevel import StorageLevel from pyspark.version import __version__ -if __version__[:2] >= "3.4": +if __version__[:3] >= "3.4": from pyspark.sql.utils import is_remote else: # All the Connect-related utilities are accessible starting from 3.4.x @@ -33,11 +34,12 @@ def is_remote() -> bool: from graphframes.lib import Pregel -from .classic.graphframe import GraphFrame as GraphFrameClassic +from graphframes.classic.graphframe import GraphFrame as GraphFrameClassic -if __version__[:2] >= "3.4": +if __version__[:3] >= "3.4": from graphframes.connect.graphframe_client import GraphFrameConnect else: + class GraphFrameConnect: def __init__(self, *args, **kwargs) -> None: raise ValueError("Unreachable error happened!") @@ -97,9 +99,7 @@ def cache(self) -> "GraphFrame": """ return GraphFrame._from_impl(self._impl.cache()) - def persist( - self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY - ) -> "GraphFrame": + def persist(self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) -> "GraphFrame": """Persist the dataframe representation of vertices and edges of the graph with the given storage level. """ @@ -254,9 +254,7 @@ def aggregateMessages( :return: DataFrame with columns for the vertex ID and the resulting aggregated message """ - return self._impl.aggregateMessages( - aggCol=aggCol, sendToSrc=sendToSrc, sendToDst=sendToDst - ) + return self._impl.aggregateMessages(aggCol=aggCol, sendToSrc=sendToSrc, sendToDst=sendToDst) # Standard algorithms @@ -418,9 +416,7 @@ def _test(): from pyspark.sql import SparkSession globs = graphframe.__dict__.copy() - globs["spark"] = ( - SparkSession.builder.master("local[4]").appName("PythonTest").getOrCreate() - ) + globs["spark"] = SparkSession.builder.master("local[4]").appName("PythonTest").getOrCreate() (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE ) diff --git a/python/graphframes/lib/__init__.py b/python/graphframes/lib/__init__.py index 325e74543..076dd5232 100644 --- a/python/graphframes/lib/__init__.py +++ b/python/graphframes/lib/__init__.py @@ -1,5 +1,4 @@ - from .aggregate_messages import AggregateMessages from .pregel import Pregel -__all__ = ['AggregateMessages', 'Pregel'] +__all__ = ["AggregateMessages", "Pregel"] diff --git a/python/graphframes/lib/aggregate_messages.py b/python/graphframes/lib/aggregate_messages.py index c0867dcd0..2ff1ba3b8 100644 --- a/python/graphframes/lib/aggregate_messages.py +++ b/python/graphframes/lib/aggregate_messages.py @@ -23,8 +23,12 @@ def _java_api(jsc: SparkContext) -> Any: javaClassName = "org.graphframes.GraphFramePythonAPI" - return jsc._jvm.Thread.currentThread().getContextClassLoader().loadClass(javaClassName) \ - .newInstance() + return ( + jsc._jvm.Thread.currentThread() + .getContextClassLoader() + .loadClass(javaClassName) + .newInstance() + ) class _ClassProperty: diff --git a/python/graphframes/lib/pregel.py b/python/graphframes/lib/pregel.py index 72077c25c..fe9580a04 100644 --- a/python/graphframes/lib/pregel.py +++ b/python/graphframes/lib/pregel.py @@ -17,7 +17,8 @@ import sys from typing import Any -if sys.version > '3': + +if sys.version > "3": basestring = str from pyspark.sql import DataFrame, SparkSession @@ -80,6 +81,7 @@ class Pregel(JavaWrapper): def __init__(self, graph: "GraphFrame") -> None: super(Pregel, self).__init__() from graphframes import GraphFrame + self.graph = graph self._java_obj = self._new_java_obj("org.graphframes.lib.Pregel", graph._jvm_graph) @@ -102,7 +104,9 @@ def setCheckpointInterval(self, value: int) -> "Pregel": self._java_obj.setCheckpointInterval(int(value)) return self - def withVertexColumn(self, colName: str, initialExpr: Any, updateAfterAggMsgsExpr: Any) -> "Pregel": + def withVertexColumn( + self, colName: str, initialExpr: Any, updateAfterAggMsgsExpr: Any + ) -> "Pregel": """ Defines an additional vertex column at the start of run and how to update it in each iteration. diff --git a/python/poetry.lock b/python/poetry.lock index 0fb5fb139..616a08814 100644 --- a/python/poetry.lock +++ b/python/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.0.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "black" @@ -6,7 +6,6 @@ version = "25.1.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.9" -groups = ["dev"] files = [ {file = "black-25.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759e7ec1e050a15f89b770cefbf91ebee8917aac5c20483bc2d80a6c3a04df32"}, {file = "black-25.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e519ecf93120f34243e6b0054db49c00a35f84f195d5bce7e9f5cfc578fc2da"}, @@ -53,7 +52,6 @@ version = "8.1.8" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" -groups = ["dev"] files = [ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, @@ -68,8 +66,6 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" -groups = ["dev"] -markers = "platform_system == \"Windows\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, @@ -81,7 +77,6 @@ version = "7.1.2" description = "the modular source code checker: pep8 pyflakes and co" optional = false python-versions = ">=3.8.1" -groups = ["dev"] files = [ {file = "flake8-7.1.2-py2.py3-none-any.whl", hash = "sha256:1cbc62e65536f65e6d754dfe6f1bada7f5cf392d6f5db3c2b85892466c3e7c1a"}, {file = "flake8-7.1.2.tar.gz", hash = "sha256:c586ffd0b41540951ae41af572e6790dbd49fc12b3aa2541685d253d9bd504bd"}, @@ -92,13 +87,112 @@ mccabe = ">=0.7.0,<0.8.0" pycodestyle = ">=2.12.0,<2.13.0" pyflakes = ">=3.2.0,<3.3.0" +[[package]] +name = "googleapis-common-protos" +version = "1.68.0" +description = "Common protobufs used in Google APIs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "googleapis_common_protos-1.68.0-py2.py3-none-any.whl", hash = "sha256:aaf179b2f81df26dfadac95def3b16a95064c76a5f45f07e4c68a21bb371c4ac"}, + {file = "googleapis_common_protos-1.68.0.tar.gz", hash = "sha256:95d38161f4f9af0d9423eed8fb7b64ffd2568c3464eb542ff02c5bfa1953ab3c"}, +] + +[package.dependencies] +protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" + +[package.extras] +grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] + +[[package]] +name = "grpcio" +version = "1.70.0" +description = "HTTP/2-based RPC framework" +optional = false +python-versions = ">=3.8" +files = [ + {file = "grpcio-1.70.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:95469d1977429f45fe7df441f586521361e235982a0b39e33841549143ae2851"}, + {file = "grpcio-1.70.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:ed9718f17fbdb472e33b869c77a16d0b55e166b100ec57b016dc7de9c8d236bf"}, + {file = "grpcio-1.70.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:374d014f29f9dfdb40510b041792e0e2828a1389281eb590df066e1cc2b404e5"}, + {file = "grpcio-1.70.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2af68a6f5c8f78d56c145161544ad0febbd7479524a59c16b3e25053f39c87f"}, + {file = "grpcio-1.70.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce7df14b2dcd1102a2ec32f621cc9fab6695effef516efbc6b063ad749867295"}, + {file = "grpcio-1.70.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c78b339869f4dbf89881e0b6fbf376313e4f845a42840a7bdf42ee6caed4b11f"}, + {file = "grpcio-1.70.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:58ad9ba575b39edef71f4798fdb5c7b6d02ad36d47949cd381d4392a5c9cbcd3"}, + {file = "grpcio-1.70.0-cp310-cp310-win32.whl", hash = "sha256:2b0d02e4b25a5c1f9b6c7745d4fa06efc9fd6a611af0fb38d3ba956786b95199"}, + {file = "grpcio-1.70.0-cp310-cp310-win_amd64.whl", hash = "sha256:0de706c0a5bb9d841e353f6343a9defc9fc35ec61d6eb6111802f3aa9fef29e1"}, + {file = "grpcio-1.70.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:17325b0be0c068f35770f944124e8839ea3185d6d54862800fc28cc2ffad205a"}, + {file = "grpcio-1.70.0-cp311-cp311-macosx_10_14_universal2.whl", hash = "sha256:dbe41ad140df911e796d4463168e33ef80a24f5d21ef4d1e310553fcd2c4a386"}, + {file = "grpcio-1.70.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:5ea67c72101d687d44d9c56068328da39c9ccba634cabb336075fae2eab0d04b"}, + {file = "grpcio-1.70.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb5277db254ab7586769e490b7b22f4ddab3876c490da0a1a9d7c695ccf0bf77"}, + {file = "grpcio-1.70.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7831a0fc1beeeb7759f737f5acd9fdcda520e955049512d68fda03d91186eea"}, + {file = "grpcio-1.70.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:27cc75e22c5dba1fbaf5a66c778e36ca9b8ce850bf58a9db887754593080d839"}, + {file = "grpcio-1.70.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d63764963412e22f0491d0d32833d71087288f4e24cbcddbae82476bfa1d81fd"}, + {file = "grpcio-1.70.0-cp311-cp311-win32.whl", hash = "sha256:bb491125103c800ec209d84c9b51f1c60ea456038e4734688004f377cfacc113"}, + {file = "grpcio-1.70.0-cp311-cp311-win_amd64.whl", hash = "sha256:d24035d49e026353eb042bf7b058fb831db3e06d52bee75c5f2f3ab453e71aca"}, + {file = "grpcio-1.70.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:ef4c14508299b1406c32bdbb9fb7b47612ab979b04cf2b27686ea31882387cff"}, + {file = "grpcio-1.70.0-cp312-cp312-macosx_10_14_universal2.whl", hash = "sha256:aa47688a65643afd8b166928a1da6247d3f46a2784d301e48ca1cc394d2ffb40"}, + {file = "grpcio-1.70.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:880bfb43b1bb8905701b926274eafce5c70a105bc6b99e25f62e98ad59cb278e"}, + {file = "grpcio-1.70.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e654c4b17d07eab259d392e12b149c3a134ec52b11ecdc6a515b39aceeec898"}, + {file = "grpcio-1.70.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2394e3381071045a706ee2eeb6e08962dd87e8999b90ac15c55f56fa5a8c9597"}, + {file = "grpcio-1.70.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:b3c76701428d2df01964bc6479422f20e62fcbc0a37d82ebd58050b86926ef8c"}, + {file = "grpcio-1.70.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ac073fe1c4cd856ebcf49e9ed6240f4f84d7a4e6ee95baa5d66ea05d3dd0df7f"}, + {file = "grpcio-1.70.0-cp312-cp312-win32.whl", hash = "sha256:cd24d2d9d380fbbee7a5ac86afe9787813f285e684b0271599f95a51bce33528"}, + {file = "grpcio-1.70.0-cp312-cp312-win_amd64.whl", hash = "sha256:0495c86a55a04a874c7627fd33e5beaee771917d92c0e6d9d797628ac40e7655"}, + {file = "grpcio-1.70.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:aa573896aeb7d7ce10b1fa425ba263e8dddd83d71530d1322fd3a16f31257b4a"}, + {file = "grpcio-1.70.0-cp313-cp313-macosx_10_14_universal2.whl", hash = "sha256:d405b005018fd516c9ac529f4b4122342f60ec1cee181788249372524e6db429"}, + {file = "grpcio-1.70.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:f32090238b720eb585248654db8e3afc87b48d26ac423c8dde8334a232ff53c9"}, + {file = "grpcio-1.70.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dfa089a734f24ee5f6880c83d043e4f46bf812fcea5181dcb3a572db1e79e01c"}, + {file = "grpcio-1.70.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f19375f0300b96c0117aca118d400e76fede6db6e91f3c34b7b035822e06c35f"}, + {file = "grpcio-1.70.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:7c73c42102e4a5ec76608d9b60227d917cea46dff4d11d372f64cbeb56d259d0"}, + {file = "grpcio-1.70.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:0a5c78d5198a1f0aa60006cd6eb1c912b4a1520b6a3968e677dbcba215fabb40"}, + {file = "grpcio-1.70.0-cp313-cp313-win32.whl", hash = "sha256:fe9dbd916df3b60e865258a8c72ac98f3ac9e2a9542dcb72b7a34d236242a5ce"}, + {file = "grpcio-1.70.0-cp313-cp313-win_amd64.whl", hash = "sha256:4119fed8abb7ff6c32e3d2255301e59c316c22d31ab812b3fbcbaf3d0d87cc68"}, + {file = "grpcio-1.70.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:8058667a755f97407fca257c844018b80004ae8035565ebc2812cc550110718d"}, + {file = "grpcio-1.70.0-cp38-cp38-macosx_10_14_universal2.whl", hash = "sha256:879a61bf52ff8ccacbedf534665bb5478ec8e86ad483e76fe4f729aaef867cab"}, + {file = "grpcio-1.70.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:0ba0a173f4feacf90ee618fbc1a27956bfd21260cd31ced9bc707ef551ff7dc7"}, + {file = "grpcio-1.70.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:558c386ecb0148f4f99b1a65160f9d4b790ed3163e8610d11db47838d452512d"}, + {file = "grpcio-1.70.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:412faabcc787bbc826f51be261ae5fa996b21263de5368a55dc2cf824dc5090e"}, + {file = "grpcio-1.70.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3b0f01f6ed9994d7a0b27eeddea43ceac1b7e6f3f9d86aeec0f0064b8cf50fdb"}, + {file = "grpcio-1.70.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7385b1cb064734005204bc8994eed7dcb801ed6c2eda283f613ad8c6c75cf873"}, + {file = "grpcio-1.70.0-cp38-cp38-win32.whl", hash = "sha256:07269ff4940f6fb6710951116a04cd70284da86d0a4368fd5a3b552744511f5a"}, + {file = "grpcio-1.70.0-cp38-cp38-win_amd64.whl", hash = "sha256:aba19419aef9b254e15011b230a180e26e0f6864c90406fdbc255f01d83bc83c"}, + {file = "grpcio-1.70.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:4f1937f47c77392ccd555728f564a49128b6a197a05a5cd527b796d36f3387d0"}, + {file = "grpcio-1.70.0-cp39-cp39-macosx_10_14_universal2.whl", hash = "sha256:0cd430b9215a15c10b0e7d78f51e8a39d6cf2ea819fd635a7214fae600b1da27"}, + {file = "grpcio-1.70.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:e27585831aa6b57b9250abaf147003e126cd3a6c6ca0c531a01996f31709bed1"}, + {file = "grpcio-1.70.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c1af8e15b0f0fe0eac75195992a63df17579553b0c4af9f8362cc7cc99ccddf4"}, + {file = "grpcio-1.70.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbce24409beaee911c574a3d75d12ffb8c3e3dd1b813321b1d7a96bbcac46bf4"}, + {file = "grpcio-1.70.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ff4a8112a79464919bb21c18e956c54add43ec9a4850e3949da54f61c241a4a6"}, + {file = "grpcio-1.70.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5413549fdf0b14046c545e19cfc4eb1e37e9e1ebba0ca390a8d4e9963cab44d2"}, + {file = "grpcio-1.70.0-cp39-cp39-win32.whl", hash = "sha256:b745d2c41b27650095e81dea7091668c040457483c9bdb5d0d9de8f8eb25e59f"}, + {file = "grpcio-1.70.0-cp39-cp39-win_amd64.whl", hash = "sha256:a31d7e3b529c94e930a117b2175b2efd179d96eb3c7a21ccb0289a8ab05b645c"}, + {file = "grpcio-1.70.0.tar.gz", hash = "sha256:8d1584a68d5922330025881e63a6c1b54cc8117291d382e4fa69339b6d914c56"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.70.0)"] + +[[package]] +name = "grpcio-status" +version = "1.70.0" +description = "Status proto mapping for gRPC" +optional = false +python-versions = ">=3.8" +files = [ + {file = "grpcio_status-1.70.0-py3-none-any.whl", hash = "sha256:fc5a2ae2b9b1c1969cc49f3262676e6854aa2398ec69cb5bd6c47cd501904a85"}, + {file = "grpcio_status-1.70.0.tar.gz", hash = "sha256:0e7b42816512433b18b9d764285ff029bde059e9d41f8fe10a60631bd8348101"}, +] + +[package.dependencies] +googleapis-common-protos = ">=1.5.5" +grpcio = ">=1.70.0" +protobuf = ">=5.26.1,<6.0dev" + [[package]] name = "isort" version = "6.0.0" description = "A Python utility / library to sort Python imports." optional = false python-versions = ">=3.9.0" -groups = ["dev"] files = [ {file = "isort-6.0.0-py3-none-any.whl", hash = "sha256:567954102bb47bb12e0fae62606570faacddd441e45683968c8d1734fb1af892"}, {file = "isort-6.0.0.tar.gz", hash = "sha256:75d9d8a1438a9432a7d7b54f2d3b45cad9a4a0fdba43617d9873379704a8bdf1"}, @@ -114,7 +208,6 @@ version = "0.7.0" description = "McCabe checker, plugin for flake8" optional = false python-versions = ">=3.6" -groups = ["dev"] files = [ {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, @@ -126,7 +219,6 @@ version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." optional = false python-versions = ">=3.5" -groups = ["dev"] files = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, @@ -138,7 +230,6 @@ version = "1.3.7" description = "nose extends unittest to make testing easier" optional = false python-versions = "*" -groups = ["main"] files = [ {file = "nose-1.3.7-py2-none-any.whl", hash = "sha256:dadcddc0aefbf99eea214e0f1232b94f2fa9bd98fa8353711dacb112bfcbbb2a"}, {file = "nose-1.3.7-py3-none-any.whl", hash = "sha256:9ff7c6cc443f8c51994b34a667bbcf45afd6d945be7477b52e97516fd17c53ac"}, @@ -147,57 +238,47 @@ files = [ [[package]] name = "numpy" -version = "2.0.2" +version = "1.26.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.9" -groups = ["main"] files = [ - {file = "numpy-2.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:51129a29dbe56f9ca83438b706e2e69a39892b5eda6cedcb6b0c9fdc9b0d3ece"}, - {file = "numpy-2.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f15975dfec0cf2239224d80e32c3170b1d168335eaedee69da84fbe9f1f9cd04"}, - {file = "numpy-2.0.2-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8c5713284ce4e282544c68d1c3b2c7161d38c256d2eefc93c1d683cf47683e66"}, - {file = "numpy-2.0.2-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:becfae3ddd30736fe1889a37f1f580e245ba79a5855bff5f2a29cb3ccc22dd7b"}, - {file = "numpy-2.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2da5960c3cf0df7eafefd806d4e612c5e19358de82cb3c343631188991566ccd"}, - {file = "numpy-2.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:496f71341824ed9f3d2fd36cf3ac57ae2e0165c143b55c3a035ee219413f3318"}, - {file = "numpy-2.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a61ec659f68ae254e4d237816e33171497e978140353c0c2038d46e63282d0c8"}, - {file = "numpy-2.0.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d731a1c6116ba289c1e9ee714b08a8ff882944d4ad631fd411106a30f083c326"}, - {file = "numpy-2.0.2-cp310-cp310-win32.whl", hash = "sha256:984d96121c9f9616cd33fbd0618b7f08e0cfc9600a7ee1d6fd9b239186d19d97"}, - {file = "numpy-2.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:c7b0be4ef08607dd04da4092faee0b86607f111d5ae68036f16cc787e250a131"}, - {file = "numpy-2.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:49ca4decb342d66018b01932139c0961a8f9ddc7589611158cb3c27cbcf76448"}, - {file = "numpy-2.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:11a76c372d1d37437857280aa142086476136a8c0f373b2e648ab2c8f18fb195"}, - {file = "numpy-2.0.2-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:807ec44583fd708a21d4a11d94aedf2f4f3c3719035c76a2bbe1fe8e217bdc57"}, - {file = "numpy-2.0.2-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8cafab480740e22f8d833acefed5cc87ce276f4ece12fdaa2e8903db2f82897a"}, - {file = "numpy-2.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a15f476a45e6e5a3a79d8a14e62161d27ad897381fecfa4a09ed5322f2085669"}, - {file = "numpy-2.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13e689d772146140a252c3a28501da66dfecd77490b498b168b501835041f951"}, - {file = "numpy-2.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9ea91dfb7c3d1c56a0e55657c0afb38cf1eeae4544c208dc465c3c9f3a7c09f9"}, - {file = "numpy-2.0.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c1c9307701fec8f3f7a1e6711f9089c06e6284b3afbbcd259f7791282d660a15"}, - {file = "numpy-2.0.2-cp311-cp311-win32.whl", hash = "sha256:a392a68bd329eafac5817e5aefeb39038c48b671afd242710b451e76090e81f4"}, - {file = "numpy-2.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:286cd40ce2b7d652a6f22efdfc6d1edf879440e53e76a75955bc0c826c7e64dc"}, - {file = "numpy-2.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:df55d490dea7934f330006d0f81e8551ba6010a5bf035a249ef61a94f21c500b"}, - {file = "numpy-2.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8df823f570d9adf0978347d1f926b2a867d5608f434a7cff7f7908c6570dcf5e"}, - {file = "numpy-2.0.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:9a92ae5c14811e390f3767053ff54eaee3bf84576d99a2456391401323f4ec2c"}, - {file = "numpy-2.0.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:a842d573724391493a97a62ebbb8e731f8a5dcc5d285dfc99141ca15a3302d0c"}, - {file = "numpy-2.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c05e238064fc0610c840d1cf6a13bf63d7e391717d247f1bf0318172e759e692"}, - {file = "numpy-2.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0123ffdaa88fa4ab64835dcbde75dcdf89c453c922f18dced6e27c90d1d0ec5a"}, - {file = "numpy-2.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:96a55f64139912d61de9137f11bf39a55ec8faec288c75a54f93dfd39f7eb40c"}, - {file = "numpy-2.0.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ec9852fb39354b5a45a80bdab5ac02dd02b15f44b3804e9f00c556bf24b4bded"}, - {file = "numpy-2.0.2-cp312-cp312-win32.whl", hash = "sha256:671bec6496f83202ed2d3c8fdc486a8fc86942f2e69ff0e986140339a63bcbe5"}, - {file = "numpy-2.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:cfd41e13fdc257aa5778496b8caa5e856dc4896d4ccf01841daee1d96465467a"}, - {file = "numpy-2.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9059e10581ce4093f735ed23f3b9d283b9d517ff46009ddd485f1747eb22653c"}, - {file = "numpy-2.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:423e89b23490805d2a5a96fe40ec507407b8ee786d66f7328be214f9679df6dd"}, - {file = "numpy-2.0.2-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:2b2955fa6f11907cf7a70dab0d0755159bca87755e831e47932367fc8f2f2d0b"}, - {file = "numpy-2.0.2-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:97032a27bd9d8988b9a97a8c4d2c9f2c15a81f61e2f21404d7e8ef00cb5be729"}, - {file = "numpy-2.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1e795a8be3ddbac43274f18588329c72939870a16cae810c2b73461c40718ab1"}, - {file = "numpy-2.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26b258c385842546006213344c50655ff1555a9338e2e5e02a0756dc3e803dd"}, - {file = "numpy-2.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5fec9451a7789926bcf7c2b8d187292c9f93ea30284802a0ab3f5be8ab36865d"}, - {file = "numpy-2.0.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:9189427407d88ff25ecf8f12469d4d39d35bee1db5d39fc5c168c6f088a6956d"}, - {file = "numpy-2.0.2-cp39-cp39-win32.whl", hash = "sha256:905d16e0c60200656500c95b6b8dca5d109e23cb24abc701d41c02d74c6b3afa"}, - {file = "numpy-2.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:a3f4ab0caa7f053f6797fcd4e1e25caee367db3112ef2b6ef82d749530768c73"}, - {file = "numpy-2.0.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:7f0a0c6f12e07fa94133c8a67404322845220c06a9e80e85999afe727f7438b8"}, - {file = "numpy-2.0.2-pp39-pypy39_pp73-macosx_14_0_x86_64.whl", hash = "sha256:312950fdd060354350ed123c0e25a71327d3711584beaef30cdaa93320c392d4"}, - {file = "numpy-2.0.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26df23238872200f63518dd2aa984cfca675d82469535dc7162dc2ee52d9dd5c"}, - {file = "numpy-2.0.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a46288ec55ebbd58947d31d72be2c63cbf839f0a63b49cb755022310792a3385"}, - {file = "numpy-2.0.2.tar.gz", hash = "sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78"}, + {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, + {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"}, + {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"}, + {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"}, + {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"}, + {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"}, + {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"}, + {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"}, + {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"}, + {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"}, + {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, ] [[package]] @@ -206,19 +287,103 @@ version = "24.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" -groups = ["dev"] files = [ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, ] +[[package]] +name = "pandas" +version = "2.2.3" +description = "Powerful data structures for data analysis, time series, and statistics" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pandas-2.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1948ddde24197a0f7add2bdc4ca83bf2b1ef84a1bc8ccffd95eda17fd836ecb5"}, + {file = "pandas-2.2.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:381175499d3802cde0eabbaf6324cce0c4f5d52ca6f8c377c29ad442f50f6348"}, + {file = "pandas-2.2.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d9c45366def9a3dd85a6454c0e7908f2b3b8e9c138f5dc38fed7ce720d8453ed"}, + {file = "pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86976a1c5b25ae3f8ccae3a5306e443569ee3c3faf444dfd0f41cda24667ad57"}, + {file = "pandas-2.2.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b8661b0238a69d7aafe156b7fa86c44b881387509653fdf857bebc5e4008ad42"}, + {file = "pandas-2.2.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:37e0aced3e8f539eccf2e099f65cdb9c8aa85109b0be6e93e2baff94264bdc6f"}, + {file = "pandas-2.2.3-cp310-cp310-win_amd64.whl", hash = "sha256:56534ce0746a58afaf7942ba4863e0ef81c9c50d3f0ae93e9497d6a41a057645"}, + {file = "pandas-2.2.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:66108071e1b935240e74525006034333f98bcdb87ea116de573a6a0dccb6c039"}, + {file = "pandas-2.2.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7c2875855b0ff77b2a64a0365e24455d9990730d6431b9e0ee18ad8acee13dbd"}, + {file = "pandas-2.2.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd8d0c3be0515c12fed0bdbae072551c8b54b7192c7b1fda0ba56059a0179698"}, + {file = "pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c124333816c3a9b03fbeef3a9f230ba9a737e9e5bb4060aa2107a86cc0a497fc"}, + {file = "pandas-2.2.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:63cc132e40a2e084cf01adf0775b15ac515ba905d7dcca47e9a251819c575ef3"}, + {file = "pandas-2.2.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:29401dbfa9ad77319367d36940cd8a0b3a11aba16063e39632d98b0e931ddf32"}, + {file = "pandas-2.2.3-cp311-cp311-win_amd64.whl", hash = "sha256:3fc6873a41186404dad67245896a6e440baacc92f5b716ccd1bc9ed2995ab2c5"}, + {file = "pandas-2.2.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b1d432e8d08679a40e2a6d8b2f9770a5c21793a6f9f47fdd52c5ce1948a5a8a9"}, + {file = "pandas-2.2.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a5a1595fe639f5988ba6a8e5bc9649af3baf26df3998a0abe56c02609392e0a4"}, + {file = "pandas-2.2.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5de54125a92bb4d1c051c0659e6fcb75256bf799a732a87184e5ea503965bce3"}, + {file = "pandas-2.2.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fffb8ae78d8af97f849404f21411c95062db1496aeb3e56f146f0355c9989319"}, + {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dfcb5ee8d4d50c06a51c2fffa6cff6272098ad6540aed1a76d15fb9318194d8"}, + {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:062309c1b9ea12a50e8ce661145c6aab431b1e99530d3cd60640e255778bd43a"}, + {file = "pandas-2.2.3-cp312-cp312-win_amd64.whl", hash = "sha256:59ef3764d0fe818125a5097d2ae867ca3fa64df032331b7e0917cf5d7bf66b13"}, + {file = "pandas-2.2.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f00d1345d84d8c86a63e476bb4955e46458b304b9575dcf71102b5c705320015"}, + {file = "pandas-2.2.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3508d914817e153ad359d7e069d752cdd736a247c322d932eb89e6bc84217f28"}, + {file = "pandas-2.2.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:22a9d949bfc9a502d320aa04e5d02feab689d61da4e7764b62c30b991c42c5f0"}, + {file = "pandas-2.2.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3a255b2c19987fbbe62a9dfd6cff7ff2aa9ccab3fc75218fd4b7530f01efa24"}, + {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:800250ecdadb6d9c78eae4990da62743b857b470883fa27f652db8bdde7f6659"}, + {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6374c452ff3ec675a8f46fd9ab25c4ad0ba590b71cf0656f8b6daa5202bca3fb"}, + {file = "pandas-2.2.3-cp313-cp313-win_amd64.whl", hash = "sha256:61c5ad4043f791b61dd4752191d9f07f0ae412515d59ba8f005832a532f8736d"}, + {file = "pandas-2.2.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:3b71f27954685ee685317063bf13c7709a7ba74fc996b84fc6821c59b0f06468"}, + {file = "pandas-2.2.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:38cf8125c40dae9d5acc10fa66af8ea6fdf760b2714ee482ca691fc66e6fcb18"}, + {file = "pandas-2.2.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ba96630bc17c875161df3818780af30e43be9b166ce51c9a18c1feae342906c2"}, + {file = "pandas-2.2.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1db71525a1538b30142094edb9adc10be3f3e176748cd7acc2240c2f2e5aa3a4"}, + {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:15c0e1e02e93116177d29ff83e8b1619c93ddc9c49083f237d4312337a61165d"}, + {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a"}, + {file = "pandas-2.2.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc6b93f9b966093cb0fd62ff1a7e4c09e6d546ad7c1de191767baffc57628f39"}, + {file = "pandas-2.2.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5dbca4c1acd72e8eeef4753eeca07de9b1db4f398669d5994086f788a5d7cc30"}, + {file = "pandas-2.2.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8cd6d7cc958a3910f934ea8dbdf17b2364827bb4dafc38ce6eef6bb3d65ff09c"}, + {file = "pandas-2.2.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99df71520d25fade9db7c1076ac94eb994f4d2673ef2aa2e86ee039b6746d20c"}, + {file = "pandas-2.2.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:31d0ced62d4ea3e231a9f228366919a5ea0b07440d9d4dac345376fd8e1477ea"}, + {file = "pandas-2.2.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7eee9e7cea6adf3e3d24e304ac6b8300646e2a5d1cd3a3c2abed9101b0846761"}, + {file = "pandas-2.2.3-cp39-cp39-win_amd64.whl", hash = "sha256:4850ba03528b6dd51d6c5d273c46f183f39a9baf3f0143e566b89450965b105e"}, + {file = "pandas-2.2.3.tar.gz", hash = "sha256:4f18ba62b61d7e192368b84517265a99b4d7ee8912f8708660fb4a366cc82667"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, +] +python-dateutil = ">=2.8.2" +pytz = ">=2020.1" +tzdata = ">=2022.7" + +[package.extras] +all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"] +aws = ["s3fs (>=2022.11.0)"] +clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"] +compression = ["zstandard (>=0.19.0)"] +computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"] +consortium-standard = ["dataframe-api-compat (>=0.1.7)"] +excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"] +feather = ["pyarrow (>=10.0.1)"] +fss = ["fsspec (>=2022.11.0)"] +gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"] +hdf5 = ["tables (>=3.8.0)"] +html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"] +mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"] +output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"] +parquet = ["pyarrow (>=10.0.1)"] +performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"] +plot = ["matplotlib (>=3.6.3)"] +postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"] +pyarrow = ["pyarrow (>=10.0.1)"] +spss = ["pyreadstat (>=1.2.0)"] +sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"] +test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] +xml = ["lxml (>=4.9.2)"] + [[package]] name = "pathspec" version = "0.12.1" description = "Utility library for gitignore style pattern matching of file paths." optional = false python-versions = ">=3.8" -groups = ["dev"] files = [ {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, @@ -230,7 +395,6 @@ version = "4.3.6" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" -groups = ["dev"] files = [ {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, @@ -241,25 +405,97 @@ docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.0.2)", "sphinx-a test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.2)", "pytest-cov (>=5)", "pytest-mock (>=3.14)"] type = ["mypy (>=1.11.2)"] +[[package]] +name = "protobuf" +version = "5.29.3" +description = "" +optional = false +python-versions = ">=3.8" +files = [ + {file = "protobuf-5.29.3-cp310-abi3-win32.whl", hash = "sha256:3ea51771449e1035f26069c4c7fd51fba990d07bc55ba80701c78f886bf9c888"}, + {file = "protobuf-5.29.3-cp310-abi3-win_amd64.whl", hash = "sha256:a4fa6f80816a9a0678429e84973f2f98cbc218cca434abe8db2ad0bffc98503a"}, + {file = "protobuf-5.29.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a8434404bbf139aa9e1300dbf989667a83d42ddda9153d8ab76e0d5dcaca484e"}, + {file = "protobuf-5.29.3-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:daaf63f70f25e8689c072cfad4334ca0ac1d1e05a92fc15c54eb9cf23c3efd84"}, + {file = "protobuf-5.29.3-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:c027e08a08be10b67c06bf2370b99c811c466398c357e615ca88c91c07f0910f"}, + {file = "protobuf-5.29.3-cp38-cp38-win32.whl", hash = "sha256:84a57163a0ccef3f96e4b6a20516cedcf5bb3a95a657131c5c3ac62200d23252"}, + {file = "protobuf-5.29.3-cp38-cp38-win_amd64.whl", hash = "sha256:b89c115d877892a512f79a8114564fb435943b59067615894c3b13cd3e1fa107"}, + {file = "protobuf-5.29.3-cp39-cp39-win32.whl", hash = "sha256:0eb32bfa5219fc8d4111803e9a690658aa2e6366384fd0851064b963b6d1f2a7"}, + {file = "protobuf-5.29.3-cp39-cp39-win_amd64.whl", hash = "sha256:6ce8cc3389a20693bfde6c6562e03474c40851b44975c9b2bf6df7d8c4f864da"}, + {file = "protobuf-5.29.3-py3-none-any.whl", hash = "sha256:0a18ed4a24198528f2333802eb075e59dea9d679ab7a6c5efb017a59004d849f"}, + {file = "protobuf-5.29.3.tar.gz", hash = "sha256:5da0f41edaf117bde316404bad1a486cb4ededf8e4a54891296f648e8e076620"}, +] + [[package]] name = "py4j" version = "0.10.9.7" description = "Enables Python programs to dynamically access arbitrary Java objects" optional = false python-versions = "*" -groups = ["main"] files = [ {file = "py4j-0.10.9.7-py2.py3-none-any.whl", hash = "sha256:85defdfd2b2376eb3abf5ca6474b51ab7e0de341c75a02f46dc9b5976f5a5c1b"}, {file = "py4j-0.10.9.7.tar.gz", hash = "sha256:0b6e5315bb3ada5cf62ac651d107bb2ebc02def3dee9d9548e3baac644ea8dbb"}, ] +[[package]] +name = "pyarrow" +version = "19.0.1" +description = "Python library for Apache Arrow" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:fc28912a2dc924dddc2087679cc8b7263accc71b9ff025a1362b004711661a69"}, + {file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:fca15aabbe9b8355800d923cc2e82c8ef514af321e18b437c3d782aa884eaeec"}, + {file = "pyarrow-19.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad76aef7f5f7e4a757fddcdcf010a8290958f09e3470ea458c80d26f4316ae89"}, + {file = "pyarrow-19.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d03c9d6f2a3dffbd62671ca070f13fc527bb1867b4ec2b98c7eeed381d4f389a"}, + {file = "pyarrow-19.0.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:65cf9feebab489b19cdfcfe4aa82f62147218558d8d3f0fc1e9dea0ab8e7905a"}, + {file = "pyarrow-19.0.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:41f9706fbe505e0abc10e84bf3a906a1338905cbbcf1177b71486b03e6ea6608"}, + {file = "pyarrow-19.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:c6cb2335a411b713fdf1e82a752162f72d4a7b5dbc588e32aa18383318b05866"}, + {file = "pyarrow-19.0.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:cc55d71898ea30dc95900297d191377caba257612f384207fe9f8293b5850f90"}, + {file = "pyarrow-19.0.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:7a544ec12de66769612b2d6988c36adc96fb9767ecc8ee0a4d270b10b1c51e00"}, + {file = "pyarrow-19.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0148bb4fc158bfbc3d6dfe5001d93ebeed253793fff4435167f6ce1dc4bddeae"}, + {file = "pyarrow-19.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f24faab6ed18f216a37870d8c5623f9c044566d75ec586ef884e13a02a9d62c5"}, + {file = "pyarrow-19.0.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:4982f8e2b7afd6dae8608d70ba5bd91699077323f812a0448d8b7abdff6cb5d3"}, + {file = "pyarrow-19.0.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:49a3aecb62c1be1d822f8bf629226d4a96418228a42f5b40835c1f10d42e4db6"}, + {file = "pyarrow-19.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:008a4009efdb4ea3d2e18f05cd31f9d43c388aad29c636112c2966605ba33466"}, + {file = "pyarrow-19.0.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:80b2ad2b193e7d19e81008a96e313fbd53157945c7be9ac65f44f8937a55427b"}, + {file = "pyarrow-19.0.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:ee8dec072569f43835932a3b10c55973593abc00936c202707a4ad06af7cb294"}, + {file = "pyarrow-19.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d5d1ec7ec5324b98887bdc006f4d2ce534e10e60f7ad995e7875ffa0ff9cb14"}, + {file = "pyarrow-19.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3ad4c0eb4e2a9aeb990af6c09e6fa0b195c8c0e7b272ecc8d4d2b6574809d34"}, + {file = "pyarrow-19.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:d383591f3dcbe545f6cc62daaef9c7cdfe0dff0fb9e1c8121101cabe9098cfa6"}, + {file = "pyarrow-19.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b4c4156a625f1e35d6c0b2132635a237708944eb41df5fbe7d50f20d20c17832"}, + {file = "pyarrow-19.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:5bd1618ae5e5476b7654c7b55a6364ae87686d4724538c24185bbb2952679960"}, + {file = "pyarrow-19.0.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:e45274b20e524ae5c39d7fc1ca2aa923aab494776d2d4b316b49ec7572ca324c"}, + {file = "pyarrow-19.0.1-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:d9dedeaf19097a143ed6da37f04f4051aba353c95ef507764d344229b2b740ae"}, + {file = "pyarrow-19.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ebfb5171bb5f4a52319344ebbbecc731af3f021e49318c74f33d520d31ae0c4"}, + {file = "pyarrow-19.0.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2a21d39fbdb948857f67eacb5bbaaf36802de044ec36fbef7a1c8f0dd3a4ab2"}, + {file = "pyarrow-19.0.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:99bc1bec6d234359743b01e70d4310d0ab240c3d6b0da7e2a93663b0158616f6"}, + {file = "pyarrow-19.0.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:1b93ef2c93e77c442c979b0d596af45e4665d8b96da598db145b0fec014b9136"}, + {file = "pyarrow-19.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:d9d46e06846a41ba906ab25302cf0fd522f81aa2a85a71021826f34639ad31ef"}, + {file = "pyarrow-19.0.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:c0fe3dbbf054a00d1f162fda94ce236a899ca01123a798c561ba307ca38af5f0"}, + {file = "pyarrow-19.0.1-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:96606c3ba57944d128e8a8399da4812f56c7f61de8c647e3470b417f795d0ef9"}, + {file = "pyarrow-19.0.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f04d49a6b64cf24719c080b3c2029a3a5b16417fd5fd7c4041f94233af732f3"}, + {file = "pyarrow-19.0.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a9137cf7e1640dce4c190551ee69d478f7121b5c6f323553b319cac936395f6"}, + {file = "pyarrow-19.0.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:7c1bca1897c28013db5e4c83944a2ab53231f541b9e0c3f4791206d0c0de389a"}, + {file = "pyarrow-19.0.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:58d9397b2e273ef76264b45531e9d552d8ec8a6688b7390b5be44c02a37aade8"}, + {file = "pyarrow-19.0.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:b9766a47a9cb56fefe95cb27f535038b5a195707a08bf61b180e642324963b46"}, + {file = "pyarrow-19.0.1-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:6c5941c1aac89a6c2f2b16cd64fe76bcdb94b2b1e99ca6459de4e6f07638d755"}, + {file = "pyarrow-19.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd44d66093a239358d07c42a91eebf5015aa54fccba959db899f932218ac9cc8"}, + {file = "pyarrow-19.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:335d170e050bcc7da867a1ed8ffb8b44c57aaa6e0843b156a501298657b1e972"}, + {file = "pyarrow-19.0.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:1c7556165bd38cf0cd992df2636f8bcdd2d4b26916c6b7e646101aff3c16f76f"}, + {file = "pyarrow-19.0.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:699799f9c80bebcf1da0983ba86d7f289c5a2a5c04b945e2f2bcf7e874a91911"}, + {file = "pyarrow-19.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:8464c9fbe6d94a7fe1599e7e8965f350fd233532868232ab2596a71586c5a429"}, + {file = "pyarrow-19.0.1.tar.gz", hash = "sha256:3bf266b485df66a400f282ac0b6d1b500b9d2ae73314a153dbe97d6d5cc8a99e"}, +] + +[package.extras] +test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] + [[package]] name = "pycodestyle" version = "2.12.1" description = "Python style guide checker" optional = false python-versions = ">=3.8" -groups = ["dev"] files = [ {file = "pycodestyle-2.12.1-py2.py3-none-any.whl", hash = "sha256:46f0fb92069a7c28ab7bb558f05bfc0110dac69a0cd23c61ea0040283a9d78b3"}, {file = "pycodestyle-2.12.1.tar.gz", hash = "sha256:6838eae08bbce4f6accd5d5572075c63626a15ee3e6f842df996bf62f6d73521"}, @@ -271,7 +507,6 @@ version = "3.2.0" description = "passive checker of Python programs" optional = false python-versions = ">=3.8" -groups = ["dev"] files = [ {file = "pyflakes-3.2.0-py2.py3-none-any.whl", hash = "sha256:84b5be138a2dfbb40689ca07e2152deb896a65c3a3e24c251c5c62489568074a"}, {file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"}, @@ -283,13 +518,18 @@ version = "3.5.4" description = "Apache Spark Python API" optional = false python-versions = ">=3.8" -groups = ["main"] files = [ {file = "pyspark-3.5.4.tar.gz", hash = "sha256:1c2926d63020902163f58222466adf6f8016f6c43c1f319b8e7a71dbaa05fc51"}, ] [package.dependencies] +googleapis-common-protos = {version = ">=1.56.4", optional = true, markers = "extra == \"connect\""} +grpcio = {version = ">=1.56.0", optional = true, markers = "extra == \"connect\""} +grpcio-status = {version = ">=1.56.0", optional = true, markers = "extra == \"connect\""} +numpy = {version = ">=1.15,<2", optional = true, markers = "extra == \"connect\""} +pandas = {version = ">=1.0.5", optional = true, markers = "extra == \"connect\""} py4j = "0.10.9.7" +pyarrow = {version = ">=4.0.0", optional = true, markers = "extra == \"connect\""} [package.extras] connect = ["googleapis-common-protos (>=1.56.4)", "grpcio (>=1.56.0)", "grpcio-status (>=1.56.0)", "numpy (>=1.15,<2)", "pandas (>=1.0.5)", "pyarrow (>=4.0.0)"] @@ -298,14 +538,48 @@ mllib = ["numpy (>=1.15,<2)"] pandas-on-spark = ["numpy (>=1.15,<2)", "pandas (>=1.0.5)", "pyarrow (>=4.0.0)"] sql = ["numpy (>=1.15,<2)", "pandas (>=1.0.5)", "pyarrow (>=4.0.0)"] +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +description = "Extensions to the standard Python datetime module" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, +] + +[package.dependencies] +six = ">=1.5" + +[[package]] +name = "pytz" +version = "2025.1" +description = "World timezone definitions, modern and historical" +optional = false +python-versions = "*" +files = [ + {file = "pytz-2025.1-py2.py3-none-any.whl", hash = "sha256:89dd22dca55b46eac6eda23b2d72721bf1bdfef212645d81513ef5d03038de57"}, + {file = "pytz-2025.1.tar.gz", hash = "sha256:c2db42be2a2518b28e65f9207c4d05e6ff547d1efa4086469ef855e4ab70178e"}, +] + +[[package]] +name = "six" +version = "1.17.0" +description = "Python 2 and 3 compatibility utilities" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, + {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, +] + [[package]] name = "tomli" version = "2.2.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" -groups = ["dev"] -markers = "python_version < \"3.11\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, @@ -347,14 +621,23 @@ version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" -groups = ["dev"] -markers = "python_version < \"3.11\"" files = [ {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] +[[package]] +name = "tzdata" +version = "2025.1" +description = "Provider of IANA time zone data" +optional = false +python-versions = ">=2" +files = [ + {file = "tzdata-2025.1-py2.py3-none-any.whl", hash = "sha256:7e127113816800496f027041c570f50bcd464a020098a3b6b199517772303639"}, + {file = "tzdata-2025.1.tar.gz", hash = "sha256:24894909e88cdb28bd1636c6887801df64cb485bd593f2fd83ef29075a81d694"}, +] + [metadata] -lock-version = "2.1" +lock-version = "2.0" python-versions = ">=3.9 <3.13" -content-hash = "52c129fee3e94e69edf727f219bc7582ddbfcedf6c43547a7f67a876051bd7c4" +content-hash = "e7e3cc6021a3736ea422d8866904491341ab719a8a235f1a1b663bc0e6fb6561" diff --git a/python/pyproject.toml b/python/pyproject.toml index a239412ce..93f9811ec 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -35,13 +35,14 @@ script = "dev/build_jar.py" [tool.poetry.dependencies] python = ">=3.9 <3.13" nose = "1.3.7" -pyspark = "^3.4" +pyspark = ">=3.3" numpy = ">= 1.7" [tool.poetry.group.dev.dependencies] black = "^25.1.0" flake8 = "^7.1.1" isort = "^6.0.0" +pyspark = { version = "3.5.4", extras = ["connect"] } [build-system] requires = ["poetry-core"] diff --git a/python/tests/tests.py b/python/tests/tests.py index c9a93fec6..3e0a09f47 100644 --- a/python/tests/tests.py +++ b/python/tests/tests.py @@ -29,7 +29,7 @@ from pyspark.sql import functions as sqlfunctions from pyspark.version import __version__ -if __version__[:2] >= "3.4": +if __version__[:3] >= "3.4": from pyspark.sql.utils import is_remote else: @@ -43,6 +43,7 @@ def setUpClass(cls): warnings.filterwarnings("ignore", category=ResourceWarning) warnings.filterwarnings("ignore", category=DeprecationWarning) cls.checkpointDir = "/tmp/GFTestsCheckpointDir" + pathlib.Path(cls.checkpointDir).mkdir(parents=True, exist_ok=True) if is_remote(): cls.spark = ( @@ -210,22 +211,24 @@ def test_page_rank(self): numVertices = vertices.count() vertices = GraphFrame(vertices, edges).outDegrees + vertices.toPandas().head() vertices.cache() graph = GraphFrame(vertices, edges) alpha = 0.15 + pregel = graph.pregel ranks = ( graph.pregel.setMaxIter(5) .withVertexColumn( "rank", lit(1.0 / numVertices), - coalesce(Pregel.msg(), lit(0.0)) * lit(1.0 - alpha) + coalesce(pregel.msg(), lit(0.0)) * lit(1.0 - alpha) + lit(alpha / numVertices), ) - .sendMsgToDst(Pregel.src("rank") / Pregel.src("outDegree")) - .aggMsgs(sum(Pregel.msg())) + .sendMsgToDst(pregel.src("rank") / pregel.src("outDegree")) + .aggMsgs(sum(pregel.msg())) .run() ) - resultRows = ranks.sort(ranks.id).collect() + resultRows = ranks.sort("id").collect() result = map(lambda x: x.rank, resultRows) expected = [0.245, 0.224, 0.303, 0.03, 0.197] for a, b in zip(result, expected): diff --git a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index 68c13b85b..1713c58da 100644 --- a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -21,14 +21,14 @@ import java.io.IOException import java.math.BigDecimal import java.util.UUID +import org.graphframes.{GraphFrame, Logging} + import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DecimalType -import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.storage.StorageLevel -import org.graphframes.{GraphFrame, Logging} - /** * Connected components algorithm. * @@ -201,14 +201,13 @@ object ConnectedComponents extends Logging { * Prepares the input graph for computing connected components by: * - de-duplicating vertices and assigning unique long IDs to each, * - changing edge directions to have increasing long IDs from src to dst, - * - de-duplicating edges and removing self-loops. - * In the returned GraphFrame, the vertex DataFrame has two columns: + * - de-duplicating edges and removing self-loops. In the returned GraphFrame, the vertex + * DataFrame has two columns: * - column `id` stores a long ID assigned to the vertex, - * - column `attr` stores the original vertex attributes. - * The edge DataFrame has two columns: + * - column `attr` stores the original vertex attributes. The edge DataFrame has two columns: * - column `src` stores the long ID of the source vertex, - * - column `dst` stores the long ID of the destination vertex, - * where we always have `src` < `dst`. + * - column `dst` stores the long ID of the destination vertex, where we always have `src` < + * `dst`. */ private def prepare(graph: GraphFrame): GraphFrame = { // TODO: This assignment job might fail if the graph is skewed. @@ -311,8 +310,14 @@ object ConnectedComponents extends Logging { new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString } .getOrElse { - throw new IOException( - "Checkpoint directory is not set. Please set it first using sc.setCheckpointDir().") + // Spark-Connect workaround + spark.conf.getOption("spark.checkpoint.dir") match { + case Some(d) => new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString + case None => + throw new IOException( + "Checkpoint directory is not set. Please set it first using sc.setCheckpointDir()" + + "or by specifying the conf 'spark.checkpoint.dir'.") + } } logInfo(s"$logPrefix Using $dir for checkpointing with interval $checkpointInterval.") Some(dir) diff --git a/src/main/scala/org/graphframes/lib/Pregel.scala b/src/main/scala/org/graphframes/lib/Pregel.scala index beba0eb4e..fd4ec96b7 100644 --- a/src/main/scala/org/graphframes/lib/Pregel.scala +++ b/src/main/scala/org/graphframes/lib/Pregel.scala @@ -17,8 +17,11 @@ package org.graphframes.lib +import java.io.IOException + import org.graphframes.GraphFrame import org.graphframes.GraphFrame._ + import org.apache.spark.sql.{Column, DataFrame} import org.apache.spark.sql.functions.{array, col, explode, struct} @@ -249,6 +252,17 @@ class Pregel(val graph: GraphFrame) { var newVertexUpdateColDF = verticesWithMsg.select((col(ID) :: updateVertexCols): _*) + if (shouldCheckpoint && graph.spark.sparkContext.getCheckpointDir.isEmpty) { + // Spark Connect workaround + graph.spark.conf.getOption("spark.checkpoint.dir") match { + case Some(d) => graph.spark.sparkContext.setCheckpointDir(d) + case None => + throw new IOException( + "Checkpoint directory is not set. Please set it first using sc.setCheckpointDir()" + + "or by specifying the conf 'spark.checkpoint.dir'.") + } + } + if (shouldCheckpoint && iteration % checkpointInterval == 0) { // do checkpoint, use lazy checkpoint because later we will materialize this DF. newVertexUpdateColDF = newVertexUpdateColDF.checkpoint(eager = false) From 7e325aa4cf18914adcbe98e1b00dc316d4b93a8d Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 23 Feb 2025 22:57:51 +0100 Subject: [PATCH 08/27] Fix tests --- python/graphframes/connect/utils.py | 2 ++ python/tests/tests.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/graphframes/connect/utils.py b/python/graphframes/connect/utils.py index aacd5ed06..77152137e 100644 --- a/python/graphframes/connect/utils.py +++ b/python/graphframes/connect/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pyspark.sql.connect.client import SparkConnectClient from pyspark.sql.connect.column import Column from pyspark.sql.connect.dataframe import DataFrame diff --git a/python/tests/tests.py b/python/tests/tests.py index 3e0a09f47..df22155e3 100644 --- a/python/tests/tests.py +++ b/python/tests/tests.py @@ -23,7 +23,7 @@ from graphframes.classic.graphframe import _from_java_gf, _java_api from graphframes.examples import BeliefPropagation, Graphs -from graphframes.graphframe import GraphFrame, Pregel +from graphframes.graphframe import GraphFrame from graphframes.lib import AggregateMessages as AM from pyspark.sql import SparkSession from pyspark.sql import functions as sqlfunctions From c47a57e048c95ee55ccef1eae8fcf16e6cb11fea Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 23 Feb 2025 23:06:29 +0100 Subject: [PATCH 09/27] Fix tests --- .github/workflows/python-ci.yml | 4 +--- {dev => python/dev}/run_connect.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) rename {dev => python/dev}/run_connect.py (98%) diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index 54cccdcb4..3e6c941a5 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -46,11 +46,9 @@ jobs: poetry run python -m unittest discover -s tests/ - name: Test SparkConnect + working-directory: ./python run: | - cd python export VENV_ROOT=$(poetry env info --path) - cd .. $VENV_ROOT/bin/python dev/run-connect.sh - cd python export SPARK_CONNECT_MODE_ENABLED=1 poetry run python -m unittest discover -s tests/ diff --git a/dev/run_connect.py b/python/dev/run_connect.py similarity index 98% rename from dev/run_connect.py rename to python/dev/run_connect.py index 95fb0c6fb..5d2898404 100644 --- a/dev/run_connect.py +++ b/python/dev/run_connect.py @@ -15,7 +15,7 @@ if __name__ == "__main__": - prj_root = Path(__file__).parent.parent + prj_root = Path(__file__).parent.parent.parent scala_root = prj_root.joinpath("graphframes-connect") print("Build Graphframes...") From fc8ebaeed25203344a6e55582d4bbadb48291836 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 23 Feb 2025 23:10:48 +0100 Subject: [PATCH 10/27] Fix CI typo --- .github/workflows/python-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index 3e6c941a5..c37d99661 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -49,6 +49,6 @@ jobs: working-directory: ./python run: | export VENV_ROOT=$(poetry env info --path) - $VENV_ROOT/bin/python dev/run-connect.sh + $VENV_ROOT/bin/python dev/run_connect.sh export SPARK_CONNECT_MODE_ENABLED=1 poetry run python -m unittest discover -s tests/ From f13c754ed30f4b0a8cdc8ec65d699b3f40a5ed00 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 23 Feb 2025 23:15:05 +0100 Subject: [PATCH 11/27] Fix typo in CI --- .github/workflows/python-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index c37d99661..e80d85bfd 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -49,6 +49,6 @@ jobs: working-directory: ./python run: | export VENV_ROOT=$(poetry env info --path) - $VENV_ROOT/bin/python dev/run_connect.sh + $VENV_ROOT/bin/python dev/run_connect.py export SPARK_CONNECT_MODE_ENABLED=1 poetry run python -m unittest discover -s tests/ From a21b5aa941e19f3e39663bec97f438315ebb171d Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 23 Feb 2025 23:21:07 +0100 Subject: [PATCH 12/27] Fix wget's verbose + GHA bug --- python/dev/run_connect.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/dev/run_connect.py b/python/dev/run_connect.py index 5d2898404..a0916ece7 100644 --- a/python/dev/run_connect.py +++ b/python/dev/run_connect.py @@ -50,6 +50,7 @@ get_spark = subprocess.run( [ "wget", + "--no-verbose", f"https://archive.apache.org/dist/spark/spark-{SPARK_VERSION}/spark-{SPARK_VERSION}-bin-hadoop3.tgz", ], stdout=subprocess.PIPE, From f4c91d6d1a014701de41040f021021b6a8816353 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Sun, 23 Feb 2025 23:47:52 +0100 Subject: [PATCH 13/27] Stop connect server --- .github/workflows/python-ci.yml | 1 + python/dev/stop_connect.py | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 python/dev/stop_connect.py diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index e80d85bfd..60b9d26f9 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -52,3 +52,4 @@ jobs: $VENV_ROOT/bin/python dev/run_connect.py export SPARK_CONNECT_MODE_ENABLED=1 poetry run python -m unittest discover -s tests/ + poetry run python dev/stop_connect.py diff --git a/python/dev/stop_connect.py b/python/dev/stop_connect.py new file mode 100644 index 000000000..849eefaea --- /dev/null +++ b/python/dev/stop_connect.py @@ -0,0 +1,25 @@ +#!/usr/bin/python + + +import shutil +import subprocess +from pathlib import Path + +if __name__ == "__main__": + prj_root = Path(__file__).parent.parent.parent + scala_root = prj_root.joinpath("graphframes-connect") + tmp_dir = prj_root.joinpath("tmp") + checkpoint_dir = Path("/tmp/GFTestsCheckpointDir") + + stop_connect_cmd = ["./sbin/stop-connect-server.sh"] + print("Stopping SparkConnect Server...") + spark_connect_stop = subprocess.run( + stop_connect_cmd, + stdout=subprocess.PIPE, + universal_newlines=True, + ) + + if spark_connect_stop.returncode == 0: + print("Done.") + + shutil.rmtree(checkpoint_dir.absolute().__str__()) From e4d75f7f95ae8e6d52d434d41b469fcba1d7b69e Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 24 Feb 2025 00:35:30 +0100 Subject: [PATCH 14/27] An attempt to fix a bug in GHA with a non-stopping tests --- .github/workflows/python-ci.yml | 6 +++--- python/dev/run_connect.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index 60b9d26f9..8dd7238b6 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -46,10 +46,10 @@ jobs: poetry run python -m unittest discover -s tests/ - name: Test SparkConnect + env: + SPARK_CONNECT_MODE_ENABLED: 1 working-directory: ./python run: | - export VENV_ROOT=$(poetry env info --path) - $VENV_ROOT/bin/python dev/run_connect.py - export SPARK_CONNECT_MODE_ENABLED=1 + poetry run python dev/run_connect.py poetry run python -m unittest discover -s tests/ poetry run python dev/stop_connect.py diff --git a/python/dev/run_connect.py b/python/dev/run_connect.py index a0916ece7..f2f81c0d7 100644 --- a/python/dev/run_connect.py +++ b/python/dev/run_connect.py @@ -118,3 +118,4 @@ if spark_connect.returncode == 0: print("Done.") + sys.exit(0) From 0950cfd3176de1ebc6b2e72b8fec74a7b0554717 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 24 Feb 2025 00:47:39 +0100 Subject: [PATCH 15/27] Maybe https://github.com/grpc/grpc/issues/38290? --- .github/workflows/python-ci.yml | 2 +- python/poetry.lock | 124 ++++++++++++++++---------------- python/pyproject.toml | 1 + 3 files changed, 64 insertions(+), 63 deletions(-) diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index 8dd7238b6..24beef786 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -8,7 +8,7 @@ jobs: include: - spark-version: 3.5.4 scala-version: 2.12.18 - python-version: 3.9.19 + python-version: 3.10.6 runs-on: ubuntu-22.04 env: # define Java options for both official sbt and sbt-extras diff --git a/python/poetry.lock b/python/poetry.lock index 336524bf5..c34972853 100644 --- a/python/poetry.lock +++ b/python/poetry.lock @@ -470,85 +470,85 @@ grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] [[package]] name = "grpcio" -version = "1.70.0" +version = "1.67.1" description = "HTTP/2-based RPC framework" optional = false python-versions = ">=3.8" files = [ - {file = "grpcio-1.70.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:95469d1977429f45fe7df441f586521361e235982a0b39e33841549143ae2851"}, - {file = "grpcio-1.70.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:ed9718f17fbdb472e33b869c77a16d0b55e166b100ec57b016dc7de9c8d236bf"}, - {file = "grpcio-1.70.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:374d014f29f9dfdb40510b041792e0e2828a1389281eb590df066e1cc2b404e5"}, - {file = "grpcio-1.70.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2af68a6f5c8f78d56c145161544ad0febbd7479524a59c16b3e25053f39c87f"}, - {file = "grpcio-1.70.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce7df14b2dcd1102a2ec32f621cc9fab6695effef516efbc6b063ad749867295"}, - {file = "grpcio-1.70.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c78b339869f4dbf89881e0b6fbf376313e4f845a42840a7bdf42ee6caed4b11f"}, - {file = "grpcio-1.70.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:58ad9ba575b39edef71f4798fdb5c7b6d02ad36d47949cd381d4392a5c9cbcd3"}, - {file = "grpcio-1.70.0-cp310-cp310-win32.whl", hash = "sha256:2b0d02e4b25a5c1f9b6c7745d4fa06efc9fd6a611af0fb38d3ba956786b95199"}, - {file = "grpcio-1.70.0-cp310-cp310-win_amd64.whl", hash = "sha256:0de706c0a5bb9d841e353f6343a9defc9fc35ec61d6eb6111802f3aa9fef29e1"}, - {file = "grpcio-1.70.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:17325b0be0c068f35770f944124e8839ea3185d6d54862800fc28cc2ffad205a"}, - {file = "grpcio-1.70.0-cp311-cp311-macosx_10_14_universal2.whl", hash = "sha256:dbe41ad140df911e796d4463168e33ef80a24f5d21ef4d1e310553fcd2c4a386"}, - {file = "grpcio-1.70.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:5ea67c72101d687d44d9c56068328da39c9ccba634cabb336075fae2eab0d04b"}, - {file = "grpcio-1.70.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb5277db254ab7586769e490b7b22f4ddab3876c490da0a1a9d7c695ccf0bf77"}, - {file = "grpcio-1.70.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7831a0fc1beeeb7759f737f5acd9fdcda520e955049512d68fda03d91186eea"}, - {file = "grpcio-1.70.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:27cc75e22c5dba1fbaf5a66c778e36ca9b8ce850bf58a9db887754593080d839"}, - {file = "grpcio-1.70.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d63764963412e22f0491d0d32833d71087288f4e24cbcddbae82476bfa1d81fd"}, - {file = "grpcio-1.70.0-cp311-cp311-win32.whl", hash = "sha256:bb491125103c800ec209d84c9b51f1c60ea456038e4734688004f377cfacc113"}, - {file = "grpcio-1.70.0-cp311-cp311-win_amd64.whl", hash = "sha256:d24035d49e026353eb042bf7b058fb831db3e06d52bee75c5f2f3ab453e71aca"}, - {file = "grpcio-1.70.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:ef4c14508299b1406c32bdbb9fb7b47612ab979b04cf2b27686ea31882387cff"}, - {file = "grpcio-1.70.0-cp312-cp312-macosx_10_14_universal2.whl", hash = "sha256:aa47688a65643afd8b166928a1da6247d3f46a2784d301e48ca1cc394d2ffb40"}, - {file = "grpcio-1.70.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:880bfb43b1bb8905701b926274eafce5c70a105bc6b99e25f62e98ad59cb278e"}, - {file = "grpcio-1.70.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e654c4b17d07eab259d392e12b149c3a134ec52b11ecdc6a515b39aceeec898"}, - {file = "grpcio-1.70.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2394e3381071045a706ee2eeb6e08962dd87e8999b90ac15c55f56fa5a8c9597"}, - {file = "grpcio-1.70.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:b3c76701428d2df01964bc6479422f20e62fcbc0a37d82ebd58050b86926ef8c"}, - {file = "grpcio-1.70.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ac073fe1c4cd856ebcf49e9ed6240f4f84d7a4e6ee95baa5d66ea05d3dd0df7f"}, - {file = "grpcio-1.70.0-cp312-cp312-win32.whl", hash = "sha256:cd24d2d9d380fbbee7a5ac86afe9787813f285e684b0271599f95a51bce33528"}, - {file = "grpcio-1.70.0-cp312-cp312-win_amd64.whl", hash = "sha256:0495c86a55a04a874c7627fd33e5beaee771917d92c0e6d9d797628ac40e7655"}, - {file = "grpcio-1.70.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:aa573896aeb7d7ce10b1fa425ba263e8dddd83d71530d1322fd3a16f31257b4a"}, - {file = "grpcio-1.70.0-cp313-cp313-macosx_10_14_universal2.whl", hash = "sha256:d405b005018fd516c9ac529f4b4122342f60ec1cee181788249372524e6db429"}, - {file = "grpcio-1.70.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:f32090238b720eb585248654db8e3afc87b48d26ac423c8dde8334a232ff53c9"}, - {file = "grpcio-1.70.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dfa089a734f24ee5f6880c83d043e4f46bf812fcea5181dcb3a572db1e79e01c"}, - {file = "grpcio-1.70.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f19375f0300b96c0117aca118d400e76fede6db6e91f3c34b7b035822e06c35f"}, - {file = "grpcio-1.70.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:7c73c42102e4a5ec76608d9b60227d917cea46dff4d11d372f64cbeb56d259d0"}, - {file = "grpcio-1.70.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:0a5c78d5198a1f0aa60006cd6eb1c912b4a1520b6a3968e677dbcba215fabb40"}, - {file = "grpcio-1.70.0-cp313-cp313-win32.whl", hash = "sha256:fe9dbd916df3b60e865258a8c72ac98f3ac9e2a9542dcb72b7a34d236242a5ce"}, - {file = "grpcio-1.70.0-cp313-cp313-win_amd64.whl", hash = "sha256:4119fed8abb7ff6c32e3d2255301e59c316c22d31ab812b3fbcbaf3d0d87cc68"}, - {file = "grpcio-1.70.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:8058667a755f97407fca257c844018b80004ae8035565ebc2812cc550110718d"}, - {file = "grpcio-1.70.0-cp38-cp38-macosx_10_14_universal2.whl", hash = "sha256:879a61bf52ff8ccacbedf534665bb5478ec8e86ad483e76fe4f729aaef867cab"}, - {file = "grpcio-1.70.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:0ba0a173f4feacf90ee618fbc1a27956bfd21260cd31ced9bc707ef551ff7dc7"}, - {file = "grpcio-1.70.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:558c386ecb0148f4f99b1a65160f9d4b790ed3163e8610d11db47838d452512d"}, - {file = "grpcio-1.70.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:412faabcc787bbc826f51be261ae5fa996b21263de5368a55dc2cf824dc5090e"}, - {file = "grpcio-1.70.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3b0f01f6ed9994d7a0b27eeddea43ceac1b7e6f3f9d86aeec0f0064b8cf50fdb"}, - {file = "grpcio-1.70.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7385b1cb064734005204bc8994eed7dcb801ed6c2eda283f613ad8c6c75cf873"}, - {file = "grpcio-1.70.0-cp38-cp38-win32.whl", hash = "sha256:07269ff4940f6fb6710951116a04cd70284da86d0a4368fd5a3b552744511f5a"}, - {file = "grpcio-1.70.0-cp38-cp38-win_amd64.whl", hash = "sha256:aba19419aef9b254e15011b230a180e26e0f6864c90406fdbc255f01d83bc83c"}, - {file = "grpcio-1.70.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:4f1937f47c77392ccd555728f564a49128b6a197a05a5cd527b796d36f3387d0"}, - {file = "grpcio-1.70.0-cp39-cp39-macosx_10_14_universal2.whl", hash = "sha256:0cd430b9215a15c10b0e7d78f51e8a39d6cf2ea819fd635a7214fae600b1da27"}, - {file = "grpcio-1.70.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:e27585831aa6b57b9250abaf147003e126cd3a6c6ca0c531a01996f31709bed1"}, - {file = "grpcio-1.70.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c1af8e15b0f0fe0eac75195992a63df17579553b0c4af9f8362cc7cc99ccddf4"}, - {file = "grpcio-1.70.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbce24409beaee911c574a3d75d12ffb8c3e3dd1b813321b1d7a96bbcac46bf4"}, - {file = "grpcio-1.70.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ff4a8112a79464919bb21c18e956c54add43ec9a4850e3949da54f61c241a4a6"}, - {file = "grpcio-1.70.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5413549fdf0b14046c545e19cfc4eb1e37e9e1ebba0ca390a8d4e9963cab44d2"}, - {file = "grpcio-1.70.0-cp39-cp39-win32.whl", hash = "sha256:b745d2c41b27650095e81dea7091668c040457483c9bdb5d0d9de8f8eb25e59f"}, - {file = "grpcio-1.70.0-cp39-cp39-win_amd64.whl", hash = "sha256:a31d7e3b529c94e930a117b2175b2efd179d96eb3c7a21ccb0289a8ab05b645c"}, - {file = "grpcio-1.70.0.tar.gz", hash = "sha256:8d1584a68d5922330025881e63a6c1b54cc8117291d382e4fa69339b6d914c56"}, + {file = "grpcio-1.67.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:8b0341d66a57f8a3119b77ab32207072be60c9bf79760fa609c5609f2deb1f3f"}, + {file = "grpcio-1.67.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:f5a27dddefe0e2357d3e617b9079b4bfdc91341a91565111a21ed6ebbc51b22d"}, + {file = "grpcio-1.67.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:43112046864317498a33bdc4797ae6a268c36345a910de9b9c17159d8346602f"}, + {file = "grpcio-1.67.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c9b929f13677b10f63124c1a410994a401cdd85214ad83ab67cc077fc7e480f0"}, + {file = "grpcio-1.67.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7d1797a8a3845437d327145959a2c0c47c05947c9eef5ff1a4c80e499dcc6fa"}, + {file = "grpcio-1.67.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:0489063974d1452436139501bf6b180f63d4977223ee87488fe36858c5725292"}, + {file = "grpcio-1.67.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9fd042de4a82e3e7aca44008ee2fb5da01b3e5adb316348c21980f7f58adc311"}, + {file = "grpcio-1.67.1-cp310-cp310-win32.whl", hash = "sha256:638354e698fd0c6c76b04540a850bf1db27b4d2515a19fcd5cf645c48d3eb1ed"}, + {file = "grpcio-1.67.1-cp310-cp310-win_amd64.whl", hash = "sha256:608d87d1bdabf9e2868b12338cd38a79969eaf920c89d698ead08f48de9c0f9e"}, + {file = "grpcio-1.67.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:7818c0454027ae3384235a65210bbf5464bd715450e30a3d40385453a85a70cb"}, + {file = "grpcio-1.67.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ea33986b70f83844cd00814cee4451055cd8cab36f00ac64a31f5bb09b31919e"}, + {file = "grpcio-1.67.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:c7a01337407dd89005527623a4a72c5c8e2894d22bead0895306b23c6695698f"}, + {file = "grpcio-1.67.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:80b866f73224b0634f4312a4674c1be21b2b4afa73cb20953cbbb73a6b36c3cc"}, + {file = "grpcio-1.67.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9fff78ba10d4250bfc07a01bd6254a6d87dc67f9627adece85c0b2ed754fa96"}, + {file = "grpcio-1.67.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8a23cbcc5bb11ea7dc6163078be36c065db68d915c24f5faa4f872c573bb400f"}, + {file = "grpcio-1.67.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1a65b503d008f066e994f34f456e0647e5ceb34cfcec5ad180b1b44020ad4970"}, + {file = "grpcio-1.67.1-cp311-cp311-win32.whl", hash = "sha256:e29ca27bec8e163dca0c98084040edec3bc49afd10f18b412f483cc68c712744"}, + {file = "grpcio-1.67.1-cp311-cp311-win_amd64.whl", hash = "sha256:786a5b18544622bfb1e25cc08402bd44ea83edfb04b93798d85dca4d1a0b5be5"}, + {file = "grpcio-1.67.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:267d1745894200e4c604958da5f856da6293f063327cb049a51fe67348e4f953"}, + {file = "grpcio-1.67.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:85f69fdc1d28ce7cff8de3f9c67db2b0ca9ba4449644488c1e0303c146135ddb"}, + {file = "grpcio-1.67.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:f26b0b547eb8d00e195274cdfc63ce64c8fc2d3e2d00b12bf468ece41a0423a0"}, + {file = "grpcio-1.67.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4422581cdc628f77302270ff839a44f4c24fdc57887dc2a45b7e53d8fc2376af"}, + {file = "grpcio-1.67.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d7616d2ded471231c701489190379e0c311ee0a6c756f3c03e6a62b95a7146e"}, + {file = "grpcio-1.67.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8a00efecde9d6fcc3ab00c13f816313c040a28450e5e25739c24f432fc6d3c75"}, + {file = "grpcio-1.67.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:699e964923b70f3101393710793289e42845791ea07565654ada0969522d0a38"}, + {file = "grpcio-1.67.1-cp312-cp312-win32.whl", hash = "sha256:4e7b904484a634a0fff132958dabdb10d63e0927398273917da3ee103e8d1f78"}, + {file = "grpcio-1.67.1-cp312-cp312-win_amd64.whl", hash = "sha256:5721e66a594a6c4204458004852719b38f3d5522082be9061d6510b455c90afc"}, + {file = "grpcio-1.67.1-cp313-cp313-linux_armv7l.whl", hash = "sha256:aa0162e56fd10a5547fac8774c4899fc3e18c1aa4a4759d0ce2cd00d3696ea6b"}, + {file = "grpcio-1.67.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:beee96c8c0b1a75d556fe57b92b58b4347c77a65781ee2ac749d550f2a365dc1"}, + {file = "grpcio-1.67.1-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:a93deda571a1bf94ec1f6fcda2872dad3ae538700d94dc283c672a3b508ba3af"}, + {file = "grpcio-1.67.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0e6f255980afef598a9e64a24efce87b625e3e3c80a45162d111a461a9f92955"}, + {file = "grpcio-1.67.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e838cad2176ebd5d4a8bb03955138d6589ce9e2ce5d51c3ada34396dbd2dba8"}, + {file = "grpcio-1.67.1-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:a6703916c43b1d468d0756c8077b12017a9fcb6a1ef13faf49e67d20d7ebda62"}, + {file = "grpcio-1.67.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:917e8d8994eed1d86b907ba2a61b9f0aef27a2155bca6cbb322430fc7135b7bb"}, + {file = "grpcio-1.67.1-cp313-cp313-win32.whl", hash = "sha256:e279330bef1744040db8fc432becc8a727b84f456ab62b744d3fdb83f327e121"}, + {file = "grpcio-1.67.1-cp313-cp313-win_amd64.whl", hash = "sha256:fa0c739ad8b1996bd24823950e3cb5152ae91fca1c09cc791190bf1627ffefba"}, + {file = "grpcio-1.67.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:178f5db771c4f9a9facb2ab37a434c46cb9be1a75e820f187ee3d1e7805c4f65"}, + {file = "grpcio-1.67.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0f3e49c738396e93b7ba9016e153eb09e0778e776df6090c1b8c91877cc1c426"}, + {file = "grpcio-1.67.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:24e8a26dbfc5274d7474c27759b54486b8de23c709d76695237515bc8b5baeab"}, + {file = "grpcio-1.67.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b6c16489326d79ead41689c4b84bc40d522c9a7617219f4ad94bc7f448c5085"}, + {file = "grpcio-1.67.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:60e6a4dcf5af7bbc36fd9f81c9f372e8ae580870a9e4b6eafe948cd334b81cf3"}, + {file = "grpcio-1.67.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:95b5f2b857856ed78d72da93cd7d09b6db8ef30102e5e7fe0961fe4d9f7d48e8"}, + {file = "grpcio-1.67.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b49359977c6ec9f5d0573ea4e0071ad278ef905aa74e420acc73fd28ce39e9ce"}, + {file = "grpcio-1.67.1-cp38-cp38-win32.whl", hash = "sha256:f5b76ff64aaac53fede0cc93abf57894ab2a7362986ba22243d06218b93efe46"}, + {file = "grpcio-1.67.1-cp38-cp38-win_amd64.whl", hash = "sha256:804c6457c3cd3ec04fe6006c739579b8d35c86ae3298ffca8de57b493524b771"}, + {file = "grpcio-1.67.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:a25bdea92b13ff4d7790962190bf6bf5c4639876e01c0f3dda70fc2769616335"}, + {file = "grpcio-1.67.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cdc491ae35a13535fd9196acb5afe1af37c8237df2e54427be3eecda3653127e"}, + {file = "grpcio-1.67.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:85f862069b86a305497e74d0dc43c02de3d1d184fc2c180993aa8aa86fbd19b8"}, + {file = "grpcio-1.67.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ec74ef02010186185de82cc594058a3ccd8d86821842bbac9873fd4a2cf8be8d"}, + {file = "grpcio-1.67.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01f616a964e540638af5130469451cf580ba8c7329f45ca998ab66e0c7dcdb04"}, + {file = "grpcio-1.67.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:299b3d8c4f790c6bcca485f9963b4846dd92cf6f1b65d3697145d005c80f9fe8"}, + {file = "grpcio-1.67.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:60336bff760fbb47d7e86165408126f1dded184448e9a4c892189eb7c9d3f90f"}, + {file = "grpcio-1.67.1-cp39-cp39-win32.whl", hash = "sha256:5ed601c4c6008429e3d247ddb367fe8c7259c355757448d7c1ef7bd4a6739e8e"}, + {file = "grpcio-1.67.1-cp39-cp39-win_amd64.whl", hash = "sha256:5db70d32d6703b89912af16d6d45d78406374a8b8ef0d28140351dd0ec610e98"}, + {file = "grpcio-1.67.1.tar.gz", hash = "sha256:3dc2ed4cabea4dc14d5e708c2b426205956077cc5de419b4d4079315017e9732"}, ] [package.extras] -protobuf = ["grpcio-tools (>=1.70.0)"] +protobuf = ["grpcio-tools (>=1.67.1)"] [[package]] name = "grpcio-status" -version = "1.70.0" +version = "1.67.1" description = "Status proto mapping for gRPC" optional = false python-versions = ">=3.8" files = [ - {file = "grpcio_status-1.70.0-py3-none-any.whl", hash = "sha256:fc5a2ae2b9b1c1969cc49f3262676e6854aa2398ec69cb5bd6c47cd501904a85"}, - {file = "grpcio_status-1.70.0.tar.gz", hash = "sha256:0e7b42816512433b18b9d764285ff029bde059e9d41f8fe10a60631bd8348101"}, + {file = "grpcio_status-1.67.1-py3-none-any.whl", hash = "sha256:16e6c085950bdacac97c779e6a502ea671232385e6e37f258884d6883392c2bd"}, + {file = "grpcio_status-1.67.1.tar.gz", hash = "sha256:2bf38395e028ceeecfd8866b081f61628114b384da7d51ae064ddc8d766a5d11"}, ] [package.dependencies] googleapis-common-protos = ">=1.5.5" -grpcio = ">=1.70.0" +grpcio = ">=1.67.1" protobuf = ">=5.26.1,<6.0dev" [[package]] @@ -1459,4 +1459,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = ">=3.9 <3.13" -content-hash = "7f7c75162a401092d71486ab62567acefae73c53b567bfb89cc078ec60c9bc64" +content-hash = "d827923683a5908e3b555f58ada2bfce8a678e349cdfd65328d7a92c7380dc2e" diff --git a/python/pyproject.toml b/python/pyproject.toml index 3926eefbf..29dc1ab0c 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -43,6 +43,7 @@ black = "^25.1.0" flake8 = "^7.1.1" isort = "^6.0.0" pyspark = { version = "3.5.4", extras = ["connect"] } +grpcio = "<=1.67.1" [tool.poetry.group.tutorials.dependencies] py7zr = "^0.22.0" From 8cb430caf998ab429ee0fda06b33a922dba50a3d Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 24 Feb 2025 00:57:03 +0100 Subject: [PATCH 16/27] Fix broken stop-cript --- .github/workflows/python-ci.yml | 2 +- python/dev/stop_connect.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index 24beef786..f19b7dcf6 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -39,7 +39,7 @@ jobs: - name: Build Python package and its dependencies working-directory: ./python run: | - poetry install + poetry install --with=dev - name: Test working-directory: ./python run: | diff --git a/python/dev/stop_connect.py b/python/dev/stop_connect.py index 849eefaea..59c3ee03f 100644 --- a/python/dev/stop_connect.py +++ b/python/dev/stop_connect.py @@ -1,14 +1,22 @@ #!/usr/bin/python +import os import shutil import subprocess from pathlib import Path +SPARK_VERSION = "3.5.4" + if __name__ == "__main__": prj_root = Path(__file__).parent.parent.parent scala_root = prj_root.joinpath("graphframes-connect") tmp_dir = prj_root.joinpath("tmp") + unpackaed_spark_binary = f"spark-{SPARK_VERSION}-bin-hadoop3" + spark_home = tmp_dir.joinpath(unpackaed_spark_binary) + + os.chdir(spark_home) + checkpoint_dir = Path("/tmp/GFTestsCheckpointDir") stop_connect_cmd = ["./sbin/stop-connect-server.sh"] From 19c1934e3a65e15757420127acd9fd96817e69a2 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 24 Feb 2025 01:04:06 +0100 Subject: [PATCH 17/27] Ignore errors in clean-up --- python/dev/stop_connect.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/dev/stop_connect.py b/python/dev/stop_connect.py index 59c3ee03f..a6ec8304f 100644 --- a/python/dev/stop_connect.py +++ b/python/dev/stop_connect.py @@ -4,6 +4,7 @@ import os import shutil import subprocess +import sys from pathlib import Path SPARK_VERSION = "3.5.4" @@ -30,4 +31,5 @@ if spark_connect_stop.returncode == 0: print("Done.") - shutil.rmtree(checkpoint_dir.absolute().__str__()) + shutil.rmtree(checkpoint_dir.absolute().__str__(), ignore_errors=True) + sys.exit(0) From 1eef32315c80c8df7df1769631f55d1ce94f454c Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 24 Feb 2025 01:05:17 +0100 Subject: [PATCH 18/27] Verbosity in ci tests --- .github/workflows/python-ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index f19b7dcf6..07789a8a9 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -43,7 +43,7 @@ jobs: - name: Test working-directory: ./python run: | - poetry run python -m unittest discover -s tests/ + poetry run python -m unittest discover -s tests/ -v - name: Test SparkConnect env: @@ -51,5 +51,5 @@ jobs: working-directory: ./python run: | poetry run python dev/run_connect.py - poetry run python -m unittest discover -s tests/ + poetry run python -m unittest discover -s tests/ -v poetry run python dev/stop_connect.py From cc60bcb75407a1690d139167401acfb24cd3fa8e Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 10 Mar 2025 21:26:19 +0100 Subject: [PATCH 19/27] Typo --- .github/workflows/python-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index 593845ea9..35ef4d567 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -44,7 +44,7 @@ jobs: - name: Test working-directory: ./python run: | - poetry run python -m pytets + poetry run python -m pytest - name: Test SparkConnect env: From 5528d65b3d051853b7f449b0a1cc204b8a412189 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 10 Mar 2025 21:32:57 +0100 Subject: [PATCH 20/27] Fix merge-artifacts --- build.sbt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/build.sbt b/build.sbt index b8ffd4b2a..18017cfe5 100644 --- a/build.sbt +++ b/build.sbt @@ -51,7 +51,10 @@ lazy val commonSetting = Seq( "-XX:ReservedCodeCacheSize=384m", "-XX:MaxMetaspaceSize=384m", "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED", - "--add-opens=java.base/java.lang=ALL-UNNAMED"), + "--add-opens=java.base/java.lang=ALL-UNNAMED", + "--add-opens=java.base/java.nio=ALL-UNNAMED", + "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED", + "--add-opens=java.base/java.util=ALL-UNNAMED"), credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials")) lazy val root = (project in file(".")) From f88e19a6e0bc3fa397d5201e29b8f74e4ee09256 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 10 Mar 2025 21:36:18 +0100 Subject: [PATCH 21/27] Fix merge artifacts --- build.sbt | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/build.sbt b/build.sbt index 18017cfe5..b41af000b 100644 --- a/build.sbt +++ b/build.sbt @@ -96,4 +96,17 @@ lazy val connect = (project in file("graphframes-connect")) Compile / PB.includePaths ++= Seq(file("src/main/protobuf")), PB.protocVersion := "3.23.4", // Spark 3.5 branch libraryDependencies ++= Seq( - "org.apache.spark" %% "spark-connect" % sparkVer % "provided" cross CrossVersion.for3Use2_13)) + "org.apache.spark" %% "spark-connect" % sparkVer % "provided" cross CrossVersion.for3Use2_13), + + // Assembly and shading + assembly / test := {}, + assembly / assemblyShadeRules := Seq( + ShadeRule.rename("com.google.protobuf.**" -> "org.sparkproject.connect.protobuf.@1").inAll), + assembly / assemblyMergeStrategy := { + case PathList("META-INF", xs @ _*) => MergeStrategy.discard + case x if x.endsWith("module-info.class") => MergeStrategy.discard + case x => + val oldStrategy = (assembly / assemblyMergeStrategy).value + oldStrategy(x) + } + ) From 59897fb045d0846ad03b7d5ec4e1d3fbe2b04f17 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 10 Mar 2025 21:38:29 +0100 Subject: [PATCH 22/27] Apply pre-commit rules --- .../connect/proto/graphframes_pb2.py | 84 +++++++++---------- .../connect/proto/graphframes_pb2.pyi | 47 +++++------ 2 files changed, 61 insertions(+), 70 deletions(-) diff --git a/python/graphframes/connect/proto/graphframes_pb2.py b/python/graphframes/connect/proto/graphframes_pb2.py index 2ba63f068..39e99776e 100644 --- a/python/graphframes/connect/proto/graphframes_pb2.py +++ b/python/graphframes/connect/proto/graphframes_pb2.py @@ -19,7 +19,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x11graphframes.proto\x12\x1dorg.graphframes.connect.proto"\xba\r\n\x0eGraphFramesAPI\x12\x1a\n\x08vertices\x18\x01 \x01(\x0cR\x08vertices\x12\x14\n\x05\x65\x64ges\x18\x02 \x01(\x0cR\x05\x65\x64ges\x12\x61\n\x12\x61ggregate_messages\x18\x03 \x01(\x0b\x32\x30.org.graphframes.connect.proto.AggregateMessagesH\x00R\x11\x61ggregateMessages\x12\x36\n\x03\x62\x66s\x18\x04 \x01(\x0b\x32".org.graphframes.connect.proto.BFSH\x00R\x03\x62\x66s\x12g\n\x14\x63onnected_components\x18\x05 \x01(\x0b\x32\x32.org.graphframes.connect.proto.ConnectedComponentsH\x00R\x13\x63onnectedComponents\x12\x42\n\x07\x64\x65grees\x18\x06 \x01(\x0b\x32&.org.graphframes.connect.proto.DegreesH\x00R\x07\x64\x65grees\x12k\n\x16\x64rop_isolated_vertices\x18\x07 \x01(\x0b\x32\x33.org.graphframes.connect.proto.DropIsolatedVerticesH\x00R\x14\x64ropIsolatedVertices\x12O\n\x0c\x66ilter_edges\x18\x08 \x01(\x0b\x32*.org.graphframes.connect.proto.FilterEdgesH\x00R\x0b\x66ilterEdges\x12X\n\x0f\x66ilter_vertices\x18\t \x01(\x0b\x32-.org.graphframes.connect.proto.FilterVerticesH\x00R\x0e\x66ilterVertices\x12\x39\n\x04\x66ind\x18\n \x01(\x0b\x32#.org.graphframes.connect.proto.FindH\x00R\x04\x66ind\x12I\n\nin_degrees\x18\x0b \x01(\x0b\x32(.org.graphframes.connect.proto.InDegreesH\x00R\tinDegrees\x12^\n\x11label_propagation\x18\x0c \x01(\x0b\x32/.org.graphframes.connect.proto.LabelPropagationH\x00R\x10labelPropagation\x12L\n\x0bout_degrees\x18\r \x01(\x0b\x32).org.graphframes.connect.proto.OutDegreesH\x00R\noutDegrees\x12\x46\n\tpage_rank\x18\x0e \x01(\x0b\x32\'.org.graphframes.connect.proto.PageRankH\x00R\x08pageRank\x12\x84\x01\n\x1fparallel_personalized_page_rank\x18\x0f \x01(\x0b\x32;.org.graphframes.connect.proto.ParallelPersonalizedPageRankH\x00R\x1cparallelPersonalizedPageRank\x12?\n\x06pregel\x18\x10 \x01(\x0b\x32%.org.graphframes.connect.proto.PregelH\x00R\x06pregel\x12U\n\x0eshortest_paths\x18\x11 \x01(\x0b\x32,.org.graphframes.connect.proto.ShortestPathsH\x00R\rshortestPaths\x12\x80\x01\n\x1dstrongly_connected_components\x18\x12 \x01(\x0b\x32:.org.graphframes.connect.proto.StronglyConnectedComponentsH\x00R\x1bstronglyConnectedComponents\x12P\n\rsvd_plus_plus\x18\x13 \x01(\x0b\x32*.org.graphframes.connect.proto.SVDPlusPlusH\x00R\x0bsvdPlusPlus\x12U\n\x0etriangle_count\x18\x14 \x01(\x0b\x32,.org.graphframes.connect.proto.TriangleCountH\x00R\rtriangleCount\x12\x45\n\x08triplets\x18\x15 \x01(\x0b\x32\'.org.graphframes.connect.proto.TripletsH\x00R\x08tripletsB\x08\n\x06method"M\n\x12\x43olumnOrExpression\x12\x12\n\x03\x63ol\x18\x01 \x01(\x0cH\x00R\x03\x63ol\x12\x14\n\x04\x65xpr\x18\x02 \x01(\tH\x00R\x04\x65xprB\r\n\x0b\x63ol_or_expr"P\n\x0eStringOrLongID\x12\x19\n\x07long_id\x18\x01 \x01(\x03H\x00R\x06longId\x12\x1d\n\tstring_id\x18\x02 \x01(\tH\x00R\x08stringIdB\x04\n\x02id"\xaf\x02\n\x11\x41ggregateMessages\x12J\n\x07\x61gg_col\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06\x61ggCol\x12V\n\x0bsend_to_src\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x00R\tsendToSrc\x88\x01\x01\x12V\n\x0bsend_to_dst\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x01R\tsendToDst\x88\x01\x01\x42\x0e\n\x0c_send_to_srcB\x0e\n\x0c_send_to_dst"\x9d\x02\n\x03\x42\x46S\x12N\n\tfrom_expr\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x08\x66romExpr\x12J\n\x07to_expr\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06toExpr\x12R\n\x0b\x65\x64ge_filter\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\nedgeFilter\x12&\n\x0fmax_path_length\x18\x04 \x01(\x05R\rmaxPathLength"\x95\x01\n\x13\x43onnectedComponents\x12\x1c\n\talgorithm\x18\x01 \x01(\tR\talgorithm\x12/\n\x13\x63heckpoint_interval\x18\x02 \x01(\x05R\x12\x63heckpointInterval\x12/\n\x13\x62roadcast_threshold\x18\x03 \x01(\x05R\x12\x62roadcastThreshold"\t\n\x07\x44\x65grees"\x16\n\x14\x44ropIsolatedVertices"^\n\x0b\x46ilterEdges\x12O\n\tcondition\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition"a\n\x0e\x46ilterVertices\x12O\n\tcondition\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition" \n\x04\x46ind\x12\x18\n\x07pattern\x18\x01 \x01(\tR\x07pattern"\x0b\n\tInDegrees"-\n\x10LabelPropagation\x12\x19\n\x08max_iter\x18\x01 \x01(\x05R\x07maxIter"\x0c\n\nOutDegrees"\xe2\x01\n\x08PageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12O\n\tsource_id\x18\x02 \x01(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDH\x00R\x08sourceId\x88\x01\x01\x12\x1e\n\x08max_iter\x18\x03 \x01(\x05H\x01R\x07maxIter\x88\x01\x01\x12\x15\n\x03tol\x18\x04 \x01(\x01H\x02R\x03tol\x88\x01\x01\x42\x0c\n\n_source_idB\x0b\n\t_max_iterB\x06\n\x04_tol"\xb4\x01\n\x1cParallelPersonalizedPageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12L\n\nsource_ids\x18\x02 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tsourceIds\x12\x19\n\x08max_iter\x18\x03 \x01(\x05R\x07maxIter"\xd0\x04\n\x06Pregel\x12L\n\x08\x61gg_msgs\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x07\x61ggMsgs\x12X\n\x0fsend_msg_to_dst\x18\x02 \x03(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToDst\x12X\n\x0fsend_msg_to_src\x18\x03 \x03(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToSrc\x12/\n\x13\x63heckpoint_interval\x18\x04 \x01(\x05R\x12\x63heckpointInterval\x12\x19\n\x08max_iter\x18\x05 \x01(\x05R\x07maxIter\x12.\n\x13\x61\x64\x64itional_col_name\x18\x06 \x01(\tR\x11\x61\x64\x64itionalColName\x12g\n\x16\x61\x64\x64itional_col_initial\x18\x07 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x14\x61\x64\x64itionalColInitial\x12_\n\x12\x61\x64\x64itional_col_upd\x18\x08 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x10\x61\x64\x64itionalColUpd"\\\n\rShortestPaths\x12K\n\tlandmarks\x18\x01 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tlandmarks"8\n\x1bStronglyConnectedComponents\x12\x19\n\x08max_iter\x18\x01 \x01(\x05R\x07maxIter"\xd6\x01\n\x0bSVDPlusPlus\x12\x12\n\x04rank\x18\x01 \x01(\x05R\x04rank\x12\x19\n\x08max_iter\x18\x02 \x01(\x05R\x07maxIter\x12\x1b\n\tmin_value\x18\x03 \x01(\x01R\x08minValue\x12\x1b\n\tmax_value\x18\x04 \x01(\x01R\x08maxValue\x12\x16\n\x06gamma1\x18\x05 \x01(\x01R\x06gamma1\x12\x16\n\x06gamma2\x18\x06 \x01(\x01R\x06gamma2\x12\x16\n\x06gamma6\x18\x07 \x01(\x01R\x06gamma6\x12\x16\n\x06gamma7\x18\x08 \x01(\x01R\x06gamma7"\x0f\n\rTriangleCount"\n\n\x08TripletsB\xd2\x01\n!com.org.graphframes.connect.protoB\x10GraphframesProtoH\x01P\x01\xa0\x01\x01\xa2\x02\x04OGCP\xaa\x02\x1dOrg.Graphframes.Connect.Proto\xca\x02\x1dOrg\\Graphframes\\Connect\\Proto\xe2\x02)Org\\Graphframes\\Connect\\Proto\\GPBMetadata\xea\x02 Org::Graphframes::Connect::Protob\x06proto3' + b'\n\x11graphframes.proto\x12\x1dorg.graphframes.connect.proto"\xd6\x0c\n\x0eGraphFramesAPI\x12\x1a\n\x08vertices\x18\x01 \x01(\x0cR\x08vertices\x12\x14\n\x05\x65\x64ges\x18\x02 \x01(\x0cR\x05\x65\x64ges\x12\x61\n\x12\x61ggregate_messages\x18\x03 \x01(\x0b\x32\x30.org.graphframes.connect.proto.AggregateMessagesH\x00R\x11\x61ggregateMessages\x12\x36\n\x03\x62\x66s\x18\x04 \x01(\x0b\x32".org.graphframes.connect.proto.BFSH\x00R\x03\x62\x66s\x12g\n\x14\x63onnected_components\x18\x05 \x01(\x0b\x32\x32.org.graphframes.connect.proto.ConnectedComponentsH\x00R\x13\x63onnectedComponents\x12k\n\x16\x64rop_isolated_vertices\x18\x06 \x01(\x0b\x32\x33.org.graphframes.connect.proto.DropIsolatedVerticesH\x00R\x14\x64ropIsolatedVertices\x12O\n\x0c\x66ilter_edges\x18\x07 \x01(\x0b\x32*.org.graphframes.connect.proto.FilterEdgesH\x00R\x0b\x66ilterEdges\x12X\n\x0f\x66ilter_vertices\x18\x08 \x01(\x0b\x32-.org.graphframes.connect.proto.FilterVerticesH\x00R\x0e\x66ilterVertices\x12\x39\n\x04\x66ind\x18\t \x01(\x0b\x32#.org.graphframes.connect.proto.FindH\x00R\x04\x66ind\x12^\n\x11label_propagation\x18\n \x01(\x0b\x32/.org.graphframes.connect.proto.LabelPropagationH\x00R\x10labelPropagation\x12\x46\n\tpage_rank\x18\x0b \x01(\x0b\x32\'.org.graphframes.connect.proto.PageRankH\x00R\x08pageRank\x12\x84\x01\n\x1fparallel_personalized_page_rank\x18\x0c \x01(\x0b\x32;.org.graphframes.connect.proto.ParallelPersonalizedPageRankH\x00R\x1cparallelPersonalizedPageRank\x12w\n\x1apower_iteration_clustering\x18\r \x01(\x0b\x32\x37.org.graphframes.connect.proto.PowerIterationClusteringH\x00R\x18powerIterationClustering\x12?\n\x06pregel\x18\x0e \x01(\x0b\x32%.org.graphframes.connect.proto.PregelH\x00R\x06pregel\x12U\n\x0eshortest_paths\x18\x0f \x01(\x0b\x32,.org.graphframes.connect.proto.ShortestPathsH\x00R\rshortestPaths\x12\x80\x01\n\x1dstrongly_connected_components\x18\x10 \x01(\x0b\x32:.org.graphframes.connect.proto.StronglyConnectedComponentsH\x00R\x1bstronglyConnectedComponents\x12P\n\rsvd_plus_plus\x18\x11 \x01(\x0b\x32*.org.graphframes.connect.proto.SVDPlusPlusH\x00R\x0bsvdPlusPlus\x12U\n\x0etriangle_count\x18\x12 \x01(\x0b\x32,.org.graphframes.connect.proto.TriangleCountH\x00R\rtriangleCount\x12\x45\n\x08triplets\x18\x13 \x01(\x0b\x32\'.org.graphframes.connect.proto.TripletsH\x00R\x08tripletsB\x08\n\x06method"M\n\x12\x43olumnOrExpression\x12\x12\n\x03\x63ol\x18\x01 \x01(\x0cH\x00R\x03\x63ol\x12\x14\n\x04\x65xpr\x18\x02 \x01(\tH\x00R\x04\x65xprB\r\n\x0b\x63ol_or_expr"P\n\x0eStringOrLongID\x12\x19\n\x07long_id\x18\x01 \x01(\x03H\x00R\x06longId\x12\x1d\n\tstring_id\x18\x02 \x01(\tH\x00R\x08stringIdB\x04\n\x02id"\xaf\x02\n\x11\x41ggregateMessages\x12J\n\x07\x61gg_col\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06\x61ggCol\x12V\n\x0bsend_to_src\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x00R\tsendToSrc\x88\x01\x01\x12V\n\x0bsend_to_dst\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionH\x01R\tsendToDst\x88\x01\x01\x42\x0e\n\x0c_send_to_srcB\x0e\n\x0c_send_to_dst"\x9d\x02\n\x03\x42\x46S\x12N\n\tfrom_expr\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x08\x66romExpr\x12J\n\x07to_expr\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x06toExpr\x12R\n\x0b\x65\x64ge_filter\x18\x03 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\nedgeFilter\x12&\n\x0fmax_path_length\x18\x04 \x01(\x05R\rmaxPathLength"\x95\x01\n\x13\x43onnectedComponents\x12\x1c\n\talgorithm\x18\x01 \x01(\tR\talgorithm\x12/\n\x13\x63heckpoint_interval\x18\x02 \x01(\x05R\x12\x63heckpointInterval\x12/\n\x13\x62roadcast_threshold\x18\x03 \x01(\x05R\x12\x62roadcastThreshold"\x16\n\x14\x44ropIsolatedVertices"^\n\x0b\x46ilterEdges\x12O\n\tcondition\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition"a\n\x0e\x46ilterVertices\x12O\n\tcondition\x18\x02 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\tcondition" \n\x04\x46ind\x12\x18\n\x07pattern\x18\x01 \x01(\tR\x07pattern"-\n\x10LabelPropagation\x12\x19\n\x08max_iter\x18\x01 \x01(\x05R\x07maxIter"\xe2\x01\n\x08PageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12O\n\tsource_id\x18\x02 \x01(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDH\x00R\x08sourceId\x88\x01\x01\x12\x1e\n\x08max_iter\x18\x03 \x01(\x05H\x01R\x07maxIter\x88\x01\x01\x12\x15\n\x03tol\x18\x04 \x01(\x01H\x02R\x03tol\x88\x01\x01\x42\x0c\n\n_source_idB\x0b\n\t_max_iterB\x06\n\x04_tol"\xb4\x01\n\x1cParallelPersonalizedPageRank\x12+\n\x11reset_probability\x18\x01 \x01(\x01R\x10resetProbability\x12L\n\nsource_ids\x18\x02 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tsourceIds\x12\x19\n\x08max_iter\x18\x03 \x01(\x05R\x07maxIter"v\n\x18PowerIterationClustering\x12\x0c\n\x01k\x18\x01 \x01(\x05R\x01k\x12\x19\n\x08max_iter\x18\x02 \x01(\x05R\x07maxIter\x12"\n\nweight_col\x18\x03 \x01(\tH\x00R\tweightCol\x88\x01\x01\x42\r\n\x0b_weight_col"\xd0\x04\n\x06Pregel\x12L\n\x08\x61gg_msgs\x18\x01 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x07\x61ggMsgs\x12X\n\x0fsend_msg_to_dst\x18\x02 \x03(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToDst\x12X\n\x0fsend_msg_to_src\x18\x03 \x03(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x0csendMsgToSrc\x12/\n\x13\x63heckpoint_interval\x18\x04 \x01(\x05R\x12\x63heckpointInterval\x12\x19\n\x08max_iter\x18\x05 \x01(\x05R\x07maxIter\x12.\n\x13\x61\x64\x64itional_col_name\x18\x06 \x01(\tR\x11\x61\x64\x64itionalColName\x12g\n\x16\x61\x64\x64itional_col_initial\x18\x07 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x14\x61\x64\x64itionalColInitial\x12_\n\x12\x61\x64\x64itional_col_upd\x18\x08 \x01(\x0b\x32\x31.org.graphframes.connect.proto.ColumnOrExpressionR\x10\x61\x64\x64itionalColUpd"\\\n\rShortestPaths\x12K\n\tlandmarks\x18\x01 \x03(\x0b\x32-.org.graphframes.connect.proto.StringOrLongIDR\tlandmarks"8\n\x1bStronglyConnectedComponents\x12\x19\n\x08max_iter\x18\x01 \x01(\x05R\x07maxIter"\xd6\x01\n\x0bSVDPlusPlus\x12\x12\n\x04rank\x18\x01 \x01(\x05R\x04rank\x12\x19\n\x08max_iter\x18\x02 \x01(\x05R\x07maxIter\x12\x1b\n\tmin_value\x18\x03 \x01(\x01R\x08minValue\x12\x1b\n\tmax_value\x18\x04 \x01(\x01R\x08maxValue\x12\x16\n\x06gamma1\x18\x05 \x01(\x01R\x06gamma1\x12\x16\n\x06gamma2\x18\x06 \x01(\x01R\x06gamma2\x12\x16\n\x06gamma6\x18\x07 \x01(\x01R\x06gamma6\x12\x16\n\x06gamma7\x18\x08 \x01(\x01R\x06gamma7"\x0f\n\rTriangleCount"\n\n\x08TripletsB\xd2\x01\n!com.org.graphframes.connect.protoB\x10GraphframesProtoH\x01P\x01\xa0\x01\x01\xa2\x02\x04OGCP\xaa\x02\x1dOrg.Graphframes.Connect.Proto\xca\x02\x1dOrg\\Graphframes\\Connect\\Proto\xe2\x02)Org\\Graphframes\\Connect\\Proto\\GPBMetadata\xea\x02 Org::Graphframes::Connect::Protob\x06proto3' ) _globals = globals() @@ -31,47 +31,43 @@ "DESCRIPTOR" ]._serialized_options = b"\n!com.org.graphframes.connect.protoB\020GraphframesProtoH\001P\001\240\001\001\242\002\004OGCP\252\002\035Org.Graphframes.Connect.Proto\312\002\035Org\\Graphframes\\Connect\\Proto\342\002)Org\\Graphframes\\Connect\\Proto\\GPBMetadata\352\002 Org::Graphframes::Connect::Proto" _globals["_GRAPHFRAMESAPI"]._serialized_start = 53 - _globals["_GRAPHFRAMESAPI"]._serialized_end = 1775 - _globals["_COLUMNOREXPRESSION"]._serialized_start = 1777 - _globals["_COLUMNOREXPRESSION"]._serialized_end = 1854 - _globals["_STRINGORLONGID"]._serialized_start = 1856 - _globals["_STRINGORLONGID"]._serialized_end = 1936 - _globals["_AGGREGATEMESSAGES"]._serialized_start = 1939 - _globals["_AGGREGATEMESSAGES"]._serialized_end = 2242 - _globals["_BFS"]._serialized_start = 2245 - _globals["_BFS"]._serialized_end = 2530 - _globals["_CONNECTEDCOMPONENTS"]._serialized_start = 2533 - _globals["_CONNECTEDCOMPONENTS"]._serialized_end = 2682 - _globals["_DEGREES"]._serialized_start = 2684 - _globals["_DEGREES"]._serialized_end = 2693 - _globals["_DROPISOLATEDVERTICES"]._serialized_start = 2695 - _globals["_DROPISOLATEDVERTICES"]._serialized_end = 2717 - _globals["_FILTEREDGES"]._serialized_start = 2719 - _globals["_FILTEREDGES"]._serialized_end = 2813 - _globals["_FILTERVERTICES"]._serialized_start = 2815 - _globals["_FILTERVERTICES"]._serialized_end = 2912 - _globals["_FIND"]._serialized_start = 2914 - _globals["_FIND"]._serialized_end = 2946 - _globals["_INDEGREES"]._serialized_start = 2948 - _globals["_INDEGREES"]._serialized_end = 2959 - _globals["_LABELPROPAGATION"]._serialized_start = 2961 - _globals["_LABELPROPAGATION"]._serialized_end = 3006 - _globals["_OUTDEGREES"]._serialized_start = 3008 - _globals["_OUTDEGREES"]._serialized_end = 3020 - _globals["_PAGERANK"]._serialized_start = 3023 - _globals["_PAGERANK"]._serialized_end = 3249 - _globals["_PARALLELPERSONALIZEDPAGERANK"]._serialized_start = 3252 - _globals["_PARALLELPERSONALIZEDPAGERANK"]._serialized_end = 3432 - _globals["_PREGEL"]._serialized_start = 3435 - _globals["_PREGEL"]._serialized_end = 4027 - _globals["_SHORTESTPATHS"]._serialized_start = 4029 - _globals["_SHORTESTPATHS"]._serialized_end = 4121 - _globals["_STRONGLYCONNECTEDCOMPONENTS"]._serialized_start = 4123 - _globals["_STRONGLYCONNECTEDCOMPONENTS"]._serialized_end = 4179 - _globals["_SVDPLUSPLUS"]._serialized_start = 4182 - _globals["_SVDPLUSPLUS"]._serialized_end = 4396 - _globals["_TRIANGLECOUNT"]._serialized_start = 4398 - _globals["_TRIANGLECOUNT"]._serialized_end = 4413 - _globals["_TRIPLETS"]._serialized_start = 4415 - _globals["_TRIPLETS"]._serialized_end = 4425 + _globals["_GRAPHFRAMESAPI"]._serialized_end = 1675 + _globals["_COLUMNOREXPRESSION"]._serialized_start = 1677 + _globals["_COLUMNOREXPRESSION"]._serialized_end = 1754 + _globals["_STRINGORLONGID"]._serialized_start = 1756 + _globals["_STRINGORLONGID"]._serialized_end = 1836 + _globals["_AGGREGATEMESSAGES"]._serialized_start = 1839 + _globals["_AGGREGATEMESSAGES"]._serialized_end = 2142 + _globals["_BFS"]._serialized_start = 2145 + _globals["_BFS"]._serialized_end = 2430 + _globals["_CONNECTEDCOMPONENTS"]._serialized_start = 2433 + _globals["_CONNECTEDCOMPONENTS"]._serialized_end = 2582 + _globals["_DROPISOLATEDVERTICES"]._serialized_start = 2584 + _globals["_DROPISOLATEDVERTICES"]._serialized_end = 2606 + _globals["_FILTEREDGES"]._serialized_start = 2608 + _globals["_FILTEREDGES"]._serialized_end = 2702 + _globals["_FILTERVERTICES"]._serialized_start = 2704 + _globals["_FILTERVERTICES"]._serialized_end = 2801 + _globals["_FIND"]._serialized_start = 2803 + _globals["_FIND"]._serialized_end = 2835 + _globals["_LABELPROPAGATION"]._serialized_start = 2837 + _globals["_LABELPROPAGATION"]._serialized_end = 2882 + _globals["_PAGERANK"]._serialized_start = 2885 + _globals["_PAGERANK"]._serialized_end = 3111 + _globals["_PARALLELPERSONALIZEDPAGERANK"]._serialized_start = 3114 + _globals["_PARALLELPERSONALIZEDPAGERANK"]._serialized_end = 3294 + _globals["_POWERITERATIONCLUSTERING"]._serialized_start = 3296 + _globals["_POWERITERATIONCLUSTERING"]._serialized_end = 3414 + _globals["_PREGEL"]._serialized_start = 3417 + _globals["_PREGEL"]._serialized_end = 4009 + _globals["_SHORTESTPATHS"]._serialized_start = 4011 + _globals["_SHORTESTPATHS"]._serialized_end = 4103 + _globals["_STRONGLYCONNECTEDCOMPONENTS"]._serialized_start = 4105 + _globals["_STRONGLYCONNECTEDCOMPONENTS"]._serialized_end = 4161 + _globals["_SVDPLUSPLUS"]._serialized_start = 4164 + _globals["_SVDPLUSPLUS"]._serialized_end = 4378 + _globals["_TRIANGLECOUNT"]._serialized_start = 4380 + _globals["_TRIANGLECOUNT"]._serialized_end = 4395 + _globals["_TRIPLETS"]._serialized_start = 4397 + _globals["_TRIPLETS"]._serialized_end = 4407 # @@protoc_insertion_point(module_scope) diff --git a/python/graphframes/connect/proto/graphframes_pb2.pyi b/python/graphframes/connect/proto/graphframes_pb2.pyi index 88d26ff24..00649305c 100644 --- a/python/graphframes/connect/proto/graphframes_pb2.pyi +++ b/python/graphframes/connect/proto/graphframes_pb2.pyi @@ -1,6 +1,6 @@ +from collections.abc import Iterable as _Iterable +from collections.abc import Mapping as _Mapping from typing import ClassVar as _ClassVar -from typing import Iterable as _Iterable -from typing import Mapping as _Mapping from typing import Optional as _Optional from typing import Union as _Union @@ -17,16 +17,14 @@ class GraphFramesAPI(_message.Message): "aggregate_messages", "bfs", "connected_components", - "degrees", "drop_isolated_vertices", "filter_edges", "filter_vertices", "find", - "in_degrees", "label_propagation", - "out_degrees", "page_rank", "parallel_personalized_page_rank", + "power_iteration_clustering", "pregel", "shortest_paths", "strongly_connected_components", @@ -39,16 +37,14 @@ class GraphFramesAPI(_message.Message): AGGREGATE_MESSAGES_FIELD_NUMBER: _ClassVar[int] BFS_FIELD_NUMBER: _ClassVar[int] CONNECTED_COMPONENTS_FIELD_NUMBER: _ClassVar[int] - DEGREES_FIELD_NUMBER: _ClassVar[int] DROP_ISOLATED_VERTICES_FIELD_NUMBER: _ClassVar[int] FILTER_EDGES_FIELD_NUMBER: _ClassVar[int] FILTER_VERTICES_FIELD_NUMBER: _ClassVar[int] FIND_FIELD_NUMBER: _ClassVar[int] - IN_DEGREES_FIELD_NUMBER: _ClassVar[int] LABEL_PROPAGATION_FIELD_NUMBER: _ClassVar[int] - OUT_DEGREES_FIELD_NUMBER: _ClassVar[int] PAGE_RANK_FIELD_NUMBER: _ClassVar[int] PARALLEL_PERSONALIZED_PAGE_RANK_FIELD_NUMBER: _ClassVar[int] + POWER_ITERATION_CLUSTERING_FIELD_NUMBER: _ClassVar[int] PREGEL_FIELD_NUMBER: _ClassVar[int] SHORTEST_PATHS_FIELD_NUMBER: _ClassVar[int] STRONGLY_CONNECTED_COMPONENTS_FIELD_NUMBER: _ClassVar[int] @@ -60,16 +56,14 @@ class GraphFramesAPI(_message.Message): aggregate_messages: AggregateMessages bfs: BFS connected_components: ConnectedComponents - degrees: Degrees drop_isolated_vertices: DropIsolatedVertices filter_edges: FilterEdges filter_vertices: FilterVertices find: Find - in_degrees: InDegrees label_propagation: LabelPropagation - out_degrees: OutDegrees page_rank: PageRank parallel_personalized_page_rank: ParallelPersonalizedPageRank + power_iteration_clustering: PowerIterationClustering pregel: Pregel shortest_paths: ShortestPaths strongly_connected_components: StronglyConnectedComponents @@ -83,18 +77,16 @@ class GraphFramesAPI(_message.Message): aggregate_messages: _Optional[_Union[AggregateMessages, _Mapping]] = ..., bfs: _Optional[_Union[BFS, _Mapping]] = ..., connected_components: _Optional[_Union[ConnectedComponents, _Mapping]] = ..., - degrees: _Optional[_Union[Degrees, _Mapping]] = ..., drop_isolated_vertices: _Optional[_Union[DropIsolatedVertices, _Mapping]] = ..., filter_edges: _Optional[_Union[FilterEdges, _Mapping]] = ..., filter_vertices: _Optional[_Union[FilterVertices, _Mapping]] = ..., find: _Optional[_Union[Find, _Mapping]] = ..., - in_degrees: _Optional[_Union[InDegrees, _Mapping]] = ..., label_propagation: _Optional[_Union[LabelPropagation, _Mapping]] = ..., - out_degrees: _Optional[_Union[OutDegrees, _Mapping]] = ..., page_rank: _Optional[_Union[PageRank, _Mapping]] = ..., parallel_personalized_page_rank: _Optional[ _Union[ParallelPersonalizedPageRank, _Mapping] ] = ..., + power_iteration_clustering: _Optional[_Union[PowerIterationClustering, _Mapping]] = ..., pregel: _Optional[_Union[Pregel, _Mapping]] = ..., shortest_paths: _Optional[_Union[ShortestPaths, _Mapping]] = ..., strongly_connected_components: _Optional[ @@ -169,10 +161,6 @@ class ConnectedComponents(_message.Message): broadcast_threshold: _Optional[int] = ..., ) -> None: ... -class Degrees(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - class DropIsolatedVertices(_message.Message): __slots__ = () def __init__(self) -> None: ... @@ -199,20 +187,12 @@ class Find(_message.Message): pattern: str def __init__(self, pattern: _Optional[str] = ...) -> None: ... -class InDegrees(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - class LabelPropagation(_message.Message): __slots__ = ("max_iter",) MAX_ITER_FIELD_NUMBER: _ClassVar[int] max_iter: int def __init__(self, max_iter: _Optional[int] = ...) -> None: ... -class OutDegrees(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - class PageRank(_message.Message): __slots__ = ("reset_probability", "source_id", "max_iter", "tol") RESET_PROBABILITY_FIELD_NUMBER: _ClassVar[int] @@ -246,6 +226,21 @@ class ParallelPersonalizedPageRank(_message.Message): max_iter: _Optional[int] = ..., ) -> None: ... +class PowerIterationClustering(_message.Message): + __slots__ = ("k", "max_iter", "weight_col") + K_FIELD_NUMBER: _ClassVar[int] + MAX_ITER_FIELD_NUMBER: _ClassVar[int] + WEIGHT_COL_FIELD_NUMBER: _ClassVar[int] + k: int + max_iter: int + weight_col: str + def __init__( + self, + k: _Optional[int] = ..., + max_iter: _Optional[int] = ..., + weight_col: _Optional[str] = ..., + ) -> None: ... + class Pregel(_message.Message): __slots__ = ( "agg_msgs", From 97054b05b72c5c34b6d32b0a0c1614e50ad8c1ac Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 10 Mar 2025 21:58:59 +0100 Subject: [PATCH 23/27] Add the missing method --- .../graphframes/connect/graphframe_client.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/python/graphframes/connect/graphframe_client.py b/python/graphframes/connect/graphframe_client.py index 762992120..bb74e19a6 100644 --- a/python/graphframes/connect/graphframe_client.py +++ b/python/graphframes/connect/graphframe_client.py @@ -678,6 +678,45 @@ def plan(self, session: SparkConnectClient) -> proto.Relation: ) return self._update_page_rank_edge_weights(new_vertices) + def powerIterationClustering( + self, k: int, maxIter: int, weightCol: str | None = None + ) -> DataFrame: + class PowerIterationClustering(LogicalPlan): + def __init__( + self, + v: DataFrame, + e: DataFrame, + k: int, + max_iter: int, + weight_col: str | None, + ) -> None: + super().__init__(None) + self.v = v + self.e = e + self.k = k + self.max_iter = max_iter + self.weight_col = weight_col + + def plan(self, session: SparkConnectClient) -> proto.Relation: + graphframes_api_call = GraphFrameConnect._get_pb_api_message( + self.v, self.e, session + ) + graphframes_api_call.power_iteration_clustering.CopyFrom( + pb.PowerIterationClustering( + k=self.k, + max_iter=self.max_iter, + weight_col=self.weight_col, + ) + ) + plan = self._create_proto_relation() + plan.extension.Pack(graphframes_api_call) + return plan + + return DataFrame.withPlan( + PowerIterationClustering(self._vertices, self._edges, k, maxIter, weightCol), + self._spark, + ) + def shortestPaths(self, landmarks: list[str | int]) -> DataFrame: class ShortestPaths(LogicalPlan): def __init__(self, v: DataFrame, e: DataFrame, landmarks: list[str | int]) -> None: From 90a326f767c368910bd646003f8f129a62ea807d Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 10 Mar 2025 22:34:39 +0100 Subject: [PATCH 24/27] Restore accidently deleted part of CI --- .github/workflows/python-ci.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index 35ef4d567..e13b5b0ae 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -41,6 +41,13 @@ jobs: working-directory: ./python run: | poetry install --with=dev + - name: Code style + workign-directory: ./python + run: | + poetry run python -m black --check graphframes + poetry run python -m flake8 graphframes + poetry run python -m isort --check graphframes + - name: Test working-directory: ./python run: | From c8bcf43b38fea3942c99275b3a2e50ee3d52bf0f Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Tue, 11 Mar 2025 01:12:23 +0100 Subject: [PATCH 25/27] Typo --- .github/workflows/python-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml index e13b5b0ae..1e0f41fbb 100644 --- a/.github/workflows/python-ci.yml +++ b/.github/workflows/python-ci.yml @@ -42,7 +42,7 @@ jobs: run: | poetry install --with=dev - name: Code style - workign-directory: ./python + working-directory: ./python run: | poetry run python -m black --check graphframes poetry run python -m flake8 graphframes From 9d7f71423c184e2b00a1e3d973e82bb28bbeda5c Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 17 Mar 2025 19:33:40 +0100 Subject: [PATCH 26/27] Fixes from comments --- python/MANIFEST.in | 1 + python/dev/build_jar.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/MANIFEST.in b/python/MANIFEST.in index e04dbff56..f9d333e9a 100644 --- a/python/MANIFEST.in +++ b/python/MANIFEST.in @@ -8,3 +8,4 @@ recursive-exclude * *.pyc include README.md include LICENSE include graphframes/tutorials/data/.exists +recursive-include graphframes/resources *.jar diff --git a/python/dev/build_jar.py b/python/dev/build_jar.py index 0c40d0e37..03e3e0171 100644 --- a/python/dev/build_jar.py +++ b/python/dev/build_jar.py @@ -25,7 +25,7 @@ def build(spark_version: str = "3.5.4"): if sbt_build.returncode != 0: print("Error during the build of GraphFrames JAR!") print("stdout: ", sbt_build.stdout) - print("stdeerr: ", sbt_build.stderr) + print("stderr: ", sbt_build.stderr) sys.exit(1) else: print("Building DONE successfully!") From 5a91659f503a46eeae44a2fa29c8f6040a0f4719 Mon Sep 17 00:00:00 2001 From: semyonsinchenko Date: Mon, 17 Mar 2025 20:14:02 +0100 Subject: [PATCH 27/27] Pin the pyspark version <4.0 and re-generate lock --- python/poetry.lock | 77 +++++++++++++++++++++---------------------- python/pyproject.toml | 2 +- 2 files changed, 38 insertions(+), 41 deletions(-) diff --git a/python/poetry.lock b/python/poetry.lock index 2ac01b309..63733cfdd 100644 --- a/python/poetry.lock +++ b/python/poetry.lock @@ -467,20 +467,20 @@ pyflakes = ">=3.2.0,<3.3.0" [[package]] name = "googleapis-common-protos" -version = "1.69.1" +version = "1.69.2" description = "Common protobufs used in Google APIs" optional = false python-versions = ">=3.7" files = [ - {file = "googleapis_common_protos-1.69.1-py2.py3-none-any.whl", hash = "sha256:4077f27a6900d5946ee5a369fab9c8ded4c0ef1c6e880458ea2f70c14f7b70d5"}, - {file = "googleapis_common_protos-1.69.1.tar.gz", hash = "sha256:e20d2d8dda87da6fe7340afbbdf4f0bcb4c8fae7e6cadf55926c31f946b0b9b1"}, + {file = "googleapis_common_protos-1.69.2-py3-none-any.whl", hash = "sha256:0b30452ff9c7a27d80bfc5718954063e8ab53dd3697093d3bc99581f5fd24212"}, + {file = "googleapis_common_protos-1.69.2.tar.gz", hash = "sha256:3e1b904a27a33c821b4b749fd31d334c0c9c30e6113023d495e48979a3dc9c5f"}, ] [package.dependencies] -protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" +protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0" [package.extras] -grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] +grpc = ["grpcio (>=1.44.0,<2.0.0)"] [[package]] name = "grpcio" @@ -1110,43 +1110,40 @@ files = [ [[package]] name = "pycryptodomex" -version = "3.21.0" +version = "3.22.0" description = "Cryptographic library for Python" optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ - {file = "pycryptodomex-3.21.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:dbeb84a399373df84a69e0919c1d733b89e049752426041deeb30d68e9867822"}, - {file = "pycryptodomex-3.21.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:a192fb46c95489beba9c3f002ed7d93979423d1b2a53eab8771dbb1339eb3ddd"}, - {file = "pycryptodomex-3.21.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:1233443f19d278c72c4daae749872a4af3787a813e05c3561c73ab0c153c7b0f"}, - {file = "pycryptodomex-3.21.0-cp27-cp27m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbb07f88e277162b8bfca7134b34f18b400d84eac7375ce73117f865e3c80d4c"}, - {file = "pycryptodomex-3.21.0-cp27-cp27m-musllinux_1_1_aarch64.whl", hash = "sha256:e859e53d983b7fe18cb8f1b0e29d991a5c93be2c8dd25db7db1fe3bd3617f6f9"}, - {file = "pycryptodomex-3.21.0-cp27-cp27m-win32.whl", hash = "sha256:ef046b2e6c425647971b51424f0f88d8a2e0a2a63d3531817968c42078895c00"}, - {file = "pycryptodomex-3.21.0-cp27-cp27m-win_amd64.whl", hash = "sha256:da76ebf6650323eae7236b54b1b1f0e57c16483be6e3c1ebf901d4ada47563b6"}, - {file = "pycryptodomex-3.21.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:c07e64867a54f7e93186a55bec08a18b7302e7bee1b02fd84c6089ec215e723a"}, - {file = "pycryptodomex-3.21.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:56435c7124dd0ce0c8bdd99c52e5d183a0ca7fdcd06c5d5509423843f487dd0b"}, - {file = "pycryptodomex-3.21.0-cp27-cp27mu-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:65d275e3f866cf6fe891411be9c1454fb58809ccc5de6d3770654c47197acd65"}, - {file = "pycryptodomex-3.21.0-cp27-cp27mu-musllinux_1_1_aarch64.whl", hash = "sha256:5241bdb53bcf32a9568770a6584774b1b8109342bd033398e4ff2da052123832"}, - {file = "pycryptodomex-3.21.0-cp36-abi3-macosx_10_9_universal2.whl", hash = "sha256:34325b84c8b380675fd2320d0649cdcbc9cf1e0d1526edbe8fce43ed858cdc7e"}, - {file = "pycryptodomex-3.21.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:103c133d6cd832ae7266feb0a65b69e3a5e4dbbd6f3a3ae3211a557fd653f516"}, - {file = "pycryptodomex-3.21.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77ac2ea80bcb4b4e1c6a596734c775a1615d23e31794967416afc14852a639d3"}, - {file = "pycryptodomex-3.21.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9aa0cf13a1a1128b3e964dc667e5fe5c6235f7d7cfb0277213f0e2a783837cc2"}, - {file = "pycryptodomex-3.21.0-cp36-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:46eb1f0c8d309da63a2064c28de54e5e614ad17b7e2f88df0faef58ce192fc7b"}, - {file = "pycryptodomex-3.21.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:cc7e111e66c274b0df5f4efa679eb31e23c7545d702333dfd2df10ab02c2a2ce"}, - {file = "pycryptodomex-3.21.0-cp36-abi3-musllinux_1_2_i686.whl", hash = "sha256:770d630a5c46605ec83393feaa73a9635a60e55b112e1fb0c3cea84c2897aa0a"}, - {file = "pycryptodomex-3.21.0-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:52e23a0a6e61691134aa8c8beba89de420602541afaae70f66e16060fdcd677e"}, - {file = "pycryptodomex-3.21.0-cp36-abi3-win32.whl", hash = "sha256:a3d77919e6ff56d89aada1bd009b727b874d464cb0e2e3f00a49f7d2e709d76e"}, - {file = "pycryptodomex-3.21.0-cp36-abi3-win_amd64.whl", hash = "sha256:b0e9765f93fe4890f39875e6c90c96cb341767833cfa767f41b490b506fa9ec0"}, - {file = "pycryptodomex-3.21.0-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:feaecdce4e5c0045e7a287de0c4351284391fe170729aa9182f6bd967631b3a8"}, - {file = "pycryptodomex-3.21.0-pp27-pypy_73-win32.whl", hash = "sha256:365aa5a66d52fd1f9e0530ea97f392c48c409c2f01ff8b9a39c73ed6f527d36c"}, - {file = "pycryptodomex-3.21.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:3efddfc50ac0ca143364042324046800c126a1d63816d532f2e19e6f2d8c0c31"}, - {file = "pycryptodomex-3.21.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0df2608682db8279a9ebbaf05a72f62a321433522ed0e499bc486a6889b96bf3"}, - {file = "pycryptodomex-3.21.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5823d03e904ea3e53aebd6799d6b8ec63b7675b5d2f4a4bd5e3adcb512d03b37"}, - {file = "pycryptodomex-3.21.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:27e84eeff24250ffec32722334749ac2a57a5fd60332cd6a0680090e7c42877e"}, - {file = "pycryptodomex-3.21.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8ef436cdeea794015263853311f84c1ff0341b98fc7908e8a70595a68cefd971"}, - {file = "pycryptodomex-3.21.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a1058e6dfe827f4209c5cae466e67610bcd0d66f2f037465daa2a29d92d952b"}, - {file = "pycryptodomex-3.21.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9ba09a5b407cbb3bcb325221e346a140605714b5e880741dc9a1e9ecf1688d42"}, - {file = "pycryptodomex-3.21.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:8a9d8342cf22b74a746e3c6c9453cb0cfbb55943410e3a2619bd9164b48dc9d9"}, - {file = "pycryptodomex-3.21.0.tar.gz", hash = "sha256:222d0bd05381dd25c32dd6065c071ebf084212ab79bab4599ba9e6a3e0009e6c"}, + {file = "pycryptodomex-3.22.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:41673e5cc39a8524557a0472077635d981172182c9fe39ce0b5f5c19381ffaff"}, + {file = "pycryptodomex-3.22.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:276be1ed006e8fd01bba00d9bd9b60a0151e478033e86ea1cb37447bbc057edc"}, + {file = "pycryptodomex-3.22.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:813e57da5ceb4b549bab96fa548781d9a63f49f1d68fdb148eeac846238056b7"}, + {file = "pycryptodomex-3.22.0-cp27-cp27m-win32.whl", hash = "sha256:d7beeacb5394765aa8dabed135389a11ee322d3ee16160d178adc7f8ee3e1f65"}, + {file = "pycryptodomex-3.22.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:b3746dedf74787da43e4a2f85bd78f5ec14d2469eb299ddce22518b3891f16ea"}, + {file = "pycryptodomex-3.22.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:5ebc09b7d8964654aaf8a4f5ac325f2b0cc038af9bea12efff0cd4a5bb19aa42"}, + {file = "pycryptodomex-3.22.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:aef4590263b9f2f6283469e998574d0bd45c14fb262241c27055b82727426157"}, + {file = "pycryptodomex-3.22.0-cp37-abi3-macosx_10_9_x86_64.whl", hash = "sha256:5ac608a6dce9418d4f300fab7ba2f7d499a96b462f2b9b5c90d8d994cd36dcad"}, + {file = "pycryptodomex-3.22.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a24f681365ec9757ccd69b85868bbd7216ba451d0f86f6ea0eed75eeb6975db"}, + {file = "pycryptodomex-3.22.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:259664c4803a1fa260d5afb322972813c5fe30ea8b43e54b03b7e3a27b30856b"}, + {file = "pycryptodomex-3.22.0-cp37-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7127d9de3c7ce20339e06bcd4f16f1a1a77f1471bcf04e3b704306dde101b719"}, + {file = "pycryptodomex-3.22.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ee75067b35c93cc18b38af47b7c0664998d8815174cfc66dd00ea1e244eb27e6"}, + {file = "pycryptodomex-3.22.0-cp37-abi3-musllinux_1_2_i686.whl", hash = "sha256:1a8b0c5ba061ace4bcd03496d42702c3927003db805b8ec619ea6506080b381d"}, + {file = "pycryptodomex-3.22.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:bfe4fe3233ef3e58028a3ad8f28473653b78c6d56e088ea04fe7550c63d4d16b"}, + {file = "pycryptodomex-3.22.0-cp37-abi3-win32.whl", hash = "sha256:2cac9ed5c343bb3d0075db6e797e6112514764d08d667c74cb89b931aac9dddd"}, + {file = "pycryptodomex-3.22.0-cp37-abi3-win_amd64.whl", hash = "sha256:ff46212fda7ee86ec2f4a64016c994e8ad80f11ef748131753adb67e9b722ebd"}, + {file = "pycryptodomex-3.22.0-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:5bf3ce9211d2a9877b00b8e524593e2209e370a287b3d5e61a8c45f5198487e2"}, + {file = "pycryptodomex-3.22.0-pp27-pypy_73-win32.whl", hash = "sha256:684cb57812cd243217c3d1e01a720c5844b30f0b7b64bb1a49679f7e1e8a54ac"}, + {file = "pycryptodomex-3.22.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:c8cffb03f5dee1026e3f892f7cffd79926a538c67c34f8b07c90c0bd5c834e27"}, + {file = "pycryptodomex-3.22.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:140b27caa68a36d0501b05eb247bd33afa5f854c1ee04140e38af63c750d4e39"}, + {file = "pycryptodomex-3.22.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:644834b1836bb8e1d304afaf794d5ae98a1d637bd6e140c9be7dd192b5374811"}, + {file = "pycryptodomex-3.22.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:72c506aba3318505dbeecf821ed7b9a9f86f422ed085e2d79c4fba0ae669920a"}, + {file = "pycryptodomex-3.22.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:7cd39f7a110c1ab97ce9ee3459b8bc615920344dc00e56d1b709628965fba3f2"}, + {file = "pycryptodomex-3.22.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:e4eaaf6163ff13788c1f8f615ad60cdc69efac6d3bf7b310b21e8cfe5f46c801"}, + {file = "pycryptodomex-3.22.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eac39e237d65981554c2d4c6668192dc7051ad61ab5fc383ed0ba049e4007ca2"}, + {file = "pycryptodomex-3.22.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ab0d89d1761959b608952c7b347b0e76a32d1a5bb278afbaa10a7f3eaef9a0a"}, + {file = "pycryptodomex-3.22.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5e64164f816f5e43fd69f8ed98eb28f98157faf68208cd19c44ed9d8e72d33e8"}, + {file = "pycryptodomex-3.22.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:f005de31efad6f9acefc417296c641f13b720be7dbfec90edeaca601c0fab048"}, + {file = "pycryptodomex-3.22.0.tar.gz", hash = "sha256:a1da61bacc22f93a91cbe690e3eb2022a03ab4123690ab16c46abb693a9df63d"}, ] [[package]] @@ -1521,4 +1518,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = ">=3.9 <3.13" -content-hash = "9d995a3804df22f5e0d708728188e102eae377605dae89fd90ce013b456fcd29" +content-hash = "5868c8adea68ac43f8aaf157139b4eace30cde260718d170240e030ead629e19" diff --git a/python/pyproject.toml b/python/pyproject.toml index ed55a8978..bbae8404f 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -35,7 +35,7 @@ script = "dev/build_jar.py" [tool.poetry.dependencies] python = ">=3.9 <3.13" nose = "1.3.7" -pyspark = ">=3.3" +pyspark = ">=3.4 <4.0" numpy = ">= 1.7" [tool.poetry.group.dev.dependencies]