add sd3 to starter models

This commit is contained in:
Lincoln Stein 2024-06-20 18:13:46 -04:00
parent 66260fd345
commit 445561e3a4
5 changed files with 37 additions and 12 deletions

View File

@ -105,7 +105,6 @@ class InvokeAIAppConfig(BaseSettings):
vram: Amount of VRAM reserved for model storage (GB). vram: Amount of VRAM reserved for model storage (GB).
convert_cache: Maximum size of on-disk converted models cache (GB). convert_cache: Maximum size of on-disk converted models cache (GB).
lazy_offload: Keep models in VRAM until their space is needed. lazy_offload: Keep models in VRAM until their space is needed.
load_sd3_encoder_3: Load the memory-intensive SD3 text_encoder_3.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour. log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps` device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32` precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
@ -176,7 +175,6 @@ class InvokeAIAppConfig(BaseSettings):
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).") vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).") convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).")
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.") lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
load_sd3_encoder_3: bool = Field(default=False, description="Load the memory-intensive SD3 text_encoder_3.")
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.") log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
# DEVICE # DEVICE

View File

@ -241,9 +241,15 @@ class ModelCache(ModelCacheBase[AnyModel]):
if vram_in_use <= reserved: if vram_in_use <= reserved:
break break
# only way to remove a quantized model from VRAM is to # Special handling of the stable-diffusion-3:text_encoder_3
# submodel, when the user has loaded a quantized model.
# The only way to remove the quantized version of this model from VRAM is to
# delete it completely - it can't be moved from device to device # delete it completely - it can't be moved from device to device
# This also contains a workaround for quantized models that
# persist indefinitely in VRAM
if cache_entry.is_quantized: if cache_entry.is_quantized:
self._empty_quantized_state_dict(cache_entry.model)
cache_entry.model = None
self._delete_cache_entry(cache_entry) self._delete_cache_entry(cache_entry)
vram_in_use = torch.cuda.memory_allocated() + size_required vram_in_use = torch.cuda.memory_allocated() + size_required
continue continue
@ -426,3 +432,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
del cache_entry del cache_entry
gc.collect() gc.collect()
TorchDevice.empty_cache() TorchDevice.empty_cache()
def _empty_quantized_state_dict(self, model: AnyModel) -> None:
"""Set all keys of a model's state dict to None.
This is a partial workaround for a poorly-understood bug in
transformers' support for quantized T5EncoderModels (text_encoder_3
of SD3). This allows most of the model to be unloaded from VRAM, but
still leaks 8K of VRAM each time the model is unloaded. Using the quantized
version of stable-diffusion-3-medium is NOT recommended.
"""
assert isinstance(model, torch.nn.Module)
sd = model.state_dict()
for k in sd.keys():
sd[k] = None

View File

@ -40,7 +40,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
model_base_to_model_type = { model_base_to_model_type = {
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder", BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder", BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
BaseModelType.StableDiffusion3: "SD3", # non-functional, for completeness only BaseModelType.StableDiffusion3: "SD3",
BaseModelType.StableDiffusionXL: "SDXL", BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner", BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
} }

View File

@ -122,6 +122,13 @@ STARTER_MODELS: list[StarterModel] = [
type=ModelType.Main, type=ModelType.Main,
dependencies=[sdxl_fp16_vae_fix], dependencies=[sdxl_fp16_vae_fix],
), ),
StarterModel(
name="Stable Diffusion 3",
base=BaseModelType.StableDiffusion3,
source="stabilityai/stable-diffusion-3-medium-diffusers",
description="The OG Stable Diffusion 3 base model (beta).",
type=ModelType.Main,
),
# endregion # endregion
# region VAE # region VAE
sdxl_fp16_vae_fix, sdxl_fp16_vae_fix,

View File

@ -34,26 +34,26 @@ classifiers = [
dependencies = [ dependencies = [
# Core generation dependencies, pinned for reproducible builds. # Core generation dependencies, pinned for reproducible builds.
"accelerate==0.30.1", "accelerate==0.30.1",
"bitsandbytes", "bitsandbytes==0.43.1",
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==2.0.2", "compel==2.0.2",
"controlnet-aux==0.0.7", "controlnet-aux==0.0.7",
"diffusers[torch]", "diffusers[torch]==0.29.0",
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids "invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model "mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal() "numpy==1.23.5", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
"onnx==1.15.0", "onnx==1.15.0",
"onnxruntime==1.16.3", "onnxruntime==1.16.3",
"opencv-python==4.9.0.80", "opencv-python==4.9.0.80",
"pytorch-lightning==2.1.3", "pytorch-lightning",
"safetensors==0.4.3", "safetensors==0.4.3",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"torch", "torch==2.2.2",
"torchmetrics==0.11.4", "torchmetrics==0.11.4",
"torchsde==0.2.6", "torchsde==0.2.6",
"torchvision", "torchvision==0.17.2",
"transformers", "transformers==4.41.1",
"sentencepiece", "sentencepiece==0.1.99",
# Core application dependencies, pinned for reproducible builds. # Core application dependencies, pinned for reproducible builds.
"fastapi-events==0.11.0", "fastapi-events==0.11.0",