Skip to content

Commit 21c3572

Browse files
authored
hardware devices: run all device communication on dedicated thread (spesmilo#6561)
hidapi/libusb etc are not thread-safe. related: spesmilo#6554
1 parent 53a5a21 commit 21c3572

12 files changed

Lines changed: 195 additions & 97 deletions

File tree

electrum/plugin.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@
2929
import threading
3030
import sys
3131
from typing import (NamedTuple, Any, Union, TYPE_CHECKING, Optional, Tuple,
32-
Dict, Iterable, List, Sequence)
32+
Dict, Iterable, List, Sequence, Callable, TypeVar)
3333
import concurrent
3434
from concurrent import futures
35+
from functools import wraps, partial
3536

3637
from .i18n import _
3738
from .util import (profiler, DaemonThread, UserCancelled, ThreadJob, UserFacingException)
@@ -334,11 +335,37 @@ class HardwarePluginToScan(NamedTuple):
334335
# https://github.com/signal11/hidapi/pull/414#issuecomment-445164238
335336
# It is not entirely clear to me, exactly what is safe and what isn't, when
336337
# using multiple threads...
337-
# For now, we use a dedicated thread to enumerate devices (_hid_executor),
338-
# and we synchronize all device opens/closes/enumeration (_hid_lock).
339-
# FIXME there are still probably threading issues with how we use hidapi...
340-
_hid_executor = None # type: Optional[concurrent.futures.Executor]
341-
_hid_lock = threading.Lock()
338+
# Hence, we use a single thread for all device communications, including
339+
# enumeration. Everything that uses hidapi, libusb, etc, MUST run on
340+
# the following thread:
341+
_hwd_comms_executor = concurrent.futures.ThreadPoolExecutor(
342+
max_workers=1,
343+
thread_name_prefix='hwd_comms_thread'
344+
)
345+
346+
347+
T = TypeVar('T')
348+
349+
350+
def run_in_hwd_thread(func: Callable[[], T]) -> T:
351+
if threading.current_thread().name.startswith("hwd_comms_thread"):
352+
return func()
353+
else:
354+
fut = _hwd_comms_executor.submit(func)
355+
return fut.result()
356+
#except (concurrent.futures.CancelledError, concurrent.futures.TimeoutError) as e:
357+
358+
359+
def runs_in_hwd_thread(func):
360+
@wraps(func)
361+
def wrapper(*args, **kwargs):
362+
return run_in_hwd_thread(partial(func, *args, **kwargs))
363+
return wrapper
364+
365+
366+
def assert_runs_in_hwd_thread():
367+
if not threading.current_thread().name.startswith("hwd_comms_thread"):
368+
raise Exception("must only be called from HWD communication thread")
342369

343370

344371
class DeviceMgr(ThreadJob):
@@ -384,24 +411,11 @@ def __init__(self, config: SimpleConfig):
384411
self._recognised_hardware = {} # type: Dict[Tuple[int, int], HW_PluginBase]
385412
# Custom enumerate functions for devices we don't know about.
386413
self._enumerate_func = set() # Needs self.lock.
387-
# locks: if you need to take multiple ones, acquire them in the order they are defined here!
388-
self._scan_lock = threading.RLock()
414+
389415
self.lock = threading.RLock()
390-
self.hid_lock = _hid_lock
391416

392417
self.config = config
393418

394-
global _hid_executor
395-
if _hid_executor is None:
396-
_hid_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1,
397-
thread_name_prefix='hid_enumerate_thread')
398-
399-
def with_scan_lock(func):
400-
def func_wrapper(self: 'DeviceMgr', *args, **kwargs):
401-
with self._scan_lock:
402-
return func(self, *args, **kwargs)
403-
return func_wrapper
404-
405419
def thread_jobs(self):
406420
# Thread job to handle device timeouts
407421
return [self]
@@ -423,6 +437,7 @@ def register_enumerate_func(self, func):
423437
with self.lock:
424438
self._enumerate_func.add(func)
425439

440+
@runs_in_hwd_thread
426441
def create_client(self, device: 'Device', handler: Optional['HardwareHandlerBase'],
427442
plugin: 'HW_PluginBase') -> Optional['HardwareClientBase']:
428443
# Get from cache first
@@ -452,7 +467,7 @@ def unpair_xpub(self, xpub):
452467
if xpub not in self.xpub_ids:
453468
return
454469
_id = self.xpub_ids.pop(xpub)
455-
self._close_client(_id)
470+
self._close_client(_id)
456471

457472
def unpair_id(self, id_):
458473
xpub = self.xpub_by_id(id_)
@@ -462,8 +477,9 @@ def unpair_id(self, id_):
462477
self._close_client(id_)
463478

