2023-04-06 04:06:05 +00:00
|
|
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
2024-02-28 17:15:39 +00:00
|
|
|
import inspect
|
2023-07-08 09:28:26 +00:00
|
|
|
from contextlib import ExitStack
|
2024-06-05 18:48:32 +00:00
|
|
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
2023-05-12 03:33:24 +00:00
|
|
|
|
2023-04-06 04:06:05 +00:00
|
|
|
import torch
|
2024-03-08 18:42:35 +00:00
|
|
|
import torchvision
|
2023-08-11 10:20:37 +00:00
|
|
|
import torchvision.transforms as T
|
2024-02-10 23:09:45 +00:00
|
|
|
from diffusers.configuration_utils import ConfigMixin
|
2023-11-09 20:06:01 +00:00
|
|
|
from diffusers.models.adapter import T2IAdapter
|
2024-02-10 23:09:45 +00:00
|
|
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
2024-05-01 07:00:06 +00:00
|
|
|
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
|
|
|
|
from diffusers.schedulers.scheduling_tcd import TCDScheduler
|
|
|
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
|
2024-04-27 19:12:06 +00:00
|
|
|
from pydantic import field_validator
|
2023-08-11 10:20:37 +00:00
|
|
|
from torchvision.transforms.functional import resize as tv_resize
|
2024-03-06 08:37:15 +00:00
|
|
|
from transformers import CLIPVisionModelWithProjection
|
2023-04-06 04:06:05 +00:00
|
|
|
|
2024-06-06 13:30:49 +00:00
|
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
2024-04-13 06:43:50 +00:00
|
|
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
2024-06-06 13:30:49 +00:00
|
|
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
2024-04-13 06:43:50 +00:00
|
|
|
from invokeai.app.invocations.fields import (
|
|
|
|
ConditioningField,
|
|
|
|
DenoiseMaskField,
|
|
|
|
FieldDescriptions,
|
|
|
|
Input,
|
|
|
|
InputField,
|
|
|
|
LatentsField,
|
|
|
|
UIType,
|
|
|
|
)
|
2023-09-06 17:36:00 +00:00
|
|
|
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
2024-06-06 13:30:49 +00:00
|
|
|
from invokeai.app.invocations.model import ModelIdentifierField, UNetField
|
2024-06-05 18:59:45 +00:00
|
|
|
from invokeai.app.invocations.primitives import LatentsOutput
|
2023-10-05 05:29:16 +00:00
|
|
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
2024-02-05 06:16:35 +00:00
|
|
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
2023-08-06 03:41:47 +00:00
|
|
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
2024-05-29 14:29:54 +00:00
|
|
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
2024-02-17 16:45:32 +00:00
|
|
|
from invokeai.backend.lora import LoRAModelRaw
|
2024-06-05 18:48:32 +00:00
|
|
|
from invokeai.backend.model_manager import BaseModelType
|
2024-02-17 16:45:32 +00:00
|
|
|
from invokeai.backend.model_patcher import ModelPatcher
|
2024-04-13 06:43:50 +00:00
|
|
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
2024-06-06 13:30:49 +00:00
|
|
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
|
|
|
ControlNetData,
|
|
|
|
StableDiffusionGeneratorPipeline,
|
|
|
|
T2IAdapterData,
|
|
|
|
)
|
2024-03-08 16:49:32 +00:00
|
|
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
2024-04-13 06:43:50 +00:00
|
|
|
BasicConditioningInfo,
|
|
|
|
IPAdapterConditioningInfo,
|
|
|
|
IPAdapterData,
|
|
|
|
Range,
|
|
|
|
SDXLConditioningInfo,
|
|
|
|
TextConditioningData,
|
|
|
|
TextConditioningRegions,
|
|
|
|
)
|
2024-06-06 13:30:49 +00:00
|
|
|
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
|
|
|
from invokeai.backend.util.devices import TorchDevice
|
2024-06-12 15:48:07 +00:00
|
|
|
from invokeai.backend.util.hotfixes import ControlNetModel
|
2024-04-08 19:07:49 +00:00
|
|
|
from invokeai.backend.util.mask import to_standard_float_mask
|
2024-02-06 03:56:32 +00:00
|
|
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
2023-08-11 10:20:37 +00:00
|
|
|
|
2023-04-06 04:06:05 +00:00
|
|
|
|
2023-05-13 13:08:03 +00:00
|
|
|
def get_scheduler(
|
2024-02-05 06:16:35 +00:00
|
|
|
context: InvocationContext,
|
2024-03-09 08:43:24 +00:00
|
|
|
scheduler_info: ModelIdentifierField,
|
2023-05-13 13:08:03 +00:00
|
|
|
scheduler_name: str,
|
2023-08-13 21:24:38 +00:00
|
|
|
seed: int,
|
2023-05-13 13:08:03 +00:00
|
|
|
) -> Scheduler:
|
2024-06-06 19:04:31 +00:00
|
|
|
"""Load a scheduler and apply some scheduler-specific overrides."""
|
|
|
|
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
|
|
|
|
# possible.
|
2023-07-28 13:46:44 +00:00
|
|
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
2024-03-06 08:37:15 +00:00
|
|
|
orig_scheduler_info = context.models.load(scheduler_info)
|
2023-05-14 00:06:26 +00:00
|
|
|
with orig_scheduler_info as orig_scheduler:
|
2023-05-13 13:08:03 +00:00
|
|
|
scheduler_config = orig_scheduler.config
|
2023-07-05 02:37:16 +00:00
|
|
|
|
2023-05-11 14:23:33 +00:00
|
|
|
if "_backup" in scheduler_config:
|
|
|
|
scheduler_config = scheduler_config["_backup"]
|
2023-07-05 17:00:43 +00:00
|
|
|
scheduler_config = {
|
|
|
|
**scheduler_config,
|
2024-02-10 23:09:45 +00:00
|
|
|
**scheduler_extra_config, # FIXME
|
2023-07-05 17:00:43 +00:00
|
|
|
"_backup": scheduler_config,
|
|
|
|
}
|
2023-08-13 21:24:38 +00:00
|
|
|
|
|
|
|
# make dpmpp_sde reproducable(seed can be passed only in initializer)
|
|
|
|
if scheduler_class is DPMSolverSDEScheduler:
|
|
|
|
scheduler_config["noise_sampler_seed"] = seed
|
|
|
|
|
2023-05-13 13:08:03 +00:00
|
|
|
scheduler = scheduler_class.from_config(scheduler_config)
|
2023-07-05 02:37:16 +00:00
|
|
|
|
2023-04-06 04:06:05 +00:00
|
|
|
# hack copied over from generate.py
|
2023-07-28 13:46:44 +00:00
|
|
|
if not hasattr(scheduler, "uses_inpainting_model"):
|
2023-04-06 04:06:05 +00:00
|
|
|
scheduler.uses_inpainting_model = lambda: False
|
2024-02-10 23:09:45 +00:00
|
|
|
assert isinstance(scheduler, Scheduler)
|
2023-04-06 04:06:05 +00:00
|
|
|
return scheduler
|
|
|
|
|
|
|
|
|
feat(nodes): move all invocation metadata (type, title, tags, category) to decorator
All invocation metadata (type, title, tags and category) are now defined in decorators.
The decorators add the `type: Literal["invocation_type"]: "invocation_type"` field to the invocation.
Category is a new invocation metadata, but it is not used by the frontend just yet.
- `@invocation()` decorator for invocations
```py
@invocation(
"sdxl_compel_prompt",
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
...
```
- `@invocation_output()` decorator for invocation outputs
```py
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput):
...
```
- update invocation docs
- add category to decorator
- regen frontend types
2023-08-30 08:35:12 +00:00
|
|
|
@invocation(
|
|
|
|
"denoise_latents",
|
|
|
|
title="Denoise Latents",
|
|
|
|
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
|
|
|
category="latents",
|
2024-03-19 11:08:16 +00:00
|
|
|
version="1.5.3",
|
feat(nodes): move all invocation metadata (type, title, tags, category) to decorator
All invocation metadata (type, title, tags and category) are now defined in decorators.
The decorators add the `type: Literal["invocation_type"]: "invocation_type"` field to the invocation.
Category is a new invocation metadata, but it is not used by the frontend just yet.
- `@invocation()` decorator for invocations
```py
@invocation(
"sdxl_compel_prompt",
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
...
```
- `@invocation_output()` decorator for invocation outputs
```py
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput):
...
```
- update invocation docs
- add category to decorator
- regen frontend types
2023-08-30 08:35:12 +00:00
|
|
|
)
|
2023-08-11 10:20:37 +00:00
|
|
|
class DenoiseLatentsInvocation(BaseInvocation):
|
|
|
|
"""Denoises noisy latents to decodable images"""
|
2023-04-06 04:06:05 +00:00
|
|
|
|
2024-03-08 18:42:35 +00:00
|
|
|
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
2023-08-22 06:23:20 +00:00
|
|
|
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
2023-08-14 03:23:09 +00:00
|
|
|
)
|
2024-03-08 18:42:35 +00:00
|
|
|
negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
2023-08-22 06:23:20 +00:00
|
|
|
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
2023-08-14 03:23:09 +00:00
|
|
|
)
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
noise: Optional[LatentsField] = InputField(
|
|
|
|
default=None,
|
|
|
|
description=FieldDescriptions.noise,
|
|
|
|
input=Input.Connection,
|
|
|
|
ui_order=3,
|
|
|
|
)
|
2023-08-14 03:23:09 +00:00
|
|
|
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
|
|
|
cfg_scale: Union[float, List[float]] = InputField(
|
2024-04-27 18:40:52 +00:00
|
|
|
default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale"
|
2023-08-11 10:20:37 +00:00
|
|
|
)
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
denoising_start: float = InputField(
|
|
|
|
default=0.0,
|
|
|
|
ge=0,
|
|
|
|
le=1,
|
|
|
|
description=FieldDescriptions.denoising_start,
|
|
|
|
)
|
2023-08-14 03:23:09 +00:00
|
|
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
2024-02-10 22:51:25 +00:00
|
|
|
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
default="euler",
|
|
|
|
description=FieldDescriptions.scheduler,
|
|
|
|
ui_type=UIType.Scheduler,
|
|
|
|
)
|
|
|
|
unet: UNetField = InputField(
|
|
|
|
description=FieldDescriptions.unet,
|
|
|
|
input=Input.Connection,
|
|
|
|
title="UNet",
|
|
|
|
ui_order=2,
|
2023-08-17 08:58:01 +00:00
|
|
|
)
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
feat: polymorphic fields
Initial support for polymorphic field types. Polymorphic types are a single of or list of a specific type. For example, `Union[str, list[str]]`.
Polymorphics do not yet have support for direct input in the UI (will come in the future). They will be forcibly set as Connection-only fields, in which case users will not be able to provide direct input to the field.
If a polymorphic should present as a singleton type - which would allow direct input - the node must provide an explicit type hint.
For example, `DenoiseLatents`' `CFG Scale` is polymorphic, but in the node editor, we want to present this as a number input. In the node definition, the field is given `ui_type=UIType.Float`, which tells the UI to treat this as a `float` field.
The connection validation logic will prevent connecting a collection to `CFG Scale` in this situation, because it is typed as `float`. The workaround is to disable validation from the settings to make this specific connection. A future improvement will resolve this.
This also introduces better support for collection field types. Like polymorphics, collection types are parsed automatically by the client and do not need any specific type hints.
Also like polymorphics, there is no support yet for direct input of collection types in the UI.
- Disabling validation in workflow editor now displays the visual hints for valid connections, but lets you connect to anything.
- Added `ui_order: int` to `InputField` and `OutputField`. The UI will use this, if present, to order fields in a node UI. See usage in `DenoiseLatents` for an example.
- Updated the field colors - duplicate colors have just been lightened a bit. It's not perfect but it was a quick fix.
- Field handles for collections are the same color as their single counterparts, but have a dark dot in the center of them.
- Field handles for polymorphics are a rounded square with dot in the middle.
- Removed all fields that just render `null` from `InputFieldRenderer`, replaced with a single fallback
- Removed logic in `zValidatedWorkflow`, which checked for existence of node templates for each node in a workflow. This logic introduced a circular dependency, due to importing the global redux `store` in order to get the node templates within a zod schema. It's actually fine to just leave this out entirely; The case of a missing node template is handled by the UI. Fixing it otherwise would introduce a substantial headache.
- Fixed the `ControlNetInvocation.control_model` field default, which was a string when it shouldn't have one.
2023-09-01 09:40:27 +00:00
|
|
|
default=None,
|
|
|
|
input=Input.Connection,
|
|
|
|
ui_order=5,
|
2023-08-22 06:23:20 +00:00
|
|
|
)
|
2023-09-21 21:46:05 +00:00
|
|
|
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
description=FieldDescriptions.ip_adapter,
|
|
|
|
title="IP-Adapter",
|
|
|
|
default=None,
|
|
|
|
input=Input.Connection,
|
|
|
|
ui_order=6,
|
|
|
|
)
|
|
|
|
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField(
|
|
|
|
description=FieldDescriptions.t2i_adapter,
|
|
|
|
title="T2I-Adapter",
|
|
|
|
default=None,
|
|
|
|
input=Input.Connection,
|
|
|
|
ui_order=7,
|
2023-09-06 17:36:00 +00:00
|
|
|
)
|
2023-11-30 09:55:20 +00:00
|
|
|
cfg_rescale_multiplier: float = InputField(
|
2024-01-02 20:41:59 +00:00
|
|
|
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
2023-11-30 09:55:20 +00:00
|
|
|
)
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
latents: Optional[LatentsField] = InputField(
|
feat(ui): add support for custom field types
Node authors may now create their own arbitrary/custom field types. Any pydantic model is supported.
Two notes:
1. Your field type's class name must be unique.
Suggest prefixing fields with something related to the node pack as a kind of namespace.
2. Custom field types function as connection-only fields.
For example, if your custom field has string attributes, you will not get a text input for that attribute when you give a node a field with your custom type.
This is the same behaviour as other complex fields that don't have custom UIs in the workflow editor - like, say, a string collection.
feat(ui): fix tooltips for custom types
We need to hold onto the original type of the field so they don't all just show up as "Unknown".
fix(ui): fix ts error with custom fields
feat(ui): custom field types connection validation
In the initial commit, a custom field's original type was added to the *field templates* only as `originalType`. Custom fields' `type` property was `"Custom"`*. This allowed for type safety throughout the UI logic.
*Actually, it was `"Unknown"`, but I changed it to custom for clarity.
Connection validation logic, however, uses the *field instance* of the node/field. Like the templates, *field instances* with custom types have their `type` set to `"Custom"`, but they didn't have an `originalType` property. As a result, all custom fields could be connected to all other custom fields.
To resolve this, we need to add `originalType` to the *field instances*, then switch the validation logic to use this instead of `type`.
This ended up needing a bit of fanagling:
- If we make `originalType` a required property on field instances, existing workflows will break during connection validation, because they won't have this property. We'd need a new layer of logic to migrate the workflows, adding the new `originalType` property.
While this layer is probably needed anyways, typing `originalType` as optional is much simpler. Workflow migration logic can come layer.
(Technically, we could remove all references to field types from the workflow files, and let the templates hold all this information. This feels like a significant change and I'm reluctant to do it now.)
- Because `originalType` is optional, anywhere we care about the type of a field, we need to use it over `type`. So there are a number of `field.originalType ?? field.type` expressions. This is a bit of a gotcha, we'll need to remember this in the future.
- We use `Array.prototype.includes()` often in the workflow editor, e.g. `COLLECTION_TYPES.includes(type)`. In these cases, the const array is of type `FieldType[]`, and `type` is is `FieldType`.
Because we now support custom types, the arg `type` is now widened from `FieldType` to `string`.
This causes a TS error. This behaviour is somewhat controversial (see https://github.com/microsoft/TypeScript/issues/14520). These expressions are now rewritten as `COLLECTION_TYPES.some((t) => t === type)` to satisfy TS. It's logically equivalent.
fix(ui): typo
feat(ui): add CustomCollection and CustomPolymorphic field types
feat(ui): add validation for CustomCollection & CustomPolymorphic types
- Update connection validation for custom types
- Use simple string parsing to determine if a field is a collection or polymorphic type.
- No longer need to keep a list of collection and polymorphic types.
- Added runtime checks in `baseinvocation.py` to ensure no fields are named in such a way that it could mess up the new parsing
chore(ui): remove errant console.log
fix(ui): rename 'nodes.currentConnectionFieldType' -> 'nodes.connectionStartFieldType'
This was confusingly named and kept tripping me up. Renamed to be consistent with the `reactflow` `ConnectionStartParams` type.
fix(ui): fix ts error
feat(nodes): add runtime check for custom field names
"Custom", "CustomCollection" and "CustomPolymorphic" are reserved field names.
chore(ui): add TODO for revising field type names
wip refactor fieldtype structured
wip refactor field types
wip refactor types
wip refactor types
fix node layout
refactor field types
chore: mypy
organisation
organisation
organisation
fix(nodes): fix field orig_required, field_kind and input statuses
feat(nodes): remove broken implementation of default_factory on InputField
Use of this could break connection validation due to the difference in node schemas required fields and invoke() required args.
Removed entirely for now. It wasn't ever actually used by the system, because all graphs always had values provided for fields where default_factory was used.
Also, pydantic is smart enough to not reuse the same object when specifying a default value - it clones the object first. So, the common pattern of `default_factory=list` is extraneous. It can just be `default=[]`.
fix(nodes): fix InputField name validation
workflow validation
validation
chore: ruff
feat(nodes): fix up baseinvocation comments
fix(ui): improve typing & logic of buildFieldInputTemplate
improved error handling in parseFieldType
fix: back compat for deprecated default_factory and UIType
feat(nodes): do not show node packs loaded log if none loaded
chore(ui): typegen
2023-11-17 00:32:35 +00:00
|
|
|
default=None,
|
|
|
|
description=FieldDescriptions.latents,
|
|
|
|
input=Input.Connection,
|
|
|
|
ui_order=4,
|
2023-10-05 05:29:16 +00:00
|
|
|
)
|
2023-08-26 17:50:13 +00:00
|
|
|
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
default=None,
|
|
|
|
description=FieldDescriptions.mask,
|
|
|
|
input=Input.Connection,
|
|
|
|
ui_order=8,
|
2023-08-11 10:20:37 +00:00
|
|
|
)
|
2023-04-06 04:06:05 +00:00
|
|
|
|
feat(api): chore: pydantic & fastapi upgrade
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
|
|
|
@field_validator("cfg_scale")
|
2024-04-27 19:12:06 +00:00
|
|
|
def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
|
Feat/easy param (#3504)
* Testing change to LatentsToText to allow setting different cfg_scale values per diffusion step.
* Adding first attempt at float param easing node, using Penner easing functions.
* Core implementation of ControlNet and MultiControlNet.
* Added support for ControlNet and MultiControlNet to legacy non-nodal Txt2Img in backend/generator. Although backend/generator will likely disappear by v3.x, right now they are very useful for testing core ControlNet and MultiControlNet functionality while node codebase is rapidly evolving.
* Added example of using ControlNet with legacy Txt2Img generator
* Resolving rebase conflict
* Added first controlnet preprocessor node for canny edge detection.
* Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node
* Switching to ControlField for output from controlnet nodes.
* Resolving conflicts in rebase to origin/main
* Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke())
* changes to base class for controlnet nodes
* Added HED, LineArt, and OpenPose ControlNet nodes
* Added an additional "raw_processed_image" output port to controlnets, mainly so could route ImageField to a ShowImage node
* Added more preprocessor nodes for:
MidasDepth
ZoeDepth
MLSD
NormalBae
Pidi
LineartAnime
ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
* Prep for splitting pre-processor and controlnet nodes
* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.
* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.
* More rebase repair.
* Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this...
* Fixed use of ControlNet control_weight parameter
* Fixed lint-ish formatting error
* Core implementation of ControlNet and MultiControlNet.
* Added first controlnet preprocessor node for canny edge detection.
* Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node
* Switching to ControlField for output from controlnet nodes.
* Refactored controlnet node to output ControlField that bundles control info.
* changes to base class for controlnet nodes
* Added more preprocessor nodes for:
MidasDepth
ZoeDepth
MLSD
NormalBae
Pidi
LineartAnime
ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
* Prep for splitting pre-processor and controlnet nodes
* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.
* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.
* Cleaning up TextToLatent arg testing
* Cleaning up mistakes after rebase.
* Removed last bits of dtype and and device hardwiring from controlnet section
* Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled.
* Added support for specifying which step iteration to start using
each ControlNet, and which step to end using each controlnet (specified as fraction of total steps)
* Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input.
* Added dependency on controlnet-aux v0.0.3
* Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it.
* Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names.
* Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle.
* Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier.
* Cleaning up after ControlNet refactor in TextToLatentsInvocation
* Extended node-based ControlNet support to LatentsToLatentsInvocation.
* chore(ui): regen api client
* fix(ui): add value to conditioning field
* fix(ui): add control field type
* fix(ui): fix node ui type hints
* fix(nodes): controlnet input accepts list or single controlnet
* Moved to controlnet_aux v0.0.4, reinstated Zoe controlnet preprocessor. Also in pyproject.toml had to specify downgrade of timm to 0.6.13 _after_ controlnet-aux installs timm >= 0.9.2, because timm >0.6.13 breaks Zoe preprocessor.
* Core implementation of ControlNet and MultiControlNet.
* Added first controlnet preprocessor node for canny edge detection.
* Switching to ControlField for output from controlnet nodes.
* Resolving conflicts in rebase to origin/main
* Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke())
* changes to base class for controlnet nodes
* Added HED, LineArt, and OpenPose ControlNet nodes
* Added more preprocessor nodes for:
MidasDepth
ZoeDepth
MLSD
NormalBae
Pidi
LineartAnime
ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
* Prep for splitting pre-processor and controlnet nodes
* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.
* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.
* Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this...
* Fixed use of ControlNet control_weight parameter
* Core implementation of ControlNet and MultiControlNet.
* Added first controlnet preprocessor node for canny edge detection.
* Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node
* Switching to ControlField for output from controlnet nodes.
* Refactored controlnet node to output ControlField that bundles control info.
* changes to base class for controlnet nodes
* Added more preprocessor nodes for:
MidasDepth
ZoeDepth
MLSD
NormalBae
Pidi
LineartAnime
ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
* Prep for splitting pre-processor and controlnet nodes
* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.
* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.
* Cleaning up TextToLatent arg testing
* Cleaning up mistakes after rebase.
* Removed last bits of dtype and and device hardwiring from controlnet section
* Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled.
* Added support for specifying which step iteration to start using
each ControlNet, and which step to end using each controlnet (specified as fraction of total steps)
* Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input.
* Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it.
* Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names.
* Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle.
* Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier.
* Cleaning up after ControlNet refactor in TextToLatentsInvocation
* Extended node-based ControlNet support to LatentsToLatentsInvocation.
* chore(ui): regen api client
* fix(ui): fix node ui type hints
* fix(nodes): controlnet input accepts list or single controlnet
* Added Mediapipe image processor for use as ControlNet preprocessor.
Also hacked in ability to specify HF subfolder when loading ControlNet models from string.
* Fixed bug where MediapipFaceProcessorInvocation was ignoring max_faces and min_confidence params.
* Added nodes for float params: ParamFloatInvocation and FloatCollectionOutput. Also added FloatOutput.
* Added mediapipe install requirement. Should be able to remove once controlnet_aux package adds mediapipe to its requirements.
* Added float to FIELD_TYPE_MAP ins constants.ts
* Progress toward improvement in fieldTemplateBuilder.ts getFieldType()
* Fixed controlnet preprocessors and controlnet handling in TextToLatents to work with revised Image services.
* Cleaning up from merge, re-adding cfg_scale to FIELD_TYPE_MAP
* Making sure cfg_scale of type list[float] can be used in image metadata, to support param easing for cfg_scale
* Fixed math for per-step param easing.
* Added option to show plot of param value at each step
* Just cleaning up after adding param easing plot option, removing vestigial code.
* Modified control_weight ControlNet param to be polistmorphic --
can now be either a single float weight applied for all steps, or a list of floats of size total_steps, that specifies weight for each step.
* Added more informative error message when _validat_edge() throws an error.
* Just improving parm easing bar chart title to include easing type.
* Added requirement for easing-functions package
* Taking out some diagnostic prints.
* Added option to use both easing function and mirror of easing function together.
* Fixed recently introduced problem (when pulled in main), triggered by num_steps in StepParamEasingInvocation not having a default value -- just added default.
---------
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-06-11 06:27:44 +00:00
|
|
|
"""validate that all cfg_scale values are >= 1"""
|
|
|
|
if isinstance(v, list):
|
|
|
|
for i in v:
|
|
|
|
if i < 1:
|
2023-07-28 13:46:44 +00:00
|
|
|
raise ValueError("cfg_scale must be greater than 1")
|
Feat/easy param (#3504)
* Testing change to LatentsToText to allow setting different cfg_scale values per diffusion step.
* Adding first attempt at float param easing node, using Penner easing functions.
* Core implementation of ControlNet and MultiControlNet.
* Added support for ControlNet and MultiControlNet to legacy non-nodal Txt2Img in backend/generator. Although backend/generator will likely disappear by v3.x, right now they are very useful for testing core ControlNet and MultiControlNet functionality while node codebase is rapidly evolving.
* Added example of using ControlNet with legacy Txt2Img generator
* Resolving rebase conflict
* Added first controlnet preprocessor node for canny edge detection.
* Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node
* Switching to ControlField for output from controlnet nodes.
* Resolving conflicts in rebase to origin/main
* Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke())
* changes to base class for controlnet nodes
* Added HED, LineArt, and OpenPose ControlNet nodes
* Added an additional "raw_processed_image" output port to controlnets, mainly so could route ImageField to a ShowImage node
* Added more preprocessor nodes for:
MidasDepth
ZoeDepth
MLSD
NormalBae
Pidi
LineartAnime
ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
* Prep for splitting pre-processor and controlnet nodes
* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.
* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.
* More rebase repair.
* Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this...
* Fixed use of ControlNet control_weight parameter
* Fixed lint-ish formatting error
* Core implementation of ControlNet and MultiControlNet.
* Added first controlnet preprocessor node for canny edge detection.
* Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node
* Switching to ControlField for output from controlnet nodes.
* Refactored controlnet node to output ControlField that bundles control info.
* changes to base class for controlnet nodes
* Added more preprocessor nodes for:
MidasDepth
ZoeDepth
MLSD
NormalBae
Pidi
LineartAnime
ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
* Prep for splitting pre-processor and controlnet nodes
* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.
* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.
* Cleaning up TextToLatent arg testing
* Cleaning up mistakes after rebase.
* Removed last bits of dtype and and device hardwiring from controlnet section
* Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled.
* Added support for specifying which step iteration to start using
each ControlNet, and which step to end using each controlnet (specified as fraction of total steps)
* Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input.
* Added dependency on controlnet-aux v0.0.3
* Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it.
* Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names.
* Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle.
* Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier.
* Cleaning up after ControlNet refactor in TextToLatentsInvocation
* Extended node-based ControlNet support to LatentsToLatentsInvocation.
* chore(ui): regen api client
* fix(ui): add value to conditioning field
* fix(ui): add control field type
* fix(ui): fix node ui type hints
* fix(nodes): controlnet input accepts list or single controlnet
* Moved to controlnet_aux v0.0.4, reinstated Zoe controlnet preprocessor. Also in pyproject.toml had to specify downgrade of timm to 0.6.13 _after_ controlnet-aux installs timm >= 0.9.2, because timm >0.6.13 breaks Zoe preprocessor.
* Core implementation of ControlNet and MultiControlNet.
* Added first controlnet preprocessor node for canny edge detection.
* Switching to ControlField for output from controlnet nodes.
* Resolving conflicts in rebase to origin/main
* Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke())
* changes to base class for controlnet nodes
* Added HED, LineArt, and OpenPose ControlNet nodes
* Added more preprocessor nodes for:
MidasDepth
ZoeDepth
MLSD
NormalBae
Pidi
LineartAnime
ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
* Prep for splitting pre-processor and controlnet nodes
* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.
* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.
* Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this...
* Fixed use of ControlNet control_weight parameter
* Core implementation of ControlNet and MultiControlNet.
* Added first controlnet preprocessor node for canny edge detection.
* Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node
* Switching to ControlField for output from controlnet nodes.
* Refactored controlnet node to output ControlField that bundles control info.
* changes to base class for controlnet nodes
* Added more preprocessor nodes for:
MidasDepth
ZoeDepth
MLSD
NormalBae
Pidi
LineartAnime
ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
* Prep for splitting pre-processor and controlnet nodes
* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.
* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.
* Cleaning up TextToLatent arg testing
* Cleaning up mistakes after rebase.
* Removed last bits of dtype and and device hardwiring from controlnet section
* Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled.
* Added support for specifying which step iteration to start using
each ControlNet, and which step to end using each controlnet (specified as fraction of total steps)
* Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input.
* Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it.
* Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names.
* Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle.
* Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier.
* Cleaning up after ControlNet refactor in TextToLatentsInvocation
* Extended node-based ControlNet support to LatentsToLatentsInvocation.
* chore(ui): regen api client
* fix(ui): fix node ui type hints
* fix(nodes): controlnet input accepts list or single controlnet
* Added Mediapipe image processor for use as ControlNet preprocessor.
Also hacked in ability to specify HF subfolder when loading ControlNet models from string.
* Fixed bug where MediapipFaceProcessorInvocation was ignoring max_faces and min_confidence params.
* Added nodes for float params: ParamFloatInvocation and FloatCollectionOutput. Also added FloatOutput.
* Added mediapipe install requirement. Should be able to remove once controlnet_aux package adds mediapipe to its requirements.
* Added float to FIELD_TYPE_MAP ins constants.ts
* Progress toward improvement in fieldTemplateBuilder.ts getFieldType()
* Fixed controlnet preprocessors and controlnet handling in TextToLatents to work with revised Image services.
* Cleaning up from merge, re-adding cfg_scale to FIELD_TYPE_MAP
* Making sure cfg_scale of type list[float] can be used in image metadata, to support param easing for cfg_scale
* Fixed math for per-step param easing.
* Added option to show plot of param value at each step
* Just cleaning up after adding param easing plot option, removing vestigial code.
* Modified control_weight ControlNet param to be polistmorphic --
can now be either a single float weight applied for all steps, or a list of floats of size total_steps, that specifies weight for each step.
* Added more informative error message when _validat_edge() throws an error.
* Just improving parm easing bar chart title to include easing type.
* Added requirement for easing-functions package
* Taking out some diagnostic prints.
* Added option to use both easing function and mirror of easing function together.
* Fixed recently introduced problem (when pulled in main), triggered by num_steps in StepParamEasingInvocation not having a default value -- just added default.
---------
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-06-11 06:27:44 +00:00
|
|
|
else:
|
|
|
|
if v < 1:
|
2023-07-28 13:46:44 +00:00
|
|
|
raise ValueError("cfg_scale must be greater than 1")
|
Feat/easy param (#3504)
* Testing change to LatentsToText to allow setting different cfg_scale values per diffusion step.
* Adding first attempt at float param easing node, using Penner easing functions.
* Core implementation of ControlNet and MultiControlNet.
* Added support for ControlNet and MultiControlNet to legacy non-nodal Txt2Img in backend/generator. Although backend/generator will likely disappear by v3.x, right now they are very useful for testing core ControlNet and MultiControlNet functionality while node codebase is rapidly evolving.
* Added example of using ControlNet with legacy Txt2Img generator
* Resolving rebase conflict
* Added first controlnet preprocessor node for canny edge detection.
* Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node
* Switching to ControlField for output from controlnet nodes.
* Resolving conflicts in rebase to origin/main
* Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke())
* changes to base class for controlnet nodes
* Added HED, LineArt, and OpenPose ControlNet nodes
* Added an additional "raw_processed_image" output port to controlnets, mainly so could route ImageField to a ShowImage node
* Added more preprocessor nodes for:
MidasDepth
ZoeDepth
MLSD
NormalBae
Pidi
LineartAnime
ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
* Prep for splitting pre-processor and controlnet nodes
* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.
* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.
* More rebase repair.
* Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this...
* Fixed use of ControlNet control_weight parameter
* Fixed lint-ish formatting error
* Core implementation of ControlNet and MultiControlNet.
* Added first controlnet preprocessor node for canny edge detection.
* Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node
* Switching to ControlField for output from controlnet nodes.
* Refactored controlnet node to output ControlField that bundles control info.
* changes to base class for controlnet nodes
* Added more preprocessor nodes for:
MidasDepth
ZoeDepth
MLSD
NormalBae
Pidi
LineartAnime
ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
* Prep for splitting pre-processor and controlnet nodes
* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.
* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.
* Cleaning up TextToLatent arg testing
* Cleaning up mistakes after rebase.
* Removed last bits of dtype and and device hardwiring from controlnet section
* Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled.
* Added support for specifying which step iteration to start using
each ControlNet, and which step to end using each controlnet (specified as fraction of total steps)
* Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input.
* Added dependency on controlnet-aux v0.0.3
* Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it.
* Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names.
* Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle.
* Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier.
* Cleaning up after ControlNet refactor in TextToLatentsInvocation
* Extended node-based ControlNet support to LatentsToLatentsInvocation.
* chore(ui): regen api client
* fix(ui): add value to conditioning field
* fix(ui): add control field type
* fix(ui): fix node ui type hints
* fix(nodes): controlnet input accepts list or single controlnet
* Moved to controlnet_aux v0.0.4, reinstated Zoe controlnet preprocessor. Also in pyproject.toml had to specify downgrade of timm to 0.6.13 _after_ controlnet-aux installs timm >= 0.9.2, because timm >0.6.13 breaks Zoe preprocessor.
* Core implementation of ControlNet and MultiControlNet.
* Added first controlnet preprocessor node for canny edge detection.
* Switching to ControlField for output from controlnet nodes.
* Resolving conflicts in rebase to origin/main
* Refactored ControlNet nodes so they subclass from PreprocessedControlInvocation, and only need to override run_processor(image) (instead of reimplementing invoke())
* changes to base class for controlnet nodes
* Added HED, LineArt, and OpenPose ControlNet nodes
* Added more preprocessor nodes for:
MidasDepth
ZoeDepth
MLSD
NormalBae
Pidi
LineartAnime
ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
* Prep for splitting pre-processor and controlnet nodes
* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.
* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.
* Added support for using multiple control nets. Unfortunately this breaks direct usage of Control node output port ==> TextToLatent control input port -- passing through a Collect node is now required. Working on fixing this...
* Fixed use of ControlNet control_weight parameter
* Core implementation of ControlNet and MultiControlNet.
* Added first controlnet preprocessor node for canny edge detection.
* Initial port of controlnet node support from generator-based TextToImageInvocation node to latent-based TextToLatentsInvocation node
* Switching to ControlField for output from controlnet nodes.
* Refactored controlnet node to output ControlField that bundles control info.
* changes to base class for controlnet nodes
* Added more preprocessor nodes for:
MidasDepth
ZoeDepth
MLSD
NormalBae
Pidi
LineartAnime
ContentShuffle
Removed pil_output options, ControlNet preprocessors should always output as PIL. Removed diagnostics and other general cleanup.
* Prep for splitting pre-processor and controlnet nodes
* Refactored controlnet nodes: split out controlnet stuff into separate node, stripped controlnet stuff form image processing/analysis nodes.
* Added resizing of controlnet image based on noise latent. Fixes a tensor mismatch issue.
* Cleaning up TextToLatent arg testing
* Cleaning up mistakes after rebase.
* Removed last bits of dtype and and device hardwiring from controlnet section
* Refactored ControNet support to consolidate multiple parameters into data struct. Also redid how multiple controlnets are handled.
* Added support for specifying which step iteration to start using
each ControlNet, and which step to end using each controlnet (specified as fraction of total steps)
* Cleaning up prior to submitting ControlNet PR. Mostly turning off diagnostic printing. Also fixed error when there is no controlnet input.
* Commented out ZoeDetector. Will re-instate once there's a controlnet-aux release that supports it.
* Switched CotrolNet node modelname input from free text to default list of popular ControlNet model names.
* Fix to work with current stable release of controlnet_aux (v0.0.3). Turned of pre-processor params that were added post v0.0.3. Also change defaults for shuffle.
* Refactored most of controlnet code into its own method to declutter TextToLatents.invoke(), and make upcoming integration with LatentsToLatents easier.
* Cleaning up after ControlNet refactor in TextToLatentsInvocation
* Extended node-based ControlNet support to LatentsToLatentsInvocation.
* chore(ui): regen api client
* fix(ui): fix node ui type hints
* fix(nodes): controlnet input accepts list or single controlnet
* Added Mediapipe image processor for use as ControlNet preprocessor.
Also hacked in ability to specify HF subfolder when loading ControlNet models from string.
* Fixed bug where MediapipFaceProcessorInvocation was ignoring max_faces and min_confidence params.
* Added nodes for float params: ParamFloatInvocation and FloatCollectionOutput. Also added FloatOutput.
* Added mediapipe install requirement. Should be able to remove once controlnet_aux package adds mediapipe to its requirements.
* Added float to FIELD_TYPE_MAP ins constants.ts
* Progress toward improvement in fieldTemplateBuilder.ts getFieldType()
* Fixed controlnet preprocessors and controlnet handling in TextToLatents to work with revised Image services.
* Cleaning up from merge, re-adding cfg_scale to FIELD_TYPE_MAP
* Making sure cfg_scale of type list[float] can be used in image metadata, to support param easing for cfg_scale
* Fixed math for per-step param easing.
* Added option to show plot of param value at each step
* Just cleaning up after adding param easing plot option, removing vestigial code.
* Modified control_weight ControlNet param to be polistmorphic --
can now be either a single float weight applied for all steps, or a list of floats of size total_steps, that specifies weight for each step.
* Added more informative error message when _validat_edge() throws an error.
* Just improving parm easing bar chart title to include easing type.
* Added requirement for easing-functions package
* Taking out some diagnostic prints.
* Added option to use both easing function and mirror of easing function together.
* Fixed recently introduced problem (when pulled in main), triggered by num_steps in StepParamEasingInvocation not having a default value -- just added default.
---------
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-06-11 06:27:44 +00:00
|
|
|
return v
|
|
|
|
|
2024-06-06 21:39:04 +00:00
|
|
|
@staticmethod
|
2024-03-08 18:42:35 +00:00
|
|
|
def _get_text_embeddings_and_masks(
|
|
|
|
cond_list: list[ConditioningField],
|
|
|
|
context: InvocationContext,
|
|
|
|
device: torch.device,
|
|
|
|
dtype: torch.dtype,
|
|
|
|
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
|
|
|
|
"""Get the text embeddings and masks from the input conditioning fields."""
|
|
|
|
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
|
|
|
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
|
|
|
for cond in cond_list:
|
|
|
|
cond_data = context.conditioning.load(cond.conditioning_name)
|
|
|
|
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
|
|
|
|
|
|
|
mask = cond.mask
|
|
|
|
if mask is not None:
|
2024-04-08 18:16:22 +00:00
|
|
|
mask = context.tensors.load(mask.tensor_name)
|
2024-03-08 18:42:35 +00:00
|
|
|
text_embeddings_masks.append(mask)
|
|
|
|
|
|
|
|
return text_embeddings, text_embeddings_masks
|
|
|
|
|
2024-06-06 21:39:04 +00:00
|
|
|
@staticmethod
|
2024-03-08 18:42:35 +00:00
|
|
|
def _preprocess_regional_prompt_mask(
|
2024-06-06 21:39:04 +00:00
|
|
|
mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
|
2024-03-08 18:42:35 +00:00
|
|
|
) -> torch.Tensor:
|
|
|
|
"""Preprocess a regional prompt mask to match the target height and width.
|
|
|
|
If mask is None, returns a mask of all ones with the target height and width.
|
|
|
|
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
|
|
|
|
|
|
|
|
Returns:
|
2024-04-08 19:07:49 +00:00
|
|
|
torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width).
|
2024-03-08 18:42:35 +00:00
|
|
|
"""
|
2024-04-08 19:07:49 +00:00
|
|
|
|
2024-03-08 18:42:35 +00:00
|
|
|
if mask is None:
|
2024-04-08 19:07:49 +00:00
|
|
|
return torch.ones((1, 1, target_height, target_width), dtype=dtype)
|
|
|
|
|
|
|
|
mask = to_standard_float_mask(mask, out_dtype=dtype)
|
2024-03-08 18:42:35 +00:00
|
|
|
|
|
|
|
tf = torchvision.transforms.Resize(
|
|
|
|
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
|
|
|
)
|
2024-04-08 16:27:57 +00:00
|
|
|
|
|
|
|
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
2024-03-08 18:42:35 +00:00
|
|
|
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
|
|
|
resized_mask = tf(mask)
|
|
|
|
return resized_mask
|
|
|
|
|
2024-06-06 21:39:04 +00:00
|
|
|
@staticmethod
|
2024-03-08 18:42:35 +00:00
|
|
|
def _concat_regional_text_embeddings(
|
|
|
|
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
|
|
|
masks: Optional[list[Optional[torch.Tensor]]],
|
|
|
|
latent_height: int,
|
|
|
|
latent_width: int,
|
2024-04-08 19:07:49 +00:00
|
|
|
dtype: torch.dtype,
|
2024-03-08 18:42:35 +00:00
|
|
|
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
|
|
|
|
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
|
|
|
|
if masks is None:
|
|
|
|
masks = [None] * len(text_conditionings)
|
|
|
|
assert len(text_conditionings) == len(masks)
|
|
|
|
|
|
|
|
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
|
|
|
|
|
|
|
all_masks_are_none = all(mask is None for mask in masks)
|
|
|
|
|
|
|
|
text_embedding = []
|
|
|
|
pooled_embedding = None
|
|
|
|
add_time_ids = None
|
|
|
|
cur_text_embedding_len = 0
|
|
|
|
processed_masks = []
|
|
|
|
embedding_ranges = []
|
|
|
|
|
|
|
|
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
|
|
|
|
mask = masks[prompt_idx]
|
|
|
|
|
|
|
|
if is_sdxl:
|
|
|
|
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
|
|
|
|
# prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
|
|
|
|
# global prompt information. In an ideal case, there should be exactly one global prompt without a
|
|
|
|
# mask, but we don't enforce this.
|
|
|
|
|
|
|
|
# HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
|
|
|
|
# fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
|
|
|
|
# them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
|
|
|
|
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
|
|
|
|
# pretty major breaking change to a popular node, so for now we use this hack.
|
|
|
|
if pooled_embedding is None or mask is None:
|
|
|
|
pooled_embedding = text_embedding_info.pooled_embeds
|
|
|
|
if add_time_ids is None or mask is None:
|
|
|
|
add_time_ids = text_embedding_info.add_time_ids
|
|
|
|
|
|
|
|
text_embedding.append(text_embedding_info.embeds)
|
|
|
|
if not all_masks_are_none:
|
|
|
|
embedding_ranges.append(
|
|
|
|
Range(
|
|
|
|
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
|
|
|
)
|
|
|
|
)
|
2024-04-08 19:07:49 +00:00
|
|
|
processed_masks.append(
|
2024-06-06 21:39:04 +00:00
|
|
|
DenoiseLatentsInvocation._preprocess_regional_prompt_mask(
|
|
|
|
mask, latent_height, latent_width, dtype=dtype
|
|
|
|
)
|
2024-04-08 19:07:49 +00:00
|
|
|
)
|
2024-03-08 18:42:35 +00:00
|
|
|
|
|
|
|
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
|
|
|
|
|
|
|
text_embedding = torch.cat(text_embedding, dim=1)
|
|
|
|
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
|
|
|
|
|
|
|
regions = None
|
|
|
|
if not all_masks_are_none:
|
|
|
|
regions = TextConditioningRegions(
|
|
|
|
masks=torch.cat(processed_masks, dim=1),
|
|
|
|
ranges=embedding_ranges,
|
|
|
|
)
|
|
|
|
|
|
|
|
if is_sdxl:
|
2024-04-30 19:50:53 +00:00
|
|
|
return (
|
|
|
|
SDXLConditioningInfo(embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids),
|
|
|
|
regions,
|
|
|
|
)
|
2024-03-11 22:22:49 +00:00
|
|
|
return BasicConditioningInfo(embeds=text_embedding), regions
|
2024-03-08 18:42:35 +00:00
|
|
|
|
2024-06-06 21:39:04 +00:00
|
|
|
@staticmethod
|
2023-07-05 02:37:16 +00:00
|
|
|
def get_conditioning_data(
|
2024-02-05 06:16:35 +00:00
|
|
|
context: InvocationContext,
|
2024-06-06 21:39:04 +00:00
|
|
|
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
|
|
|
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
2024-02-10 23:09:45 +00:00
|
|
|
unet: UNet2DConditionModel,
|
2024-03-08 18:42:35 +00:00
|
|
|
latent_height: int,
|
|
|
|
latent_width: int,
|
2024-06-06 21:39:04 +00:00
|
|
|
cfg_scale: float | list[float],
|
|
|
|
steps: int,
|
|
|
|
cfg_rescale_multiplier: float,
|
2024-03-08 16:49:32 +00:00
|
|
|
) -> TextConditioningData:
|
2024-06-06 21:39:04 +00:00
|
|
|
# Normalize positive_conditioning_field and negative_conditioning_field to lists.
|
|
|
|
cond_list = positive_conditioning_field
|
2024-03-08 18:42:35 +00:00
|
|
|
if not isinstance(cond_list, list):
|
|
|
|
cond_list = [cond_list]
|
2024-06-06 21:39:04 +00:00
|
|
|
uncond_list = negative_conditioning_field
|
2024-03-08 18:42:35 +00:00
|
|
|
if not isinstance(uncond_list, list):
|
|
|
|
uncond_list = [uncond_list]
|
|
|
|
|
2024-06-06 21:39:04 +00:00
|
|
|
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
2024-03-08 18:42:35 +00:00
|
|
|
cond_list, context, unet.device, unet.dtype
|
|
|
|
)
|
2024-06-06 21:39:04 +00:00
|
|
|
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
2024-03-08 18:42:35 +00:00
|
|
|
uncond_list, context, unet.device, unet.dtype
|
|
|
|
)
|
2023-07-16 03:24:24 +00:00
|
|
|
|
2024-06-06 21:39:04 +00:00
|
|
|
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
2024-03-08 18:42:35 +00:00
|
|
|
text_conditionings=cond_text_embeddings,
|
|
|
|
masks=cond_text_embedding_masks,
|
|
|
|
latent_height=latent_height,
|
|
|
|
latent_width=latent_width,
|
2024-04-08 19:07:49 +00:00
|
|
|
dtype=unet.dtype,
|
2024-03-08 18:42:35 +00:00
|
|
|
)
|
2024-06-06 21:39:04 +00:00
|
|
|
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
2024-03-08 18:42:35 +00:00
|
|
|
text_conditionings=uncond_text_embeddings,
|
|
|
|
masks=uncond_text_embedding_masks,
|
|
|
|
latent_height=latent_height,
|
|
|
|
latent_width=latent_width,
|
2024-04-08 19:07:49 +00:00
|
|
|
dtype=unet.dtype,
|
2024-03-08 18:42:35 +00:00
|
|
|
)
|
2023-04-25 01:21:03 +00:00
|
|
|
|
2024-06-06 21:39:04 +00:00
|
|
|
if isinstance(cfg_scale, list):
|
|
|
|
assert len(cfg_scale) == steps, "cfg_scale (list) must have the same length as the number of steps"
|
2024-04-27 19:12:06 +00:00
|
|
|
|
2024-03-08 16:49:32 +00:00
|
|
|
conditioning_data = TextConditioningData(
|
2024-03-08 18:42:35 +00:00
|
|
|
uncond_text=uncond_text_embedding,
|
|
|
|
cond_text=cond_text_embedding,
|
|
|
|
uncond_regions=uncond_regions,
|
|
|
|
cond_regions=cond_regions,
|
2024-06-06 21:39:04 +00:00
|
|
|
guidance_scale=cfg_scale,
|
|
|
|
guidance_rescale_multiplier=cfg_rescale_multiplier,
|
2023-06-18 21:34:01 +00:00
|
|
|
)
|
2023-04-06 04:06:05 +00:00
|
|
|
return conditioning_data
|
|
|
|
|
2024-06-06 21:39:04 +00:00
|
|
|
@staticmethod
|
2023-07-05 02:37:16 +00:00
|
|
|
def create_pipeline(
|
2024-02-10 23:09:45 +00:00
|
|
|
unet: UNet2DConditionModel,
|
|
|
|
scheduler: Scheduler,
|
2023-07-05 17:00:43 +00:00
|
|
|
) -> StableDiffusionGeneratorPipeline:
|
2023-05-13 13:08:03 +00:00
|
|
|
class FakeVae:
|
|
|
|
class FakeVaeConfig:
|
2024-02-10 23:09:45 +00:00
|
|
|
def __init__(self) -> None:
|
2023-05-13 13:08:03 +00:00
|
|
|
self.block_out_channels = [0]
|
2023-07-05 02:37:16 +00:00
|
|
|
|
2024-02-10 23:09:45 +00:00
|
|
|
def __init__(self) -> None:
|
2023-05-13 13:08:03 +00:00
|
|
|
self.config = FakeVae.FakeVaeConfig()
|
|
|
|
|
|
|
|
return StableDiffusionGeneratorPipeline(
|
2023-07-05 02:37:16 +00:00
|
|
|
vae=FakeVae(), # TODO: oh...
|
2023-05-13 13:08:03 +00:00
|
|
|
text_encoder=None,
|
|
|
|
tokenizer=None,
|
|
|
|
unet=unet,
|
|
|
|
scheduler=scheduler,
|
|
|
|
safety_checker=None,
|
|
|
|
feature_extractor=None,
|
|
|
|
requires_safety_checker=False,
|
|
|
|
)
|
2023-07-05 02:37:16 +00:00
|
|
|
|
2024-06-12 15:34:12 +00:00
|
|
|
@staticmethod
|
2023-06-13 21:26:37 +00:00
|
|
|
def prep_control_data(
|
2024-02-05 06:16:35 +00:00
|
|
|
context: InvocationContext,
|
2024-06-12 15:48:07 +00:00
|
|
|
control_input: ControlField | list[ControlField] | None,
|
2023-06-13 21:26:37 +00:00
|
|
|
latents_shape: List[int],
|
2023-07-05 17:00:43 +00:00
|
|
|
exit_stack: ExitStack,
|
2023-06-13 21:26:37 +00:00
|
|
|
do_classifier_free_guidance: bool = True,
|
2024-06-12 15:48:07 +00:00
|
|
|
) -> list[ControlNetData] | None:
|
|
|
|
# Normalize control_input to a list.
|
|
|
|
control_list: list[ControlField]
|
|
|
|
if isinstance(control_input, ControlField):
|
2023-05-18 00:23:21 +00:00
|
|
|
control_list = [control_input]
|
2024-06-12 15:48:07 +00:00
|
|
|
elif isinstance(control_input, list):
|
2023-05-18 00:23:21 +00:00
|
|
|
control_list = control_input
|
2024-06-12 15:48:07 +00:00
|
|
|
elif control_input is None:
|
|
|
|
control_list = []
|
2023-04-30 14:44:50 +00:00
|
|
|
else:
|
2024-06-12 15:48:07 +00:00
|
|
|
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
|
|
|
|
|
|
|
|
if len(control_list) == 0:
|
2023-09-06 17:36:00 +00:00
|
|
|
return None
|
|
|
|
|
2024-06-12 15:48:07 +00:00
|
|
|
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
|
|
|
_, _, latent_height, latent_width = latents_shape
|
|
|
|
control_height_resize = latent_height * LATENT_SCALE_FACTOR
|
|
|
|
control_width_resize = latent_width * LATENT_SCALE_FACTOR
|
|
|
|
|
|
|
|
controlnet_data: list[ControlNetData] = []
|
2023-09-06 17:36:00 +00:00
|
|
|
for control_info in control_list:
|
2024-03-06 08:37:15 +00:00
|
|
|
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
2024-06-12 15:48:07 +00:00
|
|
|
assert isinstance(control_model, ControlNetModel)
|
2023-07-05 17:00:43 +00:00
|
|
|
|
2023-09-06 17:36:00 +00:00
|
|
|
control_image_field = control_info.image
|
2024-01-13 12:23:16 +00:00
|
|
|
input_image = context.images.get_pil(control_image_field.image_name)
|
2023-09-06 17:36:00 +00:00
|
|
|
# self.image.image_type, self.image.image_name
|
|
|
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
|
|
|
# and add in batch_size, num_images_per_prompt?
|
|
|
|
# and do real check for classifier_free_guidance?
|
|
|
|
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
|
|
|
control_image = prepare_control_image(
|
|
|
|
image=input_image,
|
|
|
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
|
|
|
width=control_width_resize,
|
|
|
|
height=control_height_resize,
|
|
|
|
# batch_size=batch_size * num_images_per_prompt,
|
|
|
|
# num_images_per_prompt=num_images_per_prompt,
|
|
|
|
device=control_model.device,
|
|
|
|
dtype=control_model.dtype,
|
|
|
|
control_mode=control_info.control_mode,
|
|
|
|
resize_mode=control_info.resize_mode,
|
|
|
|
)
|
|
|
|
control_item = ControlNetData(
|
2024-06-12 15:48:07 +00:00
|
|
|
model=control_model,
|
2023-09-06 17:36:00 +00:00
|
|
|
image_tensor=control_image,
|
|
|
|
weight=control_info.control_weight,
|
|
|
|
begin_step_percent=control_info.begin_step_percent,
|
|
|
|
end_step_percent=control_info.end_step_percent,
|
|
|
|
control_mode=control_info.control_mode,
|
|
|
|
# any resizing needed should currently be happening in prepare_control_image(),
|
|
|
|
# but adding resize_mode to ControlNetData in case needed in the future
|
|
|
|
resize_mode=control_info.resize_mode,
|
|
|
|
)
|
|
|
|
controlnet_data.append(control_item)
|
|
|
|
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
2023-09-01 06:07:15 +00:00
|
|
|
|
2023-09-06 17:36:00 +00:00
|
|
|
return controlnet_data
|
|
|
|
|
2024-05-29 02:41:44 +00:00
|
|
|
def prep_ip_adapter_image_prompts(
|
|
|
|
self,
|
|
|
|
context: InvocationContext,
|
2024-05-29 14:29:54 +00:00
|
|
|
ip_adapters: List[IPAdapterField],
|
2024-05-29 02:41:44 +00:00
|
|
|
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
|
|
|
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
|
|
|
|
image_prompts = []
|
2024-05-29 14:29:54 +00:00
|
|
|
for single_ip_adapter in ip_adapters:
|
2024-05-29 02:41:44 +00:00
|
|
|
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
|
2024-05-29 14:29:54 +00:00
|
|
|
assert isinstance(ip_adapter_model, IPAdapter)
|
2024-05-29 02:41:44 +00:00
|
|
|
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
|
|
|
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
|
|
|
single_ipa_image_fields = single_ip_adapter.image
|
|
|
|
if not isinstance(single_ipa_image_fields, list):
|
|
|
|
single_ipa_image_fields = [single_ipa_image_fields]
|
|
|
|
|
|
|
|
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
|
|
|
|
with image_encoder_model_info as image_encoder_model:
|
|
|
|
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
|
|
|
# Get image embeddings from CLIP and ImageProjModel.
|
|
|
|
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
|
|
|
single_ipa_images, image_encoder_model
|
|
|
|
)
|
|
|
|
image_prompts.append((image_prompt_embeds, uncond_image_prompt_embeds))
|
|
|
|
|
|
|
|
return image_prompts
|
|
|
|
|
2023-09-06 17:36:00 +00:00
|
|
|
def prep_ip_adapter_data(
|
|
|
|
self,
|
2024-02-05 06:16:35 +00:00
|
|
|
context: InvocationContext,
|
2024-05-29 14:29:54 +00:00
|
|
|
ip_adapters: List[IPAdapterField],
|
|
|
|
image_prompts: List[Tuple[torch.Tensor, torch.Tensor]],
|
2023-09-12 23:09:10 +00:00
|
|
|
exit_stack: ExitStack,
|
2024-03-14 20:58:11 +00:00
|
|
|
latent_height: int,
|
|
|
|
latent_width: int,
|
|
|
|
dtype: torch.dtype,
|
2024-05-29 14:29:54 +00:00
|
|
|
) -> Optional[List[IPAdapterData]]:
|
|
|
|
"""If IP-Adapter is enabled, then this function loads the requisite models and adds the image prompt conditioning data."""
|
2023-09-21 21:46:05 +00:00
|
|
|
ip_adapter_data_list = []
|
2024-05-30 10:40:04 +00:00
|
|
|
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
|
2024-06-01 12:54:28 +00:00
|
|
|
ip_adapters, image_prompts, strict=True
|
2024-05-30 10:40:04 +00:00
|
|
|
):
|
2024-05-29 02:41:44 +00:00
|
|
|
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
|
2023-09-12 23:09:10 +00:00
|
|
|
|
2024-05-29 14:29:54 +00:00
|
|
|
mask_field = single_ip_adapter.mask
|
|
|
|
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
|
2024-03-14 20:58:11 +00:00
|
|
|
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
|
|
|
|
2023-09-21 21:46:05 +00:00
|
|
|
ip_adapter_data_list.append(
|
|
|
|
IPAdapterData(
|
|
|
|
ip_adapter_model=ip_adapter_model,
|
2023-10-06 18:37:05 +00:00
|
|
|
weight=single_ip_adapter.weight,
|
2024-04-13 05:39:45 +00:00
|
|
|
target_blocks=single_ip_adapter.target_blocks,
|
2023-10-06 18:37:05 +00:00
|
|
|
begin_step_percent=single_ip_adapter.begin_step_percent,
|
|
|
|
end_step_percent=single_ip_adapter.end_step_percent,
|
2024-02-28 18:49:02 +00:00
|
|
|
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
2024-03-14 20:58:11 +00:00
|
|
|
mask=mask,
|
2023-09-21 21:46:05 +00:00
|
|
|
)
|
2023-09-13 23:10:02 +00:00
|
|
|
)
|
|
|
|
|
2024-05-29 14:29:54 +00:00
|
|
|
return ip_adapter_data_list if len(ip_adapter_data_list) > 0 else None
|
2023-04-06 04:06:05 +00:00
|
|
|
|
2023-10-05 05:29:16 +00:00
|
|
|
def run_t2i_adapters(
|
|
|
|
self,
|
2024-02-05 06:16:35 +00:00
|
|
|
context: InvocationContext,
|
2023-10-05 05:29:16 +00:00
|
|
|
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
|
|
|
latents_shape: list[int],
|
|
|
|
do_classifier_free_guidance: bool,
|
|
|
|
) -> Optional[list[T2IAdapterData]]:
|
|
|
|
if t2i_adapter is None:
|
|
|
|
return None
|
|
|
|
|
|
|
|
# Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField.
|
|
|
|
if isinstance(t2i_adapter, T2IAdapterField):
|
|
|
|
t2i_adapter = [t2i_adapter]
|
|
|
|
|
|
|
|
if len(t2i_adapter) == 0:
|
|
|
|
return None
|
|
|
|
|
|
|
|
t2i_adapter_data = []
|
|
|
|
for t2i_adapter_field in t2i_adapter:
|
2024-03-06 08:37:15 +00:00
|
|
|
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
|
|
|
|
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
|
2024-01-13 12:23:16 +00:00
|
|
|
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
2023-10-05 05:29:16 +00:00
|
|
|
|
|
|
|
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
2024-02-16 11:51:47 +00:00
|
|
|
if t2i_adapter_model_config.base == BaseModelType.StableDiffusion1:
|
2023-10-05 05:29:16 +00:00
|
|
|
max_unet_downscale = 8
|
2024-02-16 11:51:47 +00:00
|
|
|
elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL:
|
2023-10-05 05:29:16 +00:00
|
|
|
max_unet_downscale = 4
|
|
|
|
else:
|
2024-02-16 11:51:47 +00:00
|
|
|
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")
|
2023-10-05 05:29:16 +00:00
|
|
|
|
|
|
|
t2i_adapter_model: T2IAdapter
|
2024-02-16 11:51:47 +00:00
|
|
|
with t2i_adapter_loaded_model as t2i_adapter_model:
|
2023-10-05 05:29:16 +00:00
|
|
|
total_downscale_factor = t2i_adapter_model.total_downscale_factor
|
|
|
|
|
|
|
|
# Resize the T2I-Adapter input image.
|
|
|
|
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
|
|
|
# result will match the latent image's dimensions after max_unet_downscale is applied.
|
|
|
|
t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
|
|
|
|
t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
|
|
|
|
|
|
|
|
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
|
|
|
|
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
|
|
|
|
# T2I-Adapter model.
|
|
|
|
#
|
|
|
|
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
|
|
|
|
# of the same requirements (e.g. preserving binary masks during resize).
|
|
|
|
t2i_image = prepare_control_image(
|
|
|
|
image=image,
|
|
|
|
do_classifier_free_guidance=False,
|
|
|
|
width=t2i_input_width,
|
|
|
|
height=t2i_input_height,
|
2024-02-06 03:56:32 +00:00
|
|
|
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
|
2023-10-05 05:29:16 +00:00
|
|
|
device=t2i_adapter_model.device,
|
|
|
|
dtype=t2i_adapter_model.dtype,
|
|
|
|
resize_mode=t2i_adapter_field.resize_mode,
|
|
|
|
)
|
|
|
|
|
|
|
|
adapter_state = t2i_adapter_model(t2i_image)
|
|
|
|
|
|
|
|
if do_classifier_free_guidance:
|
|
|
|
for idx, value in enumerate(adapter_state):
|
|
|
|
adapter_state[idx] = torch.cat([value] * 2, dim=0)
|
|
|
|
|
|
|
|
t2i_adapter_data.append(
|
|
|
|
T2IAdapterData(
|
|
|
|
adapter_state=adapter_state,
|
|
|
|
weight=t2i_adapter_field.weight,
|
|
|
|
begin_step_percent=t2i_adapter_field.begin_step_percent,
|
|
|
|
end_step_percent=t2i_adapter_field.end_step_percent,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
return t2i_adapter_data
|
|
|
|
|
2023-08-12 00:19:49 +00:00
|
|
|
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
2023-08-13 16:31:47 +00:00
|
|
|
# TODO: research more for second order schedulers timesteps
|
2024-06-07 15:00:37 +00:00
|
|
|
@staticmethod
|
2024-02-10 23:09:45 +00:00
|
|
|
def init_scheduler(
|
|
|
|
scheduler: Union[Scheduler, ConfigMixin],
|
|
|
|
device: torch.device,
|
|
|
|
steps: int,
|
|
|
|
denoising_start: float,
|
|
|
|
denoising_end: float,
|
2024-02-28 17:15:39 +00:00
|
|
|
seed: int,
|
2024-06-12 19:11:24 +00:00
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
2024-02-10 23:09:45 +00:00
|
|
|
assert isinstance(scheduler, ConfigMixin)
|
2023-08-14 02:14:05 +00:00
|
|
|
if scheduler.config.get("cpu_only", False):
|
2023-08-30 00:40:59 +00:00
|
|
|
scheduler.set_timesteps(steps, device="cpu")
|
2023-08-14 02:14:05 +00:00
|
|
|
timesteps = scheduler.timesteps.to(device=device)
|
|
|
|
else:
|
2023-08-30 00:40:59 +00:00
|
|
|
scheduler.set_timesteps(steps, device=device)
|
2023-08-14 02:14:05 +00:00
|
|
|
timesteps = scheduler.timesteps
|
2023-08-07 16:57:11 +00:00
|
|
|
|
2023-08-30 00:40:59 +00:00
|
|
|
# skip greater order timesteps
|
|
|
|
_timesteps = timesteps[:: scheduler.order]
|
2023-08-12 00:19:49 +00:00
|
|
|
|
2023-08-30 00:40:59 +00:00
|
|
|
# get start timestep index
|
2024-02-10 23:09:45 +00:00
|
|
|
t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start)))
|
2023-08-30 00:40:59 +00:00
|
|
|
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
|
2023-08-11 12:46:16 +00:00
|
|
|
|
2023-08-30 00:40:59 +00:00
|
|
|
# get end timestep index
|
2024-02-10 23:09:45 +00:00
|
|
|
t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end)))
|
2023-08-30 00:40:59 +00:00
|
|
|
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
|
|
|
|
|
|
|
|
# apply order to indexes
|
|
|
|
t_start_idx *= scheduler.order
|
|
|
|
t_end_idx *= scheduler.order
|
2023-08-12 00:19:49 +00:00
|
|
|
|
2023-08-30 00:40:59 +00:00
|
|
|
init_timestep = timesteps[t_start_idx : t_start_idx + 1]
|
|
|
|
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
2023-08-07 16:57:11 +00:00
|
|
|
|
2024-05-01 06:38:39 +00:00
|
|
|
scheduler_step_kwargs: Dict[str, Any] = {}
|
2024-02-28 17:15:39 +00:00
|
|
|
scheduler_step_signature = inspect.signature(scheduler.step)
|
|
|
|
if "generator" in scheduler_step_signature.parameters:
|
|
|
|
# At some point, someone decided that schedulers that accept a generator should use the original seed with
|
|
|
|
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
|
|
|
|
# reproducibility.
|
fix(nodes): blend latents with weight=0 with DPMSolverSDEScheduler
- Pass the seed from `latents_a` to the output latents. Fixed an issue where using `BlendLatentsInvocation` could result in different outputs during denoising even when the alpha or slerp weight was 0.
## Explanation
`LatentsField` has an optional `seed` field. During denoising, if this `seed` field is not present, we **fall back to 0 for the seed**. The seed is used during denoising in a few ways:
1. Initializing the scheduler.
The seed is used in two places in `invokeai/app/invocations/latent.py`.
The `get_scheduler()` utility function has special handling for `DPMSolverSDEScheduler`, which appears to need a seed for deterministic outputs.
`DenoiseLatentsInvocation.init_scheduler()` has special handling for schedulers that accept a generator - the generator needs to be seeded in a particular way. At the time of this commit, these are the Invoke-supported schedulers that need this seed:
- DDIMScheduler
- DDPMScheduler
- DPMSolverMultistepScheduler
- EulerAncestralDiscreteScheduler
- EulerDiscreteScheduler
- KDPM2AncestralDiscreteScheduler
- LCMScheduler
- TCDScheduler
2. Adding noise during inpainting.
If a mask is used for denoising, and we are not using an inpainting model, we add noise to the unmasked area. If, for some reason, we have a mask but no noise, the seed is used to add noise.
I wonder if we should instead assert that if a mask is provided, we also have noise.
This is done in `invokeai/backend/stable_diffusion/diffusers_pipeline.py` in `StableDiffusionGeneratorPipeline.latents_from_embeddings()`.
When we create noise to be used in denoising, we are expected to set `LatentsField.seed` to the seed used to create the noise. This introduces some awkwardness when we manipulate any "latents" that will be used for denoising. We have to pass the seed along for every operation.
If the wrong seed or no seed is passed along, we can get unexpected outputs during denoising. One notable case relates to blending latents (slerping tensors).
If we slerp two noise tensors (`LatentsField`s) _without_ passing along the seed from the source latents, when we denoise with a seed-dependent scheduler*, the schedulers use the fallback seed of 0 and we get the wrong output. This is most obvious when slerping with a weight of 0, in which case we expect the exact same output after denoising.
*It looks like only the DPMSolver* schedulers are affected, but I haven't tested all of them.
Passing the seed along in the output fixes this issue.
2024-06-03 22:33:42 +00:00
|
|
|
#
|
|
|
|
# These Invoke-supported schedulers accept a generator as of 2024-06-04:
|
|
|
|
# - DDIMScheduler
|
|
|
|
# - DDPMScheduler
|
|
|
|
# - DPMSolverMultistepScheduler
|
|
|
|
# - EulerAncestralDiscreteScheduler
|
|
|
|
# - EulerDiscreteScheduler
|
|
|
|
# - KDPM2AncestralDiscreteScheduler
|
|
|
|
# - LCMScheduler
|
|
|
|
# - TCDScheduler
|
2024-05-01 06:38:39 +00:00
|
|
|
scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)})
|
2024-05-01 07:00:06 +00:00
|
|
|
if isinstance(scheduler, TCDScheduler):
|
2024-05-01 06:38:39 +00:00
|
|
|
scheduler_step_kwargs.update({"eta": 1.0})
|
2024-02-28 17:15:39 +00:00
|
|
|
|
2024-06-12 17:39:34 +00:00
|
|
|
return timesteps, init_timestep, scheduler_step_kwargs
|
2023-08-07 16:57:11 +00:00
|
|
|
|
2024-02-10 23:09:45 +00:00
|
|
|
def prep_inpaint_mask(
|
|
|
|
self, context: InvocationContext, latents: torch.Tensor
|
2024-02-21 02:13:19 +00:00
|
|
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
|
2023-08-26 17:50:13 +00:00
|
|
|
if self.denoise_mask is None:
|
2024-02-21 02:13:19 +00:00
|
|
|
return None, None, False
|
2023-08-08 15:50:36 +00:00
|
|
|
|
2024-02-07 12:30:46 +00:00
|
|
|
mask = context.tensors.load(self.denoise_mask.mask_name)
|
2023-08-27 17:04:55 +00:00
|
|
|
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
2023-08-26 17:50:13 +00:00
|
|
|
if self.denoise_mask.masked_latents_name is not None:
|
2024-02-07 12:30:46 +00:00
|
|
|
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
|
2023-08-18 01:07:40 +00:00
|
|
|
else:
|
2024-03-08 04:39:45 +00:00
|
|
|
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
2023-08-18 01:07:40 +00:00
|
|
|
|
2024-02-21 02:13:19 +00:00
|
|
|
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
2023-08-08 15:50:36 +00:00
|
|
|
|
2024-06-06 19:10:04 +00:00
|
|
|
@staticmethod
|
|
|
|
def prepare_noise_and_latents(
|
|
|
|
context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None
|
2024-06-07 15:01:50 +00:00
|
|
|
) -> Tuple[int, torch.Tensor | None, torch.Tensor]:
|
2024-06-06 14:40:19 +00:00
|
|
|
noise = None
|
2024-06-06 19:10:04 +00:00
|
|
|
if noise_field is not None:
|
|
|
|
noise = context.tensors.load(noise_field.latents_name)
|
|
|
|
|
|
|
|
if latents_field is not None:
|
|
|
|
latents = context.tensors.load(latents_field.latents_name)
|
2024-06-06 14:40:19 +00:00
|
|
|
elif noise is not None:
|
|
|
|
latents = torch.zeros_like(noise)
|
|
|
|
else:
|
2024-06-06 19:16:34 +00:00
|
|
|
raise ValueError("'latents' or 'noise' must be provided!")
|
2023-10-05 05:29:16 +00:00
|
|
|
|
2024-06-06 19:16:34 +00:00
|
|
|
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
|
|
|
raise ValueError(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
|
|
|
|
|
|
|
# The seed comes from (in order of priority): the noise field, the latents field, or 0.
|
|
|
|
seed = 0
|
|
|
|
if noise_field is not None and noise_field.seed is not None:
|
|
|
|
seed = noise_field.seed
|
|
|
|
elif latents_field is not None and latents_field.seed is not None:
|
|
|
|
seed = latents_field.seed
|
|
|
|
else:
|
2024-06-06 14:40:19 +00:00
|
|
|
seed = 0
|
2023-07-05 02:37:16 +00:00
|
|
|
|
2024-06-06 19:10:04 +00:00
|
|
|
return seed, noise, latents
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
|
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
|
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
|
|
|
|
2024-06-06 14:40:19 +00:00
|
|
|
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
2024-03-08 18:42:35 +00:00
|
|
|
|
2024-06-06 14:40:19 +00:00
|
|
|
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
|
|
|
# below. Investigate whether this is appropriate.
|
|
|
|
t2i_adapter_data = self.run_t2i_adapters(
|
|
|
|
context,
|
|
|
|
self.t2i_adapter,
|
|
|
|
latents.shape,
|
|
|
|
do_classifier_free_guidance=True,
|
|
|
|
)
|
2023-05-13 13:08:03 +00:00
|
|
|
|
2024-06-06 14:40:19 +00:00
|
|
|
ip_adapters: List[IPAdapterField] = []
|
|
|
|
if self.ip_adapter is not None:
|
|
|
|
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
|
|
|
|
if isinstance(self.ip_adapter, list):
|
|
|
|
ip_adapters = self.ip_adapter
|
|
|
|
else:
|
|
|
|
ip_adapters = [self.ip_adapter]
|
|
|
|
|
|
|
|
# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
|
|
|
|
# a series of image conditioning embeddings. This is being done here rather than in the
|
|
|
|
# big model context below in order to use less VRAM on low-VRAM systems.
|
|
|
|
# The image prompts are then passed to prep_ip_adapter_data().
|
|
|
|
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
|
|
|
|
|
|
|
|
# get the unet's config so that we can pass the base to dispatch_progress()
|
|
|
|
unet_config = context.models.get_config(self.unet.unet.key)
|
|
|
|
|
|
|
|
def step_callback(state: PipelineIntermediateState) -> None:
|
|
|
|
context.util.sd_step_callback(state, unet_config.base)
|
|
|
|
|
|
|
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
|
|
|
for lora in self.unet.loras:
|
|
|
|
lora_info = context.models.load(lora.lora)
|
|
|
|
assert isinstance(lora_info.model, LoRAModelRaw)
|
|
|
|
yield (lora_info.model, lora.weight)
|
|
|
|
del lora_info
|
|
|
|
return
|
|
|
|
|
|
|
|
unet_info = context.models.load(self.unet.unet)
|
|
|
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
|
|
|
with (
|
|
|
|
ExitStack() as exit_stack,
|
|
|
|
unet_info.model_on_device() as (model_state_dict, unet),
|
|
|
|
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
|
|
|
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
|
|
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
|
|
|
ModelPatcher.apply_lora_unet(
|
|
|
|
unet,
|
|
|
|
loras=_lora_loader(),
|
|
|
|
model_state_dict=model_state_dict,
|
|
|
|
),
|
|
|
|
):
|
|
|
|
assert isinstance(unet, UNet2DConditionModel)
|
|
|
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
|
|
|
if noise is not None:
|
|
|
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
|
|
|
if mask is not None:
|
|
|
|
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
|
|
|
if masked_latents is not None:
|
|
|
|
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
|
|
|
|
|
|
|
|
scheduler = get_scheduler(
|
|
|
|
context=context,
|
|
|
|
scheduler_info=self.unet.scheduler,
|
|
|
|
scheduler_name=self.scheduler,
|
|
|
|
seed=seed,
|
|
|
|
)
|
2023-05-06 04:44:12 +00:00
|
|
|
|
2024-06-06 14:40:19 +00:00
|
|
|
pipeline = self.create_pipeline(unet, scheduler)
|
2023-05-06 04:44:12 +00:00
|
|
|
|
2024-06-06 14:40:19 +00:00
|
|
|
_, _, latent_height, latent_width = latents.shape
|
|
|
|
conditioning_data = self.get_conditioning_data(
|
2024-06-06 21:39:04 +00:00
|
|
|
context=context,
|
|
|
|
positive_conditioning_field=self.positive_conditioning,
|
|
|
|
negative_conditioning_field=self.negative_conditioning,
|
|
|
|
unet=unet,
|
|
|
|
latent_height=latent_height,
|
|
|
|
latent_width=latent_width,
|
|
|
|
cfg_scale=self.cfg_scale,
|
|
|
|
steps=self.steps,
|
|
|
|
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
2024-06-06 14:40:19 +00:00
|
|
|
)
|
2023-05-06 04:44:12 +00:00
|
|
|
|
2024-06-06 14:40:19 +00:00
|
|
|
controlnet_data = self.prep_control_data(
|
|
|
|
context=context,
|
|
|
|
control_input=self.control,
|
|
|
|
latents_shape=latents.shape,
|
|
|
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
|
|
|
do_classifier_free_guidance=True,
|
|
|
|
exit_stack=exit_stack,
|
|
|
|
)
|
|
|
|
|
|
|
|
ip_adapter_data = self.prep_ip_adapter_data(
|
|
|
|
context=context,
|
|
|
|
ip_adapters=ip_adapters,
|
|
|
|
image_prompts=image_prompts,
|
|
|
|
exit_stack=exit_stack,
|
|
|
|
latent_height=latent_height,
|
|
|
|
latent_width=latent_width,
|
|
|
|
dtype=unet.dtype,
|
|
|
|
)
|
|
|
|
|
2024-06-12 17:39:34 +00:00
|
|
|
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
2024-06-06 14:40:19 +00:00
|
|
|
scheduler,
|
|
|
|
device=unet.device,
|
|
|
|
steps=self.steps,
|
|
|
|
denoising_start=self.denoising_start,
|
|
|
|
denoising_end=self.denoising_end,
|
|
|
|
seed=seed,
|
|
|
|
)
|
|
|
|
|
|
|
|
result_latents = pipeline.latents_from_embeddings(
|
|
|
|
latents=latents,
|
|
|
|
timesteps=timesteps,
|
|
|
|
init_timestep=init_timestep,
|
|
|
|
noise=noise,
|
|
|
|
seed=seed,
|
|
|
|
mask=mask,
|
|
|
|
masked_latents=masked_latents,
|
|
|
|
gradient_mask=gradient_mask,
|
|
|
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
|
|
|
conditioning_data=conditioning_data,
|
|
|
|
control_data=controlnet_data,
|
|
|
|
ip_adapter_data=ip_adapter_data,
|
|
|
|
t2i_adapter_data=t2i_adapter_data,
|
|
|
|
callback=step_callback,
|
|
|
|
)
|
2023-04-06 04:06:05 +00:00
|
|
|
|
2024-06-06 14:40:19 +00:00
|
|
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
|
|
result_latents = result_latents.to("cpu")
|
|
|
|
TorchDevice.empty_cache()
|
2023-04-06 04:06:05 +00:00
|
|
|
|
2024-06-06 14:40:19 +00:00
|
|
|
name = context.tensors.save(tensor=result_latents)
|
2024-04-11 04:50:22 +00:00
|
|
|
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|