latents.py converted to use model manager service; events emitted

This commit is contained in:
Lincoln Stein 2023-05-11 23:33:24 -04:00
parent df5b968954
commit 11ecf438f5
3 changed files with 104 additions and 54 deletions

View File

@ -1,31 +1,32 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import random
from typing import Literal, Optional, Union
import diffusers
import einops
from pydantic import BaseModel, Field
import torch
from diffusers import DiffusionPipeline
from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput, build_image_output
from .compel import ConditioningField
from ...backend.model_management.model_manager import SDModelType
from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
from diffusers import DiffusionPipeline
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, StableDiffusionGeneratorPipeline,
image_resized_to_grid_as_tensor)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings
from ...backend.util.devices import choose_torch_device, torch_dtype
from ..services.image_storage import ImageType
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .compel import ConditioningField
from .image import ImageField, ImageOutput, build_image_output
class LatentsField(BaseModel):
@ -103,6 +104,37 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
return x
class ModelChooser:
def choose_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline:
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
context.services.events.emit_model_load_started(
graph_execution_state_id=context.graph_execution_state_id,
node=self.dict(),
source_node_id=source_node_id,
model_name=self.model,
submodel=SDModelType.diffusers
)
model_manager = context.services.model_manager
model_info = model_manager.get_model(self.model)
model_ctx: StableDiffusionGeneratorPipeline = model_info.context
context.services.events.emit_model_load_completed (
graph_execution_state_id=context.graph_execution_state_id,
node=self.dict(),
source_node_id=source_node_id,
model_name=self.model,
submodel=SDModelType.diffusers,
model_info=model_info
)
return model_ctx
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
@ -135,7 +167,7 @@ class NoiseInvocation(BaseInvocation):
# Text to image
class TextToLatentsInvocation(BaseInvocation):
class TextToLatentsInvocation(BaseInvocation, ModelChooser):
"""Generates latents from conditionings."""
type: Literal["t2l"] = "t2l"
@ -175,32 +207,6 @@ class TextToLatentsInvocation(BaseInvocation):
source_node_id=source_node_id,
)
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = model_manager.get_model(self.model)
model_name = model_info.name
model_hash = model_info.hash
model_ctx: StableDiffusionGeneratorPipeline = model_info.context
with model_ctx as model:
model.scheduler = get_scheduler(
model=model,
scheduler_name=self.scheduler
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
return model_ctx
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
@ -230,8 +236,8 @@ class TextToLatentsInvocation(BaseInvocation):
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model)
with self.choose_model(context) as model:
conditioning_data = self.get_conditioning_data(context, model)
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings(
@ -251,8 +257,30 @@ class TextToLatentsInvocation(BaseInvocation):
latents=LatentsField(latents_name=name)
)
def choose_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline:
model_ctx = super().choose_model(context)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
with model_ctx as model:
model.scheduler = get_scheduler(
model=model,
scheduler_name=self.scheduler
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
return model_ctx
class LatentsToLatentsInvocation(TextToLatentsInvocation, ModelChooser):
"""Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l"
@ -283,7 +311,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
with self.get_model(context.services.model_manager) as model:
with self.choose_model(context) as model:
conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size
@ -318,7 +346,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# Latent to image
class LatentsToImageInvocation(BaseInvocation):
class LatentsToImageInvocation(BaseInvocation, ModelChooser):
"""Generates an image from latents."""
type: Literal["l2i"] = "l2i"
@ -343,9 +371,7 @@ class LatentsToImageInvocation(BaseInvocation):
latents = context.services.latents.get(self.latents.latents_name)
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)
with model_info.context as model:
with self.choose_model(context) as model:
with torch.inference_mode():
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]
@ -432,7 +458,7 @@ class ScaleLatentsInvocation(BaseInvocation):
return LatentsOutput(latents=LatentsField(latents_name=name))
class ImageToLatentsInvocation(BaseInvocation):
class ImageToLatentsInvocation(BaseInvocation, ModelChooser):
"""Encodes an image into latents."""
type: Literal["i2l"] = "i2l"
@ -457,7 +483,7 @@ class ImageToLatentsInvocation(BaseInvocation):
)
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)
model_info = self.choose_model(context)
model: StableDiffusionGeneratorPipeline = model_info["model"]
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
@ -474,3 +500,4 @@ class ImageToLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, latents)
return LatentsOutput(latents=LatentsField(latents_name=name))

View File

@ -4,6 +4,7 @@ from typing import Any
from invokeai.app.api.models.images import ProgressImage
from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import SDModelType, SDModelInfo
from invokeai.app.models.exceptions import CanceledException
class EventServiceBase:
session_event: str = "session_event"

View File

@ -4,6 +4,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Union, Callable, types
from dataclasses import dataclass
from invokeai.backend.model_management.model_manager import (
ModelManager,
@ -15,6 +16,11 @@ from invokeai.backend.model_management.model_manager import (
from ...backend import Args,Globals # this must go when pr 3340 merged
from ...backend.util import choose_precision, choose_torch_device
@dataclass
class LastUsedModel:
model_name: str
model_type: SDModelType
class ModelManagerServiceBase(ABC):
"""Responsible for managing models on disk and in memory"""
@ -273,6 +279,22 @@ class ModelManagerService(ModelManagerServiceBase):
Retrieve the indicated model. submodel can be used to get a
part (such as the vae) of a diffusers mode.
"""
# Temporary hack here: we remember the last model fetched
# so that when executing a graph, the first node called gets
# to set default model for subsequent nodes in the event that
# they do not set the model explicitly. This should be
# displaced by model loader mechanism.
# This is to work around lack of model loader at current time,
# which was causing inconsistent model usage throughout graph.
if not model_name:
self.logger.debug('No model name provided, defaulting to last loaded model')
model_name = LastUsedModel.name
model_type = model_type or LastUsedModel.type
else:
LastUsedModel.name = model_name
LastUsedModel.model_type = model_type
return self.mgr.get_model(
model_name,
model_type,