From 2ef79b8bf324802ef5b8df32bcdf276fc8e1704d Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 12 May 2023 00:14:56 -0400 Subject: [PATCH] fix bug in persistent model scheme --- invokeai/app/invocations/generate.py | 2 +- invokeai/app/invocations/latent.py | 56 ++++--------- invokeai/app/services/events.py | 4 + .../app/services/model_manager_service.py | 84 ++++++++++++++++--- 4 files changed, 93 insertions(+), 53 deletions(-) diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 3d5dd7f5ce..1f0d99ba46 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -71,7 +71,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: # Handle invalid model parameter - model = context.services.model_manager.get_model(self.model) + model = context.services.model_manager.get_model(self.model,node=self,context=context) # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get( diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 91320173ed..3f99d08bd1 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -9,12 +9,10 @@ from diffusers import DiffusionPipeline from diffusers.schedulers import SchedulerMixin as Scheduler from pydantic import BaseModel, Field -from invokeai.app.models.exceptions import CanceledException from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.step_callback import stable_diffusion_step_callback from ...backend.image_util.seamless import configure_model_padding -from ...backend.model_management.model_manager import SDModelType from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ConditioningData, StableDiffusionGeneratorPipeline, @@ -104,37 +102,11 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c # x = (1 - self.perlin) * x + self.perlin * perlin_noise return x -class ModelChooser: - def choose_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline: - - if context.services.queue.is_canceled(context.graph_execution_state_id): - raise CanceledException - - # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] - - context.services.events.emit_model_load_started( - graph_execution_state_id=context.graph_execution_state_id, - node=self.dict(), - source_node_id=source_node_id, - model_name=self.model, - submodel=SDModelType.diffusers - ) - +class ModelGetter: + def get_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline: model_manager = context.services.model_manager - model_info = model_manager.get_model(self.model) - model_ctx: StableDiffusionGeneratorPipeline = model_info.context - context.services.events.emit_model_load_completed ( - graph_execution_state_id=context.graph_execution_state_id, - node=self.dict(), - source_node_id=source_node_id, - model_name=self.model, - submodel=SDModelType.diffusers, - model_info=model_info - ) - - return model_ctx + model_info = model_manager.get_model(self.model,node=self,context=context) + return model_info.context class NoiseInvocation(BaseInvocation): """Generates latent noise.""" @@ -167,7 +139,7 @@ class NoiseInvocation(BaseInvocation): # Text to image -class TextToLatentsInvocation(BaseInvocation, ModelChooser): +class TextToLatentsInvocation(BaseInvocation, ModelGetter): """Generates latents from conditionings.""" type: Literal["t2l"] = "t2l" @@ -236,7 +208,7 @@ class TextToLatentsInvocation(BaseInvocation, ModelChooser): def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - with self.choose_model(context) as model: + with self.get_model(context) as model: conditioning_data = self.get_conditioning_data(context, model) # TODO: Verify the noise is the right size @@ -257,8 +229,8 @@ class TextToLatentsInvocation(BaseInvocation, ModelChooser): latents=LatentsField(latents_name=name) ) - def choose_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline: - model_ctx = super().choose_model(context) + def get_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline: + model_ctx = super().get_model(context) with model_ctx as model: model.scheduler = get_scheduler( @@ -280,7 +252,7 @@ class TextToLatentsInvocation(BaseInvocation, ModelChooser): return model_ctx -class LatentsToLatentsInvocation(TextToLatentsInvocation, ModelChooser): +class LatentsToLatentsInvocation(TextToLatentsInvocation, ModelGetter): """Generates latents using latents as base image.""" type: Literal["l2l"] = "l2l" @@ -311,7 +283,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation, ModelChooser): def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - with self.choose_model(context) as model: + with self.get_model(context) as model: conditioning_data = self.get_conditioning_data(model) # TODO: Verify the noise is the right size @@ -346,7 +318,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation, ModelChooser): # Latent to image -class LatentsToImageInvocation(BaseInvocation, ModelChooser): +class LatentsToImageInvocation(BaseInvocation, ModelGetter): """Generates an image from latents.""" type: Literal["l2i"] = "l2i" @@ -371,7 +343,7 @@ class LatentsToImageInvocation(BaseInvocation, ModelChooser): latents = context.services.latents.get(self.latents.latents_name) # TODO: this only really needs the vae - with self.choose_model(context) as model: + with self.get_model(context) as model: with torch.inference_mode(): np_image = model.decode_latents(latents) image = model.numpy_to_pil(np_image)[0] @@ -458,7 +430,7 @@ class ScaleLatentsInvocation(BaseInvocation): return LatentsOutput(latents=LatentsField(latents_name=name)) -class ImageToLatentsInvocation(BaseInvocation, ModelChooser): +class ImageToLatentsInvocation(BaseInvocation, ModelGetter): """Encodes an image into latents.""" type: Literal["i2l"] = "i2l" @@ -483,7 +455,7 @@ class ImageToLatentsInvocation(BaseInvocation, ModelChooser): ) # TODO: this only really needs the vae - model_info = self.choose_model(context) + model_info = self.get_model(context) model: StableDiffusionGeneratorPipeline = model_info["model"] image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index dda5557315..03a36962fc 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -109,6 +109,7 @@ class EventServiceBase: node: dict, source_node_id: str, model_name: str, + model_type: SDModelType, submodel: SDModelType, ) -> None: """Emitted when a model is requested""" @@ -119,6 +120,7 @@ class EventServiceBase: node=node, source_node_id=source_node_id, model_name=str, + model_type=model_type, submodel=submodel, ), ) @@ -129,6 +131,7 @@ class EventServiceBase: node: dict, source_node_id: str, model_name: str, + model_type: SDModelType, submodel: SDModelType, model_info: SDModelInfo, ) -> None: @@ -140,6 +143,7 @@ class EventServiceBase: node=node, source_node_id=source_node_id, model_name=str, + model_type=model_type, submodel=submodel, model_info=model_info, ), diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 3e245bc47e..3c50d2ba2d 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -10,16 +10,18 @@ from invokeai.backend.model_management.model_manager import ( ModelManager, SDModelType, SDModelInfo, - types, torch, ) +from invokeai.app.models.exceptions import CanceledException from ...backend import Args,Globals # this must go when pr 3340 merged from ...backend.util import choose_precision, choose_torch_device @dataclass class LastUsedModel: - model_name: str - model_type: SDModelType + model_name: str=None + model_type: SDModelType=None + +last_used_model = LastUsedModel() class ModelManagerServiceBase(ABC): """Responsible for managing models on disk and in memory""" @@ -42,7 +44,9 @@ class ModelManagerServiceBase(ABC): def get_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers, - submodel: SDModelType=None + submodel: SDModelType=None, + node=None, # circular dependency issues, so untyped at moment + context=None, )->SDModelInfo: """Retrieve the indicated model with name and type. submodel can be used to get a part (such as the vae) @@ -274,6 +278,8 @@ class ModelManagerService(ModelManagerServiceBase): model_name: str, model_type: SDModelType=SDModelType.diffusers, submodel: SDModelType=None, + node=None, + context=None, )->SDModelInfo: """ Retrieve the indicated model. submodel can be used to get a @@ -287,20 +293,45 @@ class ModelManagerService(ModelManagerServiceBase): # displaced by model loader mechanism. # This is to work around lack of model loader at current time, # which was causing inconsistent model usage throughout graph. + global last_used_model + if not model_name: self.logger.debug('No model name provided, defaulting to last loaded model') - model_name = LastUsedModel.name - model_type = model_type or LastUsedModel.type + model_name = last_used_model.model_name + model_type = model_type or last_used_model.model_type else: - LastUsedModel.name = model_name - LastUsedModel.model_type = model_type - - return self.mgr.get_model( + last_used_model.model_name = model_name + last_used_model.model_type = model_type + + # if we are called from within a node, then we get to emit + # load start and complete events + if node and context: + self._emit_load_event( + node=node, + context=context, + model_name=model_name, + model_type=model_type, + submodel=submodel + ) + + model_info = self.mgr.get_model( model_name, model_type, submodel, ) + if node and context: + self._emit_load_event( + node=node, + context=context, + model_name=model_name, + model_type=model_type, + submodel=submodel, + model_info=model_info + ) + + return model_info + def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool: """ Given a model name, returns True if it is a valid @@ -466,6 +497,39 @@ class ModelManagerService(ModelManagerServiceBase): """ return self.mgr.commit(conf_file) + def _emit_load_event( + self, + node, + context, + model_name: str, + model_type: SDModelType, + submodel: SDModelType, + model_info: SDModelInfo=None, + ): + if context.services.queue.is_canceled(context.graph_execution_state_id): + raise CanceledException + graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) + source_node_id = graph_execution_state.prepared_source_mapping[node.id] + if context: + context.services.events.emit_model_load_started( + graph_execution_state_id=context.graph_execution_state_id, + node=node.dict(), + source_node_id=source_node_id, + model_name=model_name, + model_type=model_type, + submodel=submodel, + ) + else: + context.services.events.emit_model_load_completed ( + graph_execution_state_id=context.graph_execution_state_id, + node=node.dict(), + source_node_id=source_node_id, + model_name=model_name, + model_type=model_type, + submodel=submodel, + model_info=model_info + ) + @property def logger(self): return self.mgr.logger