fix(diffusers_pipeline): ensure cuda.get_mem_info always gets a specific device index.

Also tighten up the typing of `device` attributes in general.
This commit is contained in:
Kevin Turner
2023-02-17 16:29:03 -08:00
parent 07f9fa63d0
commit b8212e4dea
5 changed files with 41 additions and 24 deletions

View File

@ -40,7 +40,6 @@ from ldm.invoke.globals import Globals, global_cache_dir, global_config_dir
from ldm.invoke.readline import generic_completer
warnings.filterwarnings("ignore")
import torch
transformers.logging.set_verbosity_error()
@ -764,7 +763,7 @@ def download_weights(opt: dict) -> Union[str, None]:
precision = (
"float32"
if opt.full_precision
else choose_precision(torch.device(choose_torch_device()))
else choose_precision(choose_torch_device())
)
if opt.yes_to_all: