diff --git a/.travis.yml b/.travis.yml index 9977551a3..785aa2a6b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -89,22 +89,7 @@ jobs: --python-version=3.7 --ignore-missing-imports --no-implicit-optional - can/bit_timing.py - can/broadcastmanager.py - can/bus.py - can/interface.py - can/interfaces/udp_multicast/**.py - can/interfaces/slcan.py - can/interfaces/socketcan/**.py - can/interfaces/virtual.py - can/listener.py - can/logger.py - can/message.py - can/notifier.py - can/player.py - can/thread_safe_bus.py - can/typechecking.py - can/util.py + can/*.py can/io/**.py scripts/**.py examples/**.py diff --git a/can/broadcastmanager.py b/can/broadcastmanager.py index 62fe34988..6dd4f3fd7 100644 --- a/can/broadcastmanager.py +++ b/can/broadcastmanager.py @@ -36,7 +36,7 @@ class CyclicTask: """ @abc.abstractmethod - def stop(self): + def stop(self) -> None: """Cancel this periodic task. :raises can.CanError: @@ -49,7 +49,9 @@ class CyclicSendTaskABC(CyclicTask): Message send task with defined period """ - def __init__(self, messages: Union[Sequence[Message], Message], period: float): + def __init__( + self, messages: Union[Sequence[Message], Message], period: float + ) -> None: """ :param messages: The messages to be sent periodically. @@ -104,7 +106,7 @@ def __init__( messages: Union[Sequence[Message], Message], period: float, duration: Optional[float], - ): + ) -> None: """Message send task with a defined duration and period. :param messages: @@ -122,14 +124,14 @@ class RestartableCyclicTaskABC(CyclicSendTaskABC): """Adds support for restarting a stopped cyclic task""" @abc.abstractmethod - def start(self): + def start(self) -> None: """Restart a stopped periodic task.""" class ModifiableCyclicTaskABC(CyclicSendTaskABC): """Adds support for modifying a periodic message""" - def _check_modified_messages(self, messages: Tuple[Message, ...]): + def _check_modified_messages(self, messages: Tuple[Message, ...]) -> None: """Helper function to perform error checking when modifying the data in the cyclic task. @@ -149,7 +151,7 @@ def _check_modified_messages(self, messages: Tuple[Message, ...]): "from when the task was created" ) - def modify_data(self, messages: Union[Sequence[Message], Message]): + def modify_data(self, messages: Union[Sequence[Message], Message]) -> None: """Update the contents of the periodically sent messages, without altering the timing. @@ -178,7 +180,7 @@ def __init__( count: int, initial_period: float, subsequent_period: float, - ): + ) -> None: """ Transmits a message `count` times at `initial_period` then continues to transmit messages at `subsequent_period`. @@ -206,12 +208,12 @@ def __init__( period: float, duration: Optional[float] = None, on_error: Optional[Callable[[Exception], bool]] = None, - ): + ) -> None: """Transmits `messages` with a `period` seconds for `duration` seconds on a `bus`. The `on_error` is called if any error happens on `bus` while sending `messages`. If `on_error` present, and returns ``False`` when invoked, thread is - stopped immediately, otherwise, thread continuiously tries to send `messages` + stopped immediately, otherwise, thread continuously tries to send `messages` ignoring errors on a `bus`. Absence of `on_error` means that thread exits immediately on error. @@ -224,22 +226,24 @@ def __init__( self.bus = bus self.send_lock = lock self.stopped = True - self.thread = None - self.end_time = time.perf_counter() + duration if duration else None + self.thread: Optional[threading.Thread] = None + self.end_time: Optional[float] = ( + time.perf_counter() + duration if duration else None + ) self.on_error = on_error if HAS_EVENTS: - self.period_ms: int = int(round(period * 1000, 0)) + self.period_ms = int(round(period * 1000, 0)) self.event = win32event.CreateWaitableTimer(None, False, None) self.start() - def stop(self): + def stop(self) -> None: if HAS_EVENTS: win32event.CancelWaitableTimer(self.event.handle) self.stopped = True - def start(self): + def start(self) -> None: self.stopped = False if self.thread is None or not self.thread.is_alive(): name = "Cyclic send task for 0x%X" % (self.messages[0].arbitration_id) @@ -253,7 +257,7 @@ def start(self): self.thread.start() - def _run(self): + def _run(self) -> None: msg_index = 0 while not self.stopped: # Prevent calling bus.send from multiple threads diff --git a/can/bus.py b/can/bus.py index 6258e9a82..9fba7948e 100644 --- a/can/bus.py +++ b/can/bus.py @@ -6,14 +6,14 @@ import can.typechecking -from abc import ABCMeta, abstractmethod +from abc import ABC, ABCMeta, abstractmethod import can import logging import threading from time import time from enum import Enum, auto -from can.broadcastmanager import ThreadBasedCyclicSendTask +from can.broadcastmanager import ThreadBasedCyclicSendTask, CyclicSendTaskABC from can.message import Message LOG = logging.getLogger(__name__) @@ -61,7 +61,7 @@ def __init__( :param dict kwargs: Any backend dependent configurations are passed in this dictionary """ - self._periodic_tasks: List[can.broadcastmanager.CyclicSendTaskABC] = [] + self._periodic_tasks: List[_SelfRemovingCyclicTask] = [] self.set_filters(can_filters) def __str__(self) -> str: @@ -172,7 +172,7 @@ def send(self, msg: Message, timeout: Optional[float] = None): def send_periodic( self, - msgs: Union[Sequence[Message], Message], + msgs: Union[Message, Sequence[Message]], period: float, duration: Optional[float] = None, store_task: bool = True, @@ -188,7 +188,7 @@ def send_periodic( - the task's :meth:`CyclicTask.stop()` method is called. :param msgs: - Messages to transmit + Message(s) to transmit :param period: Period in seconds between each message :param duration: @@ -215,26 +215,35 @@ def send_periodic( appropriate as the stopped tasks are still taking up memory as they are associated with the Bus instance. """ - if not isinstance(msgs, (list, tuple)): - if isinstance(msgs, Message): - msgs = [msgs] - else: - raise ValueError("Must be either a list, tuple, or a Message") - if not msgs: - raise ValueError("Must be at least a list or tuple of length 1") - task = self._send_periodic_internal(msgs, period, duration) + if isinstance(msgs, Message): + msgs = [msgs] + elif isinstance(msgs, Sequence): + # A Sequence does not necessarily provide __bool__ we need to use len() + if len(msgs) == 0: + raise ValueError("Must be a sequence at least of length 1") + else: + raise ValueError("Must be either a message or a sequence of messages") + + # Create a backend specific task; will be patched to a _SelfRemovingCyclicTask later + task = cast( + _SelfRemovingCyclicTask, + self._send_periodic_internal(msgs, period, duration), + ) + # we wrap the task's stop method to also remove it from the Bus's list of tasks + periodic_tasks = self._periodic_tasks original_stop_method = task.stop - def wrapped_stop_method(remove_task=True): + def wrapped_stop_method(remove_task: bool = True) -> None: + nonlocal task, periodic_tasks, original_stop_method if remove_task: try: - self._periodic_tasks.remove(task) + periodic_tasks.remove(task) except ValueError: - pass + pass # allow the task to be already removed original_stop_method() - setattr(task, "stop", wrapped_stop_method) + task.stop = wrapped_stop_method # type: ignore if store_task: self._periodic_tasks.append(task) @@ -273,13 +282,13 @@ def _send_periodic_internal( ) return task - def stop_all_periodic_tasks(self, remove_tasks=True): + def stop_all_periodic_tasks(self, remove_tasks: bool = True) -> None: """Stop sending any messages that were started using **bus.send_periodic**. .. note:: The result is undefined if a single task throws an exception while being stopped. - :param bool remove_tasks: + :param remove_tasks: Stop tracking the stopped tasks. """ for task in self._periodic_tasks: @@ -288,7 +297,7 @@ def stop_all_periodic_tasks(self, remove_tasks=True): task.stop(remove_task=False) if remove_tasks: - self._periodic_tasks = [] + self._periodic_tasks.clear() def __iter__(self) -> Iterator[Message]: """Allow iteration on messages as they are received. @@ -432,3 +441,10 @@ def _detect_available_configs() -> List[can.typechecking.AutoDetectedConfig]: def fileno(self) -> int: raise NotImplementedError("fileno is not implemented using current CAN bus") + + +class _SelfRemovingCyclicTask(CyclicSendTaskABC, ABC): + """Removes itself from a bus. Only needed for typing :meth:`Bus._periodic_tasks`. Do not instantiate.""" + + def stop(self, remove_task: bool = True) -> None: + raise NotImplementedError() diff --git a/can/ctypesutil.py b/can/ctypesutil.py index 13668d1e6..1063130d3 100644 --- a/can/ctypesutil.py +++ b/can/ctypesutil.py @@ -6,23 +6,40 @@ import logging import sys +from typing import Any, Callable, Optional, Tuple, Union + log = logging.getLogger("can.ctypesutil") __all__ = ["CLibrary", "HANDLE", "PHANDLE", "HRESULT"] + try: - _LibBase = ctypes.WinDLL + _LibBase = ctypes.WinDLL # type: ignore + _FUNCTION_TYPE = ctypes.WINFUNCTYPE # type: ignore except AttributeError: _LibBase = ctypes.CDLL + _FUNCTION_TYPE = ctypes.CFUNCTYPE -class LibraryMixin: - def map_symbol(self, func_name, restype=None, argtypes=(), errcheck=None): +class CLibrary(_LibBase): # type: ignore + def __init__(self, library_or_path: Union[str, ctypes.CDLL]) -> None: + if isinstance(library_or_path, str): + super().__init__(library_or_path) + else: + super().__init__(library_or_path._name, library_or_path._handle) + + def map_symbol( + self, + func_name: str, + restype: Any = None, + argtypes: Tuple[Any, ...] = (), + errcheck: Optional[Callable[..., Any]] = None, + ) -> Any: """ Map and return a symbol (function) from a C library. A reference to the mapped symbol is also held in the instance - :param str func_name: + :param func_name: symbol_name :param ctypes.c_* restype: function result type (i.e. ctypes.c_ulong...), defaults to void @@ -32,67 +49,36 @@ def map_symbol(self, func_name, restype=None, argtypes=(), errcheck=None): optional error checking function, see ctypes docs for _FuncPtr """ if argtypes: - prototype = self.function_type(restype, *argtypes) + prototype = _FUNCTION_TYPE(restype, *argtypes) else: - prototype = self.function_type(restype) + prototype = _FUNCTION_TYPE(restype) try: - symbol = prototype((func_name, self)) + symbol: Any = prototype((func_name, self)) except AttributeError: raise ImportError( - "Could not map function '{}' from library {}".format( - func_name, self._name - ) + f'Could not map function "{func_name}" from library {self._name}' ) from None - setattr(symbol, "_name", func_name) + symbol._name = func_name log.debug( f'Wrapped function "{func_name}", result type: {type(restype)}, error_check {errcheck}' ) - if errcheck: + if errcheck is not None: symbol.errcheck = errcheck - setattr(self, func_name, symbol) - return symbol - - -class CLibrary_Win32(_LibBase, LibraryMixin): - " Basic ctypes.WinDLL derived class + LibraryMixin " - - def __init__(self, library_or_path): - if isinstance(library_or_path, str): - super().__init__(library_or_path) - else: - super().__init__(library_or_path._name, library_or_path._handle) - - @property - def function_type(self): - return ctypes.WINFUNCTYPE - + self.func_name = symbol -class CLibrary_Unix(ctypes.CDLL, LibraryMixin): - " Basic ctypes.CDLL derived class + LibraryMixin " - - def __init__(self, library_or_path): - if isinstance(library_or_path, str): - super().__init__(library_or_path) - else: - super().__init__(library_or_path._name, library_or_path._handle) - - @property - def function_type(self): - return ctypes.CFUNCTYPE + return symbol if sys.platform == "win32": - CLibrary = CLibrary_Win32 HRESULT = ctypes.HRESULT -else: - CLibrary = CLibrary_Unix - if sys.platform == "cygwin": - # Define HRESULT for cygwin - class HRESULT(ctypes.c_long): - pass + +elif sys.platform == "cygwin": + + class HRESULT(ctypes.c_long): + pass # Common win32 definitions diff --git a/test/test_socketcan.py b/test/test_socketcan.py index d4077e36f..c322bbf75 100644 --- a/test/test_socketcan.py +++ b/test/test_socketcan.py @@ -245,12 +245,13 @@ def side_effect_ctypes_alignment(value): "Should only run on platforms where sizeof(long) == 4 and alignof(long) == 4", ) def test_build_bcm_header_sizeof_long_4_alignof_long_4(self): - expected_result = b"" - expected_result += b"\x02\x00\x00\x00\x00\x00\x00\x00" - expected_result += b"\x00\x00\x00\x00\x00\x00\x00\x00" - expected_result += b"\x00\x00\x00\x00\x00\x00\x00\x00" - expected_result += b"\x00\x00\x00\x00\x01\x04\x00\x00" - expected_result += b"\x01\x00\x00\x00\x00\x00\x00\x00" + expected_result = ( + b"\x02\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x00\x01\x04\x00\x00" + b"\x01\x00\x00\x00\x00\x00\x00\x00" + ) self.assertEqual( expected_result, @@ -274,14 +275,15 @@ def test_build_bcm_header_sizeof_long_4_alignof_long_4(self): "Should only run on platforms where sizeof(long) == 8 and alignof(long) == 8", ) def test_build_bcm_header_sizeof_long_8_alignof_long_8(self): - expected_result = b"" - expected_result += b"\x02\x00\x00\x00\x00\x00\x00\x00" - expected_result += b"\x00\x00\x00\x00\x00\x00\x00\x00" - expected_result += b"\x00\x00\x00\x00\x00\x00\x00\x00" - expected_result += b"\x00\x00\x00\x00\x00\x00\x00\x00" - expected_result += b"\x00\x00\x00\x00\x00\x00\x00\x00" - expected_result += b"\x00\x00\x00\x00\x00\x00\x00\x00" - expected_result += b"\x01\x04\x00\x00\x01\x00\x00\x00" + expected_result = ( + b"\x02\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x01\x04\x00\x00\x01\x00\x00\x00" + ) self.assertEqual( expected_result,