FIX: 'last_device.txt is empty' behavior

According to PR #227
Thanks @xdadrm for spotting the error
This commit is contained in:
Gourieff 2023-12-03 01:36:46 +07:00
parent 20d2131e81
commit 7706d6aa34
2 changed files with 15 additions and 9 deletions

View File

@ -79,12 +79,14 @@ print("ReActor preheating...", end=' ')
last_device = None last_device = None
first_run = False first_run = False
available_devices = ["CPU", "CUDA"]
try: try:
last_device_log = os.path.join(BASE_PATH, "last_device.txt") last_device_log = os.path.join(BASE_PATH, "last_device.txt")
with open(last_device_log) as f: with open(last_device_log) as f:
for el in f: last_device = f.readline().strip()
last_device = el.strip() if last_device not in available_devices:
last_device = None
except: except:
last_device = "CPU" last_device = "CPU"
first_run = True first_run = True
@ -97,7 +99,7 @@ with open(req_file) as file:
import torch import torch
try: try:
if torch.cuda.is_available(): if torch.cuda.is_available():
if first_run: if first_run or last_device is None:
last_device = "CUDA" last_device = "CUDA"
elif torch.backends.mps.is_available() or hasattr(torch,'dml'): elif torch.backends.mps.is_available() or hasattr(torch,'dml'):
ort = "onnxruntime" ort = "onnxruntime"
@ -105,10 +107,10 @@ with open(req_file) as file:
if first_run: if first_run:
pip_uninstall("onnxruntime", "onnxruntime-gpu") pip_uninstall("onnxruntime", "onnxruntime-gpu")
# just in case: # just in case:
if last_device == "CUDA": if last_device == "CUDA" or last_device is None:
last_device = "CPU" last_device = "CPU"
else: else:
if last_device == "CUDA": if last_device == "CUDA" or last_device is None:
last_device = "CPU" last_device = "CPU"
with open(os.path.join(BASE_PATH, "last_device.txt"), "w") as txt: with open(os.path.join(BASE_PATH, "last_device.txt"), "w") as txt:
txt.write(last_device) txt.write(last_device)

View File

@ -26,10 +26,14 @@ def updateDevice():
try: try:
LAST_DEVICE_PATH = os.path.join(BASE_PATH, "last_device.txt") LAST_DEVICE_PATH = os.path.join(BASE_PATH, "last_device.txt")
with open(LAST_DEVICE_PATH) as f: with open(LAST_DEVICE_PATH) as f:
for el in f: device = f.readline().strip()
device = el.strip() if device not in DEVICE_LIST:
except: print(f"Error: Device {device} is not in DEVICE_LIST")
device = "CPU" device = DEVICE_LIST[0]
print(f"Execution Provider has been set to {device}")
except Exception as e:
device = DEVICE_LIST[0]
print(f"Error: {e}\nExecution Provider has been set to {device}")
return device return device
DEVICE = updateDevice() DEVICE = updateDevice()