Compare commits

...

7 Commits

Author SHA1 Message Date
4b9a46e4c2 make ip_adapters work with stable-fast 2023-12-21 17:29:28 -05:00
952b12abb7 resolve conflicts 2023-12-21 16:31:42 -05:00
2ff41afe8c ruff fixes 2023-12-21 16:29:32 -05:00
e22df59239 proof-of-principle support for stable-fast
only compile model the first time :-)

probe for availability of stable-fast compiler and triton at startup time

simplify config logic
2023-12-21 16:28:42 -05:00
e3ab074b95 probe for availability of stable-fast compiler and triton at startup time 2023-12-21 16:10:52 -05:00
6cb3031c09 only compile model the first time :-) 2023-12-20 22:40:56 -05:00
9c1d250665 hacked in stable-fast; can generate one image before crashing 2023-12-20 22:11:16 -05:00
6 changed files with 43 additions and 1 deletions

View File

@ -271,6 +271,7 @@ class InvokeAIAppConfig(InvokeAISettings):
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', json_schema_extra=Categories.Generation) attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', json_schema_extra=Categories.Generation)
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Generation) force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Generation)
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", json_schema_extra=Categories.Generation) png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", json_schema_extra=Categories.Generation)
stable_fast : bool = Field(default=True, description="Enable stable-fast performance optimizations, if the library is installed and functional", json_schema_extra=Categories.Generation)
# QUEUE # QUEUE
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", json_schema_extra=Categories.Queue) max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", json_schema_extra=Categories.Queue)

View File

@ -141,7 +141,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
ip_hidden_states = ipa_embed ip_hidden_states = ipa_embed
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
ip_key = ipa_weights.to_k_ip(ip_hidden_states) ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states) ip_value = ipa_weights.to_v_ip(ip_hidden_states)

View File

@ -12,6 +12,8 @@ class IPAttentionProcessorWeights(torch.nn.Module):
super().__init__() super().__init__()
self.to_k_ip = torch.nn.Linear(in_dim, out_dim, bias=False) self.to_k_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
self.to_v_ip = torch.nn.Linear(in_dim, out_dim, bias=False) self.to_v_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
for param in self.parameters():
param.requires_grad = False
class IPAttentionWeights(torch.nn.Module): class IPAttentionWeights(torch.nn.Module):

View File

@ -24,12 +24,14 @@ import sys
import time import time
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass, field from dataclasses import dataclass, field
from importlib.util import find_spec
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Type, Union, types from typing import Any, Dict, Optional, Type, Union, types
import torch import torch
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
@ -39,6 +41,26 @@ from .models import BaseModelType, ModelBase, ModelType, SubModelType
if choose_torch_device() == torch.device("mps"): if choose_torch_device() == torch.device("mps"):
from torch import mps from torch import mps
SFAST_AVAILABLE = False
TRITON_AVAILABLE = False
XFORMERS_AVAILABLE = False
SFAST_CONFIG = None
TRITON_AVAILABLE = find_spec("triton") is not None
XFORMERS_AVAILABLE = find_spec("xformers") is not None
try:
from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig, compile_unet, compile_vae
SFAST_CONFIG = CompilationConfig.Default()
SFAST_CONFIG.enable_cuda_graph = True
SFAST_CONFIG.enable_xformers = XFORMERS_AVAILABLE
SFAST_CONFIG.enable_triton = TRITON_AVAILABLE
SFAST_AVAILABLE = True
except ImportError:
pass
# Maximum size of the cache, in gigs # Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously # Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0 DEFAULT_MAX_CACHE_SIZE = 6.0
@ -110,6 +132,7 @@ class _CacheRecord:
class ModelCache(object): class ModelCache(object):
def __init__( def __init__(
self, self,
app_config: InvokeAIAppConfig,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE, max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
execution_device: torch.device = torch.device("cuda"), execution_device: torch.device = torch.device("cuda"),
@ -122,6 +145,7 @@ class ModelCache(object):
log_memory_usage: bool = False, log_memory_usage: bool = False,
): ):
""" """
:param app_config: InvokeAIAppConfig for application
:param max_cache_size: Maximum size of the RAM cache [6.0 GB] :param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_device: Torch device to load active model into [torch.device('cuda')] :param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')]
@ -135,6 +159,7 @@ class ModelCache(object):
behaviour. behaviour.
""" """
self.model_infos: Dict[str, ModelBase] = {} self.model_infos: Dict[str, ModelBase] = {}
self.app_config = app_config
# allow lazy offloading only when vram cache enabled # allow lazy offloading only when vram cache enabled
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0 self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self.precision: torch.dtype = precision self.precision: torch.dtype = precision
@ -239,6 +264,9 @@ class ModelCache(object):
snapshot_before = self._capture_memory_snapshot() snapshot_before = self._capture_memory_snapshot()
with skip_torch_weight_init(): with skip_torch_weight_init():
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision) model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
if SFAST_AVAILABLE and self.app_config.stable_fast and submodel:
model = self._compile_model(model, submodel)
snapshot_after = self._capture_memory_snapshot() snapshot_after = self._capture_memory_snapshot()
end_load_time = time.time() end_load_time = time.time()
@ -322,6 +350,16 @@ class ModelCache(object):
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
) )
def _compile_model(self, model: Any, model_type: SubModelType) -> Any:
if model_type == SubModelType("unet"):
self.logger.info("SFast-compiling unet model")
return compile_unet(model, SFAST_CONFIG)
elif model_type == SubModelType("vae"):
self.logger.info("SFast-compiling vae model")
return compile_vae(model, SFAST_CONFIG)
else:
return model
class ModelLocker(object): class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load, size_needed): def __init__(self, cache, key, model, gpu_load, size_needed):
""" """

View File

@ -344,6 +344,7 @@ class ModelManager(object):
self.app_config = InvokeAIAppConfig.get_config() self.app_config = InvokeAIAppConfig.get_config()
self.logger = logger self.logger = logger
self.cache = ModelCache( self.cache = ModelCache(
app_config=self.app_config,
max_cache_size=max_cache_size, max_cache_size=max_cache_size,
max_vram_cache_size=self.app_config.vram_cache_size, max_vram_cache_size=self.app_config.vram_cache_size,
lazy_offloading=self.app_config.lazy_offload, lazy_offloading=self.app_config.lazy_offload,

View File

@ -113,6 +113,7 @@ dependencies = [
"onnx" = ["onnxruntime"] "onnx" = ["onnxruntime"]
"onnx-cuda" = ["onnxruntime-gpu"] "onnx-cuda" = ["onnxruntime-gpu"]
"onnx-directml" = ["onnxruntime-directml"] "onnx-directml" = ["onnxruntime-directml"]
"stable-fast" = ["stable-fast"]
[project.scripts] [project.scripts]