diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index ef17962f89..7593b34142 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -171,7 +171,7 @@ class TextToLatentsInvocation(BaseInvocation): # TODO: pass this an emitter method or something? or a session for dispatching? def dispatch_progress( self, context: InvocationContext, intermediate_state: PipelineIntermediateState - ) -> None: + ) -> None: if (context.services.queue.is_canceled(context.graph_execution_state_id)): raise CanceledException @@ -185,7 +185,7 @@ class TextToLatentsInvocation(BaseInvocation): diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context) - + def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline: model_info = choose_model(model_manager, self.model) model_name = model_info['model_name'] @@ -195,7 +195,7 @@ class TextToLatentsInvocation(BaseInvocation): model=model, scheduler_name=self.scheduler ) - + if isinstance(model, DiffusionPipeline): for component in [model.unet, model.vae]: configure_model_padding(component, @@ -292,57 +292,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): initial_latents = latent if self.strength < 1.0 else torch.zeros_like( latent, device=model.device, dtype=latent.dtype ) - - timesteps, _ = model.get_img2img_timesteps( - self.steps, - self.strength, - device=model.device, - ) - result_latents, result_attention_map_saver = model.latents_from_embeddings( - latents=initial_latents, - timesteps=timesteps, - noise=noise, - num_inference_steps=self.steps, - conditioning_data=conditioning_data, - callback=step_callback - ) - - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - torch.cuda.empty_cache() - - name = f'{context.graph_execution_state_id}__{self.id}' - context.services.latents.set(name, result_latents) - return LatentsOutput( - latents=LatentsField(latents_name=name) - ) - - -class LatentsToLatentsInvocation(TextToLatentsInvocation): - """Generates latents using latents as base image.""" - - type: Literal["l2l"] = "l2l" - - # Inputs - latents: Optional[LatentsField] = Field(description="The latents to use as a base image") - strength: float = Field(default=0.5, description="The strength of the latents to use") - - def invoke(self, context: InvocationContext) -> LatentsOutput: - noise = context.services.latents.get(self.noise.latents_name) - latent = context.services.latents.get(self.latents.latents_name) - - def step_callback(state: PipelineIntermediateState): - self.dispatch_progress(context, state) - - model = self.get_model(context.services.model_manager) - conditioning_data = self.get_conditioning_data(model) - - # TODO: Verify the noise is the right size - - initial_latents = latent if self.strength < 1.0 else torch.zeros_like( - latent, device=model.device, dtype=latent.dtype - ) - timesteps, _ = model.get_img2img_timesteps( self.steps, self.strength,