refactor(diffusers_pipeline): remove unused ModelGroup 🚮

orphaned since #3550 removed the LazilyLoadedModelGroup code, probably unused since ModelCache took over responsibility for sequential offload somewhere around #3335.
This commit is contained in:
Kevin Turner 2023-08-05 21:19:29 -07:00
parent 77033eabd3
commit 6487e7d906
3 changed files with 7 additions and 318 deletions

View File

@ -190,7 +190,6 @@ class InpaintInvocation(BaseInvocation):
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
execution_device=device,
)
yield OldModelInfo(

View File

@ -4,7 +4,6 @@ import dataclasses
import inspect
import math
import secrets
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
@ -41,7 +40,6 @@ from .diffusion import (
InvokeAIDiffuserComponent,
PostprocessingSettings,
)
from .offloading import FullyLoadedModelGroup, ModelGroup
from ..util import normalize_device
@ -286,8 +284,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_model_group: ModelGroup
ID_LENGTH = 8
def __init__(
@ -301,7 +297,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
control_model: ControlNetModel = None,
execution_device: Optional[torch.device] = None,
):
super().__init__(
vae,
@ -326,9 +321,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# control_model=control_model,
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
self._model_group.install(*self._submodels)
self.control_model = control_model
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
@ -364,30 +356,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
else:
self.disable_attention_slicing()
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
# overridden method; types match the superclass.
if torch_device is None:
return self
self._model_group.set_device(torch.device(torch_device))
self._model_group.ready()
@property
def device(self) -> torch.device:
return self._model_group.execution_device
@property
def _submodels(self) -> Sequence[torch.nn.Module]:
module_names, _, _ = self.extract_init_dict(dict(self.config))
submodels = []
for name in module_names.keys():
if hasattr(self, name):
value = getattr(self, name)
else:
value = getattr(self.config, name)
if isinstance(value, torch.nn.Module):
submodels.append(value)
return submodels
def latents_from_embeddings(
self,
latents: torch.Tensor,
@ -404,7 +372,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device("cpu")
else:
scheduler_device = self._model_group.device_for(self.unet)
scheduler_device = self.unet.device
if timesteps is None:
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
@ -458,7 +426,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
(batch_size,),
timesteps[0],
dtype=timesteps.dtype,
device=self._model_group.device_for(self.unet),
device=self.unet.device,
)
latents = self.scheduler.add_noise(latents, noise, batched_t)
@ -675,7 +643,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# 6. Prepare latent variables
initial_latents = self.non_noised_latents_from_image(
init_image,
device=self._model_group.device_for(self.unet),
device=self.unet.device,
dtype=self.unet.dtype,
)
if seed is not None:
@ -725,7 +693,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
nsfw_content_detected=[],
attention_map_saver=result_attention_maps,
)
return self.check_for_safety(output, dtype=conditioning_data.dtype)
return output
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int):
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
@ -734,7 +702,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device("cpu")
else:
scheduler_device = self._model_group.device_for(self.unet)
scheduler_device = self.unet.device
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
@ -760,7 +728,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise_func=None,
seed=None,
) -> InvokeAIStableDiffusionPipelineOutput:
device = self._model_group.device_for(self.unet)
device = self.unet.device
latents_dtype = self.unet.dtype
if isinstance(init_image, PIL.Image.Image):
@ -831,42 +799,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
nsfw_content_detected=[],
attention_map_saver=result_attention_maps,
)
return self.check_for_safety(output, dtype=conditioning_data.dtype)
return output
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
init_image = init_image.to(device=device, dtype=dtype)
with torch.inference_mode():
self._model_group.load(self.vae)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
init_latents = 0.18215 * init_latents
return init_latents
def check_for_safety(self, output, dtype):
with torch.inference_mode():
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
screened_attention_map_saver = None
if has_nsfw_concept is None or not has_nsfw_concept:
screened_attention_map_saver = output.attention_map_saver
return InvokeAIStableDiffusionPipelineOutput(
screened_images,
has_nsfw_concept,
# block the attention maps if NSFW content is detected
attention_map_saver=screened_attention_map_saver,
)
def run_safety_checker(self, image, device=None, dtype=None):
# overriding to use the model group for device info instead of requiring the caller to know.
if self.safety_checker is not None:
device = self._model_group.device_for(self.safety_checker)
return super().run_safety_checker(image, device, dtype)
def decode_latents(self, latents):
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
self._model_group.load(self.vae)
return super().decode_latents(latents)
def debug_latents(self, latents, msg):
from invokeai.backend.image_util import debug_image

View File

