diff --git a/launch.py b/launch.py index a357b917..72d2a5dc 100644 --- a/launch.py +++ b/launch.py @@ -23,6 +23,16 @@ taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HAS codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") +args = shlex.split(commandline_args) + + +def extract_arg(args, name): + return [x for x in args if x != name], name in args + + +args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test') + + def repo_dir(name): return os.path.join(dir_repos, name) @@ -95,7 +105,8 @@ print(f"Commit hash: {commit}") if not is_installed("torch"): run(f'"{python}" -m {torch_command}', "Installing torch", "Couldn't install torch") -run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU'") +if not skip_torch_cuda_test: + run_python("import torch; assert not torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDINE_ARGS variable to disable this check'") if not is_installed("k_diffusion.sampling"): run_pip(f"install {k_diffusion_package}", "k-diffusion") @@ -115,7 +126,7 @@ if not is_installed("lpips"): run_pip(f"install -r {requirements_file}", "requirements for Web UI") -sys.argv += shlex.split(commandline_args) +sys.argv += args def start_webui():