464479
def _close_client(self, id_):
465-
client = self._client_by_id(id_)
466-
self.clients.pop(client, None)
480+
with self.lock:
481+
client = self._client_by_id(id_)
482+
self.clients.pop(client, None)
467483
if client:
468484
client.close()
469485

@@ -486,7 +502,7 @@ def client_by_id(self, id_, *, scan_now: bool = True) -> Optional['HardwareClien
486502
self.scan_devices()
487503
return self._client_by_id(id_)
488504

489-
@with_scan_lock
505+
@runs_in_hwd_thread
490506
def client_for_keystore(self, plugin: 'HW_PluginBase', handler: Optional['HardwareHandlerBase'],
491507
keystore: 'Hardware_KeyStore',
492508
force_pair: bool, *,
@@ -655,33 +671,23 @@ def select_device(self, plugin: 'HW_PluginBase', handler: 'HardwareHandlerBase',
655671
# note: updated label/soft_device_id will be saved after pairing succeeds
656672
return info
657673

658-
@with_scan_lock
674+
@runs_in_hwd_thread
659675
def _scan_devices_with_hid(self) -> List['Device']:
660676
try:
661677
import hid
662678
except ImportError:
663679
return []
664680

665-
def hid_enumerate():
666-
with self.hid_lock:
667-
return hid.enumerate(0, 0)
668-
669-
hid_list_fut = _hid_executor.submit(hid_enumerate)
670-
try:
671-
hid_list = hid_list_fut.result()
672-
except (concurrent.futures.CancelledError, concurrent.futures.TimeoutError) as e:
673-
return []
674-
675681
devices = []
676-
for d in hid_list:
682+
for d in hid.enumerate(0, 0):
677683
product_key = (d['vendor_id'], d['product_id'])
678684
if product_key in self._recognised_hardware:
679685
plugin = self._recognised_hardware[product_key]
680686
device = plugin.create_device_from_hid_enumeration(d, product_key=product_key)
681687
devices.append(device)
682688
return devices
683689

684-
@with_scan_lock
690+
@runs_in_hwd_thread
685691
@profiler
686692
def scan_devices(self) -> Sequence['Device']:
687693
self.logger.info("scanning devices...")
@@ -693,10 +699,8 @@ def scan_devices(self) -> Sequence['Device']:
693699
with self.lock:
694700
enumerate_funcs = list(self._enumerate_func)
695701
for f in enumerate_funcs:
696-
# custom enumerate functions might use hidapi, so use hid thread to be safe
697-
new_devices_fut = _hid_executor.submit(f)
698702
try:
699-
new_devices = new_devices_fut.result()
703+
new_devices = f()
700704
except BaseException as e:
701705
self.logger.error('custom device enum failed. func {}, error {}'
702706
.format(str(f), repr(e)))

electrum/plugins/bitbox02/bitbox02.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from electrum.util import bh2u, UserFacingException
1414
from electrum.base_wizard import ScriptTypeNotSupported, BaseWizard
1515
from electrum.logging import get_logger
16-
from electrum.plugin import Device, DeviceInfo
16+
from electrum.plugin import Device, DeviceInfo, runs_in_hwd_thread
1717
from electrum.simple_config import SimpleConfig
1818
from electrum.json_db import StoredDict
1919
from electrum.storage import get_derivation_used_for_hw_device_encryption
@@ -73,18 +73,19 @@ def __init__(self, handler: Any, device: Device, config: SimpleConfig, *, plugin
7373
def is_initialized(self) -> bool:
7474
return True
7575

76+
@runs_in_hwd_thread
7677
def close(self):
77-
with self.device_manager().hid_lock:
78-
try:
79-
self.bitbox02_device.close()
80-
except:
81-
pass
78+
try:
79+
self.bitbox02_device.close()
80+
except:
81+
pass
8282

8383
def has_usable_connection_with_device(self) -> bool:
8484
if self.bitbox_hid_info is None:
8585
return False
8686
return True
8787

88+
@runs_in_hwd_thread
8889
def get_soft_device_id(self) -> Optional[str]:
8990
if self.handler is None:
9091
# Can't do the pairing without the handler. This happens at wallet creation time, when
@@ -94,6 +95,7 @@ def get_soft_device_id(self) -> Optional[str]:
9495
self.pairing_dialog()
9596
return self.bitbox02_device.root_fingerprint().hex()
9697

98+
@runs_in_hwd_thread
9799
def pairing_dialog(self):
98100
def pairing_step(code: str, device_response: Callable[[], bool]) -> bool:
99101
msg = "Please compare and confirm the pairing code on your BitBox02:\n" + code
@@ -102,8 +104,7 @@ def pairing_step(code: str, device_response: Callable[[], bool]) -> bool:
102104
res = device_response()
103105
except:
104106
# Close the hid device on exception
105-
with self.device_manager().hid_lock:
106-
hid_device.close()
107+
hid_device.close()
107108
raise
108109
finally:
109110
self.handler.finished()
@@ -167,10 +168,8 @@ def set_app_static_privkey(self, privkey: bytes) -> None:
167168
return set_noise_privkey(privkey)
168169

169170
if self.bitbox02_device is None:
170-
with self.device_manager().hid_lock:
171-
hid_device = hid.device()
172-
hid_device.open_path(self.bitbox_hid_info["path"])
173-
171+
hid_device = hid.device()
172+
hid_device.open_path(self.bitbox_hid_info["path"])
174173

175174
bitbox02_device = bitbox02.BitBox02(
176175
transport=u2fhid.U2FHid(hid_device),
@@ -197,13 +196,15 @@ def coin_network_from_electrum_network(self) -> int:
197196
return bitbox02.btc.TBTC
198197
return bitbox02.btc.BTC
199198

199+
@runs_in_hwd_thread
200200
def get_password_for_storage_encryption(self) -> str:
201201
derivation = get_derivation_used_for_hw_device_encryption()
202202
derivation_list = bip32.convert_bip32_path_to_list_of_uint32(derivation)
203203
xpub = self.bitbox02_device.electrum_encryption_key(derivation_list)
204204
node = bip32.BIP32Node.from_xkey(xpub, net = constants.BitcoinMainnet()).subkey_at_public_derivation(())
205205
return node.eckey.get_public_key_bytes(compressed=True).hex()
206206

207+
@runs_in_hwd_thread
207208
def get_xpub(self, bip32_path: str, xtype: str, *, display: bool = False) -> str:
208209
if self.bitbox02_device is None:
209210
self.pairing_dialog()
@@ -244,6 +245,7 @@ def get_xpub(self, bip32_path: str, xtype: str, *, display: bool = False) -> str
244245
display=display,
245246
)
246247

248+
@runs_in_hwd_thread
247249
def label(self) -> str:
248250
if self.handler is None:
249251
# Can't do the pairing without the handler. This happens at wallet creation time, when
@@ -258,6 +260,7 @@ def label(self) -> str:
258260
self.bitbox02_device.root_fingerprint().hex(),
259261
)
260262

263+
@runs_in_hwd_thread
261264
def request_root_fingerprint_from_device(self) -> str:
262265
if self.bitbox02_device is None:
263266
raise Exception(
@@ -271,6 +274,7 @@ def is_pairable(self) -> bool:
271274
return False
272275
return True
273276

277+
@runs_in_hwd_thread
274278
def btc_multisig_config(
275279
self, coin, bip32_path: List[int], wallet: Multisig_Wallet
276280
):
@@ -316,6 +320,7 @@ def btc_multisig_config(
316320
raise UserFacingException("Failed to register multisig\naccount configuration on BitBox02")
317321
return multisig_config
318322

323+
@runs_in_hwd_thread
319324
def show_address(
320325
self, bip32_path: str, address_type: str, wallet: Deterministic_Wallet
321326
) -> str:
@@ -357,6 +362,7 @@ def show_address(
357362
display=True,
358363
)
359364

365+
@runs_in_hwd_thread
360366
def sign_transaction(
361367
self,
362368
keystore: Hardware_KeyStore,
@@ -553,6 +559,7 @@ def sign_message(self, sequence, message, password):
553559
).format(self.device)
554560
)
555561

562+
@runs_in_hwd_thread
556563
def sign_transaction(self, tx: PartialTransaction, password: str):
557564
if tx.is_complete():
558565
return
@@ -572,6 +579,7 @@ def sign_transaction(self, tx: PartialTransaction, password: str):
572579
self.give_error(e, True)
573580
return
574581

582+
@runs_in_hwd_thread
575583
def show_address(
576584
self, sequence: Tuple[int, int], txin_type: str, wallet: Deterministic_Wallet
577585
):
@@ -616,6 +624,7 @@ def get_library_version(self):
616624
raise ImportError()
617625

618626
# handler is a BitBox02_Handler
627+
@runs_in_hwd_thread
619628
def create_client(self, device: Device, handler: Any) -> BitBox02Client:
620629
if not handler:
621630
self.handler = handler
@@ -645,6 +654,7 @@ def get_xpub(
645654
assert client.bitbox02_device is not None
646655
return client.get_xpub(derivation, xtype)
647656

657+
@runs_in_hwd_thread
648658
def show_address(
649659
self,
650660
wallet: Deterministic_Wallet,
@@ -660,6 +670,7 @@ def show_address(
660670
sequence = wallet.get_address_index(address)
661671
keystore.show_address(sequence, txin_type, wallet)
662672

673+
@runs_in_hwd_thread
663674
def show_xpub(self, keystore: BitBox02_KeyStore):
664675
client = keystore.get_client()
665676
assert isinstance(client, BitBox02Client)

0 commit comments

Comments
 (0)