diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 3f99d08bd1..9ca1c81a3d 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -7,6 +7,7 @@ import einops import torch from diffusers import DiffusionPipeline from diffusers.schedulers import SchedulerMixin as Scheduler +from diffusers.image_processor import VaeImageProcessor from pydantic import BaseModel, Field from invokeai.app.util.misc import SEED_MAX, get_random_seed @@ -26,6 +27,9 @@ from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .compel import ConditioningField from .image import ImageField, ImageOutput, build_image_output +from .model import ModelInfo, UNetField, VaeField +from ...backend.model_management import SDModelType + class LatentsField(BaseModel): """A latents field used for passing latents between invocations""" @@ -70,9 +74,21 @@ SAMPLER_NAME_VALUES = Literal[ ] -def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: +def get_scheduler( + context: InvocationContext, + scheduler_info: ModelInfo, + scheduler_name: str, +) -> Scheduler: + orig_scheduler_info = context.services.model_manager.get_model( + model_name=scheduler_info.model_name, + model_type=SDModelType[scheduler_info.model_type], + submodel=SDModelType[scheduler_info.submodel], + ) + with orig_scheduler_info.context as orig_scheduler: + scheduler_config = orig_scheduler.config + scheduler_class = scheduler_map.get(scheduler_name,'ddim') - scheduler = scheduler_class.from_config(model.scheduler.config) + scheduler = scheduler_class.from_config(scheduler_config) # hack copied over from generate.py if not hasattr(scheduler, 'uses_inpainting_model'): scheduler.uses_inpainting_model = lambda: False @@ -102,12 +118,6 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c # x = (1 - self.perlin) * x + self.perlin * perlin_noise return x -class ModelGetter: - def get_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline: - model_manager = context.services.model_manager - model_info = model_manager.get_model(self.model,node=self,context=context) - return model_info.context - class NoiseInvocation(BaseInvocation): """Generates latent noise.""" @@ -139,7 +149,7 @@ class NoiseInvocation(BaseInvocation): # Text to image -class TextToLatentsInvocation(BaseInvocation, ModelGetter): +class TextToLatentsInvocation(BaseInvocation): """Generates latents from conditionings.""" type: Literal["t2l"] = "t2l" @@ -152,9 +162,10 @@ class TextToLatentsInvocation(BaseInvocation, ModelGetter): steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" ) - model: str = Field(default="", description="The model to use (currently ignored)") seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") + + unet: UNetField = Field(default=None, description="UNet submodel") # fmt: on # Schema customisation @@ -162,9 +173,6 @@ class TextToLatentsInvocation(BaseInvocation, ModelGetter): schema_extra = { "ui": { "tags": ["latents", "image"], - "type_hints": { - "model": "model" - } }, } @@ -179,7 +187,7 @@ class TextToLatentsInvocation(BaseInvocation, ModelGetter): source_node_id=source_node_id, ) - def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData: + def get_conditioning_data(self, context: InvocationContext, scheduler) -> ConditioningData: c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) @@ -194,9 +202,36 @@ class TextToLatentsInvocation(BaseInvocation, ModelGetter): h_symmetry_time_pct=None,#h_symmetry_time_pct, v_symmetry_time_pct=None#v_symmetry_time_pct, ), - ).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta) + ).add_scheduler_args_if_applicable(scheduler, eta=None)#ddim_eta) return conditioning_data + def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline: + configure_model_padding( + unet, + self.seamless, + self.seamless_axes, + ) + + class FakeVae: + class FakeVaeConfig: + def __init__(self): + self.block_out_channels = [0] + + def __init__(self): + self.config = FakeVae.FakeVaeConfig() + + return StableDiffusionGeneratorPipeline( + vae=FakeVae(), # TODO: oh... + text_encoder=None, + tokenizer=None, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + precision="float16" if unet.dtype == torch.float16 else "float32", + #precision="float16", # TODO: + ) def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) @@ -207,19 +242,33 @@ class TextToLatentsInvocation(BaseInvocation, ModelGetter): def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - - with self.get_model(context) as model: - conditioning_data = self.get_conditioning_data(context, model) - # TODO: Verify the noise is the right size - result_latents, result_attention_map_saver = model.latents_from_embeddings( - latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)), - noise=noise, - num_inference_steps=self.steps, - conditioning_data=conditioning_data, - callback=step_callback + #unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) + unet_info = context.services.model_manager.get_model( + model_name=self.unet.unet.model_name, + model_type=SDModelType[self.unet.unet.model_type], + submodel=SDModelType[self.unet.unet.submodel] if self.unet.unet.submodel else None, ) + with unet_info.context as unet: + scheduler = get_scheduler( + context=context, + scheduler_info=self.unet.scheduler, + scheduler_name=self.scheduler, + ) + + pipeline = self.create_pipeline(unet, scheduler) + conditioning_data = self.get_conditioning_data(context, scheduler) + + # TODO: Verify the noise is the right size + result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( + latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), + noise=noise, + num_inference_steps=self.steps, + conditioning_data=conditioning_data, + callback=step_callback + ) + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -229,30 +278,8 @@ class TextToLatentsInvocation(BaseInvocation, ModelGetter): latents=LatentsField(latents_name=name) ) - def get_model(self, context: InvocationContext) -> StableDiffusionGeneratorPipeline: - model_ctx = super().get_model(context) - with model_ctx as model: - model.scheduler = get_scheduler( - model=model, - scheduler_name=self.scheduler - ) - - if isinstance(model, DiffusionPipeline): - for component in [model.unet, model.vae]: - configure_model_padding(component, - self.seamless, - self.seamless_axes - ) - else: - configure_model_padding(model, - self.seamless, - self.seamless_axes - ) - return model_ctx - - -class LatentsToLatentsInvocation(TextToLatentsInvocation, ModelGetter): +class LatentsToLatentsInvocation(TextToLatentsInvocation): """Generates latents using latents as base image.""" type: Literal["l2l"] = "l2l" @@ -266,9 +293,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation, ModelGetter): schema_extra = { "ui": { "tags": ["latents"], - "type_hints": { - "model": "model" - } }, } @@ -283,22 +307,35 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation, ModelGetter): def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - with self.get_model(context) as model: - conditioning_data = self.get_conditioning_data(model) + #unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) + unet_info = context.services.model_manager.get_model( + model_name=self.unet.unet.model_name, + model_type=SDModelType[self.unet.unet.model_type], + submodel=SDModelType[self.unet.unet.submodel] if self.unet.unet.submodel else None, + ) + + with unet_info.context as unet: + scheduler = get_scheduler( + context=context, + scheduler_info=self.unet.scheduler, + scheduler_name=self.scheduler, + ) + + pipeline = self.create_pipeline(unet, scheduler) + conditioning_data = self.get_conditioning_data(context, scheduler) # TODO: Verify the noise is the right size - initial_latents = latent if self.strength < 1.0 else torch.zeros_like( - latent, device=model.device, dtype=latent.dtype + latent, device=unet.device, dtype=latent.dtype ) - timesteps, _ = model.get_img2img_timesteps( + timesteps, _ = pipeline.get_img2img_timesteps( self.steps, self.strength, - device=model.device, + device=unet.device, ) - result_latents, result_attention_map_saver = model.latents_from_embeddings( + result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( latents=initial_latents, timesteps=timesteps, noise=noise, @@ -318,23 +355,21 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation, ModelGetter): # Latent to image -class LatentsToImageInvocation(BaseInvocation, ModelGetter): +class LatentsToImageInvocation(BaseInvocation): """Generates an image from latents.""" type: Literal["l2i"] = "l2i" # Inputs latents: Optional[LatentsField] = Field(description="The latents to generate an image from") - model: str = Field(default="", description="The model to use") + vae: VaeField = Field(default=None, description="Vae submodel") + tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)") # Schema customisation class Config(InvocationConfig): schema_extra = { "ui": { "tags": ["latents", "image"], - "type_hints": { - "model": "model" - } }, } @@ -342,27 +377,45 @@ class LatentsToImageInvocation(BaseInvocation, ModelGetter): def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) - # TODO: this only really needs the vae - with self.get_model(context) as model: + #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) + vae_info = context.services.model_manager.get_model( + model_name=self.vae.vae.model_name, + model_type=SDModelType[self.vae.vae.model_type], + submodel=SDModelType[self.vae.vae.submodel] if self.vae.vae.submodel else None, + ) + + with vae_info.context as vae: + # TODO: check if it works + if self.tiled: + vae.enable_tiling() + else: + vae.disable_tiling() + with torch.inference_mode(): - np_image = model.decode_latents(latents) - image = model.numpy_to_pil(np_image)[0] + # copied from diffusers pipeline + latents = latents / vae.config.scaling_factor + image = vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) # denormalize + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + np_image = image.cpu().permute(0, 2, 3, 1).float().numpy() - image_type = ImageType.RESULT - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id - ) + image = VaeImageProcessor.numpy_to_pil(np_image)[0] - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self - ) + image_type = ImageType.RESULT + image_name = context.services.images.create_name( + context.graph_execution_state_id, self.id + ) - torch.cuda.empty_cache() + metadata = context.services.metadata.build_metadata( + session_id=context.graph_execution_state_id, node=self + ) - context.services.images.save(image_type, image_name, image, metadata) - return build_image_output( - image_type=image_type, image_name=image_name, image=image - ) + torch.cuda.empty_cache() + + context.services.images.save(image_type, image_name, image, metadata) + return build_image_output( + image_type=image_type, image_name=image_name, image=image + ) LATENTS_INTERPOLATION_MODE = Literal[ @@ -430,21 +483,21 @@ class ScaleLatentsInvocation(BaseInvocation): return LatentsOutput(latents=LatentsField(latents_name=name)) -class ImageToLatentsInvocation(BaseInvocation, ModelGetter): +class ImageToLatentsInvocation(BaseInvocation): """Encodes an image into latents.""" type: Literal["i2l"] = "i2l" # Inputs image: Union[ImageField, None] = Field(description="The image to encode") - model: str = Field(default="", description="The model to use") + vae: VaeField = Field(default=None, description="Vae submodel") + tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)") # Schema customisation class Config(InvocationConfig): schema_extra = { "ui": { "tags": ["latents", "image"], - "type_hints": {"model": "model"}, }, } @@ -454,22 +507,38 @@ class ImageToLatentsInvocation(BaseInvocation, ModelGetter): self.image.image_type, self.image.image_name ) - # TODO: this only really needs the vae - model_info = self.get_model(context) - model: StableDiffusionGeneratorPipeline = model_info["model"] + #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) + vae_info = context.services.model_manager.get_model( + model_name=self.vae.vae.model_name, + model_type=SDModelType[self.vae.vae.model_type], + submodel=SDModelType[self.vae.vae.submodel] if self.vae.vae.submodel else None, + ) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) - if image_tensor.dim() == 3: image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") - latents = model.non_noised_latents_from_image( - image_tensor, - device=model._model_group.device_for(model.unet), - dtype=model.unet.dtype, - ) + with vae_info.context as vae: + # TODO: check if it works + if self.tiled: + vae.enable_tiling() + else: + vae.disable_tiling() + + latents = self.non_noised_latents_from_image(vae, image_tensor) name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.set(name, latents) return LatentsOutput(latents=LatentsField(latents_name=name)) + + def non_noised_latents_from_image(self, vae, init_image): + init_image = init_image.to(device=vae.device, dtype=vae.dtype) + with torch.inference_mode(): + init_latent_dist = vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample().to( + dtype=vae.dtype + ) # FIXME: uses torch.randn. make reproducible! + + init_latents = 0.18215 * init_latents + return init_latents \ No newline at end of file