Resolving rebase conflict

This commit is contained in:
user1 2023-04-29 19:32:19 -07:00 committed by Kent Keirsey
parent b59a749627
commit ca0669c337

View File

@ -4,7 +4,9 @@ from functools import partial
from typing import Literal, Optional, Union, get_args from typing import Literal, Optional, Union, get_args
import numpy as np import numpy as np
from diffusers import ControlNetModel
from torch import Tensor from torch import Tensor
import torch
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -58,6 +60,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)") model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
control_model: Optional[str] = Field(default=None, description="The control model to use")
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
# fmt: on # fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching? # TODO: pass this an emitter method or something? or a session for dispatching?
@ -78,17 +83,35 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# Handle invalid model parameter # Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model) model = choose_model(context.services.model_manager, self.model)
# loading controlnet image (currently requires pre-processed image)
control_image = (
None if self.control_image is None
else context.services.images.get(
self.control_image.image_type, self.control_image.image_name
)
)
# loading controlnet model
if (self.control_model is None or self.control_model==''):
control_model = None
else:
# FIXME: change this to dropdown menu?
# FIXME: generalize so don't have to hardcode torch_dtype and device
control_model = ControlNetModel.from_pretrained(self.control_model,
torch_dtype=torch.float16).to("cuda")
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id context.graph_execution_state_id
) )
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Txt2Img(model).generate( txt2img = Txt2Img(model, control_model=control_model)
outputs = txt2img.generate(
prompt=self.prompt, prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context, source_node_id), step_callback=partial(self.dispatch_progress, context, source_node_id),
control_image=control_image,
**self.dict( **self.dict(
exclude={"prompt"} exclude={"prompt", "control_image" }
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually
) )
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object