Rewrite inpaint node to new model manager, remove TextToImage and ImageToImage nodes

This commit is contained in:
Sergey Borisov 2023-06-16 23:17:18 +03:00 committed by psychedelicious
parent f312e1448f
commit c26e1a9271
2 changed files with 133 additions and 175 deletions

View File

@ -18,6 +18,12 @@ from ..util.step_callback import stable_diffusion_step_callback
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
from .image import ImageOutput from .image import ImageOutput
import re
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from .model import UNetField, ClipField, VaeField
from contextlib import contextmanager, ExitStack, ContextDecorator
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())] SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())] INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = ( DEFAULT_INFILL_METHOD = (
@ -25,30 +31,38 @@ DEFAULT_INFILL_METHOD = (
) )
class SDImageInvocation(BaseModel): from .latent import get_scheduler
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
# Schema customisation class OldModelContext(ContextDecorator):
class Config(InvocationConfig): model: StableDiffusionGeneratorPipeline
schema_extra = {
"ui": { def __init__(self, model):
"tags": ["stable-diffusion", "image"], self.model = model
"type_hints": {
"model": "model", def __enter__(self):
}, return self.model
},
} def __exit__(self, *exc):
return False
class OldModelInfo:
name: str
hash: str
context: OldModelContext
def __init__(self, name: str, hash: str, model: StableDiffusionGeneratorPipeline):
self.name = name
self.hash = hash
self.context = OldModelContext(
model=model,
)
# Text to image class InpaintInvocation(BaseInvocation):
class TextToImageInvocation(BaseInvocation, SDImageInvocation): """Generates an image using inpaint."""
"""Generates an image using text2img."""
type: Literal["txt2img"] = "txt2img" type: Literal["inpaint"] = "inpaint"
# Inputs
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from") prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed) seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
@ -56,83 +70,13 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", ) height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)") #model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) #progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
control_model: Optional[str] = Field(default=None, description="The control model to use") #control_model: Optional[str] = Field(default=None, description="The control model to use")
control_image: Optional[ImageField] = Field(default=None, description="The processed control image") #control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
# fmt: on unet: UNetField = Field(default=None, description="UNet model")
clip: ClipField = Field(default=None, description="Clip model")
# TODO: pass this an emitter method or something? or a session for dispatching? vae: VaeField = Field(default=None, description="Vae model")
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
# Handle invalid model parameter
model = context.services.model_manager.get_model(self.model,node=self,context=context)
# loading controlnet image (currently requires pre-processed image)
control_image = (
None if self.control_image is None
else context.services.images.get_pil_image(self.control_image.image_name)
)
# loading controlnet model
if (self.control_model is None or self.control_model==''):
control_model = None
else:
# FIXME: change this to dropdown menu?
# FIXME: generalize so don't have to hardcode torch_dtype and device
control_model = ControlNetModel.from_pretrained(self.control_model,
torch_dtype=torch.float16).to("cuda")
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
txt2img = Txt2Img(model, control_model=control_model)
outputs = txt2img.generate(
prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context, source_node_id),
control_image=control_image,
**self.dict(
exclude={"prompt", "control_image" }
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generate_output = next(outputs)
image_dto = context.services.images.create(
image=generate_output.image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
class ImageToImageInvocation(TextToImageInvocation):
"""Generates an image using img2img."""
type: Literal["img2img"] = "img2img"
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The input image") image: Union[ImageField, None] = Field(description="The input image")
@ -144,72 +88,6 @@ class ImageToImageInvocation(TextToImageInvocation):
description="Whether or not the result should be fit to the aspect ratio of the input image", description="Whether or not the result should be fit to the aspect ratio of the input image",
) )
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
None
if self.image is None
else context.services.images.get_pil_image(self.image.image_name)
)
if self.fit:
image = image.resize((self.width, self.height))
# Handle invalid model parameter
model = context.services.model_manager.get_model(self.model,node=self,context=context)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Img2Img(model).generate(
prompt=self.prompt,
init_image=image,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
image_dto = context.services.images.create(
image=generator_output.image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
type: Literal["inpaint"] = "inpaint"
# Inputs # Inputs
mask: Union[ImageField, None] = Field(description="The mask") mask: Union[ImageField, None] = Field(description="The mask")
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)") seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
@ -252,6 +130,14 @@ class InpaintInvocation(ImageToImageInvocation):
description="The amount by which to replace masked areas with latent noise", description="The amount by which to replace masked areas with latent noise",
) )
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["stable-diffusion", "image"],
},
}
def dispatch_progress( def dispatch_progress(
self, self,
context: InvocationContext, context: InvocationContext,
@ -265,6 +151,79 @@ class InpaintInvocation(ImageToImageInvocation):
source_node_id=source_node_id, source_node_id=source_node_id,
) )
@contextmanager
def load_model_old_way(self, context):
with ExitStack() as stack:
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
tokenizer_info = context.services.model_manager.get_model(**self.clip.tokenizer.dict())
text_encoder_info = context.services.model_manager.get_model(**self.clip.text_encoder.dict())
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
#unet = stack.enter_context(unet_info)
#tokenizer = stack.enter_context(tokenizer_info)
#text_encoder = stack.enter_context(text_encoder_info)
#vae = stack.enter_context(vae_info)
with vae_info as vae:
device = vae.device
dtype = vae.dtype
# not load models to gpu as it should be handled by pipeline
unet = unet_info.context.model
tokenizer = tokenizer_info.context.model
text_encoder = text_encoder_info.context.model
vae = vae_info.context.model
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
stack.enter_context(
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
)
)
)
except Exception:
#print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
with ModelPatcher.apply_lora_unet(unet, loras),\
ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (ti_tokenizer, ti_manager):
pipeline = StableDiffusionGeneratorPipeline(
# TODO: ti_manager
vae=vae,
text_encoder=text_encoder,
tokenizer=ti_tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
precision="float16" if dtype == torch.float16 else "float32",
execution_device=device,
)
yield OldModelInfo(
name=self.unet.unet.model_name,
hash="<NO-HASH>",
model=pipeline,
)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = ( image = (
None None
@ -277,15 +236,13 @@ class InpaintInvocation(ImageToImageInvocation):
else context.services.images.get_pil_image(self.mask.image_name) else context.services.images.get_pil_image(self.mask.image_name)
) )
# Handle invalid model parameter
model = context.services.model_manager.get_model(self.model,node=self,context=context)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id context.graph_execution_state_id
) )
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
with self.load_model_old_way(context) as model:
outputs = Inpaint(model).generate( outputs = Inpaint(model).generate(
prompt=self.prompt, prompt=self.prompt,
init_image=image, init_image=image,

View File

@ -317,6 +317,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
requires_safety_checker: bool = False, requires_safety_checker: bool = False,
precision: str = "float32", precision: str = "float32",
control_model: ControlNetModel = None, control_model: ControlNetModel = None,
execution_device: Optional[torch.device] = None,
): ):
super().__init__( super().__init__(
vae, vae,
@ -356,7 +357,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
textual_inversion_manager=self.textual_inversion_manager, textual_inversion_manager=self.textual_inversion_manager,
) )
self._model_group = FullyLoadedModelGroup(self.unet.device) self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
self._model_group.install(*self._submodels) self._model_group.install(*self._submodels)
self.control_model = control_model self.control_model = control_model