fix bug in persistent model scheme

This commit is contained in:
Lincoln Stein 2023-05-12 00:14:56 -04:00
parent 11ecf438f5
commit 2ef79b8bf3
4 changed files with 93 additions and 53 deletions

View File

@ -71,7 +71,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
# Handle invalid model parameter # Handle invalid model parameter
model = context.services.model_manager.get_model(self.model) model = context.services.model_manager.get_model(self.model,node=self,context=context)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(

View File

@ -9,12 +9,10 @@ from diffusers import DiffusionPipeline
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ...backend.image_util.seamless import configure_model_padding from ...backend.image_util.seamless import configure_model_padding
from ...backend.model_management.model_manager import SDModelType
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, StableDiffusionGeneratorPipeline, ConditioningData, StableDiffusionGeneratorPipeline,
@ -104,37 +102,11 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
# x = (1 - self.perlin) * x + self.perlin * perlin_noise # x = (1 - self.perlin) * x + self.perlin * perlin_noise
return x return x
class ModelChooser: class ModelGetter:
def choose_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline: def get_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline:
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
context.services.events.emit_model_load_started(
graph_execution_state_id=context.graph_execution_state_id,
node=self.dict(),
source_node_id=source_node_id,
model_name=self.model,
submodel=SDModelType.diffusers
)
model_manager = context.services.model_manager model_manager = context.services.model_manager
model_info = model_manager.get_model(self.model) model_info = model_manager.get_model(self.model,node=self,context=context)
model_ctx: StableDiffusionGeneratorPipeline = model_info.context return model_info.context
context.services.events.emit_model_load_completed (
graph_execution_state_id=context.graph_execution_state_id,
node=self.dict(),
source_node_id=source_node_id,
model_name=self.model,
submodel=SDModelType.diffusers,
model_info=model_info
)
return model_ctx
class NoiseInvocation(BaseInvocation): class NoiseInvocation(BaseInvocation):
"""Generates latent noise.""" """Generates latent noise."""
@ -167,7 +139,7 @@ class NoiseInvocation(BaseInvocation):
# Text to image # Text to image
class TextToLatentsInvocation(BaseInvocation, ModelChooser): class TextToLatentsInvocation(BaseInvocation, ModelGetter):
"""Generates latents from conditionings.""" """Generates latents from conditionings."""
type: Literal["t2l"] = "t2l" type: Literal["t2l"] = "t2l"
@ -236,7 +208,7 @@ class TextToLatentsInvocation(BaseInvocation, ModelChooser):
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
with self.choose_model(context) as model: with self.get_model(context) as model:
conditioning_data = self.get_conditioning_data(context, model) conditioning_data = self.get_conditioning_data(context, model)
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
@ -257,8 +229,8 @@ class TextToLatentsInvocation(BaseInvocation, ModelChooser):
latents=LatentsField(latents_name=name) latents=LatentsField(latents_name=name)
) )
def choose_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline: def get_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline:
model_ctx = super().choose_model(context) model_ctx = super().get_model(context)
with model_ctx as model: with model_ctx as model:
model.scheduler = get_scheduler( model.scheduler = get_scheduler(
@ -280,7 +252,7 @@ class TextToLatentsInvocation(BaseInvocation, ModelChooser):
return model_ctx return model_ctx
class LatentsToLatentsInvocation(TextToLatentsInvocation, ModelChooser): class LatentsToLatentsInvocation(TextToLatentsInvocation, ModelGetter):
"""Generates latents using latents as base image.""" """Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l" type: Literal["l2l"] = "l2l"
@ -311,7 +283,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation, ModelChooser):
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
with self.choose_model(context) as model: with self.get_model(context) as model:
conditioning_data = self.get_conditioning_data(model) conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
@ -346,7 +318,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation, ModelChooser):
# Latent to image # Latent to image
class LatentsToImageInvocation(BaseInvocation, ModelChooser): class LatentsToImageInvocation(BaseInvocation, ModelGetter):
"""Generates an image from latents.""" """Generates an image from latents."""
type: Literal["l2i"] = "l2i" type: Literal["l2i"] = "l2i"
@ -371,7 +343,7 @@ class LatentsToImageInvocation(BaseInvocation, ModelChooser):
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
# TODO: this only really needs the vae # TODO: this only really needs the vae
with self.choose_model(context) as model: with self.get_model(context) as model:
with torch.inference_mode(): with torch.inference_mode():
np_image = model.decode_latents(latents) np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0] image = model.numpy_to_pil(np_image)[0]
@ -458,7 +430,7 @@ class ScaleLatentsInvocation(BaseInvocation):
return LatentsOutput(latents=LatentsField(latents_name=name)) return LatentsOutput(latents=LatentsField(latents_name=name))
class ImageToLatentsInvocation(BaseInvocation, ModelChooser): class ImageToLatentsInvocation(BaseInvocation, ModelGetter):
"""Encodes an image into latents.""" """Encodes an image into latents."""
type: Literal["i2l"] = "i2l" type: Literal["i2l"] = "i2l"
@ -483,7 +455,7 @@ class ImageToLatentsInvocation(BaseInvocation, ModelChooser):
) )
# TODO: this only really needs the vae # TODO: this only really needs the vae
model_info = self.choose_model(context) model_info = self.get_model(context)
model: StableDiffusionGeneratorPipeline = model_info["model"] model: StableDiffusionGeneratorPipeline = model_info["model"]
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))

View File

@ -109,6 +109,7 @@ class EventServiceBase:
node: dict, node: dict,
source_node_id: str, source_node_id: str,
model_name: str, model_name: str,
model_type: SDModelType,
submodel: SDModelType, submodel: SDModelType,
) -> None: ) -> None:
"""Emitted when a model is requested""" """Emitted when a model is requested"""
@ -119,6 +120,7 @@ class EventServiceBase:
node=node, node=node,
source_node_id=source_node_id, source_node_id=source_node_id,
model_name=str, model_name=str,
model_type=model_type,
submodel=submodel, submodel=submodel,
), ),
) )
@ -129,6 +131,7 @@ class EventServiceBase:
node: dict, node: dict,
source_node_id: str, source_node_id: str,
model_name: str, model_name: str,
model_type: SDModelType,
submodel: SDModelType, submodel: SDModelType,
model_info: SDModelInfo, model_info: SDModelInfo,
) -> None: ) -> None:
@ -140,6 +143,7 @@ class EventServiceBase:
node=node, node=node,
source_node_id=source_node_id, source_node_id=source_node_id,
model_name=str, model_name=str,
model_type=model_type,
submodel=submodel, submodel=submodel,
model_info=model_info, model_info=model_info,
), ),

View File

@ -10,16 +10,18 @@ from invokeai.backend.model_management.model_manager import (
ModelManager, ModelManager,
SDModelType, SDModelType,
SDModelInfo, SDModelInfo,
types,
torch, torch,
) )
from invokeai.app.models.exceptions import CanceledException
from ...backend import Args,Globals # this must go when pr 3340 merged from ...backend import Args,Globals # this must go when pr 3340 merged
from ...backend.util import choose_precision, choose_torch_device from ...backend.util import choose_precision, choose_torch_device
@dataclass @dataclass
class LastUsedModel: class LastUsedModel:
model_name: str model_name: str=None
model_type: SDModelType model_type: SDModelType=None
last_used_model = LastUsedModel()
class ModelManagerServiceBase(ABC): class ModelManagerServiceBase(ABC):
"""Responsible for managing models on disk and in memory""" """Responsible for managing models on disk and in memory"""
@ -42,7 +44,9 @@ class ModelManagerServiceBase(ABC):
def get_model(self, def get_model(self,
model_name: str, model_name: str,
model_type: SDModelType=SDModelType.diffusers, model_type: SDModelType=SDModelType.diffusers,
submodel: SDModelType=None submodel: SDModelType=None,
node=None, # circular dependency issues, so untyped at moment
context=None,
)->SDModelInfo: )->SDModelInfo:
"""Retrieve the indicated model with name and type. """Retrieve the indicated model with name and type.
submodel can be used to get a part (such as the vae) submodel can be used to get a part (such as the vae)
@ -274,6 +278,8 @@ class ModelManagerService(ModelManagerServiceBase):
model_name: str, model_name: str,
model_type: SDModelType=SDModelType.diffusers, model_type: SDModelType=SDModelType.diffusers,
submodel: SDModelType=None, submodel: SDModelType=None,
node=None,
context=None,
)->SDModelInfo: )->SDModelInfo:
""" """
Retrieve the indicated model. submodel can be used to get a Retrieve the indicated model. submodel can be used to get a
@ -287,20 +293,45 @@ class ModelManagerService(ModelManagerServiceBase):
# displaced by model loader mechanism. # displaced by model loader mechanism.
# This is to work around lack of model loader at current time, # This is to work around lack of model loader at current time,
# which was causing inconsistent model usage throughout graph. # which was causing inconsistent model usage throughout graph.
global last_used_model
if not model_name: if not model_name:
self.logger.debug('No model name provided, defaulting to last loaded model') self.logger.debug('No model name provided, defaulting to last loaded model')
model_name = LastUsedModel.name model_name = last_used_model.model_name
model_type = model_type or LastUsedModel.type model_type = model_type or last_used_model.model_type
else: else:
LastUsedModel.name = model_name last_used_model.model_name = model_name
LastUsedModel.model_type = model_type last_used_model.model_type = model_type
return self.mgr.get_model( # if we are called from within a node, then we get to emit
# load start and complete events
if node and context:
self._emit_load_event(
node=node,
context=context,
model_name=model_name,
model_type=model_type,
submodel=submodel
)
model_info = self.mgr.get_model(
model_name, model_name,
model_type, model_type,
submodel, submodel,
) )
if node and context:
self._emit_load_event(
node=node,
context=context,
model_name=model_name,
model_type=model_type,
submodel=submodel,
model_info=model_info
)
return model_info
def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool: def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool:
""" """
Given a model name, returns True if it is a valid Given a model name, returns True if it is a valid
@ -466,6 +497,39 @@ class ModelManagerService(ModelManagerServiceBase):
""" """
return self.mgr.commit(conf_file) return self.mgr.commit(conf_file)
def _emit_load_event(
self,
node,
context,
model_name: str,
model_type: SDModelType,
submodel: SDModelType,
model_info: SDModelInfo=None,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[node.id]
if context:
context.services.events.emit_model_load_started(
graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name,
model_type=model_type,
submodel=submodel,
)
else:
context.services.events.emit_model_load_completed (
graph_execution_state_id=context.graph_execution_state_id,
node=node.dict(),
source_node_id=source_node_id,
model_name=model_name,
model_type=model_type,
submodel=submodel,
model_info=model_info
)
@property @property
def logger(self): def logger(self):
return self.mgr.logger return self.mgr.logger