From 1e2db3a17f6a96af76a7c91cc366df87386c8075 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 25 May 2023 23:28:15 -0400 Subject: [PATCH] hook tiled_decode up to configuration --- invokeai/app/invocations/latent.py | 13 ++---------- invokeai/app/services/config.py | 2 +- .../app/services/model_manager_service.py | 21 ++++++++++--------- 3 files changed, 14 insertions(+), 22 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 6437c0d675..6b759c23d7 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -25,7 +25,7 @@ from ..services.model_manager_service import ModelManagerService from .baseinvocation import (BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext) from .compel import ConditioningField -from .image import ImageField, ImageOutput +from .image import ImageCategory, ImageField, ImageOutput from .model import ModelInfo, UNetField, VaeField @@ -390,7 +390,7 @@ class LatentsToImageInvocation(BaseInvocation): ) with vae_info as vae: - if self.tiled: + if self.tiled or context.services.configuration.tiled_decode: vae.enable_tiling() else: vae.disable_tiling() @@ -408,15 +408,6 @@ class LatentsToImageInvocation(BaseInvocation): image = VaeImageProcessor.numpy_to_pil(np_image)[0] - image_type = ImageType.RESULT - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id - ) - - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self - ) - torch.cuda.empty_cache() image_dto = context.services.images.create( diff --git a/invokeai/app/services/config.py b/invokeai/app/services/config.py index 9cee093af9..559528d6ae 100644 --- a/invokeai/app/services/config.py +++ b/invokeai/app/services/config.py @@ -352,7 +352,7 @@ setting environment variables INVOKEAI_. precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance') sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance') xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') - + tiled_decode : bool = Field(default=False, description"Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance') root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths') autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths') diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 36696affce..2374f50d44 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -549,16 +549,7 @@ class ModelManagerService(ModelManagerServiceBase): raise CanceledException() graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[node.id] - if context: - context.services.events.emit_model_load_started( - graph_execution_state_id=context.graph_execution_state_id, - node=node.dict(), - source_node_id=source_node_id, - model_name=model_name, - model_type=model_type, - submodel=submodel, - ) - else: + if model_info: context.services.events.emit_model_load_completed( graph_execution_state_id=context.graph_execution_state_id, node=node.dict(), @@ -568,6 +559,16 @@ class ModelManagerService(ModelManagerServiceBase): submodel=submodel, model_info=model_info ) + else: + context.services.events.emit_model_load_started( + graph_execution_state_id=context.graph_execution_state_id, + node=node.dict(), + source_node_id=source_node_id, + model_name=model_name, + model_type=model_type, + submodel=submodel, + ) + @property def logger(self):