mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(InvokeAIDiffuserComponent): rename internal methods
Prefix with `_` as is tradition.
This commit is contained in:
parent
7eafcd47a6
commit
aca9d74489
@ -1,4 +1,3 @@
|
|||||||
import math
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from math import ceil
|
from math import ceil
|
||||||
@ -6,8 +5,8 @@ from typing import Callable, Optional, Union, Any, Dict
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers.models.cross_attention import AttnProcessor
|
from diffusers.models.cross_attention import AttnProcessor
|
||||||
|
|
||||||
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||||
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
|
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
|
||||||
CrossAttentionType, SwapCrossAttnContext
|
CrossAttentionType, SwapCrossAttnContext
|
||||||
@ -143,11 +142,16 @@ class InvokeAIDiffuserComponent:
|
|||||||
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
||||||
|
|
||||||
if wants_hybrid_conditioning:
|
if wants_hybrid_conditioning:
|
||||||
unconditioned_next_x, conditioned_next_x = self.apply_hybrid_conditioning(x, sigma, unconditioning, conditioning)
|
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(x, sigma, unconditioning,
|
||||||
|
conditioning)
|
||||||
elif wants_cross_attention_control:
|
elif wants_cross_attention_control:
|
||||||
unconditioned_next_x, conditioned_next_x = self.apply_cross_attention_controlled_conditioning(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
|
unconditioned_next_x, conditioned_next_x = self._apply_cross_attention_controlled_conditioning(x, sigma,
|
||||||
|
unconditioning,
|
||||||
|
conditioning,
|
||||||
|
cross_attention_control_types_to_do)
|
||||||
else:
|
else:
|
||||||
unconditioned_next_x, conditioned_next_x = self.apply_standard_conditioning(x, sigma, unconditioning, conditioning)
|
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning(x, sigma, unconditioning,
|
||||||
|
conditioning)
|
||||||
|
|
||||||
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
|
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
|
||||||
|
|
||||||
@ -181,7 +185,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||||
# fast batched path
|
# fast batched path
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
@ -194,7 +198,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
|
||||||
def apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||||
assert isinstance(conditioning, dict)
|
assert isinstance(conditioning, dict)
|
||||||
assert isinstance(unconditioning, dict)
|
assert isinstance(unconditioning, dict)
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
@ -212,18 +216,21 @@ class InvokeAIDiffuserComponent:
|
|||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
|
||||||
def apply_cross_attention_controlled_conditioning(self,
|
def _apply_cross_attention_controlled_conditioning(self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
sigma,
|
sigma,
|
||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do):
|
cross_attention_control_types_to_do):
|
||||||
if self.is_running_diffusers:
|
if self.is_running_diffusers:
|
||||||
return self.apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
|
return self._apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning,
|
||||||
|
conditioning,
|
||||||
|
cross_attention_control_types_to_do)
|
||||||
else:
|
else:
|
||||||
return self.apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
|
return self._apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning,
|
||||||
|
cross_attention_control_types_to_do)
|
||||||
|
|
||||||
def apply_cross_attention_controlled_conditioning__diffusers(self,
|
def _apply_cross_attention_controlled_conditioning__diffusers(self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
sigma,
|
sigma,
|
||||||
unconditioning,
|
unconditioning,
|
||||||
@ -246,7 +253,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
|
||||||
def apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
def _apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
||||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||||
# slower non-batched path (20% slower on mac MPS)
|
# slower non-batched path (20% slower on mac MPS)
|
||||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||||
|
Loading…
Reference in New Issue
Block a user