Skip to content

Commit 92b9e67

Browse files
authored
Add confirmation dialog and optimize wheel retrieval (#2)
* Add confirmation dialog and optimize wheel retrieval * Fix typo
1 parent 3d80e40 commit 92b9e67

1 file changed

Lines changed: 20 additions & 6 deletions

File tree

PyTorchUtils/PyTorchUtils.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,14 @@ def setup(self):
3838

3939
def onInstallTorch(self):
4040
torch = PyTorchUtilsLogic().torch
41-
slicer.util.delayDisplay(f'PyTorch {torch.__version__} installed correctly')
41+
if torch is not None:
42+
slicer.util.delayDisplay(f'PyTorch {torch.__version__} installed correctly')
4243

4344

4445
class PyTorchUtilsLogic(ScriptedLoadableModuleLogic):
4546
def __init__(self):
4647
self._torch = None
48+
self._wheel = None
4749

4850
@property
4951
def torch(self):
@@ -53,6 +55,14 @@ def torch(self):
5355
self._torch = self.importTorch()
5456
return self._torch
5557

58+
@property
59+
def wheelURL(self):
60+
"""URL to the ``torch`` package wheel, retrieved using ``light-the-torch``."""
61+
if self._wheel is None:
62+
logging.info('Querying light-the-torch for torch wheel...')
63+
self._wheel = self.getTorchUrl()
64+
return self._wheel
65+
5666
@staticmethod
5767
def torchInstalled():
5868
try:
@@ -68,21 +78,25 @@ def importTorch(self):
6878
import torch
6979
else:
7080
torch = self.installTorch()
71-
logging.info(f'PyTorch {torch.__version__} imported correctly')
72-
logging.info(f'CUDA available: {torch.cuda.is_available()}')
81+
if torch is None:
82+
logging.warning('PyTorch was not installed')
83+
else:
84+
logging.info(f'PyTorch {torch.__version__} imported correctly')
85+
logging.info(f'CUDA available: {torch.cuda.is_available()}')
7386
return torch
7487

7588
def installTorch(self, askConfirmation=False):
7689
"""Install PyTorch and return the ``torch`` Python module."""
7790
if askConfirmation and not slicer.app.commandOptions().testingEnabled:
7891
install = slicer.util.confirmOkCancelDisplay(
79-
'PyTorch will be downloaded and installed now. The process might take some minutes.'
92+
'PyTorch will be downloaded and installed from the following URL:\n'
93+
f'{self.wheelURL}'
94+
'\nThe process might take some minutes.'
8095
)
8196
if not install:
8297
logging.info('Installation of PyTorch aborted by user')
8398
return None
84-
wheelUrl = self.getTorchUrl()
85-
slicer.util.pip_install(wheelUrl)
99+
slicer.util.pip_install(self.wheelURL)
86100
import torch
87101
logging.info(f'PyTorch {torch.__version__} installed correctly')
88102
return torch

0 commit comments

Comments
 (0)