latents.py converted to use model manager service; events emitted

This commit is contained in:
Lincoln Stein
2023-05-11 23:33:24 -04:00
parent df5b968954
commit 11ecf438f5
3 changed files with 104 additions and 54 deletions

View File

@ -1,31 +1,32 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import random
from typing import Literal, Optional, Union
import diffusers
import einops
from pydantic import BaseModel, Field
import torch
from diffusers import DiffusionPipeline
from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput, build_image_output
from .compel import ConditioningField
from ...backend.model_management.model_manager import SDModelType
from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
from diffusers import DiffusionPipeline
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, StableDiffusionGeneratorPipeline,
image_resized_to_grid_as_tensor)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings
from ...backend.util.devices import choose_torch_device, torch_dtype
from ..services.image_storage import ImageType
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .compel import ConditioningField
from .image import ImageField, ImageOutput, build_image_output
class LatentsField(BaseModel):
@ -103,6 +104,37 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
return x
class ModelChooser:
def choose_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline:
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
context.services.events.emit_model_load_started(
graph_execution_state_id=context.graph_execution_state_id,
node=self.dict(),
source_node_id=source_node_id,
model_name=self.model,
submodel=SDModelType.diffusers
)
model_manager = context.services.model_manager
model_info = model_manager.get_model(self.model)
model_ctx: StableDiffusionGeneratorPipeline = model_info.context
context.services.events.emit_model_load_completed (
graph_execution_state_id=context.graph_execution_state_id,
node=self.dict(),
source_node_id=source_node_id,
model_name=self.model,
submodel=SDModelType.diffusers,
model_info=model_info
)
return model_ctx
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
@ -135,7 +167,7 @@ class NoiseInvocation(BaseInvocation):
# Text to image
class TextToLatentsInvocation(BaseInvocation):
class TextToLatentsInvocation(BaseInvocation, ModelChooser):
"""Generates latents from conditionings."""
type: Literal["t2l"] = "t2l"
@ -175,32 +207,6 @@ class TextToLatentsInvocation(BaseInvocation):
source_node_id=source_node_id,
)
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = model_manager.get_model(self.model)
model_name = model_info.name
model_hash = model_info.hash
model_ctx: StableDiffusionGeneratorPipeline = model_info.context
with model_ctx as model:
model.scheduler = get_scheduler(
model=model,
scheduler_name=self.scheduler
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
return model_ctx
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
@ -230,8 +236,8 @@ class TextToLatentsInvocation(BaseInvocation):
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model)
with self.choose_model(context) as model:
conditioning_data = self.get_conditioning_data(context, model)
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings(
@ -251,8 +257,30 @@ class TextToLatentsInvocation(BaseInvocation):
latents=LatentsField(latents_name=name)
)
def choose_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline:
model_ctx = super().choose_model(context)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
with model_ctx as model:
model.scheduler = get_scheduler(
model=model,
scheduler_name=self.scheduler
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
return model_ctx
class LatentsToLatentsInvocation(TextToLatentsInvocation, ModelChooser):
"""Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l"
@ -283,7 +311,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
with self.get_model(context.services.model_manager) as model:
with self.choose_model(context) as model:
conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size
@ -318,7 +346,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# Latent to image
class LatentsToImageInvocation(BaseInvocation):
class LatentsToImageInvocation(BaseInvocation, ModelChooser):
"""Generates an image from latents."""
type: Literal["l2i"] = "l2i"
@ -343,9 +371,7 @@ class LatentsToImageInvocation(BaseInvocation):
latents = context.services.latents.get(self.latents.latents_name)
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)
with model_info.context as model:
with self.choose_model(context) as model:
with torch.inference_mode():
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]
@ -432,7 +458,7 @@ class ScaleLatentsInvocation(BaseInvocation):
return LatentsOutput(latents=LatentsField(latents_name=name))
class ImageToLatentsInvocation(BaseInvocation):
class ImageToLatentsInvocation(BaseInvocation, ModelChooser):
"""Encodes an image into latents."""
type: Literal["i2l"] = "i2l"
@ -457,7 +483,7 @@ class ImageToLatentsInvocation(BaseInvocation):
)
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)
model_info = self.choose_model(context)
model: StableDiffusionGeneratorPipeline = model_info["model"]
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
@ -474,3 +500,4 @@ class ImageToLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, latents)
return LatentsOutput(latents=LatentsField(latents_name=name))