mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make ip_adapters work with stable-fast
This commit is contained in:
@ -141,7 +141,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
ip_hidden_states = ipa_embed
|
ip_hidden_states = ipa_embed
|
||||||
|
|
||||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||||
|
|
||||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
|
@ -12,6 +12,8 @@ class IPAttentionProcessorWeights(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.to_k_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
|
self.to_k_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
|
||||||
self.to_v_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
|
self.to_v_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
|
||||||
class IPAttentionWeights(torch.nn.Module):
|
class IPAttentionWeights(torch.nn.Module):
|
||||||
|
@ -24,12 +24,14 @@ import sys
|
|||||||
import time
|
import time
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from importlib.util import find_spec
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Type, Union, types
|
from typing import Any, Dict, Optional, Type, Union, types
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||||
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
|
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
|
||||||
|
|
||||||
@ -44,22 +46,11 @@ TRITON_AVAILABLE = False
|
|||||||
XFORMERS_AVAILABLE = False
|
XFORMERS_AVAILABLE = False
|
||||||
SFAST_CONFIG = None
|
SFAST_CONFIG = None
|
||||||
|
|
||||||
try:
|
TRITON_AVAILABLE = find_spec("triton") is not None
|
||||||
import triton
|
XFORMERS_AVAILABLE = find_spec("xformers") is not None
|
||||||
|
|
||||||
TRITON_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import xformers
|
from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig, compile_unet, compile_vae
|
||||||
|
|
||||||
XFORMERS_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
from sfast.compilers.diffusion_pipeline_compiler import compile_unet, compile_vae, CompilationConfig
|
|
||||||
|
|
||||||
SFAST_CONFIG = CompilationConfig.Default()
|
SFAST_CONFIG = CompilationConfig.Default()
|
||||||
SFAST_CONFIG.enable_cuda_graph = True
|
SFAST_CONFIG.enable_cuda_graph = True
|
||||||
@ -141,6 +132,7 @@ class _CacheRecord:
|
|||||||
class ModelCache(object):
|
class ModelCache(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
app_config: InvokeAIAppConfig,
|
||||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||||
execution_device: torch.device = torch.device("cuda"),
|
execution_device: torch.device = torch.device("cuda"),
|
||||||
@ -153,6 +145,7 @@ class ModelCache(object):
|
|||||||
log_memory_usage: bool = False,
|
log_memory_usage: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
:param app_config: InvokeAIAppConfig for application
|
||||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||||
@ -166,6 +159,7 @@ class ModelCache(object):
|
|||||||
behaviour.
|
behaviour.
|
||||||
"""
|
"""
|
||||||
self.model_infos: Dict[str, ModelBase] = {}
|
self.model_infos: Dict[str, ModelBase] = {}
|
||||||
|
self.app_config = app_config
|
||||||
# allow lazy offloading only when vram cache enabled
|
# allow lazy offloading only when vram cache enabled
|
||||||
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||||
self.precision: torch.dtype = precision
|
self.precision: torch.dtype = precision
|
||||||
@ -270,7 +264,7 @@ class ModelCache(object):
|
|||||||
snapshot_before = self._capture_memory_snapshot()
|
snapshot_before = self._capture_memory_snapshot()
|
||||||
with skip_torch_weight_init():
|
with skip_torch_weight_init():
|
||||||
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
||||||
if SFAST_AVAILABLE and submodel:
|
if SFAST_AVAILABLE and self.app_config.stable_fast and submodel:
|
||||||
model = self._compile_model(model, submodel)
|
model = self._compile_model(model, submodel)
|
||||||
|
|
||||||
snapshot_after = self._capture_memory_snapshot()
|
snapshot_after = self._capture_memory_snapshot()
|
||||||
|
@ -344,6 +344,7 @@ class ModelManager(object):
|
|||||||
self.app_config = InvokeAIAppConfig.get_config()
|
self.app_config = InvokeAIAppConfig.get_config()
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.cache = ModelCache(
|
self.cache = ModelCache(
|
||||||
|
app_config=self.app_config,
|
||||||
max_cache_size=max_cache_size,
|
max_cache_size=max_cache_size,
|
||||||
max_vram_cache_size=self.app_config.vram_cache_size,
|
max_vram_cache_size=self.app_config.vram_cache_size,
|
||||||
lazy_offloading=self.app_config.lazy_offload,
|
lazy_offloading=self.app_config.lazy_offload,
|
||||||
|
Reference in New Issue
Block a user