Compare commits

...

7 Commits

Author SHA1 Message Date
4b9a46e4c2 make ip_adapters work with stable-fast 2023-12-21 17:29:28 -05:00
952b12abb7 resolve conflicts 2023-12-21 16:31:42 -05:00
2ff41afe8c ruff fixes 2023-12-21 16:29:32 -05:00
e22df59239 proof-of-principle support for stable-fast
only compile model the first time :-)

probe for availability of stable-fast compiler and triton at startup time

simplify config logic
2023-12-21 16:28:42 -05:00
e3ab074b95 probe for availability of stable-fast compiler and triton at startup time 2023-12-21 16:10:52 -05:00
6cb3031c09 only compile model the first time :-) 2023-12-20 22:40:56 -05:00
9c1d250665 hacked in stable-fast; can generate one image before crashing 2023-12-20 22:11:16 -05:00
6 changed files with 43 additions and 1 deletions

View File

@ -271,6 +271,7 @@ class InvokeAIAppConfig(InvokeAISettings):
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', json_schema_extra=Categories.Generation)
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Generation)
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", json_schema_extra=Categories.Generation)
stable_fast : bool = Field(default=True, description="Enable stable-fast performance optimizations, if the library is installed and functional", json_schema_extra=Categories.Generation)
# QUEUE
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", json_schema_extra=Categories.Queue)

View File

@ -141,7 +141,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
ip_hidden_states = ipa_embed
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)

View File

@ -12,6 +12,8 @@ class IPAttentionProcessorWeights(torch.nn.Module):
super().__init__()
self.to_k_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
self.to_v_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
for param in self.parameters():
param.requires_grad = False
class IPAttentionWeights(torch.nn.Module):

View File

@ -24,12 +24,14 @@ import sys
import time
from contextlib import suppress
from dataclasses import dataclass, field
from importlib.util import find_spec
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union, types
import torch
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
@ -39,6 +41,26 @@ from .models import BaseModelType, ModelBase, ModelType, SubModelType
if choose_torch_device() == torch.device("mps"):
from torch import mps
SFAST_AVAILABLE = False
TRITON_AVAILABLE = False
XFORMERS_AVAILABLE = False
SFAST_CONFIG = None
TRITON_AVAILABLE = find_spec("triton") is not None
XFORMERS_AVAILABLE = find_spec("xformers") is not None
try:
from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig, compile_unet, compile_vae
SFAST_CONFIG = CompilationConfig.Default()
SFAST_CONFIG.enable_cuda_graph = True
SFAST_CONFIG.enable_xformers = XFORMERS_AVAILABLE
SFAST_CONFIG.enable_triton = TRITON_AVAILABLE
SFAST_AVAILABLE = True
except ImportError:
pass
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
@ -110,6 +132,7 @@ class _CacheRecord:
class ModelCache(object):
def __init__(
self,
app_config: InvokeAIAppConfig,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
execution_device: torch.device = torch.device("cuda"),
@ -122,6 +145,7 @@ class ModelCache(object):
log_memory_usage: bool = False,
):
"""
:param app_config: InvokeAIAppConfig for application
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
@ -135,6 +159,7 @@ class ModelCache(object):
behaviour.
"""
self.model_infos: Dict[str, ModelBase] = {}
self.app_config = app_config
# allow lazy offloading only when vram cache enabled
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self.precision: torch.dtype = precision
@ -239,6 +264,9 @@ class ModelCache(object):
snapshot_before = self._capture_memory_snapshot()
with skip_torch_weight_init():
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
if SFAST_AVAILABLE and self.app_config.stable_fast and submodel:
model = self._compile_model(model, submodel)
snapshot_after = self._capture_memory_snapshot()
end_load_time = time.time()
@ -322,6 +350,16 @@ class ModelCache(object):
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
def _compile_model(self, model: Any, model_type: SubModelType) -> Any:
if model_type == SubModelType("unet"):
self.logger.info("SFast-compiling unet model")
return compile_unet(model, SFAST_CONFIG)
elif model_type == SubModelType("vae"):
self.logger.info("SFast-compiling vae model")
return compile_vae(model, SFAST_CONFIG)
else:
return model
class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load, size_needed):
"""

View File

@ -344,6 +344,7 @@ class ModelManager(object):
self.app_config = InvokeAIAppConfig.get_config()
self.logger = logger
self.cache = ModelCache(
app_config=self.app_config,
max_cache_size=max_cache_size,
max_vram_cache_size=self.app_config.vram_cache_size,
lazy_offloading=self.app_config.lazy_offload,

View File

@ -113,6 +113,7 @@ dependencies = [
"onnx" = ["onnxruntime"]
"onnx-cuda" = ["onnxruntime-gpu"]
"onnx-directml" = ["onnxruntime-directml"]
"stable-fast" = ["stable-fast"]
[project.scripts]