Skip to content

Commit 3078bfb

Browse files
authored
Use contextvars to maintain a call stack during the usage calls (feast-dev#1882)
* Use contextvars to maintain a call stack during the usage calls Signed-off-by: Achal Shah <achals@gmail.com> * refactor Signed-off-by: Achal Shah <achals@gmail.com> * Remove default values Signed-off-by: Achal Shah <achals@gmail.com>
1 parent 09f525c commit 3078bfb

1 file changed

Lines changed: 21 additions & 10 deletions

File tree

sdk/python/feast/usage.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import concurrent.futures
15+
import contextvars
1516
import enum
1617
import logging
1718
import os
1819
import sys
1920
import uuid
21+
from collections import defaultdict
2022
from datetime import datetime
2123
from functools import wraps
2224
from os.path import expanduser, join
2325
from pathlib import Path
24-
from typing import List, Optional, Tuple
26+
from typing import List, Optional, Tuple, Union
2527

2628
import requests
2729

@@ -31,6 +33,7 @@
3133
_logger = logging.getLogger(__name__)
3234

3335
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
36+
call_stack: contextvars.ContextVar = contextvars.ContextVar("call_stack", default=[])
3437

3538

3639
@enum.unique
@@ -51,6 +54,8 @@ def __str__(self):
5154
class Usage:
5255
def __init__(self):
5356
self._usage_enabled: bool = False
57+
self._is_test = os.getenv("FEAST_IS_USAGE_TEST", "False") == "True"
58+
self._usage_counter = defaultdict(lambda: 0)
5459
self.check_env_and_configure()
5560

5661
def check_env_and_configure(self):
@@ -68,9 +73,6 @@ def check_env_and_configure(self):
6873
Path(feast_home_dir).mkdir(exist_ok=True)
6974
usage_filepath = join(feast_home_dir, "usage")
7075

71-
self._is_test = os.getenv("FEAST_IS_USAGE_TEST", "False") == "True"
72-
self._usage_counter = {}
73-
7476
if os.path.exists(usage_filepath):
7577
with open(usage_filepath, "r") as f:
7678
self._usage_id = f.read()
@@ -106,9 +108,8 @@ def _send_usage_request(self, json):
106108
def log_function(self, function_name: str):
107109
self.check_env_and_configure()
108110
if self._usage_enabled and self.usage_id:
109-
if (
110-
function_name == "get_online_features"
111-
and not self.should_log_for_get_online_features_event(function_name)
111+
if "get_online_features" in call_stack.get() and not self.should_log_for_get_online_features_event(
112+
"get_online_features"
112113
):
113114
return
114115
json = {
@@ -121,10 +122,10 @@ def log_function(self, function_name: str):
121122
}
122123
self._send_usage_request(json)
123124

124-
def should_log_for_get_online_features_event(self, event_name: str):
125-
if event_name not in self._usage_counter:
126-
self._usage_counter[event_name] = 0
125+
def increment_event_count(self, event_name: Union[UsageEvent, str]):
127126
self._usage_counter[event_name] += 1
127+
128+
def should_log_for_get_online_features_event(self, event_name: str):
128129
if self._usage_counter[event_name] % 10000 != 2:
129130
return False
130131
self._usage_counter[event_name] = 2 # avoid overflow
@@ -174,6 +175,7 @@ def log_exceptions(func):
174175
@wraps(func)
175176
def exception_logging_wrapper(*args, **kwargs):
176177
try:
178+
call_stack.set(call_stack.get() + [func.__name__])
177179
result = func(*args, **kwargs)
178180
except Exception as e:
179181
error_type = type(e).__name__
@@ -190,6 +192,9 @@ def exception_logging_wrapper(*args, **kwargs):
190192
tb = tb.tb_next
191193
usage.log_exception(error_type, trace_to_log)
192194
raise
195+
finally:
196+
if len(call_stack.get()) > 0:
197+
call_stack.set(call_stack.get()[:-1])
193198
return result
194199

195200
return exception_logging_wrapper
@@ -199,6 +204,8 @@ def log_exceptions_and_usage(func):
199204
@wraps(func)
200205
def exception_logging_wrapper(*args, **kwargs):
201206
try:
207+
call_stack.set(call_stack.get() + [func.__name__])
208+
usage.increment_event_count(func.__name__)
202209
result = func(*args, **kwargs)
203210
usage.log_function(func.__name__)
204211
except Exception as e:
@@ -216,12 +223,16 @@ def exception_logging_wrapper(*args, **kwargs):
216223
tb = tb.tb_next
217224
usage.log_exception(error_type, trace_to_log)
218225
raise
226+
finally:
227+
if len(call_stack.get()) > 0:
228+
call_stack.set(call_stack.get()[:-1])
219229
return result
220230

221231
return exception_logging_wrapper
222232

223233

224234
def log_event(event: UsageEvent):
235+
usage.increment_event_count(event)
225236
usage.log_event(event)
226237

227238

0 commit comments

Comments
 (0)