diff --git a/README.md b/README.md
index e030165..5c7f7c6 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
- 
+ 
diff --git a/README_RU.md b/README_RU.md
index a039fe9..14cb95e 100644
--- a/README_RU.md
+++ b/README_RU.md
@@ -2,7 +2,7 @@
- 
+ 
diff --git a/install.py b/install.py
index c218e40..d7b9567 100644
--- a/install.py
+++ b/install.py
@@ -43,9 +43,12 @@ def get_sd_option(name: str, default: Any) -> Any:
assert shared.opts.data is not None
return shared.opts.data.get(name, default)
-def run_pip(*args):
+def pip_install(*args):
subprocess.run([sys.executable, "-m", "pip", "install", *args])
+def pip_uninstall(*args):
+ subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", *args])
+
def is_installed (
package: str, version: str | None = None, strict: bool = True
):
@@ -96,10 +99,18 @@ with open(req_file) as file:
install_count = 0
try:
ort = "onnxruntime-gpu"
- import torch.cuda as cuda
- if cuda.is_available():
+ import torch
+ if torch.cuda.is_available():
if first_run:
last_device = "CUDA"
+ elif torch.backends.mps.is_available() or hasattr(torch,'dml'):
+ ort = "onnxruntime"
+ # to prevent errors when ORT-GPU is installed but we want ORT instead:
+ if first_run:
+ pip_uninstall("onnxruntime", "onnxruntime-gpu")
+ # just in case:
+ if last_device == "CUDA":
+ last_device = "CPU"
else:
if last_device == "CUDA":
last_device = "CPU"
@@ -107,7 +118,7 @@ with open(req_file) as file:
txt.write(last_device)
if not is_installed(ort,"1.16.1",False):
install_count += 1
- run_pip(ort)
+ pip_install(ort)
except Exception as e:
print(e)
print(f"\nERROR: Failed to install {ort} - ReActor won't start")
@@ -125,7 +136,7 @@ with open(req_file) as file:
strict = False
if not is_installed(package,package_version,strict):
install_count += 1
- run_pip(package)
+ pip_install(package)
except Exception as e:
print(e)
print(f"\nERROR: Failed to install {package} - ReActor won't start")
diff --git a/scripts/reactor_version.py b/scripts/reactor_version.py
index 6a2396d..b9eb334 100644
--- a/scripts/reactor_version.py
+++ b/scripts/reactor_version.py
@@ -1,5 +1,5 @@
app_title = "ReActor"
-version_flag = "v0.5.0-a3"
+version_flag = "v0.5.0-a4"
from scripts.reactor_logger import logger, get_Run, set_Run