mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor redundant code and fix typechecking errors
This commit is contained in:
parent
e28737fc8b
commit
f13427e3f4
@ -50,6 +50,7 @@ from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput,
|
|||||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||||
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
|
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
|
||||||
@ -674,22 +675,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
def prep_ip_adapter_image_prompts(
|
def prep_ip_adapter_image_prompts(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
ip_adapters: List[IPAdapterField],
|
||||||
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
|
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
|
||||||
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
|
|
||||||
if ip_adapter is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
if not isinstance(ip_adapter, list):
|
|
||||||
ip_adapter = [ip_adapter]
|
|
||||||
|
|
||||||
if len(ip_adapter) == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
image_prompts = []
|
image_prompts = []
|
||||||
for single_ip_adapter in ip_adapter:
|
for single_ip_adapter in ip_adapters:
|
||||||
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
|
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
|
||||||
|
assert isinstance(ip_adapter_model, IPAdapter)
|
||||||
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
||||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||||
single_ipa_image_fields = single_ip_adapter.image
|
single_ipa_image_fields = single_ip_adapter.image
|
||||||
@ -710,36 +702,23 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
def prep_ip_adapter_data(
|
def prep_ip_adapter_data(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
ip_adapters: List[IPAdapterField],
|
||||||
|
image_prompts: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
exit_stack: ExitStack,
|
exit_stack: ExitStack,
|
||||||
latent_height: int,
|
latent_height: int,
|
||||||
latent_width: int,
|
latent_width: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
image_prompts: List[Tuple[torch.Tensor, torch.Tensor]],
|
) -> Optional[List[IPAdapterData]]:
|
||||||
) -> Optional[list[IPAdapterData]]:
|
"""If IP-Adapter is enabled, then this function loads the requisite models and adds the image prompt conditioning data."""
|
||||||
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
|
||||||
to the `conditioning_data` (in-place).
|
|
||||||
"""
|
|
||||||
if ip_adapter is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
|
|
||||||
if not isinstance(ip_adapter, list):
|
|
||||||
ip_adapter = [ip_adapter]
|
|
||||||
|
|
||||||
if len(ip_adapter) == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
ip_adapter_data_list = []
|
ip_adapter_data_list = []
|
||||||
assert len(ip_adapter) == len(image_prompts)
|
assert len(ip_adapters) == len(image_prompts)
|
||||||
for single_ip_adapter in ip_adapter:
|
for single_ip_adapter in ip_adapters:
|
||||||
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
|
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
|
||||||
|
|
||||||
image_prompt_embeds, uncond_image_prompt_embeds = image_prompts.pop(0)
|
image_prompt_embeds, uncond_image_prompt_embeds = image_prompts.pop(0)
|
||||||
|
|
||||||
mask = single_ip_adapter.mask
|
mask_field = single_ip_adapter.mask
|
||||||
if mask is not None:
|
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
|
||||||
mask = context.tensors.load(mask.tensor_name)
|
|
||||||
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
||||||
|
|
||||||
ip_adapter_data_list.append(
|
ip_adapter_data_list.append(
|
||||||
@ -754,7 +733,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return ip_adapter_data_list
|
return ip_adapter_data_list if len(ip_adapter_data_list) > 0 else None
|
||||||
|
|
||||||
def run_t2i_adapters(
|
def run_t2i_adapters(
|
||||||
self,
|
self,
|
||||||
@ -932,7 +911,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapter=self.ip_adapter)
|
ip_adapters: List[IPAdapterField] = []
|
||||||
|
if self.ip_adapter is not None:
|
||||||
|
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
|
||||||
|
if isinstance(self.ip_adapter, list):
|
||||||
|
ip_adapters = self.ip_adapter
|
||||||
|
else:
|
||||||
|
ip_adapters = [self.ip_adapter]
|
||||||
|
|
||||||
|
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
|
||||||
|
|
||||||
# get the unet's config so that we can pass the base to dispatch_progress()
|
# get the unet's config so that we can pass the base to dispatch_progress()
|
||||||
unet_config = context.models.get_config(self.unet.unet.key)
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
@ -992,12 +979,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
ip_adapter_data = self.prep_ip_adapter_data(
|
ip_adapter_data = self.prep_ip_adapter_data(
|
||||||
context=context,
|
context=context,
|
||||||
ip_adapter=self.ip_adapter,
|
ip_adapters=ip_adapters,
|
||||||
|
image_prompts=image_prompts,
|
||||||
exit_stack=exit_stack,
|
exit_stack=exit_stack,
|
||||||
latent_height=latent_height,
|
latent_height=latent_height,
|
||||||
latent_width=latent_width,
|
latent_width=latent_width,
|
||||||
dtype=unet.dtype,
|
dtype=unet.dtype,
|
||||||
image_prompts=image_prompts,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||||
|
Loading…
Reference in New Issue
Block a user