From 7706d6aa349724e55443096eadb32e92125908e4 Mon Sep 17 00:00:00 2001 From: Gourieff <777@lovemet.ru> Date: Sun, 3 Dec 2023 01:36:46 +0700 Subject: [PATCH] FIX: 'last_device.txt is empty' behavior According to PR #227 Thanks @xdadrm for spotting the error --- install.py | 12 +++++++----- scripts/reactor_globals.py | 12 ++++++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/install.py b/install.py index 6b46a8c..7023b53 100644 --- a/install.py +++ b/install.py @@ -79,12 +79,14 @@ print("ReActor preheating...", end=' ') last_device = None first_run = False +available_devices = ["CPU", "CUDA"] try: last_device_log = os.path.join(BASE_PATH, "last_device.txt") with open(last_device_log) as f: - for el in f: - last_device = el.strip() + last_device = f.readline().strip() + if last_device not in available_devices: + last_device = None except: last_device = "CPU" first_run = True @@ -97,7 +99,7 @@ with open(req_file) as file: import torch try: if torch.cuda.is_available(): - if first_run: + if first_run or last_device is None: last_device = "CUDA" elif torch.backends.mps.is_available() or hasattr(torch,'dml'): ort = "onnxruntime" @@ -105,10 +107,10 @@ with open(req_file) as file: if first_run: pip_uninstall("onnxruntime", "onnxruntime-gpu") # just in case: - if last_device == "CUDA": + if last_device == "CUDA" or last_device is None: last_device = "CPU" else: - if last_device == "CUDA": + if last_device == "CUDA" or last_device is None: last_device = "CPU" with open(os.path.join(BASE_PATH, "last_device.txt"), "w") as txt: txt.write(last_device) diff --git a/scripts/reactor_globals.py b/scripts/reactor_globals.py index 7aea387..f828254 100644 --- a/scripts/reactor_globals.py +++ b/scripts/reactor_globals.py @@ -26,10 +26,14 @@ def updateDevice(): try: LAST_DEVICE_PATH = os.path.join(BASE_PATH, "last_device.txt") with open(LAST_DEVICE_PATH) as f: - for el in f: - device = el.strip() - except: - device = "CPU" + device = f.readline().strip() + if device not in DEVICE_LIST: + print(f"Error: Device {device} is not in DEVICE_LIST") + device = DEVICE_LIST[0] + print(f"Execution Provider has been set to {device}") + except Exception as e: + device = DEVICE_LIST[0] + print(f"Error: {e}\nExecution Provider has been set to {device}") return device DEVICE = updateDevice()