mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(InvokeAIDiffuserComponent): rename internal methods
Prefix with `_` as is tradition.
This commit is contained in:
parent
7eafcd47a6
commit
aca9d74489
@ -1,4 +1,3 @@
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
@ -6,8 +5,8 @@ from typing import Callable, Optional, Union, Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
|
||||
from ldm.models.diffusion.cross_attention_control import Arguments, \
|
||||
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
|
||||
CrossAttentionType, SwapCrossAttnContext
|
||||
@ -143,11 +142,16 @@ class InvokeAIDiffuserComponent:
|
||||
wants_hybrid_conditioning = isinstance(conditioning, dict)
|
||||
|
||||
if wants_hybrid_conditioning:
|
||||
unconditioned_next_x, conditioned_next_x = self.apply_hybrid_conditioning(x, sigma, unconditioning, conditioning)
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(x, sigma, unconditioning,
|
||||
conditioning)
|
||||
elif wants_cross_attention_control:
|
||||
unconditioned_next_x, conditioned_next_x = self.apply_cross_attention_controlled_conditioning(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_cross_attention_controlled_conditioning(x, sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do)
|
||||
else:
|
||||
unconditioned_next_x, conditioned_next_x = self.apply_standard_conditioning(x, sigma, unconditioning, conditioning)
|
||||
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning(x, sigma, unconditioning,
|
||||
conditioning)
|
||||
|
||||
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
|
||||
|
||||
@ -181,7 +185,7 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||
# fast batched path
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
@ -194,7 +198,7 @@ class InvokeAIDiffuserComponent:
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||
assert isinstance(conditioning, dict)
|
||||
assert isinstance(unconditioning, dict)
|
||||
x_twice = torch.cat([x] * 2)
|
||||
@ -212,18 +216,21 @@ class InvokeAIDiffuserComponent:
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def apply_cross_attention_controlled_conditioning(self,
|
||||
def _apply_cross_attention_controlled_conditioning(self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do):
|
||||
if self.is_running_diffusers:
|
||||
return self.apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
|
||||
return self._apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do)
|
||||
else:
|
||||
return self.apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
|
||||
return self._apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning,
|
||||
cross_attention_control_types_to_do)
|
||||
|
||||
def apply_cross_attention_controlled_conditioning__diffusers(self,
|
||||
def _apply_cross_attention_controlled_conditioning__diffusers(self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
@ -246,7 +253,7 @@ class InvokeAIDiffuserComponent:
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
|
||||
def apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
||||
def _apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
|
||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||
# slower non-batched path (20% slower on mac MPS)
|
||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||
|
Loading…
Reference in New Issue
Block a user