mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(diffusers_pipeline): ensure cuda.get_mem_info
always gets a specific device index.
Also tighten up the typing of `device` attributes in general.
This commit is contained in:
parent
07f9fa63d0
commit
b8212e4dea
@ -191,9 +191,8 @@ class Generate:
|
|||||||
# Note that in previous versions, there was an option to pass the
|
# Note that in previous versions, there was an option to pass the
|
||||||
# device to Generate(). However the device was then ignored, so
|
# device to Generate(). However the device was then ignored, so
|
||||||
# it wasn't actually doing anything. This logic could be reinstated.
|
# it wasn't actually doing anything. This logic could be reinstated.
|
||||||
device_type = choose_torch_device()
|
self.device = torch.device(choose_torch_device())
|
||||||
print(f'>> Using device_type {device_type}')
|
print(f'>> Using device_type {self.device.type}')
|
||||||
self.device = torch.device(device_type)
|
|
||||||
if full_precision:
|
if full_precision:
|
||||||
if self.precision != 'auto':
|
if self.precision != 'auto':
|
||||||
raise ValueError('Remove --full_precision / -F if using --precision')
|
raise ValueError('Remove --full_precision / -F if using --precision')
|
||||||
|
@ -40,7 +40,6 @@ from ldm.invoke.globals import Globals, global_cache_dir, global_config_dir
|
|||||||
from ldm.invoke.readline import generic_completer
|
from ldm.invoke.readline import generic_completer
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
import torch
|
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
@ -764,7 +763,7 @@ def download_weights(opt: dict) -> Union[str, None]:
|
|||||||
precision = (
|
precision = (
|
||||||
"float32"
|
"float32"
|
||||||
if opt.full_precision
|
if opt.full_precision
|
||||||
else choose_precision(torch.device(choose_torch_device()))
|
else choose_precision(choose_torch_device())
|
||||||
)
|
)
|
||||||
|
|
||||||
if opt.yes_to_all:
|
if opt.yes_to_all:
|
||||||
|
@ -1,19 +1,25 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
from contextlib import nullcontext
|
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
|
|
||||||
def choose_torch_device() -> str:
|
CPU_DEVICE = torch.device("cpu")
|
||||||
|
|
||||||
|
def choose_torch_device() -> torch.device:
|
||||||
'''Convenience routine for guessing which GPU device to run model on'''
|
'''Convenience routine for guessing which GPU device to run model on'''
|
||||||
if Globals.always_use_cpu:
|
if Globals.always_use_cpu:
|
||||||
return "cpu"
|
return CPU_DEVICE
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return 'cuda'
|
return torch.device('cuda')
|
||||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||||
return 'mps'
|
return torch.device('mps')
|
||||||
return 'cpu'
|
return CPU_DEVICE
|
||||||
|
|
||||||
def choose_precision(device) -> str:
|
def choose_precision(device: torch.device) -> str:
|
||||||
'''Returns an appropriate precision for the given torch device'''
|
'''Returns an appropriate precision for the given torch device'''
|
||||||
if device.type == 'cuda':
|
if device.type == 'cuda':
|
||||||
device_name = torch.cuda.get_device_name(device)
|
device_name = torch.cuda.get_device_name(device)
|
||||||
@ -21,7 +27,7 @@ def choose_precision(device) -> str:
|
|||||||
return 'float16'
|
return 'float16'
|
||||||
return 'float32'
|
return 'float32'
|
||||||
|
|
||||||
def torch_dtype(device) -> torch.dtype:
|
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||||
if Globals.full_precision:
|
if Globals.full_precision:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
if choose_precision(device) == 'float16':
|
if choose_precision(device) == 'float16':
|
||||||
@ -36,3 +42,13 @@ def choose_autocast(precision):
|
|||||||
if precision == 'autocast' or precision == 'float16':
|
if precision == 'autocast' or precision == 'float16':
|
||||||
return autocast
|
return autocast
|
||||||
return nullcontext
|
return nullcontext
|
||||||
|
|
||||||
|
def normalize_device(device: str | torch.device) -> torch.device:
|
||||||
|
"""Ensure device has a device index defined, if appropriate."""
|
||||||
|
device = torch.device(device)
|
||||||
|
if device.index is None:
|
||||||
|
# cuda might be the only torch backend that currently uses the device index?
|
||||||
|
# I don't see anything like `current_device` for cpu or mps.
|
||||||
|
if device.type == 'cuda':
|
||||||
|
device = torch.device(device.type, torch.cuda.current_device())
|
||||||
|
return device
|
||||||
|
@ -28,6 +28,7 @@ from typing_extensions import ParamSpec
|
|||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
||||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||||
|
from ..devices import normalize_device, CPU_DEVICE
|
||||||
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
|
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
|
||||||
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
||||||
@ -319,7 +320,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if self.device.type == 'cpu' or self.device.type == 'mps':
|
if self.device.type == 'cpu' or self.device.type == 'mps':
|
||||||
mem_free = psutil.virtual_memory().free
|
mem_free = psutil.virtual_memory().free
|
||||||
elif self.device.type == 'cuda':
|
elif self.device.type == 'cuda':
|
||||||
mem_free, _ = torch.cuda.mem_get_info(self.device)
|
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unrecognized device {self.device}")
|
raise ValueError(f"unrecognized device {self.device}")
|
||||||
# input tensor of [1, 4, h/8, w/8]
|
# input tensor of [1, 4, h/8, w/8]
|
||||||
@ -380,9 +381,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self._model_group.ready()
|
self._model_group.ready()
|
||||||
|
|
||||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
|
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
|
||||||
|
# overridden method; types match the superclass.
|
||||||
if torch_device is None:
|
if torch_device is None:
|
||||||
return self
|
return self
|
||||||
self._model_group.set_device(torch_device)
|
self._model_group.set_device(torch.device(torch_device))
|
||||||
self._model_group.ready()
|
self._model_group.ready()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -689,8 +691,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if device.type == 'mps':
|
if device.type == 'mps':
|
||||||
# workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
|
# workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
|
||||||
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
|
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
|
||||||
self.vae.to('cpu')
|
self.vae.to(CPU_DEVICE)
|
||||||
init_image = init_image.to('cpu')
|
init_image = init_image.to(CPU_DEVICE)
|
||||||
else:
|
else:
|
||||||
self._model_group.load(self.vae)
|
self._model_group.load(self.vae)
|
||||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||||
|
@ -30,6 +30,7 @@ from omegaconf import OmegaConf
|
|||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
|
from ldm.invoke.devices import CPU_DEVICE
|
||||||
from ldm.invoke.generator.diffusers_pipeline import \
|
from ldm.invoke.generator.diffusers_pipeline import \
|
||||||
StableDiffusionGeneratorPipeline
|
StableDiffusionGeneratorPipeline
|
||||||
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
|
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
|
||||||
@ -47,7 +48,7 @@ class ModelManager(object):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: OmegaConf,
|
config: OmegaConf,
|
||||||
device_type: str | torch.device = "cpu",
|
device_type: torch.device = CPU_DEVICE,
|
||||||
precision: str = "float16",
|
precision: str = "float16",
|
||||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||||
sequential_offload = False
|
sequential_offload = False
|
||||||
@ -996,25 +997,25 @@ class ModelManager(object):
|
|||||||
self.models.pop(model_name, None)
|
self.models.pop(model_name, None)
|
||||||
|
|
||||||
def _model_to_cpu(self, model):
|
def _model_to_cpu(self, model):
|
||||||
if self.device == "cpu":
|
if self.device == CPU_DEVICE:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
if isinstance(model, StableDiffusionGeneratorPipeline):
|
if isinstance(model, StableDiffusionGeneratorPipeline):
|
||||||
model.offload_all()
|
model.offload_all()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
model.cond_stage_model.device = "cpu"
|
model.cond_stage_model.device = CPU_DEVICE
|
||||||
model.to("cpu")
|
model.to(CPU_DEVICE)
|
||||||
|
|
||||||
for submodel in ("first_stage_model", "cond_stage_model", "model"):
|
for submodel in ("first_stage_model", "cond_stage_model", "model"):
|
||||||
try:
|
try:
|
||||||
getattr(model, submodel).to("cpu")
|
getattr(model, submodel).to(CPU_DEVICE)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _model_from_cpu(self, model):
|
def _model_from_cpu(self, model):
|
||||||
if self.device == "cpu":
|
if self.device == CPU_DEVICE:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
if isinstance(model, StableDiffusionGeneratorPipeline):
|
if isinstance(model, StableDiffusionGeneratorPipeline):
|
||||||
|
Loading…
Reference in New Issue
Block a user