diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 44280c3b41..0385c6a9f0 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -4,7 +4,9 @@ from functools import partial from typing import Literal, Optional, Union, get_args import numpy as np +from diffusers import ControlNetModel from torch import Tensor +import torch from pydantic import BaseModel, Field @@ -58,6 +60,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) model: str = Field(default="", description="The model to use (currently ignored)") + progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) + control_model: Optional[str] = Field(default=None, description="The control model to use") + control_image: Optional[ImageField] = Field(default=None, description="The processed control image") # fmt: on # TODO: pass this an emitter method or something? or a session for dispatching? @@ -78,17 +83,35 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): # Handle invalid model parameter model = choose_model(context.services.model_manager, self.model) + # loading controlnet image (currently requires pre-processed image) + control_image = ( + None if self.control_image is None + else context.services.images.get( + self.control_image.image_type, self.control_image.image_name + ) + ) + # loading controlnet model + if (self.control_model is None or self.control_model==''): + control_model = None + else: + # FIXME: change this to dropdown menu? + # FIXME: generalize so don't have to hardcode torch_dtype and device + control_model = ControlNetModel.from_pretrained(self.control_model, + torch_dtype=torch.float16).to("cuda") + # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get( context.graph_execution_state_id ) source_node_id = graph_execution_state.prepared_source_mapping[self.id] - outputs = Txt2Img(model).generate( + txt2img = Txt2Img(model, control_model=control_model) + outputs = txt2img.generate( prompt=self.prompt, step_callback=partial(self.dispatch_progress, context, source_node_id), + control_image=control_image, **self.dict( - exclude={"prompt"} + exclude={"prompt", "control_image" } ), # Shorthand for passing all of the parameters above manually ) # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object