@ -1,253 +0,0 @@
from __future__ import annotations
import warnings
import weakref
from abc import ABCMeta, abstractmethod
from collections.abc import MutableMapping
from typing import Callable, Union
import torch
from accelerate.utils import send_to_device
from torch.utils.hooks import RemovableHandle
OFFLOAD_DEVICE = torch.device("cpu")
class _NoModel:
"""Symbol that indicates no model is loaded.
(We can't weakref.ref(None), so this was my best idea at the time to come up with something
type-checkable.)
"""
def __bool__(self):
return False
def to(self, device: torch.device):
pass
def __repr__(self):
return "<NO MODEL>"
NO_MODEL = _NoModel()
class ModelGroup(metaclass=ABCMeta):
"""
A group of models.
The use case I had in mind when writing this is the sub-models used by a DiffusionPipeline,
e.g. its text encoder, U-net, VAE, etc.
Those models are :py:class:`diffusers.ModelMixin`, but "model" is interchangeable with
:py:class:`torch.nn.Module` here.
"""
def __init__(self, execution_device: torch.device):
self.execution_device = execution_device
@abstractmethod
def install(self, *models: torch.nn.Module):
"""Add models to this group."""
pass
@abstractmethod
def uninstall(self, models: torch.nn.Module):
"""Remove models from this group."""
pass
@abstractmethod
def uninstall_all(self):
"""Remove all models from this group."""
@abstractmethod
def load(self, model: torch.nn.Module):
"""Load this model to the execution device."""
pass
@abstractmethod
def offload_current(self):
"""Offload the current model(s) from the execution device."""
pass
@abstractmethod
def ready(self):
"""Ready this group for use."""
pass
@abstractmethod
def set_device(self, device: torch.device):
"""Change which device models from this group will execute on."""
pass
@abstractmethod
def device_for(self, model) -> torch.device:
"""Get the device the given model will execute on.
The model should already be a member of this group.
"""
pass
@abstractmethod
def __contains__(self, model):
"""Check if the model is a member of this group."""
pass
def __repr__(self) -> str:
return f"<{self.__class__.__name__} object at {id(self):x}: " f"device={self.execution_device} >"
class LazilyLoadedModelGroup(ModelGroup):
"""
Only one model from this group is loaded on the GPU at a time.
Running the forward method of a model will displace the previously-loaded model,
offloading it to CPU.
If you call other methods on the model, e.g. ``model.encode(x)`` instead of ``model(x)``,
you will need to explicitly load it with :py:method:`.load(model)`.
This implementation relies on pytorch forward-pre-hooks, and it will copy forward arguments
to the appropriate execution device, as long as they are positional arguments and not keyword
arguments. (I didn't make the rules; that's the way the pytorch 1.13 API works for hooks.)
"""
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
_current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]]
def __init__(self, execution_device: torch.device):
super().__init__(execution_device)
self._hooks = weakref.WeakKeyDictionary()
self._current_model_ref = weakref.ref(NO_MODEL)
def install(self, *models: torch.nn.Module):
for model in models:
self._hooks[model] = model.register_forward_pre_hook(self._pre_hook)
def uninstall(self, *models: torch.nn.Module):
for model in models:
hook = self._hooks.pop(model)
hook.remove()
if self.is_current_model(model):
# no longer hooked by this object, so don't claim to manage it
self.clear_current_model()
def uninstall_all(self):
self.uninstall(*self._hooks.keys())
def _pre_hook(self, module: torch.nn.Module, forward_input):
self.load(module)
if len(forward_input) == 0:
warnings.warn(
f"Hook for {module.__class__.__name__} got no input. " f"Inputs must be positional, not keywords.",
stacklevel=3,
)
return send_to_device(forward_input, self.execution_device)
def load(self, module):
if not self.is_current_model(module):
self.offload_current()
self._load(module)
def offload_current(self):
module = self._current_model_ref()
if module is not NO_MODEL:
module.to(OFFLOAD_DEVICE)
self.clear_current_model()
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}"
module = module.to(self.execution_device)
self.set_current_model(module)
return module
def is_current_model(self, model: torch.nn.Module) -> bool:
"""Is the given model the one currently loaded on the execution device?"""
return self._current_model_ref() is model
def is_empty(self):
"""Are none of this group's models loaded on the execution device?"""
return self._current_model_ref() is NO_MODEL
def set_current_model(self, value):
self._current_model_ref = weakref.ref(value)
def clear_current_model(self):
self._current_model_ref = weakref.ref(NO_MODEL)
def set_device(self, device: torch.device):
if device == self.execution_device:
return
self.execution_device = device
current = self._current_model_ref()
if current is not NO_MODEL:
current.to(device)
def device_for(self, model):
if model not in self:
raise KeyError(f"This does not manage this model {type(model).__name__}", model)
return self.execution_device # this implementation only dispatches to one device
def ready(self):
pass # always ready to load on-demand
def __contains__(self, model):
return model in self._hooks
def __repr__(self) -> str:
return (
f"<{self.__class__.__name__} object at {id(self):x}: "
f"current_model={type(self._current_model_ref()).__name__} >"
)
class FullyLoadedModelGroup(ModelGroup):
"""
A group of models without any implicit loading or unloading.
:py:meth:`.ready` loads _all_ the models to the execution device at once.
"""
_models: weakref.WeakSet
def __init__(self, execution_device: torch.device):
super().__init__(execution_device)
self._models = weakref.WeakSet()
def install(self, *models: torch.nn.Module):
for model in models:
self._models.add(model)
model.to(self.execution_device)
def uninstall(self, *models: torch.nn.Module):
for model in models:
self._models.remove(model)
def uninstall_all(self):
self.uninstall(*self._models)
def load(self, model):
model.to(self.execution_device)
def offload_current(self):
for model in self._models:
model.to(OFFLOAD_DEVICE)
def ready(self):
for model in self._models:
self.load(model)
def set_device(self, device: torch.device):
self.execution_device = device
for model in self._models:
if model.device != OFFLOAD_DEVICE:
model.to(device)
def device_for(self, model):
if model not in self:
raise KeyError("This does not manage this model f{type(model).__name__}", model)
return self.execution_device # this implementation only dispatches to one device
def __contains__(self, model):
return model in self._models