fix(diffusers_pipeline): ensure cuda.get_mem_info always gets a specific device index. (#2700)

Also tighten up the typing of `device` attributes in general.

Fixes 
> ValueError: Expected a torch.device with a specified index or an
integer, but got:cuda
This commit is contained in:
blessedcoolant 2023-02-19 04:33:16 +13:00 committed by GitHub
commit e3d1c64b77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 41 additions and 24 deletions

View File

@ -191,9 +191,8 @@ class Generate:
# Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so
# it wasn't actually doing anything. This logic could be reinstated.
device_type = choose_torch_device()
print(f'>> Using device_type {device_type}')
self.device = torch.device(device_type)
self.device = torch.device(choose_torch_device())
print(f'>> Using device_type {self.device.type}')
if full_precision:
if self.precision != 'auto':
raise ValueError('Remove --full_precision / -F if using --precision')

View File

@ -40,7 +40,6 @@ from ldm.invoke.globals import Globals, global_cache_dir, global_config_dir
from ldm.invoke.readline import generic_completer
warnings.filterwarnings("ignore")
import torch
transformers.logging.set_verbosity_error()
@ -764,7 +763,7 @@ def download_weights(opt: dict) -> Union[str, None]:
precision = (
"float32"
if opt.full_precision
else choose_precision(torch.device(choose_torch_device()))
else choose_precision(choose_torch_device())
)
if opt.yes_to_all:

View File

@ -1,19 +1,25 @@
from __future__ import annotations
from contextlib import nullcontext
import torch
from torch import autocast
from contextlib import nullcontext
from ldm.invoke.globals import Globals
def choose_torch_device() -> str:
CPU_DEVICE = torch.device("cpu")
def choose_torch_device() -> torch.device:
'''Convenience routine for guessing which GPU device to run model on'''
if Globals.always_use_cpu:
return "cpu"
return CPU_DEVICE
if torch.cuda.is_available():
return 'cuda'
return torch.device('cuda')
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return 'mps'
return 'cpu'
return torch.device('mps')
return CPU_DEVICE
def choose_precision(device) -> str:
def choose_precision(device: torch.device) -> str:
'''Returns an appropriate precision for the given torch device'''
if device.type == 'cuda':
device_name = torch.cuda.get_device_name(device)
@ -21,7 +27,7 @@ def choose_precision(device) -> str:
return 'float16'
return 'float32'
def torch_dtype(device) -> torch.dtype:
def torch_dtype(device: torch.device) -> torch.dtype:
if Globals.full_precision:
return torch.float32
if choose_precision(device) == 'float16':
@ -36,3 +42,13 @@ def choose_autocast(precision):
if precision == 'autocast' or precision == 'float16':
return autocast
return nullcontext
def normalize_device(device: str | torch.device) -> torch.device:
"""Ensure device has a device index defined, if appropriate."""
device = torch.device(device)
if device.index is None:
# cuda might be the only torch backend that currently uses the device index?
# I don't see anything like `current_device` for cpu or mps.
if device.type == 'cuda':
device = torch.device(device.type, torch.cuda.current_device())
return device

View File

@ -28,6 +28,7 @@ from typing_extensions import ParamSpec
from ldm.invoke.globals import Globals
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
from ldm.modules.textual_inversion_manager import TextualInversionManager
from ..devices import normalize_device, CPU_DEVICE
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
@ -319,7 +320,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.device.type == 'cpu' or self.device.type == 'mps':
mem_free = psutil.virtual_memory().free
elif self.device.type == 'cuda':
mem_free, _ = torch.cuda.mem_get_info(self.device)
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
else:
raise ValueError(f"unrecognized device {self.device}")
# input tensor of [1, 4, h/8, w/8]
@ -380,9 +381,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self._model_group.ready()
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
# overridden method; types match the superclass.
if torch_device is None:
return self
self._model_group.set_device(torch_device)
self._model_group.set_device(torch.device(torch_device))
self._model_group.ready()
@property
@ -689,8 +691,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if device.type == 'mps':
# workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
self.vae.to('cpu')
init_image = init_image.to('cpu')
self.vae.to(CPU_DEVICE)
init_image = init_image.to(CPU_DEVICE)
else:
self._model_group.load(self.vae)
init_latent_dist = self.vae.encode(init_image).latent_dist

View File

@ -30,6 +30,7 @@ from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from picklescan.scanner import scan_file_path
from ldm.invoke.devices import CPU_DEVICE
from ldm.invoke.generator.diffusers_pipeline import \
StableDiffusionGeneratorPipeline
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
@ -47,7 +48,7 @@ class ModelManager(object):
def __init__(
self,
config: OmegaConf,
device_type: str | torch.device = "cpu",
device_type: torch.device = CPU_DEVICE,
precision: str = "float16",
max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload = False
@ -675,7 +676,7 @@ class ModelManager(object):
"""
if str(weights).startswith(("http:", "https:")):
model_name = model_name or url_attachment_name(weights)
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
config_path = self._resolve_path(config, "configs/stable-diffusion")
@ -996,25 +997,25 @@ class ModelManager(object):
self.models.pop(model_name, None)
def _model_to_cpu(self, model):
if self.device == "cpu":
if self.device == CPU_DEVICE:
return model
if isinstance(model, StableDiffusionGeneratorPipeline):
model.offload_all()
return model
model.cond_stage_model.device = "cpu"
model.to("cpu")
model.cond_stage_model.device = CPU_DEVICE
model.to(CPU_DEVICE)
for submodel in ("first_stage_model", "cond_stage_model", "model"):
try:
getattr(model, submodel).to("cpu")
getattr(model, submodel).to(CPU_DEVICE)
except AttributeError:
pass
return model
def _model_from_cpu(self, model):
if self.device == "cpu":
if self.device == CPU_DEVICE:
return model
if isinstance(model, StableDiffusionGeneratorPipeline):