mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
probe for availability of stable-fast compiler and triton at startup time
This commit is contained in:
parent
6cb3031c09
commit
e3ab074b95
@ -271,6 +271,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', json_schema_extra=Categories.Generation)
|
||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Generation)
|
||||
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", json_schema_extra=Categories.Generation)
|
||||
stable_fast : bool = Field(default=True, description="Enable stable-fast performance optimizations, if the library is installed and functional", json_schema_extra=Categories.Generation)
|
||||
|
||||
# QUEUE
|
||||
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", json_schema_extra=Categories.Queue)
|
||||
|
@ -39,13 +39,29 @@ from .models import BaseModelType, ModelBase, ModelType, SubModelType
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
sfast_available = True
|
||||
if sfast_available:
|
||||
from sfast.compilers.diffusion_pipeline_compiler import (compile,
|
||||
compile_unet,
|
||||
compile_vae,
|
||||
CompilationConfig
|
||||
)
|
||||
SFAST_AVAILABLE = False
|
||||
TRITON_AVAILABLE = False
|
||||
SFAST_CONFIG = None
|
||||
|
||||
try:
|
||||
import triton
|
||||
|
||||
TRITON_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from sfast.compilers.diffusion_pipeline_compiler import compile_unet, compile_vae, CompilationConfig
|
||||
|
||||
SFAST_CONFIG = CompilationConfig.Default()
|
||||
SFAST_CONFIG.enable_xformers = True
|
||||
SFAST_CONFIG.enable_cuda_graph = True
|
||||
if TRITON_AVAILABLE:
|
||||
SFAST_CONFIG.enable_triton = True
|
||||
SFAST_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
@ -247,7 +263,7 @@ class ModelCache(object):
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
with skip_torch_weight_init():
|
||||
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
||||
if sfast_available and submodel:
|
||||
if SFAST_AVAILABLE and submodel:
|
||||
model = self._compile_model(model, submodel)
|
||||
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
@ -333,18 +349,15 @@ class ModelCache(object):
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
def _compile_model(self, model, model_type):
|
||||
config = CompilationConfig.Default()
|
||||
config.enable_xformers = True
|
||||
config.enable_triton = True
|
||||
config.enable_cuda_graph = True
|
||||
def _compile_model(self, model: Any, model_type: SubModelType) -> Any:
|
||||
if model_type == SubModelType("unet"):
|
||||
return compile_unet(model, config)
|
||||
self.logger.info("SFast-compiling unet model")
|
||||
return compile_unet(model, SFAST_CONFIG)
|
||||
elif model_type == SubModelType("vae"):
|
||||
return compile_vae(model, config)
|
||||
self.logger.info("SFast-compiling vae model")
|
||||
return compile_vae(model, SFAST_CONFIG)
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
class ModelLocker(object):
|
||||
def __init__(self, cache, key, model, gpu_load, size_needed):
|
||||
|
@ -113,6 +113,7 @@ dependencies = [
|
||||
"onnx" = ["onnxruntime"]
|
||||
"onnx-cuda" = ["onnxruntime-gpu"]
|
||||
"onnx-directml" = ["onnxruntime-directml"]
|
||||
"stable-fast" = ["stable-fast"]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user