diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index d48c9f922e..88a76e930c 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -1,26 +1,23 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +from contextlib import contextmanager, ContextDecorator from functools import partial from typing import Literal, Optional, get_args -import torch from pydantic import Field from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.backend.generator.inpaint import infill_methods - -from ...backend.generator import Inpaint, InvokeAIGenerator -from ...backend.stable_diffusion import PipelineIntermediateState -from ..util.step_callback import stable_diffusion_step_callback from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext -from .image import ImageOutput - -from ...backend.model_management.lora import ModelPatcher -from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline -from .model import UNetField, VaeField from .compel import ConditioningField -from contextlib import contextmanager, ExitStack, ContextDecorator +from .image import ImageOutput +from .model import UNetField, VaeField +from ..util.step_callback import stable_diffusion_step_callback +from ...backend.generator import Inpaint, InvokeAIGenerator +from ...backend.model_management.lora import ModelPatcher +from ...backend.stable_diffusion import PipelineIntermediateState +from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())] INFILL_METHODS = Literal[tuple(infill_methods())] @@ -193,8 +190,6 @@ class InpaintInvocation(BaseInvocation): safety_checker=None, feature_extractor=None, requires_safety_checker=False, - precision="float16" if dtype == torch.float16 else "float32", - execution_device=device, ) yield OldModelInfo( diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 6e2e0838bc..c15f84ddd0 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -5,15 +5,26 @@ from typing import List, Literal, Optional, Union import einops import torch -from diffusers import ControlNetModel from diffusers.image_processor import VaeImageProcessor +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) from diffusers.schedulers import SchedulerMixin as Scheduler from pydantic import BaseModel, Field, validator from invokeai.app.invocations.metadata import CoreMetadata +from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_management.models import ModelType, SilenceWarnings - +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext +from .compel import ConditioningField +from .controlnet_image_processors import ControlField +from .image import ImageOutput +from .model import ModelInfo, UNetField, VaeField +from ..models.image import ImageCategory, ImageField, ResourceOrigin from ...backend.model_management import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( @@ -24,23 +35,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import ( ) from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP -from ...backend.model_management import ModelPatcher from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision -from ..models.image import ImageCategory, ImageField, ResourceOrigin -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext -from .compel import ConditioningField -from .controlnet_image_processors import ControlField -from .image import ImageOutput -from .model import ModelInfo, UNetField, VaeField -from invokeai.app.util.controlnet_utils import prepare_control_image - -from diffusers.models.attention_processor import ( - AttnProcessor2_0, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, - XFormersAttnProcessor, -) - DEFAULT_PRECISION = choose_precision(choose_torch_device()) @@ -231,7 +226,6 @@ class TextToLatentsInvocation(BaseInvocation): safety_checker=None, feature_extractor=None, requires_safety_checker=False, - precision="float16" if unet.dtype == torch.float16 else "float32", ) def prep_control_data( diff --git a/invokeai/backend/generator/img2img.py b/invokeai/backend/generator/img2img.py index 5490b2325c..8aaaff5deb 100644 --- a/invokeai/backend/generator/img2img.py +++ b/invokeai/backend/generator/img2img.py @@ -1,25 +1,11 @@ """ invokeai.backend.generator.img2img descends from .generator """ -from typing import Optional -import torch -from accelerate.utils import set_seed -from diffusers import logging - -from ..stable_diffusion import ( - ConditioningData, - PostprocessingSettings, - StableDiffusionGeneratorPipeline, -) from .base import Generator class Img2Img(Generator): - def __init__(self, model, precision): - super().__init__(model, precision) - self.init_latent = None # by get_noise() - def get_make_image( self, sampler, @@ -42,51 +28,4 @@ class Img2Img(Generator): Returns a function returning an image derived from the prompt and the initial image Return value depends on the seed at the time you call it. """ - self.perlin = perlin - - # noinspection PyTypeChecker - pipeline: StableDiffusionGeneratorPipeline = self.model - pipeline.scheduler = sampler - - uc, c, extra_conditioning_info = conditioning - conditioning_data = ConditioningData( - uc, - c, - cfg_scale, - extra_conditioning_info, - postprocessing_settings=PostprocessingSettings( - threshold=threshold, - warmup=warmup, - h_symmetry_time_pct=h_symmetry_time_pct, - v_symmetry_time_pct=v_symmetry_time_pct, - ), - ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta) - - def make_image(x_T: torch.Tensor, seed: int): - # FIXME: use x_T for initial seeded noise - # We're not at the moment because the pipeline automatically resizes init_image if - # necessary, which the x_T input might not match. - # In the meantime, reset the seed prior to generating pipeline output so we at least get the same result. - logging.set_verbosity_error() # quench safety check warnings - pipeline_output = pipeline.img2img_from_embeddings( - init_image, - strength, - steps, - conditioning_data, - noise_func=self.get_noise_like, - callback=step_callback, - seed=seed, - ) - if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None: - attention_maps_callback(pipeline_output.attention_map_saver) - return pipeline.numpy_to_pil(pipeline_output.images)[0] - - return make_image - - def get_noise_like(self, like: torch.Tensor): - device = like.device - x = torch.randn_like(like, device=device) - if self.perlin > 0.0: - shape = like.shape - x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2]) - return x + raise NotImplementedError("replaced by invokeai.app.invocations.latent.LatentsToLatentsInvocation") diff --git a/invokeai/backend/generator/inpaint.py b/invokeai/backend/generator/inpaint.py index 7aeb3d4809..494f213d11 100644 --- a/invokeai/backend/generator/inpaint.py +++ b/invokeai/backend/generator/inpaint.py @@ -377,3 +377,11 @@ class Inpaint(Img2Img): ) return corrected_result + + def get_noise_like(self, like: torch.Tensor): + device = like.device + x = torch.randn_like(like, device=device) + if self.perlin > 0.0: + shape = like.shape + x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2]) + return x diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 624d47ff64..c2c8165d02 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -4,25 +4,21 @@ import dataclasses import inspect import math import secrets -from collections.abc import Sequence from dataclasses import dataclass, field from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union -from pydantic import Field -import einops import PIL.Image -import numpy as np -from accelerate.utils import set_seed +import einops import psutil import torch import torchvision.transforms as T +from accelerate.utils import set_seed from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models.controlnet import ControlNetModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( StableDiffusionPipeline, ) - from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( StableDiffusionImg2ImgPipeline, ) @@ -31,21 +27,20 @@ from diffusers.pipelines.stable_diffusion.safety_checker import ( ) from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput -from diffusers.utils import PIL_INTERPOLATION from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.outputs import BaseOutput +from pydantic import Field from torchvision.transforms.functional import resize as tv_resize from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from typing_extensions import ParamSpec from invokeai.app.services.config import InvokeAIAppConfig -from ..util import CPU_DEVICE, normalize_device from .diffusion import ( AttentionMapSaver, InvokeAIDiffuserComponent, PostprocessingSettings, ) -from .offloading import FullyLoadedModelGroup, ModelGroup +from ..util import normalize_device @dataclass @@ -289,8 +284,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ - _model_group: ModelGroup - ID_LENGTH = 8 def __init__( @@ -303,9 +296,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): safety_checker: Optional[StableDiffusionSafetyChecker], feature_extractor: Optional[CLIPFeatureExtractor], requires_safety_checker: bool = False, - precision: str = "float32", control_model: ControlNetModel = None, - execution_device: Optional[torch.device] = None, ): super().__init__( vae, @@ -330,9 +321,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # control_model=control_model, ) self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) - - self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device) - self._model_group.install(*self._submodels) self.control_model = control_model def _adjust_memory_efficient_attention(self, latents: torch.Tensor): @@ -368,72 +356,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): else: self.disable_attention_slicing() - def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False): - # overridden method; types match the superclass. - if torch_device is None: - return self - self._model_group.set_device(torch.device(torch_device)) - self._model_group.ready() - - @property - def device(self) -> torch.device: - return self._model_group.execution_device - - @property - def _submodels(self) -> Sequence[torch.nn.Module]: - module_names, _, _ = self.extract_init_dict(dict(self.config)) - submodels = [] - for name in module_names.keys(): - if hasattr(self, name): - value = getattr(self, name) - else: - value = getattr(self.config, name) - if isinstance(value, torch.nn.Module): - submodels.append(value) - return submodels - - def image_from_embeddings( - self, - latents: torch.Tensor, - num_inference_steps: int, - conditioning_data: ConditioningData, - *, - noise: torch.Tensor, - callback: Callable[[PipelineIntermediateState], None] = None, - run_id=None, - ) -> InvokeAIStableDiffusionPipelineOutput: - r""" - Function invoked when calling the pipeline for generation. - - :param conditioning_data: - :param latents: Pre-generated un-noised latents, to be used as inputs for - image generation. Can be used to tweak the same generation with different prompts. - :param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality - image at the expense of slower inference. - :param noise: Noise to add to the latents, sampled from a Gaussian distribution. - :param callback: - :param run_id: - """ - result_latents, result_attention_map_saver = self.latents_from_embeddings( - latents, - num_inference_steps, - conditioning_data, - noise=noise, - run_id=run_id, - callback=callback, - ) - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - torch.cuda.empty_cache() - - with torch.inference_mode(): - image = self.decode_latents(result_latents) - output = InvokeAIStableDiffusionPipelineOutput( - images=image, - nsfw_content_detected=[], - attention_map_saver=result_attention_map_saver, - ) - return self.check_for_safety(output, dtype=conditioning_data.dtype) - def latents_from_embeddings( self, latents: torch.Tensor, @@ -450,7 +372,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if self.scheduler.config.get("cpu_only", False): scheduler_device = torch.device("cpu") else: - scheduler_device = self._model_group.device_for(self.unet) + scheduler_device = self.unet.device if timesteps is None: self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device) @@ -504,7 +426,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): (batch_size,), timesteps[0], dtype=timesteps.dtype, - device=self._model_group.device_for(self.unet), + device=self.unet.device, ) latents = self.scheduler.add_noise(latents, noise, batched_t) @@ -700,79 +622,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): **kwargs, ).sample - def img2img_from_embeddings( - self, - init_image: Union[torch.FloatTensor, PIL.Image.Image], - strength: float, - num_inference_steps: int, - conditioning_data: ConditioningData, - *, - callback: Callable[[PipelineIntermediateState], None] = None, - run_id=None, - noise_func=None, - seed=None, - ) -> InvokeAIStableDiffusionPipelineOutput: - if isinstance(init_image, PIL.Image.Image): - init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB")) - - if init_image.dim() == 3: - init_image = einops.rearrange(init_image, "c h w -> 1 c h w") - - # 6. Prepare latent variables - initial_latents = self.non_noised_latents_from_image( - init_image, - device=self._model_group.device_for(self.unet), - dtype=self.unet.dtype, - ) - if seed is not None: - set_seed(seed) - noise = noise_func(initial_latents) - - return self.img2img_from_latents_and_embeddings( - initial_latents, - num_inference_steps, - conditioning_data, - strength, - noise, - run_id, - callback, - ) - - def img2img_from_latents_and_embeddings( - self, - initial_latents, - num_inference_steps, - conditioning_data: ConditioningData, - strength, - noise: torch.Tensor, - run_id=None, - callback=None, - ) -> InvokeAIStableDiffusionPipelineOutput: - timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength) - result_latents, result_attention_maps = self.latents_from_embeddings( - latents=initial_latents - if strength < 1.0 - else torch.zeros_like(initial_latents, device=initial_latents.device, dtype=initial_latents.dtype), - num_inference_steps=num_inference_steps, - conditioning_data=conditioning_data, - timesteps=timesteps, - noise=noise, - run_id=run_id, - callback=callback, - ) - - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - torch.cuda.empty_cache() - - with torch.inference_mode(): - image = self.decode_latents(result_latents) - output = InvokeAIStableDiffusionPipelineOutput( - images=image, - nsfw_content_detected=[], - attention_map_saver=result_attention_maps, - ) - return self.check_for_safety(output, dtype=conditioning_data.dtype) - def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int): img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) assert img2img_pipeline.scheduler is self.scheduler @@ -780,7 +629,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if self.scheduler.config.get("cpu_only", False): scheduler_device = torch.device("cpu") else: - scheduler_device = self._model_group.device_for(self.unet) + scheduler_device = self.unet.device img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device) timesteps, adjusted_steps = img2img_pipeline.get_timesteps( @@ -806,7 +655,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise_func=None, seed=None, ) -> InvokeAIStableDiffusionPipelineOutput: - device = self._model_group.device_for(self.unet) + device = self.unet.device latents_dtype = self.unet.dtype if isinstance(init_image, PIL.Image.Image): @@ -877,42 +726,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): nsfw_content_detected=[], attention_map_saver=result_attention_maps, ) - return self.check_for_safety(output, dtype=conditioning_data.dtype) + return output def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype): init_image = init_image.to(device=device, dtype=dtype) with torch.inference_mode(): - self._model_group.load(self.vae) init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible! init_latents = 0.18215 * init_latents return init_latents - def check_for_safety(self, output, dtype): - with torch.inference_mode(): - screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype) - screened_attention_map_saver = None - if has_nsfw_concept is None or not has_nsfw_concept: - screened_attention_map_saver = output.attention_map_saver - return InvokeAIStableDiffusionPipelineOutput( - screened_images, - has_nsfw_concept, - # block the attention maps if NSFW content is detected - attention_map_saver=screened_attention_map_saver, - ) - - def run_safety_checker(self, image, device=None, dtype=None): - # overriding to use the model group for device info instead of requiring the caller to know. - if self.safety_checker is not None: - device = self._model_group.device_for(self.safety_checker) - return super().run_safety_checker(image, device, dtype) - - def decode_latents(self, latents): - # Explicit call to get the vae loaded, since `decode` isn't the forward method. - self._model_group.load(self.vae) - return super().decode_latents(latents) - def debug_latents(self, latents, msg): from invokeai.backend.image_util import debug_image diff --git a/invokeai/backend/stable_diffusion/offloading.py b/invokeai/backend/stable_diffusion/offloading.py deleted file mode 100644 index aa2426d514..0000000000 --- a/invokeai/backend/stable_diffusion/offloading.py +++ /dev/null @@ -1,253 +0,0 @@ -from __future__ import annotations - -import warnings -import weakref -from abc import ABCMeta, abstractmethod -from collections.abc import MutableMapping -from typing import Callable, Union - -import torch -from accelerate.utils import send_to_device -from torch.utils.hooks import RemovableHandle - -OFFLOAD_DEVICE = torch.device("cpu") - - -class _NoModel: - """Symbol that indicates no model is loaded. - - (We can't weakref.ref(None), so this was my best idea at the time to come up with something - type-checkable.) - """ - - def __bool__(self): - return False - - def to(self, device: torch.device): - pass - - def __repr__(self): - return "" - - -NO_MODEL = _NoModel() - - -class ModelGroup(metaclass=ABCMeta): - """ - A group of models. - - The use case I had in mind when writing this is the sub-models used by a DiffusionPipeline, - e.g. its text encoder, U-net, VAE, etc. - - Those models are :py:class:`diffusers.ModelMixin`, but "model" is interchangeable with - :py:class:`torch.nn.Module` here. - """ - - def __init__(self, execution_device: torch.device): - self.execution_device = execution_device - - @abstractmethod - def install(self, *models: torch.nn.Module): - """Add models to this group.""" - pass - - @abstractmethod - def uninstall(self, models: torch.nn.Module): - """Remove models from this group.""" - pass - - @abstractmethod - def uninstall_all(self): - """Remove all models from this group.""" - - @abstractmethod - def load(self, model: torch.nn.Module): - """Load this model to the execution device.""" - pass - - @abstractmethod - def offload_current(self): - """Offload the current model(s) from the execution device.""" - pass - - @abstractmethod - def ready(self): - """Ready this group for use.""" - pass - - @abstractmethod - def set_device(self, device: torch.device): - """Change which device models from this group will execute on.""" - pass - - @abstractmethod - def device_for(self, model) -> torch.device: - """Get the device the given model will execute on. - - The model should already be a member of this group. - """ - pass - - @abstractmethod - def __contains__(self, model): - """Check if the model is a member of this group.""" - pass - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} object at {id(self):x}: " f"device={self.execution_device} >" - - -class LazilyLoadedModelGroup(ModelGroup): - """ - Only one model from this group is loaded on the GPU at a time. - - Running the forward method of a model will displace the previously-loaded model, - offloading it to CPU. - - If you call other methods on the model, e.g. ``model.encode(x)`` instead of ``model(x)``, - you will need to explicitly load it with :py:method:`.load(model)`. - - This implementation relies on pytorch forward-pre-hooks, and it will copy forward arguments - to the appropriate execution device, as long as they are positional arguments and not keyword - arguments. (I didn't make the rules; that's the way the pytorch 1.13 API works for hooks.) - """ - - _hooks: MutableMapping[torch.nn.Module, RemovableHandle] - _current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]] - - def __init__(self, execution_device: torch.device): - super().__init__(execution_device) - self._hooks = weakref.WeakKeyDictionary() - self._current_model_ref = weakref.ref(NO_MODEL) - - def install(self, *models: torch.nn.Module): - for model in models: - self._hooks[model] = model.register_forward_pre_hook(self._pre_hook) - - def uninstall(self, *models: torch.nn.Module): - for model in models: - hook = self._hooks.pop(model) - hook.remove() - if self.is_current_model(model): - # no longer hooked by this object, so don't claim to manage it - self.clear_current_model() - - def uninstall_all(self): - self.uninstall(*self._hooks.keys()) - - def _pre_hook(self, module: torch.nn.Module, forward_input): - self.load(module) - if len(forward_input) == 0: - warnings.warn( - f"Hook for {module.__class__.__name__} got no input. " f"Inputs must be positional, not keywords.", - stacklevel=3, - ) - return send_to_device(forward_input, self.execution_device) - - def load(self, module): - if not self.is_current_model(module): - self.offload_current() - self._load(module) - - def offload_current(self): - module = self._current_model_ref() - if module is not NO_MODEL: - module.to(OFFLOAD_DEVICE) - self.clear_current_model() - - def _load(self, module: torch.nn.Module) -> torch.nn.Module: - assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}" - module = module.to(self.execution_device) - self.set_current_model(module) - return module - - def is_current_model(self, model: torch.nn.Module) -> bool: - """Is the given model the one currently loaded on the execution device?""" - return self._current_model_ref() is model - - def is_empty(self): - """Are none of this group's models loaded on the execution device?""" - return self._current_model_ref() is NO_MODEL - - def set_current_model(self, value): - self._current_model_ref = weakref.ref(value) - - def clear_current_model(self): - self._current_model_ref = weakref.ref(NO_MODEL) - - def set_device(self, device: torch.device): - if device == self.execution_device: - return - self.execution_device = device - current = self._current_model_ref() - if current is not NO_MODEL: - current.to(device) - - def device_for(self, model): - if model not in self: - raise KeyError(f"This does not manage this model {type(model).__name__}", model) - return self.execution_device # this implementation only dispatches to one device - - def ready(self): - pass # always ready to load on-demand - - def __contains__(self, model): - return model in self._hooks - - def __repr__(self) -> str: - return ( - f"<{self.__class__.__name__} object at {id(self):x}: " - f"current_model={type(self._current_model_ref()).__name__} >" - ) - - -class FullyLoadedModelGroup(ModelGroup): - """ - A group of models without any implicit loading or unloading. - - :py:meth:`.ready` loads _all_ the models to the execution device at once. - """ - - _models: weakref.WeakSet - - def __init__(self, execution_device: torch.device): - super().__init__(execution_device) - self._models = weakref.WeakSet() - - def install(self, *models: torch.nn.Module): - for model in models: - self._models.add(model) - model.to(self.execution_device) - - def uninstall(self, *models: torch.nn.Module): - for model in models: - self._models.remove(model) - - def uninstall_all(self): - self.uninstall(*self._models) - - def load(self, model): - model.to(self.execution_device) - - def offload_current(self): - for model in self._models: - model.to(OFFLOAD_DEVICE) - - def ready(self): - for model in self._models: - self.load(model) - - def set_device(self, device: torch.device): - self.execution_device = device - for model in self._models: - if model.device != OFFLOAD_DEVICE: - model.to(device) - - def device_for(self, model): - if model not in self: - raise KeyError("This does not manage this model f{type(model).__name__}", model) - return self.execution_device # this implementation only dispatches to one device - - def __contains__(self, model): - return model in self._models