mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
backend: more post-ldm-removal cleanup (#2911)
This commit is contained in:
commit
12c7db3a16
@ -495,18 +495,6 @@ class Generate:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
results = list()
|
||||
init_image = None
|
||||
mask_image = None
|
||||
|
||||
try:
|
||||
if (
|
||||
self.free_gpu_mem
|
||||
and self.model.cond_stage_model.device != self.model.device
|
||||
):
|
||||
self.model.cond_stage_model.device = self.model.device
|
||||
self.model.cond_stage_model.to(self.model.device)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||
|
@ -104,7 +104,7 @@ class ModelManager(object):
|
||||
if model_name in self.models:
|
||||
requested_model = self.models[model_name]["model"]
|
||||
print(f">> Retrieving model {model_name} from system RAM cache")
|
||||
self.models[model_name]["model"] = self._model_from_cpu(requested_model)
|
||||
requested_model.ready()
|
||||
width = self.models[model_name]["width"]
|
||||
height = self.models[model_name]["height"]
|
||||
hash = self.models[model_name]["hash"]
|
||||
@ -499,7 +499,7 @@ class ModelManager(object):
|
||||
|
||||
print(f">> Offloading {model_name} to CPU")
|
||||
model = self.models[model_name]["model"]
|
||||
self.models[model_name]["model"] = self._model_to_cpu(model)
|
||||
model.offload_all()
|
||||
|
||||
gc.collect()
|
||||
if self._has_cuda():
|
||||
@ -557,7 +557,7 @@ class ModelManager(object):
|
||||
"""
|
||||
model_name = model_name or Path(repo_or_path).stem
|
||||
model_description = (
|
||||
model_description or f"Imported diffusers model {model_name}"
|
||||
description or f"Imported diffusers model {model_name}"
|
||||
)
|
||||
new_config = dict(
|
||||
description=model_description,
|
||||
@ -1044,43 +1044,6 @@ class ModelManager(object):
|
||||
self.stack.remove(model_name)
|
||||
self.models.pop(model_name, None)
|
||||
|
||||
def _model_to_cpu(self, model):
|
||||
if self.device == CPU_DEVICE:
|
||||
return model
|
||||
|
||||
if isinstance(model, StableDiffusionGeneratorPipeline):
|
||||
model.offload_all()
|
||||
return model
|
||||
|
||||
model.cond_stage_model.device = CPU_DEVICE
|
||||
model.to(CPU_DEVICE)
|
||||
|
||||
for submodel in ("first_stage_model", "cond_stage_model", "model"):
|
||||
try:
|
||||
getattr(model, submodel).to(CPU_DEVICE)
|
||||
except AttributeError:
|
||||
pass
|
||||
return model
|
||||
|
||||
def _model_from_cpu(self, model):
|
||||
if self.device == CPU_DEVICE:
|
||||
return model
|
||||
|
||||
if isinstance(model, StableDiffusionGeneratorPipeline):
|
||||
model.ready()
|
||||
return model
|
||||
|
||||
model.to(self.device)
|
||||
model.cond_stage_model.device = self.device
|
||||
|
||||
for submodel in ("first_stage_model", "cond_stage_model", "model"):
|
||||
try:
|
||||
getattr(model, submodel).to(self.device)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return model
|
||||
|
||||
def _pop_oldest_model(self):
|
||||
"""
|
||||
Remove the first element of the FIFO, which ought
|
||||
|
@ -3,7 +3,6 @@ Initialization file for invokeai.backend.prompting
|
||||
"""
|
||||
from .conditioning import (
|
||||
get_prompt_structure,
|
||||
get_tokenizer,
|
||||
get_tokens_for_prompt_object,
|
||||
get_uc_and_c_and_ec,
|
||||
split_weighted_subprompts,
|
||||
|
@ -7,7 +7,7 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
|
||||
|
||||
"""
|
||||
import re
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import (
|
||||
@ -17,7 +17,6 @@ from compel.prompt_parser import (
|
||||
Fragment,
|
||||
PromptParser,
|
||||
)
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
@ -25,36 +24,6 @@ from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||
from ..util import torch_dtype
|
||||
|
||||
|
||||
def get_tokenizer(model) -> CLIPTokenizer:
|
||||
# TODO remove legacy ckpt fallback handling
|
||||
return (
|
||||
getattr(model, "tokenizer", None) # diffusers
|
||||
or model.cond_stage_model.tokenizer
|
||||
) # ldm
|
||||
|
||||
|
||||
def get_text_encoder(model) -> Any:
|
||||
# TODO remove legacy ckpt fallback handling
|
||||
return getattr(
|
||||
model, "text_encoder", None
|
||||
) or UnsqueezingLDMTransformer( # diffusers
|
||||
model.cond_stage_model.transformer
|
||||
) # ldm
|
||||
|
||||
|
||||
class UnsqueezingLDMTransformer:
|
||||
def __init__(self, ldm_transformer):
|
||||
self.ldm_transformer = ldm_transformer
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.ldm_transformer.device
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
insufficiently_unsqueezed_tensor = self.ldm_transformer(*args, **kwargs)
|
||||
return insufficiently_unsqueezed_tensor.unsqueeze(0)
|
||||
|
||||
|
||||
def get_uc_and_c_and_ec(
|
||||
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
|
||||
):
|
||||
@ -64,11 +33,10 @@ def get_uc_and_c_and_ec(
|
||||
prompt_string
|
||||
)
|
||||
|
||||
tokenizer = get_tokenizer(model)
|
||||
text_encoder = get_text_encoder(model)
|
||||
tokenizer = model.tokenizer
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False
|
||||
|
@ -54,16 +54,6 @@ class PipelineIntermediateState:
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
|
||||
# copied from configs/stable-diffusion/v1-inference.yaml
|
||||
_default_personalization_config_params = dict(
|
||||
placeholder_strings=["*"],
|
||||
initializer_wods=["sculpture"],
|
||||
per_image_tokens=False,
|
||||
num_vectors_per_token=1,
|
||||
progressive_words=False,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddsMaskLatents:
|
||||
"""Add the channels required for inpainting model input.
|
||||
@ -175,7 +165,7 @@ def image_resized_to_grid_as_tensor(
|
||||
:param normalize: scale the range to [-1, 1] instead of [0, 1]
|
||||
:param multiple_of: resize the input so both dimensions are a multiple of this
|
||||
"""
|
||||
w, h = trim_to_multiple_of(*image.size)
|
||||
w, h = trim_to_multiple_of(*image.size, multiple_of=multiple_of)
|
||||
transformation = T.Compose(
|
||||
[
|
||||
T.Resize((h, w), T.InterpolationMode.LANCZOS),
|
||||
@ -290,10 +280,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offsensive or harmful.
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
@ -436,11 +426,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
"""
|
||||
Ready this pipeline's models.
|
||||
|
||||
i.e. pre-load them to the GPU if appropriate.
|
||||
i.e. preload them to the GPU if appropriate.
|
||||
"""
|
||||
self._model_group.ready()
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||
# overridden method; types match the superclass.
|
||||
if torch_device is None:
|
||||
return self
|
||||
@ -917,20 +907,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
device=self._model_group.device_for(self.unet),
|
||||
)
|
||||
|
||||
@property
|
||||
def cond_stage_model(self):
|
||||
return self.embeddings_provider
|
||||
|
||||
@torch.inference_mode()
|
||||
def _tokenize(self, prompt: Union[str, List[str]]):
|
||||
return self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
@property
|
||||
def channels(self) -> int:
|
||||
"""Compatible with DiffusionWrapper"""
|
||||
@ -942,11 +918,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
return super().decode_latents(latents)
|
||||
|
||||
def debug_latents(self, latents, msg):
|
||||
from invokeai.backend.image_util import debug_image
|
||||
with torch.inference_mode():
|
||||
from ldm.util import debug_image
|
||||
|
||||
decoded = self.numpy_to_pil(self.decode_latents(latents))
|
||||
for i, img in enumerate(decoded):
|
||||
debug_image(
|
||||
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
||||
)
|
||||
for i, img in enumerate(decoded):
|
||||
debug_image(
|
||||
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
||||
)
|
||||
|
@ -29,7 +29,6 @@ from ..image_util import PngWriter, retrieve_metadata
|
||||
from ...frontend.merge.merge_diffusers import merge_diffusion_models
|
||||
from ..prompting import (
|
||||
get_prompt_structure,
|
||||
get_tokenizer,
|
||||
get_tokens_for_prompt_object,
|
||||
)
|
||||
from ..stable_diffusion import PipelineIntermediateState
|
||||
@ -1274,7 +1273,7 @@ class InvokeAIWebServer:
|
||||
None
|
||||
if type(parsed_prompt) is Blend
|
||||
else get_tokens_for_prompt_object(
|
||||
get_tokenizer(self.generate.model), parsed_prompt
|
||||
self.generate.model.tokenizer, parsed_prompt
|
||||
)
|
||||
)
|
||||
attention_maps_image_base64_url = (
|
||||
|
Loading…
Reference in New Issue
Block a user