|
10 | 10 | import pytest |
11 | 11 | from google.cloud import bigquery |
12 | 12 | from pandas.testing import assert_frame_equal |
| 13 | +from pytz import utc |
13 | 14 |
|
14 | 15 | import feast.driver_test_data as driver_data |
| 16 | +from feast import utils |
15 | 17 | from feast.data_source import BigQuerySource, FileSource |
16 | 18 | from feast.entity import Entity |
17 | 19 | from feast.feature import Feature |
@@ -98,74 +100,93 @@ def create_customer_daily_profile_feature_view(source): |
98 | 100 | return customer_profile_feature_view |
99 | 101 |
|
100 | 102 |
|
| 103 | +# Converts the given column of the pandas records to UTC timestamps |
| 104 | +def convert_timestamp_records_to_utc(records, column): |
| 105 | + for record in records: |
| 106 | + record[column] = utils.make_tzaware(record[column]).astimezone(utc) |
| 107 | + return records |
| 108 | + |
| 109 | + |
| 110 | +# Find the latest record in the given time range and filter |
| 111 | +def find_asof_record(records, ts_key, ts_start, ts_end, filter_key, filter_value): |
| 112 | + found_record = {} |
| 113 | + for record in records: |
| 114 | + if record[filter_key] == filter_value and ts_start <= record[ts_key] <= ts_end: |
| 115 | + if not found_record or found_record[ts_key] < record[ts_key]: |
| 116 | + found_record = record |
| 117 | + return found_record |
| 118 | + |
| 119 | + |
101 | 120 | def get_expected_training_df( |
102 | 121 | customer_df: pd.DataFrame, |
103 | 122 | customer_fv: FeatureView, |
104 | 123 | driver_df: pd.DataFrame, |
105 | 124 | driver_fv: FeatureView, |
106 | 125 | orders_df: pd.DataFrame, |
107 | 126 | ): |
108 | | - expected_orders_df = orders_df.copy().sort_values(ENTITY_DF_EVENT_TIMESTAMP_COL) |
109 | | - expected_drivers_df = driver_df.copy().sort_values( |
110 | | - driver_fv.input.event_timestamp_column |
| 127 | + # Convert all pandas dataframes into records with UTC timestamps |
| 128 | + order_records = convert_timestamp_records_to_utc( |
| 129 | + orders_df.to_dict("records"), "event_timestamp" |
111 | 130 | ) |
112 | | - expected_orders_with_drivers = pd.merge_asof( |
113 | | - expected_orders_df, |
114 | | - expected_drivers_df[ |
115 | | - [ |
116 | | - driver_fv.input.event_timestamp_column, |
117 | | - "driver_id", |
118 | | - "conv_rate", |
119 | | - "avg_daily_trips", |
120 | | - ] |
121 | | - ], |
122 | | - left_on=ENTITY_DF_EVENT_TIMESTAMP_COL, |
123 | | - right_on=driver_fv.input.event_timestamp_column, |
124 | | - by=["driver_id"], |
125 | | - tolerance=driver_fv.ttl, |
| 131 | + driver_records = convert_timestamp_records_to_utc( |
| 132 | + driver_df.to_dict("records"), driver_fv.input.event_timestamp_column |
126 | 133 | ) |
127 | | - |
128 | | - expected_orders_with_drivers.drop( |
129 | | - columns=[driver_fv.input.event_timestamp_column], inplace=True |
| 134 | + customer_records = convert_timestamp_records_to_utc( |
| 135 | + customer_df.to_dict("records"), customer_fv.input.event_timestamp_column |
130 | 136 | ) |
131 | 137 |
|
132 | | - expected_customers_df = customer_df.copy().sort_values( |
133 | | - [customer_fv.input.event_timestamp_column] |
134 | | - ) |
135 | | - expected_df = pd.merge_asof( |
136 | | - expected_orders_with_drivers, |
137 | | - expected_customers_df[ |
138 | | - [ |
139 | | - customer_fv.input.event_timestamp_column, |
140 | | - "customer_id", |
141 | | - "current_balance", |
142 | | - "avg_passenger_count", |
143 | | - "lifetime_trip_count", |
144 | | - ] |
145 | | - ], |
146 | | - left_on=ENTITY_DF_EVENT_TIMESTAMP_COL, |
147 | | - right_on=customer_fv.input.event_timestamp_column, |
148 | | - by=["customer_id"], |
149 | | - tolerance=customer_fv.ttl, |
150 | | - ) |
151 | | - expected_df.drop(columns=[driver_fv.input.event_timestamp_column], inplace=True) |
| 138 | + # Manually do point-in-time join of orders to drivers and customers records |
| 139 | + for order_record in order_records: |
| 140 | + driver_record = find_asof_record( |
| 141 | + driver_records, |
| 142 | + ts_key=driver_fv.input.event_timestamp_column, |
| 143 | + ts_start=order_record["event_timestamp"] - driver_fv.ttl, |
| 144 | + ts_end=order_record["event_timestamp"], |
| 145 | + filter_key="driver_id", |
| 146 | + filter_value=order_record["driver_id"], |
| 147 | + ) |
| 148 | + customer_record = find_asof_record( |
| 149 | + customer_records, |
| 150 | + ts_key=customer_fv.input.event_timestamp_column, |
| 151 | + ts_start=order_record["event_timestamp"] - customer_fv.ttl, |
| 152 | + ts_end=order_record["event_timestamp"], |
| 153 | + filter_key="customer_id", |
| 154 | + filter_value=order_record["customer_id"], |
| 155 | + ) |
| 156 | + order_record.update( |
| 157 | + { |
| 158 | + f"driver_stats__{k}": driver_record.get(k, None) |
| 159 | + for k in ("conv_rate", "avg_daily_trips") |
| 160 | + } |
| 161 | + ) |
| 162 | + order_record.update( |
| 163 | + { |
| 164 | + f"customer_profile__{k}": customer_record.get(k, None) |
| 165 | + for k in ( |
| 166 | + "current_balance", |
| 167 | + "avg_passenger_count", |
| 168 | + "lifetime_trip_count", |
| 169 | + ) |
| 170 | + } |
| 171 | + ) |
| 172 | + |
| 173 | + # Convert records back to pandas dataframe |
| 174 | + expected_df = pd.DataFrame(order_records) |
152 | 175 |
|
153 | 176 | # Move "datetime" column to front |
154 | 177 | current_cols = expected_df.columns.tolist() |
155 | 178 | current_cols.remove(ENTITY_DF_EVENT_TIMESTAMP_COL) |
156 | 179 | expected_df = expected_df[[ENTITY_DF_EVENT_TIMESTAMP_COL] + current_cols] |
157 | 180 |
|
158 | | - # Rename columns to have double underscore |
159 | | - expected_df.rename( |
160 | | - inplace=True, |
161 | | - columns={ |
162 | | - "conv_rate": "driver_stats__conv_rate", |
163 | | - "avg_daily_trips": "driver_stats__avg_daily_trips", |
164 | | - "current_balance": "customer_profile__current_balance", |
165 | | - "avg_passenger_count": "customer_profile__avg_passenger_count", |
166 | | - "lifetime_trip_count": "customer_profile__lifetime_trip_count", |
167 | | - }, |
168 | | - ) |
| 181 | + # Cast some columns to expected types, since we lose information when converting pandas DFs into Python objects. |
| 182 | + expected_df["order_is_success"] = expected_df["order_is_success"].astype("int32") |
| 183 | + expected_df["customer_profile__current_balance"] = expected_df[ |
| 184 | + "customer_profile__current_balance" |
| 185 | + ].astype("float32") |
| 186 | + expected_df["customer_profile__avg_passenger_count"] = expected_df[ |
| 187 | + "customer_profile__avg_passenger_count" |
| 188 | + ].astype("float32") |
| 189 | + |
169 | 190 | return expected_df |
170 | 191 |
|
171 | 192 |
|
|
0 commit comments