diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 152e079693..94ec9da7e8 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -445,8 +445,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): @property def _submodels(self) -> Sequence[torch.nn.Module]: module_names, _, _ = self.extract_init_dict(dict(self.config)) - values = [getattr(self, name) for name in module_names.keys()] - return [m for m in values if isinstance(m, torch.nn.Module)] + submodels = [] + for name in module_names.keys(): + if hasattr(self, name): + value = getattr(self, name) + else: + value = getattr(self.config, name) + if isinstance(value, torch.nn.Module): + submodels.append(value) + return submodels def image_from_embeddings( self, @@ -544,7 +551,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): yield PipelineIntermediateState( run_id=run_id, step=-1, - timestep=self.scheduler.num_train_timesteps, + timestep=self.scheduler.config.num_train_timesteps, latents=latents, ) @@ -915,7 +922,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): @property def channels(self) -> int: """Compatible with DiffusionWrapper""" - return self.unet.in_channels + return self.unet.config.in_channels def decode_latents(self, latents): # Explicit call to get the vae loaded, since `decode` isn't the forward method. diff --git a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py index f933a11a6f..d6c90503fe 100644 --- a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py +++ b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py @@ -10,8 +10,7 @@ import diffusers import psutil import torch from compel.cross_attention_control import Arguments -from diffusers.models.cross_attention import AttnProcessor -from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers.models.attention_processor import AttentionProcessor from torch import nn from ...util import torch_dtype @@ -188,7 +187,7 @@ class Context: class InvokeAICrossAttentionMixin: """ - Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls + Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling and dymamic slicing strategy selection. """ @@ -209,7 +208,7 @@ class InvokeAICrossAttentionMixin: Set custom attention calculator to be called when attention is calculated :param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size), which returns either the suggested_attention_slice or an adjusted equivalent. - `module` is the current CrossAttention module for which the callback is being invoked. + `module` is the current Attention module for which the callback is being invoked. `suggested_attention_slice` is the default-calculated attention slice `dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length. @@ -345,11 +344,11 @@ class InvokeAICrossAttentionMixin: def restore_default_cross_attention( model, is_running_diffusers: bool, - restore_attention_processor: Optional[AttnProcessor] = None, + restore_attention_processor: Optional[AttentionProcessor] = None, ): if is_running_diffusers: unet = model - unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor()) + unet.set_attn_processor(restore_attention_processor or AttnProcessor()) else: remove_attention_function(model) @@ -408,12 +407,9 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False def get_cross_attention_modules( model, which: CrossAttentionType ) -> list[tuple[str, InvokeAICrossAttentionMixin]]: - from ldm.modules.attention import CrossAttention # avoid circular import cross_attention_class: type = ( InvokeAIDiffusersCrossAttention - if isinstance(model, UNet2DConditionModel) - else CrossAttention ) which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2" attention_module_tuples = [ @@ -428,10 +424,10 @@ def get_cross_attention_modules( print( f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " + f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " - + f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " + + "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " + f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " - + f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " - + f"work properly until it is fixed." + + "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " + + "work properly until it is fixed." ) return attention_module_tuples @@ -550,7 +546,7 @@ def get_mem_free_total(device): class InvokeAIDiffusersCrossAttention( - diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin + diffusers.models.attention.Attention, InvokeAICrossAttentionMixin ): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -572,8 +568,8 @@ class InvokeAIDiffusersCrossAttention( """ # base implementation -class CrossAttnProcessor: - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): +class AttnProcessor: + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) @@ -601,9 +597,9 @@ class CrossAttnProcessor: from dataclasses import dataclass, field import torch -from diffusers.models.cross_attention import ( - CrossAttention, - CrossAttnProcessor, +from diffusers.models.attention_processor import ( + Attention, + AttnProcessor, SlicedAttnProcessor, ) @@ -653,7 +649,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): def __call__( self, - attn: CrossAttention, + attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 7e3ab455b9..1137aa52e4 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, Optional, Union import numpy as np import torch -from diffusers.models.cross_attention import AttnProcessor +from diffusers.models.attention_processor import AttentionProcessor from typing_extensions import TypeAlias from invokeai.backend.globals import Globals @@ -101,7 +101,7 @@ class InvokeAIDiffuserComponent: def override_cross_attention( self, conditioning: ExtraConditioningInfo, step_count: int - ) -> Dict[str, AttnProcessor]: + ) -> Dict[str, AttentionProcessor]: """ setup cross attention .swap control. for diffusers this replaces the attention processor, so the previous attention processor is returned so that the caller can restore it later. @@ -118,7 +118,7 @@ class InvokeAIDiffuserComponent: ) def restore_default_cross_attention( - self, restore_attention_processor: Optional["AttnProcessor"] = None + self, restore_attention_processor: Optional["AttentionProcessor"] = None ): self.conditioning = None self.cross_attention_control_context = None @@ -262,7 +262,7 @@ class InvokeAIDiffuserComponent: # TODO remove when compvis codepath support is dropped if step_index is None and sigma is None: raise ValueError( - f"Either step_index or sigma is required when doing cross attention control, but both are None." + "Either step_index or sigma is required when doing cross attention control, but both are None." ) percent_through = self.estimate_percent_through(step_index, sigma) return percent_through @@ -599,7 +599,6 @@ class InvokeAIDiffuserComponent: ) # below is fugly omg - num_actual_conditionings = len(c_or_weighted_c_list) conditionings = [uc] + [c for c, weight in weighted_cond_list] weights = [1] + [weight for c, weight in weighted_cond_list] chunk_count = ceil(len(conditionings) / 2) diff --git a/invokeai/frontend/install/invokeai_update.py b/invokeai/frontend/install/invokeai_update.py index 040067cff9..781c66cddd 100644 --- a/invokeai/frontend/install/invokeai_update.py +++ b/invokeai/frontend/install/invokeai_update.py @@ -1,10 +1,9 @@ -""" +''' Minimalist updater script. Prompts user for the tag or branch to update to and runs pip install . -""" +''' import os import platform - import requests from rich import box, print from rich.console import Console, Group, group @@ -16,8 +15,10 @@ from rich.text import Text from invokeai.version import __version__ -INVOKE_AI_SRC = "https://github.com/invoke-ai/InvokeAI/archive" -INVOKE_AI_REL = "https://api.github.com/repos/invoke-ai/InvokeAI/releases" +INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive" +INVOKE_AI_TAG="https://github.com/invoke-ai/InvokeAI/archive/refs/tags" +INVOKE_AI_BRANCH="https://github.com/invoke-ai/InvokeAI/archive/refs/heads" +INVOKE_AI_REL="https://api.github.com/repos/invoke-ai/InvokeAI/releases" OS = platform.uname().system ARCH = platform.uname().machine @@ -28,22 +29,22 @@ if OS == "Windows": else: console = Console(style=Style(color="grey74", bgcolor="grey19")) - -def get_versions() -> dict: +def get_versions()->dict: return requests.get(url=INVOKE_AI_REL).json() - def welcome(versions: dict): + @group() def text(): - yield f"InvokeAI Version: [bold yellow]{__version__}" - yield "" - yield "This script will update InvokeAI to the latest release, or to a development version of your choice." - yield "" - yield "[bold yellow]Options:" - yield f"""[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic]) + yield f'InvokeAI Version: [bold yellow]{__version__}' + yield '' + yield 'This script will update InvokeAI to the latest release, or to a development version of your choice.' + yield '' + yield '[bold yellow]Options:' + yield f'''[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic]) [2] Update to the bleeding-edge development version ([italic]main[/italic]) -[3] Manually enter the tag or branch name you wish to update""" +[3] Manually enter the [bold]tag name[/bold] for the version you wish to update to +[4] Manually enter the [bold]branch name[/bold] for the version you wish to update to''' console.rule() print( @@ -59,33 +60,41 @@ def welcome(versions: dict): ) console.line() - def main(): versions = get_versions() welcome(versions) tag = None - choice = Prompt.ask("Choice:", choices=["1", "2", "3"], default="1") + branch = None + release = None + choice = Prompt.ask('Choice:',choices=['1','2','3','4'],default='1') + + if choice=='1': + release = versions[0]['tag_name'] + elif choice=='2': + release = 'main' + elif choice=='3': + tag = Prompt.ask('Enter an InvokeAI tag name') + elif choice=='4': + branch = Prompt.ask('Enter an InvokeAI branch name') - if choice == "1": - tag = versions[0]["tag_name"] - elif choice == "2": - tag = "main" - elif choice == "3": - tag = Prompt.ask("Enter an InvokeAI tag or branch name") - - print(f":crossed_fingers: Upgrading to [yellow]{tag}[/yellow]") - cmd = f"pip install {INVOKE_AI_SRC}/{tag}.zip --use-pep517" - print("") - print("") - if os.system(cmd) == 0: - print(f":heavy_check_mark: Upgrade successful") + print(f':crossed_fingers: Upgrading to [yellow]{tag if tag else release}[/yellow]') + if release: + cmd = f'pip install {INVOKE_AI_SRC}/{release}.zip --use-pep517 --upgrade' + elif tag: + cmd = f'pip install {INVOKE_AI_TAG}/{tag}.zip --use-pep517 --upgrade' else: - print(f":exclamation: [bold red]Upgrade failed[/red bold]") - - + cmd = f'pip install {INVOKE_AI_BRANCH}/{branch}.zip --use-pep517 --upgrade' + print('') + print('') + if os.system(cmd)==0: + print(f':heavy_check_mark: Upgrade successful') + else: + print(f':exclamation: [bold red]Upgrade failed[/red bold]') + if __name__ == "__main__": try: main() except KeyboardInterrupt: pass + diff --git a/pyproject.toml b/pyproject.toml index ec6aabfb8b..0ab1e0ede6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "compel==1.0.5", "datasets", - "diffusers[torch]==0.14", + "diffusers[torch]==0.15.1", "dnspython==2.2.1", "einops", "eventlet", @@ -109,7 +109,7 @@ dependencies = [ "invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers" "invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion" "invokeai-model-install" = "invokeai.frontend.install:invokeai_model_install" -"invokeai-update" = "invokeai.frontend.config:invokeai_update" +"invokeai-update" = "invokeai.frontend.install:invokeai_update" "invokeai-metadata" = "invokeai.frontend.CLI.sd_metadata:print_metadata" "invokeai-node-cli" = "invokeai.app.cli_app:invoke_cli" "invokeai-node-web" = "invokeai.app.api_app:invoke_api"