mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove more old logic
This commit is contained in:
parent
7b35162b9e
commit
a01998d095
@ -12,7 +12,7 @@ from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
|
|||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.generator.inpaint import infill_methods
|
from invokeai.backend.generator.inpaint import infill_methods
|
||||||
|
|
||||||
from ...backend.generator import Img2Img, Inpaint, InvokeAIGenerator, Txt2Img
|
from ...backend.generator import Inpaint, InvokeAIGenerator
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ..util.step_callback import stable_diffusion_step_callback
|
from ..util.step_callback import stable_diffusion_step_callback
|
||||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||||
|
@ -5,7 +5,6 @@ from .generator import (
|
|||||||
InvokeAIGeneratorBasicParams,
|
InvokeAIGeneratorBasicParams,
|
||||||
InvokeAIGenerator,
|
InvokeAIGenerator,
|
||||||
InvokeAIGeneratorOutput,
|
InvokeAIGeneratorOutput,
|
||||||
Txt2Img,
|
|
||||||
Img2Img,
|
Img2Img,
|
||||||
Inpaint
|
Inpaint
|
||||||
)
|
)
|
||||||
|
@ -5,7 +5,6 @@ from .base import (
|
|||||||
InvokeAIGenerator,
|
InvokeAIGenerator,
|
||||||
InvokeAIGeneratorBasicParams,
|
InvokeAIGeneratorBasicParams,
|
||||||
InvokeAIGeneratorOutput,
|
InvokeAIGeneratorOutput,
|
||||||
Txt2Img,
|
|
||||||
Img2Img,
|
Img2Img,
|
||||||
Inpaint,
|
Inpaint,
|
||||||
Generator,
|
Generator,
|
||||||
|
@ -175,13 +175,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
'''
|
'''
|
||||||
return Generator
|
return Generator
|
||||||
|
|
||||||
# ------------------------------------
|
|
||||||
class Txt2Img(InvokeAIGenerator):
|
|
||||||
@classmethod
|
|
||||||
def _generator_class(cls):
|
|
||||||
from .txt2img import Txt2Img
|
|
||||||
return Txt2Img
|
|
||||||
|
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
class Img2Img(InvokeAIGenerator):
|
class Img2Img(InvokeAIGenerator):
|
||||||
def generate(self,
|
def generate(self,
|
||||||
@ -235,24 +228,6 @@ class Inpaint(Img2Img):
|
|||||||
from .inpaint import Inpaint
|
from .inpaint import Inpaint
|
||||||
return Inpaint
|
return Inpaint
|
||||||
|
|
||||||
# ------------------------------------
|
|
||||||
class Embiggen(Txt2Img):
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
embiggen: list=None,
|
|
||||||
embiggen_tiles: list = None,
|
|
||||||
strength: float=0.75,
|
|
||||||
**kwargs)->Iterator[InvokeAIGeneratorOutput]:
|
|
||||||
return super().generate(embiggen=embiggen,
|
|
||||||
embiggen_tiles=embiggen_tiles,
|
|
||||||
strength=strength,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _generator_class(cls):
|
|
||||||
from .embiggen import Embiggen
|
|
||||||
return Embiggen
|
|
||||||
|
|
||||||
class Generator:
|
class Generator:
|
||||||
downsampling_factor: int
|
downsampling_factor: int
|
||||||
latent_channels: int
|
latent_channels: int
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for the invokeai.backend.stable_diffusion package
|
Initialization file for the invokeai.backend.stable_diffusion package
|
||||||
"""
|
"""
|
||||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
|
||||||
from .diffusers_pipeline import (
|
from .diffusers_pipeline import (
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
PipelineIntermediateState,
|
PipelineIntermediateState,
|
||||||
|
@ -1,275 +0,0 @@
|
|||||||
"""
|
|
||||||
Query and install embeddings from the HuggingFace SD Concepts Library
|
|
||||||
at https://huggingface.co/sd-concepts-library.
|
|
||||||
|
|
||||||
The interface is through the Concepts() object.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
from typing import Callable
|
|
||||||
from urllib import error as ul_error
|
|
||||||
from urllib import request
|
|
||||||
|
|
||||||
from huggingface_hub import (
|
|
||||||
HfApi,
|
|
||||||
HfFolder,
|
|
||||||
ModelFilter,
|
|
||||||
hf_hub_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
logger = InvokeAILogger.getLogger()
|
|
||||||
|
|
||||||
class HuggingFaceConceptsLibrary(object):
|
|
||||||
def __init__(self, root=None):
|
|
||||||
"""
|
|
||||||
Initialize the Concepts object. May optionally pass a root directory.
|
|
||||||
"""
|
|
||||||
self.config = InvokeAIAppConfig.get_config()
|
|
||||||
self.root = root or self.config.root
|
|
||||||
self.hf_api = HfApi()
|
|
||||||
self.local_concepts = dict()
|
|
||||||
self.concept_list = None
|
|
||||||
self.concepts_loaded = dict()
|
|
||||||
self.triggers = dict() # concept name to trigger phrase
|
|
||||||
self.concept_names = dict() # trigger phrase to concept name
|
|
||||||
self.match_trigger = re.compile(
|
|
||||||
"(<[\w\- >]+>)"
|
|
||||||
) # trigger is slightly less restrictive than HF concept name
|
|
||||||
self.match_concept = re.compile(
|
|
||||||
"<([\w\-]+)>"
|
|
||||||
) # HF concept name can only contain A-Za-z0-9_-
|
|
||||||
|
|
||||||
def list_concepts(self) -> list:
|
|
||||||
"""
|
|
||||||
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
|
|
||||||
Also adds local concepts in invokeai/embeddings folder.
|
|
||||||
"""
|
|
||||||
local_concepts_now = self.get_local_concepts(
|
|
||||||
os.path.join(self.root, "embeddings")
|
|
||||||
)
|
|
||||||
local_concepts_to_add = set(local_concepts_now).difference(
|
|
||||||
set(self.local_concepts)
|
|
||||||
)
|
|
||||||
self.local_concepts.update(local_concepts_now)
|
|
||||||
|
|
||||||
if self.concept_list is not None:
|
|
||||||
if local_concepts_to_add:
|
|
||||||
self.concept_list.extend(list(local_concepts_to_add))
|
|
||||||
return self.concept_list
|
|
||||||
return self.concept_list
|
|
||||||
elif self.config.internet_available is True:
|
|
||||||
try:
|
|
||||||
models = self.hf_api.list_models(
|
|
||||||
filter=ModelFilter(model_name="sd-concepts-library/")
|
|
||||||
)
|
|
||||||
self.concept_list = [a.id.split("/")[1] for a in models]
|
|
||||||
# when init, add all in dir. when not init, add only concepts added between init and now
|
|
||||||
self.concept_list.extend(list(local_concepts_to_add))
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
|
||||||
)
|
|
||||||
return self.concept_list
|
|
||||||
else:
|
|
||||||
return self.concept_list
|
|
||||||
|
|
||||||
def get_concept_model_path(self, concept_name: str) -> str:
|
|
||||||
"""
|
|
||||||
Returns the path to the 'learned_embeds.bin' file in
|
|
||||||
the named concept. Returns None if invalid or cannot
|
|
||||||
be downloaded.
|
|
||||||
"""
|
|
||||||
if not concept_name in self.list_concepts():
|
|
||||||
logger.warning(
|
|
||||||
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
|
|
||||||
|
|
||||||
def concept_to_trigger(self, concept_name: str) -> str:
|
|
||||||
"""
|
|
||||||
Given a concept name returns its trigger by looking in the
|
|
||||||
"token_identifier.txt" file.
|
|
||||||
"""
|
|
||||||
if concept_name in self.triggers:
|
|
||||||
return self.triggers[concept_name]
|
|
||||||
elif self.concept_is_local(concept_name):
|
|
||||||
trigger = f"<{concept_name}>"
|
|
||||||
self.triggers[concept_name] = trigger
|
|
||||||
self.concept_names[trigger] = concept_name
|
|
||||||
return trigger
|
|
||||||
|
|
||||||
file = self.get_concept_file(
|
|
||||||
concept_name, "token_identifier.txt", local_only=True
|
|
||||||
)
|
|
||||||
if not file:
|
|
||||||
return None
|
|
||||||
with open(file, "r") as f:
|
|
||||||
trigger = f.readline()
|
|
||||||
trigger = trigger.strip()
|
|
||||||
self.triggers[concept_name] = trigger
|
|
||||||
self.concept_names[trigger] = concept_name
|
|
||||||
return trigger
|
|
||||||
|
|
||||||
def trigger_to_concept(self, trigger: str) -> str:
|
|
||||||
"""
|
|
||||||
Given a trigger phrase, maps it to the concept library name.
|
|
||||||
Only works if concept_to_trigger() has previously been called
|
|
||||||
on this library. There needs to be a persistent database for
|
|
||||||
this.
|
|
||||||
"""
|
|
||||||
concept = self.concept_names.get(trigger, None)
|
|
||||||
return f"<{concept}>" if concept else f"{trigger}"
|
|
||||||
|
|
||||||
def replace_triggers_with_concepts(self, prompt: str) -> str:
|
|
||||||
"""
|
|
||||||
Given a prompt string that contains <trigger> tags, replace these
|
|
||||||
tags with the concept name. The reason for this is so that the
|
|
||||||
concept names get stored in the prompt metadata. There is no
|
|
||||||
controlling of colliding triggers in the SD library, so it is
|
|
||||||
better to store the concept name (unique) than the concept trigger
|
|
||||||
(not necessarily unique!)
|
|
||||||
"""
|
|
||||||
if not prompt:
|
|
||||||
return prompt
|
|
||||||
triggers = self.match_trigger.findall(prompt)
|
|
||||||
if not triggers:
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
def do_replace(match) -> str:
|
|
||||||
return self.trigger_to_concept(match.group(1)) or f"<{match.group(1)}>"
|
|
||||||
|
|
||||||
return self.match_trigger.sub(do_replace, prompt)
|
|
||||||
|
|
||||||
def replace_concepts_with_triggers(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
load_concepts_callback: Callable[[list], any],
|
|
||||||
excluded_tokens: list[str],
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Given a prompt string that contains `<concept_name>` tags, replace
|
|
||||||
these tags with the appropriate trigger.
|
|
||||||
|
|
||||||
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
|
||||||
of `concepts_name` strings.
|
|
||||||
|
|
||||||
`excluded_tokens` are any tokens that should not be replaced, typically because they
|
|
||||||
are trigger tokens from a locally-loaded embedding.
|
|
||||||
"""
|
|
||||||
concepts = self.match_concept.findall(prompt)
|
|
||||||
if not concepts:
|
|
||||||
return prompt
|
|
||||||
load_concepts_callback(concepts)
|
|
||||||
|
|
||||||
def do_replace(match) -> str:
|
|
||||||
if excluded_tokens and f"<{match.group(1)}>" in excluded_tokens:
|
|
||||||
return f"<{match.group(1)}>"
|
|
||||||
return self.concept_to_trigger(match.group(1)) or f"<{match.group(1)}>"
|
|
||||||
|
|
||||||
return self.match_concept.sub(do_replace, prompt)
|
|
||||||
|
|
||||||
def get_concept_file(
|
|
||||||
self,
|
|
||||||
concept_name: str,
|
|
||||||
file_name: str = "learned_embeds.bin",
|
|
||||||
local_only: bool = False,
|
|
||||||
) -> str:
|
|
||||||
if not (
|
|
||||||
self.concept_is_downloaded(concept_name)
|
|
||||||
or self.concept_is_local(concept_name)
|
|
||||||
or local_only
|
|
||||||
):
|
|
||||||
self.download_concept(concept_name)
|
|
||||||
|
|
||||||
# get local path in invokeai/embeddings if local concept
|
|
||||||
if self.concept_is_local(concept_name):
|
|
||||||
concept_path = self._concept_local_path(concept_name)
|
|
||||||
path = concept_path
|
|
||||||
else:
|
|
||||||
concept_path = self._concept_path(concept_name)
|
|
||||||
path = os.path.join(concept_path, file_name)
|
|
||||||
return path if os.path.exists(path) else None
|
|
||||||
|
|
||||||
def concept_is_local(self, concept_name) -> bool:
|
|
||||||
return concept_name in self.local_concepts
|
|
||||||
|
|
||||||
def concept_is_downloaded(self, concept_name) -> bool:
|
|
||||||
concept_directory = self._concept_path(concept_name)
|
|
||||||
return os.path.exists(concept_directory)
|
|
||||||
|
|
||||||
def download_concept(self, concept_name) -> bool:
|
|
||||||
repo_id = self._concept_id(concept_name)
|
|
||||||
dest = self._concept_path(concept_name)
|
|
||||||
|
|
||||||
access_token = HfFolder.get_token()
|
|
||||||
header = [("Authorization", f"Bearer {access_token}")] if access_token else []
|
|
||||||
opener = request.build_opener()
|
|
||||||
opener.addheaders = header
|
|
||||||
request.install_opener(opener)
|
|
||||||
|
|
||||||
os.makedirs(dest, exist_ok=True)
|
|
||||||
succeeded = True
|
|
||||||
|
|
||||||
bytes = 0
|
|
||||||
|
|
||||||
def tally_download_size(chunk, size, total):
|
|
||||||
nonlocal bytes
|
|
||||||
if chunk == 0:
|
|
||||||
bytes += total
|
|
||||||
|
|
||||||
logger.info(f"Downloading {repo_id}...", end="")
|
|
||||||
try:
|
|
||||||
for file in (
|
|
||||||
"README.md",
|
|
||||||
"learned_embeds.bin",
|
|
||||||
"token_identifier.txt",
|
|
||||||
"type_of_concept.txt",
|
|
||||||
):
|
|
||||||
url = hf_hub_url(repo_id, file)
|
|
||||||
request.urlretrieve(
|
|
||||||
url, os.path.join(dest, file), reporthook=tally_download_size
|
|
||||||
)
|
|
||||||
except ul_error.HTTPError as e:
|
|
||||||
if e.code == 404:
|
|
||||||
logger.warning(
|
|
||||||
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
|
||||||
)
|
|
||||||
os.rmdir(dest)
|
|
||||||
return False
|
|
||||||
except ul_error.URLError as e:
|
|
||||||
logger.error(
|
|
||||||
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
|
||||||
)
|
|
||||||
os.rmdir(dest)
|
|
||||||
return False
|
|
||||||
logger.info("...{:.2f}Kb".format(bytes / 1024))
|
|
||||||
return succeeded
|
|
||||||
|
|
||||||
def _concept_id(self, concept_name: str) -> str:
|
|
||||||
return f"sd-concepts-library/{concept_name}"
|
|
||||||
|
|
||||||
def _concept_path(self, concept_name: str) -> str:
|
|
||||||
return os.path.join(self.root, "models", "sd-concepts-library", concept_name)
|
|
||||||
|
|
||||||
def _concept_local_path(self, concept_name: str) -> str:
|
|
||||||
filename = self.local_concepts[concept_name]
|
|
||||||
return os.path.join(self.root, "embeddings", filename)
|
|
||||||
|
|
||||||
def get_local_concepts(self, loc_dir: str):
|
|
||||||
locs_dic = dict()
|
|
||||||
if os.path.isdir(loc_dir):
|
|
||||||
for file in os.listdir(loc_dir):
|
|
||||||
f = os.path.splitext(file)
|
|
||||||
if f[1] == ".bin" or f[1] == ".pt":
|
|
||||||
locs_dic[f[0]] = file
|
|
||||||
return locs_dic
|
|
@ -340,7 +340,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# control_model=control_model,
|
# control_model=control_model,
|
||||||
)
|
)
|
||||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
||||||
self.unet, self._unet_forward, is_running_diffusers=True
|
self.unet, self._unet_forward
|
||||||
)
|
)
|
||||||
|
|
||||||
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
|
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
|
||||||
|
@ -18,7 +18,6 @@ from .cross_attention_control import (
|
|||||||
CrossAttentionType,
|
CrossAttentionType,
|
||||||
SwapCrossAttnContext,
|
SwapCrossAttnContext,
|
||||||
get_cross_attention_modules,
|
get_cross_attention_modules,
|
||||||
restore_default_cross_attention,
|
|
||||||
setup_cross_attention_control_attention_processors,
|
setup_cross_attention_control_attention_processors,
|
||||||
)
|
)
|
||||||
from .cross_attention_map_saving import AttentionMapSaver
|
from .cross_attention_map_saving import AttentionMapSaver
|
||||||
@ -66,7 +65,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
model_forward_callback: ModelForwardCallback,
|
model_forward_callback: ModelForwardCallback,
|
||||||
is_running_diffusers: bool = False,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param model: the unet model to pass through to cross attention control
|
:param model: the unet model to pass through to cross attention control
|
||||||
@ -75,7 +73,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
self.conditioning = None
|
self.conditioning = None
|
||||||
self.model = model
|
self.model = model
|
||||||
self.is_running_diffusers = is_running_diffusers
|
|
||||||
self.model_forward_callback = model_forward_callback
|
self.model_forward_callback = model_forward_callback
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
self.sequential_guidance = config.sequential_guidance
|
self.sequential_guidance = config.sequential_guidance
|
||||||
@ -112,37 +109,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
# TODO resuscitate attention map saving
|
# TODO resuscitate attention map saving
|
||||||
# self.remove_attention_map_saving()
|
# self.remove_attention_map_saving()
|
||||||
|
|
||||||
# apparently unused code
|
|
||||||
# TODO: delete
|
|
||||||
# def override_cross_attention(
|
|
||||||
# self, conditioning: ExtraConditioningInfo, step_count: int
|
|
||||||
# ) -> Dict[str, AttentionProcessor]:
|
|
||||||
# """
|
|
||||||
# setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
|
||||||
# the previous attention processor is returned so that the caller can restore it later.
|
|
||||||
# """
|
|
||||||
# self.conditioning = conditioning
|
|
||||||
# self.cross_attention_control_context = Context(
|
|
||||||
# arguments=self.conditioning.cross_attention_control_args,
|
|
||||||
# step_count=step_count,
|
|
||||||
# )
|
|
||||||
# return override_cross_attention(
|
|
||||||
# self.model,
|
|
||||||
# self.cross_attention_control_context,
|
|
||||||
# is_running_diffusers=self.is_running_diffusers,
|
|
||||||
# )
|
|
||||||
|
|
||||||
def restore_default_cross_attention(
|
|
||||||
self, restore_attention_processor: Optional["AttentionProcessor"] = None
|
|
||||||
):
|
|
||||||
self.conditioning = None
|
|
||||||
self.cross_attention_control_context = None
|
|
||||||
restore_default_cross_attention(
|
|
||||||
self.model,
|
|
||||||
is_running_diffusers=self.is_running_diffusers,
|
|
||||||
restore_attention_processor=restore_attention_processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||||
def callback(slice, dim, offset, slice_size, key):
|
def callback(slice, dim, offset, slice_size, key):
|
||||||
if dim is not None:
|
if dim is not None:
|
||||||
@ -204,9 +170,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
cross_attention_control_types_to_do = []
|
cross_attention_control_types_to_do = []
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
if self.cross_attention_control_context is not None:
|
if self.cross_attention_control_context is not None:
|
||||||
percent_through = self.calculate_percent_through(
|
percent_through = step_index / total_step_count
|
||||||
sigma, step_index, total_step_count
|
|
||||||
)
|
|
||||||
cross_attention_control_types_to_do = (
|
cross_attention_control_types_to_do = (
|
||||||
context.get_active_cross_attention_control_types_for_step(
|
context.get_active_cross_attention_control_types_for_step(
|
||||||
percent_through
|
percent_through
|
||||||
@ -264,9 +228,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
total_step_count,
|
total_step_count,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if postprocessing_settings is not None:
|
if postprocessing_settings is not None:
|
||||||
percent_through = self.calculate_percent_through(
|
percent_through = step_index / total_step_count
|
||||||
sigma, step_index, total_step_count
|
|
||||||
)
|
|
||||||
latents = self.apply_threshold(
|
latents = self.apply_threshold(
|
||||||
postprocessing_settings, latents, percent_through
|
postprocessing_settings, latents, percent_through
|
||||||
)
|
)
|
||||||
@ -275,22 +237,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
def calculate_percent_through(self, sigma, step_index, total_step_count):
|
|
||||||
if step_index is not None and total_step_count is not None:
|
|
||||||
# 🧨diffusers codepath
|
|
||||||
percent_through = (
|
|
||||||
step_index / total_step_count
|
|
||||||
) # will never reach 1.0 - this is deliberate
|
|
||||||
else:
|
|
||||||
# legacy compvis codepath
|
|
||||||
# TODO remove when compvis codepath support is dropped
|
|
||||||
if step_index is None and sigma is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Either step_index or sigma is required when doing cross attention control, but both are None."
|
|
||||||
)
|
|
||||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
|
||||||
return percent_through
|
|
||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
@ -323,6 +269,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
conditioned_next_x = conditioned_next_x.clone()
|
conditioned_next_x = conditioned_next_x.clone()
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
# TODO: looks unused
|
||||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
assert isinstance(conditioning, dict)
|
assert isinstance(conditioning, dict)
|
||||||
assert isinstance(unconditioning, dict)
|
assert isinstance(unconditioning, dict)
|
||||||
@ -350,34 +297,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
|
||||||
if self.is_running_diffusers:
|
|
||||||
return self._apply_cross_attention_controlled_conditioning__diffusers(
|
|
||||||
x,
|
|
||||||
sigma,
|
|
||||||
unconditioning,
|
|
||||||
conditioning,
|
|
||||||
cross_attention_control_types_to_do,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self._apply_cross_attention_controlled_conditioning__compvis(
|
|
||||||
x,
|
|
||||||
sigma,
|
|
||||||
unconditioning,
|
|
||||||
conditioning,
|
|
||||||
cross_attention_control_types_to_do,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _apply_cross_attention_controlled_conditioning__diffusers(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
sigma,
|
|
||||||
unconditioning,
|
|
||||||
conditioning,
|
|
||||||
cross_attention_control_types_to_do,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
|
|
||||||
@ -409,54 +328,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
def _apply_cross_attention_controlled_conditioning__compvis(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
sigma,
|
|
||||||
unconditioning,
|
|
||||||
conditioning,
|
|
||||||
cross_attention_control_types_to_do,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
|
||||||
# slower non-batched path (20% slower on mac MPS)
|
|
||||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
|
||||||
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
|
||||||
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
|
|
||||||
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
|
|
||||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
|
||||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
|
||||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
|
||||||
context: Context = self.cross_attention_control_context
|
|
||||||
|
|
||||||
try:
|
|
||||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
|
||||||
|
|
||||||
# process x using the original prompt, saving the attention maps
|
|
||||||
# print("saving attention maps for", cross_attention_control_types_to_do)
|
|
||||||
for ca_type in cross_attention_control_types_to_do:
|
|
||||||
context.request_save_attention_maps(ca_type)
|
|
||||||
_ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
|
|
||||||
context.clear_requests(cleanup=False)
|
|
||||||
|
|
||||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
|
||||||
# print("applying saved attention maps for", cross_attention_control_types_to_do)
|
|
||||||
for ca_type in cross_attention_control_types_to_do:
|
|
||||||
context.request_apply_saved_attention_maps(ca_type)
|
|
||||||
edited_conditioning = (
|
|
||||||
self.conditioning.cross_attention_control_args.edited_conditioning
|
|
||||||
)
|
|
||||||
conditioned_next_x = self.model_forward_callback(
|
|
||||||
x, sigma, edited_conditioning, **kwargs,
|
|
||||||
)
|
|
||||||
context.clear_requests(cleanup=True)
|
|
||||||
|
|
||||||
except:
|
|
||||||
context.clear_requests(cleanup=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
return unconditioned_next_x, conditioned_next_x
|
|
||||||
|
|
||||||
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
|
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
|
||||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
||||||
|
Loading…
Reference in New Issue
Block a user