FIX: 'last_device.txt is empty' behavior
According to PR #227 Thanks @xdadrm for spotting the error
This commit is contained in:
parent
20d2131e81
commit
7706d6aa34
12
install.py
12
install.py
@ -79,12 +79,14 @@ print("ReActor preheating...", end=' ')
|
||||
|
||||
last_device = None
|
||||
first_run = False
|
||||
available_devices = ["CPU", "CUDA"]
|
||||
|
||||
try:
|
||||
last_device_log = os.path.join(BASE_PATH, "last_device.txt")
|
||||
with open(last_device_log) as f:
|
||||
for el in f:
|
||||
last_device = el.strip()
|
||||
last_device = f.readline().strip()
|
||||
if last_device not in available_devices:
|
||||
last_device = None
|
||||
except:
|
||||
last_device = "CPU"
|
||||
first_run = True
|
||||
@ -97,7 +99,7 @@ with open(req_file) as file:
|
||||
import torch
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
if first_run:
|
||||
if first_run or last_device is None:
|
||||
last_device = "CUDA"
|
||||
elif torch.backends.mps.is_available() or hasattr(torch,'dml'):
|
||||
ort = "onnxruntime"
|
||||
@ -105,10 +107,10 @@ with open(req_file) as file:
|
||||
if first_run:
|
||||
pip_uninstall("onnxruntime", "onnxruntime-gpu")
|
||||
# just in case:
|
||||
if last_device == "CUDA":
|
||||
if last_device == "CUDA" or last_device is None:
|
||||
last_device = "CPU"
|
||||
else:
|
||||
if last_device == "CUDA":
|
||||
if last_device == "CUDA" or last_device is None:
|
||||
last_device = "CPU"
|
||||
with open(os.path.join(BASE_PATH, "last_device.txt"), "w") as txt:
|
||||
txt.write(last_device)
|
||||
|
||||
@ -26,10 +26,14 @@ def updateDevice():
|
||||
try:
|
||||
LAST_DEVICE_PATH = os.path.join(BASE_PATH, "last_device.txt")
|
||||
with open(LAST_DEVICE_PATH) as f:
|
||||
for el in f:
|
||||
device = el.strip()
|
||||
except:
|
||||
device = "CPU"
|
||||
device = f.readline().strip()
|
||||
if device not in DEVICE_LIST:
|
||||
print(f"Error: Device {device} is not in DEVICE_LIST")
|
||||
device = DEVICE_LIST[0]
|
||||
print(f"Execution Provider has been set to {device}")
|
||||
except Exception as e:
|
||||
device = DEVICE_LIST[0]
|
||||
print(f"Error: {e}\nExecution Provider has been set to {device}")
|
||||
return device
|
||||
|
||||
DEVICE = updateDevice()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user