From 25725444691a5ef051582ae528ab31f92db6e4ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Thu, 16 Jan 2025 18:17:11 +0000 Subject: [PATCH 1/6] 1. --- python/docs/epytext.py | 7 +- python/docs/underscores.py | 18 +++-- .../examples/belief_propagation.py | 10 +-- python/graphframes/examples/graphs.py | 8 +- python/graphframes/graphframe.py | 77 +++++++++++-------- python/graphframes/lib/aggregate_messages.py | 22 +++--- python/graphframes/lib/pregel.py | 26 ++++--- 7 files changed, 95 insertions(+), 73 deletions(-) diff --git a/python/docs/epytext.py b/python/docs/epytext.py index e884d5e6b..b02996415 100644 --- a/python/docs/epytext.py +++ b/python/docs/epytext.py @@ -1,4 +1,5 @@ import re +from sphinx.application import Sphinx RULES = ( (r"<(!BLANKLINE)[\w.]+>", r""), @@ -9,7 +10,7 @@ ('pyspark.rdd.RDD', 'RDD'), ) -def _convert_epytext(line): +def _convert_epytext(line: str) -> str: """ >>> _convert_epytext("L{A}") :class:`A` @@ -19,9 +20,9 @@ def _convert_epytext(line): line = re.sub(p, sub, line) return line -def _process_docstring(app, what, name, obj, options, lines): +def _process_docstring(app: "Sphinx", what: str, name: str, obj: object, options: dict, lines: list[str]) -> None: for i in range(len(lines)): lines[i] = _convert_epytext(lines[i]) -def setup(app): +def setup(app: "Sphinx") -> None: app.connect("autodoc-process-docstring", _process_docstring) diff --git a/python/docs/underscores.py b/python/docs/underscores.py index fc8df8142..cabad3313 100644 --- a/python/docs/underscores.py +++ b/python/docs/underscores.py @@ -29,27 +29,33 @@ """ import os import shutil +from collections.abc import Callable +from typing import Any +from sphinx.application import Sphinx - -def setup(app): +def setup(app: Sphinx) -> None: """ Add a html-page-context and a build-finished event handlers """ app.connect('html-page-context', change_pathto) app.connect('build-finished', move_private_folders) -def change_pathto(app, pagename, templatename, context, doctree): +def change_pathto(app: Sphinx, + pagename: str, + templatename: str, + context: dict[str, Any], + doctree: Any | None) -> None: """ Replace pathto helper to change paths to folders with a leading underscore. """ - pathto = context.get('pathto') - def gh_pathto(otheruri, *args, **kw): + pathto: Callable = context.get('pathto') + def gh_pathto(otheruri: str, *args: Any, **kw: Any) -> Any: if otheruri.startswith('_'): otheruri = otheruri[1:] return pathto(otheruri, *args, **kw) context['pathto'] = gh_pathto -def move_private_folders(app, e): +def move_private_folders(app: Sphinx, e: Exception | None) -> None: """ remove leading underscore from folders in in the output folder. diff --git a/python/graphframes/examples/belief_propagation.py b/python/graphframes/examples/belief_propagation.py index ae0d096fd..9b5ab2509 100644 --- a/python/graphframes/examples/belief_propagation.py +++ b/python/graphframes/examples/belief_propagation.py @@ -27,7 +27,7 @@ __all__ = ['BeliefPropagation'] -class BeliefPropagation(object): +class BeliefPropagation: """Example code for Belief Propagation (BP) This provides a template for building customized BP algorithms for different types of graphical @@ -63,7 +63,7 @@ class BeliefPropagation(object): """ @classmethod - def runBPwithGraphFrames(cls, g, numIter): + def runBPwithGraphFrames(cls, g: GraphFrame, numIter: int) -> GraphFrame: """Run Belief Propagation using GraphFrame. This implementation of BP shows how to use GraphFrame's aggregateMessages method. @@ -117,7 +117,7 @@ def runBPwithGraphFrames(cls, g, numIter): return GraphFrame(gx.vertices.drop('color'), gx.edges) @staticmethod - def _colorGraph(g): + def _colorGraph(g: GraphFrame) -> GraphFrame: """Given a GraphFrame, choose colors for each vertex. No neighboring vertices will share the same color. The number of colors is minimized. @@ -135,7 +135,7 @@ def _colorGraph(g): return GraphFrame(v, g.edges) @staticmethod - def _sigmoid(x): + def _sigmoid(x: int | float | None) -> float | None: """Numerically stable sigmoid function 1 / (1 + exp(-x))""" if not x: return None @@ -147,7 +147,7 @@ def _sigmoid(x): return z / (1 + z) -def main(): +def main() -> None: """Run the belief propagation algorithm for an example problem.""" # setup spark session spark = SparkSession.builder.appName("BeliefPropagation example").getOrCreate() diff --git a/python/graphframes/examples/graphs.py b/python/graphframes/examples/graphs.py index 6ee6e5c0e..cea05a2ce 100644 --- a/python/graphframes/examples/graphs.py +++ b/python/graphframes/examples/graphs.py @@ -24,17 +24,17 @@ __all__ = ['Graphs'] -class Graphs(object): +class Graphs: """Example GraphFrames for testing the API :param spark: SparkSession """ - def __init__(self, spark): + def __init__(self, spark: SparkSession) -> None: self._spark = spark self._sc = spark._sc - def friends(self): + def friends(self) -> GraphFrame: """A GraphFrame of friends in a (fake) social network.""" # Vertex DataFrame v = self._spark.createDataFrame([ @@ -58,7 +58,7 @@ def friends(self): # Create a GraphFrame return GraphFrame(v, e) - def gridIsingModel(self, n, vStd=1.0, eStd=1.0): + def gridIsingModel(self, n: int, vStd: float = 1.0, eStd: float = 1.0) -> GraphFrame: """Grid Ising model with random parameters. Ising models are probabilistic graphical models over binary variables x\ :sub:`i`. diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index 92833fab8..7f58b634c 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -16,17 +16,18 @@ # import sys +from typing import Any + if sys.version > '3': basestring = str +from graphframes.lib import Pregel from pyspark import SparkContext from pyspark.sql import Column, DataFrame, SparkSession from pyspark.storagelevel import StorageLevel -from graphframes.lib import Pregel - -def _from_java_gf(jgf, spark): +def _from_java_gf(jgf: Any, spark: SparkSession) -> 'GraphFrame': """ (internal) creates a python GraphFrame wrapper from a java GraphFrame. @@ -36,13 +37,13 @@ def _from_java_gf(jgf, spark): pe = DataFrame(jgf.edges(), spark) return GraphFrame(pv, pe) -def _java_api(jsc): +def _java_api(jsc: SparkContext) -> Any: javaClassName = "org.graphframes.GraphFramePythonAPI" return jsc._jvm.Thread.currentThread().getContextClassLoader().loadClass(javaClassName) \ .newInstance() -class GraphFrame(object): +class GraphFrame: """ Represents a graph with vertices and edges stored as DataFrames. @@ -60,7 +61,7 @@ class GraphFrame(object): >>> g = GraphFrame(v, e) """ - def __init__(self, v, e): + def __init__(self, v: DataFrame, e: DataFrame) -> None: self._vertices = v self._edges = e self._spark = SparkSession.getActiveSession() @@ -89,7 +90,7 @@ def __init__(self, v, e): self._jvm_graph = self._jvm_gf_api.createGraph(v._jdf, e._jdf) @property - def vertices(self): + def vertices(self) -> DataFrame: """ :class:`DataFrame` holding vertex information, with unique column "id" for vertex IDs. @@ -97,7 +98,7 @@ def vertices(self): return self._vertices @property - def edges(self): + def edges(self) -> DataFrame: """ :class:`DataFrame` holding edge information, with unique columns "src" and "dst" storing source vertex IDs and destination vertex IDs of edges, @@ -108,14 +109,14 @@ def edges(self): def __repr__(self): return self._jvm_graph.toString() - def cache(self): + def cache(self) -> 'GraphFrame': """ Persist the dataframe representation of vertices and edges of the graph with the default storage level. """ self._jvm_graph.cache() return self - def persist(self, storageLevel=StorageLevel.MEMORY_ONLY): + def persist(self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) -> "GraphFrame": """Persist the dataframe representation of vertices and edges of the graph with the given storage level. """ @@ -123,7 +124,7 @@ def persist(self, storageLevel=StorageLevel.MEMORY_ONLY): self._jvm_graph.persist(javaStorageLevel) return self - def unpersist(self, blocking=False): + def unpersist(self, blocking: bool = False) -> 'GraphFrame': """Mark the dataframe representation of vertices and edges of the graph as non-persistent, and remove all blocks for it from memory and disk. """ @@ -131,7 +132,7 @@ def unpersist(self, blocking=False): return self @property - def outDegrees(self): + def outDegrees(self) -> DataFrame: """ The out-degree of each vertex in the graph, returned as a DataFrame with two columns: - "id": the ID of the vertex @@ -145,7 +146,7 @@ def outDegrees(self): return DataFrame(jdf, self._spark) @property - def inDegrees(self): + def inDegrees(self) -> DataFrame: """ The in-degree of each vertex in the graph, returned as a DataFame with two columns: - "id": the ID of the vertex @@ -159,7 +160,7 @@ def inDegrees(self): return DataFrame(jdf, self._spark) @property - def degrees(self): + def degrees(self) -> DataFrame: """ The degree of each vertex in the graph, returned as a DataFrame with two columns: - "id": the ID of the vertex @@ -173,7 +174,7 @@ def degrees(self): return DataFrame(jdf, self._spark) @property - def triplets(self): + def triplets(self) -> DataFrame: """ The triplets (source vertex)-[edge]->(destination vertex) for all edges in the graph. @@ -196,7 +197,7 @@ def pregel(self): """ return Pregel(self) - def find(self, pattern): + def find(self, pattern: str) -> DataFrame: """ Motif finding. @@ -208,7 +209,7 @@ def find(self, pattern): jdf = self._jvm_graph.find(pattern) return DataFrame(jdf, self._spark) - def filterVertices(self, condition): + def filterVertices(self, condition: str | Column) -> 'GraphFrame': """ Filters the vertices based on expression, remove edges containing any dropped vertices. @@ -224,7 +225,7 @@ def filterVertices(self, condition): raise TypeError("condition should be string or Column") return _from_java_gf(jdf, self._spark) - def filterEdges(self, condition): + def filterEdges(self, condition: str | Column) -> 'GraphFrame': """ Filters the edges based on expression, keep all vertices. @@ -239,7 +240,7 @@ def filterEdges(self, condition): raise TypeError("condition should be string or Column") return _from_java_gf(jdf, self._spark) - def dropIsolatedVertices(self): + def dropIsolatedVertices(self) -> 'GraphFrame': """ Drops isolated vertices, vertices are not contained in any edges. @@ -248,7 +249,9 @@ def dropIsolatedVertices(self): jdf = self._jvm_graph.dropIsolatedVertices() return _from_java_gf(jdf, self._spark) - def bfs(self, fromExpr, toExpr, edgeFilter=None, maxPathLength=10): + def bfs(self, fromExpr: str, toExpr: str, + edgeFilter: str | None = None, + maxPathLength: int = 10) -> DataFrame: """ Breadth-first search (BFS). @@ -265,7 +268,9 @@ def bfs(self, fromExpr, toExpr, edgeFilter=None, maxPathLength=10): jdf = builder.run() return DataFrame(jdf, self._spark) - def aggregateMessages(self, aggCol, sendToSrc=None, sendToDst=None): + def aggregateMessages(self, aggCol: Column | str, + sendToSrc: Column | str | None = None, + sendToDst: Column | str | None = None) -> DataFrame: """ Aggregates messages from the neighbours. @@ -309,8 +314,9 @@ def aggregateMessages(self, aggCol, sendToSrc=None, sendToDst=None): # Standard algorithms - def connectedComponents(self, algorithm = "graphframes", checkpointInterval = 2, - broadcastThreshold = 1000000): + def connectedComponents(self, algorithm: str = 'graphframes', + checkpointInterval: int = 2, + broadcastThreshold: int = 1000000) -> DataFrame: """ Computes the connected components of the graph. @@ -331,7 +337,7 @@ def connectedComponents(self, algorithm = "graphframes", checkpointInterval = 2, .run() return DataFrame(jdf, self._spark) - def labelPropagation(self, maxIter): + def labelPropagation(self, maxIter: int) -> DataFrame: """ Runs static label propagation for detecting communities in networks. @@ -343,8 +349,10 @@ def labelPropagation(self, maxIter): jdf = self._jvm_graph.labelPropagation().maxIter(maxIter).run() return DataFrame(jdf, self._spark) - def pageRank(self, resetProbability = 0.15, sourceId = None, maxIter = None, - tol = None): + def pageRank(self, resetProbability: float = 0.15, + sourceId: Any | None = None, + maxIter: int | None = None, + tol: float | None = None) -> 'GraphFrame': """ Runs the PageRank algorithm on the graph. Note: Exactly one of fixed_num_iter or tolerance must be set. @@ -371,8 +379,9 @@ def pageRank(self, resetProbability = 0.15, sourceId = None, maxIter = None, jgf = builder.run() return _from_java_gf(jgf, self._spark) - def parallelPersonalizedPageRank(self, resetProbability = 0.15, sourceIds = None, - maxIter = None): + def parallelPersonalizedPageRank(self, resetProbability: float = 0.15, + sourceIds: list[Any] | None = None, + maxIter: int | None = None) -> 'GraphFrame': """ Run the personalized PageRank algorithm on the graph, from the provided list of sources in parallel for a fixed number of iterations. @@ -394,7 +403,7 @@ def parallelPersonalizedPageRank(self, resetProbability = 0.15, sourceIds = None jgf = builder.run() return _from_java_gf(jgf, self._spark) - def shortestPaths(self, landmarks): + def shortestPaths(self, landmarks: list[Any]) -> DataFrame: """ Runs the shortest path algorithm from a set of landmark vertices in the graph. @@ -406,7 +415,7 @@ def shortestPaths(self, landmarks): jdf = self._jvm_graph.shortestPaths().landmarks(landmarks).run() return DataFrame(jdf, self._spark) - def stronglyConnectedComponents(self, maxIter): + def stronglyConnectedComponents(self, maxIter: int) -> DataFrame: """ Runs the strongly connected components algorithm on this graph. @@ -418,8 +427,10 @@ def stronglyConnectedComponents(self, maxIter): jdf = self._jvm_graph.stronglyConnectedComponents().maxIter(maxIter).run() return DataFrame(jdf, self._spark) - def svdPlusPlus(self, rank = 10, maxIter = 2, minValue = 0.0, maxValue = 5.0, - gamma1 = 0.007, gamma2 = 0.007, gamma6 = 0.005, gamma7 = 0.015): + def svdPlusPlus(self, rank: int = 10, maxIter: int = 2, + minValue: float = 0.0, maxValue: float = 5.0, + gamma1: float = 0.007, gamma2: float = 0.007, + gamma6: float = 0.005, gamma7: float = 0.015) -> tuple[DataFrame, float]: """ Runs the SVD++ algorithm. @@ -436,7 +447,7 @@ def svdPlusPlus(self, rank = 10, maxIter = 2, minValue = 0.0, maxValue = 5.0, v = DataFrame(jdf, self._spark) return (v, loss) - def triangleCount(self): + def triangleCount(self) -> DataFrame: """ Counts the number of triangles passing through each vertex in this graph. diff --git a/python/graphframes/lib/aggregate_messages.py b/python/graphframes/lib/aggregate_messages.py index 7fd7a81a0..c794fe345 100644 --- a/python/graphframes/lib/aggregate_messages.py +++ b/python/graphframes/lib/aggregate_messages.py @@ -15,59 +15,61 @@ # limitations under the License. # +from typing import Any + from pyspark import SparkContext from pyspark.sql import DataFrame, functions as sqlfunctions, SparkSession -def _java_api(jsc): +def _java_api(jsc: SparkContext) -> Any: javaClassName = "org.graphframes.GraphFramePythonAPI" return jsc._jvm.Thread.currentThread().getContextClassLoader().loadClass(javaClassName) \ .newInstance() -class _ClassProperty(object): +class _ClassProperty: """Custom read-only class property descriptor. The underlying method should take the class as the sole argument. """ - def __init__(self, f): + def __init__(self, f: callable) -> None: self.f = f self.__doc__ = f.__doc__ - def __get__(self, instance, owner): + def __get__(self, instance: Any, owner: type) -> Any: return self.f(owner) -class AggregateMessages(object): +class AggregateMessages: """Collection of utilities usable with :meth:`graphframes.GraphFrame.aggregateMessages()`.""" @_ClassProperty - def src(cls): + def src(cls) -> Column: """Reference for source column, used for specifying messages.""" jvm_gf_api = _java_api(SparkContext) return sqlfunctions.col(jvm_gf_api.SRC()) @_ClassProperty - def dst(cls): + def dst(cls) -> Column: """Reference for destination column, used for specifying messages.""" jvm_gf_api = _java_api(SparkContext) return sqlfunctions.col(jvm_gf_api.DST()) @_ClassProperty - def edge(cls): + def edge(cls) -> Column: """Reference for edge column, used for specifying messages.""" jvm_gf_api = _java_api(SparkContext) return sqlfunctions.col(jvm_gf_api.EDGE()) @_ClassProperty - def msg(cls): + def msg(cls) -> Column: """Reference for message column, used for specifying aggregation function.""" jvm_gf_api = _java_api(SparkContext) return sqlfunctions.col(jvm_gf_api.aggregateMessages().MSG_COL_NAME()) @staticmethod - def getCachedDataFrame(df): + def getCachedDataFrame(df: DataFrame) -> DataFrame: """ Create a new cached copy of a DataFrame. diff --git a/python/graphframes/lib/pregel.py b/python/graphframes/lib/pregel.py index 9fcccfb9c..e432ff02d 100644 --- a/python/graphframes/lib/pregel.py +++ b/python/graphframes/lib/pregel.py @@ -16,12 +16,14 @@ # import sys +from typing import Any if sys.version > '3': basestring = str from pyspark.sql import DataFrame, SparkSession from pyspark.sql.functions import col from pyspark.ml.wrapper import JavaWrapper +from graphframes import GraphFrame class Pregel(JavaWrapper): @@ -76,19 +78,19 @@ class Pregel(JavaWrapper): ... .run() """ - def __init__(self, graph): + def __init__(self, graph: GraphFrame) -> None: super(Pregel, self).__init__() self.graph = graph self._java_obj = self._new_java_obj("org.graphframes.lib.Pregel", graph._jvm_graph) - def setMaxIter(self, value): + def setMaxIter(self, value: int) -> "Pregel": """ Sets the max number of iterations (default: 10). """ self._java_obj.setMaxIter(int(value)) return self - def setCheckpointInterval(self, value): + def setCheckpointInterval(self, value: int) -> "Pregel": """ Sets the number of iterations between two checkpoints (default: 2). @@ -100,7 +102,7 @@ def setCheckpointInterval(self, value): self._java_obj.setCheckpointInterval(int(value)) return self - def withVertexColumn(self, colName, initialExpr, updateAfterAggMsgsExpr): + def withVertexColumn(self, colName: str, initialExpr: Any, updateAfterAggMsgsExpr: Any) -> "Pregel": """ Defines an additional vertex column at the start of run and how to update it in each iteration. @@ -118,7 +120,7 @@ def withVertexColumn(self, colName, initialExpr, updateAfterAggMsgsExpr): self._java_obj.withVertexColumn(colName, initialExpr._jc, updateAfterAggMsgsExpr._jc) return self - def sendMsgToSrc(self, msgExpr): + def sendMsgToSrc(self, msgExpr: Any) -> "Pregel": """ Defines a message to send to the source vertex of each edge triplet. @@ -135,7 +137,7 @@ def sendMsgToSrc(self, msgExpr): self._java_obj.sendMsgToSrc(msgExpr._jc) return self - def sendMsgToDst(self, msgExpr): + def sendMsgToDst(self, msgExpr: Any) -> "Pregel": """ Defines a message to send to the destination vertex of each edge triplet. @@ -152,7 +154,7 @@ def sendMsgToDst(self, msgExpr): self._java_obj.sendMsgToDst(msgExpr._jc) return self - def aggMsgs(self, aggExpr): + def aggMsgs(self, aggExpr: Any) -> "Pregel": """ Defines how messages are aggregated after grouped by target vertex IDs. @@ -163,7 +165,7 @@ def aggMsgs(self, aggExpr): self._java_obj.aggMsgs(aggExpr._jc) return self - def run(self): + def run(self) -> DataFrame: """ Runs the defined Pregel algorithm. @@ -172,7 +174,7 @@ def run(self): return DataFrame(self._java_obj.run(), SparkSession.getActiveSession()) @staticmethod - def msg(): + def msg() -> Any: """ References the message column in aggregating messages and updating additional vertex columns. @@ -181,7 +183,7 @@ def msg(): return col("_pregel_msg_") @staticmethod - def src(colName): + def src(colName: str) -> Any: """ References a source vertex column in generating messages to send. @@ -192,7 +194,7 @@ def src(colName): return col("src." + colName) @staticmethod - def dst(colName): + def dst(colName: str) -> Any: """ References a destination vertex column in generating messages to send. @@ -203,7 +205,7 @@ def dst(colName): return col("dst." + colName) @staticmethod - def edge(colName): + def edge(colName: str) -> Any: """ References an edge column in generating messages to send. From db32f9f671c1763bf8921fa7c776857108a243ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Thu, 16 Jan 2025 18:23:10 +0000 Subject: [PATCH 2/6] 2. --- python/graphframes/examples/graphs.py | 2 +- python/graphframes/lib/aggregate_messages.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/graphframes/examples/graphs.py b/python/graphframes/examples/graphs.py index cea05a2ce..8db04aecc 100644 --- a/python/graphframes/examples/graphs.py +++ b/python/graphframes/examples/graphs.py @@ -17,7 +17,7 @@ import itertools -from pyspark.sql import functions as sqlfunctions +from pyspark.sql import functions as sqlfunctions, SparkSession from graphframes import GraphFrame diff --git a/python/graphframes/lib/aggregate_messages.py b/python/graphframes/lib/aggregate_messages.py index c794fe345..c0867dcd0 100644 --- a/python/graphframes/lib/aggregate_messages.py +++ b/python/graphframes/lib/aggregate_messages.py @@ -18,7 +18,7 @@ from typing import Any from pyspark import SparkContext -from pyspark.sql import DataFrame, functions as sqlfunctions, SparkSession +from pyspark.sql import DataFrame, functions as sqlfunctions, SparkSession, Column def _java_api(jsc: SparkContext) -> Any: From 07128f5f14bb014b89b6f9015fe8fd852046dcf4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Thu, 16 Jan 2025 18:27:34 +0000 Subject: [PATCH 3/6] 3. --- python/graphframes/lib/pregel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/graphframes/lib/pregel.py b/python/graphframes/lib/pregel.py index e432ff02d..72077c25c 100644 --- a/python/graphframes/lib/pregel.py +++ b/python/graphframes/lib/pregel.py @@ -23,7 +23,6 @@ from pyspark.sql import DataFrame, SparkSession from pyspark.sql.functions import col from pyspark.ml.wrapper import JavaWrapper -from graphframes import GraphFrame class Pregel(JavaWrapper): @@ -78,8 +77,9 @@ class Pregel(JavaWrapper): ... .run() """ - def __init__(self, graph: GraphFrame) -> None: + def __init__(self, graph: "GraphFrame") -> None: super(Pregel, self).__init__() + from graphframes import GraphFrame self.graph = graph self._java_obj = self._new_java_obj("org.graphframes.lib.Pregel", graph._jvm_graph) From 62df4aeaa9b41311c6cbd6f5e092c4f3d638958f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Thu, 16 Jan 2025 18:31:55 +0000 Subject: [PATCH 4/6] Union for python 3.9 --- python/graphframes/graphframe.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index 7f58b634c..cd7670f7d 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -16,7 +16,7 @@ # import sys -from typing import Any +from typing import Any, Union if sys.version > '3': basestring = str @@ -209,7 +209,7 @@ def find(self, pattern: str) -> DataFrame: jdf = self._jvm_graph.find(pattern) return DataFrame(jdf, self._spark) - def filterVertices(self, condition: str | Column) -> 'GraphFrame': + def filterVertices(self, condition: Union[str, Column]) -> 'GraphFrame': """ Filters the vertices based on expression, remove edges containing any dropped vertices. @@ -225,7 +225,7 @@ def filterVertices(self, condition: str | Column) -> 'GraphFrame': raise TypeError("condition should be string or Column") return _from_java_gf(jdf, self._spark) - def filterEdges(self, condition: str | Column) -> 'GraphFrame': + def filterEdges(self, condition: Union[str, Column]) -> 'GraphFrame': """ Filters the edges based on expression, keep all vertices. @@ -268,9 +268,9 @@ def bfs(self, fromExpr: str, toExpr: str, jdf = builder.run() return DataFrame(jdf, self._spark) - def aggregateMessages(self, aggCol: Column | str, - sendToSrc: Column | str | None = None, - sendToDst: Column | str | None = None) -> DataFrame: + def aggregateMessages(self, aggCol: Union[Column, str], + sendToSrc: Union[Column, str, None] = None, + sendToDst: Union[Column, str, None] = None) -> DataFrame: """ Aggregates messages from the neighbours. From 79be367ae0298a8a89856408f35891fc18235673 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Thu, 16 Jan 2025 18:35:53 +0000 Subject: [PATCH 5/6] old type hints --- python/graphframes/graphframe.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index cd7670f7d..3fc379a21 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -16,7 +16,7 @@ # import sys -from typing import Any, Union +from typing import Any, Union, Optional if sys.version > '3': basestring = str @@ -250,7 +250,7 @@ def dropIsolatedVertices(self) -> 'GraphFrame': return _from_java_gf(jdf, self._spark) def bfs(self, fromExpr: str, toExpr: str, - edgeFilter: str | None = None, + edgeFilter: Optional[str] = None, maxPathLength: int = 10) -> DataFrame: """ Breadth-first search (BFS). @@ -350,9 +350,9 @@ def labelPropagation(self, maxIter: int) -> DataFrame: return DataFrame(jdf, self._spark) def pageRank(self, resetProbability: float = 0.15, - sourceId: Any | None = None, - maxIter: int | None = None, - tol: float | None = None) -> 'GraphFrame': + sourceId: Optional[Any] = None, + maxIter: Optional[int] = None, + tol: Optional[float] = None) -> 'GraphFrame': """ Runs the PageRank algorithm on the graph. Note: Exactly one of fixed_num_iter or tolerance must be set. @@ -380,8 +380,8 @@ def pageRank(self, resetProbability: float = 0.15, return _from_java_gf(jgf, self._spark) def parallelPersonalizedPageRank(self, resetProbability: float = 0.15, - sourceIds: list[Any] | None = None, - maxIter: int | None = None) -> 'GraphFrame': + sourceIds: Optional[list[Any]] = None, + maxIter: Optional[int] = None) -> 'GraphFrame': """ Run the personalized PageRank algorithm on the graph, from the provided list of sources in parallel for a fixed number of iterations. From 824fda74726aea3e7ffd7c78e338e5655292200e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Thu, 16 Jan 2025 18:39:18 +0000 Subject: [PATCH 6/6] import Union --- python/graphframes/examples/belief_propagation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/graphframes/examples/belief_propagation.py b/python/graphframes/examples/belief_propagation.py index 9b5ab2509..c013450d7 100644 --- a/python/graphframes/examples/belief_propagation.py +++ b/python/graphframes/examples/belief_propagation.py @@ -16,6 +16,7 @@ # import math +from typing import Union # Import subpackage examples here explicitly so that # this module can be run directly with spark-submit. @@ -135,7 +136,7 @@ def _colorGraph(g: GraphFrame) -> GraphFrame: return GraphFrame(v, g.edges) @staticmethod - def _sigmoid(x: int | float | None) -> float | None: + def _sigmoid(x: Union[int, float, None]) -> Union[float, None]: """Numerically stable sigmoid function 1 / (1 + exp(-x))""" if not x: return None