mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
do not manage GPU for pipelines if sequential_offloading is True
This commit is contained in:
parent
63e465eb5c
commit
b9e9087dbe
@ -21,13 +21,14 @@ import gc
|
|||||||
import hashlib
|
import hashlib
|
||||||
import warnings
|
import warnings
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
from contextlib import suppress
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Sequence, Union, Tuple, types, Optional
|
from typing import Dict, Sequence, Union, Set, Tuple, types, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline, StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel, ConfigMixin
|
from diffusers import DiffusionPipeline, StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel, ConfigMixin
|
||||||
from diffusers import logging as diffusers_logging
|
from diffusers import logging as diffusers_logging
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
||||||
@ -87,6 +88,16 @@ MODEL_CLASSES = {
|
|||||||
SDModelType.TextualInversion: TIType,
|
SDModelType.TextualInversion: TIType,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DIFFUSERS_PARTS = {
|
||||||
|
SDModelType.Vae,
|
||||||
|
SDModelType.TextEncoder,
|
||||||
|
SDModelType.Tokenizer,
|
||||||
|
SDModelType.UNet,
|
||||||
|
SDModelType.Scheduler,
|
||||||
|
SDModelType.SafetyChecker,
|
||||||
|
SDModelType.FeatureExtractor,
|
||||||
|
}
|
||||||
|
|
||||||
class ModelStatus(Enum):
|
class ModelStatus(Enum):
|
||||||
unknown='unknown'
|
unknown='unknown'
|
||||||
not_loaded='not loaded'
|
not_loaded='not loaded'
|
||||||
@ -169,7 +180,7 @@ class ModelCache(object):
|
|||||||
subfolder: Path = None,
|
subfolder: Path = None,
|
||||||
submodel: SDModelType = None,
|
submodel: SDModelType = None,
|
||||||
revision: str = None,
|
revision: str = None,
|
||||||
attach_model_part: Tuple[SDModelType, str] = (None, None),
|
attach_model_parts: Optional[Set[Tuple[SDModelType, str]]] = None,
|
||||||
gpu_load: bool = True,
|
gpu_load: bool = True,
|
||||||
) -> ModelLocker: # ?? what does it return
|
) -> ModelLocker: # ?? what does it return
|
||||||
'''
|
'''
|
||||||
@ -213,15 +224,18 @@ class ModelCache(object):
|
|||||||
|
|
||||||
pipeline_context = cache.get_model(
|
pipeline_context = cache.get_model(
|
||||||
'runwayml/stable-diffusion-v1-5',
|
'runwayml/stable-diffusion-v1-5',
|
||||||
attach_model_part=(SDModelType.Vae,'stabilityai/sd-vae-ft-mse')
|
attach_model_parts=set(
|
||||||
|
[SDModelType.Vae,'stabilityai/sd-vae-ft-mse']
|
||||||
|
[SDModelType.UNet,'runwayml/stable-diffusion-1.5','unet'] #type, ID, subfolder
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
The model will be locked into GPU VRAM for the duration of the context.
|
The model will be locked into GPU VRAM for the duration of the context.
|
||||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||||
:param model_type: An SDModelType enum indicating the type of the (parent) model
|
:param model_type: An SDModelType enum indicating the type of the (parent) model
|
||||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||||
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.Vae
|
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.Vae
|
||||||
:param attach_model_part: load and attach a diffusers model component. Pass a tuple of format (SDModelType,repo_id)
|
:param attach_model_parts: load and attach a diffusers model component. Pass a set of tuple of format (SDModelType,repo_id_or_path,subfolder)
|
||||||
:param revision: model revision
|
:param revision: model revision
|
||||||
:param gpu_load: load the model into GPU [default True]
|
:param gpu_load: load the model into GPU [default True]
|
||||||
'''
|
'''
|
||||||
@ -274,8 +288,9 @@ class ModelCache(object):
|
|||||||
self.current_cache_size += mem_used # increment size of the cache
|
self.current_cache_size += mem_used # increment size of the cache
|
||||||
|
|
||||||
# this is a bit of legacy work needed to support the old-style "load this diffuser with custom VAE"
|
# this is a bit of legacy work needed to support the old-style "load this diffuser with custom VAE"
|
||||||
if model_type == SDModelType.Diffusers and attach_model_part[0]:
|
if model_type == SDModelType.Diffusers and attach_model_parts:
|
||||||
self.attach_part(model, *attach_model_part)
|
for attach_model_part in attach_model_parts:
|
||||||
|
self.attach_part(model, *attach_model_part)
|
||||||
|
|
||||||
self.stack.append(key) # add to LRU cache
|
self.stack.append(key) # add to LRU cache
|
||||||
self.models[key] = model # keep copy of model in dict
|
self.models[key] = model # keep copy of model in dict
|
||||||
@ -320,11 +335,12 @@ class ModelCache(object):
|
|||||||
if model.device != cache.execution_device:
|
if model.device != cache.execution_device:
|
||||||
cache.logger.debug(f'Moving {key} into {cache.execution_device}')
|
cache.logger.debug(f'Moving {key} into {cache.execution_device}')
|
||||||
with VRAMUsage() as mem:
|
with VRAMUsage() as mem:
|
||||||
model.to(cache.execution_device) # move into GPU
|
self._to(model,cache.execution_device)
|
||||||
|
# model.to(cache.execution_device) # move into GPU
|
||||||
|
|
||||||
cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
|
cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
|
||||||
cache.model_sizes[key] = mem.vram_used # more accurate size
|
cache.model_sizes[key] = mem.vram_used # more accurate size
|
||||||
|
|
||||||
cache.logger.debug(f'Locking {key} in {cache.execution_device}')
|
|
||||||
cache._print_cuda_stats()
|
cache._print_cuda_stats()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -332,7 +348,8 @@ class ModelCache(object):
|
|||||||
# move it into CPU if it is in GPU and not locked
|
# move it into CPU if it is in GPU and not locked
|
||||||
if hasattr(model, 'to') and (key in cache.loaded_models
|
if hasattr(model, 'to') and (key in cache.loaded_models
|
||||||
and cache.locked_models[key] == 0):
|
and cache.locked_models[key] == 0):
|
||||||
model.to(cache.storage_device)
|
self._go(model,cache.storage_device)
|
||||||
|
# model.to(cache.storage_device)
|
||||||
cache.loaded_models.remove(key)
|
cache.loaded_models.remove(key)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -347,6 +364,18 @@ class ModelCache(object):
|
|||||||
cache._offload_unlocked_models()
|
cache._offload_unlocked_models()
|
||||||
cache._print_cuda_stats()
|
cache._print_cuda_stats()
|
||||||
|
|
||||||
|
def _to(self, model, device):
|
||||||
|
# if set, sequential offload will take care of GPU management for diffusers
|
||||||
|
if self.cache.sequential_offload and isinstance(model, StableDiffusionGeneratorPipeline):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.cache.logger.debug(f'Moving {key} into {cache.execution_device}')
|
||||||
|
model.to(device)
|
||||||
|
if isinstance(model,MODEL_CLASSES[SDModelType.Diffusers]):
|
||||||
|
for part in DIFFUSERS_PARTS:
|
||||||
|
with suppress(Exception):
|
||||||
|
getattr(model,part).to(device)
|
||||||
|
|
||||||
def attach_part(
|
def attach_part(
|
||||||
self,
|
self,
|
||||||
diffusers_model: StableDiffusionPipeline,
|
diffusers_model: StableDiffusionPipeline,
|
||||||
@ -366,7 +395,8 @@ class ModelCache(object):
|
|||||||
model_type=part_type,
|
model_type=part_type,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
)
|
)
|
||||||
part.to(diffusers_model.device)
|
if hasattr(part,'to'):
|
||||||
|
part.to(diffusers_model.device)
|
||||||
setattr(diffusers_model, part_type, part)
|
setattr(diffusers_model, part_type, part)
|
||||||
self.logger.debug(f'Attached {part_type} {part_id}')
|
self.logger.debug(f'Attached {part_type} {part_id}')
|
||||||
|
|
||||||
|
@ -146,6 +146,7 @@ from typing import Callable, Optional, List, Tuple, Union, types
|
|||||||
import safetensors
|
import safetensors
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers import AutoencoderKL
|
||||||
from huggingface_hub import scan_cache_dir
|
from huggingface_hub import scan_cache_dir
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
@ -157,7 +158,7 @@ from invokeai.backend.util import download_with_resume
|
|||||||
|
|
||||||
from ..util import CUDA_DEVICE
|
from ..util import CUDA_DEVICE
|
||||||
from .model_cache import (ModelCache, ModelLocker, ModelStatus, SDModelType,
|
from .model_cache import (ModelCache, ModelLocker, ModelStatus, SDModelType,
|
||||||
SilenceWarnings)
|
SilenceWarnings, DIFFUSERS_PARTS)
|
||||||
|
|
||||||
# We are only starting to number the config file with release 3.
|
# We are only starting to number the config file with release 3.
|
||||||
# The config file version doesn't have to start at release version, but it will help
|
# The config file version doesn't have to start at release version, but it will help
|
||||||
@ -375,12 +376,14 @@ class ModelManager(object):
|
|||||||
# to support the traditional way of attaching a VAE
|
# to support the traditional way of attaching a VAE
|
||||||
# to a model, we hacked in `attach_model_part`
|
# to a model, we hacked in `attach_model_part`
|
||||||
# TODO: generalize this
|
# TODO: generalize this
|
||||||
vae = (None, None)
|
external_parts = set()
|
||||||
if model_type == SDModelType.Diffusers:
|
if model_type == SDModelType.Diffusers:
|
||||||
with suppress(Exception):
|
for part in DIFFUSERS_PARTS:
|
||||||
vae_id = mconfig.vae.get('path') or mconfig.vae.get('repo_id')
|
with suppress(Exception):
|
||||||
vae_subfolder = mconfig.vae.get('subfolder')
|
if part_config := mconfig.get(part):
|
||||||
vae = (SDModelType.Vae, vae_id, vae_subfolder)
|
id = part_config.get('path') or part_config.get('repo_id')
|
||||||
|
subfolder = part_config.get('subfolder')
|
||||||
|
external_parts.add((part, id, subfolder))
|
||||||
|
|
||||||
model_context = self.cache.get_model(
|
model_context = self.cache.get_model(
|
||||||
location,
|
location,
|
||||||
@ -388,7 +391,7 @@ class ModelManager(object):
|
|||||||
revision = revision,
|
revision = revision,
|
||||||
subfolder = subfolder,
|
subfolder = subfolder,
|
||||||
submodel = submodel,
|
submodel = submodel,
|
||||||
attach_model_part = vae,
|
attach_model_parts = external_parts,
|
||||||
)
|
)
|
||||||
|
|
||||||
# in case we need to communicate information about this
|
# in case we need to communicate information about this
|
||||||
|
Loading…
x
Reference in New Issue
Block a user