# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) from functools import partial from typing import Literal, Optional, get_args import torch from pydantic import Field from invokeai.app.models.image import (ColorField, ImageCategory, ImageField, ResourceOrigin) from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.backend.generator.inpaint import infill_methods from ...backend.generator import Inpaint, InvokeAIGenerator from ...backend.stable_diffusion import PipelineIntermediateState from ..util.step_callback import stable_diffusion_step_callback from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext from .image import ImageOutput from ...backend.model_management.lora import ModelPatcher from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline from .model import UNetField, VaeField from .compel import ConditioningField from contextlib import contextmanager, ExitStack, ContextDecorator SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())] INFILL_METHODS = Literal[tuple(infill_methods())] DEFAULT_INFILL_METHOD = ( "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile" ) from .latent import get_scheduler class OldModelContext(ContextDecorator): model: StableDiffusionGeneratorPipeline def __init__(self, model): self.model = model def __enter__(self): return self.model def __exit__(self, *exc): return False class OldModelInfo: name: str hash: str context: OldModelContext def __init__(self, name: str, hash: str, model: StableDiffusionGeneratorPipeline): self.name = name self.hash = hash self.context = OldModelContext( model=model, ) class InpaintInvocation(BaseInvocation): """Generates an image using inpaint.""" type: Literal["inpaint"] = "inpaint" positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation") negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation") seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed) steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image") width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", ) height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", ) cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) unet: UNetField = Field(default=None, description="UNet model") vae: VaeField = Field(default=None, description="Vae model") # Inputs image: Optional[ImageField] = Field(description="The input image") strength: float = Field( default=0.75, gt=0, le=1, description="The strength of the original image" ) fit: bool = Field( default=True, description="Whether or not the result should be fit to the aspect ratio of the input image", ) # Inputs mask: Optional[ImageField] = Field(description="The mask") seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)") seam_blur: int = Field( default=16, ge=0, description="The seam inpaint blur radius (px)" ) seam_strength: float = Field( default=0.75, gt=0, le=1, description="The seam inpaint strength" ) seam_steps: int = Field( default=30, ge=1, description="The number of steps to use for seam inpaint" ) tile_size: int = Field( default=32, ge=1, description="The tile infill method size (px)" ) infill_method: INFILL_METHODS = Field( default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)", ) inpaint_width: Optional[int] = Field( default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)", ) inpaint_height: Optional[int] = Field( default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)", ) inpaint_fill: Optional[ColorField] = Field( default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color", ) inpaint_replace: float = Field( default=0.0, ge=0.0, le=1.0, description="The amount by which to replace masked areas with latent noise", ) # Schema customisation class Config(InvocationConfig): schema_extra = { "ui": { "tags": ["stable-diffusion", "image"], }, } def dispatch_progress( self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState, ) -> None: stable_diffusion_step_callback( context=context, intermediate_state=intermediate_state, node=self.dict(), source_node_id=source_node_id, ) def get_conditioning(self, context): c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) return (uc, c, extra_conditioning_info) @contextmanager def load_model_old_way(self, context, scheduler): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( **lora.dict(exclude={"weight"})) yield (lora_info.context.model, lora.weight) del lora_info return unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) with vae_info as vae,\ ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info as unet: device = context.services.model_manager.mgr.cache.execution_device dtype = context.services.model_manager.mgr.cache.precision pipeline = StableDiffusionGeneratorPipeline( vae=vae, text_encoder=None, tokenizer=None, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False, precision="float16" if dtype == torch.float16 else "float32", execution_device=device, ) yield OldModelInfo( name=self.unet.unet.model_name, hash="", model=pipeline, ) def invoke(self, context: InvocationContext) -> ImageOutput: image = ( None if self.image is None else context.services.images.get_pil_image(self.image.image_name) ) mask = ( None if self.mask is None else context.services.images.get_pil_image(self.mask.image_name) ) # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get( context.graph_execution_state_id ) source_node_id = graph_execution_state.prepared_source_mapping[self.id] conditioning = self.get_conditioning(context) scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, scheduler_name=self.scheduler, ) with self.load_model_old_way(context, scheduler) as model: outputs = Inpaint(model).generate( conditioning=conditioning, scheduler=scheduler, init_image=image, mask_image=mask, step_callback=partial(self.dispatch_progress, context, source_node_id), **self.dict( exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"} ), # Shorthand for passing all of the parameters above manually ) # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object # each time it is called. We only need the first one. generator_output = next(outputs) image_dto = context.services.images.create( image=generator_output.image, image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, session_id=context.graph_execution_state_id, node_id=self.id, is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, )