diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 5229ace95f..91320173ed 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1,31 +1,32 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) -import random from typing import Literal, Optional, Union + +import diffusers import einops -from pydantic import BaseModel, Field import torch +from diffusers import DiffusionPipeline +from diffusers.schedulers import SchedulerMixin as Scheduler +from pydantic import BaseModel, Field +from invokeai.app.models.exceptions import CanceledException from invokeai.app.util.misc import SEED_MAX, get_random_seed - from invokeai.app.util.step_callback import stable_diffusion_step_callback -from ...backend.model_management.model_manager import ModelManager -from ...backend.util.devices import choose_torch_device, torch_dtype -from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.image_util.seamless import configure_model_padding -from ...backend.prompting.conditioning import get_uc_and_c_and_ec -from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig -import numpy as np -from ..services.image_storage import ImageType -from .baseinvocation import BaseInvocation, InvocationContext -from .image import ImageField, ImageOutput, build_image_output -from .compel import ConditioningField +from ...backend.model_management.model_manager import SDModelType from ...backend.stable_diffusion import PipelineIntermediateState -from diffusers.schedulers import SchedulerMixin as Scheduler -import diffusers -from diffusers import DiffusionPipeline +from ...backend.stable_diffusion.diffusers_pipeline import ( + ConditioningData, StableDiffusionGeneratorPipeline, + image_resized_to_grid_as_tensor) +from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \ + PostprocessingSettings +from ...backend.util.devices import choose_torch_device, torch_dtype +from ..services.image_storage import ImageType +from .baseinvocation import (BaseInvocation, BaseInvocationOutput, + InvocationConfig, InvocationContext) +from .compel import ConditioningField +from .image import ImageField, ImageOutput, build_image_output class LatentsField(BaseModel): @@ -103,6 +104,37 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c # x = (1 - self.perlin) * x + self.perlin * perlin_noise return x +class ModelChooser: + def choose_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline: + + if context.services.queue.is_canceled(context.graph_execution_state_id): + raise CanceledException + + # Get the source node id (we are invoking the prepared node) + graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) + source_node_id = graph_execution_state.prepared_source_mapping[self.id] + + context.services.events.emit_model_load_started( + graph_execution_state_id=context.graph_execution_state_id, + node=self.dict(), + source_node_id=source_node_id, + model_name=self.model, + submodel=SDModelType.diffusers + ) + + model_manager = context.services.model_manager + model_info = model_manager.get_model(self.model) + model_ctx: StableDiffusionGeneratorPipeline = model_info.context + context.services.events.emit_model_load_completed ( + graph_execution_state_id=context.graph_execution_state_id, + node=self.dict(), + source_node_id=source_node_id, + model_name=self.model, + submodel=SDModelType.diffusers, + model_info=model_info + ) + + return model_ctx class NoiseInvocation(BaseInvocation): """Generates latent noise.""" @@ -135,7 +167,7 @@ class NoiseInvocation(BaseInvocation): # Text to image -class TextToLatentsInvocation(BaseInvocation): +class TextToLatentsInvocation(BaseInvocation, ModelChooser): """Generates latents from conditionings.""" type: Literal["t2l"] = "t2l" @@ -175,32 +207,6 @@ class TextToLatentsInvocation(BaseInvocation): source_node_id=source_node_id, ) - def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline: - model_info = model_manager.get_model(self.model) - model_name = model_info.name - model_hash = model_info.hash - model_ctx: StableDiffusionGeneratorPipeline = model_info.context - with model_ctx as model: - model.scheduler = get_scheduler( - model=model, - scheduler_name=self.scheduler - ) - - if isinstance(model, DiffusionPipeline): - for component in [model.unet, model.vae]: - configure_model_padding(component, - self.seamless, - self.seamless_axes - ) - else: - configure_model_padding(model, - self.seamless, - self.seamless_axes - ) - - return model_ctx - - def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData: c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) @@ -230,8 +236,8 @@ class TextToLatentsInvocation(BaseInvocation): def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - model = self.get_model(context.services.model_manager) - conditioning_data = self.get_conditioning_data(context, model) + with self.choose_model(context) as model: + conditioning_data = self.get_conditioning_data(context, model) # TODO: Verify the noise is the right size result_latents, result_attention_map_saver = model.latents_from_embeddings( @@ -251,8 +257,30 @@ class TextToLatentsInvocation(BaseInvocation): latents=LatentsField(latents_name=name) ) + def choose_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline: + model_ctx = super().choose_model(context) -class LatentsToLatentsInvocation(TextToLatentsInvocation): + with model_ctx as model: + model.scheduler = get_scheduler( + model=model, + scheduler_name=self.scheduler + ) + + if isinstance(model, DiffusionPipeline): + for component in [model.unet, model.vae]: + configure_model_padding(component, + self.seamless, + self.seamless_axes + ) + else: + configure_model_padding(model, + self.seamless, + self.seamless_axes + ) + return model_ctx + + +class LatentsToLatentsInvocation(TextToLatentsInvocation, ModelChooser): """Generates latents using latents as base image.""" type: Literal["l2l"] = "l2l" @@ -283,7 +311,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - with self.get_model(context.services.model_manager) as model: + with self.choose_model(context) as model: conditioning_data = self.get_conditioning_data(model) # TODO: Verify the noise is the right size @@ -318,7 +346,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): # Latent to image -class LatentsToImageInvocation(BaseInvocation): +class LatentsToImageInvocation(BaseInvocation, ModelChooser): """Generates an image from latents.""" type: Literal["l2i"] = "l2i" @@ -343,9 +371,7 @@ class LatentsToImageInvocation(BaseInvocation): latents = context.services.latents.get(self.latents.latents_name) # TODO: this only really needs the vae - model_info = choose_model(context.services.model_manager, self.model) - - with model_info.context as model: + with self.choose_model(context) as model: with torch.inference_mode(): np_image = model.decode_latents(latents) image = model.numpy_to_pil(np_image)[0] @@ -432,7 +458,7 @@ class ScaleLatentsInvocation(BaseInvocation): return LatentsOutput(latents=LatentsField(latents_name=name)) -class ImageToLatentsInvocation(BaseInvocation): +class ImageToLatentsInvocation(BaseInvocation, ModelChooser): """Encodes an image into latents.""" type: Literal["i2l"] = "i2l" @@ -457,7 +483,7 @@ class ImageToLatentsInvocation(BaseInvocation): ) # TODO: this only really needs the vae - model_info = choose_model(context.services.model_manager, self.model) + model_info = self.choose_model(context) model: StableDiffusionGeneratorPipeline = model_info["model"] image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) @@ -474,3 +500,4 @@ class ImageToLatentsInvocation(BaseInvocation): name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.set(name, latents) return LatentsOutput(latents=LatentsField(latents_name=name)) + diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index a25549dfc5..dda5557315 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -4,6 +4,7 @@ from typing import Any from invokeai.app.api.models.images import ProgressImage from invokeai.app.util.misc import get_timestamp from invokeai.app.services.model_manager_service import SDModelType, SDModelInfo +from invokeai.app.models.exceptions import CanceledException class EventServiceBase: session_event: str = "session_event" diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index e9c959d5e0..3e245bc47e 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import Union, Callable, types +from dataclasses import dataclass from invokeai.backend.model_management.model_manager import ( ModelManager, @@ -15,6 +16,11 @@ from invokeai.backend.model_management.model_manager import ( from ...backend import Args,Globals # this must go when pr 3340 merged from ...backend.util import choose_precision, choose_torch_device +@dataclass +class LastUsedModel: + model_name: str + model_type: SDModelType + class ModelManagerServiceBase(ABC): """Responsible for managing models on disk and in memory""" @@ -273,6 +279,22 @@ class ModelManagerService(ModelManagerServiceBase): Retrieve the indicated model. submodel can be used to get a part (such as the vae) of a diffusers mode. """ + + # Temporary hack here: we remember the last model fetched + # so that when executing a graph, the first node called gets + # to set default model for subsequent nodes in the event that + # they do not set the model explicitly. This should be + # displaced by model loader mechanism. + # This is to work around lack of model loader at current time, + # which was causing inconsistent model usage throughout graph. + if not model_name: + self.logger.debug('No model name provided, defaulting to last loaded model') + model_name = LastUsedModel.name + model_type = model_type or LastUsedModel.type + else: + LastUsedModel.name = model_name + LastUsedModel.model_type = model_type + return self.mgr.get_model( model_name, model_type,