2929import threading
3030import sys
3131from typing import (NamedTuple , Any , Union , TYPE_CHECKING , Optional , Tuple ,
32- Dict , Iterable , List , Sequence )
32+ Dict , Iterable , List , Sequence , Callable , TypeVar )
3333import concurrent
3434from concurrent import futures
35+ from functools import wraps , partial
3536
3637from .i18n import _
3738from .util import (profiler , DaemonThread , UserCancelled , ThreadJob , UserFacingException )
@@ -334,11 +335,37 @@ class HardwarePluginToScan(NamedTuple):
334335# https://github.com/signal11/hidapi/pull/414#issuecomment-445164238
335336# It is not entirely clear to me, exactly what is safe and what isn't, when
336337# using multiple threads...
337- # For now, we use a dedicated thread to enumerate devices (_hid_executor),
338- # and we synchronize all device opens/closes/enumeration (_hid_lock).
339- # FIXME there are still probably threading issues with how we use hidapi...
340- _hid_executor = None # type: Optional[concurrent.futures.Executor]
341- _hid_lock = threading .Lock ()
338+ # Hence, we use a single thread for all device communications, including
339+ # enumeration. Everything that uses hidapi, libusb, etc, MUST run on
340+ # the following thread:
341+ _hwd_comms_executor = concurrent .futures .ThreadPoolExecutor (
342+ max_workers = 1 ,
343+ thread_name_prefix = 'hwd_comms_thread'
344+ )
345+
346+
347+ T = TypeVar ('T' )
348+
349+
350+ def run_in_hwd_thread (func : Callable [[], T ]) -> T :
351+ if threading .current_thread ().name .startswith ("hwd_comms_thread" ):
352+ return func ()
353+ else :
354+ fut = _hwd_comms_executor .submit (func )
355+ return fut .result ()
356+ #except (concurrent.futures.CancelledError, concurrent.futures.TimeoutError) as e:
357+
358+
359+ def runs_in_hwd_thread (func ):
360+ @wraps (func )
361+ def wrapper (* args , ** kwargs ):
362+ return run_in_hwd_thread (partial (func , * args , ** kwargs ))
363+ return wrapper
364+
365+
366+ def assert_runs_in_hwd_thread ():
367+ if not threading .current_thread ().name .startswith ("hwd_comms_thread" ):
368+ raise Exception ("must only be called from HWD communication thread" )
342369
343370
344371class DeviceMgr (ThreadJob ):
@@ -384,24 +411,11 @@ def __init__(self, config: SimpleConfig):
384411 self ._recognised_hardware = {} # type: Dict[Tuple[int, int], HW_PluginBase]
385412 # Custom enumerate functions for devices we don't know about.
386413 self ._enumerate_func = set () # Needs self.lock.
387- # locks: if you need to take multiple ones, acquire them in the order they are defined here!
388- self ._scan_lock = threading .RLock ()
414+
389415 self .lock = threading .RLock ()
390- self .hid_lock = _hid_lock
391416
392417 self .config = config
393418
394- global _hid_executor
395- if _hid_executor is None :
396- _hid_executor = concurrent .futures .ThreadPoolExecutor (max_workers = 1 ,
397- thread_name_prefix = 'hid_enumerate_thread' )
398-
399- def with_scan_lock (func ):
400- def func_wrapper (self : 'DeviceMgr' , * args , ** kwargs ):
401- with self ._scan_lock :
402- return func (self , * args , ** kwargs )
403- return func_wrapper
404-
405419 def thread_jobs (self ):
406420 # Thread job to handle device timeouts
407421 return [self ]
@@ -423,6 +437,7 @@ def register_enumerate_func(self, func):
423437 with self .lock :
424438 self ._enumerate_func .add (func )
425439
440+ @runs_in_hwd_thread
426441 def create_client (self , device : 'Device' , handler : Optional ['HardwareHandlerBase' ],
427442 plugin : 'HW_PluginBase' ) -> Optional ['HardwareClientBase' ]:
428443 # Get from cache first
@@ -452,7 +467,7 @@ def unpair_xpub(self, xpub):
452467 if xpub not in self .xpub_ids :
453468 return
454469 _id = self .xpub_ids .pop (xpub )
455- self ._close_client (_id )
470+ self ._close_client (_id )
456471
457472 def unpair_id (self , id_ ):
458473 xpub = self .xpub_by_id (id_ )
@@ -462,8 +477,9 @@ def unpair_id(self, id_):
462477 self ._close_client (id_ )
463478
464479 def _close_client (self , id_ ):
465- client = self ._client_by_id (id_ )
466- self .clients .pop (client , None )
480+ with self .lock :
481+ client = self ._client_by_id (id_ )
482+ self .clients .pop (client , None )
467483 if client :
468484 client .close ()
469485
@@ -486,7 +502,7 @@ def client_by_id(self, id_, *, scan_now: bool = True) -> Optional['HardwareClien
486502 self .scan_devices ()
487503 return self ._client_by_id (id_ )
488504
489- @with_scan_lock
505+ @runs_in_hwd_thread
490506 def client_for_keystore (self , plugin : 'HW_PluginBase' , handler : Optional ['HardwareHandlerBase' ],
491507 keystore : 'Hardware_KeyStore' ,
492508 force_pair : bool , * ,
@@ -655,33 +671,23 @@ def select_device(self, plugin: 'HW_PluginBase', handler: 'HardwareHandlerBase',
655671 # note: updated label/soft_device_id will be saved after pairing succeeds
656672 return info
657673
658- @with_scan_lock
674+ @runs_in_hwd_thread
659675 def _scan_devices_with_hid (self ) -> List ['Device' ]:
660676 try :
661677 import hid
662678 except ImportError :
663679 return []
664680
665- def hid_enumerate ():
666- with self .hid_lock :
667- return hid .enumerate (0 , 0 )
668-
669- hid_list_fut = _hid_executor .submit (hid_enumerate )
670- try :
671- hid_list = hid_list_fut .result ()
672- except (concurrent .futures .CancelledError , concurrent .futures .TimeoutError ) as e :
673- return []
674-
675681 devices = []
676- for d in hid_list :
682+ for d in hid . enumerate ( 0 , 0 ) :
677683 product_key = (d ['vendor_id' ], d ['product_id' ])
678684 if product_key in self ._recognised_hardware :
679685 plugin = self ._recognised_hardware [product_key ]
680686 device = plugin .create_device_from_hid_enumeration (d , product_key = product_key )
681687 devices .append (device )
682688 return devices
683689
684- @with_scan_lock
690+ @runs_in_hwd_thread
685691 @profiler
686692 def scan_devices (self ) -> Sequence ['Device' ]:
687693 self .logger .info ("scanning devices..." )
@@ -693,10 +699,8 @@ def scan_devices(self) -> Sequence['Device']:
693699 with self .lock :
694700 enumerate_funcs = list (self ._enumerate_func )
695701 for f in enumerate_funcs :
696- # custom enumerate functions might use hidapi, so use hid thread to be safe
697- new_devices_fut = _hid_executor .submit (f )
698702 try :
699- new_devices = new_devices_fut . result ()
703+ new_devices = f ()
700704 except BaseException as e :
701705 self .logger .error ('custom device enum failed. func {}, error {}'
702706 .format (str (f ), repr (e )))
0 commit comments