@@ -38,12 +38,14 @@ def setup(self):
3838
3939 def onInstallTorch (self ):
4040 torch = PyTorchUtilsLogic ().torch
41- slicer .util .delayDisplay (f'PyTorch { torch .__version__ } installed correctly' )
41+ if torch is not None :
42+ slicer .util .delayDisplay (f'PyTorch { torch .__version__ } installed correctly' )
4243
4344
4445class PyTorchUtilsLogic (ScriptedLoadableModuleLogic ):
4546 def __init__ (self ):
4647 self ._torch = None
48+ self ._wheel = None
4749
4850 @property
4951 def torch (self ):
@@ -53,6 +55,14 @@ def torch(self):
5355 self ._torch = self .importTorch ()
5456 return self ._torch
5557
58+ @property
59+ def wheelURL (self ):
60+ """URL to the ``torch`` package wheel, retrieved using ``light-the-torch``."""
61+ if self ._wheel is None :
62+ logging .info ('Querying light-the-torch for torch wheel...' )
63+ self ._wheel = self .getTorchUrl ()
64+ return self ._wheel
65+
5666 @staticmethod
5767 def torchInstalled ():
5868 try :
@@ -68,21 +78,25 @@ def importTorch(self):
6878 import torch
6979 else :
7080 torch = self .installTorch ()
71- logging .info (f'PyTorch { torch .__version__ } imported correctly' )
72- logging .info (f'CUDA available: { torch .cuda .is_available ()} ' )
81+ if torch is None :
82+ logging .warning ('PyTorch was not installed' )
83+ else :
84+ logging .info (f'PyTorch { torch .__version__ } imported correctly' )
85+ logging .info (f'CUDA available: { torch .cuda .is_available ()} ' )
7386 return torch
7487
7588 def installTorch (self , askConfirmation = False ):
7689 """Install PyTorch and return the ``torch`` Python module."""
7790 if askConfirmation and not slicer .app .commandOptions ().testingEnabled :
7891 install = slicer .util .confirmOkCancelDisplay (
79- 'PyTorch will be downloaded and installed now. The process might take some minutes.'
92+ 'PyTorch will be downloaded and installed from the following URL:\n '
93+ f'{ self .wheelURL } '
94+ '\n The process might take some minutes.'
8095 )
8196 if not install :
8297 logging .info ('Installation of PyTorch aborted by user' )
8398 return None
84- wheelUrl = self .getTorchUrl ()
85- slicer .util .pip_install (wheelUrl )
99+ slicer .util .pip_install (self .wheelURL )
86100 import torch
87101 logging .info (f'PyTorch { torch .__version__ } installed correctly' )
88102 return torch
0 commit comments