backend: more post-ldm-removal cleanup (#2911)

This commit is contained in:
Lincoln Stein 2023-03-09 23:11:10 -05:00 committed by GitHub
commit 12c7db3a16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 17 additions and 125 deletions

View File

@ -495,18 +495,6 @@ class Generate:
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
results = list() results = list()
init_image = None
mask_image = None
try:
if (
self.free_gpu_mem
and self.model.cond_stage_model.device != self.model.device
):
self.model.cond_stage_model.device = self.model.device
self.model.cond_stage_model.to(self.model.device)
except AttributeError:
pass
try: try:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec( uc, c, extra_conditioning_info = get_uc_and_c_and_ec(

View File

@ -104,7 +104,7 @@ class ModelManager(object):
if model_name in self.models: if model_name in self.models:
requested_model = self.models[model_name]["model"] requested_model = self.models[model_name]["model"]
print(f">> Retrieving model {model_name} from system RAM cache") print(f">> Retrieving model {model_name} from system RAM cache")
self.models[model_name]["model"] = self._model_from_cpu(requested_model) requested_model.ready()
width = self.models[model_name]["width"] width = self.models[model_name]["width"]
height = self.models[model_name]["height"] height = self.models[model_name]["height"]
hash = self.models[model_name]["hash"] hash = self.models[model_name]["hash"]
@ -499,7 +499,7 @@ class ModelManager(object):
print(f">> Offloading {model_name} to CPU") print(f">> Offloading {model_name} to CPU")
model = self.models[model_name]["model"] model = self.models[model_name]["model"]
self.models[model_name]["model"] = self._model_to_cpu(model) model.offload_all()
gc.collect() gc.collect()
if self._has_cuda(): if self._has_cuda():
@ -557,7 +557,7 @@ class ModelManager(object):
""" """
model_name = model_name or Path(repo_or_path).stem model_name = model_name or Path(repo_or_path).stem
model_description = ( model_description = (
model_description or f"Imported diffusers model {model_name}" description or f"Imported diffusers model {model_name}"
) )
new_config = dict( new_config = dict(
description=model_description, description=model_description,
@ -1044,43 +1044,6 @@ class ModelManager(object):
self.stack.remove(model_name) self.stack.remove(model_name)
self.models.pop(model_name, None) self.models.pop(model_name, None)
def _model_to_cpu(self, model):
if self.device == CPU_DEVICE:
return model
if isinstance(model, StableDiffusionGeneratorPipeline):
model.offload_all()
return model
model.cond_stage_model.device = CPU_DEVICE
model.to(CPU_DEVICE)
for submodel in ("first_stage_model", "cond_stage_model", "model"):
try:
getattr(model, submodel).to(CPU_DEVICE)
except AttributeError:
pass
return model
def _model_from_cpu(self, model):
if self.device == CPU_DEVICE:
return model
if isinstance(model, StableDiffusionGeneratorPipeline):
model.ready()
return model
model.to(self.device)
model.cond_stage_model.device = self.device
for submodel in ("first_stage_model", "cond_stage_model", "model"):
try:
getattr(model, submodel).to(self.device)
except AttributeError:
pass
return model
def _pop_oldest_model(self): def _pop_oldest_model(self):
""" """
Remove the first element of the FIFO, which ought Remove the first element of the FIFO, which ought

View File

@ -3,7 +3,6 @@ Initialization file for invokeai.backend.prompting
""" """
from .conditioning import ( from .conditioning import (
get_prompt_structure, get_prompt_structure,
get_tokenizer,
get_tokens_for_prompt_object, get_tokens_for_prompt_object,
get_uc_and_c_and_ec, get_uc_and_c_and_ec,
split_weighted_subprompts, split_weighted_subprompts,

View File

@ -7,7 +7,7 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
""" """
import re import re
from typing import Any, Optional, Union from typing import Optional, Union
from compel import Compel from compel import Compel
from compel.prompt_parser import ( from compel.prompt_parser import (
@ -17,7 +17,6 @@ from compel.prompt_parser import (
Fragment, Fragment,
PromptParser, PromptParser,
) )
from transformers import CLIPTokenizer
from invokeai.backend.globals import Globals from invokeai.backend.globals import Globals
@ -25,36 +24,6 @@ from ..stable_diffusion import InvokeAIDiffuserComponent
from ..util import torch_dtype from ..util import torch_dtype
def get_tokenizer(model) -> CLIPTokenizer:
# TODO remove legacy ckpt fallback handling
return (
getattr(model, "tokenizer", None) # diffusers
or model.cond_stage_model.tokenizer
) # ldm
def get_text_encoder(model) -> Any:
# TODO remove legacy ckpt fallback handling
return getattr(
model, "text_encoder", None
) or UnsqueezingLDMTransformer( # diffusers
model.cond_stage_model.transformer
) # ldm
class UnsqueezingLDMTransformer:
def __init__(self, ldm_transformer):
self.ldm_transformer = ldm_transformer
@property
def device(self):
return self.ldm_transformer.device
def __call__(self, *args, **kwargs):
insufficiently_unsqueezed_tensor = self.ldm_transformer(*args, **kwargs)
return insufficiently_unsqueezed_tensor.unsqueeze(0)
def get_uc_and_c_and_ec( def get_uc_and_c_and_ec(
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
): ):
@ -64,11 +33,10 @@ def get_uc_and_c_and_ec(
prompt_string prompt_string
) )
tokenizer = get_tokenizer(model) tokenizer = model.tokenizer
text_encoder = get_text_encoder(model)
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder=text_encoder, text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager, textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype, dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False truncate_long_prompts=False

View File

@ -54,16 +54,6 @@ class PipelineIntermediateState:
attention_map_saver: Optional[AttentionMapSaver] = None attention_map_saver: Optional[AttentionMapSaver] = None
# copied from configs/stable-diffusion/v1-inference.yaml
_default_personalization_config_params = dict(
placeholder_strings=["*"],
initializer_wods=["sculpture"],
per_image_tokens=False,
num_vectors_per_token=1,
progressive_words=False,
)
@dataclass @dataclass
class AddsMaskLatents: class AddsMaskLatents:
"""Add the channels required for inpainting model input. """Add the channels required for inpainting model input.
@ -175,7 +165,7 @@ def image_resized_to_grid_as_tensor(
:param normalize: scale the range to [-1, 1] instead of [0, 1] :param normalize: scale the range to [-1, 1] instead of [0, 1]
:param multiple_of: resize the input so both dimensions are a multiple of this :param multiple_of: resize the input so both dimensions are a multiple of this
""" """
w, h = trim_to_multiple_of(*image.size) w, h = trim_to_multiple_of(*image.size, multiple_of=multiple_of)
transformation = T.Compose( transformation = T.Compose(
[ [
T.Resize((h, w), T.InterpolationMode.LANCZOS), T.Resize((h, w), T.InterpolationMode.LANCZOS),
@ -290,10 +280,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]): scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offsensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
@ -436,11 +426,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
""" """
Ready this pipeline's models. Ready this pipeline's models.
i.e. pre-load them to the GPU if appropriate. i.e. preload them to the GPU if appropriate.
""" """
self._model_group.ready() self._model_group.ready()
def to(self, torch_device: Optional[Union[str, torch.device]] = None): def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
# overridden method; types match the superclass. # overridden method; types match the superclass.
if torch_device is None: if torch_device is None:
return self return self
@ -917,20 +907,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
device=self._model_group.device_for(self.unet), device=self._model_group.device_for(self.unet),
) )
@property
def cond_stage_model(self):
return self.embeddings_provider
@torch.inference_mode()
def _tokenize(self, prompt: Union[str, List[str]]):
return self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
@property @property
def channels(self) -> int: def channels(self) -> int:
"""Compatible with DiffusionWrapper""" """Compatible with DiffusionWrapper"""
@ -942,11 +918,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return super().decode_latents(latents) return super().decode_latents(latents)
def debug_latents(self, latents, msg): def debug_latents(self, latents, msg):
from invokeai.backend.image_util import debug_image
with torch.inference_mode(): with torch.inference_mode():
from ldm.util import debug_image
decoded = self.numpy_to_pil(self.decode_latents(latents)) decoded = self.numpy_to_pil(self.decode_latents(latents))
for i, img in enumerate(decoded): for i, img in enumerate(decoded):
debug_image( debug_image(
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
) )

View File

@ -29,7 +29,6 @@ from ..image_util import PngWriter, retrieve_metadata
from ...frontend.merge.merge_diffusers import merge_diffusion_models from ...frontend.merge.merge_diffusers import merge_diffusion_models
from ..prompting import ( from ..prompting import (
get_prompt_structure, get_prompt_structure,
get_tokenizer,
get_tokens_for_prompt_object, get_tokens_for_prompt_object,
) )
from ..stable_diffusion import PipelineIntermediateState from ..stable_diffusion import PipelineIntermediateState
@ -1274,7 +1273,7 @@ class InvokeAIWebServer:
None None
if type(parsed_prompt) is Blend if type(parsed_prompt) is Blend
else get_tokens_for_prompt_object( else get_tokens_for_prompt_object(
get_tokenizer(self.generate.model), parsed_prompt self.generate.model.tokenizer, parsed_prompt
) )
) )
attention_maps_image_base64_url = ( attention_maps_image_base64_url = (