mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add config variable to suppress loading of sd3 text_encoder_3 T5 model
This commit is contained in:
parent
f65d50a4dd
commit
423057a2e8
@ -1,3 +1,4 @@
|
|||||||
|
from contextlib import ExitStack
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -23,7 +24,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
|||||||
from invokeai.app.util.misc import SEED_MAX
|
from invokeai.app.util.misc import SEED_MAX
|
||||||
from invokeai.backend.model_manager.config import SubModelType
|
from invokeai.backend.model_manager.config import SubModelType
|
||||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
|
|
||||||
sd3_pipeline: Optional[StableDiffusion3Pipeline] = None
|
sd3_pipeline: Optional[StableDiffusion3Pipeline] = None
|
||||||
transformer_info: Optional[LoadedModel] = None
|
transformer_info: Optional[LoadedModel] = None
|
||||||
@ -148,39 +148,35 @@ class StableDiffusion3Invocation(BaseInvocation):
|
|||||||
return v % (SEED_MAX + 1)
|
return v % (SEED_MAX + 1)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
global sd3_pipeline, transformer_info, tokenizer_1_info, tokenizer_2_info, tokenizer_3_info, text_encoder_1_info, text_encoder_2_info, text_encoder_3_info
|
app_config = context.config.get()
|
||||||
|
load_te3 = app_config.load_sd3_encoder_3
|
||||||
|
|
||||||
if not transformer_info:
|
transformer_info = context.models.load(self.transformer.transformer)
|
||||||
transformer_info = context.models.load(self.transformer.transformer)
|
tokenizer_1_info = context.models.load(self.clip.tokenizer_1)
|
||||||
if not tokenizer_1_info:
|
tokenizer_2_info = context.models.load(self.clip.tokenizer_2)
|
||||||
tokenizer_1_info = context.models.load(self.clip.tokenizer_1)
|
text_encoder_1_info = context.models.load(self.clip.text_encoder_1)
|
||||||
if not tokenizer_2_info:
|
text_encoder_2_info = context.models.load(self.clip.text_encoder_2)
|
||||||
tokenizer_2_info = context.models.load(self.clip.tokenizer_2)
|
|
||||||
if not tokenizer_3_info:
|
|
||||||
tokenizer_3_info = context.models.load(self.clip.tokenizer_3)
|
|
||||||
if not text_encoder_1_info:
|
|
||||||
text_encoder_1_info = context.models.load(self.clip.text_encoder_1)
|
|
||||||
if not text_encoder_2_info:
|
|
||||||
text_encoder_2_info = context.models.load(self.clip.text_encoder_2)
|
|
||||||
if not text_encoder_3_info:
|
|
||||||
text_encoder_3_info = context.models.load(self.clip.text_encoder_3)
|
|
||||||
|
|
||||||
with (
|
with ExitStack() as stack:
|
||||||
tokenizer_1_info as tokenizer_1,
|
tokenizer_1 = stack.enter_context(tokenizer_1_info)
|
||||||
tokenizer_2_info as tokenizer_2,
|
tokenizer_2 = stack.enter_context(tokenizer_2_info)
|
||||||
tokenizer_3_info as tokenizer_3,
|
text_encoder_1 = stack.enter_context(text_encoder_1_info)
|
||||||
text_encoder_1_info as text_encoder_1,
|
text_encoder_2 = stack.enter_context(text_encoder_2_info)
|
||||||
text_encoder_2_info as text_encoder_2,
|
transformer = stack.enter_context(transformer_info)
|
||||||
text_encoder_3_info as text_encoder_3,
|
|
||||||
transformer_info as transformer,
|
|
||||||
):
|
|
||||||
assert isinstance(transformer, SD3Transformer2DModel)
|
assert isinstance(transformer, SD3Transformer2DModel)
|
||||||
assert isinstance(text_encoder_1, CLIPTextModelWithProjection)
|
assert isinstance(text_encoder_1, CLIPTextModelWithProjection)
|
||||||
assert isinstance(text_encoder_2, CLIPTextModelWithProjection)
|
assert isinstance(text_encoder_2, CLIPTextModelWithProjection)
|
||||||
assert isinstance(text_encoder_3, T5EncoderModel)
|
|
||||||
assert isinstance(tokenizer_1, CLIPTokenizer)
|
assert isinstance(tokenizer_1, CLIPTokenizer)
|
||||||
assert isinstance(tokenizer_2, CLIPTokenizer)
|
assert isinstance(tokenizer_2, CLIPTokenizer)
|
||||||
assert isinstance(tokenizer_3, T5TokenizerFast)
|
|
||||||
|
if load_te3:
|
||||||
|
tokenizer_3 = stack.enter_context(context.models.load(self.clip.tokenizer_3))
|
||||||
|
text_encoder_3 = stack.enter_context(context.models.load(self.clip.text_encoder_3))
|
||||||
|
assert isinstance(text_encoder_3, T5EncoderModel)
|
||||||
|
assert isinstance(tokenizer_3, T5TokenizerFast)
|
||||||
|
else:
|
||||||
|
tokenizer_3 = None
|
||||||
|
text_encoder_3 = None
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
context=context,
|
context=context,
|
||||||
@ -189,21 +185,17 @@ class StableDiffusion3Invocation(BaseInvocation):
|
|||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(sd3_pipeline, StableDiffusion3Pipeline):
|
sd3_pipeline = StableDiffusion3Pipeline(
|
||||||
sd3_pipeline = StableDiffusion3Pipeline(
|
transformer=transformer,
|
||||||
transformer=transformer,
|
vae=FakeVae(),
|
||||||
vae=FakeVae(),
|
text_encoder=text_encoder_1,
|
||||||
text_encoder=text_encoder_1,
|
text_encoder_2=text_encoder_2,
|
||||||
text_encoder_2=text_encoder_2,
|
text_encoder_3=text_encoder_3,
|
||||||
text_encoder_3=text_encoder_3,
|
tokenizer=tokenizer_1,
|
||||||
tokenizer=tokenizer_1,
|
tokenizer_2=tokenizer_2,
|
||||||
tokenizer_2=tokenizer_2,
|
tokenizer_3=tokenizer_3,
|
||||||
tokenizer_3=tokenizer_3,
|
scheduler=scheduler,
|
||||||
scheduler=scheduler,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
sd3_pipeline.components["scheduler"] = scheduler
|
|
||||||
sd3_pipeline.to(TorchDevice.choose_torch_device().type)
|
|
||||||
|
|
||||||
results = sd3_pipeline(
|
results = sd3_pipeline(
|
||||||
self.positive_prompt,
|
self.positive_prompt,
|
||||||
|
@ -104,6 +104,7 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
vram: Amount of VRAM reserved for model storage (GB).
|
vram: Amount of VRAM reserved for model storage (GB).
|
||||||
convert_cache: Maximum size of on-disk converted models cache (GB).
|
convert_cache: Maximum size of on-disk converted models cache (GB).
|
||||||
lazy_offload: Keep models in VRAM until their space is needed.
|
lazy_offload: Keep models in VRAM until their space is needed.
|
||||||
|
load_sd3_encoder_3: Load the memory-intensive SD3 text_encoder_3.
|
||||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
||||||
@ -173,6 +174,7 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
|
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
|
||||||
convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).")
|
convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).")
|
||||||
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
|
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
|
||||||
|
load_sd3_encoder_3: bool = Field(default=False, description="Load the memory-intensive SD3 text_encoder_3.")
|
||||||
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
|
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
|
||||||
|
|
||||||
# DEVICE
|
# DEVICE
|
||||||
|
@ -84,6 +84,8 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
except IndexError:
|
except IndexError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
self._logger.info(f"Loading {config.key}:{submodel_type}")
|
||||||
|
|
||||||
cache_path: Path = self._convert_cache.cache_path(str(model_path))
|
cache_path: Path = self._convert_cache.cache_path(str(model_path))
|
||||||
if self._needs_conversion(config, model_path, cache_path):
|
if self._needs_conversion(config, model_path, cache_path):
|
||||||
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
||||||
|
@ -161,11 +161,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self.make_room(size)
|
self.make_room(size)
|
||||||
|
|
||||||
is_quantized = hasattr(model, "is_quantized") and model.is_quantized
|
is_quantized = hasattr(model, "is_quantized") and model.is_quantized
|
||||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not is_quantized else None
|
||||||
cache_record = CacheRecord(
|
cache_record = CacheRecord(
|
||||||
key=key,
|
key=key,
|
||||||
model=model,
|
model=model,
|
||||||
device=self._storage_device,
|
device=self._execution_device
|
||||||
|
if is_quantized
|
||||||
|
else self._storage_device, # quantized models are loaded directly into CUDA
|
||||||
is_quantized=is_quantized,
|
is_quantized=is_quantized,
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
size=size,
|
size=size,
|
||||||
@ -235,26 +237,28 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
reserved = self._max_vram_cache_size * GIG
|
reserved = self._max_vram_cache_size * GIG
|
||||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
||||||
delete_it = False
|
|
||||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||||
if vram_in_use <= reserved:
|
if vram_in_use <= reserved:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# only way to remove a quantized model from VRAM is to
|
||||||
|
# delete it completely - it can't be moved from device to device
|
||||||
|
if cache_entry.is_quantized:
|
||||||
|
self._delete_cache_entry(cache_entry)
|
||||||
|
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||||
|
continue
|
||||||
|
|
||||||
if not cache_entry.loaded:
|
if not cache_entry.loaded:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not cache_entry.locked:
|
if not cache_entry.locked:
|
||||||
if cache_entry.is_quantized:
|
self.move_model_to_device(cache_entry, self.storage_device)
|
||||||
self._delete_cache_entry(cache_entry)
|
cache_entry.loaded = False
|
||||||
delete_it = True
|
|
||||||
else:
|
|
||||||
self.move_model_to_device(cache_entry, self.storage_device)
|
|
||||||
cache_entry.loaded = False
|
|
||||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||||
)
|
)
|
||||||
|
gc.collect()
|
||||||
if delete_it:
|
|
||||||
del cache_entry
|
|
||||||
TorchDevice.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||||
@ -268,7 +272,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||||
source_device = cache_entry.device
|
source_device = cache_entry.device
|
||||||
|
|
||||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
# Note: We compare device types so that 'cuda' == 'cuda:0'.
|
||||||
# This would need to be revised to support multi-GPU.
|
# This would need to be revised to support multi-GPU.
|
||||||
if torch.device(source_device).type == torch.device(target_device).type:
|
if torch.device(source_device).type == torch.device(target_device).type:
|
||||||
return
|
return
|
||||||
@ -277,9 +281,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
if not hasattr(cache_entry.model, "to"):
|
if not hasattr(cache_entry.model, "to"):
|
||||||
return
|
return
|
||||||
|
|
||||||
if cache_entry.is_quantized: # can't move quantized models around
|
|
||||||
return
|
|
||||||
|
|
||||||
# This roundabout method for moving the model around is done to avoid
|
# This roundabout method for moving the model around is done to avoid
|
||||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||||
@ -422,5 +423,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||||
self._cache_stack.remove(cache_entry.key)
|
self._cache_stack.remove(cache_entry.key)
|
||||||
del self._cached_models[cache_entry.key]
|
del self._cached_models[cache_entry.key]
|
||||||
|
del cache_entry
|
||||||
gc.collect()
|
gc.collect()
|
||||||
TorchDevice.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
|
Loading…
Reference in New Issue
Block a user