diff --git a/docs/src/04-user-guide/11-property-graphs.md b/docs/src/04-user-guide/11-property-graphs.md index 76c9f1298..1b9d220e9 100644 --- a/docs/src/04-user-guide/11-property-graphs.md +++ b/docs/src/04-user-guide/11-property-graphs.md @@ -24,11 +24,11 @@ GraphFrames represent a property graph as a combination of multiple logical stru ### Vertex Property Group -For API details see @:scaladoc(org.graphframes.propertygraph.property.VertexPropertyGroup). It contains a name of the property group, for example, `movies`, a name of ID column and underlying data in the form of a `DataFrame`. +For API details see @:scaladoc(org.graphframes.propertygraph.property.VertexPropertyGroup) (Scala) or `graphframes.pg.VertexPropertyGroup` (Python). It contains a name of the property group, for example, `movies`, a name of ID column and underlying data in the form of a `DataFrame`. The simple example below creates two property groups: `people` and `movies`. -```scala +````scala import org.graphframes.propertygraph.property.VertexPropertyGroup val peopleData = spark @@ -43,15 +43,31 @@ val moviesData = spark .toDF("id", "title") val moviesGroup = VertexPropertyGroup("movies", moviesData, "id") +```` + +```python +from graphframes.pg import VertexPropertyGroup + +people_data = spark.createDataFrame( + [(1, "Alice"), (2, "Bob"), (3, "Charlie"), (4, "David"), (5, "Eve")], + ["id", "name"] +) +people_group = VertexPropertyGroup("people", people_data, "id") + +movies_data = spark.createDataFrame( + [(1, "Matrix"), (2, "Inception"), (3, "Interstellar")], + ["id", "title"] +) +movies_group = VertexPropertyGroup("movies", movies_data, "id") ``` ### Edge Property Group -For API details see @:scaladoc(org.graphframes.propertygraph.property.EdgePropertyGroup). It contains a name of the property group, links to the source and target vertex property groups, direction of the edges (`directed` or `undirected`), and underlying data in the form of a `DataFrame`. Optionally, it can contain a column with edge weights as well as names of source and target vertex ID columns. +For API details see @:scaladoc(org.graphframes.propertygraph.property.EdgePropertyGroup) (Scala) or `graphframes.pg.EdgePropertyGroup` (Python). It contains a name of the property group, links to the source and target vertex property groups, direction of the edges (`directed` or `undirected`), and underlying data in the form of a `DataFrame`. Optionally, it can contain a column with edge weights as well as names of source and target vertex ID columns. The simple example below creates an edge property group with the name `likes` and links to the `people` and `movies` vertex property groups as well as `messages` property group that links people to people. -```scala +````scala import org.graphframes.propertygraph.property.EdgePropertyGroup val likesData = spark @@ -82,22 +98,68 @@ val messagesGroup = EdgePropertyGroup( "src", "dst", col("weight")) +```` + +```python +from pyspark.sql.functions import col, lit +from graphframes.pg import EdgePropertyGroup + +likes_data = spark.createDataFrame( + [(1, 1), (1, 2), (2, 1), (3, 2), (4, 3), (5, 2)], + ["src", "dst"] +).withColumn("weight", lit(1.0)) + +likes_group = EdgePropertyGroup( + "likes", + likes_data, + people_group, + movies_group, + is_directed=False, + src_column_name="src", + dst_column_name="dst", + weight_column_name="weight" +) + +messages_data = spark.createDataFrame( + [(1, 2, 5.0), (2, 3, 8.0), (3, 4, 3.0), (4, 5, 6.0), (5, 1, 9.0)], + ["src", "dst", "weight"] +) + +messages_group = EdgePropertyGroup( + "messages", + messages_data, + people_group, + people_group, + is_directed=True, + src_column_name="src", + dst_column_name="dst", + weight_column_name="weight" +) ``` ### Property GraphFrame Having defined the property groups, we can create a `PropertyGraphFrame` by passing the property groups to the constructor. -```scala +````scala import org.graphframes.propertygraph.PropertyGraphFrame peopleMoviesGraph = PropertyGraphFrame(Seq(peopleGroup, moviesGroup), Seq(likesGroup, messagesGroup)) +```` + +```python +from graphframes.pg import PropertyGraphFrame + +people_movies_graph = PropertyGraphFrame( + [people_group, movies_group], + [likes_group, messages_group] +) ``` ### Conversion to GraphFrames -The `PropertyGraphFrame` can be converted to a `GraphFrame` by calling `toGraphFrame`. Users can select a subset of vertex and edge property groups to be included in the resulting `GraphFrame`. Under the hood, the conversion will take care about handling potential vertex and edge ID collisions by applying hashing to both vertex and edge IDs. +The `PropertyGraphFrame` can be converted to a `GraphFrame` by calling `toGraphFrame` (Scala) or `to_graphframe` (Python). Users can select a subset of vertex and edge property groups to be included in the resulting `GraphFrame`. Under the hood, the conversion will take care about handling potential vertex and edge ID collisions by applying hashing to both vertex and edge IDs. ```scala val graph = peopleMoviesGraph.toGraphFrame( @@ -107,16 +169,37 @@ val graph = peopleMoviesGraph.toGraphFrame( Map("people" -> lit(true))) ``` -For more details see @:scaladoc(org.graphframes.propertygraph.PropertyGraphFrame). +```python +from pyspark.sql.functions import lit + +graph = people_movies_graph.to_graphframe( + vertex_property_groups=["people"], + edge_property_groups=["messages"], + edge_group_filters={"messages": lit(True)}, + vertex_group_filters={"people": lit(True)} +) +``` + +For more details see @:scaladoc(org.graphframes.propertygraph.PropertyGraphFrame) (Scala) or @:pydoc(graphframes.pg.PropertyGraphFrame) (Python). This operation is not free, so user can also explicitly specify for each of `VertexGroup` does it need to be hashed or not. -```scala +````scala val moviesData = spark .createDataFrame(Seq((1L, "Matrix"), (2L, "Inception"), (3L, "Interstellar"))) .toDF("id", "title") val moviesGroup = VertexPropertyGroup("movies", moviesData, "id", applyMaskOnId = false) +```` + +```python +movies_data = spark.createDataFrame( + [(1, "Matrix"), (2, "Inception"), (3, "Interstellar")], + ["id", "title"] +) +movies_group = VertexPropertyGroup( + "movies", movies_data, "id", apply_mask_on_id=False +) ``` ### Projection @@ -127,3 +210,21 @@ The `PropertyGraphFrame` support projection of edges groups to a new edge group. val projectedGraph = peopleMoviesGraph.projectionBy("people", "movies", "likes") ``` +```python +projected_graph = people_movies_graph.projection_by("people", "movies", "likes") +``` + +### Joining Algorithm Results + +After running graph algorithms on a `GraphFrame` created from a `PropertyGraphFrame`, you can join the results back to the original vertex data using `join_vertices` (Python) or `joinVertices` (Scala). + +```scala +val components = graph.connectedComponents() +val joinedBack = peopleMoviesGraph.joinVertices(components, Seq("people", "movies")) +``` + +```python +components = graph.connectedComponents() +joined_back = people_movies_graph.join_vertices(components, ["people", "movies"]) +``` + diff --git a/python/graphframes/graphframe.py b/python/graphframes/graphframe.py index e4abf688e..33c4fcd46 100644 --- a/python/graphframes/graphframe.py +++ b/python/graphframes/graphframe.py @@ -53,6 +53,9 @@ def is_remote() -> bool: """Constant for the edge column name.""" EDGE = "edge" +"""Constant for the weight column name.""" +WEIGHT = "weight" + class GraphFrame: """ @@ -76,6 +79,7 @@ class GraphFrame: SRC: str = SRC DST: str = DST EDGE: str = EDGE + WEIGHT: str = WEIGHT @staticmethod def _from_impl(impl: "GraphFrameClassic | GraphFrameConnect") -> "GraphFrame": diff --git a/python/graphframes/pg/__init__.py b/python/graphframes/pg/__init__.py new file mode 100644 index 000000000..8ea68c3dd --- /dev/null +++ b/python/graphframes/pg/__init__.py @@ -0,0 +1,25 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from graphframes.pg.property_graphframe import PropertyGraphFrame +from graphframes.pg.property_groups import EdgePropertyGroup, VertexPropertyGroup + +__all__ = [ + "VertexPropertyGroup", + "EdgePropertyGroup", + "PropertyGraphFrame", +] diff --git a/python/graphframes/pg/property_graphframe.py b/python/graphframes/pg/property_graphframe.py new file mode 100644 index 000000000..72553a04c --- /dev/null +++ b/python/graphframes/pg/property_graphframe.py @@ -0,0 +1,382 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""PropertyGraphFrame implementation for PySpark.""" + +from __future__ import annotations + +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING + +from pyspark.sql.functions import col, lit + +from graphframes.pg.property_groups import EdgePropertyGroup, VertexPropertyGroup + +if TYPE_CHECKING: + from pyspark.sql import Column, DataFrame + +from graphframes import GraphFrame + + +class PropertyGraphFrame: + """ + A high-level abstraction for working with property graphs in PySpark. + + PropertyGraphFrame serves as a logical structure that manages collections of vertex and edge + property groups, providing a user-friendly API for graph operations. It handles various + internal complexities such as: + + - ID conversion and collision prevention + - Management of directed/undirected graph representations + - Handling of weighted/unweighted edges + - Data consistency across different property groups + + The class maintains separate collections for vertex and edge properties, allowing for flexible + graph construction while ensuring data integrity. + + Example: + >>> from graphframes.pg import VertexPropertyGroup, EdgePropertyGroup, PropertyGraphFrame + >>> from graphframes import GraphFrame + >>> + >>> # Create vertex groups + >>> people_data = spark.createDataFrame([(1, "Alice"), (2, "Bob")], ["id", "name"]) + >>> people_group = VertexPropertyGroup("people", people_data, "id") + >>> + >>> movies_data = spark.createDataFrame([(1, "Matrix"), (2, "Inception")], ["id", "title"]) + >>> movies_group = VertexPropertyGroup("movies", movies_data, "id") + >>> + >>> # Create edge group + >>> likes_data = spark.createDataFrame([(1, 1, 1.0)], ["src", "dst", "weight"]) + >>> likes_group = EdgePropertyGroup( + ... "likes", likes_data, people_group, movies_group, + ... is_directed=False, src_column_name="src", dst_column_name="dst", + ... weight_column_name="weight" + ... ) + >>> + >>> # Create property graph + >>> pg = PropertyGraphFrame([people_group, movies_group], [likes_group]) + + :param vertex_property_groups: Sequence of vertex property groups + :param edges_property_groups: Sequence of edge property groups + """ + + PROPERTY_GROUP_COL_NAME = "property_group" + EXTERNAL_ID = "external_id" + + def __init__( + self, + vertex_property_groups: Sequence, + edges_property_groups: Sequence, + ) -> None: + """ + Initialize a PropertyGraphFrame. + + :param vertex_property_groups: Sequence of vertex property groups + :param edges_property_groups: Sequence of edge property groups + """ + + # Validate input types + for group in vertex_property_groups: + if not isinstance(group, VertexPropertyGroup): + raise TypeError( + f"All vertex_property_groups must be VertexPropertyGroup instances, " + f"got {type(group)}" + ) + + for group in edges_property_groups: + if not isinstance(group, EdgePropertyGroup): + raise TypeError( + f"All edges_property_groups must be EdgePropertyGroup instances, " + f"got {type(group)}" + ) + + self._vertex_property_groups = list(vertex_property_groups) + self._edges_property_groups = list(edges_property_groups) + + # Create lookup maps + self._vertex_groups: dict[str, VertexPropertyGroup] = { + group.name: group for group in self._vertex_property_groups + } + self._edge_groups: dict[str, EdgePropertyGroup] = { + group.name: group for group in self._edges_property_groups + } + + @property + def vertex_property_groups(self) -> list[VertexPropertyGroup]: + """Return the list of vertex property groups.""" + + return self._vertex_property_groups + + @property + def edges_property_groups(self) -> list[EdgePropertyGroup]: + """Return the list of edge property groups.""" + + return self._edges_property_groups + + def to_graphframe( + self, + vertex_property_groups: Sequence[str], + edge_property_groups: Sequence[str], + edge_group_filters: dict[str, Column] | None = None, + vertex_group_filters: dict[str, Column] | None = None, + ) -> GraphFrame: + """ + Convert the property graph to a unified GraphFrame representation. + + This method transforms a property graph that may contain multiple vertex types and both + directed and undirected edges into a single GraphFrame object where all vertices and edges + share the same schema. The conversion process handles: + + - Internal ID generation and collision prevention by hashing vertex/edge IDs with their + group names + - Merging of different vertex types into a unified vertex DataFrame + - Conversion of directed/undirected edge relationships into a consistent edge DataFrame + - Filtering of vertices and edges based on provided predicates + + :param vertex_property_groups: Sequence of vertex property group names to include + :param edge_property_groups: Sequence of edge property group names to include + :param edge_group_filters: Optional dict mapping edge group names to filter predicates + :param vertex_group_filters: Optional dict mapping vertex group names to filter predicates + :return: A GraphFrame containing the unified representation + :raises ValueError: If a specified group name does not exist + + Example: + >>> from pyspark.sql.functions import lit + >>> graph = pg.to_graph_frame( + ... vertex_property_groups=["people", "movies"], + ... edge_property_groups=["likes", "messages"], + ... edge_group_filters={"likes": lit(True), "messages": lit(True)}, + ... vertex_group_filters={"people": lit(True), "movies": lit(True)} + ... ) + """ + # Set default filters if not provided + if edge_group_filters is None: + edge_group_filters = {} + if vertex_group_filters is None: + vertex_group_filters = {} + + # Validate group names + for name in vertex_property_groups: + if name not in self._vertex_groups: + raise ValueError(f"Vertex property group '{name}' does not exist") + + for name in edge_property_groups: + if name not in self._edge_groups: + raise ValueError(f"Edge property group '{name}' does not exist") + + # Combine vertices from all specified groups + if not vertex_property_groups: + raise ValueError("At least one vertex property group must be specified") + + vertices_list = [] + for name in vertex_property_groups: + filter_col = vertex_group_filters.get(name, lit(True)) + group_data = self._vertex_groups[name].get_data(filter_col) + vertices_list.append(group_data) + + vertices = vertices_list[0] + for v in vertices_list[1:]: + vertices = vertices.union(v) + + # Combine edges from all specified groups + if not edge_property_groups: + raise ValueError("At least one edge property group must be specified") + + edges_list = [] + for name in edge_property_groups: + filter_col = edge_group_filters.get(name, lit(True)) + group_data = self._edge_groups[name].get_data(filter_col) + edges_list.append(group_data) + + edges = edges_list[0] + for e in edges_list[1:]: + edges = edges.union(e) + + return GraphFrame(vertices, edges) + + def projection_by( + self, + left_bi_graph_part: str, + right_bi_graph_part: str, + edge_group: str, + new_edge_weight: Callable[[Column, Column], Column] | None = None, + ) -> "PropertyGraphFrame": + """ + Project a bipartite graph onto one of its parts. + + Creates edges between vertices that share neighbors in the other part. Drops the property + group used for projection and returns a new property graph. + + :param left_bi_graph_part: Name of the vertex property group to project onto + :param right_bi_graph_part: Name of the vertex property group to project through + :param edge_group: Name of the edge property group connecting the two parts + :param new_edge_weight: Optional function that takes two weight columns and returns + a new weight column. If None, uses weight 1.0 for all edges. + :return: A new PropertyGraphFrame containing the projected graph + :raises ValueError: If group names are invalid or edge group doesn't connect the parts + + Example: + >>> # Project people through movies they both like + >>> projected = pg.projection_by("people", "movies", "likes") + >>> # Custom weight function + >>> from pyspark.sql.functions import col + >>> projected = pg.projection_by( + ... "people", "movies", "likes", + ... new_edge_weight=lambda w1, w2: w1 + w2 + ... ) + """ + # Validate inputs + if edge_group not in self._edge_groups: + raise ValueError(f"Edge property group '{edge_group}' does not exist") + + if left_bi_graph_part not in self._vertex_groups: + raise ValueError(f"Vertex property group '{left_bi_graph_part}' does not exist") + + if right_bi_graph_part not in self._vertex_groups: + raise ValueError(f"Vertex property group '{right_bi_graph_part}' does not exist") + + old_group = self._edge_groups[edge_group] + + # Validate edge group connects the specified parts + if old_group.src_property_group.name != left_bi_graph_part: + raise ValueError( + f"Edge property group should have '{left_bi_graph_part}' as source " + f"but has '{old_group.src_property_group.name}'" + ) + + if old_group.dst_property_group.name != right_bi_graph_part: + raise ValueError( + f"Edge property group should have '{right_bi_graph_part}' as destination " + f"but has '{old_group.dst_property_group.name}'" + ) + + # Get vertex groups to keep + kept_v_property_groups = [ + g for g in self._vertex_property_groups if g.name != right_bi_graph_part + ] + + # Get edge groups to keep (excluding the one being projected) + kept_e_property_groups = [g for g in self._edges_property_groups if g.name != edge_group] + + # Create projected edges by joining edges through common neighbors + old_edges_data = old_group.data + + e1 = old_edges_data.alias("e1") + e2 = old_edges_data.alias("e2") + + # Join edges on common destination (the right part) + joined = e1.join( + e2, col("e1." + old_group.dst_column_name) == col("e2." + old_group.dst_column_name) + ) + + # Filter to avoid duplicates (e1.src < e2.src) + joined = joined.filter( + col("e1." + old_group.src_column_name) < col("e2." + old_group.src_column_name) + ) + + # Add weight column + if new_edge_weight is not None: + w1 = col(f"e1.{old_group.weight_column_name}") + w2 = col(f"e2.{old_group.weight_column_name}") + weight_col = new_edge_weight(w1, w2) + else: + weight_col = lit(1.0) + + # Select source and destination for new edges + projected_edges = joined.select( + col("e1." + old_group.src_column_name).alias(GraphFrame.SRC), + col("e2." + old_group.src_column_name).alias(GraphFrame.DST), + weight_col.alias(GraphFrame.WEIGHT), + ) + + # Create new edge property group + left_group = self._vertex_groups[left_bi_graph_part] + + new_edge_group = EdgePropertyGroup( + name=f"projected_{edge_group}", + data=projected_edges, + src_property_group=left_group, + dst_property_group=left_group, + is_directed=False, + src_column_name=GraphFrame.SRC, + dst_column_name=GraphFrame.DST, + weight_column_name=GraphFrame.WEIGHT, + ) + + return PropertyGraphFrame(kept_v_property_groups, kept_e_property_groups + [new_edge_group]) + + def join_vertices( + self, + vertices_data: DataFrame, + vertex_groups: Sequence[str], + ) -> DataFrame: + """ + Join algorithm results back to the original vertex data. + + Joins the vertices data (typically output from graph algorithms) with the specified + vertex property groups to produce a unified DataFrame with original vertex attributes. + + :param vertices_data: DataFrame containing vertex algorithm results (from to_graph_frame) + :param vertex_groups: Sequence of vertex group names to join + :return: A DataFrame with joined vertex data + :raises ValueError: If a specified group name does not exist + + Example: + >>> # Run connected components and join results back + >>> graph = pg.to_graph_frame(["people"], ["messages"], {}, {}) + >>> components = graph.connectedComponents() + >>> joined = pg.join_vertices(components, ["people"]) + """ + # Validate group names + for name in vertex_groups: + if name not in self._vertex_groups: + raise ValueError(f"Vertex property group '{name}' does not exist") + + if not vertex_groups: + raise ValueError("At least one vertex group must be specified") + + # Join each group separately + result_dfs = [] + + for vg_name in vertex_groups: + group = self._vertex_groups[vg_name] + + # Filter vertices data for this group + filtered = vertices_data.filter( + col(PropertyGraphFrame.PROPERTY_GROUP_COL_NAME) == lit(vg_name) + ) + + if group.apply_mask_on_id: + # Use internal ID mapping to join back to original data + id_mapping = group._get_internal_id_mapping() + joined = id_mapping.join(filtered, [GraphFrame.ID], "left").drop(GraphFrame.ID) + else: + # Direct join on ID + joined = ( + group.get_data() + .join(filtered, GraphFrame.ID, "left") + .withColumnRenamed(GraphFrame.ID, PropertyGraphFrame.EXTERNAL_ID) + ) + + result_dfs.append(joined) + + # Union all results + result = result_dfs[0] + for df in result_dfs[1:]: + result = result.union(df) + + return result diff --git a/python/graphframes/pg/property_groups.py b/python/graphframes/pg/property_groups.py new file mode 100644 index 000000000..87e75377b --- /dev/null +++ b/python/graphframes/pg/property_groups.py @@ -0,0 +1,387 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Property group classes for property graphs.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from pyspark.sql.functions import col, concat, lit, sha2 +from pyspark.sql.types import ( + ByteType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + ShortType, + StringType, +) + +from graphframes import GraphFrame + +if TYPE_CHECKING: + from pyspark.sql import Column, DataFrame + + +class InvalidPropertyGroupException(Exception): + """Exception raised when a property group is invalid.""" + + pass + + +class PropertyGroup(ABC): + """Abstract base class for property groups.""" + + def __init__(self, name: str, data: DataFrame) -> None: + """ + Initialize a property group. + + :param name: The unique identifier for this property group + :param data: The DataFrame containing the property data + """ + self._name = name + self._data = data + self._validate() + + @property + def name(self) -> str: + """Return the name of the property group.""" + return self._name + + @property + def data(self) -> DataFrame: + """Return the DataFrame containing the property data.""" + return self._data + + @abstractmethod + def _validate(self) -> None: + """Validate the property group. Must be implemented by subclasses.""" + pass + + def get_data(self, filter_col: Column | None = None) -> DataFrame: + """ + Return a view of the data for the property group. + + :param filter_col: An optional filter condition (Column) to apply to the data + :return: A DataFrame containing the filtered and optionally transformed data + """ + + if filter_col is None: + filter_col = lit(True) + return self._get_data(filter_col) + + @abstractmethod + def _get_data(self, filter_col: Column) -> DataFrame: + """Internal method to get filtered data. Must be implemented by subclasses.""" + pass + + +class VertexPropertyGroup(PropertyGroup): + """ + Represents a logical group of vertices in a property graph. + + A VertexPropertyGroup organizes and manages vertices that share common characteristics + or belong to the same logical group within a property graph. Each group maintains its + own data in the form of a DataFrame and uses a primary key column for unique vertex + identification. + + When vertices from different groups are combined into a GraphFrame, their IDs are + hashed with the group name to prevent collisions. + + Example: + >>> people_data = spark.createDataFrame([(1, "Alice"), (2, "Bob")], ["id", "name"]) + >>> people_group = VertexPropertyGroup("people", people_data, "id") + + :param name: The unique identifier for this vertex property group + :param data: The DataFrame containing the vertex data + :param primary_key_column: The column name used to uniquely identify vertices + :param apply_mask_on_id: Whether to hash IDs with group name (default: True) + """ + + def __init__( + self, + name: str, + data: DataFrame, + primary_key_column: str = "id", + apply_mask_on_id: bool = True, + ) -> None: + """ + Initialize a VertexPropertyGroup. + + :param name: Name of the vertex property group + :param data: DataFrame containing vertex data + :param primary_key_column: Name of the column to use as primary key (default: "id") + :param apply_mask_on_id: Whether to apply masking on vertex IDs (default: True) + """ + self._primary_key_column = primary_key_column + self._apply_mask_on_id = apply_mask_on_id + super().__init__(name, data) + + @property + def primary_key_column(self) -> str: + """Return the primary key column name.""" + return self._primary_key_column + + @property + def apply_mask_on_id(self) -> bool: + """Return whether ID masking is applied.""" + return self._apply_mask_on_id + + def _validate(self) -> None: + """Validate that the primary key column exists in the data.""" + if self._primary_key_column not in self._data.columns: + raise InvalidPropertyGroupException( + f"source column {self._primary_key_column} does not exist, " + f"existed columns [{', '.join(self._data.columns)}]" + ) + + def _get_internal_id_mapping(self) -> DataFrame: + """ + Create a mapping from external IDs to internal hashed IDs. + + :return: DataFrame with columns 'external_id' and 'id' + """ + + EXTERNAL_ID = "external_id" + + return self._data.select(col(self._primary_key_column).alias(EXTERNAL_ID)).withColumn( + GraphFrame.ID, + concat( + lit(self._name), + sha2(col(EXTERNAL_ID).cast(StringType()), 256), + ), + ) + + def _get_data(self, filter_col: Column) -> DataFrame: + """ + Return filtered vertex data with internal IDs and property group column. + + :param filter_col: Filter condition to apply + :return: DataFrame with columns 'id' and 'property_group' + """ + PROPERTY_GROUP_COL_NAME = "property_group" + + filtered_data = self._data.filter(filter_col) + + if self._apply_mask_on_id: + result = filtered_data.select( + concat( + lit(self._name), + sha2(col(self._primary_key_column).cast(StringType()), 256), + ).alias(GraphFrame.ID) + ) + else: + result = filtered_data.select( + col(self._primary_key_column).cast(StringType()).alias(GraphFrame.ID) + ) + + return result.select( + col(GraphFrame.ID), + lit(self._name).alias(PROPERTY_GROUP_COL_NAME), + ) + + +class EdgePropertyGroup(PropertyGroup): + """ + Represents a logical group of edges in a property graph. + + EdgePropertyGroup encapsulates edge data stored in a DataFrame along with metadata + describing how to interpret the data as graph edges. Each edge group has: + + - A unique name identifier + - DataFrame containing the actual edge data + - Source and destination vertex property groups + - Direction flag indicating if edges are directed or undirected + - Column names specifying source vertex, destination vertex, and edge weight + + When edges from different groups are combined into a GraphFrame, their src and dst + are hashed with the group name to prevent ID collisions. + + Example: + >>> edges_data = spark.createDataFrame([(1, 2, 1.0)], ["src", "dst", "weight"]) + >>> edges_group = EdgePropertyGroup( + ... "likes", edges_data, people_group, movies_group, + ... is_directed=False, src_column="src", dst_column="dst", weight_column="weight" + ... ) + + :param name: Unique identifier for this edge property group + :param data: DataFrame containing the edge data + :param src_property_group: Source vertex property group + :param dst_property_group: Destination vertex property group + :param is_directed: Whether edges should be treated as directed + :param src_column_name: Name of the source vertex column in the data + :param dst_column_name: Name of the destination vertex column in the data + :param weight_column_name: Name of the edge weight column in the data + """ + + def __init__( + self, + name: str, + data: DataFrame, + src_property_group: VertexPropertyGroup, + dst_property_group: VertexPropertyGroup, + is_directed: bool, + src_column_name: str, + dst_column_name: str, + weight_column_name: str | None = None, + ) -> None: + """ + Initialize an EdgePropertyGroup. + + :param name: Unique identifier for this edge property group + :param data: DataFrame containing the edge data with required columns + :param src_property_group: Source vertex property group + :param dst_property_group: Destination vertex property group + :param is_directed: Whether edges are directed (True) or undirected (False) + :param src_column_name: Name of the source vertex column + :param dst_column_name: Name of the destination vertex column + :param weight_column_name: Name of the edge weight column + (None means the lit(1).alias("weight") will be used) + """ + if weight_column_name is None: + data = data.withColumn("weight", lit(1.0)) + weight_column_name = "weight" + + self._src_property_group = src_property_group + self._dst_property_group = dst_property_group + self._is_directed = is_directed + self._src_column_name = src_column_name + self._dst_column_name = dst_column_name + self._weight_column_name = weight_column_name + super().__init__(name, data) + + @property + def src_property_group(self) -> VertexPropertyGroup: + """Return the source vertex property group.""" + return self._src_property_group + + @property + def dst_property_group(self) -> VertexPropertyGroup: + """Return the destination vertex property group.""" + return self._dst_property_group + + @property + def is_directed(self) -> bool: + """Return whether edges are directed.""" + return self._is_directed + + @property + def src_column_name(self) -> str: + """Return the source column name.""" + return self._src_column_name + + @property + def dst_column_name(self) -> str: + """Return the destination column name.""" + return self._dst_column_name + + @property + def weight_column_name(self) -> str: + """Return the weight column name.""" + return self._weight_column_name + + def _validate(self) -> None: + """Validate that required columns exist and weight column is numeric.""" + if self._src_column_name not in self._data.columns: + raise InvalidPropertyGroupException( + f"source column {self._src_column_name} does not exist, " + f"existed columns [{', '.join(self._data.columns)}]" + ) + if self._dst_column_name not in self._data.columns: + raise InvalidPropertyGroupException( + f"dest column {self._dst_column_name} does not exist, " + f"existed columns [{', '.join(self._data.columns)}]" + ) + if self._weight_column_name not in self._data.columns: + raise InvalidPropertyGroupException( + f"weight column {self._weight_column_name} does not exist, " + f"existed columns [{', '.join(self._data.columns)}]" + ) + + # Check weight column type + weight_column_type = self._data.schema[self._weight_column_name].dataType + if not self._is_numeric_type(weight_column_type): + _msg = "weight column {} must be numeric type, but was {}" + raise InvalidPropertyGroupException( + _msg.format(self._weight_column_name, weight_column_type) + ) + + def _is_numeric_type(self, data_type) -> bool: + """Check if a Spark data type is numeric.""" + + numeric_types = ( + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DecimalType, + ) + return isinstance(data_type, numeric_types) + + def _hash_src_edge(self) -> Column: + """Hash the source edge ID based on the source property group settings.""" + + if self._src_property_group.apply_mask_on_id: + return concat( + lit(self._src_property_group.name), + sha2(col(self._src_column_name).cast(StringType()), 256), + ) + else: + return col(self._src_column_name).cast(StringType()) + + def _hash_dst_edge(self) -> Column: + """Hash the destination edge ID based on the destination property group settings.""" + if self._dst_property_group.apply_mask_on_id: + return concat( + lit(self._dst_property_group.name), + sha2(col(self._dst_column_name).cast(StringType()), 256), + ) + else: + return col(self._dst_column_name).cast(StringType()) + + def _get_data(self, filter_col: Column) -> DataFrame: + """ + Return filtered edge data with hashed IDs and weights. + + For undirected edges, creates bidirectional edges. + + :param filter_col: Filter condition to apply + :return: DataFrame with columns 'src', 'dst', and 'weight' + """ + filtered_data = self._data.filter(filter_col) + + base_edges = filtered_data.select( + self._hash_src_edge().alias(GraphFrame.SRC), + self._hash_dst_edge().alias(GraphFrame.DST), + col(self._weight_column_name).alias(GraphFrame.WEIGHT), + ) + + if self._is_directed: + return base_edges + else: + # For undirected edges, create bidirectional edges + reverse_edges = base_edges.select( + col(GraphFrame.DST).alias(GraphFrame.SRC), + col(GraphFrame.SRC).alias(GraphFrame.DST), + col(GraphFrame.WEIGHT).alias(GraphFrame.WEIGHT), + ) + return base_edges.union(reverse_edges) diff --git a/python/tests/pg/__init__.py b/python/tests/pg/__init__.py new file mode 100644 index 000000000..7fa495369 --- /dev/null +++ b/python/tests/pg/__init__.py @@ -0,0 +1 @@ +# Tests for property graph module diff --git a/python/tests/pg/test_property_graphframe.py b/python/tests/pg/test_property_graphframe.py new file mode 100644 index 000000000..d809351be --- /dev/null +++ b/python/tests/pg/test_property_graphframe.py @@ -0,0 +1,405 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import hashlib + +import pytest +from pyspark.sql import SparkSession +from pyspark.sql.functions import lit + +from graphframes import GraphFrame +from graphframes.pg import EdgePropertyGroup, PropertyGraphFrame, VertexPropertyGroup + + +def sha256_hash(id_val, group_name): + """Helper to compute SHA256 hash like Scala does.""" + hash_val = hashlib.sha256(str(id_val).encode("utf-8")).hexdigest() + return f"{group_name}{hash_val}" + + +@pytest.fixture(scope="module") +def people_group(spark: SparkSession): + people_data = spark.createDataFrame( + [(1, "Alice"), (2, "Bob"), (3, "Charlie"), (4, "David"), (5, "Eve")], + ["id", "name"], + ) + return VertexPropertyGroup("people", people_data, "id") + + +@pytest.fixture(scope="module") +def movies_group(spark: SparkSession): + movies_data = spark.createDataFrame( + [(1, "Matrix"), (2, "Inception"), (3, "Interstellar")], + ["id", "title"], + ) + return VertexPropertyGroup("movies", movies_data, "id") + + +@pytest.fixture(scope="module") +def likes_group(spark: SparkSession, people_group: VertexPropertyGroup, movies_group: VertexPropertyGroup): + likes_data = spark.createDataFrame( + [(1, 1), (1, 2), (2, 1), (3, 2), (4, 3), (5, 2)], + ["src", "dst"], + ) + likes_data_with_weight = likes_data.withColumn("weight", lit(1.0)) + return EdgePropertyGroup( + "likes", + likes_data_with_weight, + people_group, + movies_group, + is_directed=False, + src_column_name="src", + dst_column_name="dst", + weight_column_name="weight", + ) + + +@pytest.fixture(scope="module") +def messages_group(spark: SparkSession, people_group: VertexPropertyGroup): + messages_data = spark.createDataFrame( + [(1, 2, 5.0), (2, 3, 8.0), (3, 4, 3.0), (4, 5, 6.0), (5, 1, 9.0)], + ["src", "dst", "weight"], + ) + return EdgePropertyGroup( + "messages", + messages_data, + people_group, + people_group, + is_directed=True, + src_column_name="src", + dst_column_name="dst", + weight_column_name="weight", + ) + + +@pytest.fixture(scope="module") +def people_movies_graph( + people_group: VertexPropertyGroup, + movies_group: VertexPropertyGroup, + likes_group: EdgePropertyGroup, + messages_group: EdgePropertyGroup, +): + return PropertyGraphFrame( + [people_group, movies_group], + [likes_group, messages_group], + ) + + +def test_property_graph_frame_constructor(people_movies_graph: PropertyGraphFrame) -> None: + assert len(people_movies_graph.vertex_property_groups) == 2 + assert len(people_movies_graph.edges_property_groups) == 2 + + +def test_vertex_property_group_creation(people_group: VertexPropertyGroup) -> None: + assert people_group.name == "people" + assert people_group.primary_key_column == "id" + assert people_group.apply_mask_on_id + + +def test_edge_property_group_creation( + likes_group: EdgePropertyGroup, +) -> None: + assert likes_group.name == "likes" + assert likes_group.src_property_group.name == "people" + assert likes_group.dst_property_group.name == "movies" + assert not likes_group.is_directed + + +def test_projection_by_movies(people_movies_graph: PropertyGraphFrame) -> None: + projected_graph = people_movies_graph.projection_by("people", "movies", "likes") + + assert len(projected_graph.vertex_property_groups) == 1 + assert projected_graph.vertex_property_groups[0].name == "people" + + assert len(projected_graph.edges_property_groups) == 2 + assert any(group.name == "messages" for group in projected_graph.edges_property_groups) + + projected_edges_group = next( + ( + group + for group in projected_graph.edges_property_groups + if group.name == "projected_likes" + ), + None, + ) + assert projected_edges_group is not None + assert projected_edges_group.src_column_name == GraphFrame.SRC + assert projected_edges_group.dst_column_name == GraphFrame.DST + assert projected_edges_group.weight_column_name == GraphFrame.WEIGHT + assert not projected_edges_group.is_directed + + projected_edges = projected_edges_group.data.collect() + edge_pairs = {(row.src, row.dst) for row in projected_edges} + + expected_edges = { + (1, 2), # Alice and Bob both like Matrix + (1, 3), # Alice and Charlie both like Inception + (1, 5), # Alice and Eve both like Inception + (3, 5), # Charlie and Eve both like Inception + } + assert edge_pairs == expected_edges + + +def test_projection_with_custom_weight(people_movies_graph: PropertyGraphFrame) -> None: + projected_graph = people_movies_graph.projection_by( + "people", "movies", "likes", new_edge_weight=lambda w1, w2: w1 + w2 + ) + + projected_edges_group = next( + ( + group + for group in projected_graph.edges_property_groups + if group.name == "projected_likes" + ), + None, + ) + assert projected_edges_group is not None + + projected_edges = projected_edges_group.data.collect() + edge_triples = {(row.src, row.dst, row.weight) for row in projected_edges} + + expected_edges = { + (1, 2, 2.0), + (1, 3, 2.0), + (1, 5, 2.0), + (3, 5, 2.0), + } + assert edge_triples == expected_edges + + +def test_to_graph_frame_messages_only(people_movies_graph: PropertyGraphFrame) -> None: + graph = people_movies_graph.to_graphframe( + vertex_property_groups=["people"], + edge_property_groups=["messages"], + edge_group_filters={"messages": lit(True)}, + vertex_group_filters={"people": lit(True)}, + ) + + vertices = {row.id for row in graph.vertices.collect()} + edges = {(row.src, row.dst, row.weight) for row in graph.edges.collect()} + + expected_vertices = {sha256_hash(i, "people") for i in range(1, 6)} + assert vertices == expected_vertices + + expected_edges = { + (sha256_hash(1, "people"), sha256_hash(2, "people"), 5.0), + (sha256_hash(2, "people"), sha256_hash(3, "people"), 8.0), + (sha256_hash(3, "people"), sha256_hash(4, "people"), 3.0), + (sha256_hash(4, "people"), sha256_hash(5, "people"), 6.0), + (sha256_hash(5, "people"), sha256_hash(1, "people"), 9.0), + } + assert edges == expected_edges + + +def test_to_graph_frame_all_groups(people_movies_graph: PropertyGraphFrame) -> None: + graph = people_movies_graph.to_graphframe( + vertex_property_groups=["people", "movies"], + edge_property_groups=["messages", "likes"], + edge_group_filters={"messages": lit(True), "likes": lit(True)}, + vertex_group_filters={"people": lit(True), "movies": lit(True)}, + ) + + vertices = graph.vertices.collect() + edges = graph.edges.collect() + + assert len(vertices) == 8 # 5 people + 3 movies + + vertex_ids = {row.id for row in vertices} + assert sha256_hash(1, "movies") in vertex_ids + assert sha256_hash(1, "people") in vertex_ids + + message_edges = [e for e in edges if e.weight != 1.0] + like_edges = [e for e in edges if e.weight == 1.0] + + assert len(message_edges) == 5 # Directed messages + assert len(like_edges) == 12 # 6 undirected edges * 2 + + +def test_to_graph_frame_unmasked_ids( + spark: SparkSession, + people_group: VertexPropertyGroup, + likes_group: EdgePropertyGroup, + messages_group: EdgePropertyGroup, +) -> None: + movies_data = spark.createDataFrame( + [(1, "Matrix"), (2, "Inception"), (3, "Interstellar")], + ["id", "title"], + ) + unmasked_movies_group = VertexPropertyGroup( + "movies", movies_data, "id", apply_mask_on_id=False + ) + + new_likes_group = EdgePropertyGroup( + "likes", + likes_group.data, + likes_group.src_property_group, + unmasked_movies_group, + likes_group.is_directed, + likes_group.src_column_name, + likes_group.dst_column_name, + likes_group.weight_column_name, + ) + + modified_graph = PropertyGraphFrame( + [people_group, unmasked_movies_group], + [new_likes_group, messages_group], + ) + + graph = modified_graph.to_graphframe( + vertex_property_groups=["people", "movies"], + edge_property_groups=["messages", "likes"], + edge_group_filters={"messages": lit(True), "likes": lit(True)}, + vertex_group_filters={"people": lit(True), "movies": lit(True)}, + ) + + vertices = {row.id for row in graph.vertices.collect()} + edges = graph.edges.collect() + + assert "1" in vertices + assert "2" in vertices + assert "3" in vertices + assert sha256_hash(1, "people") in vertices + + likes_edges = [e for e in edges if e.weight == 1.0] + assert any( + e.src == sha256_hash(1, "people") and e.dst == "1" for e in likes_edges + ) + assert any( + e.src == "1" and e.dst == sha256_hash(1, "people") for e in likes_edges + ) + + +def test_join_vertices_with_connected_components( + people_movies_graph: PropertyGraphFrame, +) -> None: + graph = people_movies_graph.to_graphframe( + vertex_property_groups=["people", "movies"], + edge_property_groups=["messages", "likes"], + edge_group_filters={"messages": lit(True), "likes": lit(True)}, + vertex_group_filters={"people": lit(True), "movies": lit(True)}, + ) + + components = graph.connectedComponents() + + joined_back = people_movies_graph.join_vertices( + components, vertex_groups=["people", "movies"] + ) + + joined_data = joined_back.collect() + + by_group = {} + for row in joined_data: + group = row.property_group + if group not in by_group: + by_group[group] = [] + by_group[group].append(row) + + assert "movies" in by_group + assert "people" in by_group + assert len(by_group["movies"]) == 3 + assert len(by_group["people"]) == 5 + + +def test_vertex_property_group_validation(people_group: VertexPropertyGroup) -> None: + from graphframes.pg.property_groups import InvalidPropertyGroupException + + with pytest.raises(InvalidPropertyGroupException): + VertexPropertyGroup("test", people_group.data, "nonexistent_column") + + +def test_edge_property_group_validation( + people_group: VertexPropertyGroup, + movies_group: VertexPropertyGroup, + likes_group: EdgePropertyGroup, +) -> None: + from graphframes.pg.property_groups import InvalidPropertyGroupException + + with pytest.raises(InvalidPropertyGroupException): + EdgePropertyGroup( + "test", + likes_group.data, + people_group, + movies_group, + is_directed=True, + src_column_name="nonexistent", + dst_column_name="dst", + weight_column_name="weight", + ) + + with pytest.raises(InvalidPropertyGroupException): + EdgePropertyGroup( + "test", + likes_group.data, + people_group, + movies_group, + is_directed=True, + src_column_name="src", + dst_column_name="nonexistent", + weight_column_name="weight", + ) + + with pytest.raises(InvalidPropertyGroupException): + EdgePropertyGroup( + "test", + likes_group.data, + people_group, + movies_group, + is_directed=True, + src_column_name="src", + dst_column_name="dst", + weight_column_name="nonexistent", + ) + + +def test_to_graph_frame_invalid_group(people_movies_graph: PropertyGraphFrame) -> None: + with pytest.raises(ValueError): + people_movies_graph.to_graphframe( + vertex_property_groups=["nonexistent"], + edge_property_groups=["likes"], + ) + + with pytest.raises(ValueError): + people_movies_graph.to_graphframe( + vertex_property_groups=["people"], + edge_property_groups=["nonexistent"], + ) + + +def test_projection_by_invalid_group(people_movies_graph: PropertyGraphFrame) -> None: + with pytest.raises(ValueError): + people_movies_graph.projection_by("nonexistent", "movies", "likes") + + with pytest.raises(ValueError): + people_movies_graph.projection_by("people", "nonexistent", "likes") + + with pytest.raises(ValueError): + people_movies_graph.projection_by("people", "movies", "nonexistent") + + +def test_property_graph_frame_to_graph_frame_conversion( + people_movies_graph: PropertyGraphFrame, +) -> None: + graph = people_movies_graph.to_graphframe( + vertex_property_groups=["people"], + edge_property_groups=["messages"], + ) + + assert isinstance(graph, GraphFrame) + assert GraphFrame.ID in graph.vertices.columns + assert GraphFrame.SRC in graph.edges.columns + assert GraphFrame.DST in graph.edges.columns + assert GraphFrame.WEIGHT in graph.edges.columns