mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(diffusers_pipeline): ensure cuda.get_mem_info
always gets a specific device index.
Also tighten up the typing of `device` attributes in general.
This commit is contained in:
parent
07f9fa63d0
commit
b8212e4dea
@ -191,9 +191,8 @@ class Generate:
|
||||
# Note that in previous versions, there was an option to pass the
|
||||
# device to Generate(). However the device was then ignored, so
|
||||
# it wasn't actually doing anything. This logic could be reinstated.
|
||||
device_type = choose_torch_device()
|
||||
print(f'>> Using device_type {device_type}')
|
||||
self.device = torch.device(device_type)
|
||||
self.device = torch.device(choose_torch_device())
|
||||
print(f'>> Using device_type {self.device.type}')
|
||||
if full_precision:
|
||||
if self.precision != 'auto':
|
||||
raise ValueError('Remove --full_precision / -F if using --precision')
|
||||
|
@ -40,7 +40,6 @@ from ldm.invoke.globals import Globals, global_cache_dir, global_config_dir
|
||||
from ldm.invoke.readline import generic_completer
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import torch
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
@ -764,7 +763,7 @@ def download_weights(opt: dict) -> Union[str, None]:
|
||||
precision = (
|
||||
"float32"
|
||||
if opt.full_precision
|
||||
else choose_precision(torch.device(choose_torch_device()))
|
||||
else choose_precision(choose_torch_device())
|
||||
)
|
||||
|
||||
if opt.yes_to_all:
|
||||
|
@ -1,19 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch import autocast
|
||||
from contextlib import nullcontext
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
def choose_torch_device() -> str:
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
def choose_torch_device() -> torch.device:
|
||||
'''Convenience routine for guessing which GPU device to run model on'''
|
||||
if Globals.always_use_cpu:
|
||||
return "cpu"
|
||||
return CPU_DEVICE
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
return torch.device('cuda')
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
return 'cpu'
|
||||
return torch.device('mps')
|
||||
return CPU_DEVICE
|
||||
|
||||
def choose_precision(device) -> str:
|
||||
def choose_precision(device: torch.device) -> str:
|
||||
'''Returns an appropriate precision for the given torch device'''
|
||||
if device.type == 'cuda':
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
@ -21,7 +27,7 @@ def choose_precision(device) -> str:
|
||||
return 'float16'
|
||||
return 'float32'
|
||||
|
||||
def torch_dtype(device) -> torch.dtype:
|
||||
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||
if Globals.full_precision:
|
||||
return torch.float32
|
||||
if choose_precision(device) == 'float16':
|
||||
@ -36,3 +42,13 @@ def choose_autocast(precision):
|
||||
if precision == 'autocast' or precision == 'float16':
|
||||
return autocast
|
||||
return nullcontext
|
||||
|
||||
def normalize_device(device: str | torch.device) -> torch.device:
|
||||
"""Ensure device has a device index defined, if appropriate."""
|
||||
device = torch.device(device)
|
||||
if device.index is None:
|
||||
# cuda might be the only torch backend that currently uses the device index?
|
||||
# I don't see anything like `current_device` for cpu or mps.
|
||||
if device.type == 'cuda':
|
||||
device = torch.device(device.type, torch.cuda.current_device())
|
||||
return device
|
||||
|
@ -28,6 +28,7 @@ from typing_extensions import ParamSpec
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||
from ..devices import normalize_device, CPU_DEVICE
|
||||
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
|
||||
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
||||
@ -319,7 +320,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if self.device.type == 'cpu' or self.device.type == 'mps':
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.device.type == 'cuda':
|
||||
mem_free, _ = torch.cuda.mem_get_info(self.device)
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {self.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
@ -380,9 +381,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self._model_group.ready()
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
|
||||
# overridden method; types match the superclass.
|
||||
if torch_device is None:
|
||||
return self
|
||||
self._model_group.set_device(torch_device)
|
||||
self._model_group.set_device(torch.device(torch_device))
|
||||
self._model_group.ready()
|
||||
|
||||
@property
|
||||
@ -689,8 +691,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if device.type == 'mps':
|
||||
# workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
|
||||
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
|
||||
self.vae.to('cpu')
|
||||
init_image = init_image.to('cpu')
|
||||
self.vae.to(CPU_DEVICE)
|
||||
init_image = init_image.to(CPU_DEVICE)
|
||||
else:
|
||||
self._model_group.load(self.vae)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
|
@ -30,6 +30,7 @@ from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from ldm.invoke.devices import CPU_DEVICE
|
||||
from ldm.invoke.generator.diffusers_pipeline import \
|
||||
StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
|
||||
@ -47,7 +48,7 @@ class ModelManager(object):
|
||||
def __init__(
|
||||
self,
|
||||
config: OmegaConf,
|
||||
device_type: str | torch.device = "cpu",
|
||||
device_type: torch.device = CPU_DEVICE,
|
||||
precision: str = "float16",
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload = False
|
||||
@ -675,7 +676,7 @@ class ModelManager(object):
|
||||
"""
|
||||
if str(weights).startswith(("http:", "https:")):
|
||||
model_name = model_name or url_attachment_name(weights)
|
||||
|
||||
|
||||
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
|
||||
config_path = self._resolve_path(config, "configs/stable-diffusion")
|
||||
|
||||
@ -996,25 +997,25 @@ class ModelManager(object):
|
||||
self.models.pop(model_name, None)
|
||||
|
||||
def _model_to_cpu(self, model):
|
||||
if self.device == "cpu":
|
||||
if self.device == CPU_DEVICE:
|
||||
return model
|
||||
|
||||
if isinstance(model, StableDiffusionGeneratorPipeline):
|
||||
model.offload_all()
|
||||
return model
|
||||
|
||||
model.cond_stage_model.device = "cpu"
|
||||
model.to("cpu")
|
||||
model.cond_stage_model.device = CPU_DEVICE
|
||||
model.to(CPU_DEVICE)
|
||||
|
||||
for submodel in ("first_stage_model", "cond_stage_model", "model"):
|
||||
try:
|
||||
getattr(model, submodel).to("cpu")
|
||||
getattr(model, submodel).to(CPU_DEVICE)
|
||||
except AttributeError:
|
||||
pass
|
||||
return model
|
||||
|
||||
def _model_from_cpu(self, model):
|
||||
if self.device == "cpu":
|
||||
if self.device == CPU_DEVICE:
|
||||
return model
|
||||
|
||||
if isinstance(model, StableDiffusionGeneratorPipeline):
|
||||
|
Loading…
Reference in New Issue
Block a user