latents.py converted to use model manager service; events emitted

This commit is contained in:
Lincoln Stein 2023-05-11 23:33:24 -04:00
parent df5b968954
commit 11ecf438f5
3 changed files with 104 additions and 54 deletions

View File

@ -1,31 +1,32 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import random
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import diffusers
import einops import einops
from pydantic import BaseModel, Field
import torch import torch
from diffusers import DiffusionPipeline
from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec from ...backend.model_management.model_manager import SDModelType
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput, build_image_output
from .compel import ConditioningField
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler from ...backend.stable_diffusion.diffusers_pipeline import (
import diffusers ConditioningData, StableDiffusionGeneratorPipeline,
from diffusers import DiffusionPipeline image_resized_to_grid_as_tensor)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings
from ...backend.util.devices import choose_torch_device, torch_dtype
from ..services.image_storage import ImageType
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .compel import ConditioningField
from .image import ImageField, ImageOutput, build_image_output
class LatentsField(BaseModel): class LatentsField(BaseModel):
@ -103,6 +104,37 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
# x = (1 - self.perlin) * x + self.perlin * perlin_noise # x = (1 - self.perlin) * x + self.perlin * perlin_noise
return x return x
class ModelChooser:
def choose_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline:
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
context.services.events.emit_model_load_started(
graph_execution_state_id=context.graph_execution_state_id,
node=self.dict(),
source_node_id=source_node_id,
model_name=self.model,
submodel=SDModelType.diffusers
)
model_manager = context.services.model_manager
model_info = model_manager.get_model(self.model)
model_ctx: StableDiffusionGeneratorPipeline = model_info.context
context.services.events.emit_model_load_completed (
graph_execution_state_id=context.graph_execution_state_id,
node=self.dict(),
source_node_id=source_node_id,
model_name=self.model,
submodel=SDModelType.diffusers,
model_info=model_info
)
return model_ctx
class NoiseInvocation(BaseInvocation): class NoiseInvocation(BaseInvocation):
"""Generates latent noise.""" """Generates latent noise."""
@ -135,7 +167,7 @@ class NoiseInvocation(BaseInvocation):
# Text to image # Text to image
class TextToLatentsInvocation(BaseInvocation): class TextToLatentsInvocation(BaseInvocation, ModelChooser):
"""Generates latents from conditionings.""" """Generates latents from conditionings."""
type: Literal["t2l"] = "t2l" type: Literal["t2l"] = "t2l"
@ -175,32 +207,6 @@ class TextToLatentsInvocation(BaseInvocation):
source_node_id=source_node_id, source_node_id=source_node_id,
) )
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = model_manager.get_model(self.model)
model_name = model_info.name
model_hash = model_info.hash
model_ctx: StableDiffusionGeneratorPipeline = model_info.context
with model_ctx as model:
model.scheduler = get_scheduler(
model=model,
scheduler_name=self.scheduler
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
return model_ctx
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData: def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
@ -230,8 +236,8 @@ class TextToLatentsInvocation(BaseInvocation):
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager) with self.choose_model(context) as model:
conditioning_data = self.get_conditioning_data(context, model) conditioning_data = self.get_conditioning_data(context, model)
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings( result_latents, result_attention_map_saver = model.latents_from_embeddings(
@ -251,8 +257,30 @@ class TextToLatentsInvocation(BaseInvocation):
latents=LatentsField(latents_name=name) latents=LatentsField(latents_name=name)
) )
def choose_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline:
model_ctx = super().choose_model(context)
class LatentsToLatentsInvocation(TextToLatentsInvocation): with model_ctx as model:
model.scheduler = get_scheduler(
model=model,
scheduler_name=self.scheduler
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
return model_ctx
class LatentsToLatentsInvocation(TextToLatentsInvocation, ModelChooser):
"""Generates latents using latents as base image.""" """Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l" type: Literal["l2l"] = "l2l"
@ -283,7 +311,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
with self.get_model(context.services.model_manager) as model: with self.choose_model(context) as model:
conditioning_data = self.get_conditioning_data(model) conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
@ -318,7 +346,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# Latent to image # Latent to image
class LatentsToImageInvocation(BaseInvocation): class LatentsToImageInvocation(BaseInvocation, ModelChooser):
"""Generates an image from latents.""" """Generates an image from latents."""
type: Literal["l2i"] = "l2i" type: Literal["l2i"] = "l2i"
@ -343,9 +371,7 @@ class LatentsToImageInvocation(BaseInvocation):
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
# TODO: this only really needs the vae # TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model) with self.choose_model(context) as model:
with model_info.context as model:
with torch.inference_mode(): with torch.inference_mode():
np_image = model.decode_latents(latents) np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0] image = model.numpy_to_pil(np_image)[0]
@ -432,7 +458,7 @@ class ScaleLatentsInvocation(BaseInvocation):
return LatentsOutput(latents=LatentsField(latents_name=name)) return LatentsOutput(latents=LatentsField(latents_name=name))
class ImageToLatentsInvocation(BaseInvocation): class ImageToLatentsInvocation(BaseInvocation, ModelChooser):
"""Encodes an image into latents.""" """Encodes an image into latents."""
type: Literal["i2l"] = "i2l" type: Literal["i2l"] = "i2l"
@ -457,7 +483,7 @@ class ImageToLatentsInvocation(BaseInvocation):
) )
# TODO: this only really needs the vae # TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model) model_info = self.choose_model(context)
model: StableDiffusionGeneratorPipeline = model_info["model"] model: StableDiffusionGeneratorPipeline = model_info["model"]
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
@ -474,3 +500,4 @@ class ImageToLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, latents) context.services.latents.set(name, latents)
return LatentsOutput(latents=LatentsField(latents_name=name)) return LatentsOutput(latents=LatentsField(latents_name=name))

View File

@ -4,6 +4,7 @@ from typing import Any
from invokeai.app.api.models.images import ProgressImage from invokeai.app.api.models.images import ProgressImage
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import SDModelType, SDModelInfo from invokeai.app.services.model_manager_service import SDModelType, SDModelInfo
from invokeai.app.models.exceptions import CanceledException
class EventServiceBase: class EventServiceBase:
session_event: str = "session_event" session_event: str = "session_event"

View File

@ -4,6 +4,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Union, Callable, types from typing import Union, Callable, types
from dataclasses import dataclass
from invokeai.backend.model_management.model_manager import ( from invokeai.backend.model_management.model_manager import (
ModelManager, ModelManager,
@ -15,6 +16,11 @@ from invokeai.backend.model_management.model_manager import (
from ...backend import Args,Globals # this must go when pr 3340 merged from ...backend import Args,Globals # this must go when pr 3340 merged
from ...backend.util import choose_precision, choose_torch_device from ...backend.util import choose_precision, choose_torch_device
@dataclass
class LastUsedModel:
model_name: str
model_type: SDModelType
class ModelManagerServiceBase(ABC): class ModelManagerServiceBase(ABC):
"""Responsible for managing models on disk and in memory""" """Responsible for managing models on disk and in memory"""
@ -273,6 +279,22 @@ class ModelManagerService(ModelManagerServiceBase):
Retrieve the indicated model. submodel can be used to get a Retrieve the indicated model. submodel can be used to get a
part (such as the vae) of a diffusers mode. part (such as the vae) of a diffusers mode.
""" """
# Temporary hack here: we remember the last model fetched
# so that when executing a graph, the first node called gets
# to set default model for subsequent nodes in the event that
# they do not set the model explicitly. This should be
# displaced by model loader mechanism.
# This is to work around lack of model loader at current time,
# which was causing inconsistent model usage throughout graph.
if not model_name:
self.logger.debug('No model name provided, defaulting to last loaded model')
model_name = LastUsedModel.name
model_type = model_type or LastUsedModel.type
else:
LastUsedModel.name = model_name
LastUsedModel.model_type = model_type
return self.mgr.get_model( return self.mgr.get_model(
model_name, model_name,
model_type, model_type,