diff --git a/Lib/profile.py b/Lib/profile.py new file mode 100644 index 0000000000..a5afb12c9d --- /dev/null +++ b/Lib/profile.py @@ -0,0 +1,615 @@ +# +# Class for profiling python code. rev 1.0 6/2/94 +# +# Written by James Roskind +# Based on prior profile module by Sjoerd Mullender... +# which was hacked somewhat by: Guido van Rossum + +"""Class for profiling Python code.""" + +# Copyright Disney Enterprises, Inc. All Rights Reserved. +# Licensed to PSF under a Contributor Agreement +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +# either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + + +import importlib.machinery +import io +import sys +import time +import marshal + +__all__ = ["run", "runctx", "Profile"] + +# Sample timer for use with +#i_count = 0 +#def integer_timer(): +# global i_count +# i_count = i_count + 1 +# return i_count +#itimes = integer_timer # replace with C coded timer returning integers + +class _Utils: + """Support class for utility functions which are shared by + profile.py and cProfile.py modules. + Not supposed to be used directly. + """ + + def __init__(self, profiler): + self.profiler = profiler + + def run(self, statement, filename, sort): + prof = self.profiler() + try: + prof.run(statement) + except SystemExit: + pass + finally: + self._show(prof, filename, sort) + + def runctx(self, statement, globals, locals, filename, sort): + prof = self.profiler() + try: + prof.runctx(statement, globals, locals) + except SystemExit: + pass + finally: + self._show(prof, filename, sort) + + def _show(self, prof, filename, sort): + if filename is not None: + prof.dump_stats(filename) + else: + prof.print_stats(sort) + + +#************************************************************************** +# The following are the static member functions for the profiler class +# Note that an instance of Profile() is *not* needed to call them. +#************************************************************************** + +def run(statement, filename=None, sort=-1): + """Run statement under profiler optionally saving results in filename + + This function takes a single argument that can be passed to the + "exec" statement, and an optional file name. In all cases this + routine attempts to "exec" its first argument and gather profiling + statistics from the execution. If no file name is present, then this + function automatically prints a simple profiling report, sorted by the + standard name string (file/line/function-name) that is presented in + each line. + """ + return _Utils(Profile).run(statement, filename, sort) + +def runctx(statement, globals, locals, filename=None, sort=-1): + """Run statement under profiler, supplying your own globals and locals, + optionally saving results in filename. + + statement and filename have the same semantics as profile.run + """ + return _Utils(Profile).runctx(statement, globals, locals, filename, sort) + + +class Profile: + """Profiler class. + + self.cur is always a tuple. Each such tuple corresponds to a stack + frame that is currently active (self.cur[-2]). The following are the + definitions of its members. We use this external "parallel stack" to + avoid contaminating the program that we are profiling. (old profiler + used to write into the frames local dictionary!!) Derived classes + can change the definition of some entries, as long as they leave + [-2:] intact (frame and previous tuple). In case an internal error is + detected, the -3 element is used as the function name. + + [ 0] = Time that needs to be charged to the parent frame's function. + It is used so that a function call will not have to access the + timing data for the parent frame. + [ 1] = Total time spent in this frame's function, excluding time in + subfunctions (this latter is tallied in cur[2]). + [ 2] = Total time spent in subfunctions, excluding time executing the + frame's function (this latter is tallied in cur[1]). + [-3] = Name of the function that corresponds to this frame. + [-2] = Actual frame that we correspond to (used to sync exception handling). + [-1] = Our parent 6-tuple (corresponds to frame.f_back). + + Timing data for each function is stored as a 5-tuple in the dictionary + self.timings[]. The index is always the name stored in self.cur[-3]. + The following are the definitions of the members: + + [0] = The number of times this function was called, not counting direct + or indirect recursion, + [1] = Number of times this function appears on the stack, minus one + [2] = Total time spent internal to this function + [3] = Cumulative time that this function was present on the stack. In + non-recursive functions, this is the total execution time from start + to finish of each invocation of a function, including time spent in + all subfunctions. + [4] = A dictionary indicating for each function name, the number of times + it was called by us. + """ + + bias = 0 # calibration constant + + def __init__(self, timer=None, bias=None): + self.timings = {} + self.cur = None + self.cmd = "" + self.c_func_name = "" + + if bias is None: + bias = self.bias + self.bias = bias # Materialize in local dict for lookup speed. + + if not timer: + self.timer = self.get_time = time.process_time + self.dispatcher = self.trace_dispatch_i + else: + self.timer = timer + t = self.timer() # test out timer function + try: + length = len(t) + except TypeError: + self.get_time = timer + self.dispatcher = self.trace_dispatch_i + else: + if length == 2: + self.dispatcher = self.trace_dispatch + else: + self.dispatcher = self.trace_dispatch_l + # This get_time() implementation needs to be defined + # here to capture the passed-in timer in the parameter + # list (for performance). Note that we can't assume + # the timer() result contains two values in all + # cases. + def get_time_timer(timer=timer, sum=sum): + return sum(timer()) + self.get_time = get_time_timer + self.t = self.get_time() + self.simulate_call('profiler') + + # Heavily optimized dispatch routine for time.process_time() timer + + def trace_dispatch(self, frame, event, arg): + timer = self.timer + t = timer() + t = t[0] + t[1] - self.t - self.bias + + if event == "c_call": + self.c_func_name = arg.__name__ + + if self.dispatch[event](self, frame,t): + t = timer() + self.t = t[0] + t[1] + else: + r = timer() + self.t = r[0] + r[1] - t # put back unrecorded delta + + # Dispatch routine for best timer program (return = scalar, fastest if + # an integer but float works too -- and time.process_time() relies on that). + + def trace_dispatch_i(self, frame, event, arg): + timer = self.timer + t = timer() - self.t - self.bias + + if event == "c_call": + self.c_func_name = arg.__name__ + + if self.dispatch[event](self, frame, t): + self.t = timer() + else: + self.t = timer() - t # put back unrecorded delta + + # Dispatch routine for macintosh (timer returns time in ticks of + # 1/60th second) + + def trace_dispatch_mac(self, frame, event, arg): + timer = self.timer + t = timer()/60.0 - self.t - self.bias + + if event == "c_call": + self.c_func_name = arg.__name__ + + if self.dispatch[event](self, frame, t): + self.t = timer()/60.0 + else: + self.t = timer()/60.0 - t # put back unrecorded delta + + # SLOW generic dispatch routine for timer returning lists of numbers + + def trace_dispatch_l(self, frame, event, arg): + get_time = self.get_time + t = get_time() - self.t - self.bias + + if event == "c_call": + self.c_func_name = arg.__name__ + + if self.dispatch[event](self, frame, t): + self.t = get_time() + else: + self.t = get_time() - t # put back unrecorded delta + + # In the event handlers, the first 3 elements of self.cur are unpacked + # into vrbls w/ 3-letter names. The last two characters are meant to be + # mnemonic: + # _pt self.cur[0] "parent time" time to be charged to parent frame + # _it self.cur[1] "internal time" time spent directly in the function + # _et self.cur[2] "external time" time spent in subfunctions + + def trace_dispatch_exception(self, frame, t): + rpt, rit, ret, rfn, rframe, rcur = self.cur + if (rframe is not frame) and rcur: + return self.trace_dispatch_return(rframe, t) + self.cur = rpt, rit+t, ret, rfn, rframe, rcur + return 1 + + + def trace_dispatch_call(self, frame, t): + if self.cur and frame.f_back is not self.cur[-2]: + rpt, rit, ret, rfn, rframe, rcur = self.cur + if not isinstance(rframe, Profile.fake_frame): + assert rframe.f_back is frame.f_back, ("Bad call", rfn, + rframe, rframe.f_back, + frame, frame.f_back) + self.trace_dispatch_return(rframe, 0) + assert (self.cur is None or \ + frame.f_back is self.cur[-2]), ("Bad call", + self.cur[-3]) + fcode = frame.f_code + fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name) + self.cur = (t, 0, 0, fn, frame, self.cur) + timings = self.timings + if fn in timings: + cc, ns, tt, ct, callers = timings[fn] + timings[fn] = cc, ns + 1, tt, ct, callers + else: + timings[fn] = 0, 0, 0, 0, {} + return 1 + + def trace_dispatch_c_call (self, frame, t): + fn = ("", 0, self.c_func_name) + self.cur = (t, 0, 0, fn, frame, self.cur) + timings = self.timings + if fn in timings: + cc, ns, tt, ct, callers = timings[fn] + timings[fn] = cc, ns+1, tt, ct, callers + else: + timings[fn] = 0, 0, 0, 0, {} + return 1 + + def trace_dispatch_return(self, frame, t): + if frame is not self.cur[-2]: + assert frame is self.cur[-2].f_back, ("Bad return", self.cur[-3]) + self.trace_dispatch_return(self.cur[-2], 0) + + # Prefix "r" means part of the Returning or exiting frame. + # Prefix "p" means part of the Previous or Parent or older frame. + + rpt, rit, ret, rfn, frame, rcur = self.cur + rit = rit + t + frame_total = rit + ret + + ppt, pit, pet, pfn, pframe, pcur = rcur + self.cur = ppt, pit + rpt, pet + frame_total, pfn, pframe, pcur + + timings = self.timings + cc, ns, tt, ct, callers = timings[rfn] + if not ns: + # This is the only occurrence of the function on the stack. + # Else this is a (directly or indirectly) recursive call, and + # its cumulative time will get updated when the topmost call to + # it returns. + ct = ct + frame_total + cc = cc + 1 + + if pfn in callers: + callers[pfn] = callers[pfn] + 1 # hack: gather more + # stats such as the amount of time added to ct courtesy + # of this specific call, and the contribution to cc + # courtesy of this call. + else: + callers[pfn] = 1 + + timings[rfn] = cc, ns - 1, tt + rit, ct, callers + + return 1 + + + dispatch = { + "call": trace_dispatch_call, + "exception": trace_dispatch_exception, + "return": trace_dispatch_return, + "c_call": trace_dispatch_c_call, + "c_exception": trace_dispatch_return, # the C function returned + "c_return": trace_dispatch_return, + } + + + # The next few functions play with self.cmd. By carefully preloading + # our parallel stack, we can force the profiled result to include + # an arbitrary string as the name of the calling function. + # We use self.cmd as that string, and the resulting stats look + # very nice :-). + + def set_cmd(self, cmd): + if self.cur[-1]: return # already set + self.cmd = cmd + self.simulate_call(cmd) + + class fake_code: + def __init__(self, filename, line, name): + self.co_filename = filename + self.co_line = line + self.co_name = name + self.co_firstlineno = 0 + + def __repr__(self): + return repr((self.co_filename, self.co_line, self.co_name)) + + class fake_frame: + def __init__(self, code, prior): + self.f_code = code + self.f_back = prior + + def simulate_call(self, name): + code = self.fake_code('profile', 0, name) + if self.cur: + pframe = self.cur[-2] + else: + pframe = None + frame = self.fake_frame(code, pframe) + self.dispatch['call'](self, frame, 0) + + # collect stats from pending stack, including getting final + # timings for self.cmd frame. + + def simulate_cmd_complete(self): + get_time = self.get_time + t = get_time() - self.t + while self.cur[-1]: + # We *can* cause assertion errors here if + # dispatch_trace_return checks for a frame match! + self.dispatch['return'](self, self.cur[-2], t) + t = 0 + self.t = get_time() - t + + + def print_stats(self, sort=-1): + import pstats + if not isinstance(sort, tuple): + sort = (sort,) + pstats.Stats(self).strip_dirs().sort_stats(*sort).print_stats() + + def dump_stats(self, file): + with open(file, 'wb') as f: + self.create_stats() + marshal.dump(self.stats, f) + + def create_stats(self): + self.simulate_cmd_complete() + self.snapshot_stats() + + def snapshot_stats(self): + self.stats = {} + for func, (cc, ns, tt, ct, callers) in self.timings.items(): + callers = callers.copy() + nc = 0 + for callcnt in callers.values(): + nc += callcnt + self.stats[func] = cc, nc, tt, ct, callers + + + # The following two methods can be called by clients to use + # a profiler to profile a statement, given as a string. + + def run(self, cmd): + import __main__ + dict = __main__.__dict__ + return self.runctx(cmd, dict, dict) + + def runctx(self, cmd, globals, locals): + self.set_cmd(cmd) + sys.setprofile(self.dispatcher) + try: + exec(cmd, globals, locals) + finally: + sys.setprofile(None) + return self + + # This method is more useful to profile a single function call. + def runcall(self, func, /, *args, **kw): + self.set_cmd(repr(func)) + sys.setprofile(self.dispatcher) + try: + return func(*args, **kw) + finally: + sys.setprofile(None) + + + #****************************************************************** + # The following calculates the overhead for using a profiler. The + # problem is that it takes a fair amount of time for the profiler + # to stop the stopwatch (from the time it receives an event). + # Similarly, there is a delay from the time that the profiler + # re-starts the stopwatch before the user's code really gets to + # continue. The following code tries to measure the difference on + # a per-event basis. + # + # Note that this difference is only significant if there are a lot of + # events, and relatively little user code per event. For example, + # code with small functions will typically benefit from having the + # profiler calibrated for the current platform. This *could* be + # done on the fly during init() time, but it is not worth the + # effort. Also note that if too large a value specified, then + # execution time on some functions will actually appear as a + # negative number. It is *normal* for some functions (with very + # low call counts) to have such negative stats, even if the + # calibration figure is "correct." + # + # One alternative to profile-time calibration adjustments (i.e., + # adding in the magic little delta during each event) is to track + # more carefully the number of events (and cumulatively, the number + # of events during sub functions) that are seen. If this were + # done, then the arithmetic could be done after the fact (i.e., at + # display time). Currently, we track only call/return events. + # These values can be deduced by examining the callees and callers + # vectors for each functions. Hence we *can* almost correct the + # internal time figure at print time (note that we currently don't + # track exception event processing counts). Unfortunately, there + # is currently no similar information for cumulative sub-function + # time. It would not be hard to "get all this info" at profiler + # time. Specifically, we would have to extend the tuples to keep + # counts of this in each frame, and then extend the defs of timing + # tuples to include the significant two figures. I'm a bit fearful + # that this additional feature will slow the heavily optimized + # event/time ratio (i.e., the profiler would run slower, fur a very + # low "value added" feature.) + #************************************************************** + + def calibrate(self, m, verbose=0): + if self.__class__ is not Profile: + raise TypeError("Subclasses must override .calibrate().") + + saved_bias = self.bias + self.bias = 0 + try: + return self._calibrate_inner(m, verbose) + finally: + self.bias = saved_bias + + def _calibrate_inner(self, m, verbose): + get_time = self.get_time + + # Set up a test case to be run with and without profiling. Include + # lots of calls, because we're trying to quantify stopwatch overhead. + # Do not raise any exceptions, though, because we want to know + # exactly how many profile events are generated (one call event, + + # one return event, per Python-level call). + + def f1(n): + for i in range(n): + x = 1 + + def f(m, f1=f1): + for i in range(m): + f1(100) + + f(m) # warm up the cache + + # elapsed_noprofile <- time f(m) takes without profiling. + t0 = get_time() + f(m) + t1 = get_time() + elapsed_noprofile = t1 - t0 + if verbose: + print("elapsed time without profiling =", elapsed_noprofile) + + # elapsed_profile <- time f(m) takes with profiling. The difference + # is profiling overhead, only some of which the profiler subtracts + # out on its own. + p = Profile() + t0 = get_time() + p.runctx('f(m)', globals(), locals()) + t1 = get_time() + elapsed_profile = t1 - t0 + if verbose: + print("elapsed time with profiling =", elapsed_profile) + + # reported_time <- "CPU seconds" the profiler charged to f and f1. + total_calls = 0.0 + reported_time = 0.0 + for (filename, line, funcname), (cc, ns, tt, ct, callers) in \ + p.timings.items(): + if funcname in ("f", "f1"): + total_calls += cc + reported_time += tt + + if verbose: + print("'CPU seconds' profiler reported =", reported_time) + print("total # calls =", total_calls) + if total_calls != m + 1: + raise ValueError("internal error: total calls = %d" % total_calls) + + # reported_time - elapsed_noprofile = overhead the profiler wasn't + # able to measure. Divide by twice the number of calls (since there + # are two profiler events per call in this test) to get the hidden + # overhead per event. + mean = (reported_time - elapsed_noprofile) / 2.0 / total_calls + if verbose: + print("mean stopwatch overhead per profile event =", mean) + return mean + +#**************************************************************************** + +def main(): + import os + from optparse import OptionParser + + usage = "profile.py [-o output_file_path] [-s sort] [-m module | scriptfile] [arg] ..." + parser = OptionParser(usage=usage) + parser.allow_interspersed_args = False + parser.add_option('-o', '--outfile', dest="outfile", + help="Save stats to ", default=None) + parser.add_option('-m', dest="module", action="store_true", + help="Profile a library module.", default=False) + parser.add_option('-s', '--sort', dest="sort", + help="Sort order when printing to stdout, based on pstats.Stats class", + default=-1) + + if not sys.argv[1:]: + parser.print_usage() + sys.exit(2) + + (options, args) = parser.parse_args() + sys.argv[:] = args + + # The script that we're profiling may chdir, so capture the absolute path + # to the output file at startup. + if options.outfile is not None: + options.outfile = os.path.abspath(options.outfile) + + if len(args) > 0: + if options.module: + import runpy + code = "run_module(modname, run_name='__main__')" + globs = { + 'run_module': runpy.run_module, + 'modname': args[0] + } + else: + progname = args[0] + sys.path.insert(0, os.path.dirname(progname)) + with io.open_code(progname) as fp: + code = compile(fp.read(), progname, 'exec') + spec = importlib.machinery.ModuleSpec(name='__main__', loader=None, + origin=progname) + globs = { + '__spec__': spec, + '__file__': spec.origin, + '__name__': spec.name, + '__package__': None, + '__cached__': None, + } + try: + runctx(code, globs, None, options.outfile, options.sort) + except BrokenPipeError as exc: + # Prevent "Exception ignored" during interpreter shutdown. + sys.stdout = None + sys.exit(exc.errno) + else: + parser.print_usage() + return parser + +# When invoked as main program, invoke the profiler on a script +if __name__ == '__main__': + main() diff --git a/Lib/pstats.py b/Lib/pstats.py new file mode 100644 index 0000000000..becaf35580 --- /dev/null +++ b/Lib/pstats.py @@ -0,0 +1,777 @@ +"""Class for printing reports on profiled python code.""" + +# Written by James Roskind +# Based on prior profile module by Sjoerd Mullender... +# which was hacked somewhat by: Guido van Rossum + +# Copyright Disney Enterprises, Inc. All Rights Reserved. +# Licensed to PSF under a Contributor Agreement +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +# either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + + +import sys +import os +import time +import marshal +import re + +from enum import StrEnum, _simple_enum +from functools import cmp_to_key +from dataclasses import dataclass + +__all__ = ["Stats", "SortKey", "FunctionProfile", "StatsProfile"] + +@_simple_enum(StrEnum) +class SortKey: + CALLS = 'calls', 'ncalls' + CUMULATIVE = 'cumulative', 'cumtime' + FILENAME = 'filename', 'module' + LINE = 'line' + NAME = 'name' + NFL = 'nfl' + PCALLS = 'pcalls' + STDNAME = 'stdname' + TIME = 'time', 'tottime' + + def __new__(cls, *values): + value = values[0] + obj = str.__new__(cls, value) + obj._value_ = value + for other_value in values[1:]: + cls._value2member_map_[other_value] = obj + obj._all_values = values + return obj + + +@dataclass(unsafe_hash=True) +class FunctionProfile: + ncalls: str + tottime: float + percall_tottime: float + cumtime: float + percall_cumtime: float + file_name: str + line_number: int + +@dataclass(unsafe_hash=True) +class StatsProfile: + '''Class for keeping track of an item in inventory.''' + total_tt: float + func_profiles: dict[str, FunctionProfile] + +class Stats: + """This class is used for creating reports from data generated by the + Profile class. It is a "friend" of that class, and imports data either + by direct access to members of Profile class, or by reading in a dictionary + that was emitted (via marshal) from the Profile class. + + The big change from the previous Profiler (in terms of raw functionality) + is that an "add()" method has been provided to combine Stats from + several distinct profile runs. Both the constructor and the add() + method now take arbitrarily many file names as arguments. + + All the print methods now take an argument that indicates how many lines + to print. If the arg is a floating-point number between 0 and 1.0, then + it is taken as a decimal percentage of the available lines to be printed + (e.g., .1 means print 10% of all available lines). If it is an integer, + it is taken to mean the number of lines of data that you wish to have + printed. + + The sort_stats() method now processes some additional options (i.e., in + addition to the old -1, 0, 1, or 2 that are respectively interpreted as + 'stdname', 'calls', 'time', and 'cumulative'). It takes either an + arbitrary number of quoted strings or SortKey enum to select the sort + order. + + For example sort_stats('time', 'name') or sort_stats(SortKey.TIME, + SortKey.NAME) sorts on the major key of 'internal function time', and on + the minor key of 'the name of the function'. Look at the two tables in + sort_stats() and get_sort_arg_defs(self) for more examples. + + All methods return self, so you can string together commands like: + Stats('foo', 'goo').strip_dirs().sort_stats('calls').\ + print_stats(5).print_callers(5) + """ + + def __init__(self, *args, stream=None): + self.stream = stream or sys.stdout + if not len(args): + arg = None + else: + arg = args[0] + args = args[1:] + self.init(arg) + self.add(*args) + + def init(self, arg): + self.all_callees = None # calc only if needed + self.files = [] + self.fcn_list = None + self.total_tt = 0 + self.total_calls = 0 + self.prim_calls = 0 + self.max_name_len = 0 + self.top_level = set() + self.stats = {} + self.sort_arg_dict = {} + self.load_stats(arg) + try: + self.get_top_level_stats() + except Exception: + print("Invalid timing data %s" % + (self.files[-1] if self.files else ''), file=self.stream) + raise + + def load_stats(self, arg): + if arg is None: + self.stats = {} + return + elif isinstance(arg, str): + with open(arg, 'rb') as f: + self.stats = marshal.load(f) + try: + file_stats = os.stat(arg) + arg = time.ctime(file_stats.st_mtime) + " " + arg + except: # in case this is not unix + pass + self.files = [arg] + elif hasattr(arg, 'create_stats'): + arg.create_stats() + self.stats = arg.stats + arg.stats = {} + if not self.stats: + raise TypeError("Cannot create or construct a %r object from %r" + % (self.__class__, arg)) + return + + def get_top_level_stats(self): + for func, (cc, nc, tt, ct, callers) in self.stats.items(): + self.total_calls += nc + self.prim_calls += cc + self.total_tt += tt + if ("jprofile", 0, "profiler") in callers: + self.top_level.add(func) + if len(func_std_string(func)) > self.max_name_len: + self.max_name_len = len(func_std_string(func)) + + def add(self, *arg_list): + if not arg_list: + return self + for item in reversed(arg_list): + if type(self) != type(item): + item = Stats(item) + self.files += item.files + self.total_calls += item.total_calls + self.prim_calls += item.prim_calls + self.total_tt += item.total_tt + for func in item.top_level: + self.top_level.add(func) + + if self.max_name_len < item.max_name_len: + self.max_name_len = item.max_name_len + + self.fcn_list = None + + for func, stat in item.stats.items(): + if func in self.stats: + old_func_stat = self.stats[func] + else: + old_func_stat = (0, 0, 0, 0, {},) + self.stats[func] = add_func_stats(old_func_stat, stat) + return self + + def dump_stats(self, filename): + """Write the profile data to a file we know how to load back.""" + with open(filename, 'wb') as f: + marshal.dump(self.stats, f) + + # list the tuple indices and directions for sorting, + # along with some printable description + sort_arg_dict_default = { + "calls" : (((1,-1), ), "call count"), + "ncalls" : (((1,-1), ), "call count"), + "cumtime" : (((3,-1), ), "cumulative time"), + "cumulative": (((3,-1), ), "cumulative time"), + "filename" : (((4, 1), ), "file name"), + "line" : (((5, 1), ), "line number"), + "module" : (((4, 1), ), "file name"), + "name" : (((6, 1), ), "function name"), + "nfl" : (((6, 1),(4, 1),(5, 1),), "name/file/line"), + "pcalls" : (((0,-1), ), "primitive call count"), + "stdname" : (((7, 1), ), "standard name"), + "time" : (((2,-1), ), "internal time"), + "tottime" : (((2,-1), ), "internal time"), + } + + def get_sort_arg_defs(self): + """Expand all abbreviations that are unique.""" + if not self.sort_arg_dict: + self.sort_arg_dict = dict = {} + bad_list = {} + for word, tup in self.sort_arg_dict_default.items(): + fragment = word + while fragment: + if fragment in dict: + bad_list[fragment] = 0 + break + dict[fragment] = tup + fragment = fragment[:-1] + for word in bad_list: + del dict[word] + return self.sort_arg_dict + + def sort_stats(self, *field): + if not field: + self.fcn_list = 0 + return self + if len(field) == 1 and isinstance(field[0], int): + # Be compatible with old profiler + field = [ {-1: "stdname", + 0: "calls", + 1: "time", + 2: "cumulative"}[field[0]] ] + elif len(field) >= 2: + for arg in field[1:]: + if type(arg) != type(field[0]): + raise TypeError("Can't have mixed argument type") + + sort_arg_defs = self.get_sort_arg_defs() + + sort_tuple = () + self.sort_type = "" + connector = "" + for word in field: + if isinstance(word, SortKey): + word = word.value + sort_tuple = sort_tuple + sort_arg_defs[word][0] + self.sort_type += connector + sort_arg_defs[word][1] + connector = ", " + + stats_list = [] + for func, (cc, nc, tt, ct, callers) in self.stats.items(): + stats_list.append((cc, nc, tt, ct) + func + + (func_std_string(func), func)) + + stats_list.sort(key=cmp_to_key(TupleComp(sort_tuple).compare)) + + self.fcn_list = fcn_list = [] + for tuple in stats_list: + fcn_list.append(tuple[-1]) + return self + + def reverse_order(self): + if self.fcn_list: + self.fcn_list.reverse() + return self + + def strip_dirs(self): + oldstats = self.stats + self.stats = newstats = {} + max_name_len = 0 + for func, (cc, nc, tt, ct, callers) in oldstats.items(): + newfunc = func_strip_path(func) + if len(func_std_string(newfunc)) > max_name_len: + max_name_len = len(func_std_string(newfunc)) + newcallers = {} + for func2, caller in callers.items(): + newcallers[func_strip_path(func2)] = caller + + if newfunc in newstats: + newstats[newfunc] = add_func_stats( + newstats[newfunc], + (cc, nc, tt, ct, newcallers)) + else: + newstats[newfunc] = (cc, nc, tt, ct, newcallers) + old_top = self.top_level + self.top_level = new_top = set() + for func in old_top: + new_top.add(func_strip_path(func)) + + self.max_name_len = max_name_len + + self.fcn_list = None + self.all_callees = None + return self + + def calc_callees(self): + if self.all_callees: + return + self.all_callees = all_callees = {} + for func, (cc, nc, tt, ct, callers) in self.stats.items(): + if not func in all_callees: + all_callees[func] = {} + for func2, caller in callers.items(): + if not func2 in all_callees: + all_callees[func2] = {} + all_callees[func2][func] = caller + return + + #****************************************************************** + # The following functions support actual printing of reports + #****************************************************************** + + # Optional "amount" is either a line count, or a percentage of lines. + + def eval_print_amount(self, sel, list, msg): + new_list = list + if isinstance(sel, str): + try: + rex = re.compile(sel) + except re.PatternError: + msg += " \n" % sel + return new_list, msg + new_list = [] + for func in list: + if rex.search(func_std_string(func)): + new_list.append(func) + else: + count = len(list) + if isinstance(sel, float) and 0.0 <= sel < 1.0: + count = int(count * sel + .5) + new_list = list[:count] + elif isinstance(sel, int) and 0 <= sel < count: + count = sel + new_list = list[:count] + if len(list) != len(new_list): + msg += " List reduced from %r to %r due to restriction <%r>\n" % ( + len(list), len(new_list), sel) + + return new_list, msg + + def get_stats_profile(self): + """This method returns an instance of StatsProfile, which contains a mapping + of function names to instances of FunctionProfile. Each FunctionProfile + instance holds information related to the function's profile such as how + long the function took to run, how many times it was called, etc... + """ + func_list = self.fcn_list[:] if self.fcn_list else list(self.stats.keys()) + if not func_list: + return StatsProfile(0, {}) + + total_tt = float(f8(self.total_tt)) + func_profiles = {} + stats_profile = StatsProfile(total_tt, func_profiles) + + for func in func_list: + cc, nc, tt, ct, callers = self.stats[func] + file_name, line_number, func_name = func + ncalls = str(nc) if nc == cc else (str(nc) + '/' + str(cc)) + tottime = float(f8(tt)) + percall_tottime = -1 if nc == 0 else float(f8(tt/nc)) + cumtime = float(f8(ct)) + percall_cumtime = -1 if cc == 0 else float(f8(ct/cc)) + func_profile = FunctionProfile( + ncalls, + tottime, # time spent in this function alone + percall_tottime, + cumtime, # time spent in the function plus all functions that this function called, + percall_cumtime, + file_name, + line_number + ) + func_profiles[func_name] = func_profile + + return stats_profile + + def get_print_list(self, sel_list): + width = self.max_name_len + if self.fcn_list: + stat_list = self.fcn_list[:] + msg = " Ordered by: " + self.sort_type + '\n' + else: + stat_list = list(self.stats.keys()) + msg = " Random listing order was used\n" + + for selection in sel_list: + stat_list, msg = self.eval_print_amount(selection, stat_list, msg) + + count = len(stat_list) + + if not stat_list: + return 0, stat_list + print(msg, file=self.stream) + if count < len(self.stats): + width = 0 + for func in stat_list: + if len(func_std_string(func)) > width: + width = len(func_std_string(func)) + return width+2, stat_list + + def print_stats(self, *amount): + for filename in self.files: + print(filename, file=self.stream) + if self.files: + print(file=self.stream) + indent = ' ' * 8 + for func in self.top_level: + print(indent, func_get_function_name(func), file=self.stream) + + print(indent, self.total_calls, "function calls", end=' ', file=self.stream) + if self.total_calls != self.prim_calls: + print("(%d primitive calls)" % self.prim_calls, end=' ', file=self.stream) + print("in %.3f seconds" % self.total_tt, file=self.stream) + print(file=self.stream) + width, list = self.get_print_list(amount) + if list: + self.print_title() + for func in list: + self.print_line(func) + print(file=self.stream) + print(file=self.stream) + return self + + def print_callees(self, *amount): + width, list = self.get_print_list(amount) + if list: + self.calc_callees() + + self.print_call_heading(width, "called...") + for func in list: + if func in self.all_callees: + self.print_call_line(width, func, self.all_callees[func]) + else: + self.print_call_line(width, func, {}) + print(file=self.stream) + print(file=self.stream) + return self + + def print_callers(self, *amount): + width, list = self.get_print_list(amount) + if list: + self.print_call_heading(width, "was called by...") + for func in list: + cc, nc, tt, ct, callers = self.stats[func] + self.print_call_line(width, func, callers, "<-") + print(file=self.stream) + print(file=self.stream) + return self + + def print_call_heading(self, name_size, column_title): + print("Function ".ljust(name_size) + column_title, file=self.stream) + # print sub-header only if we have new-style callers + subheader = False + for cc, nc, tt, ct, callers in self.stats.values(): + if callers: + value = next(iter(callers.values())) + subheader = isinstance(value, tuple) + break + if subheader: + print(" "*name_size + " ncalls tottime cumtime", file=self.stream) + + def print_call_line(self, name_size, source, call_dict, arrow="->"): + print(func_std_string(source).ljust(name_size) + arrow, end=' ', file=self.stream) + if not call_dict: + print(file=self.stream) + return + clist = sorted(call_dict.keys()) + indent = "" + for func in clist: + name = func_std_string(func) + value = call_dict[func] + if isinstance(value, tuple): + nc, cc, tt, ct = value + if nc != cc: + substats = '%d/%d' % (nc, cc) + else: + substats = '%d' % (nc,) + substats = '%s %s %s %s' % (substats.rjust(7+2*len(indent)), + f8(tt), f8(ct), name) + left_width = name_size + 1 + else: + substats = '%s(%r) %s' % (name, value, f8(self.stats[func][3])) + left_width = name_size + 3 + print(indent*left_width + substats, file=self.stream) + indent = " " + + def print_title(self): + print(' ncalls tottime percall cumtime percall', end=' ', file=self.stream) + print('filename:lineno(function)', file=self.stream) + + def print_line(self, func): # hack: should print percentages + cc, nc, tt, ct, callers = self.stats[func] + c = str(nc) + if nc != cc: + c = c + '/' + str(cc) + print(c.rjust(9), end=' ', file=self.stream) + print(f8(tt), end=' ', file=self.stream) + if nc == 0: + print(' '*8, end=' ', file=self.stream) + else: + print(f8(tt/nc), end=' ', file=self.stream) + print(f8(ct), end=' ', file=self.stream) + if cc == 0: + print(' '*8, end=' ', file=self.stream) + else: + print(f8(ct/cc), end=' ', file=self.stream) + print(func_std_string(func), file=self.stream) + +class TupleComp: + """This class provides a generic function for comparing any two tuples. + Each instance records a list of tuple-indices (from most significant + to least significant), and sort direction (ascending or descending) for + each tuple-index. The compare functions can then be used as the function + argument to the system sort() function when a list of tuples need to be + sorted in the instances order.""" + + def __init__(self, comp_select_list): + self.comp_select_list = comp_select_list + + def compare (self, left, right): + for index, direction in self.comp_select_list: + l = left[index] + r = right[index] + if l < r: + return -direction + if l > r: + return direction + return 0 + + +#************************************************************************** +# func_name is a triple (file:string, line:int, name:string) + +def func_strip_path(func_name): + filename, line, name = func_name + return os.path.basename(filename), line, name + +def func_get_function_name(func): + return func[2] + +def func_std_string(func_name): # match what old profile produced + if func_name[:2] == ('~', 0): + # special case for built-in functions + name = func_name[2] + if name.startswith('<') and name.endswith('>'): + return '{%s}' % name[1:-1] + else: + return name + else: + return "%s:%d(%s)" % func_name + +#************************************************************************** +# The following functions combine statistics for pairs functions. +# The bulk of the processing involves correctly handling "call" lists, +# such as callers and callees. +#************************************************************************** + +def add_func_stats(target, source): + """Add together all the stats for two profile entries.""" + cc, nc, tt, ct, callers = source + t_cc, t_nc, t_tt, t_ct, t_callers = target + return (cc+t_cc, nc+t_nc, tt+t_tt, ct+t_ct, + add_callers(t_callers, callers)) + +def add_callers(target, source): + """Combine two caller lists in a single list.""" + new_callers = {} + for func, caller in target.items(): + new_callers[func] = caller + for func, caller in source.items(): + if func in new_callers: + if isinstance(caller, tuple): + # format used by cProfile + new_callers[func] = tuple(i + j for i, j in zip(caller, new_callers[func])) + else: + # format used by profile + new_callers[func] += caller + else: + new_callers[func] = caller + return new_callers + +def count_calls(callers): + """Sum the caller statistics to get total number of calls received.""" + nc = 0 + for calls in callers.values(): + nc += calls + return nc + +#************************************************************************** +# The following functions support printing of reports +#************************************************************************** + +def f8(x): + return "%8.3f" % x + +#************************************************************************** +# Statistics browser added by ESR, April 2001 +#************************************************************************** + +if __name__ == '__main__': + import cmd + try: + import readline # noqa: F401 + except ImportError: + pass + + class ProfileBrowser(cmd.Cmd): + def __init__(self, profile=None): + cmd.Cmd.__init__(self) + self.prompt = "% " + self.stats = None + self.stream = sys.stdout + if profile is not None: + self.do_read(profile) + + def generic(self, fn, line): + args = line.split() + processed = [] + for term in args: + try: + processed.append(int(term)) + continue + except ValueError: + pass + try: + frac = float(term) + if frac > 1 or frac < 0: + print("Fraction argument must be in [0, 1]", file=self.stream) + continue + processed.append(frac) + continue + except ValueError: + pass + processed.append(term) + if self.stats: + getattr(self.stats, fn)(*processed) + else: + print("No statistics object is loaded.", file=self.stream) + return 0 + def generic_help(self): + print("Arguments may be:", file=self.stream) + print("* An integer maximum number of entries to print.", file=self.stream) + print("* A decimal fractional number between 0 and 1, controlling", file=self.stream) + print(" what fraction of selected entries to print.", file=self.stream) + print("* A regular expression; only entries with function names", file=self.stream) + print(" that match it are printed.", file=self.stream) + + def do_add(self, line): + if self.stats: + try: + self.stats.add(line) + except OSError as e: + print("Failed to load statistics for %s: %s" % (line, e), file=self.stream) + else: + print("No statistics object is loaded.", file=self.stream) + return 0 + def help_add(self): + print("Add profile info from given file to current statistics object.", file=self.stream) + + def do_callees(self, line): + return self.generic('print_callees', line) + def help_callees(self): + print("Print callees statistics from the current stat object.", file=self.stream) + self.generic_help() + + def do_callers(self, line): + return self.generic('print_callers', line) + def help_callers(self): + print("Print callers statistics from the current stat object.", file=self.stream) + self.generic_help() + + def do_EOF(self, line): + print("", file=self.stream) + return 1 + def help_EOF(self): + print("Leave the profile browser.", file=self.stream) + + def do_quit(self, line): + return 1 + def help_quit(self): + print("Leave the profile browser.", file=self.stream) + + def do_read(self, line): + if line: + try: + self.stats = Stats(line) + except OSError as err: + print(err.args[1], file=self.stream) + return + except Exception as err: + print(err.__class__.__name__ + ':', err, file=self.stream) + return + self.prompt = line + "% " + elif len(self.prompt) > 2: + line = self.prompt[:-2] + self.do_read(line) + else: + print("No statistics object is current -- cannot reload.", file=self.stream) + return 0 + def help_read(self): + print("Read in profile data from a specified file.", file=self.stream) + print("Without argument, reload the current file.", file=self.stream) + + def do_reverse(self, line): + if self.stats: + self.stats.reverse_order() + else: + print("No statistics object is loaded.", file=self.stream) + return 0 + def help_reverse(self): + print("Reverse the sort order of the profiling report.", file=self.stream) + + def do_sort(self, line): + if not self.stats: + print("No statistics object is loaded.", file=self.stream) + return + abbrevs = self.stats.get_sort_arg_defs() + if line and all((x in abbrevs) for x in line.split()): + self.stats.sort_stats(*line.split()) + else: + print("Valid sort keys (unique prefixes are accepted):", file=self.stream) + for (key, value) in Stats.sort_arg_dict_default.items(): + print("%s -- %s" % (key, value[1]), file=self.stream) + return 0 + def help_sort(self): + print("Sort profile data according to specified keys.", file=self.stream) + print("(Typing `sort' without arguments lists valid keys.)", file=self.stream) + def complete_sort(self, text, *args): + return [a for a in Stats.sort_arg_dict_default if a.startswith(text)] + + def do_stats(self, line): + return self.generic('print_stats', line) + def help_stats(self): + print("Print statistics from the current stat object.", file=self.stream) + self.generic_help() + + def do_strip(self, line): + if self.stats: + self.stats.strip_dirs() + else: + print("No statistics object is loaded.", file=self.stream) + def help_strip(self): + print("Strip leading path information from filenames in the report.", file=self.stream) + + def help_help(self): + print("Show help for a given command.", file=self.stream) + + def postcmd(self, stop, line): + if stop: + return stop + return None + + if len(sys.argv) > 1: + initprofile = sys.argv[1] + else: + initprofile = None + try: + browser = ProfileBrowser(initprofile) + for profile in sys.argv[2:]: + browser.do_add(profile) + print("Welcome to the profile statistics browser.", file=browser.stream) + browser.cmdloop() + print("Goodbye.", file=browser.stream) + except KeyboardInterrupt: + pass + +# That's all, folks. diff --git a/Lib/test/pstats.pck b/Lib/test/pstats.pck new file mode 100644 index 0000000000..c48ccb73a9 Binary files /dev/null and b/Lib/test/pstats.pck differ diff --git a/Lib/test/test_cprofile.py b/Lib/test/test_cprofile.py new file mode 100644 index 0000000000..48057e8f03 --- /dev/null +++ b/Lib/test/test_cprofile.py @@ -0,0 +1,234 @@ +"""Test suite for the cProfile module.""" + +import sys +import unittest + +# rip off all interesting stuff from test_profile +try: + import cProfile +except ImportError: + # TODO: RUSTPYTHON; _lsprof not implemented + raise unittest.SkipTest('cProfile requires _lsprof') +import tempfile +import textwrap +from test.test_profile import ProfileTest, regenerate_expected_output +from test.support.script_helper import assert_python_failure, assert_python_ok +from test import support + + +class CProfileTest(ProfileTest): + profilerclass = cProfile.Profile + profilermodule = cProfile + expected_max_output = "{built-in method builtins.max}" + + def get_expected_output(self): + return _ProfileOutput + + def test_bad_counter_during_dealloc(self): + # bpo-3895 + import _lsprof + + with support.catch_unraisable_exception() as cm: + obj = _lsprof.Profiler(lambda: int) + obj.enable() + obj.disable() + obj.clear() + + self.assertEqual(cm.unraisable.exc_type, TypeError) + + def test_crash_with_not_enough_args(self): + # gh-126220 + import _lsprof + + for profile in [_lsprof.Profiler(), cProfile.Profile()]: + for method in [ + "_pystart_callback", + "_pyreturn_callback", + "_ccall_callback", + "_creturn_callback", + ]: + with self.subTest(profile=profile, method=method): + method_obj = getattr(profile, method) + with self.assertRaises(TypeError): + method_obj() # should not crash + + def test_evil_external_timer(self): + # gh-120289 + # Disabling profiler in external timer should not crash + import _lsprof + class EvilTimer(): + def __init__(self, disable_count): + self.count = 0 + self.disable_count = disable_count + + def __call__(self): + self.count += 1 + if self.count == self.disable_count: + profiler_with_evil_timer.disable() + return self.count + + # this will trigger external timer to disable profiler at + # call event - in initContext in _lsprof.c + with support.catch_unraisable_exception() as cm: + profiler_with_evil_timer = _lsprof.Profiler(EvilTimer(1)) + profiler_with_evil_timer.enable() + # Make a call to trigger timer + (lambda: None)() + profiler_with_evil_timer.disable() + profiler_with_evil_timer.clear() + self.assertEqual(cm.unraisable.exc_type, RuntimeError) + + # this will trigger external timer to disable profiler at + # return event - in Stop in _lsprof.c + with support.catch_unraisable_exception() as cm: + profiler_with_evil_timer = _lsprof.Profiler(EvilTimer(2)) + profiler_with_evil_timer.enable() + # Make a call to trigger timer + (lambda: None)() + profiler_with_evil_timer.disable() + profiler_with_evil_timer.clear() + self.assertEqual(cm.unraisable.exc_type, RuntimeError) + + def test_profile_enable_disable(self): + prof = self.profilerclass() + # Make sure we clean ourselves up if the test fails for some reason. + self.addCleanup(prof.disable) + + prof.enable() + self.assertEqual( + sys.monitoring.get_tool(sys.monitoring.PROFILER_ID), "cProfile") + + prof.disable() + self.assertIs(sys.monitoring.get_tool(sys.monitoring.PROFILER_ID), None) + + def test_profile_as_context_manager(self): + prof = self.profilerclass() + # Make sure we clean ourselves up if the test fails for some reason. + self.addCleanup(prof.disable) + + with prof as __enter__return_value: + # profile.__enter__ should return itself. + self.assertIs(prof, __enter__return_value) + + # profile should be set as the global profiler inside the + # with-block + self.assertEqual( + sys.monitoring.get_tool(sys.monitoring.PROFILER_ID), "cProfile") + + # profile shouldn't be set once we leave the with-block. + self.assertIs(sys.monitoring.get_tool(sys.monitoring.PROFILER_ID), None) + + def test_second_profiler(self): + pr = self.profilerclass() + pr2 = self.profilerclass() + pr.enable() + self.assertRaises(ValueError, pr2.enable) + pr.disable() + + def test_throw(self): + """ + gh-106152 + generator.throw() should trigger a call in cProfile + """ + + def gen(): + yield + + pr = self.profilerclass() + pr.enable() + g = gen() + try: + g.throw(SyntaxError) + except SyntaxError: + pass + pr.disable() + pr.create_stats() + + self.assertTrue(any("throw" in func[2] for func in pr.stats.keys())), + + def test_bad_descriptor(self): + # gh-132250 + # cProfile should not crash when the profiler callback fails to locate + # the actual function of a method. + with self.profilerclass() as prof: + with self.assertRaises(TypeError): + bytes.find(str()) + + +class TestCommandLine(unittest.TestCase): + def test_sort(self): + rc, out, err = assert_python_failure('-m', 'cProfile', '-s', 'demo') + self.assertGreater(rc, 0) + self.assertIn(b"option -s: invalid choice: 'demo'", err) + + def test_profile_script_importing_main(self): + """Check that scripts that reference __main__ see their own namespace + when being profiled.""" + with tempfile.NamedTemporaryFile("w+", delete_on_close=False) as f: + f.write(textwrap.dedent("""\ + class Foo: + pass + import __main__ + assert Foo == __main__.Foo + """)) + f.close() + assert_python_ok('-m', "cProfile", f.name) + + +def main(): + if '-r' not in sys.argv: + unittest.main() + else: + regenerate_expected_output(__file__, CProfileTest) + + +# Don't remove this comment. Everything below it is auto-generated. +#--cut-------------------------------------------------------------------------- +_ProfileOutput = {} +_ProfileOutput['print_stats'] = """\ + 28 0.028 0.001 0.028 0.001 profilee.py:110(__getattr__) + 1 0.270 0.270 1.000 1.000 profilee.py:25(testfunc) + 23/3 0.150 0.007 0.170 0.057 profilee.py:35(factorial) + 20 0.020 0.001 0.020 0.001 profilee.py:48(mul) + 2 0.040 0.020 0.600 0.300 profilee.py:55(helper) + 4 0.116 0.029 0.120 0.030 profilee.py:73(helper1) + 2 0.000 0.000 0.140 0.070 profilee.py:84(helper2_indirect) + 8 0.312 0.039 0.400 0.050 profilee.py:88(helper2) + 8 0.064 0.008 0.080 0.010 profilee.py:98(subhelper)""" +_ProfileOutput['print_callers'] = """\ +profilee.py:110(__getattr__) <- 16 0.016 0.016 profilee.py:98(subhelper) +profilee.py:25(testfunc) <- 1 0.270 1.000 :1() +profilee.py:35(factorial) <- 1 0.014 0.130 profilee.py:25(testfunc) + 20/3 0.130 0.147 profilee.py:35(factorial) + 2 0.006 0.040 profilee.py:84(helper2_indirect) +profilee.py:48(mul) <- 20 0.020 0.020 profilee.py:35(factorial) +profilee.py:55(helper) <- 2 0.040 0.600 profilee.py:25(testfunc) +profilee.py:73(helper1) <- 4 0.116 0.120 profilee.py:55(helper) +profilee.py:84(helper2_indirect) <- 2 0.000 0.140 profilee.py:55(helper) +profilee.py:88(helper2) <- 6 0.234 0.300 profilee.py:55(helper) + 2 0.078 0.100 profilee.py:84(helper2_indirect) +profilee.py:98(subhelper) <- 8 0.064 0.080 profilee.py:88(helper2) +{built-in method builtins.hasattr} <- 4 0.000 0.004 profilee.py:73(helper1) + 8 0.000 0.008 profilee.py:88(helper2) +{built-in method sys.exception} <- 4 0.000 0.000 profilee.py:73(helper1) +{method 'append' of 'list' objects} <- 4 0.000 0.000 profilee.py:73(helper1)""" +_ProfileOutput['print_callees'] = """\ +:1() -> 1 0.270 1.000 profilee.py:25(testfunc) +profilee.py:110(__getattr__) -> +profilee.py:25(testfunc) -> 1 0.014 0.130 profilee.py:35(factorial) + 2 0.040 0.600 profilee.py:55(helper) +profilee.py:35(factorial) -> 20/3 0.130 0.147 profilee.py:35(factorial) + 20 0.020 0.020 profilee.py:48(mul) +profilee.py:48(mul) -> +profilee.py:55(helper) -> 4 0.116 0.120 profilee.py:73(helper1) + 2 0.000 0.140 profilee.py:84(helper2_indirect) + 6 0.234 0.300 profilee.py:88(helper2) +profilee.py:73(helper1) -> 4 0.000 0.004 {built-in method builtins.hasattr} +profilee.py:84(helper2_indirect) -> 2 0.006 0.040 profilee.py:35(factorial) + 2 0.078 0.100 profilee.py:88(helper2) +profilee.py:88(helper2) -> 8 0.064 0.080 profilee.py:98(subhelper) +profilee.py:98(subhelper) -> 16 0.016 0.016 profilee.py:110(__getattr__) +{built-in method builtins.hasattr} -> 12 0.012 0.012 profilee.py:110(__getattr__)""" + +if __name__ == "__main__": + main() diff --git a/Lib/test/test_multibytecodec.py b/Lib/test/test_multibytecodec.py new file mode 100644 index 0000000000..da0a7689d7 --- /dev/null +++ b/Lib/test/test_multibytecodec.py @@ -0,0 +1,415 @@ +# +# test_multibytecodec.py +# Unit test for multibytecodec itself +# + +import codecs +import io +import sys +import textwrap +import unittest +try: + import _multibytecodec +except ImportError: + # TODO: RUSTPYTHON; _multibytecodec not implemented + raise unittest.SkipTest('_multibytecodec not available') +from test import support +from test.support import os_helper +from test.support.os_helper import TESTFN +from test.support.import_helper import import_module + +ALL_CJKENCODINGS = [ +# _codecs_cn + 'gb2312', 'gbk', 'gb18030', 'hz', +# _codecs_hk + 'big5hkscs', +# _codecs_jp + 'cp932', 'shift_jis', 'euc_jp', 'euc_jisx0213', 'shift_jisx0213', + 'euc_jis_2004', 'shift_jis_2004', +# _codecs_kr + 'cp949', 'euc_kr', 'johab', +# _codecs_tw + 'big5', 'cp950', +# _codecs_iso2022 + 'iso2022_jp', 'iso2022_jp_1', 'iso2022_jp_2', 'iso2022_jp_2004', + 'iso2022_jp_3', 'iso2022_jp_ext', 'iso2022_kr', +] + +class Test_MultibyteCodec(unittest.TestCase): + + def test_nullcoding(self): + for enc in ALL_CJKENCODINGS: + self.assertEqual(b''.decode(enc), '') + self.assertEqual(str(b'', enc), '') + self.assertEqual(''.encode(enc), b'') + + def test_str_decode(self): + for enc in ALL_CJKENCODINGS: + self.assertEqual('abcd'.encode(enc), b'abcd') + + def test_errorcallback_longindex(self): + dec = codecs.getdecoder('euc-kr') + myreplace = lambda exc: ('', sys.maxsize+1) + codecs.register_error('test.cjktest', myreplace) + self.assertRaises(IndexError, dec, + b'apple\x92ham\x93spam', 'test.cjktest') + + def test_errorcallback_custom_ignore(self): + # Issue #23215: MemoryError with custom error handlers and multibyte codecs + data = 100 * "\udc00" + codecs.register_error("test.ignore", codecs.ignore_errors) + for enc in ALL_CJKENCODINGS: + self.assertEqual(data.encode(enc, "test.ignore"), b'') + + def test_codingspec(self): + try: + for enc in ALL_CJKENCODINGS: + code = '# coding: {}\n'.format(enc) + exec(code) + finally: + os_helper.unlink(TESTFN) + + def test_init_segfault(self): + # bug #3305: this used to segfault + self.assertRaises(AttributeError, + _multibytecodec.MultibyteStreamReader, None) + self.assertRaises(AttributeError, + _multibytecodec.MultibyteStreamWriter, None) + + def test_decode_unicode(self): + # Trying to decode a unicode string should raise a TypeError + for enc in ALL_CJKENCODINGS: + self.assertRaises(TypeError, codecs.getdecoder(enc), "") + +class Test_IncrementalEncoder(unittest.TestCase): + + def test_stateless(self): + # cp949 encoder isn't stateful at all. + encoder = codecs.getincrementalencoder('cp949')() + self.assertEqual(encoder.encode('\ud30c\uc774\uc36c \ub9c8\uc744'), + b'\xc6\xc4\xc0\xcc\xbd\xe3 \xb8\xb6\xc0\xbb') + self.assertEqual(encoder.reset(), None) + self.assertEqual(encoder.encode('\u2606\u223c\u2606', True), + b'\xa1\xd9\xa1\xad\xa1\xd9') + self.assertEqual(encoder.reset(), None) + self.assertEqual(encoder.encode('', True), b'') + self.assertEqual(encoder.encode('', False), b'') + self.assertEqual(encoder.reset(), None) + + def test_stateful(self): + # jisx0213 encoder is stateful for a few code points. eg) + # U+00E6 => A9DC + # U+00E6 U+0300 => ABC4 + # U+0300 => ABDC + + encoder = codecs.getincrementalencoder('jisx0213')() + self.assertEqual(encoder.encode('\u00e6\u0300'), b'\xab\xc4') + self.assertEqual(encoder.encode('\u00e6'), b'') + self.assertEqual(encoder.encode('\u0300'), b'\xab\xc4') + self.assertEqual(encoder.encode('\u00e6', True), b'\xa9\xdc') + + self.assertEqual(encoder.reset(), None) + self.assertEqual(encoder.encode('\u0300'), b'\xab\xdc') + + self.assertEqual(encoder.encode('\u00e6'), b'') + self.assertEqual(encoder.encode('', True), b'\xa9\xdc') + self.assertEqual(encoder.encode('', True), b'') + + def test_stateful_keep_buffer(self): + encoder = codecs.getincrementalencoder('jisx0213')() + self.assertEqual(encoder.encode('\u00e6'), b'') + self.assertRaises(UnicodeEncodeError, encoder.encode, '\u0123') + self.assertEqual(encoder.encode('\u0300\u00e6'), b'\xab\xc4') + self.assertRaises(UnicodeEncodeError, encoder.encode, '\u0123') + self.assertEqual(encoder.reset(), None) + self.assertEqual(encoder.encode('\u0300'), b'\xab\xdc') + self.assertEqual(encoder.encode('\u00e6'), b'') + self.assertRaises(UnicodeEncodeError, encoder.encode, '\u0123') + self.assertEqual(encoder.encode('', True), b'\xa9\xdc') + + def test_state_methods_with_buffer_state(self): + # euc_jis_2004 stores state as a buffer of pending bytes + encoder = codecs.getincrementalencoder('euc_jis_2004')() + + initial_state = encoder.getstate() + self.assertEqual(encoder.encode('\u00e6\u0300'), b'\xab\xc4') + encoder.setstate(initial_state) + self.assertEqual(encoder.encode('\u00e6\u0300'), b'\xab\xc4') + + self.assertEqual(encoder.encode('\u00e6'), b'') + partial_state = encoder.getstate() + self.assertEqual(encoder.encode('\u0300'), b'\xab\xc4') + encoder.setstate(partial_state) + self.assertEqual(encoder.encode('\u0300'), b'\xab\xc4') + + def test_state_methods_with_non_buffer_state(self): + # iso2022_jp stores state without using a buffer + encoder = codecs.getincrementalencoder('iso2022_jp')() + + self.assertEqual(encoder.encode('z'), b'z') + en_state = encoder.getstate() + + self.assertEqual(encoder.encode('\u3042'), b'\x1b\x24\x42\x24\x22') + jp_state = encoder.getstate() + self.assertEqual(encoder.encode('z'), b'\x1b\x28\x42z') + + encoder.setstate(jp_state) + self.assertEqual(encoder.encode('\u3042'), b'\x24\x22') + + encoder.setstate(en_state) + self.assertEqual(encoder.encode('z'), b'z') + + def test_getstate_returns_expected_value(self): + # Note: getstate is implemented such that these state values + # are expected to be the same across all builds of Python, + # regardless of x32/64 bit, endianness and compiler. + + # euc_jis_2004 stores state as a buffer of pending bytes + buffer_state_encoder = codecs.getincrementalencoder('euc_jis_2004')() + self.assertEqual(buffer_state_encoder.getstate(), 0) + buffer_state_encoder.encode('\u00e6') + self.assertEqual(buffer_state_encoder.getstate(), + int.from_bytes( + b"\x02" + b"\xc3\xa6" + b"\x00\x00\x00\x00\x00\x00\x00\x00", + 'little')) + buffer_state_encoder.encode('\u0300') + self.assertEqual(buffer_state_encoder.getstate(), 0) + + # iso2022_jp stores state without using a buffer + non_buffer_state_encoder = codecs.getincrementalencoder('iso2022_jp')() + self.assertEqual(non_buffer_state_encoder.getstate(), + int.from_bytes( + b"\x00" + b"\x42\x42\x00\x00\x00\x00\x00\x00", + 'little')) + non_buffer_state_encoder.encode('\u3042') + self.assertEqual(non_buffer_state_encoder.getstate(), + int.from_bytes( + b"\x00" + b"\xc2\x42\x00\x00\x00\x00\x00\x00", + 'little')) + + def test_setstate_validates_input_size(self): + encoder = codecs.getincrementalencoder('euc_jp')() + pending_size_nine = int.from_bytes( + b"\x09" + b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00\x00\x00", + 'little') + self.assertRaises(UnicodeError, encoder.setstate, pending_size_nine) + + def test_setstate_validates_input_bytes(self): + encoder = codecs.getincrementalencoder('euc_jp')() + invalid_utf8 = int.from_bytes( + b"\x01" + b"\xff" + b"\x00\x00\x00\x00\x00\x00\x00\x00", + 'little') + self.assertRaises(UnicodeDecodeError, encoder.setstate, invalid_utf8) + + def test_issue5640(self): + encoder = codecs.getincrementalencoder('shift-jis')('backslashreplace') + self.assertEqual(encoder.encode('\xff'), b'\\xff') + self.assertEqual(encoder.encode('\n'), b'\n') + + @support.cpython_only + def test_subinterp(self): + # bpo-42846: Test a CJK codec in a subinterpreter + _testcapi = import_module("_testcapi") + encoding = 'cp932' + text = "Python の開発は、1990 年ごろから開始されています。" + code = textwrap.dedent(""" + import codecs + encoding = %r + text = %r + encoder = codecs.getincrementalencoder(encoding)() + text2 = encoder.encode(text).decode(encoding) + if text2 != text: + raise ValueError(f"encoding issue: {text2!a} != {text!a}") + """) % (encoding, text) + res = _testcapi.run_in_subinterp(code) + self.assertEqual(res, 0) + +class Test_IncrementalDecoder(unittest.TestCase): + + def test_dbcs(self): + # cp949 decoder is simple with only 1 or 2 bytes sequences. + decoder = codecs.getincrementaldecoder('cp949')() + self.assertEqual(decoder.decode(b'\xc6\xc4\xc0\xcc\xbd'), + '\ud30c\uc774') + self.assertEqual(decoder.decode(b'\xe3 \xb8\xb6\xc0\xbb'), + '\uc36c \ub9c8\uc744') + self.assertEqual(decoder.decode(b''), '') + + def test_dbcs_keep_buffer(self): + decoder = codecs.getincrementaldecoder('cp949')() + self.assertEqual(decoder.decode(b'\xc6\xc4\xc0'), '\ud30c') + self.assertRaises(UnicodeDecodeError, decoder.decode, b'', True) + self.assertEqual(decoder.decode(b'\xcc'), '\uc774') + + self.assertEqual(decoder.decode(b'\xc6\xc4\xc0'), '\ud30c') + self.assertRaises(UnicodeDecodeError, decoder.decode, + b'\xcc\xbd', True) + self.assertEqual(decoder.decode(b'\xcc'), '\uc774') + + def test_iso2022(self): + decoder = codecs.getincrementaldecoder('iso2022-jp')() + ESC = b'\x1b' + self.assertEqual(decoder.decode(ESC + b'('), '') + self.assertEqual(decoder.decode(b'B', True), '') + self.assertEqual(decoder.decode(ESC + b'$'), '') + self.assertEqual(decoder.decode(b'B@$'), '\u4e16') + self.assertEqual(decoder.decode(b'@$@'), '\u4e16') + self.assertEqual(decoder.decode(b'$', True), '\u4e16') + self.assertEqual(decoder.reset(), None) + self.assertEqual(decoder.decode(b'@$'), '@$') + self.assertEqual(decoder.decode(ESC + b'$'), '') + self.assertRaises(UnicodeDecodeError, decoder.decode, b'', True) + self.assertEqual(decoder.decode(b'B@$'), '\u4e16') + + def test_decode_unicode(self): + # Trying to decode a unicode string should raise a TypeError + for enc in ALL_CJKENCODINGS: + decoder = codecs.getincrementaldecoder(enc)() + self.assertRaises(TypeError, decoder.decode, "") + + def test_state_methods(self): + decoder = codecs.getincrementaldecoder('euc_jp')() + + # Decode a complete input sequence + self.assertEqual(decoder.decode(b'\xa4\xa6'), '\u3046') + pending1, _ = decoder.getstate() + self.assertEqual(pending1, b'') + + # Decode first half of a partial input sequence + self.assertEqual(decoder.decode(b'\xa4'), '') + pending2, flags2 = decoder.getstate() + self.assertEqual(pending2, b'\xa4') + + # Decode second half of a partial input sequence + self.assertEqual(decoder.decode(b'\xa6'), '\u3046') + pending3, _ = decoder.getstate() + self.assertEqual(pending3, b'') + + # Jump back and decode second half of partial input sequence again + decoder.setstate((pending2, flags2)) + self.assertEqual(decoder.decode(b'\xa6'), '\u3046') + pending4, _ = decoder.getstate() + self.assertEqual(pending4, b'') + + # Ensure state values are preserved correctly + decoder.setstate((b'abc', 123456789)) + self.assertEqual(decoder.getstate(), (b'abc', 123456789)) + + def test_setstate_validates_input(self): + decoder = codecs.getincrementaldecoder('euc_jp')() + self.assertRaises(TypeError, decoder.setstate, 123) + self.assertRaises(TypeError, decoder.setstate, ("invalid", 0)) + self.assertRaises(TypeError, decoder.setstate, (b"1234", "invalid")) + self.assertRaises(UnicodeDecodeError, decoder.setstate, (b"123456789", 0)) + +class Test_StreamReader(unittest.TestCase): + def test_bug1728403(self): + try: + f = open(TESTFN, 'wb') + try: + f.write(b'\xa1') + finally: + f.close() + with self.assertWarns(DeprecationWarning): + f = codecs.open(TESTFN, encoding='cp949') + try: + self.assertRaises(UnicodeDecodeError, f.read, 2) + finally: + f.close() + finally: + os_helper.unlink(TESTFN) + +class Test_StreamWriter(unittest.TestCase): + def test_gb18030(self): + s= io.BytesIO() + c = codecs.getwriter('gb18030')(s) + c.write('123') + self.assertEqual(s.getvalue(), b'123') + c.write('\U00012345') + self.assertEqual(s.getvalue(), b'123\x907\x959') + c.write('\uac00\u00ac') + self.assertEqual(s.getvalue(), + b'123\x907\x959\x827\xcf5\x810\x851') + + def test_utf_8(self): + s= io.BytesIO() + c = codecs.getwriter('utf-8')(s) + c.write('123') + self.assertEqual(s.getvalue(), b'123') + c.write('\U00012345') + self.assertEqual(s.getvalue(), b'123\xf0\x92\x8d\x85') + c.write('\uac00\u00ac') + self.assertEqual(s.getvalue(), + b'123\xf0\x92\x8d\x85' + b'\xea\xb0\x80\xc2\xac') + + def test_streamwriter_strwrite(self): + s = io.BytesIO() + wr = codecs.getwriter('gb18030')(s) + wr.write('abcd') + self.assertEqual(s.getvalue(), b'abcd') + +class Test_ISO2022(unittest.TestCase): + def test_g2(self): + iso2022jp2 = b'\x1b(B:hu4:unit\x1b.A\x1bNi de famille' + uni = ':hu4:unit\xe9 de famille' + self.assertEqual(iso2022jp2.decode('iso2022-jp-2'), uni) + + def test_iso2022_jp_g0(self): + self.assertNotIn(b'\x0e', '\N{SOFT HYPHEN}'.encode('iso-2022-jp-2')) + for encoding in ('iso-2022-jp-2004', 'iso-2022-jp-3'): + e = '\u3406'.encode(encoding) + self.assertFalse(any(x > 0x80 for x in e)) + + @support.requires_resource('cpu') + def test_bug1572832(self): + for x in range(0x10000, 0x110000): + # Any ISO 2022 codec will cause the segfault + chr(x).encode('iso_2022_jp', 'ignore') + +class TestStateful(unittest.TestCase): + text = '\u4E16\u4E16' + encoding = 'iso-2022-jp' + expected = b'\x1b$B@$@$' + reset = b'\x1b(B' + expected_reset = expected + reset + + def test_encode(self): + self.assertEqual(self.text.encode(self.encoding), self.expected_reset) + + def test_incrementalencoder(self): + encoder = codecs.getincrementalencoder(self.encoding)() + output = b''.join( + encoder.encode(char) + for char in self.text) + self.assertEqual(output, self.expected) + self.assertEqual(encoder.encode('', final=True), self.reset) + self.assertEqual(encoder.encode('', final=True), b'') + + def test_incrementalencoder_final(self): + encoder = codecs.getincrementalencoder(self.encoding)() + last_index = len(self.text) - 1 + output = b''.join( + encoder.encode(char, index == last_index) + for index, char in enumerate(self.text)) + self.assertEqual(output, self.expected_reset) + self.assertEqual(encoder.encode('', final=True), b'') + +class TestHZStateful(TestStateful): + text = '\u804a\u804a' + encoding = 'hz' + expected = b'~{ADAD' + reset = b'~}' + expected_reset = expected + reset + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_profile.py b/Lib/test/test_profile.py new file mode 100644 index 0000000000..56ccc84a2e --- /dev/null +++ b/Lib/test/test_profile.py @@ -0,0 +1,223 @@ +"""Test suite for the profile module.""" + +import sys +import pstats +import unittest +import os +from difflib import unified_diff +from io import StringIO +from test.support.os_helper import TESTFN, unlink, temp_dir, change_cwd +from contextlib import contextmanager, redirect_stdout + +import profile +from test.profilee import testfunc, timer +from test.support.script_helper import assert_python_failure, assert_python_ok + + +class ProfileTest(unittest.TestCase): + + profilerclass = profile.Profile + profilermodule = profile + methodnames = ['print_stats', 'print_callers', 'print_callees'] + expected_max_output = ':0(max)' + + def tearDown(self): + unlink(TESTFN) + + def get_expected_output(self): + return _ProfileOutput + + @classmethod + def do_profiling(cls): + results = [] + prof = cls.profilerclass(timer, 0.001) + start_timer = timer() + prof.runctx("testfunc()", globals(), locals()) + results.append(timer() - start_timer) + for methodname in cls.methodnames: + s = StringIO() + stats = pstats.Stats(prof, stream=s) + stats.strip_dirs().sort_stats("stdname") + getattr(stats, methodname)() + output = s.getvalue().splitlines() + mod_name = testfunc.__module__.rsplit('.', 1)[1] + # Only compare against stats originating from the test file. + # Prevents outside code (e.g., the io module) from causing + # unexpected output. + output = [line.rstrip() for line in output if mod_name in line] + results.append('\n'.join(output)) + return results + + @unittest.expectedFailure # TODO: RUSTPYTHON; print_callees output differs from CPython + def test_cprofile(self): + results = self.do_profiling() + expected = self.get_expected_output() + self.assertEqual(results[0], 1000) + fail = [] + for i, method in enumerate(self.methodnames): + a = expected[method] + b = results[i+1] + if a != b: + fail.append(f"\nStats.{method} output for " + f"{self.profilerclass.__name__} " + "does not fit expectation:") + fail.extend(unified_diff(a.split('\n'), b.split('\n'), + lineterm="")) + if fail: + self.fail("\n".join(fail)) + + def test_calling_conventions(self): + # Issue #5330: profile and cProfile wouldn't report C functions called + # with keyword arguments. We test all calling conventions. + stmts = [ + "max([0])", + "max([0], key=int)", + "max([0], **dict(key=int))", + "max(*([0],))", + "max(*([0],), key=int)", + "max(*([0],), **dict(key=int))", + ] + for stmt in stmts: + s = StringIO() + prof = self.profilerclass(timer, 0.001) + prof.runctx(stmt, globals(), locals()) + stats = pstats.Stats(prof, stream=s) + stats.print_stats() + res = s.getvalue() + self.assertIn(self.expected_max_output, res, + "Profiling {0!r} didn't report max:\n{1}".format(stmt, res)) + + def test_run(self): + with silent(): + self.profilermodule.run("int('1')") + self.profilermodule.run("int('1')", filename=TESTFN) + self.assertTrue(os.path.exists(TESTFN)) + + def test_run_with_sort_by_values(self): + with redirect_stdout(StringIO()) as f: + self.profilermodule.run("int('1')", sort=('tottime', 'stdname')) + self.assertIn("Ordered by: internal time, standard name", f.getvalue()) + + def test_runctx(self): + with silent(): + self.profilermodule.runctx("testfunc()", globals(), locals()) + self.profilermodule.runctx("testfunc()", globals(), locals(), + filename=TESTFN) + self.assertTrue(os.path.exists(TESTFN)) + + def test_run_profile_as_module(self): + # Test that -m switch needs an argument + assert_python_failure('-m', self.profilermodule.__name__, '-m') + + # Test failure for not-existent module + assert_python_failure('-m', self.profilermodule.__name__, + '-m', 'random_module_xyz') + + # Test successful run + assert_python_ok('-m', self.profilermodule.__name__, + '-m', 'timeit', '-n', '1') + + def test_output_file_when_changing_directory(self): + with temp_dir() as tmpdir, change_cwd(tmpdir): + os.mkdir('dest') + with open('demo.py', 'w', encoding="utf-8") as f: + f.write('import os; os.chdir("dest")') + + assert_python_ok( + '-m', self.profilermodule.__name__, + '-o', 'out.pstats', + 'demo.py', + ) + + self.assertTrue(os.path.exists('out.pstats')) + + +def regenerate_expected_output(filename, cls): + filename = filename.rstrip('co') + print('Regenerating %s...' % filename) + results = cls.do_profiling() + + newfile = [] + with open(filename, 'r') as f: + for line in f: + newfile.append(line) + if line.startswith('#--cut'): + break + + with open(filename, 'w') as f: + f.writelines(newfile) + f.write("_ProfileOutput = {}\n") + for i, method in enumerate(cls.methodnames): + f.write('_ProfileOutput[%r] = """\\\n%s"""\n' % ( + method, results[i+1])) + f.write('\nif __name__ == "__main__":\n main()\n') + +@contextmanager +def silent(): + stdout = sys.stdout + try: + sys.stdout = StringIO() + yield + finally: + sys.stdout = stdout + + +def main(): + if '-r' not in sys.argv: + unittest.main() + else: + regenerate_expected_output(__file__, ProfileTest) + + +# Don't remove this comment. Everything below it is auto-generated. +#--cut-------------------------------------------------------------------------- +_ProfileOutput = {} +_ProfileOutput['print_stats'] = """\ + 28 27.972 0.999 27.972 0.999 profilee.py:110(__getattr__) + 1 269.996 269.996 999.769 999.769 profilee.py:25(testfunc) + 23/3 149.937 6.519 169.917 56.639 profilee.py:35(factorial) + 20 19.980 0.999 19.980 0.999 profilee.py:48(mul) + 2 39.986 19.993 599.830 299.915 profilee.py:55(helper) + 4 115.984 28.996 119.964 29.991 profilee.py:73(helper1) + 2 -0.006 -0.003 139.946 69.973 profilee.py:84(helper2_indirect) + 8 311.976 38.997 399.912 49.989 profilee.py:88(helper2) + 8 63.976 7.997 79.960 9.995 profilee.py:98(subhelper)""" +_ProfileOutput['print_callers'] = """\ +:0(append) <- profilee.py:73(helper1)(4) 119.964 +:0(exception) <- profilee.py:73(helper1)(4) 119.964 +:0(hasattr) <- profilee.py:73(helper1)(4) 119.964 + profilee.py:88(helper2)(8) 399.912 +profilee.py:110(__getattr__) <- :0(hasattr)(12) 11.964 + profilee.py:98(subhelper)(16) 79.960 +profilee.py:25(testfunc) <- :1()(1) 999.767 +profilee.py:35(factorial) <- profilee.py:25(testfunc)(1) 999.769 + profilee.py:35(factorial)(20) 169.917 + profilee.py:84(helper2_indirect)(2) 139.946 +profilee.py:48(mul) <- profilee.py:35(factorial)(20) 169.917 +profilee.py:55(helper) <- profilee.py:25(testfunc)(2) 999.769 +profilee.py:73(helper1) <- profilee.py:55(helper)(4) 599.830 +profilee.py:84(helper2_indirect) <- profilee.py:55(helper)(2) 599.830 +profilee.py:88(helper2) <- profilee.py:55(helper)(6) 599.830 + profilee.py:84(helper2_indirect)(2) 139.946 +profilee.py:98(subhelper) <- profilee.py:88(helper2)(8) 399.912""" +_ProfileOutput['print_callees'] = """\ +:0(hasattr) -> profilee.py:110(__getattr__)(12) 27.972 +:1() -> profilee.py:25(testfunc)(1) 999.769 +profilee.py:110(__getattr__) -> +profilee.py:25(testfunc) -> profilee.py:35(factorial)(1) 169.917 + profilee.py:55(helper)(2) 599.830 +profilee.py:35(factorial) -> profilee.py:35(factorial)(20) 169.917 + profilee.py:48(mul)(20) 19.980 +profilee.py:48(mul) -> +profilee.py:55(helper) -> profilee.py:73(helper1)(4) 119.964 + profilee.py:84(helper2_indirect)(2) 139.946 + profilee.py:88(helper2)(6) 399.912 +profilee.py:73(helper1) -> :0(append)(4) -0.004 +profilee.py:84(helper2_indirect) -> profilee.py:35(factorial)(2) 169.917 + profilee.py:88(helper2)(2) 399.912 +profilee.py:88(helper2) -> :0(hasattr)(8) 11.964 + profilee.py:98(subhelper)(8) 79.960 +profilee.py:98(subhelper) -> profilee.py:110(__getattr__)(16) 27.972""" + +if __name__ == "__main__": + main() diff --git a/Lib/test/test_pstats.py b/Lib/test/test_pstats.py new file mode 100644 index 0000000000..20d3afebde --- /dev/null +++ b/Lib/test/test_pstats.py @@ -0,0 +1,164 @@ +import unittest + +from test import support +from test.support.import_helper import ensure_lazy_imports +from io import StringIO +from pstats import SortKey +from enum import StrEnum, _test_simple_enum + +import os +import pstats +import tempfile +try: + import cProfile +except ImportError: + cProfile = None + +class LazyImportTest(unittest.TestCase): + @support.cpython_only + def test_lazy_import(self): + ensure_lazy_imports("pstats", {"typing"}) + + +class AddCallersTestCase(unittest.TestCase): + """Tests for pstats.add_callers helper.""" + + def test_combine_results(self): + # pstats.add_callers should combine the call results of both target + # and source by adding the call time. See issue1269. + # new format: used by the cProfile module + target = {"a": (1, 2, 3, 4)} + source = {"a": (1, 2, 3, 4), "b": (5, 6, 7, 8)} + new_callers = pstats.add_callers(target, source) + self.assertEqual(new_callers, {'a': (2, 4, 6, 8), 'b': (5, 6, 7, 8)}) + # old format: used by the profile module + target = {"a": 1} + source = {"a": 1, "b": 5} + new_callers = pstats.add_callers(target, source) + self.assertEqual(new_callers, {'a': 2, 'b': 5}) + + +class StatsTestCase(unittest.TestCase): + def setUp(self): + stats_file = support.findfile('pstats.pck') + self.stats = pstats.Stats(stats_file) + + def test_add(self): + stream = StringIO() + stats = pstats.Stats(stream=stream) + stats.add(self.stats, self.stats) + + def test_dump_and_load_works_correctly(self): + temp_storage_new = tempfile.NamedTemporaryFile(delete=False) + try: + self.stats.dump_stats(filename=temp_storage_new.name) + tmp_stats = pstats.Stats(temp_storage_new.name) + self.assertEqual(self.stats.stats, tmp_stats.stats) + finally: + temp_storage_new.close() + os.remove(temp_storage_new.name) + + @unittest.skipUnless(cProfile, 'TODO: RUSTPYTHON; _lsprof not implemented') + def test_load_equivalent_to_init(self): + stats = pstats.Stats() + self.temp_storage = tempfile.NamedTemporaryFile(delete=False) + try: + cProfile.run('import os', filename=self.temp_storage.name) + stats.load_stats(self.temp_storage.name) + created = pstats.Stats(self.temp_storage.name) + self.assertEqual(stats.stats, created.stats) + finally: + self.temp_storage.close() + os.remove(self.temp_storage.name) + + def test_loading_wrong_types(self): + stats = pstats.Stats() + with self.assertRaises(TypeError): + stats.load_stats(42) + + def test_sort_stats_int(self): + valid_args = {-1: 'stdname', + 0: 'calls', + 1: 'time', + 2: 'cumulative'} + for arg_int, arg_str in valid_args.items(): + self.stats.sort_stats(arg_int) + self.assertEqual(self.stats.sort_type, + self.stats.sort_arg_dict_default[arg_str][-1]) + + def test_sort_stats_string(self): + for sort_name in ['calls', 'ncalls', 'cumtime', 'cumulative', + 'filename', 'line', 'module', 'name', 'nfl', 'pcalls', + 'stdname', 'time', 'tottime']: + self.stats.sort_stats(sort_name) + self.assertEqual(self.stats.sort_type, + self.stats.sort_arg_dict_default[sort_name][-1]) + + def test_sort_stats_partial(self): + sortkey = 'filename' + for sort_name in ['f', 'fi', 'fil', 'file', 'filen', 'filena', + 'filenam', 'filename']: + self.stats.sort_stats(sort_name) + self.assertEqual(self.stats.sort_type, + self.stats.sort_arg_dict_default[sortkey][-1]) + + def test_sort_stats_enum(self): + for member in SortKey: + self.stats.sort_stats(member) + self.assertEqual( + self.stats.sort_type, + self.stats.sort_arg_dict_default[member.value][-1]) + class CheckedSortKey(StrEnum): + CALLS = 'calls', 'ncalls' + CUMULATIVE = 'cumulative', 'cumtime' + FILENAME = 'filename', 'module' + LINE = 'line' + NAME = 'name' + NFL = 'nfl' + PCALLS = 'pcalls' + STDNAME = 'stdname' + TIME = 'time', 'tottime' + def __new__(cls, *values): + value = values[0] + obj = str.__new__(cls, value) + obj._value_ = value + for other_value in values[1:]: + cls._value2member_map_[other_value] = obj + obj._all_values = values + return obj + _test_simple_enum(CheckedSortKey, SortKey) + + def test_sort_starts_mix(self): + self.assertRaises(TypeError, self.stats.sort_stats, + 'calls', + SortKey.TIME) + self.assertRaises(TypeError, self.stats.sort_stats, + SortKey.TIME, + 'calls') + + @unittest.skipUnless(cProfile, 'TODO: RUSTPYTHON; _lsprof not implemented') + def test_get_stats_profile(self): + def pass1(): pass + def pass2(): pass + def pass3(): pass + + pr = cProfile.Profile() + pr.enable() + pass1() + pass2() + pass3() + pr.create_stats() + ps = pstats.Stats(pr) + + stats_profile = ps.get_stats_profile() + funcs_called = set(stats_profile.func_profiles.keys()) + self.assertIn('pass1', funcs_called) + self.assertIn('pass2', funcs_called) + self.assertIn('pass3', funcs_called) + + def test_SortKey_enum(self): + self.assertEqual(SortKey.FILENAME, 'filename') + self.assertNotEqual(SortKey.FILENAME, SortKey.CALLS) + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_tracemalloc.py b/Lib/test/test_tracemalloc.py new file mode 100644 index 0000000000..e6634f89d5 --- /dev/null +++ b/Lib/test/test_tracemalloc.py @@ -0,0 +1,1151 @@ +import contextlib +import os +import sys +import textwrap +import unittest +try: + import tracemalloc +except ImportError: + # TODO: RUSTPYTHON; _tracemalloc not implemented + raise unittest.SkipTest('tracemalloc requires _tracemalloc') +from unittest.mock import patch +from test.support.script_helper import (assert_python_ok, assert_python_failure, + interpreter_requires_environment) +from test import support +from test.support import force_not_colorized +from test.support import os_helper +from test.support import threading_helper + +try: + import _testcapi + import _testinternalcapi +except ImportError: + _testcapi = None + _testinternalcapi = None + + +DEFAULT_DOMAIN = 0 +EMPTY_STRING_SIZE = sys.getsizeof(b'') +INVALID_NFRAME = (-1, 2**30) + + +def get_frames(nframe, lineno_delta): + frames = [] + frame = sys._getframe(1) + for index in range(nframe): + code = frame.f_code + lineno = frame.f_lineno + lineno_delta + frames.append((code.co_filename, lineno)) + lineno_delta = 0 + frame = frame.f_back + if frame is None: + break + return tuple(frames) + +def allocate_bytes(size): + nframe = tracemalloc.get_traceback_limit() + bytes_len = (size - EMPTY_STRING_SIZE) + frames = get_frames(nframe, 1) + data = b'x' * bytes_len + return data, tracemalloc.Traceback(frames, min(len(frames), nframe)) + +def create_snapshots(): + traceback_limit = 2 + + # _tracemalloc._get_traces() returns a list of (domain, size, + # traceback_frames) tuples. traceback_frames is a tuple of (filename, + # line_number) tuples. + raw_traces = [ + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + + (1, 2, (('a.py', 5), ('b.py', 4)), 3), + + (2, 66, (('b.py', 1),), 1), + + (3, 7, (('', 0),), 1), + ] + snapshot = tracemalloc.Snapshot(raw_traces, traceback_limit) + + raw_traces2 = [ + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + + (2, 2, (('a.py', 5), ('b.py', 4)), 3), + (2, 5000, (('a.py', 5), ('b.py', 4)), 3), + + (4, 400, (('c.py', 578),), 1), + ] + snapshot2 = tracemalloc.Snapshot(raw_traces2, traceback_limit) + + return (snapshot, snapshot2) + +def frame(filename, lineno): + return tracemalloc._Frame((filename, lineno)) + +def traceback(*frames): + return tracemalloc.Traceback(frames) + +def traceback_lineno(filename, lineno): + return traceback((filename, lineno)) + +def traceback_filename(filename): + return traceback_lineno(filename, 0) + + +class TestTraceback(unittest.TestCase): + def test_repr(self): + def get_repr(*args) -> str: + return repr(tracemalloc.Traceback(*args)) + + self.assertEqual(get_repr(()), "") + self.assertEqual(get_repr((), 0), "") + + frames = (("f1", 1), ("f2", 2)) + exp_repr_frames = ( + "(," + " )" + ) + self.assertEqual(get_repr(frames), + f"") + self.assertEqual(get_repr(frames, 2), + f"") + + +class TestTracemallocEnabled(unittest.TestCase): + def setUp(self): + if tracemalloc.is_tracing(): + self.skipTest("tracemalloc must be stopped before the test") + + tracemalloc.start(1) + + def tearDown(self): + tracemalloc.stop() + + def test_get_tracemalloc_memory(self): + data = [allocate_bytes(123) for count in range(1000)] + size = tracemalloc.get_tracemalloc_memory() + self.assertGreaterEqual(size, 0) + + tracemalloc.clear_traces() + size2 = tracemalloc.get_tracemalloc_memory() + self.assertGreaterEqual(size2, 0) + self.assertLessEqual(size2, size) + + def test_get_object_traceback(self): + tracemalloc.clear_traces() + obj_size = 12345 + obj, obj_traceback = allocate_bytes(obj_size) + traceback = tracemalloc.get_object_traceback(obj) + self.assertEqual(traceback, obj_traceback) + + def test_new_reference(self): + tracemalloc.clear_traces() + # gc.collect() indirectly calls PyList_ClearFreeList() + support.gc_collect() + + # Create a list and "destroy it": put it in the PyListObject free list + obj = [] + obj = None + + # Create a list which should reuse the previously created empty list + obj = [] + + nframe = tracemalloc.get_traceback_limit() + frames = get_frames(nframe, -3) + obj_traceback = tracemalloc.Traceback(frames, min(len(frames), nframe)) + + traceback = tracemalloc.get_object_traceback(obj) + self.assertIsNotNone(traceback) + self.assertEqual(traceback, obj_traceback) + + def test_set_traceback_limit(self): + obj_size = 10 + + tracemalloc.stop() + self.assertRaises(ValueError, tracemalloc.start, -1) + + tracemalloc.stop() + tracemalloc.start(10) + obj2, obj2_traceback = allocate_bytes(obj_size) + traceback = tracemalloc.get_object_traceback(obj2) + self.assertEqual(len(traceback), 10) + self.assertEqual(traceback, obj2_traceback) + + tracemalloc.stop() + tracemalloc.start(1) + obj, obj_traceback = allocate_bytes(obj_size) + traceback = tracemalloc.get_object_traceback(obj) + self.assertEqual(len(traceback), 1) + self.assertEqual(traceback, obj_traceback) + + def find_trace(self, traces, traceback, size): + # filter also by size to ignore the memory allocated by + # _PyRefchain_Trace() if Python is built with Py_TRACE_REFS. + for trace in traces: + if trace[2] == traceback._frames and trace[1] == size: + return trace + + self.fail("trace not found") + + def test_get_traces(self): + tracemalloc.clear_traces() + obj_size = 12345 + obj, obj_traceback = allocate_bytes(obj_size) + + traces = tracemalloc._get_traces() + trace = self.find_trace(traces, obj_traceback, obj_size) + + self.assertIsInstance(trace, tuple) + domain, size, traceback, length = trace + self.assertEqual(traceback, obj_traceback._frames) + + tracemalloc.stop() + self.assertEqual(tracemalloc._get_traces(), []) + + def test_get_traces_intern_traceback(self): + # dummy wrappers to get more useful and identical frames in the traceback + def allocate_bytes2(size): + return allocate_bytes(size) + def allocate_bytes3(size): + return allocate_bytes2(size) + def allocate_bytes4(size): + return allocate_bytes3(size) + + # Ensure that two identical tracebacks are not duplicated + tracemalloc.stop() + tracemalloc.start(4) + obj1_size = 123 + obj2_size = 125 + obj1, obj1_traceback = allocate_bytes4(obj1_size) + obj2, obj2_traceback = allocate_bytes4(obj2_size) + + traces = tracemalloc._get_traces() + + obj1_traceback._frames = tuple(reversed(obj1_traceback._frames)) + obj2_traceback._frames = tuple(reversed(obj2_traceback._frames)) + + trace1 = self.find_trace(traces, obj1_traceback, obj1_size) + trace2 = self.find_trace(traces, obj2_traceback, obj2_size) + domain1, size1, traceback1, length1 = trace1 + domain2, size2, traceback2, length2 = trace2 + self.assertIs(traceback2, traceback1) + + def test_get_traced_memory(self): + # Python allocates some internals objects, so the test must tolerate + # a small difference between the expected size and the real usage + max_error = 2048 + + # allocate one object + obj_size = 1024 * 1024 + tracemalloc.clear_traces() + obj, obj_traceback = allocate_bytes(obj_size) + size, peak_size = tracemalloc.get_traced_memory() + self.assertGreaterEqual(size, obj_size) + self.assertGreaterEqual(peak_size, size) + + self.assertLessEqual(size - obj_size, max_error) + self.assertLessEqual(peak_size - size, max_error) + + # destroy the object + obj = None + size2, peak_size2 = tracemalloc.get_traced_memory() + self.assertLess(size2, size) + self.assertGreaterEqual(size - size2, obj_size - max_error) + self.assertGreaterEqual(peak_size2, peak_size) + + # clear_traces() must reset traced memory counters + tracemalloc.clear_traces() + self.assertEqual(tracemalloc.get_traced_memory(), (0, 0)) + + # allocate another object + obj, obj_traceback = allocate_bytes(obj_size) + size, peak_size = tracemalloc.get_traced_memory() + self.assertGreaterEqual(size, obj_size) + + # stop() also resets traced memory counters + tracemalloc.stop() + self.assertEqual(tracemalloc.get_traced_memory(), (0, 0)) + + def test_clear_traces(self): + obj, obj_traceback = allocate_bytes(123) + traceback = tracemalloc.get_object_traceback(obj) + self.assertIsNotNone(traceback) + + tracemalloc.clear_traces() + traceback2 = tracemalloc.get_object_traceback(obj) + self.assertIsNone(traceback2) + + def test_reset_peak(self): + # Python allocates some internals objects, so the test must tolerate + # a small difference between the expected size and the real usage + tracemalloc.clear_traces() + + # Example: allocate a large piece of memory, temporarily + large_sum = sum(list(range(100000))) + size1, peak1 = tracemalloc.get_traced_memory() + + # reset_peak() resets peak to traced memory: peak2 < peak1 + tracemalloc.reset_peak() + size2, peak2 = tracemalloc.get_traced_memory() + self.assertGreaterEqual(peak2, size2) + self.assertLess(peak2, peak1) + + # check that peak continue to be updated if new memory is allocated: + # peak3 > peak2 + obj_size = 1024 * 1024 + obj, obj_traceback = allocate_bytes(obj_size) + size3, peak3 = tracemalloc.get_traced_memory() + self.assertGreaterEqual(peak3, size3) + self.assertGreater(peak3, peak2) + self.assertGreaterEqual(peak3 - peak2, obj_size) + + def test_is_tracing(self): + tracemalloc.stop() + self.assertFalse(tracemalloc.is_tracing()) + + tracemalloc.start() + self.assertTrue(tracemalloc.is_tracing()) + + def test_snapshot(self): + obj, source = allocate_bytes(123) + + # take a snapshot + snapshot = tracemalloc.take_snapshot() + + # This can vary + self.assertGreater(snapshot.traces[1].traceback.total_nframe, 10) + + # write on disk + snapshot.dump(os_helper.TESTFN) + self.addCleanup(os_helper.unlink, os_helper.TESTFN) + + # load from disk + snapshot2 = tracemalloc.Snapshot.load(os_helper.TESTFN) + self.assertEqual(snapshot2.traces, snapshot.traces) + + # tracemalloc must be tracing memory allocations to take a snapshot + tracemalloc.stop() + with self.assertRaises(RuntimeError) as cm: + tracemalloc.take_snapshot() + self.assertEqual(str(cm.exception), + "the tracemalloc module must be tracing memory " + "allocations to take a snapshot") + + def test_snapshot_save_attr(self): + # take a snapshot with a new attribute + snapshot = tracemalloc.take_snapshot() + snapshot.test_attr = "new" + snapshot.dump(os_helper.TESTFN) + self.addCleanup(os_helper.unlink, os_helper.TESTFN) + + # load() should recreate the attribute + snapshot2 = tracemalloc.Snapshot.load(os_helper.TESTFN) + self.assertEqual(snapshot2.test_attr, "new") + + def fork_child(self): + if not tracemalloc.is_tracing(): + return 2 + + obj_size = 12345 + obj, obj_traceback = allocate_bytes(obj_size) + traceback = tracemalloc.get_object_traceback(obj) + if traceback is None: + return 3 + + # everything is fine + return 0 + + @support.requires_fork() + def test_fork(self): + # check that tracemalloc is still working after fork + pid = os.fork() + if not pid: + # child + exitcode = 1 + try: + exitcode = self.fork_child() + finally: + os._exit(exitcode) + else: + support.wait_process(pid, exitcode=0) + + def test_no_incomplete_frames(self): + tracemalloc.stop() + tracemalloc.start(8) + + def f(x): + def g(): + return x + return g + + obj = f(0).__closure__[0] + traceback = tracemalloc.get_object_traceback(obj) + self.assertIn("test_tracemalloc", traceback[-1].filename) + self.assertNotIn("test_tracemalloc", traceback[-2].filename) + + +class TestSnapshot(unittest.TestCase): + maxDiff = 4000 + + def test_create_snapshot(self): + raw_traces = [(0, 5, (('a.py', 2),), 10)] + + with contextlib.ExitStack() as stack: + stack.enter_context(patch.object(tracemalloc, 'is_tracing', + return_value=True)) + stack.enter_context(patch.object(tracemalloc, 'get_traceback_limit', + return_value=5)) + stack.enter_context(patch.object(tracemalloc, '_get_traces', + return_value=raw_traces)) + + snapshot = tracemalloc.take_snapshot() + self.assertEqual(snapshot.traceback_limit, 5) + self.assertEqual(len(snapshot.traces), 1) + trace = snapshot.traces[0] + self.assertEqual(trace.size, 5) + self.assertEqual(trace.traceback.total_nframe, 10) + self.assertEqual(len(trace.traceback), 1) + self.assertEqual(trace.traceback[0].filename, 'a.py') + self.assertEqual(trace.traceback[0].lineno, 2) + + def test_filter_traces(self): + snapshot, snapshot2 = create_snapshots() + filter1 = tracemalloc.Filter(False, "b.py") + filter2 = tracemalloc.Filter(True, "a.py", 2) + filter3 = tracemalloc.Filter(True, "a.py", 5) + + original_traces = list(snapshot.traces._traces) + + # exclude b.py + snapshot3 = snapshot.filter_traces((filter1,)) + self.assertEqual(snapshot3.traces._traces, [ + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (1, 2, (('a.py', 5), ('b.py', 4)), 3), + (3, 7, (('', 0),), 1), + ]) + + # filter_traces() must not touch the original snapshot + self.assertEqual(snapshot.traces._traces, original_traces) + + # only include two lines of a.py + snapshot4 = snapshot3.filter_traces((filter2, filter3)) + self.assertEqual(snapshot4.traces._traces, [ + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (1, 2, (('a.py', 5), ('b.py', 4)), 3), + ]) + + # No filter: just duplicate the snapshot + snapshot5 = snapshot.filter_traces(()) + self.assertIsNot(snapshot5, snapshot) + self.assertIsNot(snapshot5.traces, snapshot.traces) + self.assertEqual(snapshot5.traces, snapshot.traces) + + self.assertRaises(TypeError, snapshot.filter_traces, filter1) + + def test_filter_traces_domain(self): + snapshot, snapshot2 = create_snapshots() + filter1 = tracemalloc.Filter(False, "a.py", domain=1) + filter2 = tracemalloc.Filter(True, "a.py", domain=1) + + original_traces = list(snapshot.traces._traces) + + # exclude a.py of domain 1 + snapshot3 = snapshot.filter_traces((filter1,)) + self.assertEqual(snapshot3.traces._traces, [ + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (2, 66, (('b.py', 1),), 1), + (3, 7, (('', 0),), 1), + ]) + + # include domain 1 + snapshot3 = snapshot.filter_traces((filter1,)) + self.assertEqual(snapshot3.traces._traces, [ + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (2, 66, (('b.py', 1),), 1), + (3, 7, (('', 0),), 1), + ]) + + def test_filter_traces_domain_filter(self): + snapshot, snapshot2 = create_snapshots() + filter1 = tracemalloc.DomainFilter(False, domain=3) + filter2 = tracemalloc.DomainFilter(True, domain=3) + + # exclude domain 2 + snapshot3 = snapshot.filter_traces((filter1,)) + self.assertEqual(snapshot3.traces._traces, [ + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (0, 10, (('a.py', 2), ('b.py', 4)), 3), + (1, 2, (('a.py', 5), ('b.py', 4)), 3), + (2, 66, (('b.py', 1),), 1), + ]) + + # include domain 2 + snapshot3 = snapshot.filter_traces((filter2,)) + self.assertEqual(snapshot3.traces._traces, [ + (3, 7, (('', 0),), 1), + ]) + + def test_snapshot_group_by_line(self): + snapshot, snapshot2 = create_snapshots() + tb_0 = traceback_lineno('', 0) + tb_a_2 = traceback_lineno('a.py', 2) + tb_a_5 = traceback_lineno('a.py', 5) + tb_b_1 = traceback_lineno('b.py', 1) + tb_c_578 = traceback_lineno('c.py', 578) + + # stats per file and line + stats1 = snapshot.statistics('lineno') + self.assertEqual(stats1, [ + tracemalloc.Statistic(tb_b_1, 66, 1), + tracemalloc.Statistic(tb_a_2, 30, 3), + tracemalloc.Statistic(tb_0, 7, 1), + tracemalloc.Statistic(tb_a_5, 2, 1), + ]) + + # stats per file and line (2) + stats2 = snapshot2.statistics('lineno') + self.assertEqual(stats2, [ + tracemalloc.Statistic(tb_a_5, 5002, 2), + tracemalloc.Statistic(tb_c_578, 400, 1), + tracemalloc.Statistic(tb_a_2, 30, 3), + ]) + + # stats diff per file and line + statistics = snapshot2.compare_to(snapshot, 'lineno') + self.assertEqual(statistics, [ + tracemalloc.StatisticDiff(tb_a_5, 5002, 5000, 2, 1), + tracemalloc.StatisticDiff(tb_c_578, 400, 400, 1, 1), + tracemalloc.StatisticDiff(tb_b_1, 0, -66, 0, -1), + tracemalloc.StatisticDiff(tb_0, 0, -7, 0, -1), + tracemalloc.StatisticDiff(tb_a_2, 30, 0, 3, 0), + ]) + + def test_snapshot_group_by_file(self): + snapshot, snapshot2 = create_snapshots() + tb_0 = traceback_filename('') + tb_a = traceback_filename('a.py') + tb_b = traceback_filename('b.py') + tb_c = traceback_filename('c.py') + + # stats per file + stats1 = snapshot.statistics('filename') + self.assertEqual(stats1, [ + tracemalloc.Statistic(tb_b, 66, 1), + tracemalloc.Statistic(tb_a, 32, 4), + tracemalloc.Statistic(tb_0, 7, 1), + ]) + + # stats per file (2) + stats2 = snapshot2.statistics('filename') + self.assertEqual(stats2, [ + tracemalloc.Statistic(tb_a, 5032, 5), + tracemalloc.Statistic(tb_c, 400, 1), + ]) + + # stats diff per file + diff = snapshot2.compare_to(snapshot, 'filename') + self.assertEqual(diff, [ + tracemalloc.StatisticDiff(tb_a, 5032, 5000, 5, 1), + tracemalloc.StatisticDiff(tb_c, 400, 400, 1, 1), + tracemalloc.StatisticDiff(tb_b, 0, -66, 0, -1), + tracemalloc.StatisticDiff(tb_0, 0, -7, 0, -1), + ]) + + def test_snapshot_group_by_traceback(self): + snapshot, snapshot2 = create_snapshots() + + # stats per file + tb1 = traceback(('a.py', 2), ('b.py', 4)) + tb2 = traceback(('a.py', 5), ('b.py', 4)) + tb3 = traceback(('b.py', 1)) + tb4 = traceback(('', 0)) + stats1 = snapshot.statistics('traceback') + self.assertEqual(stats1, [ + tracemalloc.Statistic(tb3, 66, 1), + tracemalloc.Statistic(tb1, 30, 3), + tracemalloc.Statistic(tb4, 7, 1), + tracemalloc.Statistic(tb2, 2, 1), + ]) + + # stats per file (2) + tb5 = traceback(('c.py', 578)) + stats2 = snapshot2.statistics('traceback') + self.assertEqual(stats2, [ + tracemalloc.Statistic(tb2, 5002, 2), + tracemalloc.Statistic(tb5, 400, 1), + tracemalloc.Statistic(tb1, 30, 3), + ]) + + # stats diff per file + diff = snapshot2.compare_to(snapshot, 'traceback') + self.assertEqual(diff, [ + tracemalloc.StatisticDiff(tb2, 5002, 5000, 2, 1), + tracemalloc.StatisticDiff(tb5, 400, 400, 1, 1), + tracemalloc.StatisticDiff(tb3, 0, -66, 0, -1), + tracemalloc.StatisticDiff(tb4, 0, -7, 0, -1), + tracemalloc.StatisticDiff(tb1, 30, 0, 3, 0), + ]) + + self.assertRaises(ValueError, + snapshot.statistics, 'traceback', cumulative=True) + + def test_snapshot_group_by_cumulative(self): + snapshot, snapshot2 = create_snapshots() + tb_0 = traceback_filename('') + tb_a = traceback_filename('a.py') + tb_b = traceback_filename('b.py') + tb_a_2 = traceback_lineno('a.py', 2) + tb_a_5 = traceback_lineno('a.py', 5) + tb_b_1 = traceback_lineno('b.py', 1) + tb_b_4 = traceback_lineno('b.py', 4) + + # per file + stats = snapshot.statistics('filename', True) + self.assertEqual(stats, [ + tracemalloc.Statistic(tb_b, 98, 5), + tracemalloc.Statistic(tb_a, 32, 4), + tracemalloc.Statistic(tb_0, 7, 1), + ]) + + # per line + stats = snapshot.statistics('lineno', True) + self.assertEqual(stats, [ + tracemalloc.Statistic(tb_b_1, 66, 1), + tracemalloc.Statistic(tb_b_4, 32, 4), + tracemalloc.Statistic(tb_a_2, 30, 3), + tracemalloc.Statistic(tb_0, 7, 1), + tracemalloc.Statistic(tb_a_5, 2, 1), + ]) + + def test_trace_format(self): + snapshot, snapshot2 = create_snapshots() + trace = snapshot.traces[0] + self.assertEqual(str(trace), 'b.py:4: 10 B') + traceback = trace.traceback + self.assertEqual(str(traceback), 'b.py:4') + frame = traceback[0] + self.assertEqual(str(frame), 'b.py:4') + + def test_statistic_format(self): + snapshot, snapshot2 = create_snapshots() + stats = snapshot.statistics('lineno') + stat = stats[0] + self.assertEqual(str(stat), + 'b.py:1: size=66 B, count=1, average=66 B') + + def test_statistic_diff_format(self): + snapshot, snapshot2 = create_snapshots() + stats = snapshot2.compare_to(snapshot, 'lineno') + stat = stats[0] + self.assertEqual(str(stat), + 'a.py:5: size=5002 B (+5000 B), count=2 (+1), average=2501 B') + + def test_slices(self): + snapshot, snapshot2 = create_snapshots() + self.assertEqual(snapshot.traces[:2], + (snapshot.traces[0], snapshot.traces[1])) + + traceback = snapshot.traces[0].traceback + self.assertEqual(traceback[:2], + (traceback[0], traceback[1])) + + def test_format_traceback(self): + snapshot, snapshot2 = create_snapshots() + def getline(filename, lineno): + return ' <%s, %s>' % (filename, lineno) + with unittest.mock.patch('tracemalloc.linecache.getline', + side_effect=getline): + tb = snapshot.traces[0].traceback + self.assertEqual(tb.format(), + [' File "b.py", line 4', + ' ', + ' File "a.py", line 2', + ' ']) + + self.assertEqual(tb.format(limit=1), + [' File "a.py", line 2', + ' ']) + + self.assertEqual(tb.format(limit=-1), + [' File "b.py", line 4', + ' ']) + + self.assertEqual(tb.format(most_recent_first=True), + [' File "a.py", line 2', + ' ', + ' File "b.py", line 4', + ' ']) + + self.assertEqual(tb.format(limit=1, most_recent_first=True), + [' File "a.py", line 2', + ' ']) + + self.assertEqual(tb.format(limit=-1, most_recent_first=True), + [' File "b.py", line 4', + ' ']) + + +class TestFilters(unittest.TestCase): + maxDiff = 2048 + + def test_filter_attributes(self): + # test default values + f = tracemalloc.Filter(True, "abc") + self.assertEqual(f.inclusive, True) + self.assertEqual(f.filename_pattern, "abc") + self.assertIsNone(f.lineno) + self.assertEqual(f.all_frames, False) + + # test custom values + f = tracemalloc.Filter(False, "test.py", 123, True) + self.assertEqual(f.inclusive, False) + self.assertEqual(f.filename_pattern, "test.py") + self.assertEqual(f.lineno, 123) + self.assertEqual(f.all_frames, True) + + # parameters passed by keyword + f = tracemalloc.Filter(inclusive=False, filename_pattern="test.py", lineno=123, all_frames=True) + self.assertEqual(f.inclusive, False) + self.assertEqual(f.filename_pattern, "test.py") + self.assertEqual(f.lineno, 123) + self.assertEqual(f.all_frames, True) + + # read-only attribute + self.assertRaises(AttributeError, setattr, f, "filename_pattern", "abc") + + def test_filter_match(self): + # filter without line number + f = tracemalloc.Filter(True, "abc") + self.assertTrue(f._match_frame("abc", 0)) + self.assertTrue(f._match_frame("abc", 5)) + self.assertTrue(f._match_frame("abc", 10)) + self.assertFalse(f._match_frame("12356", 0)) + self.assertFalse(f._match_frame("12356", 5)) + self.assertFalse(f._match_frame("12356", 10)) + + f = tracemalloc.Filter(False, "abc") + self.assertFalse(f._match_frame("abc", 0)) + self.assertFalse(f._match_frame("abc", 5)) + self.assertFalse(f._match_frame("abc", 10)) + self.assertTrue(f._match_frame("12356", 0)) + self.assertTrue(f._match_frame("12356", 5)) + self.assertTrue(f._match_frame("12356", 10)) + + # filter with line number > 0 + f = tracemalloc.Filter(True, "abc", 5) + self.assertFalse(f._match_frame("abc", 0)) + self.assertTrue(f._match_frame("abc", 5)) + self.assertFalse(f._match_frame("abc", 10)) + self.assertFalse(f._match_frame("12356", 0)) + self.assertFalse(f._match_frame("12356", 5)) + self.assertFalse(f._match_frame("12356", 10)) + + f = tracemalloc.Filter(False, "abc", 5) + self.assertTrue(f._match_frame("abc", 0)) + self.assertFalse(f._match_frame("abc", 5)) + self.assertTrue(f._match_frame("abc", 10)) + self.assertTrue(f._match_frame("12356", 0)) + self.assertTrue(f._match_frame("12356", 5)) + self.assertTrue(f._match_frame("12356", 10)) + + # filter with line number 0 + f = tracemalloc.Filter(True, "abc", 0) + self.assertTrue(f._match_frame("abc", 0)) + self.assertFalse(f._match_frame("abc", 5)) + self.assertFalse(f._match_frame("abc", 10)) + self.assertFalse(f._match_frame("12356", 0)) + self.assertFalse(f._match_frame("12356", 5)) + self.assertFalse(f._match_frame("12356", 10)) + + f = tracemalloc.Filter(False, "abc", 0) + self.assertFalse(f._match_frame("abc", 0)) + self.assertTrue(f._match_frame("abc", 5)) + self.assertTrue(f._match_frame("abc", 10)) + self.assertTrue(f._match_frame("12356", 0)) + self.assertTrue(f._match_frame("12356", 5)) + self.assertTrue(f._match_frame("12356", 10)) + + def test_filter_match_filename(self): + def fnmatch(inclusive, filename, pattern): + f = tracemalloc.Filter(inclusive, pattern) + return f._match_frame(filename, 0) + + self.assertTrue(fnmatch(True, "abc", "abc")) + self.assertFalse(fnmatch(True, "12356", "abc")) + self.assertFalse(fnmatch(True, "", "abc")) + + self.assertFalse(fnmatch(False, "abc", "abc")) + self.assertTrue(fnmatch(False, "12356", "abc")) + self.assertTrue(fnmatch(False, "", "abc")) + + def test_filter_match_filename_joker(self): + def fnmatch(filename, pattern): + filter = tracemalloc.Filter(True, pattern) + return filter._match_frame(filename, 0) + + # empty string + self.assertFalse(fnmatch('abc', '')) + self.assertFalse(fnmatch('', 'abc')) + self.assertTrue(fnmatch('', '')) + self.assertTrue(fnmatch('', '*')) + + # no * + self.assertTrue(fnmatch('abc', 'abc')) + self.assertFalse(fnmatch('abc', 'abcd')) + self.assertFalse(fnmatch('abc', 'def')) + + # a* + self.assertTrue(fnmatch('abc', 'a*')) + self.assertTrue(fnmatch('abc', 'abc*')) + self.assertFalse(fnmatch('abc', 'b*')) + self.assertFalse(fnmatch('abc', 'abcd*')) + + # a*b + self.assertTrue(fnmatch('abc', 'a*c')) + self.assertTrue(fnmatch('abcdcx', 'a*cx')) + self.assertFalse(fnmatch('abb', 'a*c')) + self.assertFalse(fnmatch('abcdce', 'a*cx')) + + # a*b*c + self.assertTrue(fnmatch('abcde', 'a*c*e')) + self.assertTrue(fnmatch('abcbdefeg', 'a*bd*eg')) + self.assertFalse(fnmatch('abcdd', 'a*c*e')) + self.assertFalse(fnmatch('abcbdefef', 'a*bd*eg')) + + # replace .pyc suffix with .py + self.assertTrue(fnmatch('a.pyc', 'a.py')) + self.assertTrue(fnmatch('a.py', 'a.pyc')) + + if os.name == 'nt': + # case insensitive + self.assertTrue(fnmatch('aBC', 'ABc')) + self.assertTrue(fnmatch('aBcDe', 'Ab*dE')) + + self.assertTrue(fnmatch('a.pyc', 'a.PY')) + self.assertTrue(fnmatch('a.py', 'a.PYC')) + else: + # case sensitive + self.assertFalse(fnmatch('aBC', 'ABc')) + self.assertFalse(fnmatch('aBcDe', 'Ab*dE')) + + self.assertFalse(fnmatch('a.pyc', 'a.PY')) + self.assertFalse(fnmatch('a.py', 'a.PYC')) + + if os.name == 'nt': + # normalize alternate separator "/" to the standard separator "\" + self.assertTrue(fnmatch(r'a/b', r'a\b')) + self.assertTrue(fnmatch(r'a\b', r'a/b')) + self.assertTrue(fnmatch(r'a/b\c', r'a\b/c')) + self.assertTrue(fnmatch(r'a/b/c', r'a\b\c')) + else: + # there is no alternate separator + self.assertFalse(fnmatch(r'a/b', r'a\b')) + self.assertFalse(fnmatch(r'a\b', r'a/b')) + self.assertFalse(fnmatch(r'a/b\c', r'a\b/c')) + self.assertFalse(fnmatch(r'a/b/c', r'a\b\c')) + + # as of 3.5, .pyo is no longer munged to .py + self.assertFalse(fnmatch('a.pyo', 'a.py')) + + def test_filter_match_trace(self): + t1 = (("a.py", 2), ("b.py", 3)) + t2 = (("b.py", 4), ("b.py", 5)) + t3 = (("c.py", 5), ('', 0)) + unknown = (('', 0),) + + f = tracemalloc.Filter(True, "b.py", all_frames=True) + self.assertTrue(f._match_traceback(t1)) + self.assertTrue(f._match_traceback(t2)) + self.assertFalse(f._match_traceback(t3)) + self.assertFalse(f._match_traceback(unknown)) + + f = tracemalloc.Filter(True, "b.py", all_frames=False) + self.assertFalse(f._match_traceback(t1)) + self.assertTrue(f._match_traceback(t2)) + self.assertFalse(f._match_traceback(t3)) + self.assertFalse(f._match_traceback(unknown)) + + f = tracemalloc.Filter(False, "b.py", all_frames=True) + self.assertFalse(f._match_traceback(t1)) + self.assertFalse(f._match_traceback(t2)) + self.assertTrue(f._match_traceback(t3)) + self.assertTrue(f._match_traceback(unknown)) + + f = tracemalloc.Filter(False, "b.py", all_frames=False) + self.assertTrue(f._match_traceback(t1)) + self.assertFalse(f._match_traceback(t2)) + self.assertTrue(f._match_traceback(t3)) + self.assertTrue(f._match_traceback(unknown)) + + f = tracemalloc.Filter(False, "", all_frames=False) + self.assertTrue(f._match_traceback(t1)) + self.assertTrue(f._match_traceback(t2)) + self.assertTrue(f._match_traceback(t3)) + self.assertFalse(f._match_traceback(unknown)) + + f = tracemalloc.Filter(True, "", all_frames=True) + self.assertFalse(f._match_traceback(t1)) + self.assertFalse(f._match_traceback(t2)) + self.assertTrue(f._match_traceback(t3)) + self.assertTrue(f._match_traceback(unknown)) + + f = tracemalloc.Filter(False, "", all_frames=True) + self.assertTrue(f._match_traceback(t1)) + self.assertTrue(f._match_traceback(t2)) + self.assertFalse(f._match_traceback(t3)) + self.assertFalse(f._match_traceback(unknown)) + + +class TestCommandLine(unittest.TestCase): + def test_env_var_disabled_by_default(self): + # not tracing by default + code = 'import tracemalloc; print(tracemalloc.is_tracing())' + ok, stdout, stderr = assert_python_ok('-c', code) + stdout = stdout.rstrip() + self.assertEqual(stdout, b'False') + + @unittest.skipIf(interpreter_requires_environment(), + 'Cannot run -E tests when PYTHON env vars are required.') + def test_env_var_ignored_with_E(self): + """PYTHON* environment variables must be ignored when -E is present.""" + code = 'import tracemalloc; print(tracemalloc.is_tracing())' + ok, stdout, stderr = assert_python_ok('-E', '-c', code, PYTHONTRACEMALLOC='1') + stdout = stdout.rstrip() + self.assertEqual(stdout, b'False') + + def test_env_var_disabled(self): + # tracing at startup + code = 'import tracemalloc; print(tracemalloc.is_tracing())' + ok, stdout, stderr = assert_python_ok('-c', code, PYTHONTRACEMALLOC='0') + stdout = stdout.rstrip() + self.assertEqual(stdout, b'False') + + def test_env_var_enabled_at_startup(self): + # tracing at startup + code = 'import tracemalloc; print(tracemalloc.is_tracing())' + ok, stdout, stderr = assert_python_ok('-c', code, PYTHONTRACEMALLOC='1') + stdout = stdout.rstrip() + self.assertEqual(stdout, b'True') + + def test_env_limit(self): + # start and set the number of frames + code = 'import tracemalloc; print(tracemalloc.get_traceback_limit())' + ok, stdout, stderr = assert_python_ok('-c', code, PYTHONTRACEMALLOC='10') + stdout = stdout.rstrip() + self.assertEqual(stdout, b'10') + + @force_not_colorized + def check_env_var_invalid(self, nframe): + with support.SuppressCrashReport(): + ok, stdout, stderr = assert_python_failure( + '-c', 'pass', + PYTHONTRACEMALLOC=str(nframe)) + + if b'ValueError: the number of frames must be in range' in stderr: + return + if b'PYTHONTRACEMALLOC: invalid number of frames' in stderr: + return + self.fail(f"unexpected output: {stderr!a}") + + def test_env_var_invalid(self): + for nframe in INVALID_NFRAME: + with self.subTest(nframe=nframe): + self.check_env_var_invalid(nframe) + + def test_sys_xoptions(self): + for xoptions, nframe in ( + ('tracemalloc', 1), + ('tracemalloc=1', 1), + ('tracemalloc=15', 15), + ): + with self.subTest(xoptions=xoptions, nframe=nframe): + code = 'import tracemalloc; print(tracemalloc.get_traceback_limit())' + ok, stdout, stderr = assert_python_ok('-X', xoptions, '-c', code) + stdout = stdout.rstrip() + self.assertEqual(stdout, str(nframe).encode('ascii')) + + def check_sys_xoptions_invalid(self, nframe): + args = ('-X', 'tracemalloc=%s' % nframe, '-c', 'pass') + with support.SuppressCrashReport(): + ok, stdout, stderr = assert_python_failure(*args) + + if b'ValueError: the number of frames must be in range' in stderr: + return + if b'-X tracemalloc=NFRAME: invalid number of frames' in stderr: + return + self.fail(f"unexpected output: {stderr!a}") + + @force_not_colorized + def test_sys_xoptions_invalid(self): + for nframe in INVALID_NFRAME: + with self.subTest(nframe=nframe): + self.check_sys_xoptions_invalid(nframe) + + @unittest.skipIf(_testcapi is None, 'need _testcapi') + def test_pymem_alloc0(self): + # Issue #21639: Check that PyMem_Malloc(0) with tracemalloc enabled + # does not crash. + code = 'import _testcapi; _testcapi.test_pymem_alloc0(); 1' + assert_python_ok('-X', 'tracemalloc', '-c', code) + + +@unittest.skipIf(_testcapi is None, 'need _testcapi') +class TestCAPI(unittest.TestCase): + maxDiff = 80 * 20 + + def setUp(self): + if tracemalloc.is_tracing(): + self.skipTest("tracemalloc must be stopped before the test") + + self.domain = 5 + self.size = 123 + self.obj = allocate_bytes(self.size)[0] + + # for the type "object", id(obj) is the address of its memory block. + # This type is not tracked by the garbage collector + self.ptr = id(self.obj) + + def tearDown(self): + tracemalloc.stop() + + def get_traceback(self): + frames = _testinternalcapi._PyTraceMalloc_GetTraceback(self.domain, self.ptr) + if frames is not None: + return tracemalloc.Traceback(frames) + else: + return None + + def track(self, release_gil=False, nframe=1): + frames = get_frames(nframe, 1) + _testcapi.tracemalloc_track(self.domain, self.ptr, self.size, + release_gil) + return frames + + def untrack(self, release_gil=False): + _testcapi.tracemalloc_untrack(self.domain, self.ptr, release_gil) + + def get_traced_memory(self): + # Get the traced size in the domain + snapshot = tracemalloc.take_snapshot() + domain_filter = tracemalloc.DomainFilter(True, self.domain) + snapshot = snapshot.filter_traces([domain_filter]) + return sum(trace.size for trace in snapshot.traces) + + def check_track(self, release_gil): + nframe = 5 + tracemalloc.start(nframe) + + size = tracemalloc.get_traced_memory()[0] + + frames = self.track(release_gil, nframe) + self.assertEqual(self.get_traceback(), + tracemalloc.Traceback(frames)) + + self.assertEqual(self.get_traced_memory(), self.size) + + def test_track(self): + self.check_track(False) + + def test_track_without_gil(self): + # check that calling _PyTraceMalloc_Track() without holding the GIL + # works too + self.check_track(True) + + def test_track_already_tracked(self): + nframe = 5 + tracemalloc.start(nframe) + + # track a first time + self.track() + + # calling _PyTraceMalloc_Track() must remove the old trace and add + # a new trace with the new traceback + frames = self.track(nframe=nframe) + self.assertEqual(self.get_traceback(), + tracemalloc.Traceback(frames)) + + def check_untrack(self, release_gil): + tracemalloc.start() + + self.track() + self.assertIsNotNone(self.get_traceback()) + self.assertEqual(self.get_traced_memory(), self.size) + + # untrack must remove the trace + self.untrack(release_gil) + self.assertIsNone(self.get_traceback()) + self.assertEqual(self.get_traced_memory(), 0) + + # calling _PyTraceMalloc_Untrack() multiple times must not crash + self.untrack(release_gil) + self.untrack(release_gil) + + def test_untrack(self): + self.check_untrack(False) + + def test_untrack_without_gil(self): + self.check_untrack(True) + + def test_stop_track(self): + tracemalloc.start() + tracemalloc.stop() + + with self.assertRaises(RuntimeError): + self.track() + self.assertIsNone(self.get_traceback()) + + def test_stop_untrack(self): + tracemalloc.start() + self.track() + + tracemalloc.stop() + with self.assertRaises(RuntimeError): + self.untrack() + + @unittest.skipIf(_testcapi is None, 'need _testcapi') + @threading_helper.requires_working_threading() + # gh-128679: Test crash on a debug build (especially on FreeBSD). + @unittest.skipIf(support.Py_DEBUG, 'need release build') + @support.skip_if_sanitizer('gh-131566: race when setting allocator', thread=True) + def test_tracemalloc_track_race(self): + # gh-128679: Test fix for tracemalloc.stop() race condition + _testcapi.tracemalloc_track_race() + + def test_late_untrack(self): + code = textwrap.dedent(f""" + from test import support + import tracemalloc + import _testcapi + + class Tracked: + def __init__(self, domain, size): + self.domain = domain + self.ptr = id(self) + self.size = size + _testcapi.tracemalloc_track(self.domain, self.ptr, self.size) + + def __del__(self, untrack=_testcapi.tracemalloc_untrack): + untrack(self.domain, self.ptr, 1) + + domain = {DEFAULT_DOMAIN} + tracemalloc.start() + obj = Tracked(domain, 1024 * 1024) + support.late_deletion(obj) + """) + assert_python_ok("-c", code) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/tracemalloc.py b/Lib/tracemalloc.py new file mode 100644 index 0000000000..cec99c5970 --- /dev/null +++ b/Lib/tracemalloc.py @@ -0,0 +1,560 @@ +from collections.abc import Sequence, Iterable +from functools import total_ordering +import fnmatch +import linecache +import os.path +import pickle + +# Import types and functions implemented in C +from _tracemalloc import * +from _tracemalloc import _get_object_traceback, _get_traces + + +def _format_size(size, sign): + for unit in ('B', 'KiB', 'MiB', 'GiB', 'TiB'): + if abs(size) < 100 and unit != 'B': + # 3 digits (xx.x UNIT) + if sign: + return "%+.1f %s" % (size, unit) + else: + return "%.1f %s" % (size, unit) + if abs(size) < 10 * 1024 or unit == 'TiB': + # 4 or 5 digits (xxxx UNIT) + if sign: + return "%+.0f %s" % (size, unit) + else: + return "%.0f %s" % (size, unit) + size /= 1024 + + +class Statistic: + """ + Statistic difference on memory allocations between two Snapshot instance. + """ + + __slots__ = ('traceback', 'size', 'count') + + def __init__(self, traceback, size, count): + self.traceback = traceback + self.size = size + self.count = count + + def __hash__(self): + return hash((self.traceback, self.size, self.count)) + + def __eq__(self, other): + if not isinstance(other, Statistic): + return NotImplemented + return (self.traceback == other.traceback + and self.size == other.size + and self.count == other.count) + + def __str__(self): + text = ("%s: size=%s, count=%i" + % (self.traceback, + _format_size(self.size, False), + self.count)) + if self.count: + average = self.size / self.count + text += ", average=%s" % _format_size(average, False) + return text + + def __repr__(self): + return ('' + % (self.traceback, self.size, self.count)) + + def _sort_key(self): + return (self.size, self.count, self.traceback) + + +class StatisticDiff: + """ + Statistic difference on memory allocations between an old and a new + Snapshot instance. + """ + __slots__ = ('traceback', 'size', 'size_diff', 'count', 'count_diff') + + def __init__(self, traceback, size, size_diff, count, count_diff): + self.traceback = traceback + self.size = size + self.size_diff = size_diff + self.count = count + self.count_diff = count_diff + + def __hash__(self): + return hash((self.traceback, self.size, self.size_diff, + self.count, self.count_diff)) + + def __eq__(self, other): + if not isinstance(other, StatisticDiff): + return NotImplemented + return (self.traceback == other.traceback + and self.size == other.size + and self.size_diff == other.size_diff + and self.count == other.count + and self.count_diff == other.count_diff) + + def __str__(self): + text = ("%s: size=%s (%s), count=%i (%+i)" + % (self.traceback, + _format_size(self.size, False), + _format_size(self.size_diff, True), + self.count, + self.count_diff)) + if self.count: + average = self.size / self.count + text += ", average=%s" % _format_size(average, False) + return text + + def __repr__(self): + return ('' + % (self.traceback, self.size, self.size_diff, + self.count, self.count_diff)) + + def _sort_key(self): + return (abs(self.size_diff), self.size, + abs(self.count_diff), self.count, + self.traceback) + + +def _compare_grouped_stats(old_group, new_group): + statistics = [] + for traceback, stat in new_group.items(): + previous = old_group.pop(traceback, None) + if previous is not None: + stat = StatisticDiff(traceback, + stat.size, stat.size - previous.size, + stat.count, stat.count - previous.count) + else: + stat = StatisticDiff(traceback, + stat.size, stat.size, + stat.count, stat.count) + statistics.append(stat) + + for traceback, stat in old_group.items(): + stat = StatisticDiff(traceback, 0, -stat.size, 0, -stat.count) + statistics.append(stat) + return statistics + + +@total_ordering +class Frame: + """ + Frame of a traceback. + """ + __slots__ = ("_frame",) + + def __init__(self, frame): + # frame is a tuple: (filename: str, lineno: int) + self._frame = frame + + @property + def filename(self): + return self._frame[0] + + @property + def lineno(self): + return self._frame[1] + + def __eq__(self, other): + if not isinstance(other, Frame): + return NotImplemented + return (self._frame == other._frame) + + def __lt__(self, other): + if not isinstance(other, Frame): + return NotImplemented + return (self._frame < other._frame) + + def __hash__(self): + return hash(self._frame) + + def __str__(self): + return "%s:%s" % (self.filename, self.lineno) + + def __repr__(self): + return "" % (self.filename, self.lineno) + + +@total_ordering +class Traceback(Sequence): + """ + Sequence of Frame instances sorted from the oldest frame + to the most recent frame. + """ + __slots__ = ("_frames", '_total_nframe') + + def __init__(self, frames, total_nframe=None): + Sequence.__init__(self) + # frames is a tuple of frame tuples: see Frame constructor for the + # format of a frame tuple; it is reversed, because _tracemalloc + # returns frames sorted from most recent to oldest, but the + # Python API expects oldest to most recent + self._frames = tuple(reversed(frames)) + self._total_nframe = total_nframe + + @property + def total_nframe(self): + return self._total_nframe + + def __len__(self): + return len(self._frames) + + def __getitem__(self, index): + if isinstance(index, slice): + return tuple(Frame(trace) for trace in self._frames[index]) + else: + return Frame(self._frames[index]) + + def __contains__(self, frame): + return frame._frame in self._frames + + def __hash__(self): + return hash(self._frames) + + def __eq__(self, other): + if not isinstance(other, Traceback): + return NotImplemented + return (self._frames == other._frames) + + def __lt__(self, other): + if not isinstance(other, Traceback): + return NotImplemented + return (self._frames < other._frames) + + def __str__(self): + return str(self[0]) + + def __repr__(self): + s = f"" + return s + + def format(self, limit=None, most_recent_first=False): + lines = [] + if limit is not None: + if limit > 0: + frame_slice = self[-limit:] + else: + frame_slice = self[:limit] + else: + frame_slice = self + + if most_recent_first: + frame_slice = reversed(frame_slice) + for frame in frame_slice: + lines.append(' File "%s", line %s' + % (frame.filename, frame.lineno)) + line = linecache.getline(frame.filename, frame.lineno).strip() + if line: + lines.append(' %s' % line) + return lines + + +def get_object_traceback(obj): + """ + Get the traceback where the Python object *obj* was allocated. + Return a Traceback instance. + + Return None if the tracemalloc module is not tracing memory allocations or + did not trace the allocation of the object. + """ + frames = _get_object_traceback(obj) + if frames is not None: + return Traceback(frames) + else: + return None + + +class Trace: + """ + Trace of a memory block. + """ + __slots__ = ("_trace",) + + def __init__(self, trace): + # trace is a tuple: (domain: int, size: int, traceback: tuple). + # See Traceback constructor for the format of the traceback tuple. + self._trace = trace + + @property + def domain(self): + return self._trace[0] + + @property + def size(self): + return self._trace[1] + + @property + def traceback(self): + return Traceback(*self._trace[2:]) + + def __eq__(self, other): + if not isinstance(other, Trace): + return NotImplemented + return (self._trace == other._trace) + + def __hash__(self): + return hash(self._trace) + + def __str__(self): + return "%s: %s" % (self.traceback, _format_size(self.size, False)) + + def __repr__(self): + return ("" + % (self.domain, _format_size(self.size, False), self.traceback)) + + +class _Traces(Sequence): + def __init__(self, traces): + Sequence.__init__(self) + # traces is a tuple of trace tuples: see Trace constructor + self._traces = traces + + def __len__(self): + return len(self._traces) + + def __getitem__(self, index): + if isinstance(index, slice): + return tuple(Trace(trace) for trace in self._traces[index]) + else: + return Trace(self._traces[index]) + + def __contains__(self, trace): + return trace._trace in self._traces + + def __eq__(self, other): + if not isinstance(other, _Traces): + return NotImplemented + return (self._traces == other._traces) + + def __repr__(self): + return "" % len(self) + + +def _normalize_filename(filename): + filename = os.path.normcase(filename) + if filename.endswith('.pyc'): + filename = filename[:-1] + return filename + + +class BaseFilter: + def __init__(self, inclusive): + self.inclusive = inclusive + + def _match(self, trace): + raise NotImplementedError + + +class Filter(BaseFilter): + def __init__(self, inclusive, filename_pattern, + lineno=None, all_frames=False, domain=None): + super().__init__(inclusive) + self.inclusive = inclusive + self._filename_pattern = _normalize_filename(filename_pattern) + self.lineno = lineno + self.all_frames = all_frames + self.domain = domain + + @property + def filename_pattern(self): + return self._filename_pattern + + def _match_frame_impl(self, filename, lineno): + filename = _normalize_filename(filename) + if not fnmatch.fnmatch(filename, self._filename_pattern): + return False + if self.lineno is None: + return True + else: + return (lineno == self.lineno) + + def _match_frame(self, filename, lineno): + return self._match_frame_impl(filename, lineno) ^ (not self.inclusive) + + def _match_traceback(self, traceback): + if self.all_frames: + if any(self._match_frame_impl(filename, lineno) + for filename, lineno in traceback): + return self.inclusive + else: + return (not self.inclusive) + else: + filename, lineno = traceback[0] + return self._match_frame(filename, lineno) + + def _match(self, trace): + domain, size, traceback, total_nframe = trace + res = self._match_traceback(traceback) + if self.domain is not None: + if self.inclusive: + return res and (domain == self.domain) + else: + return res or (domain != self.domain) + return res + + +class DomainFilter(BaseFilter): + def __init__(self, inclusive, domain): + super().__init__(inclusive) + self._domain = domain + + @property + def domain(self): + return self._domain + + def _match(self, trace): + domain, size, traceback, total_nframe = trace + return (domain == self.domain) ^ (not self.inclusive) + + +class Snapshot: + """ + Snapshot of traces of memory blocks allocated by Python. + """ + + def __init__(self, traces, traceback_limit): + # traces is a tuple of trace tuples: see _Traces constructor for + # the exact format + self.traces = _Traces(traces) + self.traceback_limit = traceback_limit + + def dump(self, filename): + """ + Write the snapshot into a file. + """ + with open(filename, "wb") as fp: + pickle.dump(self, fp, pickle.HIGHEST_PROTOCOL) + + @staticmethod + def load(filename): + """ + Load a snapshot from a file. + """ + with open(filename, "rb") as fp: + return pickle.load(fp) + + def _filter_trace(self, include_filters, exclude_filters, trace): + if include_filters: + if not any(trace_filter._match(trace) + for trace_filter in include_filters): + return False + if exclude_filters: + if any(not trace_filter._match(trace) + for trace_filter in exclude_filters): + return False + return True + + def filter_traces(self, filters): + """ + Create a new Snapshot instance with a filtered traces sequence, filters + is a list of Filter or DomainFilter instances. If filters is an empty + list, return a new Snapshot instance with a copy of the traces. + """ + if not isinstance(filters, Iterable): + raise TypeError("filters must be a list of filters, not %s" + % type(filters).__name__) + if filters: + include_filters = [] + exclude_filters = [] + for trace_filter in filters: + if trace_filter.inclusive: + include_filters.append(trace_filter) + else: + exclude_filters.append(trace_filter) + new_traces = [trace for trace in self.traces._traces + if self._filter_trace(include_filters, + exclude_filters, + trace)] + else: + new_traces = self.traces._traces.copy() + return Snapshot(new_traces, self.traceback_limit) + + def _group_by(self, key_type, cumulative): + if key_type not in ('traceback', 'filename', 'lineno'): + raise ValueError("unknown key_type: %r" % (key_type,)) + if cumulative and key_type not in ('lineno', 'filename'): + raise ValueError("cumulative mode cannot by used " + "with key type %r" % key_type) + + stats = {} + tracebacks = {} + if not cumulative: + for trace in self.traces._traces: + domain, size, trace_traceback, total_nframe = trace + try: + traceback = tracebacks[trace_traceback] + except KeyError: + if key_type == 'traceback': + frames = trace_traceback + elif key_type == 'lineno': + frames = trace_traceback[:1] + else: # key_type == 'filename': + frames = ((trace_traceback[0][0], 0),) + traceback = Traceback(frames) + tracebacks[trace_traceback] = traceback + try: + stat = stats[traceback] + stat.size += size + stat.count += 1 + except KeyError: + stats[traceback] = Statistic(traceback, size, 1) + else: + # cumulative statistics + for trace in self.traces._traces: + domain, size, trace_traceback, total_nframe = trace + for frame in trace_traceback: + try: + traceback = tracebacks[frame] + except KeyError: + if key_type == 'lineno': + frames = (frame,) + else: # key_type == 'filename': + frames = ((frame[0], 0),) + traceback = Traceback(frames) + tracebacks[frame] = traceback + try: + stat = stats[traceback] + stat.size += size + stat.count += 1 + except KeyError: + stats[traceback] = Statistic(traceback, size, 1) + return stats + + def statistics(self, key_type, cumulative=False): + """ + Group statistics by key_type. Return a sorted list of Statistic + instances. + """ + grouped = self._group_by(key_type, cumulative) + statistics = list(grouped.values()) + statistics.sort(reverse=True, key=Statistic._sort_key) + return statistics + + def compare_to(self, old_snapshot, key_type, cumulative=False): + """ + Compute the differences with an old snapshot old_snapshot. Get + statistics as a sorted list of StatisticDiff instances, grouped by + group_by. + """ + new_group = self._group_by(key_type, cumulative) + old_group = old_snapshot._group_by(key_type, cumulative) + statistics = _compare_grouped_stats(old_group, new_group) + statistics.sort(reverse=True, key=StatisticDiff._sort_key) + return statistics + + +def take_snapshot(): + """ + Take a snapshot of traces of memory blocks allocated by Python. + """ + if not is_tracing(): + raise RuntimeError("the tracemalloc module must be tracing memory " + "allocations to take a snapshot") + traces = _get_traces() + traceback_limit = get_traceback_limit() + return Snapshot(traces, traceback_limit)