From 6487e7d9064217dcc735db208f246d14477beff3 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sat, 5 Aug 2023 21:19:29 -0700 Subject: [PATCH] =?UTF-8?q?refactor(diffusers=5Fpipeline):=20remove=20unus?= =?UTF-8?q?ed=20ModelGroup=20=F0=9F=9A=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit orphaned since #3550 removed the LazilyLoadedModelGroup code, probably unused since ModelCache took over responsibility for sequential offload somewhere around #3335. --- invokeai/app/invocations/generate.py | 1 - .../stable_diffusion/diffusers_pipeline.py | 71 +---- .../backend/stable_diffusion/offloading.py | 253 ------------------ 3 files changed, 7 insertions(+), 318 deletions(-) delete mode 100644 invokeai/backend/stable_diffusion/offloading.py diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index f8d240e45c..88a76e930c 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -190,7 +190,6 @@ class InpaintInvocation(BaseInvocation): safety_checker=None, feature_extractor=None, requires_safety_checker=False, - execution_device=device, ) yield OldModelInfo( diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 48182a6be2..6891c726dc 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -4,7 +4,6 @@ import dataclasses import inspect import math import secrets -from collections.abc import Sequence from dataclasses import dataclass, field from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union @@ -41,7 +40,6 @@ from .diffusion import ( InvokeAIDiffuserComponent, PostprocessingSettings, ) -from .offloading import FullyLoadedModelGroup, ModelGroup from ..util import normalize_device @@ -286,8 +284,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ - _model_group: ModelGroup - ID_LENGTH = 8 def __init__( @@ -301,7 +297,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): feature_extractor: Optional[CLIPFeatureExtractor], requires_safety_checker: bool = False, control_model: ControlNetModel = None, - execution_device: Optional[torch.device] = None, ): super().__init__( vae, @@ -326,9 +321,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # control_model=control_model, ) self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) - - self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device) - self._model_group.install(*self._submodels) self.control_model = control_model def _adjust_memory_efficient_attention(self, latents: torch.Tensor): @@ -364,30 +356,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): else: self.disable_attention_slicing() - def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False): - # overridden method; types match the superclass. - if torch_device is None: - return self - self._model_group.set_device(torch.device(torch_device)) - self._model_group.ready() - - @property - def device(self) -> torch.device: - return self._model_group.execution_device - - @property - def _submodels(self) -> Sequence[torch.nn.Module]: - module_names, _, _ = self.extract_init_dict(dict(self.config)) - submodels = [] - for name in module_names.keys(): - if hasattr(self, name): - value = getattr(self, name) - else: - value = getattr(self.config, name) - if isinstance(value, torch.nn.Module): - submodels.append(value) - return submodels - def latents_from_embeddings( self, latents: torch.Tensor, @@ -404,7 +372,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if self.scheduler.config.get("cpu_only", False): scheduler_device = torch.device("cpu") else: - scheduler_device = self._model_group.device_for(self.unet) + scheduler_device = self.unet.device if timesteps is None: self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device) @@ -458,7 +426,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): (batch_size,), timesteps[0], dtype=timesteps.dtype, - device=self._model_group.device_for(self.unet), + device=self.unet.device, ) latents = self.scheduler.add_noise(latents, noise, batched_t) @@ -675,7 +643,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # 6. Prepare latent variables initial_latents = self.non_noised_latents_from_image( init_image, - device=self._model_group.device_for(self.unet), + device=self.unet.device, dtype=self.unet.dtype, ) if seed is not None: @@ -725,7 +693,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): nsfw_content_detected=[], attention_map_saver=result_attention_maps, ) - return self.check_for_safety(output, dtype=conditioning_data.dtype) + return output def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int): img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) @@ -734,7 +702,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if self.scheduler.config.get("cpu_only", False): scheduler_device = torch.device("cpu") else: - scheduler_device = self._model_group.device_for(self.unet) + scheduler_device = self.unet.device img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device) timesteps, adjusted_steps = img2img_pipeline.get_timesteps( @@ -760,7 +728,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise_func=None, seed=None, ) -> InvokeAIStableDiffusionPipelineOutput: - device = self._model_group.device_for(self.unet) + device = self.unet.device latents_dtype = self.unet.dtype if isinstance(init_image, PIL.Image.Image): @@ -831,42 +799,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): nsfw_content_detected=[], attention_map_saver=result_attention_maps, ) - return self.check_for_safety(output, dtype=conditioning_data.dtype) + return output def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype): init_image = init_image.to(device=device, dtype=dtype) with torch.inference_mode(): - self._model_group.load(self.vae) init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible! init_latents = 0.18215 * init_latents return init_latents - def check_for_safety(self, output, dtype): - with torch.inference_mode(): - screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype) - screened_attention_map_saver = None - if has_nsfw_concept is None or not has_nsfw_concept: - screened_attention_map_saver = output.attention_map_saver - return InvokeAIStableDiffusionPipelineOutput( - screened_images, - has_nsfw_concept, - # block the attention maps if NSFW content is detected - attention_map_saver=screened_attention_map_saver, - ) - - def run_safety_checker(self, image, device=None, dtype=None): - # overriding to use the model group for device info instead of requiring the caller to know. - if self.safety_checker is not None: - device = self._model_group.device_for(self.safety_checker) - return super().run_safety_checker(image, device, dtype) - - def decode_latents(self, latents): - # Explicit call to get the vae loaded, since `decode` isn't the forward method. - self._model_group.load(self.vae) - return super().decode_latents(latents) - def debug_latents(self, latents, msg): from invokeai.backend.image_util import debug_image diff --git a/invokeai/backend/stable_diffusion/offloading.py b/invokeai/backend/stable_diffusion/offloading.py deleted file mode 100644 index aa2426d514..0000000000 --- a/invokeai/backend/stable_diffusion/offloading.py +++ /dev/null @@ -1,253 +0,0 @@ -from __future__ import annotations - -import warnings -import weakref -from abc import ABCMeta, abstractmethod -from collections.abc import MutableMapping -from typing import Callable, Union - -import torch -from accelerate.utils import send_to_device -from torch.utils.hooks import RemovableHandle - -OFFLOAD_DEVICE = torch.device("cpu") - - -class _NoModel: - """Symbol that indicates no model is loaded. - - (We can't weakref.ref(None), so this was my best idea at the time to come up with something - type-checkable.) - """ - - def __bool__(self): - return False - - def to(self, device: torch.device): - pass - - def __repr__(self): - return "" - - -NO_MODEL = _NoModel() - - -class ModelGroup(metaclass=ABCMeta): - """ - A group of models. - - The use case I had in mind when writing this is the sub-models used by a DiffusionPipeline, - e.g. its text encoder, U-net, VAE, etc. - - Those models are :py:class:`diffusers.ModelMixin`, but "model" is interchangeable with - :py:class:`torch.nn.Module` here. - """ - - def __init__(self, execution_device: torch.device): - self.execution_device = execution_device - - @abstractmethod - def install(self, *models: torch.nn.Module): - """Add models to this group.""" - pass - - @abstractmethod - def uninstall(self, models: torch.nn.Module): - """Remove models from this group.""" - pass - - @abstractmethod - def uninstall_all(self): - """Remove all models from this group.""" - - @abstractmethod - def load(self, model: torch.nn.Module): - """Load this model to the execution device.""" - pass - - @abstractmethod - def offload_current(self): - """Offload the current model(s) from the execution device.""" - pass - - @abstractmethod - def ready(self): - """Ready this group for use.""" - pass - - @abstractmethod - def set_device(self, device: torch.device): - """Change which device models from this group will execute on.""" - pass - - @abstractmethod - def device_for(self, model) -> torch.device: - """Get the device the given model will execute on. - - The model should already be a member of this group. - """ - pass - - @abstractmethod - def __contains__(self, model): - """Check if the model is a member of this group.""" - pass - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} object at {id(self):x}: " f"device={self.execution_device} >" - - -class LazilyLoadedModelGroup(ModelGroup): - """ - Only one model from this group is loaded on the GPU at a time. - - Running the forward method of a model will displace the previously-loaded model, - offloading it to CPU. - - If you call other methods on the model, e.g. ``model.encode(x)`` instead of ``model(x)``, - you will need to explicitly load it with :py:method:`.load(model)`. - - This implementation relies on pytorch forward-pre-hooks, and it will copy forward arguments - to the appropriate execution device, as long as they are positional arguments and not keyword - arguments. (I didn't make the rules; that's the way the pytorch 1.13 API works for hooks.) - """ - - _hooks: MutableMapping[torch.nn.Module, RemovableHandle] - _current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]] - - def __init__(self, execution_device: torch.device): - super().__init__(execution_device) - self._hooks = weakref.WeakKeyDictionary() - self._current_model_ref = weakref.ref(NO_MODEL) - - def install(self, *models: torch.nn.Module): - for model in models: - self._hooks[model] = model.register_forward_pre_hook(self._pre_hook) - - def uninstall(self, *models: torch.nn.Module): - for model in models: - hook = self._hooks.pop(model) - hook.remove() - if self.is_current_model(model): - # no longer hooked by this object, so don't claim to manage it - self.clear_current_model() - - def uninstall_all(self): - self.uninstall(*self._hooks.keys()) - - def _pre_hook(self, module: torch.nn.Module, forward_input): - self.load(module) - if len(forward_input) == 0: - warnings.warn( - f"Hook for {module.__class__.__name__} got no input. " f"Inputs must be positional, not keywords.", - stacklevel=3, - ) - return send_to_device(forward_input, self.execution_device) - - def load(self, module): - if not self.is_current_model(module): - self.offload_current() - self._load(module) - - def offload_current(self): - module = self._current_model_ref() - if module is not NO_MODEL: - module.to(OFFLOAD_DEVICE) - self.clear_current_model() - - def _load(self, module: torch.nn.Module) -> torch.nn.Module: - assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}" - module = module.to(self.execution_device) - self.set_current_model(module) - return module - - def is_current_model(self, model: torch.nn.Module) -> bool: - """Is the given model the one currently loaded on the execution device?""" - return self._current_model_ref() is model - - def is_empty(self): - """Are none of this group's models loaded on the execution device?""" - return self._current_model_ref() is NO_MODEL - - def set_current_model(self, value): - self._current_model_ref = weakref.ref(value) - - def clear_current_model(self): - self._current_model_ref = weakref.ref(NO_MODEL) - - def set_device(self, device: torch.device): - if device == self.execution_device: - return - self.execution_device = device - current = self._current_model_ref() - if current is not NO_MODEL: - current.to(device) - - def device_for(self, model): - if model not in self: - raise KeyError(f"This does not manage this model {type(model).__name__}", model) - return self.execution_device # this implementation only dispatches to one device - - def ready(self): - pass # always ready to load on-demand - - def __contains__(self, model): - return model in self._hooks - - def __repr__(self) -> str: - return ( - f"<{self.__class__.__name__} object at {id(self):x}: " - f"current_model={type(self._current_model_ref()).__name__} >" - ) - - -class FullyLoadedModelGroup(ModelGroup): - """ - A group of models without any implicit loading or unloading. - - :py:meth:`.ready` loads _all_ the models to the execution device at once. - """ - - _models: weakref.WeakSet - - def __init__(self, execution_device: torch.device): - super().__init__(execution_device) - self._models = weakref.WeakSet() - - def install(self, *models: torch.nn.Module): - for model in models: - self._models.add(model) - model.to(self.execution_device) - - def uninstall(self, *models: torch.nn.Module): - for model in models: - self._models.remove(model) - - def uninstall_all(self): - self.uninstall(*self._models) - - def load(self, model): - model.to(self.execution_device) - - def offload_current(self): - for model in self._models: - model.to(OFFLOAD_DEVICE) - - def ready(self): - for model in self._models: - self.load(model) - - def set_device(self, device: torch.device): - self.execution_device = device - for model in self._models: - if model.device != OFFLOAD_DEVICE: - model.to(device) - - def device_for(self, model): - if model not in self: - raise KeyError("This does not manage this model f{type(model).__name__}", model) - return self.execution_device # this implementation only dispatches to one device - - def __contains__(self, model): - return model in self._models