mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rewrite inpaint node to new model manager, remove TextToImage and ImageToImage nodes
This commit is contained in:
parent
f312e1448f
commit
c26e1a9271
@ -18,6 +18,12 @@ from ..util.step_callback import stable_diffusion_step_callback
|
|||||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
|
|
||||||
|
import re
|
||||||
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
|
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
|
from .model import UNetField, ClipField, VaeField
|
||||||
|
from contextlib import contextmanager, ExitStack, ContextDecorator
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||||
DEFAULT_INFILL_METHOD = (
|
DEFAULT_INFILL_METHOD = (
|
||||||
@ -25,30 +31,38 @@ DEFAULT_INFILL_METHOD = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDImageInvocation(BaseModel):
|
from .latent import get_scheduler
|
||||||
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
|
||||||
|
|
||||||
# Schema customisation
|
class OldModelContext(ContextDecorator):
|
||||||
class Config(InvocationConfig):
|
model: StableDiffusionGeneratorPipeline
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
def __init__(self, model):
|
||||||
"tags": ["stable-diffusion", "image"],
|
self.model = model
|
||||||
"type_hints": {
|
|
||||||
"model": "model",
|
def __enter__(self):
|
||||||
},
|
return self.model
|
||||||
},
|
|
||||||
}
|
def __exit__(self, *exc):
|
||||||
|
return False
|
||||||
|
|
||||||
|
class OldModelInfo:
|
||||||
|
name: str
|
||||||
|
hash: str
|
||||||
|
context: OldModelContext
|
||||||
|
|
||||||
|
def __init__(self, name: str, hash: str, model: StableDiffusionGeneratorPipeline):
|
||||||
|
self.name = name
|
||||||
|
self.hash = hash
|
||||||
|
self.context = OldModelContext(
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Text to image
|
class InpaintInvocation(BaseInvocation):
|
||||||
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
"""Generates an image using inpaint."""
|
||||||
"""Generates an image using text2img."""
|
|
||||||
|
|
||||||
type: Literal["txt2img"] = "txt2img"
|
type: Literal["inpaint"] = "inpaint"
|
||||||
|
|
||||||
# Inputs
|
|
||||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
|
||||||
# fmt: off
|
|
||||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
||||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||||
@ -56,83 +70,13 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
#model: str = Field(default="", description="The model to use (currently ignored)")
|
||||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
#progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||||
control_model: Optional[str] = Field(default=None, description="The control model to use")
|
#control_model: Optional[str] = Field(default=None, description="The control model to use")
|
||||||
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
|
#control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
|
||||||
# fmt: on
|
unet: UNetField = Field(default=None, description="UNet model")
|
||||||
|
clip: ClipField = Field(default=None, description="Clip model")
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
vae: VaeField = Field(default=None, description="Vae model")
|
||||||
def dispatch_progress(
|
|
||||||
self,
|
|
||||||
context: InvocationContext,
|
|
||||||
source_node_id: str,
|
|
||||||
intermediate_state: PipelineIntermediateState,
|
|
||||||
) -> None:
|
|
||||||
stable_diffusion_step_callback(
|
|
||||||
context=context,
|
|
||||||
intermediate_state=intermediate_state,
|
|
||||||
node=self.dict(),
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
# Handle invalid model parameter
|
|
||||||
model = context.services.model_manager.get_model(self.model,node=self,context=context)
|
|
||||||
|
|
||||||
# loading controlnet image (currently requires pre-processed image)
|
|
||||||
control_image = (
|
|
||||||
None if self.control_image is None
|
|
||||||
else context.services.images.get_pil_image(self.control_image.image_name)
|
|
||||||
)
|
|
||||||
# loading controlnet model
|
|
||||||
if (self.control_model is None or self.control_model==''):
|
|
||||||
control_model = None
|
|
||||||
else:
|
|
||||||
# FIXME: change this to dropdown menu?
|
|
||||||
# FIXME: generalize so don't have to hardcode torch_dtype and device
|
|
||||||
control_model = ControlNetModel.from_pretrained(self.control_model,
|
|
||||||
torch_dtype=torch.float16).to("cuda")
|
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
|
||||||
context.graph_execution_state_id
|
|
||||||
)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
||||||
|
|
||||||
txt2img = Txt2Img(model, control_model=control_model)
|
|
||||||
outputs = txt2img.generate(
|
|
||||||
prompt=self.prompt,
|
|
||||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
|
||||||
control_image=control_image,
|
|
||||||
**self.dict(
|
|
||||||
exclude={"prompt", "control_image" }
|
|
||||||
), # Shorthand for passing all of the parameters above manually
|
|
||||||
)
|
|
||||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
|
||||||
# each time it is called. We only need the first one.
|
|
||||||
generate_output = next(outputs)
|
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
|
||||||
image=generate_output.image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
node_id=self.id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageToImageInvocation(TextToImageInvocation):
|
|
||||||
"""Generates an image using img2img."""
|
|
||||||
|
|
||||||
type: Literal["img2img"] = "img2img"
|
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Union[ImageField, None] = Field(description="The input image")
|
image: Union[ImageField, None] = Field(description="The input image")
|
||||||
@ -144,72 +88,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
||||||
)
|
)
|
||||||
|
|
||||||
def dispatch_progress(
|
|
||||||
self,
|
|
||||||
context: InvocationContext,
|
|
||||||
source_node_id: str,
|
|
||||||
intermediate_state: PipelineIntermediateState,
|
|
||||||
) -> None:
|
|
||||||
stable_diffusion_step_callback(
|
|
||||||
context=context,
|
|
||||||
intermediate_state=intermediate_state,
|
|
||||||
node=self.dict(),
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
image = (
|
|
||||||
None
|
|
||||||
if self.image is None
|
|
||||||
else context.services.images.get_pil_image(self.image.image_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.fit:
|
|
||||||
image = image.resize((self.width, self.height))
|
|
||||||
|
|
||||||
# Handle invalid model parameter
|
|
||||||
model = context.services.model_manager.get_model(self.model,node=self,context=context)
|
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
|
||||||
context.graph_execution_state_id
|
|
||||||
)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
||||||
|
|
||||||
outputs = Img2Img(model).generate(
|
|
||||||
prompt=self.prompt,
|
|
||||||
init_image=image,
|
|
||||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
|
||||||
**self.dict(
|
|
||||||
exclude={"prompt", "image", "mask"}
|
|
||||||
), # Shorthand for passing all of the parameters above manually
|
|
||||||
)
|
|
||||||
|
|
||||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
|
||||||
# each time it is called. We only need the first one.
|
|
||||||
generator_output = next(outputs)
|
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
|
||||||
image=generator_output.image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
node_id=self.id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class InpaintInvocation(ImageToImageInvocation):
|
|
||||||
"""Generates an image using inpaint."""
|
|
||||||
|
|
||||||
type: Literal["inpaint"] = "inpaint"
|
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
mask: Union[ImageField, None] = Field(description="The mask")
|
mask: Union[ImageField, None] = Field(description="The mask")
|
||||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||||
@ -252,6 +130,14 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
description="The amount by which to replace masked areas with latent noise",
|
description="The amount by which to replace masked areas with latent noise",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"tags": ["stable-diffusion", "image"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
@ -265,6 +151,79 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def load_model_old_way(self, context):
|
||||||
|
with ExitStack() as stack:
|
||||||
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||||
|
tokenizer_info = context.services.model_manager.get_model(**self.clip.tokenizer.dict())
|
||||||
|
text_encoder_info = context.services.model_manager.get_model(**self.clip.text_encoder.dict())
|
||||||
|
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||||
|
|
||||||
|
#unet = stack.enter_context(unet_info)
|
||||||
|
#tokenizer = stack.enter_context(tokenizer_info)
|
||||||
|
#text_encoder = stack.enter_context(text_encoder_info)
|
||||||
|
#vae = stack.enter_context(vae_info)
|
||||||
|
with vae_info as vae:
|
||||||
|
device = vae.device
|
||||||
|
dtype = vae.dtype
|
||||||
|
|
||||||
|
# not load models to gpu as it should be handled by pipeline
|
||||||
|
unet = unet_info.context.model
|
||||||
|
tokenizer = tokenizer_info.context.model
|
||||||
|
text_encoder = text_encoder_info.context.model
|
||||||
|
vae = vae_info.context.model
|
||||||
|
|
||||||
|
scheduler = get_scheduler(
|
||||||
|
context=context,
|
||||||
|
scheduler_info=self.unet.scheduler,
|
||||||
|
scheduler_name=self.scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||||
|
ti_list = []
|
||||||
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||||
|
name = trigger[1:-1]
|
||||||
|
try:
|
||||||
|
ti_list.append(
|
||||||
|
stack.enter_context(
|
||||||
|
context.services.model_manager.get_model(
|
||||||
|
model_name=name,
|
||||||
|
base_model=self.clip.text_encoder.base_model,
|
||||||
|
model_type=ModelType.TextualInversion,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
#print(e)
|
||||||
|
#import traceback
|
||||||
|
#print(traceback.format_exc())
|
||||||
|
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||||
|
|
||||||
|
|
||||||
|
with ModelPatcher.apply_lora_unet(unet, loras),\
|
||||||
|
ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
|
||||||
|
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (ti_tokenizer, ti_manager):
|
||||||
|
|
||||||
|
pipeline = StableDiffusionGeneratorPipeline(
|
||||||
|
# TODO: ti_manager
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=ti_tokenizer,
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
precision="float16" if dtype == torch.float16 else "float32",
|
||||||
|
execution_device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield OldModelInfo(
|
||||||
|
name=self.unet.unet.model_name,
|
||||||
|
hash="<NO-HASH>",
|
||||||
|
model=pipeline,
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = (
|
image = (
|
||||||
None
|
None
|
||||||
@ -277,15 +236,13 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
else context.services.images.get_pil_image(self.mask.image_name)
|
else context.services.images.get_pil_image(self.mask.image_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle invalid model parameter
|
|
||||||
model = context.services.model_manager.get_model(self.model,node=self,context=context)
|
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
context.graph_execution_state_id
|
context.graph_execution_state_id
|
||||||
)
|
)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
|
with self.load_model_old_way(context) as model:
|
||||||
outputs = Inpaint(model).generate(
|
outputs = Inpaint(model).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_image=image,
|
init_image=image,
|
||||||
|
@ -317,6 +317,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
requires_safety_checker: bool = False,
|
requires_safety_checker: bool = False,
|
||||||
precision: str = "float32",
|
precision: str = "float32",
|
||||||
control_model: ControlNetModel = None,
|
control_model: ControlNetModel = None,
|
||||||
|
execution_device: Optional[torch.device] = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
vae,
|
vae,
|
||||||
@ -356,7 +357,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
textual_inversion_manager=self.textual_inversion_manager,
|
textual_inversion_manager=self.textual_inversion_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
|
||||||
self._model_group.install(*self._submodels)
|
self._model_group.install(*self._submodels)
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user