From c26e1a9271f800b5f332066f29d9867ee6aabcdd Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Fri, 16 Jun 2023 23:17:18 +0300 Subject: [PATCH] Rewrite inpaint node to new model manager, remove TextToImage and ImageToImage nodes --- invokeai/app/invocations/generate.py | 305 ++++++++---------- .../stable_diffusion/diffusers_pipeline.py | 3 +- 2 files changed, 133 insertions(+), 175 deletions(-) diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 21574c7323..83220d89ef 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -18,6 +18,12 @@ from ..util.step_callback import stable_diffusion_step_callback from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext from .image import ImageOutput +import re +from ...backend.model_management.lora import ModelPatcher +from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline +from .model import UNetField, ClipField, VaeField +from contextlib import contextmanager, ExitStack, ContextDecorator + SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())] INFILL_METHODS = Literal[tuple(infill_methods())] DEFAULT_INFILL_METHOD = ( @@ -25,30 +31,38 @@ DEFAULT_INFILL_METHOD = ( ) -class SDImageInvocation(BaseModel): - """Helper class to provide all Stable Diffusion raster image invocations with additional config""" +from .latent import get_scheduler - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["stable-diffusion", "image"], - "type_hints": { - "model": "model", - }, - }, - } +class OldModelContext(ContextDecorator): + model: StableDiffusionGeneratorPipeline + + def __init__(self, model): + self.model = model + + def __enter__(self): + return self.model + + def __exit__(self, *exc): + return False + +class OldModelInfo: + name: str + hash: str + context: OldModelContext + + def __init__(self, name: str, hash: str, model: StableDiffusionGeneratorPipeline): + self.name = name + self.hash = hash + self.context = OldModelContext( + model=model, + ) -# Text to image -class TextToImageInvocation(BaseInvocation, SDImageInvocation): - """Generates an image using text2img.""" +class InpaintInvocation(BaseInvocation): + """Generates an image using inpaint.""" - type: Literal["txt2img"] = "txt2img" + type: Literal["inpaint"] = "inpaint" - # Inputs - # TODO: consider making prompt optional to enable providing prompt through a link - # fmt: off prompt: Optional[str] = Field(description="The prompt to generate an image from") seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed) steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image") @@ -56,83 +70,13 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", ) cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) - model: str = Field(default="", description="The model to use (currently ignored)") - progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) - control_model: Optional[str] = Field(default=None, description="The control model to use") - control_image: Optional[ImageField] = Field(default=None, description="The processed control image") - # fmt: on - - # TODO: pass this an emitter method or something? or a session for dispatching? - def dispatch_progress( - self, - context: InvocationContext, - source_node_id: str, - intermediate_state: PipelineIntermediateState, - ) -> None: - stable_diffusion_step_callback( - context=context, - intermediate_state=intermediate_state, - node=self.dict(), - source_node_id=source_node_id, - ) - - def invoke(self, context: InvocationContext) -> ImageOutput: - # Handle invalid model parameter - model = context.services.model_manager.get_model(self.model,node=self,context=context) - - # loading controlnet image (currently requires pre-processed image) - control_image = ( - None if self.control_image is None - else context.services.images.get_pil_image(self.control_image.image_name) - ) - # loading controlnet model - if (self.control_model is None or self.control_model==''): - control_model = None - else: - # FIXME: change this to dropdown menu? - # FIXME: generalize so don't have to hardcode torch_dtype and device - control_model = ControlNetModel.from_pretrained(self.control_model, - torch_dtype=torch.float16).to("cuda") - - # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id - ) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] - - txt2img = Txt2Img(model, control_model=control_model) - outputs = txt2img.generate( - prompt=self.prompt, - step_callback=partial(self.dispatch_progress, context, source_node_id), - control_image=control_image, - **self.dict( - exclude={"prompt", "control_image" } - ), # Shorthand for passing all of the parameters above manually - ) - # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object - # each time it is called. We only need the first one. - generate_output = next(outputs) - - image_dto = context.services.images.create( - image=generate_output.image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - session_id=context.graph_execution_state_id, - node_id=self.id, - is_intermediate=self.is_intermediate, - ) - - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) - - -class ImageToImageInvocation(TextToImageInvocation): - """Generates an image using img2img.""" - - type: Literal["img2img"] = "img2img" + #model: str = Field(default="", description="The model to use (currently ignored)") + #progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) + #control_model: Optional[str] = Field(default=None, description="The control model to use") + #control_image: Optional[ImageField] = Field(default=None, description="The processed control image") + unet: UNetField = Field(default=None, description="UNet model") + clip: ClipField = Field(default=None, description="Clip model") + vae: VaeField = Field(default=None, description="Vae model") # Inputs image: Union[ImageField, None] = Field(description="The input image") @@ -144,72 +88,6 @@ class ImageToImageInvocation(TextToImageInvocation): description="Whether or not the result should be fit to the aspect ratio of the input image", ) - def dispatch_progress( - self, - context: InvocationContext, - source_node_id: str, - intermediate_state: PipelineIntermediateState, - ) -> None: - stable_diffusion_step_callback( - context=context, - intermediate_state=intermediate_state, - node=self.dict(), - source_node_id=source_node_id, - ) - - def invoke(self, context: InvocationContext) -> ImageOutput: - image = ( - None - if self.image is None - else context.services.images.get_pil_image(self.image.image_name) - ) - - if self.fit: - image = image.resize((self.width, self.height)) - - # Handle invalid model parameter - model = context.services.model_manager.get_model(self.model,node=self,context=context) - - # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id - ) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] - - outputs = Img2Img(model).generate( - prompt=self.prompt, - init_image=image, - step_callback=partial(self.dispatch_progress, context, source_node_id), - **self.dict( - exclude={"prompt", "image", "mask"} - ), # Shorthand for passing all of the parameters above manually - ) - - # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object - # each time it is called. We only need the first one. - generator_output = next(outputs) - - image_dto = context.services.images.create( - image=generator_output.image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - session_id=context.graph_execution_state_id, - node_id=self.id, - is_intermediate=self.is_intermediate, - ) - - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) - - -class InpaintInvocation(ImageToImageInvocation): - """Generates an image using inpaint.""" - - type: Literal["inpaint"] = "inpaint" - # Inputs mask: Union[ImageField, None] = Field(description="The mask") seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)") @@ -252,6 +130,14 @@ class InpaintInvocation(ImageToImageInvocation): description="The amount by which to replace masked areas with latent noise", ) + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["stable-diffusion", "image"], + }, + } + def dispatch_progress( self, context: InvocationContext, @@ -265,6 +151,79 @@ class InpaintInvocation(ImageToImageInvocation): source_node_id=source_node_id, ) + @contextmanager + def load_model_old_way(self, context): + with ExitStack() as stack: + unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) + tokenizer_info = context.services.model_manager.get_model(**self.clip.tokenizer.dict()) + text_encoder_info = context.services.model_manager.get_model(**self.clip.text_encoder.dict()) + vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) + + #unet = stack.enter_context(unet_info) + #tokenizer = stack.enter_context(tokenizer_info) + #text_encoder = stack.enter_context(text_encoder_info) + #vae = stack.enter_context(vae_info) + with vae_info as vae: + device = vae.device + dtype = vae.dtype + + # not load models to gpu as it should be handled by pipeline + unet = unet_info.context.model + tokenizer = tokenizer_info.context.model + text_encoder = text_encoder_info.context.model + vae = vae_info.context.model + + scheduler = get_scheduler( + context=context, + scheduler_info=self.unet.scheduler, + scheduler_name=self.scheduler, + ) + + loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] + ti_list = [] + for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): + name = trigger[1:-1] + try: + ti_list.append( + stack.enter_context( + context.services.model_manager.get_model( + model_name=name, + base_model=self.clip.text_encoder.base_model, + model_type=ModelType.TextualInversion, + ) + ) + ) + except Exception: + #print(e) + #import traceback + #print(traceback.format_exc()) + print(f"Warn: trigger: \"{trigger}\" not found") + + + with ModelPatcher.apply_lora_unet(unet, loras),\ + ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\ + ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (ti_tokenizer, ti_manager): + + pipeline = StableDiffusionGeneratorPipeline( + # TODO: ti_manager + vae=vae, + text_encoder=text_encoder, + tokenizer=ti_tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + precision="float16" if dtype == torch.float16 else "float32", + execution_device=device, + ) + + yield OldModelInfo( + name=self.unet.unet.model_name, + hash="", + model=pipeline, + ) + def invoke(self, context: InvocationContext) -> ImageOutput: image = ( None @@ -277,24 +236,22 @@ class InpaintInvocation(ImageToImageInvocation): else context.services.images.get_pil_image(self.mask.image_name) ) - # Handle invalid model parameter - model = context.services.model_manager.get_model(self.model,node=self,context=context) - # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get( context.graph_execution_state_id ) source_node_id = graph_execution_state.prepared_source_mapping[self.id] - outputs = Inpaint(model).generate( - prompt=self.prompt, - init_image=image, - mask_image=mask, - step_callback=partial(self.dispatch_progress, context, source_node_id), - **self.dict( - exclude={"prompt", "image", "mask"} - ), # Shorthand for passing all of the parameters above manually - ) + with self.load_model_old_way(context) as model: + outputs = Inpaint(model).generate( + prompt=self.prompt, + init_image=image, + mask_image=mask, + step_callback=partial(self.dispatch_progress, context, source_node_id), + **self.dict( + exclude={"prompt", "image", "mask"} + ), # Shorthand for passing all of the parameters above manually + ) # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object # each time it is called. We only need the first one. diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index f4afd880d3..2922238af9 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -317,6 +317,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): requires_safety_checker: bool = False, precision: str = "float32", control_model: ControlNetModel = None, + execution_device: Optional[torch.device] = None, ): super().__init__( vae, @@ -356,7 +357,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): textual_inversion_manager=self.textual_inversion_manager, ) - self._model_group = FullyLoadedModelGroup(self.unet.device) + self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device) self._model_group.install(*self._submodels) self.control_model = control_model