make ip_adapters work with stable-fast

This commit is contained in:
Lincoln Stein
2023-12-21 17:29:28 -05:00
parent 952b12abb7
commit 4b9a46e4c2
4 changed files with 12 additions and 16 deletions

View File

@ -141,7 +141,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
ip_hidden_states = ipa_embed ip_hidden_states = ipa_embed
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
ip_key = ipa_weights.to_k_ip(ip_hidden_states) ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states) ip_value = ipa_weights.to_v_ip(ip_hidden_states)

View File

@ -12,6 +12,8 @@ class IPAttentionProcessorWeights(torch.nn.Module):
super().__init__() super().__init__()
self.to_k_ip = torch.nn.Linear(in_dim, out_dim, bias=False) self.to_k_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
self.to_v_ip = torch.nn.Linear(in_dim, out_dim, bias=False) self.to_v_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
for param in self.parameters():
param.requires_grad = False
class IPAttentionWeights(torch.nn.Module): class IPAttentionWeights(torch.nn.Module):

View File

@ -24,12 +24,14 @@ import sys
import time import time
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass, field from dataclasses import dataclass, field
from importlib.util import find_spec
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Type, Union, types from typing import Any, Dict, Optional, Type, Union, types
import torch import torch
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
@ -44,22 +46,11 @@ TRITON_AVAILABLE = False
XFORMERS_AVAILABLE = False XFORMERS_AVAILABLE = False
SFAST_CONFIG = None SFAST_CONFIG = None
try: TRITON_AVAILABLE = find_spec("triton") is not None
import triton XFORMERS_AVAILABLE = find_spec("xformers") is not None
TRITON_AVAILABLE = True
except ImportError:
pass
try: try:
import xformers from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig, compile_unet, compile_vae
XFORMERS_AVAILABLE = True
except ImportError:
pass
try:
from sfast.compilers.diffusion_pipeline_compiler import compile_unet, compile_vae, CompilationConfig
SFAST_CONFIG = CompilationConfig.Default() SFAST_CONFIG = CompilationConfig.Default()
SFAST_CONFIG.enable_cuda_graph = True SFAST_CONFIG.enable_cuda_graph = True
@ -141,6 +132,7 @@ class _CacheRecord:
class ModelCache(object): class ModelCache(object):
def __init__( def __init__(
self, self,
app_config: InvokeAIAppConfig,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE, max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
execution_device: torch.device = torch.device("cuda"), execution_device: torch.device = torch.device("cuda"),
@ -153,6 +145,7 @@ class ModelCache(object):
log_memory_usage: bool = False, log_memory_usage: bool = False,
): ):
""" """
:param app_config: InvokeAIAppConfig for application
:param max_cache_size: Maximum size of the RAM cache [6.0 GB] :param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param execution_device: Torch device to load active model into [torch.device('cuda')] :param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')]
@ -166,6 +159,7 @@ class ModelCache(object):
behaviour. behaviour.
""" """
self.model_infos: Dict[str, ModelBase] = {} self.model_infos: Dict[str, ModelBase] = {}
self.app_config = app_config
# allow lazy offloading only when vram cache enabled # allow lazy offloading only when vram cache enabled
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0 self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self.precision: torch.dtype = precision self.precision: torch.dtype = precision
@ -270,7 +264,7 @@ class ModelCache(object):
snapshot_before = self._capture_memory_snapshot() snapshot_before = self._capture_memory_snapshot()
with skip_torch_weight_init(): with skip_torch_weight_init():
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision) model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
if SFAST_AVAILABLE and submodel: if SFAST_AVAILABLE and self.app_config.stable_fast and submodel:
model = self._compile_model(model, submodel) model = self._compile_model(model, submodel)
snapshot_after = self._capture_memory_snapshot() snapshot_after = self._capture_memory_snapshot()

View File

@ -344,6 +344,7 @@ class ModelManager(object):
self.app_config = InvokeAIAppConfig.get_config() self.app_config = InvokeAIAppConfig.get_config()
self.logger = logger self.logger = logger
self.cache = ModelCache( self.cache = ModelCache(
app_config=self.app_config,
max_cache_size=max_cache_size, max_cache_size=max_cache_size,
max_vram_cache_size=self.app_config.vram_cache_size, max_vram_cache_size=self.app_config.vram_cache_size,
lazy_offloading=self.app_config.lazy_offload, lazy_offloading=self.app_config.lazy_offload,