fix(diffusers_pipeline): ensure cuda.get_mem_info always gets a specific device index. (#2700)

Also tighten up the typing of `device` attributes in general.

Fixes 
> ValueError: Expected a torch.device with a specified index or an
integer, but got:cuda
This commit is contained in:
blessedcoolant 2023-02-19 04:33:16 +13:00 committed by GitHub
commit e3d1c64b77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 41 additions and 24 deletions

View File

@ -191,9 +191,8 @@ class Generate:
# Note that in previous versions, there was an option to pass the # Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so # device to Generate(). However the device was then ignored, so
# it wasn't actually doing anything. This logic could be reinstated. # it wasn't actually doing anything. This logic could be reinstated.
device_type = choose_torch_device() self.device = torch.device(choose_torch_device())
print(f'>> Using device_type {device_type}') print(f'>> Using device_type {self.device.type}')
self.device = torch.device(device_type)
if full_precision: if full_precision:
if self.precision != 'auto': if self.precision != 'auto':
raise ValueError('Remove --full_precision / -F if using --precision') raise ValueError('Remove --full_precision / -F if using --precision')

View File

@ -40,7 +40,6 @@ from ldm.invoke.globals import Globals, global_cache_dir, global_config_dir
from ldm.invoke.readline import generic_completer from ldm.invoke.readline import generic_completer
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
import torch
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
@ -764,7 +763,7 @@ def download_weights(opt: dict) -> Union[str, None]:
precision = ( precision = (
"float32" "float32"
if opt.full_precision if opt.full_precision
else choose_precision(torch.device(choose_torch_device())) else choose_precision(choose_torch_device())
) )
if opt.yes_to_all: if opt.yes_to_all:

View File

@ -1,19 +1,25 @@
from __future__ import annotations
from contextlib import nullcontext
import torch import torch
from torch import autocast from torch import autocast
from contextlib import nullcontext
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
def choose_torch_device() -> str: CPU_DEVICE = torch.device("cpu")
def choose_torch_device() -> torch.device:
'''Convenience routine for guessing which GPU device to run model on''' '''Convenience routine for guessing which GPU device to run model on'''
if Globals.always_use_cpu: if Globals.always_use_cpu:
return "cpu" return CPU_DEVICE
if torch.cuda.is_available(): if torch.cuda.is_available():
return 'cuda' return torch.device('cuda')
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return 'mps' return torch.device('mps')
return 'cpu' return CPU_DEVICE
def choose_precision(device) -> str: def choose_precision(device: torch.device) -> str:
'''Returns an appropriate precision for the given torch device''' '''Returns an appropriate precision for the given torch device'''
if device.type == 'cuda': if device.type == 'cuda':
device_name = torch.cuda.get_device_name(device) device_name = torch.cuda.get_device_name(device)
@ -21,7 +27,7 @@ def choose_precision(device) -> str:
return 'float16' return 'float16'
return 'float32' return 'float32'
def torch_dtype(device) -> torch.dtype: def torch_dtype(device: torch.device) -> torch.dtype:
if Globals.full_precision: if Globals.full_precision:
return torch.float32 return torch.float32
if choose_precision(device) == 'float16': if choose_precision(device) == 'float16':
@ -36,3 +42,13 @@ def choose_autocast(precision):
if precision == 'autocast' or precision == 'float16': if precision == 'autocast' or precision == 'float16':
return autocast return autocast
return nullcontext return nullcontext
def normalize_device(device: str | torch.device) -> torch.device:
"""Ensure device has a device index defined, if appropriate."""
device = torch.device(device)
if device.index is None:
# cuda might be the only torch backend that currently uses the device index?
# I don't see anything like `current_device` for cpu or mps.
if device.type == 'cuda':
device = torch.device(device.type, torch.cuda.current_device())
return device

View File

@ -28,6 +28,7 @@ from typing_extensions import ParamSpec
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
from ldm.modules.textual_inversion_manager import TextualInversionManager from ldm.modules.textual_inversion_manager import TextualInversionManager
from ..devices import normalize_device, CPU_DEVICE
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
@ -319,7 +320,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.device.type == 'cpu' or self.device.type == 'mps': if self.device.type == 'cpu' or self.device.type == 'mps':
mem_free = psutil.virtual_memory().free mem_free = psutil.virtual_memory().free
elif self.device.type == 'cuda': elif self.device.type == 'cuda':
mem_free, _ = torch.cuda.mem_get_info(self.device) mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
else: else:
raise ValueError(f"unrecognized device {self.device}") raise ValueError(f"unrecognized device {self.device}")
# input tensor of [1, 4, h/8, w/8] # input tensor of [1, 4, h/8, w/8]
@ -380,9 +381,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self._model_group.ready() self._model_group.ready()
def to(self, torch_device: Optional[Union[str, torch.device]] = None): def to(self, torch_device: Optional[Union[str, torch.device]] = None):
# overridden method; types match the superclass.
if torch_device is None: if torch_device is None:
return self return self
self._model_group.set_device(torch_device) self._model_group.set_device(torch.device(torch_device))
self._model_group.ready() self._model_group.ready()
@property @property
@ -689,8 +691,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if device.type == 'mps': if device.type == 'mps':
# workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222 # workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline # TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
self.vae.to('cpu') self.vae.to(CPU_DEVICE)
init_image = init_image.to('cpu') init_image = init_image.to(CPU_DEVICE)
else: else:
self._model_group.load(self.vae) self._model_group.load(self.vae)
init_latent_dist = self.vae.encode(init_image).latent_dist init_latent_dist = self.vae.encode(init_image).latent_dist

View File

@ -30,6 +30,7 @@ from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from ldm.invoke.devices import CPU_DEVICE
from ldm.invoke.generator.diffusers_pipeline import \ from ldm.invoke.generator.diffusers_pipeline import \
StableDiffusionGeneratorPipeline StableDiffusionGeneratorPipeline
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir, from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
@ -47,7 +48,7 @@ class ModelManager(object):
def __init__( def __init__(
self, self,
config: OmegaConf, config: OmegaConf,
device_type: str | torch.device = "cpu", device_type: torch.device = CPU_DEVICE,
precision: str = "float16", precision: str = "float16",
max_loaded_models=DEFAULT_MAX_MODELS, max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload = False sequential_offload = False
@ -996,25 +997,25 @@ class ModelManager(object):
self.models.pop(model_name, None) self.models.pop(model_name, None)
def _model_to_cpu(self, model): def _model_to_cpu(self, model):
if self.device == "cpu": if self.device == CPU_DEVICE:
return model return model
if isinstance(model, StableDiffusionGeneratorPipeline): if isinstance(model, StableDiffusionGeneratorPipeline):
model.offload_all() model.offload_all()
return model return model
model.cond_stage_model.device = "cpu" model.cond_stage_model.device = CPU_DEVICE
model.to("cpu") model.to(CPU_DEVICE)
for submodel in ("first_stage_model", "cond_stage_model", "model"): for submodel in ("first_stage_model", "cond_stage_model", "model"):
try: try:
getattr(model, submodel).to("cpu") getattr(model, submodel).to(CPU_DEVICE)
except AttributeError: except AttributeError:
pass pass
return model return model
def _model_from_cpu(self, model): def _model_from_cpu(self, model):
if self.device == "cpu": if self.device == CPU_DEVICE:
return model return model
if isinstance(model, StableDiffusionGeneratorPipeline): if isinstance(model, StableDiffusionGeneratorPipeline):