refactor(InvokeAIDiffuserComponent): rename internal methods

Prefix with `_` as is tradition.
This commit is contained in:
Kevin Turner 2023-02-19 15:33:16 -08:00
parent 7eafcd47a6
commit aca9d74489

View File

@ -1,4 +1,3 @@
import math
from contextlib import contextmanager
from dataclasses import dataclass
from math import ceil
@ -6,8 +5,8 @@ from typing import Callable, Optional, Union, Any, Dict
import numpy as np
import torch
from diffusers.models.cross_attention import AttnProcessor
from ldm.models.diffusion.cross_attention_control import Arguments, \
restore_default_cross_attention, override_cross_attention, Context, get_cross_attention_modules, \
CrossAttentionType, SwapCrossAttnContext
@ -143,11 +142,16 @@ class InvokeAIDiffuserComponent:
wants_hybrid_conditioning = isinstance(conditioning, dict)
if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self.apply_hybrid_conditioning(x, sigma, unconditioning, conditioning)
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(x, sigma, unconditioning,
conditioning)
elif wants_cross_attention_control:
unconditioned_next_x, conditioned_next_x = self.apply_cross_attention_controlled_conditioning(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
unconditioned_next_x, conditioned_next_x = self._apply_cross_attention_controlled_conditioning(x, sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do)
else:
unconditioned_next_x, conditioned_next_x = self.apply_standard_conditioning(x, sigma, unconditioning, conditioning)
unconditioned_next_x, conditioned_next_x = self._apply_standard_conditioning(x, sigma, unconditioning,
conditioning)
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
@ -181,7 +185,7 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class.
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
# fast batched path
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
@ -194,7 +198,7 @@ class InvokeAIDiffuserComponent:
return unconditioned_next_x, conditioned_next_x
def apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict)
x_twice = torch.cat([x] * 2)
@ -212,18 +216,21 @@ class InvokeAIDiffuserComponent:
return unconditioned_next_x, conditioned_next_x
def apply_cross_attention_controlled_conditioning(self,
def _apply_cross_attention_controlled_conditioning(self,
x: torch.Tensor,
sigma,
unconditioning,
conditioning,
cross_attention_control_types_to_do):
if self.is_running_diffusers:
return self.apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
return self._apply_cross_attention_controlled_conditioning__diffusers(x, sigma, unconditioning,
conditioning,
cross_attention_control_types_to_do)
else:
return self.apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do)
return self._apply_cross_attention_controlled_conditioning__compvis(x, sigma, unconditioning, conditioning,
cross_attention_control_types_to_do)
def apply_cross_attention_controlled_conditioning__diffusers(self,
def _apply_cross_attention_controlled_conditioning__diffusers(self,
x: torch.Tensor,
sigma,
unconditioning,
@ -246,7 +253,7 @@ class InvokeAIDiffuserComponent:
return unconditioned_next_x, conditioned_next_x
def apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
def _apply_cross_attention_controlled_conditioning__compvis(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of