feat(nodes): update invocation context for mm2, update nodes model usage

This commit is contained in:
psychedelicious
2024-02-15 20:43:41 +11:00
parent 88d6de4101
commit 539570cc7a
9 changed files with 141 additions and 147 deletions

View File

@ -141,7 +141,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
if self.image is not None:
image = context.services.images.get_pil_image(self.image.image_name)
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
@ -153,10 +153,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
)
if image_tensor is not None:
vae_info = context.services.model_manager.load.load_model_by_key(
**self.vae.vae.model_dump(),
context=context,
)
vae_info = context.models.load(**self.vae.vae.model_dump())
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
@ -182,10 +179,7 @@ def get_scheduler(
seed: int,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.services.model_manager.load.load_model_by_key(
**scheduler_info.model_dump(),
context=context,
)
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@ -399,12 +393,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# and if weight is None, populate with default 1.0?
controlnet_data = []
for control_info in control_list:
control_model = exit_stack.enter_context(
context.services.model_manager.load.load_model_by_key(
key=control_info.control_model.key,
context=context,
)
)
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
# control_models.append(control_model)
control_image_field = control_info.image
@ -466,25 +455,17 @@ class DenoiseLatentsInvocation(BaseInvocation):
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.services.model_manager.load.load_model_by_key(
key=single_ip_adapter.ip_adapter_model.key,
context=context,
)
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
)
image_encoder_model_info = context.services.model_manager.load.load_model_by_key(
key=single_ip_adapter.image_encoder_model.key,
context=context,
)
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
single_ipa_image_fields = [single_ipa_image_fields]
single_ipa_images = [
context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields
]
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
@ -528,10 +509,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_info = context.services.model_manager.load.load_model_by_key(
key=t2i_adapter_field.t2i_adapter_model.key,
context=context,
)
t2i_adapter_model_info = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
image = context.images.get_pil(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
@ -676,30 +654,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=True,
)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
# get the unet's config so that we can pass the base to dispatch_progress()
unet_config = context.services.model_manager.store.get_model(self.unet.unet.key)
unet_config = context.models.get_config(self.unet.unet.key)
def step_callback(state: PipelineIntermediateState) -> None:
self.dispatch_progress(context, source_node_id, state, unet_config.base)
context.util.sd_step_callback(state, unet_config.base)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.services.model_manager.load.load_model_by_key(
**lora.model_dump(exclude={"weight"}),
context=context,
)
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
yield (lora_info.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.load.load_model_by_key(
**self.unet.unet.model_dump(),
context=context,
)
unet_info = context.models.load(**self.unet.unet.model_dump())
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
@ -806,10 +774,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.services.model_manager.load.load_model_by_key(
**self.vae.vae.model_dump(),
context=context,
)
vae_info = context.models.load(**self.vae.vae.model_dump())
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module)
@ -1032,10 +997,7 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.services.model_manager.load.load_model_by_key(
**self.vae.vae.model_dump(),
context=context,
)
vae_info = context.models.load(**self.vae.vae.model_dump())
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
@ -1239,10 +1201,7 @@ class IdealSizeInvocation(BaseInvocation):
return tuple((x - x % multiple_of) for x in args)
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
unet_config = context.services.model_manager.load.load_model_by_key(
**self.unet.unet.model_dump(),
context=context,
)
unet_config = context.models.get_config(**self.unet.unet.model_dump())
aspect = self.width / self.height
dimension: float = 512
if unet_config.base == BaseModelType.StableDiffusion2: