mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/ip-adapter
This commit is contained in:
commit
6bb378a101
@ -244,8 +244,12 @@ copy-paste the template above.
|
|||||||
We can use the `@invocation` decorator to provide some additional info to the
|
We can use the `@invocation` decorator to provide some additional info to the
|
||||||
UI, like a custom title, tags and category.
|
UI, like a custom title, tags and category.
|
||||||
|
|
||||||
|
We also encourage providing a version. This must be a
|
||||||
|
[semver](https://semver.org/) version string ("$MAJOR.$MINOR.$PATCH"). The UI
|
||||||
|
will let users know if their workflow is using a mismatched version of the node.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@invocation("resize", title="My Resizer", tags=["resize", "image"], category="My Invocations")
|
@invocation("resize", title="My Resizer", tags=["resize", "image"], category="My Invocations", version="1.0.0")
|
||||||
class ResizeInvocation(BaseInvocation):
|
class ResizeInvocation(BaseInvocation):
|
||||||
"""Resizes an image"""
|
"""Resizes an image"""
|
||||||
|
|
||||||
@ -279,8 +283,6 @@ take a look a at our [contributing nodes overview](contributingNodes).
|
|||||||
|
|
||||||
## Advanced
|
## Advanced
|
||||||
|
|
||||||
-->
|
|
||||||
|
|
||||||
### Custom Output Types
|
### Custom Output Types
|
||||||
|
|
||||||
Like with custom inputs, sometimes you might find yourself needing custom
|
Like with custom inputs, sometimes you might find yourself needing custom
|
||||||
|
@ -22,12 +22,26 @@ To use a community node graph, download the the `.json` node graph file and load
|
|||||||

|

|
||||||

|

|
||||||
|
|
||||||
|
--------------------------------
|
||||||
### Ideal Size
|
### Ideal Size
|
||||||
|
|
||||||
**Description:** This node calculates an ideal image size for a first pass of a multi-pass upscaling. The aim is to avoid duplication that results from choosing a size larger than the model is capable of.
|
**Description:** This node calculates an ideal image size for a first pass of a multi-pass upscaling. The aim is to avoid duplication that results from choosing a size larger than the model is capable of.
|
||||||
|
|
||||||
**Node Link:** https://github.com/JPPhoto/ideal-size-node
|
**Node Link:** https://github.com/JPPhoto/ideal-size-node
|
||||||
|
|
||||||
|
--------------------------------
|
||||||
|
### Film Grain
|
||||||
|
|
||||||
|
**Description:** This node adds a film grain effect to the input image based on the weights, seeds, and blur radii parameters. It works with RGB input images only.
|
||||||
|
|
||||||
|
**Node Link:** https://github.com/JPPhoto/film-grain-node
|
||||||
|
|
||||||
|
--------------------------------
|
||||||
|
### Image Picker
|
||||||
|
|
||||||
|
**Description:** This InvokeAI node takes in a collection of images and randomly chooses one. This can be useful when you have a number of poses to choose from for a ControlNet node, or a number of input images for another purpose.
|
||||||
|
|
||||||
|
**Node Link:** https://github.com/JPPhoto/film-grain-node
|
||||||
|
|
||||||
--------------------------------
|
--------------------------------
|
||||||
### Retroize
|
### Retroize
|
||||||
|
@ -26,11 +26,16 @@ from typing import (
|
|||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
from pydantic.fields import Undefined, ModelField
|
from pydantic.fields import Undefined, ModelField
|
||||||
from pydantic.typing import NoArgAnyCallable
|
from pydantic.typing import NoArgAnyCallable
|
||||||
|
import semver
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidVersionError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FieldDescriptions:
|
class FieldDescriptions:
|
||||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||||
@ -105,24 +110,39 @@ class UIType(str, Enum):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# region Primitives
|
# region Primitives
|
||||||
Integer = "integer"
|
|
||||||
Float = "float"
|
|
||||||
Boolean = "boolean"
|
Boolean = "boolean"
|
||||||
String = "string"
|
Color = "ColorField"
|
||||||
Array = "array"
|
|
||||||
Image = "ImageField"
|
|
||||||
Latents = "LatentsField"
|
|
||||||
Conditioning = "ConditioningField"
|
Conditioning = "ConditioningField"
|
||||||
Control = "ControlField"
|
Control = "ControlField"
|
||||||
Color = "ColorField"
|
Float = "float"
|
||||||
ImageCollection = "ImageCollection"
|
Image = "ImageField"
|
||||||
ConditioningCollection = "ConditioningCollection"
|
Integer = "integer"
|
||||||
ColorCollection = "ColorCollection"
|
Latents = "LatentsField"
|
||||||
LatentsCollection = "LatentsCollection"
|
String = "string"
|
||||||
IntegerCollection = "IntegerCollection"
|
# endregion
|
||||||
FloatCollection = "FloatCollection"
|
|
||||||
StringCollection = "StringCollection"
|
# region Collection Primitives
|
||||||
BooleanCollection = "BooleanCollection"
|
BooleanCollection = "BooleanCollection"
|
||||||
|
ColorCollection = "ColorCollection"
|
||||||
|
ConditioningCollection = "ConditioningCollection"
|
||||||
|
ControlCollection = "ControlCollection"
|
||||||
|
FloatCollection = "FloatCollection"
|
||||||
|
ImageCollection = "ImageCollection"
|
||||||
|
IntegerCollection = "IntegerCollection"
|
||||||
|
LatentsCollection = "LatentsCollection"
|
||||||
|
StringCollection = "StringCollection"
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Polymorphic Primitives
|
||||||
|
BooleanPolymorphic = "BooleanPolymorphic"
|
||||||
|
ColorPolymorphic = "ColorPolymorphic"
|
||||||
|
ConditioningPolymorphic = "ConditioningPolymorphic"
|
||||||
|
ControlPolymorphic = "ControlPolymorphic"
|
||||||
|
FloatPolymorphic = "FloatPolymorphic"
|
||||||
|
ImagePolymorphic = "ImagePolymorphic"
|
||||||
|
IntegerPolymorphic = "IntegerPolymorphic"
|
||||||
|
LatentsPolymorphic = "LatentsPolymorphic"
|
||||||
|
StringPolymorphic = "StringPolymorphic"
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Models
|
# region Models
|
||||||
@ -176,6 +196,7 @@ class _InputField(BaseModel):
|
|||||||
ui_type: Optional[UIType]
|
ui_type: Optional[UIType]
|
||||||
ui_component: Optional[UIComponent]
|
ui_component: Optional[UIComponent]
|
||||||
ui_order: Optional[int]
|
ui_order: Optional[int]
|
||||||
|
item_default: Optional[Any]
|
||||||
|
|
||||||
|
|
||||||
class _OutputField(BaseModel):
|
class _OutputField(BaseModel):
|
||||||
@ -223,6 +244,7 @@ def InputField(
|
|||||||
ui_component: Optional[UIComponent] = None,
|
ui_component: Optional[UIComponent] = None,
|
||||||
ui_hidden: bool = False,
|
ui_hidden: bool = False,
|
||||||
ui_order: Optional[int] = None,
|
ui_order: Optional[int] = None,
|
||||||
|
item_default: Optional[Any] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
@ -249,6 +271,11 @@ def InputField(
|
|||||||
For this case, you could provide `UIComponent.Textarea`.
|
For this case, you could provide `UIComponent.Textarea`.
|
||||||
|
|
||||||
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||||
|
|
||||||
|
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||||
|
|
||||||
|
: param bool item_default: [None] Specifies the default item value, if this is a collection input. \
|
||||||
|
Ignored for non-collection fields..
|
||||||
"""
|
"""
|
||||||
return Field(
|
return Field(
|
||||||
*args,
|
*args,
|
||||||
@ -282,6 +309,7 @@ def InputField(
|
|||||||
ui_component=ui_component,
|
ui_component=ui_component,
|
||||||
ui_hidden=ui_hidden,
|
ui_hidden=ui_hidden,
|
||||||
ui_order=ui_order,
|
ui_order=ui_order,
|
||||||
|
item_default=item_default,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -332,6 +360,8 @@ def OutputField(
|
|||||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||||
|
|
||||||
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||||
|
|
||||||
|
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||||
"""
|
"""
|
||||||
return Field(
|
return Field(
|
||||||
*args,
|
*args,
|
||||||
@ -376,6 +406,9 @@ class UIConfigBase(BaseModel):
|
|||||||
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
|
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
|
||||||
title: Optional[str] = Field(default=None, description="The node's display name")
|
title: Optional[str] = Field(default=None, description="The node's display name")
|
||||||
category: Optional[str] = Field(default=None, description="The node's category")
|
category: Optional[str] = Field(default=None, description="The node's category")
|
||||||
|
version: Optional[str] = Field(
|
||||||
|
default=None, description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InvocationContext:
|
class InvocationContext:
|
||||||
@ -474,6 +507,8 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
schema["tags"] = uiconfig.tags
|
schema["tags"] = uiconfig.tags
|
||||||
if uiconfig and hasattr(uiconfig, "category"):
|
if uiconfig and hasattr(uiconfig, "category"):
|
||||||
schema["category"] = uiconfig.category
|
schema["category"] = uiconfig.category
|
||||||
|
if uiconfig and hasattr(uiconfig, "version"):
|
||||||
|
schema["version"] = uiconfig.version
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = list()
|
schema["required"] = list()
|
||||||
schema["required"].extend(["type", "id"])
|
schema["required"].extend(["type", "id"])
|
||||||
@ -542,7 +577,11 @@ GenericBaseInvocation = TypeVar("GenericBaseInvocation", bound=BaseInvocation)
|
|||||||
|
|
||||||
|
|
||||||
def invocation(
|
def invocation(
|
||||||
invocation_type: str, title: Optional[str] = None, tags: Optional[list[str]] = None, category: Optional[str] = None
|
invocation_type: str,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
tags: Optional[list[str]] = None,
|
||||||
|
category: Optional[str] = None,
|
||||||
|
version: Optional[str] = None,
|
||||||
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
|
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
|
||||||
"""
|
"""
|
||||||
Adds metadata to an invocation.
|
Adds metadata to an invocation.
|
||||||
@ -569,6 +608,12 @@ def invocation(
|
|||||||
cls.UIConfig.tags = tags
|
cls.UIConfig.tags = tags
|
||||||
if category is not None:
|
if category is not None:
|
||||||
cls.UIConfig.category = category
|
cls.UIConfig.category = category
|
||||||
|
if version is not None:
|
||||||
|
try:
|
||||||
|
semver.Version.parse(version)
|
||||||
|
except ValueError as e:
|
||||||
|
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||||
|
cls.UIConfig.version = version
|
||||||
|
|
||||||
# Add the invocation type to the pydantic model of the invocation
|
# Add the invocation type to the pydantic model of the invocation
|
||||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||||
@ -580,8 +625,9 @@ def invocation(
|
|||||||
config=cls.__config__,
|
config=cls.__config__,
|
||||||
)
|
)
|
||||||
cls.__fields__.update({"type": invocation_type_field})
|
cls.__fields__.update({"type": invocation_type_field})
|
||||||
cls.__annotations__.update({"type": invocation_type_annotation})
|
# to support 3.9, 3.10 and 3.11, as described in https://docs.python.org/3/howto/annotations.html
|
||||||
|
if annotations := cls.__dict__.get("__annotations__", None):
|
||||||
|
annotations.update({"type": invocation_type_annotation})
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
@ -615,7 +661,10 @@ def invocation_output(
|
|||||||
config=cls.__config__,
|
config=cls.__config__,
|
||||||
)
|
)
|
||||||
cls.__fields__.update({"type": output_type_field})
|
cls.__fields__.update({"type": output_type_field})
|
||||||
cls.__annotations__.update({"type": output_type_annotation})
|
|
||||||
|
# to support 3.9, 3.10 and 3.11, as described in https://docs.python.org/3/howto/annotations.html
|
||||||
|
if annotations := cls.__dict__.get("__annotations__", None):
|
||||||
|
annotations.update({"type": output_type_annotation})
|
||||||
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
@ -10,7 +10,9 @@ from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
|||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation("range", title="Integer Range", tags=["collection", "integer", "range"], category="collections")
|
@invocation(
|
||||||
|
"range", title="Integer Range", tags=["collection", "integer", "range"], category="collections", version="1.0.0"
|
||||||
|
)
|
||||||
class RangeInvocation(BaseInvocation):
|
class RangeInvocation(BaseInvocation):
|
||||||
"""Creates a range of numbers from start to stop with step"""
|
"""Creates a range of numbers from start to stop with step"""
|
||||||
|
|
||||||
@ -33,6 +35,7 @@ class RangeInvocation(BaseInvocation):
|
|||||||
title="Integer Range of Size",
|
title="Integer Range of Size",
|
||||||
tags=["collection", "integer", "size", "range"],
|
tags=["collection", "integer", "size", "range"],
|
||||||
category="collections",
|
category="collections",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class RangeOfSizeInvocation(BaseInvocation):
|
class RangeOfSizeInvocation(BaseInvocation):
|
||||||
"""Creates a range from start to start + size with step"""
|
"""Creates a range from start to start + size with step"""
|
||||||
@ -50,6 +53,7 @@ class RangeOfSizeInvocation(BaseInvocation):
|
|||||||
title="Random Range",
|
title="Random Range",
|
||||||
tags=["range", "integer", "random", "collection"],
|
tags=["range", "integer", "random", "collection"],
|
||||||
category="collections",
|
category="collections",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class RandomRangeInvocation(BaseInvocation):
|
class RandomRangeInvocation(BaseInvocation):
|
||||||
"""Creates a collection of random numbers"""
|
"""Creates a collection of random numbers"""
|
||||||
|
@ -44,7 +44,7 @@ class ConditioningFieldData:
|
|||||||
# PerpNeg = "perp_neg"
|
# PerpNeg = "perp_neg"
|
||||||
|
|
||||||
|
|
||||||
@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning")
|
@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning", version="1.0.0")
|
||||||
class CompelInvocation(BaseInvocation):
|
class CompelInvocation(BaseInvocation):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
@ -267,6 +267,7 @@ class SDXLPromptInvocationBase:
|
|||||||
title="SDXL Prompt",
|
title="SDXL Prompt",
|
||||||
tags=["sdxl", "compel", "prompt"],
|
tags=["sdxl", "compel", "prompt"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -279,8 +280,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
crop_left: int = InputField(default=0, description="")
|
crop_left: int = InputField(default=0, description="")
|
||||||
target_width: int = InputField(default=1024, description="")
|
target_width: int = InputField(default=1024, description="")
|
||||||
target_height: int = InputField(default=1024, description="")
|
target_height: int = InputField(default=1024, description="")
|
||||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
@ -351,6 +352,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
title="SDXL Refiner Prompt",
|
title="SDXL Refiner Prompt",
|
||||||
tags=["sdxl", "compel", "prompt"],
|
tags=["sdxl", "compel", "prompt"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -403,7 +405,7 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
|
|||||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
|
||||||
|
|
||||||
@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning")
|
@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning", version="1.0.0")
|
||||||
class ClipSkipInvocation(BaseInvocation):
|
class ClipSkipInvocation(BaseInvocation):
|
||||||
"""Skip layers in clip text_encoder model."""
|
"""Skip layers in clip text_encoder model."""
|
||||||
|
|
||||||
|
@ -31,8 +31,8 @@ from .baseinvocation import (
|
|||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
FieldDescriptions,
|
FieldDescriptions,
|
||||||
InputField,
|
|
||||||
Input,
|
Input,
|
||||||
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
UIType,
|
UIType,
|
||||||
@ -40,7 +40,9 @@ from .baseinvocation import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet")
|
@invocation(
|
||||||
|
"image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet", version="1.0.0"
|
||||||
|
)
|
||||||
class ImageProcessorInvocation(BaseInvocation):
|
class ImageProcessorInvocation(BaseInvocation):
|
||||||
"""Base class for invocations that preprocess images for ControlNet"""
|
"""Base class for invocations that preprocess images for ControlNet"""
|
||||||
|
|
||||||
@ -84,6 +86,7 @@ class ImageProcessorInvocation(BaseInvocation):
|
|||||||
title="Canny Processor",
|
title="Canny Processor",
|
||||||
tags=["controlnet", "canny"],
|
tags=["controlnet", "canny"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Canny edge detection for ControlNet"""
|
"""Canny edge detection for ControlNet"""
|
||||||
@ -106,6 +109,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="HED (softedge) Processor",
|
title="HED (softedge) Processor",
|
||||||
tags=["controlnet", "hed", "softedge"],
|
tags=["controlnet", "hed", "softedge"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies HED edge detection to image"""
|
"""Applies HED edge detection to image"""
|
||||||
@ -134,6 +138,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Lineart Processor",
|
title="Lineart Processor",
|
||||||
tags=["controlnet", "lineart"],
|
tags=["controlnet", "lineart"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies line art processing to image"""
|
"""Applies line art processing to image"""
|
||||||
@ -155,6 +160,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Lineart Anime Processor",
|
title="Lineart Anime Processor",
|
||||||
tags=["controlnet", "lineart", "anime"],
|
tags=["controlnet", "lineart", "anime"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies line art anime processing to image"""
|
"""Applies line art anime processing to image"""
|
||||||
@ -177,6 +183,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Openpose Processor",
|
title="Openpose Processor",
|
||||||
tags=["controlnet", "openpose", "pose"],
|
tags=["controlnet", "openpose", "pose"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Openpose processing to image"""
|
"""Applies Openpose processing to image"""
|
||||||
@ -201,6 +208,7 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Midas Depth Processor",
|
title="Midas Depth Processor",
|
||||||
tags=["controlnet", "midas"],
|
tags=["controlnet", "midas"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Midas depth processing to image"""
|
"""Applies Midas depth processing to image"""
|
||||||
@ -227,6 +235,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Normal BAE Processor",
|
title="Normal BAE Processor",
|
||||||
tags=["controlnet"],
|
tags=["controlnet"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies NormalBae processing to image"""
|
"""Applies NormalBae processing to image"""
|
||||||
@ -242,7 +251,9 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@invocation("mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet")
|
@invocation(
|
||||||
|
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.0.0"
|
||||||
|
)
|
||||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies MLSD processing to image"""
|
"""Applies MLSD processing to image"""
|
||||||
|
|
||||||
@ -263,7 +274,9 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@invocation("pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet")
|
@invocation(
|
||||||
|
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.0.0"
|
||||||
|
)
|
||||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies PIDI processing to image"""
|
"""Applies PIDI processing to image"""
|
||||||
|
|
||||||
@ -289,6 +302,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Content Shuffle Processor",
|
title="Content Shuffle Processor",
|
||||||
tags=["controlnet", "contentshuffle"],
|
tags=["controlnet", "contentshuffle"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies content shuffle processing to image"""
|
"""Applies content shuffle processing to image"""
|
||||||
@ -318,6 +332,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Zoe (Depth) Processor",
|
title="Zoe (Depth) Processor",
|
||||||
tags=["controlnet", "zoe", "depth"],
|
tags=["controlnet", "zoe", "depth"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Zoe depth processing to image"""
|
"""Applies Zoe depth processing to image"""
|
||||||
@ -333,6 +348,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Mediapipe Face Processor",
|
title="Mediapipe Face Processor",
|
||||||
tags=["controlnet", "mediapipe", "face"],
|
tags=["controlnet", "mediapipe", "face"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies mediapipe face processing to image"""
|
"""Applies mediapipe face processing to image"""
|
||||||
@ -355,6 +371,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Leres (Depth) Processor",
|
title="Leres (Depth) Processor",
|
||||||
tags=["controlnet", "leres", "depth"],
|
tags=["controlnet", "leres", "depth"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies leres processing to image"""
|
"""Applies leres processing to image"""
|
||||||
@ -383,6 +400,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Tile Resample Processor",
|
title="Tile Resample Processor",
|
||||||
tags=["controlnet", "tile"],
|
tags=["controlnet", "tile"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Tile resampler processor"""
|
"""Tile resampler processor"""
|
||||||
@ -422,6 +440,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Segment Anything Processor",
|
title="Segment Anything Processor",
|
||||||
tags=["controlnet", "segmentanything"],
|
tags=["controlnet", "segmentanything"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies segment anything processing to image"""
|
"""Applies segment anything processing to image"""
|
||||||
|
@ -10,12 +10,7 @@ from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
|||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.0.0")
|
||||||
"cv_inpaint",
|
|
||||||
title="OpenCV Inpaint",
|
|
||||||
tags=["opencv", "inpaint"],
|
|
||||||
category="inpaint",
|
|
||||||
)
|
|
||||||
class CvInpaintInvocation(BaseInvocation):
|
class CvInpaintInvocation(BaseInvocation):
|
||||||
"""Simple inpaint using opencv."""
|
"""Simple inpaint using opencv."""
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ from ..models.image import ImageCategory, ResourceOrigin
|
|||||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation("show_image", title="Show Image", tags=["image"], category="image")
|
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
|
||||||
class ShowImageInvocation(BaseInvocation):
|
class ShowImageInvocation(BaseInvocation):
|
||||||
"""Displays a provided image using the OS image viewer, and passes it forward in the pipeline."""
|
"""Displays a provided image using the OS image viewer, and passes it forward in the pipeline."""
|
||||||
|
|
||||||
@ -36,7 +36,7 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("blank_image", title="Blank Image", tags=["image"], category="image")
|
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.0.0")
|
||||||
class BlankImageInvocation(BaseInvocation):
|
class BlankImageInvocation(BaseInvocation):
|
||||||
"""Creates a blank image and forwards it to the pipeline"""
|
"""Creates a blank image and forwards it to the pipeline"""
|
||||||
|
|
||||||
@ -65,7 +65,7 @@ class BlankImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image")
|
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.0.0")
|
||||||
class ImageCropInvocation(BaseInvocation):
|
class ImageCropInvocation(BaseInvocation):
|
||||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ class ImageCropInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image")
|
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.0.0")
|
||||||
class ImagePasteInvocation(BaseInvocation):
|
class ImagePasteInvocation(BaseInvocation):
|
||||||
"""Pastes an image into another image."""
|
"""Pastes an image into another image."""
|
||||||
|
|
||||||
@ -146,7 +146,7 @@ class ImagePasteInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image")
|
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.0.0")
|
||||||
class MaskFromAlphaInvocation(BaseInvocation):
|
class MaskFromAlphaInvocation(BaseInvocation):
|
||||||
"""Extracts the alpha channel of an image as a mask."""
|
"""Extracts the alpha channel of an image as a mask."""
|
||||||
|
|
||||||
@ -177,7 +177,7 @@ class MaskFromAlphaInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image")
|
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.0.0")
|
||||||
class ImageMultiplyInvocation(BaseInvocation):
|
class ImageMultiplyInvocation(BaseInvocation):
|
||||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||||
|
|
||||||
@ -210,7 +210,7 @@ class ImageMultiplyInvocation(BaseInvocation):
|
|||||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image")
|
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.0.0")
|
||||||
class ImageChannelInvocation(BaseInvocation):
|
class ImageChannelInvocation(BaseInvocation):
|
||||||
"""Gets a channel from an image."""
|
"""Gets a channel from an image."""
|
||||||
|
|
||||||
@ -242,7 +242,7 @@ class ImageChannelInvocation(BaseInvocation):
|
|||||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image")
|
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.0.0")
|
||||||
class ImageConvertInvocation(BaseInvocation):
|
class ImageConvertInvocation(BaseInvocation):
|
||||||
"""Converts an image to a different mode."""
|
"""Converts an image to a different mode."""
|
||||||
|
|
||||||
@ -271,7 +271,7 @@ class ImageConvertInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image")
|
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.0.0")
|
||||||
class ImageBlurInvocation(BaseInvocation):
|
class ImageBlurInvocation(BaseInvocation):
|
||||||
"""Blurs an image"""
|
"""Blurs an image"""
|
||||||
|
|
||||||
@ -325,7 +325,7 @@ PIL_RESAMPLING_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image")
|
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.0.0")
|
||||||
class ImageResizeInvocation(BaseInvocation):
|
class ImageResizeInvocation(BaseInvocation):
|
||||||
"""Resizes an image to specific dimensions"""
|
"""Resizes an image to specific dimensions"""
|
||||||
|
|
||||||
@ -365,7 +365,7 @@ class ImageResizeInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image")
|
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.0.0")
|
||||||
class ImageScaleInvocation(BaseInvocation):
|
class ImageScaleInvocation(BaseInvocation):
|
||||||
"""Scales an image by a factor"""
|
"""Scales an image by a factor"""
|
||||||
|
|
||||||
@ -406,7 +406,7 @@ class ImageScaleInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image")
|
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.0.0")
|
||||||
class ImageLerpInvocation(BaseInvocation):
|
class ImageLerpInvocation(BaseInvocation):
|
||||||
"""Linear interpolation of all pixels of an image"""
|
"""Linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
@ -439,7 +439,7 @@ class ImageLerpInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image")
|
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.0.0")
|
||||||
class ImageInverseLerpInvocation(BaseInvocation):
|
class ImageInverseLerpInvocation(BaseInvocation):
|
||||||
"""Inverse linear interpolation of all pixels of an image"""
|
"""Inverse linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
@ -472,7 +472,7 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image")
|
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.0.0")
|
||||||
class ImageNSFWBlurInvocation(BaseInvocation):
|
class ImageNSFWBlurInvocation(BaseInvocation):
|
||||||
"""Add blur to NSFW-flagged images"""
|
"""Add blur to NSFW-flagged images"""
|
||||||
|
|
||||||
@ -517,7 +517,9 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
|||||||
return caution.resize((caution.width // 2, caution.height // 2))
|
return caution.resize((caution.width // 2, caution.height // 2))
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image")
|
@invocation(
|
||||||
|
"img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image", version="1.0.0"
|
||||||
|
)
|
||||||
class ImageWatermarkInvocation(BaseInvocation):
|
class ImageWatermarkInvocation(BaseInvocation):
|
||||||
"""Add an invisible watermark to an image"""
|
"""Add an invisible watermark to an image"""
|
||||||
|
|
||||||
@ -548,7 +550,7 @@ class ImageWatermarkInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image")
|
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.0.0")
|
||||||
class MaskEdgeInvocation(BaseInvocation):
|
class MaskEdgeInvocation(BaseInvocation):
|
||||||
"""Applies an edge mask to an image"""
|
"""Applies an edge mask to an image"""
|
||||||
|
|
||||||
@ -593,7 +595,9 @@ class MaskEdgeInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image")
|
@invocation(
|
||||||
|
"mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image", version="1.0.0"
|
||||||
|
)
|
||||||
class MaskCombineInvocation(BaseInvocation):
|
class MaskCombineInvocation(BaseInvocation):
|
||||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||||
|
|
||||||
@ -623,7 +627,7 @@ class MaskCombineInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image")
|
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.0.0")
|
||||||
class ColorCorrectInvocation(BaseInvocation):
|
class ColorCorrectInvocation(BaseInvocation):
|
||||||
"""
|
"""
|
||||||
Shifts the colors of a target image to match the reference image, optionally
|
Shifts the colors of a target image to match the reference image, optionally
|
||||||
@ -728,7 +732,7 @@ class ColorCorrectInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image")
|
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.0.0")
|
||||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||||
"""Adjusts the Hue of an image."""
|
"""Adjusts the Hue of an image."""
|
||||||
|
|
||||||
@ -774,6 +778,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
|
|||||||
title="Adjust Image Luminosity",
|
title="Adjust Image Luminosity",
|
||||||
tags=["image", "luminosity", "hsl"],
|
tags=["image", "luminosity", "hsl"],
|
||||||
category="image",
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
||||||
"""Adjusts the Luminosity (Value) of an image."""
|
"""Adjusts the Luminosity (Value) of an image."""
|
||||||
@ -826,6 +831,7 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
|||||||
title="Adjust Image Saturation",
|
title="Adjust Image Saturation",
|
||||||
tags=["image", "saturation", "hsl"],
|
tags=["image", "saturation", "hsl"],
|
||||||
category="image",
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
||||||
"""Adjusts the Saturation of an image."""
|
"""Adjusts the Saturation of an image."""
|
||||||
|
@ -116,7 +116,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
|||||||
return si
|
return si
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint")
|
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||||
class InfillColorInvocation(BaseInvocation):
|
class InfillColorInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image with a solid color"""
|
"""Infills transparent areas of an image with a solid color"""
|
||||||
|
|
||||||
@ -151,7 +151,7 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint")
|
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||||
class InfillTileInvocation(BaseInvocation):
|
class InfillTileInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image with tiles of the image"""
|
"""Infills transparent areas of an image with tiles of the image"""
|
||||||
|
|
||||||
@ -187,7 +187,9 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint")
|
@invocation(
|
||||||
|
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0"
|
||||||
|
)
|
||||||
class InfillPatchMatchInvocation(BaseInvocation):
|
class InfillPatchMatchInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||||
|
|
||||||
@ -218,7 +220,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint")
|
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||||
class LaMaInfillInvocation(BaseInvocation):
|
class LaMaInfillInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image using the LaMa model"""
|
"""Infills transparent areas of an image using the LaMa model"""
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ class SchedulerOutput(BaseInvocationOutput):
|
|||||||
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
||||||
|
|
||||||
|
|
||||||
@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents")
|
@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents", version="1.0.0")
|
||||||
class SchedulerInvocation(BaseInvocation):
|
class SchedulerInvocation(BaseInvocation):
|
||||||
"""Selects a scheduler."""
|
"""Selects a scheduler."""
|
||||||
|
|
||||||
@ -88,7 +88,9 @@ class SchedulerInvocation(BaseInvocation):
|
|||||||
return SchedulerOutput(scheduler=self.scheduler)
|
return SchedulerOutput(scheduler=self.scheduler)
|
||||||
|
|
||||||
|
|
||||||
@invocation("create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents")
|
@invocation(
|
||||||
|
"create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents", version="1.0.0"
|
||||||
|
)
|
||||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||||
"""Creates mask for denoising model run."""
|
"""Creates mask for denoising model run."""
|
||||||
|
|
||||||
@ -188,6 +190,7 @@ def get_scheduler(
|
|||||||
title="Denoise Latents",
|
title="Denoise Latents",
|
||||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||||
category="latents",
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class DenoiseLatentsInvocation(BaseInvocation):
|
class DenoiseLatentsInvocation(BaseInvocation):
|
||||||
"""Denoises noisy latents to decodable images"""
|
"""Denoises noisy latents to decodable images"""
|
||||||
@ -210,12 +213,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
|
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
|
||||||
control: Union[ControlField, list[ControlField]] = InputField(
|
control: Union[ControlField, list[ControlField]] = InputField(
|
||||||
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
|
default=None,
|
||||||
|
description=FieldDescriptions.control,
|
||||||
|
input=Input.Connection,
|
||||||
|
ui_order=5,
|
||||||
)
|
)
|
||||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||||
default=None,
|
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=6
|
||||||
description=FieldDescriptions.mask,
|
|
||||||
)
|
)
|
||||||
# ip_adapter_image: Optional[ImageField] = InputField(input=Input.Connection, title="IP Adapter Image", ui_order=6)
|
# ip_adapter_image: Optional[ImageField] = InputField(input=Input.Connection, title="IP Adapter Image", ui_order=6)
|
||||||
# ip_adapter_strength: float = InputField(default=1.0, ge=0, le=2, ui_type=UIType.Float,
|
# ip_adapter_strength: float = InputField(default=1.0, ge=0, le=2, ui_type=UIType.Float,
|
||||||
@ -322,7 +327,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
# really only need model for dtype and device
|
# really only need model for dtype and device
|
||||||
model: StableDiffusionGeneratorPipeline,
|
model: StableDiffusionGeneratorPipeline,
|
||||||
control_input: List[ControlField],
|
control_input: Union[ControlField, List[ControlField]],
|
||||||
latents_shape: List[int],
|
latents_shape: List[int],
|
||||||
exit_stack: ExitStack,
|
exit_stack: ExitStack,
|
||||||
do_classifier_free_guidance: bool = True,
|
do_classifier_free_guidance: bool = True,
|
||||||
@ -573,7 +578,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
||||||
|
|
||||||
|
|
||||||
@invocation("l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents")
|
@invocation(
|
||||||
|
"l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", version="1.0.0"
|
||||||
|
)
|
||||||
class LatentsToImageInvocation(BaseInvocation):
|
class LatentsToImageInvocation(BaseInvocation):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
@ -670,7 +677,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||||
|
|
||||||
|
|
||||||
@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents")
|
@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
|
||||||
class ResizeLatentsInvocation(BaseInvocation):
|
class ResizeLatentsInvocation(BaseInvocation):
|
||||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||||
|
|
||||||
@ -714,7 +721,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
|
|
||||||
|
|
||||||
@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents")
|
@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
|
||||||
class ScaleLatentsInvocation(BaseInvocation):
|
class ScaleLatentsInvocation(BaseInvocation):
|
||||||
"""Scales latents by a given factor."""
|
"""Scales latents by a given factor."""
|
||||||
|
|
||||||
@ -750,7 +757,9 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
|
|
||||||
|
|
||||||
@invocation("i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents")
|
@invocation(
|
||||||
|
"i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", version="1.0.0"
|
||||||
|
)
|
||||||
class ImageToLatentsInvocation(BaseInvocation):
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
"""Encodes an image into latents."""
|
"""Encodes an image into latents."""
|
||||||
|
|
||||||
@ -830,7 +839,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||||
|
|
||||||
|
|
||||||
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents")
|
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0")
|
||||||
class BlendLatentsInvocation(BaseInvocation):
|
class BlendLatentsInvocation(BaseInvocation):
|
||||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from invokeai.app.invocations.primitives import IntegerOutput
|
|||||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation("add", title="Add Integers", tags=["math", "add"], category="math")
|
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0")
|
||||||
class AddInvocation(BaseInvocation):
|
class AddInvocation(BaseInvocation):
|
||||||
"""Adds two numbers"""
|
"""Adds two numbers"""
|
||||||
|
|
||||||
@ -18,7 +18,7 @@ class AddInvocation(BaseInvocation):
|
|||||||
return IntegerOutput(value=self.a + self.b)
|
return IntegerOutput(value=self.a + self.b)
|
||||||
|
|
||||||
|
|
||||||
@invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math")
|
@invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math", version="1.0.0")
|
||||||
class SubtractInvocation(BaseInvocation):
|
class SubtractInvocation(BaseInvocation):
|
||||||
"""Subtracts two numbers"""
|
"""Subtracts two numbers"""
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ class SubtractInvocation(BaseInvocation):
|
|||||||
return IntegerOutput(value=self.a - self.b)
|
return IntegerOutput(value=self.a - self.b)
|
||||||
|
|
||||||
|
|
||||||
@invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math")
|
@invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math", version="1.0.0")
|
||||||
class MultiplyInvocation(BaseInvocation):
|
class MultiplyInvocation(BaseInvocation):
|
||||||
"""Multiplies two numbers"""
|
"""Multiplies two numbers"""
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ class MultiplyInvocation(BaseInvocation):
|
|||||||
return IntegerOutput(value=self.a * self.b)
|
return IntegerOutput(value=self.a * self.b)
|
||||||
|
|
||||||
|
|
||||||
@invocation("div", title="Divide Integers", tags=["math", "divide"], category="math")
|
@invocation("div", title="Divide Integers", tags=["math", "divide"], category="math", version="1.0.0")
|
||||||
class DivideInvocation(BaseInvocation):
|
class DivideInvocation(BaseInvocation):
|
||||||
"""Divides two numbers"""
|
"""Divides two numbers"""
|
||||||
|
|
||||||
@ -51,7 +51,7 @@ class DivideInvocation(BaseInvocation):
|
|||||||
return IntegerOutput(value=int(self.a / self.b))
|
return IntegerOutput(value=int(self.a / self.b))
|
||||||
|
|
||||||
|
|
||||||
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math")
|
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math", version="1.0.0")
|
||||||
class RandomIntInvocation(BaseInvocation):
|
class RandomIntInvocation(BaseInvocation):
|
||||||
"""Outputs a single random integer."""
|
"""Outputs a single random integer."""
|
||||||
|
|
||||||
|
@ -72,10 +72,10 @@ class CoreMetadata(BaseModelExcludeNull):
|
|||||||
)
|
)
|
||||||
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
|
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
|
||||||
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
|
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
|
||||||
refiner_positive_aesthetic_store: Optional[float] = Field(
|
refiner_positive_aesthetic_score: Optional[float] = Field(
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
default=None, description="The aesthetic score used for the refiner"
|
||||||
)
|
)
|
||||||
refiner_negative_aesthetic_store: Optional[float] = Field(
|
refiner_negative_aesthetic_score: Optional[float] = Field(
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
default=None, description="The aesthetic score used for the refiner"
|
||||||
)
|
)
|
||||||
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
|
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
|
||||||
@ -98,7 +98,9 @@ class MetadataAccumulatorOutput(BaseInvocationOutput):
|
|||||||
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
|
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
|
||||||
|
|
||||||
|
|
||||||
@invocation("metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata")
|
@invocation(
|
||||||
|
"metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata", version="1.0.0"
|
||||||
|
)
|
||||||
class MetadataAccumulatorInvocation(BaseInvocation):
|
class MetadataAccumulatorInvocation(BaseInvocation):
|
||||||
"""Outputs a Core Metadata Object"""
|
"""Outputs a Core Metadata Object"""
|
||||||
|
|
||||||
@ -160,11 +162,11 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
|||||||
default=None,
|
default=None,
|
||||||
description="The scheduler used for the refiner",
|
description="The scheduler used for the refiner",
|
||||||
)
|
)
|
||||||
refiner_positive_aesthetic_store: Optional[float] = InputField(
|
refiner_positive_aesthetic_score: Optional[float] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="The aesthetic score used for the refiner",
|
description="The aesthetic score used for the refiner",
|
||||||
)
|
)
|
||||||
refiner_negative_aesthetic_store: Optional[float] = InputField(
|
refiner_negative_aesthetic_score: Optional[float] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="The aesthetic score used for the refiner",
|
description="The aesthetic score used for the refiner",
|
||||||
)
|
)
|
||||||
|
@ -73,7 +73,7 @@ class LoRAModelField(BaseModel):
|
|||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
|
||||||
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model")
|
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model", version="1.0.0")
|
||||||
class MainModelLoaderInvocation(BaseInvocation):
|
class MainModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
@ -173,7 +173,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
|
|||||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
|
||||||
|
|
||||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model")
|
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.0")
|
||||||
class LoraLoaderInvocation(BaseInvocation):
|
class LoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
@ -244,19 +244,19 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
|||||||
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
|
|
||||||
|
|
||||||
@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model")
|
@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model", version="1.0.0")
|
||||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
unet: Optional[UNetField] = Field(
|
unet: Optional[UNetField] = InputField(
|
||||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET"
|
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||||
)
|
)
|
||||||
clip: Optional[ClipField] = Field(
|
clip: Optional[ClipField] = InputField(
|
||||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
|
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
|
||||||
)
|
)
|
||||||
clip2: Optional[ClipField] = Field(
|
clip2: Optional[ClipField] = InputField(
|
||||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
|
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -338,7 +338,7 @@ class VaeLoaderOutput(BaseInvocationOutput):
|
|||||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model")
|
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
||||||
class VaeLoaderInvocation(BaseInvocation):
|
class VaeLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||||
|
|
||||||
@ -376,7 +376,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
|
|||||||
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model")
|
@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model", version="1.0.0")
|
||||||
class SeamlessModeInvocation(BaseInvocation):
|
class SeamlessModeInvocation(BaseInvocation):
|
||||||
"""Applies the seamless transformation to the Model UNet and VAE."""
|
"""Applies the seamless transformation to the Model UNet and VAE."""
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents")
|
@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents", version="1.0.0")
|
||||||
class NoiseInvocation(BaseInvocation):
|
class NoiseInvocation(BaseInvocation):
|
||||||
"""Generates latent noise."""
|
"""Generates latent noise."""
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ ORT_TO_NP_TYPE = {
|
|||||||
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
||||||
|
|
||||||
|
|
||||||
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning")
|
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
|
||||||
class ONNXPromptInvocation(BaseInvocation):
|
class ONNXPromptInvocation(BaseInvocation):
|
||||||
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
||||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||||
@ -143,6 +143,7 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
title="ONNX Text to Latents",
|
title="ONNX Text to Latents",
|
||||||
tags=["latents", "inference", "txt2img", "onnx"],
|
tags=["latents", "inference", "txt2img", "onnx"],
|
||||||
category="latents",
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ONNXTextToLatentsInvocation(BaseInvocation):
|
class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||||
"""Generates latents from conditionings."""
|
"""Generates latents from conditionings."""
|
||||||
@ -319,6 +320,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
title="ONNX Latents to Image",
|
title="ONNX Latents to Image",
|
||||||
tags=["latents", "image", "vae", "onnx"],
|
tags=["latents", "image", "vae", "onnx"],
|
||||||
category="image",
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
@ -403,7 +405,7 @@ class OnnxModelField(BaseModel):
|
|||||||
model_type: ModelType = Field(description="Model Type")
|
model_type: ModelType = Field(description="Model Type")
|
||||||
|
|
||||||
|
|
||||||
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model")
|
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0")
|
||||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ from invokeai.app.invocations.primitives import FloatCollectionOutput
|
|||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation("float_range", title="Float Range", tags=["math", "range"], category="math")
|
@invocation("float_range", title="Float Range", tags=["math", "range"], category="math", version="1.0.0")
|
||||||
class FloatLinearRangeInvocation(BaseInvocation):
|
class FloatLinearRangeInvocation(BaseInvocation):
|
||||||
"""Creates a range"""
|
"""Creates a range"""
|
||||||
|
|
||||||
@ -96,7 +96,7 @@ EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
|||||||
|
|
||||||
|
|
||||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||||
@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step")
|
@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step", version="1.0.0")
|
||||||
class StepParamEasingInvocation(BaseInvocation):
|
class StepParamEasingInvocation(BaseInvocation):
|
||||||
"""Experimental per-step parameter easing for denoising steps"""
|
"""Experimental per-step parameter easing for denoising steps"""
|
||||||
|
|
||||||
|
@ -14,7 +14,6 @@ from .baseinvocation import (
|
|||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
UIComponent,
|
UIComponent,
|
||||||
UIType,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -40,10 +39,14 @@ class BooleanOutput(BaseInvocationOutput):
|
|||||||
class BooleanCollectionOutput(BaseInvocationOutput):
|
class BooleanCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of booleans"""
|
"""Base class for nodes that output a collection of booleans"""
|
||||||
|
|
||||||
collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection)
|
collection: list[bool] = OutputField(
|
||||||
|
description="The output boolean collection",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives")
|
@invocation(
|
||||||
|
"boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives", version="1.0.0"
|
||||||
|
)
|
||||||
class BooleanInvocation(BaseInvocation):
|
class BooleanInvocation(BaseInvocation):
|
||||||
"""A boolean primitive value"""
|
"""A boolean primitive value"""
|
||||||
|
|
||||||
@ -58,13 +61,12 @@ class BooleanInvocation(BaseInvocation):
|
|||||||
title="Boolean Collection Primitive",
|
title="Boolean Collection Primitive",
|
||||||
tags=["primitives", "boolean", "collection"],
|
tags=["primitives", "boolean", "collection"],
|
||||||
category="primitives",
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class BooleanCollectionInvocation(BaseInvocation):
|
class BooleanCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of boolean primitive values"""
|
"""A collection of boolean primitive values"""
|
||||||
|
|
||||||
collection: list[bool] = InputField(
|
collection: list[bool] = InputField(default_factory=list, description="The collection of boolean values")
|
||||||
default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
||||||
return BooleanCollectionOutput(collection=self.collection)
|
return BooleanCollectionOutput(collection=self.collection)
|
||||||
@ -86,10 +88,14 @@ class IntegerOutput(BaseInvocationOutput):
|
|||||||
class IntegerCollectionOutput(BaseInvocationOutput):
|
class IntegerCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of integers"""
|
"""Base class for nodes that output a collection of integers"""
|
||||||
|
|
||||||
collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection)
|
collection: list[int] = OutputField(
|
||||||
|
description="The int collection",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives")
|
@invocation(
|
||||||
|
"integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives", version="1.0.0"
|
||||||
|
)
|
||||||
class IntegerInvocation(BaseInvocation):
|
class IntegerInvocation(BaseInvocation):
|
||||||
"""An integer primitive value"""
|
"""An integer primitive value"""
|
||||||
|
|
||||||
@ -104,13 +110,12 @@ class IntegerInvocation(BaseInvocation):
|
|||||||
title="Integer Collection Primitive",
|
title="Integer Collection Primitive",
|
||||||
tags=["primitives", "integer", "collection"],
|
tags=["primitives", "integer", "collection"],
|
||||||
category="primitives",
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class IntegerCollectionInvocation(BaseInvocation):
|
class IntegerCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of integer primitive values"""
|
"""A collection of integer primitive values"""
|
||||||
|
|
||||||
collection: list[int] = InputField(
|
collection: list[int] = InputField(default_factory=list, description="The collection of integer values")
|
||||||
default_factory=list, description="The collection of integer values", ui_type=UIType.IntegerCollection
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||||
return IntegerCollectionOutput(collection=self.collection)
|
return IntegerCollectionOutput(collection=self.collection)
|
||||||
@ -132,10 +137,12 @@ class FloatOutput(BaseInvocationOutput):
|
|||||||
class FloatCollectionOutput(BaseInvocationOutput):
|
class FloatCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of floats"""
|
"""Base class for nodes that output a collection of floats"""
|
||||||
|
|
||||||
collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection)
|
collection: list[float] = OutputField(
|
||||||
|
description="The float collection",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives")
|
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives", version="1.0.0")
|
||||||
class FloatInvocation(BaseInvocation):
|
class FloatInvocation(BaseInvocation):
|
||||||
"""A float primitive value"""
|
"""A float primitive value"""
|
||||||
|
|
||||||
@ -150,13 +157,12 @@ class FloatInvocation(BaseInvocation):
|
|||||||
title="Float Collection Primitive",
|
title="Float Collection Primitive",
|
||||||
tags=["primitives", "float", "collection"],
|
tags=["primitives", "float", "collection"],
|
||||||
category="primitives",
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class FloatCollectionInvocation(BaseInvocation):
|
class FloatCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of float primitive values"""
|
"""A collection of float primitive values"""
|
||||||
|
|
||||||
collection: list[float] = InputField(
|
collection: list[float] = InputField(default_factory=list, description="The collection of float values")
|
||||||
default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||||
return FloatCollectionOutput(collection=self.collection)
|
return FloatCollectionOutput(collection=self.collection)
|
||||||
@ -178,10 +184,12 @@ class StringOutput(BaseInvocationOutput):
|
|||||||
class StringCollectionOutput(BaseInvocationOutput):
|
class StringCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of strings"""
|
"""Base class for nodes that output a collection of strings"""
|
||||||
|
|
||||||
collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection)
|
collection: list[str] = OutputField(
|
||||||
|
description="The output strings",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives")
|
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives", version="1.0.0")
|
||||||
class StringInvocation(BaseInvocation):
|
class StringInvocation(BaseInvocation):
|
||||||
"""A string primitive value"""
|
"""A string primitive value"""
|
||||||
|
|
||||||
@ -196,13 +204,12 @@ class StringInvocation(BaseInvocation):
|
|||||||
title="String Collection Primitive",
|
title="String Collection Primitive",
|
||||||
tags=["primitives", "string", "collection"],
|
tags=["primitives", "string", "collection"],
|
||||||
category="primitives",
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class StringCollectionInvocation(BaseInvocation):
|
class StringCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of string primitive values"""
|
"""A collection of string primitive values"""
|
||||||
|
|
||||||
collection: list[str] = InputField(
|
collection: list[str] = InputField(default_factory=list, description="The collection of string values")
|
||||||
default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||||
return StringCollectionOutput(collection=self.collection)
|
return StringCollectionOutput(collection=self.collection)
|
||||||
@ -232,10 +239,12 @@ class ImageOutput(BaseInvocationOutput):
|
|||||||
class ImageCollectionOutput(BaseInvocationOutput):
|
class ImageCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of images"""
|
"""Base class for nodes that output a collection of images"""
|
||||||
|
|
||||||
collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection)
|
collection: list[ImageField] = OutputField(
|
||||||
|
description="The output images",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives")
|
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0")
|
||||||
class ImageInvocation(BaseInvocation):
|
class ImageInvocation(BaseInvocation):
|
||||||
"""An image primitive value"""
|
"""An image primitive value"""
|
||||||
|
|
||||||
@ -256,13 +265,12 @@ class ImageInvocation(BaseInvocation):
|
|||||||
title="Image Collection Primitive",
|
title="Image Collection Primitive",
|
||||||
tags=["primitives", "image", "collection"],
|
tags=["primitives", "image", "collection"],
|
||||||
category="primitives",
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ImageCollectionInvocation(BaseInvocation):
|
class ImageCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of image primitive values"""
|
"""A collection of image primitive values"""
|
||||||
|
|
||||||
collection: list[ImageField] = InputField(
|
collection: list[ImageField] = InputField(description="The collection of image values")
|
||||||
default_factory=list, description="The collection of image values", ui_type=UIType.ImageCollection
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||||
return ImageCollectionOutput(collection=self.collection)
|
return ImageCollectionOutput(collection=self.collection)
|
||||||
@ -316,11 +324,12 @@ class LatentsCollectionOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
collection: list[LatentsField] = OutputField(
|
collection: list[LatentsField] = OutputField(
|
||||||
description=FieldDescriptions.latents,
|
description=FieldDescriptions.latents,
|
||||||
ui_type=UIType.LatentsCollection,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives")
|
@invocation(
|
||||||
|
"latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.0"
|
||||||
|
)
|
||||||
class LatentsInvocation(BaseInvocation):
|
class LatentsInvocation(BaseInvocation):
|
||||||
"""A latents tensor primitive value"""
|
"""A latents tensor primitive value"""
|
||||||
|
|
||||||
@ -337,12 +346,13 @@ class LatentsInvocation(BaseInvocation):
|
|||||||
title="Latents Collection Primitive",
|
title="Latents Collection Primitive",
|
||||||
tags=["primitives", "latents", "collection"],
|
tags=["primitives", "latents", "collection"],
|
||||||
category="primitives",
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class LatentsCollectionInvocation(BaseInvocation):
|
class LatentsCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of latents tensor primitive values"""
|
"""A collection of latents tensor primitive values"""
|
||||||
|
|
||||||
collection: list[LatentsField] = InputField(
|
collection: list[LatentsField] = InputField(
|
||||||
description="The collection of latents tensors", ui_type=UIType.LatentsCollection
|
description="The collection of latents tensors",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
||||||
@ -385,10 +395,12 @@ class ColorOutput(BaseInvocationOutput):
|
|||||||
class ColorCollectionOutput(BaseInvocationOutput):
|
class ColorCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of colors"""
|
"""Base class for nodes that output a collection of colors"""
|
||||||
|
|
||||||
collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection)
|
collection: list[ColorField] = OutputField(
|
||||||
|
description="The output colors",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives")
|
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives", version="1.0.0")
|
||||||
class ColorInvocation(BaseInvocation):
|
class ColorInvocation(BaseInvocation):
|
||||||
"""A color primitive value"""
|
"""A color primitive value"""
|
||||||
|
|
||||||
@ -422,7 +434,6 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
collection: list[ConditioningField] = OutputField(
|
collection: list[ConditioningField] = OutputField(
|
||||||
description="The output conditioning tensors",
|
description="The output conditioning tensors",
|
||||||
ui_type=UIType.ConditioningCollection,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -431,6 +442,7 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
|
|||||||
title="Conditioning Primitive",
|
title="Conditioning Primitive",
|
||||||
tags=["primitives", "conditioning"],
|
tags=["primitives", "conditioning"],
|
||||||
category="primitives",
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ConditioningInvocation(BaseInvocation):
|
class ConditioningInvocation(BaseInvocation):
|
||||||
"""A conditioning tensor primitive value"""
|
"""A conditioning tensor primitive value"""
|
||||||
@ -446,6 +458,7 @@ class ConditioningInvocation(BaseInvocation):
|
|||||||
title="Conditioning Collection Primitive",
|
title="Conditioning Collection Primitive",
|
||||||
tags=["primitives", "conditioning", "collection"],
|
tags=["primitives", "conditioning", "collection"],
|
||||||
category="primitives",
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class ConditioningCollectionInvocation(BaseInvocation):
|
class ConditioningCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of conditioning tensor primitive values"""
|
"""A collection of conditioning tensor primitive values"""
|
||||||
@ -453,7 +466,6 @@ class ConditioningCollectionInvocation(BaseInvocation):
|
|||||||
collection: list[ConditioningField] = InputField(
|
collection: list[ConditioningField] = InputField(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="The collection of conditioning tensors",
|
description="The collection of conditioning tensors",
|
||||||
ui_type=UIType.ConditioningCollection,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
|
||||||
|
@ -10,7 +10,7 @@ from invokeai.app.invocations.primitives import StringCollectionOutput
|
|||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt")
|
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt", version="1.0.0")
|
||||||
class DynamicPromptInvocation(BaseInvocation):
|
class DynamicPromptInvocation(BaseInvocation):
|
||||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ class DynamicPromptInvocation(BaseInvocation):
|
|||||||
return StringCollectionOutput(collection=prompts)
|
return StringCollectionOutput(collection=prompts)
|
||||||
|
|
||||||
|
|
||||||
@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt")
|
@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt", version="1.0.0")
|
||||||
class PromptsFromFileInvocation(BaseInvocation):
|
class PromptsFromFileInvocation(BaseInvocation):
|
||||||
"""Loads prompts from a text file"""
|
"""Loads prompts from a text file"""
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
|||||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model")
|
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.0")
|
||||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl base model, outputting its submodels."""
|
"""Loads an sdxl base model, outputting its submodels."""
|
||||||
|
|
||||||
@ -119,6 +119,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
title="SDXL Refiner Model",
|
title="SDXL Refiner Model",
|
||||||
tags=["model", "sdxl", "refiner"],
|
tags=["model", "sdxl", "refiner"],
|
||||||
category="model",
|
category="model",
|
||||||
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||||
|
@ -23,7 +23,7 @@ ESRGAN_MODELS = Literal[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan")
|
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.0.0")
|
||||||
class ESRGANInvocation(BaseInvocation):
|
class ESRGANInvocation(BaseInvocation):
|
||||||
"""Upscales an image using RealESRGAN."""
|
"""Upscales an image using RealESRGAN."""
|
||||||
|
|
||||||
|
@ -112,6 +112,10 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
|||||||
if to_type in get_args(from_type):
|
if to_type in get_args(from_type):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# allow int -> float, pydantic will cast for us
|
||||||
|
if from_type is int and to_type is float:
|
||||||
|
return True
|
||||||
|
|
||||||
# if not issubclass(from_type, to_type):
|
# if not issubclass(from_type, to_type):
|
||||||
if not is_union_subtype(from_type, to_type):
|
if not is_union_subtype(from_type, to_type):
|
||||||
return False
|
return False
|
||||||
|
@ -50,6 +50,7 @@ class ModelProbe(object):
|
|||||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||||
"StableDiffusionXLPipeline": ModelType.Main,
|
"StableDiffusionXLPipeline": ModelType.Main,
|
||||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||||
|
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||||
"AutoencoderKL": ModelType.Vae,
|
"AutoencoderKL": ModelType.Vae,
|
||||||
"ControlNetModel": ModelType.ControlNet,
|
"ControlNetModel": ModelType.ControlNet,
|
||||||
}
|
}
|
||||||
|
@ -1,6 +0,0 @@
|
|||||||
from ldm.modules.image_degradation.bsrgan import ( # noqa: F401
|
|
||||||
degradation_bsrgan_variant as degradation_fn_bsr,
|
|
||||||
)
|
|
||||||
from ldm.modules.image_degradation.bsrgan_light import ( # noqa: F401
|
|
||||||
degradation_bsrgan_variant as degradation_fn_bsr_light,
|
|
||||||
)
|
|
@ -1,794 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# Super-Resolution
|
|
||||||
# --------------------------------------------
|
|
||||||
#
|
|
||||||
# Kai Zhang (cskaizhang@gmail.com)
|
|
||||||
# https://github.com/cszn
|
|
||||||
# From 2019/03--2021/08
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
import random
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import albumentations
|
|
||||||
import cv2
|
|
||||||
import ldm.modules.image_degradation.utils_image as util
|
|
||||||
import numpy as np
|
|
||||||
import scipy
|
|
||||||
import scipy.stats as ss
|
|
||||||
import torch
|
|
||||||
from scipy import ndimage
|
|
||||||
from scipy.interpolate import interp2d
|
|
||||||
from scipy.linalg import orth
|
|
||||||
|
|
||||||
|
|
||||||
def modcrop_np(img, sf):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
img: numpy image, WxH or WxHxC
|
|
||||||
sf: scale factor
|
|
||||||
Return:
|
|
||||||
cropped image
|
|
||||||
"""
|
|
||||||
w, h = img.shape[:2]
|
|
||||||
im = np.copy(img)
|
|
||||||
return im[: w - w % sf, : h - h % sf, ...]
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# anisotropic Gaussian kernels
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def analytic_kernel(k):
|
|
||||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
|
||||||
k_size = k.shape[0]
|
|
||||||
# Calculate the big kernels size
|
|
||||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
|
||||||
# Loop over the small kernel to fill the big one
|
|
||||||
for r in range(k_size):
|
|
||||||
for c in range(k_size):
|
|
||||||
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k
|
|
||||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
|
||||||
crop = k_size // 2
|
|
||||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
|
||||||
# Normalize to 1
|
|
||||||
return cropped_big_k / cropped_big_k.sum()
|
|
||||||
|
|
||||||
|
|
||||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
|
||||||
"""generate an anisotropic Gaussian kernel
|
|
||||||
Args:
|
|
||||||
ksize : e.g., 15, kernel size
|
|
||||||
theta : [0, pi], rotation angle range
|
|
||||||
l1 : [0.1,50], scaling of eigenvalues
|
|
||||||
l2 : [0.1,l1], scaling of eigenvalues
|
|
||||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
|
||||||
Returns:
|
|
||||||
k : kernel
|
|
||||||
"""
|
|
||||||
|
|
||||||
v = np.dot(
|
|
||||||
np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]),
|
|
||||||
np.array([1.0, 0.0]),
|
|
||||||
)
|
|
||||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
|
||||||
D = np.array([[l1, 0], [0, l2]])
|
|
||||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
|
||||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
|
||||||
|
|
||||||
return k
|
|
||||||
|
|
||||||
|
|
||||||
def gm_blur_kernel(mean, cov, size=15):
|
|
||||||
center = size / 2.0 + 0.5
|
|
||||||
k = np.zeros([size, size])
|
|
||||||
for y in range(size):
|
|
||||||
for x in range(size):
|
|
||||||
cy = y - center + 1
|
|
||||||
cx = x - center + 1
|
|
||||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
|
||||||
|
|
||||||
k = k / np.sum(k)
|
|
||||||
return k
|
|
||||||
|
|
||||||
|
|
||||||
def shift_pixel(x, sf, upper_left=True):
|
|
||||||
"""shift pixel for super-resolution with different scale factors
|
|
||||||
Args:
|
|
||||||
x: WxHxC or WxH
|
|
||||||
sf: scale factor
|
|
||||||
upper_left: shift direction
|
|
||||||
"""
|
|
||||||
h, w = x.shape[:2]
|
|
||||||
shift = (sf - 1) * 0.5
|
|
||||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
|
||||||
if upper_left:
|
|
||||||
x1 = xv + shift
|
|
||||||
y1 = yv + shift
|
|
||||||
else:
|
|
||||||
x1 = xv - shift
|
|
||||||
y1 = yv - shift
|
|
||||||
|
|
||||||
x1 = np.clip(x1, 0, w - 1)
|
|
||||||
y1 = np.clip(y1, 0, h - 1)
|
|
||||||
|
|
||||||
if x.ndim == 2:
|
|
||||||
x = interp2d(xv, yv, x)(x1, y1)
|
|
||||||
if x.ndim == 3:
|
|
||||||
for i in range(x.shape[-1]):
|
|
||||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def blur(x, k):
|
|
||||||
"""
|
|
||||||
x: image, NxcxHxW
|
|
||||||
k: kernel, Nx1xhxw
|
|
||||||
"""
|
|
||||||
n, c = x.shape[:2]
|
|
||||||
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
|
|
||||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate")
|
|
||||||
k = k.repeat(1, c, 1, 1)
|
|
||||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
|
||||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
|
||||||
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
|
|
||||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def gen_kernel(
|
|
||||||
k_size=np.array([15, 15]),
|
|
||||||
scale_factor=np.array([4, 4]),
|
|
||||||
min_var=0.6,
|
|
||||||
max_var=10.0,
|
|
||||||
noise_level=0,
|
|
||||||
):
|
|
||||||
""" "
|
|
||||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
|
||||||
# Kai Zhang
|
|
||||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
|
||||||
# max_var = 2.5 * sf
|
|
||||||
"""
|
|
||||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
|
||||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
|
||||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
|
||||||
theta = np.random.rand() * np.pi # random theta
|
|
||||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
|
||||||
|
|
||||||
# Set COV matrix using Lambdas and Theta
|
|
||||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
|
||||||
Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
|
||||||
SIGMA = Q @ LAMBDA @ Q.T
|
|
||||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
|
||||||
|
|
||||||
# Set expectation position (shifting kernel for aligned image)
|
|
||||||
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
|
||||||
MU = MU[None, None, :, None]
|
|
||||||
|
|
||||||
# Create meshgrid for Gaussian
|
|
||||||
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
|
||||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
|
||||||
|
|
||||||
# Calcualte Gaussian for every pixel of the kernel
|
|
||||||
ZZ = Z - MU
|
|
||||||
ZZ_t = ZZ.transpose(0, 1, 3, 2)
|
|
||||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
|
||||||
|
|
||||||
# shift the kernel so it will be centered
|
|
||||||
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
|
||||||
|
|
||||||
# Normalize the kernel and return
|
|
||||||
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
|
||||||
kernel = raw_kernel / np.sum(raw_kernel)
|
|
||||||
return kernel
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial_gaussian(hsize, sigma):
|
|
||||||
hsize = [hsize, hsize]
|
|
||||||
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
|
|
||||||
std = sigma
|
|
||||||
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
|
|
||||||
arg = -(x * x + y * y) / (2 * std * std)
|
|
||||||
h = np.exp(arg)
|
|
||||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
|
||||||
sumh = h.sum()
|
|
||||||
if sumh != 0:
|
|
||||||
h = h / sumh
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial_laplacian(alpha):
|
|
||||||
alpha = max([0, min([alpha, 1])])
|
|
||||||
h1 = alpha / (alpha + 1)
|
|
||||||
h2 = (1 - alpha) / (alpha + 1)
|
|
||||||
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
|
|
||||||
h = np.array(h)
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial(filter_type, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
python code from:
|
|
||||||
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
|
||||||
"""
|
|
||||||
if filter_type == "gaussian":
|
|
||||||
return fspecial_gaussian(*args, **kwargs)
|
|
||||||
if filter_type == "laplacian":
|
|
||||||
return fspecial_laplacian(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# degradation models
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def bicubic_degradation(x, sf=3):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
bicubicly downsampled LR image
|
|
||||||
"""
|
|
||||||
x = util.imresize_np(x, scale=1 / sf)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def srmd_degradation(x, k, sf=3):
|
|
||||||
"""blur + bicubic downsampling
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
Reference:
|
|
||||||
@inproceedings{zhang2018learning,
|
|
||||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
|
||||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
|
||||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
|
||||||
pages={3262--3271},
|
|
||||||
year={2018}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror'
|
|
||||||
x = bicubic_degradation(x, sf=sf)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def dpsr_degradation(x, k, sf=3):
|
|
||||||
"""bicubic downsampling + blur
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
Reference:
|
|
||||||
@inproceedings{zhang2019deep,
|
|
||||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
|
||||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
|
||||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
|
||||||
pages={1671--1681},
|
|
||||||
year={2019}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
x = bicubic_degradation(x, sf=sf)
|
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def classical_degradation(x, k, sf=3):
|
|
||||||
"""blur + downsampling
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]/[0, 255]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
"""
|
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
|
|
||||||
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
|
||||||
st = 0
|
|
||||||
return x[st::sf, st::sf, ...]
|
|
||||||
|
|
||||||
|
|
||||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
|
||||||
"""USM sharpening. borrowed from real-ESRGAN
|
|
||||||
Input image: I; Blurry image: B.
|
|
||||||
1. K = I + weight * (I - B)
|
|
||||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
|
||||||
3. Blur mask:
|
|
||||||
4. Out = Mask * K + (1 - Mask) * I
|
|
||||||
Args:
|
|
||||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
|
||||||
weight (float): Sharp weight. Default: 1.
|
|
||||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
|
||||||
threshold (int):
|
|
||||||
"""
|
|
||||||
if radius % 2 == 0:
|
|
||||||
radius += 1
|
|
||||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
|
||||||
residual = img - blur
|
|
||||||
mask = np.abs(residual) * 255 > threshold
|
|
||||||
mask = mask.astype("float32")
|
|
||||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
|
||||||
|
|
||||||
K = img + weight * residual
|
|
||||||
K = np.clip(K, 0, 1)
|
|
||||||
return soft_mask * K + (1 - soft_mask) * img
|
|
||||||
|
|
||||||
|
|
||||||
def add_blur(img, sf=4):
|
|
||||||
wd2 = 4.0 + sf
|
|
||||||
wd = 2.0 + 0.2 * sf
|
|
||||||
if random.random() < 0.5:
|
|
||||||
l1 = wd2 * random.random()
|
|
||||||
l2 = wd2 * random.random()
|
|
||||||
k = anisotropic_Gaussian(
|
|
||||||
ksize=2 * random.randint(2, 11) + 3,
|
|
||||||
theta=random.random() * np.pi,
|
|
||||||
l1=l1,
|
|
||||||
l2=l2,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
k = fspecial("gaussian", 2 * random.randint(2, 11) + 3, wd * random.random())
|
|
||||||
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode="mirror")
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_resize(img, sf=4):
|
|
||||||
rnum = np.random.rand()
|
|
||||||
if rnum > 0.8: # up
|
|
||||||
sf1 = random.uniform(1, 2)
|
|
||||||
elif rnum < 0.7: # down
|
|
||||||
sf1 = random.uniform(0.5 / sf, 1)
|
|
||||||
else:
|
|
||||||
sf1 = 1.0
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
# noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
# rnum = np.random.rand()
|
|
||||||
# if rnum > 0.6: # add color Gaussian noise
|
|
||||||
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
# elif rnum < 0.4: # add grayscale Gaussian noise
|
|
||||||
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
# else: # add noise
|
|
||||||
# L = noise_level2 / 255.
|
|
||||||
# D = np.diag(np.random.rand(3))
|
|
||||||
# U = orth(np.random.rand(3, 3))
|
|
||||||
# conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
# img = np.clip(img, 0.0, 1.0)
|
|
||||||
# return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
rnum = np.random.rand()
|
|
||||||
if rnum > 0.6: # add color Gaussian noise
|
|
||||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
|
||||||
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
else: # add noise
|
|
||||||
L = noise_level2 / 255.0
|
|
||||||
D = np.diag(np.random.rand(3))
|
|
||||||
U = orth(np.random.rand(3, 3))
|
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
rnum = random.random()
|
|
||||||
if rnum > 0.6:
|
|
||||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
elif rnum < 0.4:
|
|
||||||
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
else:
|
|
||||||
L = noise_level2 / 255.0
|
|
||||||
D = np.diag(np.random.rand(3))
|
|
||||||
U = orth(np.random.rand(3, 3))
|
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_Poisson_noise(img):
|
|
||||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
|
|
||||||
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
|
|
||||||
if random.random() < 0.5:
|
|
||||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
|
||||||
else:
|
|
||||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
|
||||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
|
|
||||||
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
|
||||||
img += noise_gray[:, :, np.newaxis]
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_JPEG_noise(img):
|
|
||||||
quality_factor = random.randint(30, 95)
|
|
||||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
|
||||||
result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
|
||||||
img = cv2.imdecode(encimg, 1)
|
|
||||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
|
||||||
h, w = lq.shape[:2]
|
|
||||||
rnd_h = random.randint(0, h - lq_patchsize)
|
|
||||||
rnd_w = random.randint(0, w - lq_patchsize)
|
|
||||||
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
|
|
||||||
|
|
||||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
|
||||||
hq = hq[
|
|
||||||
rnd_h_H : rnd_h_H + lq_patchsize * sf,
|
|
||||||
rnd_w_H : rnd_w_H + lq_patchsize * sf,
|
|
||||||
:,
|
|
||||||
]
|
|
||||||
return lq, hq
|
|
||||||
|
|
||||||
|
|
||||||
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
|
||||||
"""
|
|
||||||
This is the degradation model of BSRGAN from the paper
|
|
||||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
|
||||||
----------
|
|
||||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
|
||||||
sf: scale factor
|
|
||||||
isp_model: camera ISP model
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
|
||||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
|
||||||
"""
|
|
||||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
|
||||||
sf_ori = sf
|
|
||||||
|
|
||||||
h1, w1 = img.shape[:2]
|
|
||||||
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
|
|
||||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
|
||||||
raise ValueError(f"img size ({h1}X{w1}) is too small!")
|
|
||||||
|
|
||||||
hq = img.copy()
|
|
||||||
|
|
||||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
|
||||||
if np.random.rand() < 0.5:
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
img = util.imresize_np(img, 1 / 2, True)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
sf = 2
|
|
||||||
|
|
||||||
shuffle_order = random.sample(range(7), 7)
|
|
||||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
|
||||||
if idx1 > idx2: # keep downsample3 last
|
|
||||||
shuffle_order[idx1], shuffle_order[idx2] = (
|
|
||||||
shuffle_order[idx2],
|
|
||||||
shuffle_order[idx1],
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in shuffle_order:
|
|
||||||
if i == 0:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
|
|
||||||
elif i == 1:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
|
|
||||||
elif i == 2:
|
|
||||||
a, b = img.shape[1], img.shape[0]
|
|
||||||
# downsample2
|
|
||||||
if random.random() < 0.75:
|
|
||||||
sf1 = random.uniform(1, 2 * sf)
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
|
||||||
k_shifted = shift_pixel(k, sf)
|
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
|
||||||
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
|
||||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 3:
|
|
||||||
# downsample3
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(1 / sf * a), int(1 / sf * b)),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 4:
|
|
||||||
# add Gaussian noise
|
|
||||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
|
||||||
|
|
||||||
elif i == 5:
|
|
||||||
# add JPEG noise
|
|
||||||
if random.random() < jpeg_prob:
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
|
|
||||||
elif i == 6:
|
|
||||||
# add processed camera sensor noise
|
|
||||||
if random.random() < isp_prob and isp_model is not None:
|
|
||||||
with torch.no_grad():
|
|
||||||
img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
|
|
||||||
# add final JPEG compression noise
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
|
|
||||||
# random crop
|
|
||||||
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
|
||||||
|
|
||||||
return img, hq
|
|
||||||
|
|
||||||
|
|
||||||
# todo no isp_model?
|
|
||||||
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
|
||||||
"""
|
|
||||||
This is the degradation model of BSRGAN from the paper
|
|
||||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
|
||||||
----------
|
|
||||||
sf: scale factor
|
|
||||||
isp_model: camera ISP model
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
|
||||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
|
||||||
"""
|
|
||||||
image = util.uint2single(image)
|
|
||||||
jpeg_prob, scale2_prob = 0.9, 0.25
|
|
||||||
# isp_prob = 0.25 # uncomment with `if i== 6` block below
|
|
||||||
# sf_ori = sf # uncomment with `if i== 6` block below
|
|
||||||
|
|
||||||
h1, w1 = image.shape[:2]
|
|
||||||
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
|
||||||
h, w = image.shape[:2]
|
|
||||||
|
|
||||||
# hq = image.copy() # uncomment with `if i== 6` block below
|
|
||||||
|
|
||||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
|
||||||
if np.random.rand() < 0.5:
|
|
||||||
image = cv2.resize(
|
|
||||||
image,
|
|
||||||
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image = util.imresize_np(image, 1 / 2, True)
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
sf = 2
|
|
||||||
|
|
||||||
shuffle_order = random.sample(range(7), 7)
|
|
||||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
|
||||||
if idx1 > idx2: # keep downsample3 last
|
|
||||||
shuffle_order[idx1], shuffle_order[idx2] = (
|
|
||||||
shuffle_order[idx2],
|
|
||||||
shuffle_order[idx1],
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in shuffle_order:
|
|
||||||
if i == 0:
|
|
||||||
image = add_blur(image, sf=sf)
|
|
||||||
|
|
||||||
elif i == 1:
|
|
||||||
image = add_blur(image, sf=sf)
|
|
||||||
|
|
||||||
elif i == 2:
|
|
||||||
a, b = image.shape[1], image.shape[0]
|
|
||||||
# downsample2
|
|
||||||
if random.random() < 0.75:
|
|
||||||
sf1 = random.uniform(1, 2 * sf)
|
|
||||||
image = cv2.resize(
|
|
||||||
image,
|
|
||||||
(
|
|
||||||
int(1 / sf1 * image.shape[1]),
|
|
||||||
int(1 / sf1 * image.shape[0]),
|
|
||||||
),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
|
||||||
k_shifted = shift_pixel(k, sf)
|
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
|
||||||
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
|
||||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 3:
|
|
||||||
# downsample3
|
|
||||||
image = cv2.resize(
|
|
||||||
image,
|
|
||||||
(int(1 / sf * a), int(1 / sf * b)),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 4:
|
|
||||||
# add Gaussian noise
|
|
||||||
image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
|
|
||||||
|
|
||||||
elif i == 5:
|
|
||||||
# add JPEG noise
|
|
||||||
if random.random() < jpeg_prob:
|
|
||||||
image = add_JPEG_noise(image)
|
|
||||||
|
|
||||||
# elif i == 6:
|
|
||||||
# # add processed camera sensor noise
|
|
||||||
# if random.random() < isp_prob and isp_model is not None:
|
|
||||||
# with torch.no_grad():
|
|
||||||
# img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
|
|
||||||
# add final JPEG compression noise
|
|
||||||
image = add_JPEG_noise(image)
|
|
||||||
image = util.single2uint(image)
|
|
||||||
example = {"image": image}
|
|
||||||
return example
|
|
||||||
|
|
||||||
|
|
||||||
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
|
|
||||||
def degradation_bsrgan_plus(
|
|
||||||
img,
|
|
||||||
sf=4,
|
|
||||||
shuffle_prob=0.5,
|
|
||||||
use_sharp=True,
|
|
||||||
lq_patchsize=64,
|
|
||||||
isp_model=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
This is an extended degradation model by combining
|
|
||||||
the degradation models of BSRGAN and Real-ESRGAN
|
|
||||||
----------
|
|
||||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
|
||||||
sf: scale factor
|
|
||||||
use_shuffle: the degradation shuffle
|
|
||||||
use_sharp: sharpening the img
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
|
||||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
|
||||||
"""
|
|
||||||
|
|
||||||
h1, w1 = img.shape[:2]
|
|
||||||
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
|
|
||||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
|
||||||
raise ValueError(f"img size ({h1}X{w1}) is too small!")
|
|
||||||
|
|
||||||
if use_sharp:
|
|
||||||
img = add_sharpening(img)
|
|
||||||
hq = img.copy()
|
|
||||||
|
|
||||||
if random.random() < shuffle_prob:
|
|
||||||
shuffle_order = random.sample(range(13), 13)
|
|
||||||
else:
|
|
||||||
shuffle_order = list(range(13))
|
|
||||||
# local shuffle for noise, JPEG is always the last one
|
|
||||||
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
|
|
||||||
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
|
|
||||||
|
|
||||||
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
|
|
||||||
|
|
||||||
for i in shuffle_order:
|
|
||||||
if i == 0:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
elif i == 1:
|
|
||||||
img = add_resize(img, sf=sf)
|
|
||||||
elif i == 2:
|
|
||||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
|
||||||
elif i == 3:
|
|
||||||
if random.random() < poisson_prob:
|
|
||||||
img = add_Poisson_noise(img)
|
|
||||||
elif i == 4:
|
|
||||||
if random.random() < speckle_prob:
|
|
||||||
img = add_speckle_noise(img)
|
|
||||||
elif i == 5:
|
|
||||||
if random.random() < isp_prob and isp_model is not None:
|
|
||||||
with torch.no_grad():
|
|
||||||
img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
elif i == 6:
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
elif i == 7:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
elif i == 8:
|
|
||||||
img = add_resize(img, sf=sf)
|
|
||||||
elif i == 9:
|
|
||||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
|
||||||
elif i == 10:
|
|
||||||
if random.random() < poisson_prob:
|
|
||||||
img = add_Poisson_noise(img)
|
|
||||||
elif i == 11:
|
|
||||||
if random.random() < speckle_prob:
|
|
||||||
img = add_speckle_noise(img)
|
|
||||||
elif i == 12:
|
|
||||||
if random.random() < isp_prob and isp_model is not None:
|
|
||||||
with torch.no_grad():
|
|
||||||
img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
else:
|
|
||||||
print("check the shuffle!")
|
|
||||||
|
|
||||||
# resize to desired size
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
|
|
||||||
# add final JPEG compression noise
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
|
|
||||||
# random crop
|
|
||||||
img, hq = random_crop(img, hq, sf, lq_patchsize)
|
|
||||||
|
|
||||||
return img, hq
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("hey")
|
|
||||||
img = util.imread_uint("utils/test.png", 3)
|
|
||||||
print(img)
|
|
||||||
img = util.uint2single(img)
|
|
||||||
print(img)
|
|
||||||
img = img[:448, :448]
|
|
||||||
h = img.shape[0] // 4
|
|
||||||
print("resizing to", h)
|
|
||||||
sf = 4
|
|
||||||
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
|
|
||||||
for i in range(20):
|
|
||||||
print(i)
|
|
||||||
img_lq = deg_fn(img)
|
|
||||||
print(img_lq)
|
|
||||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
|
|
||||||
print(img_lq.shape)
|
|
||||||
print("bicubic", img_lq_bicubic.shape)
|
|
||||||
# print(img_hq.shape)
|
|
||||||
lq_nearest = cv2.resize(
|
|
||||||
util.single2uint(img_lq),
|
|
||||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
|
||||||
interpolation=0,
|
|
||||||
)
|
|
||||||
lq_bicubic_nearest = cv2.resize(
|
|
||||||
util.single2uint(img_lq_bicubic),
|
|
||||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
|
||||||
interpolation=0,
|
|
||||||
)
|
|
||||||
# img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
|
||||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest], axis=1)
|
|
||||||
util.imsave(img_concat, str(i) + ".png")
|
|
@ -1,704 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
import random
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import albumentations
|
|
||||||
import cv2
|
|
||||||
import ldm.modules.image_degradation.utils_image as util
|
|
||||||
import numpy as np
|
|
||||||
import scipy
|
|
||||||
import scipy.stats as ss
|
|
||||||
import torch
|
|
||||||
from scipy import ndimage
|
|
||||||
from scipy.interpolate import interp2d
|
|
||||||
from scipy.linalg import orth
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# Super-Resolution
|
|
||||||
# --------------------------------------------
|
|
||||||
#
|
|
||||||
# Kai Zhang (cskaizhang@gmail.com)
|
|
||||||
# https://github.com/cszn
|
|
||||||
# From 2019/03--2021/08
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def modcrop_np(img, sf):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
img: numpy image, WxH or WxHxC
|
|
||||||
sf: scale factor
|
|
||||||
Return:
|
|
||||||
cropped image
|
|
||||||
"""
|
|
||||||
w, h = img.shape[:2]
|
|
||||||
im = np.copy(img)
|
|
||||||
return im[: w - w % sf, : h - h % sf, ...]
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# anisotropic Gaussian kernels
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def analytic_kernel(k):
|
|
||||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
|
||||||
k_size = k.shape[0]
|
|
||||||
# Calculate the big kernels size
|
|
||||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
|
||||||
# Loop over the small kernel to fill the big one
|
|
||||||
for r in range(k_size):
|
|
||||||
for c in range(k_size):
|
|
||||||
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k
|
|
||||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
|
||||||
crop = k_size // 2
|
|
||||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
|
||||||
# Normalize to 1
|
|
||||||
return cropped_big_k / cropped_big_k.sum()
|
|
||||||
|
|
||||||
|
|
||||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
|
||||||
"""generate an anisotropic Gaussian kernel
|
|
||||||
Args:
|
|
||||||
ksize : e.g., 15, kernel size
|
|
||||||
theta : [0, pi], rotation angle range
|
|
||||||
l1 : [0.1,50], scaling of eigenvalues
|
|
||||||
l2 : [0.1,l1], scaling of eigenvalues
|
|
||||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
|
||||||
Returns:
|
|
||||||
k : kernel
|
|
||||||
"""
|
|
||||||
|
|
||||||
v = np.dot(
|
|
||||||
np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]),
|
|
||||||
np.array([1.0, 0.0]),
|
|
||||||
)
|
|
||||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
|
||||||
D = np.array([[l1, 0], [0, l2]])
|
|
||||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
|
||||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
|
||||||
|
|
||||||
return k
|
|
||||||
|
|
||||||
|
|
||||||
def gm_blur_kernel(mean, cov, size=15):
|
|
||||||
center = size / 2.0 + 0.5
|
|
||||||
k = np.zeros([size, size])
|
|
||||||
for y in range(size):
|
|
||||||
for x in range(size):
|
|
||||||
cy = y - center + 1
|
|
||||||
cx = x - center + 1
|
|
||||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
|
||||||
|
|
||||||
k = k / np.sum(k)
|
|
||||||
return k
|
|
||||||
|
|
||||||
|
|
||||||
def shift_pixel(x, sf, upper_left=True):
|
|
||||||
"""shift pixel for super-resolution with different scale factors
|
|
||||||
Args:
|
|
||||||
x: WxHxC or WxH
|
|
||||||
sf: scale factor
|
|
||||||
upper_left: shift direction
|
|
||||||
"""
|
|
||||||
h, w = x.shape[:2]
|
|
||||||
shift = (sf - 1) * 0.5
|
|
||||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
|
||||||
if upper_left:
|
|
||||||
x1 = xv + shift
|
|
||||||
y1 = yv + shift
|
|
||||||
else:
|
|
||||||
x1 = xv - shift
|
|
||||||
y1 = yv - shift
|
|
||||||
|
|
||||||
x1 = np.clip(x1, 0, w - 1)
|
|
||||||
y1 = np.clip(y1, 0, h - 1)
|
|
||||||
|
|
||||||
if x.ndim == 2:
|
|
||||||
x = interp2d(xv, yv, x)(x1, y1)
|
|
||||||
if x.ndim == 3:
|
|
||||||
for i in range(x.shape[-1]):
|
|
||||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def blur(x, k):
|
|
||||||
"""
|
|
||||||
x: image, NxcxHxW
|
|
||||||
k: kernel, Nx1xhxw
|
|
||||||
"""
|
|
||||||
n, c = x.shape[:2]
|
|
||||||
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
|
|
||||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate")
|
|
||||||
k = k.repeat(1, c, 1, 1)
|
|
||||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
|
||||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
|
||||||
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
|
|
||||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def gen_kernel(
|
|
||||||
k_size=np.array([15, 15]),
|
|
||||||
scale_factor=np.array([4, 4]),
|
|
||||||
min_var=0.6,
|
|
||||||
max_var=10.0,
|
|
||||||
noise_level=0,
|
|
||||||
):
|
|
||||||
""" "
|
|
||||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
|
||||||
# Kai Zhang
|
|
||||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
|
||||||
# max_var = 2.5 * sf
|
|
||||||
"""
|
|
||||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
|
||||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
|
||||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
|
||||||
theta = np.random.rand() * np.pi # random theta
|
|
||||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
|
||||||
|
|
||||||
# Set COV matrix using Lambdas and Theta
|
|
||||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
|
||||||
Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
|
||||||
SIGMA = Q @ LAMBDA @ Q.T
|
|
||||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
|
||||||
|
|
||||||
# Set expectation position (shifting kernel for aligned image)
|
|
||||||
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
|
||||||
MU = MU[None, None, :, None]
|
|
||||||
|
|
||||||
# Create meshgrid for Gaussian
|
|
||||||
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
|
||||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
|
||||||
|
|
||||||
# Calcualte Gaussian for every pixel of the kernel
|
|
||||||
ZZ = Z - MU
|
|
||||||
ZZ_t = ZZ.transpose(0, 1, 3, 2)
|
|
||||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
|
||||||
|
|
||||||
# shift the kernel so it will be centered
|
|
||||||
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
|
||||||
|
|
||||||
# Normalize the kernel and return
|
|
||||||
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
|
||||||
kernel = raw_kernel / np.sum(raw_kernel)
|
|
||||||
return kernel
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial_gaussian(hsize, sigma):
|
|
||||||
hsize = [hsize, hsize]
|
|
||||||
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
|
|
||||||
std = sigma
|
|
||||||
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
|
|
||||||
arg = -(x * x + y * y) / (2 * std * std)
|
|
||||||
h = np.exp(arg)
|
|
||||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
|
||||||
sumh = h.sum()
|
|
||||||
if sumh != 0:
|
|
||||||
h = h / sumh
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial_laplacian(alpha):
|
|
||||||
alpha = max([0, min([alpha, 1])])
|
|
||||||
h1 = alpha / (alpha + 1)
|
|
||||||
h2 = (1 - alpha) / (alpha + 1)
|
|
||||||
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
|
|
||||||
h = np.array(h)
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
def fspecial(filter_type, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
python code from:
|
|
||||||
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
|
||||||
"""
|
|
||||||
if filter_type == "gaussian":
|
|
||||||
return fspecial_gaussian(*args, **kwargs)
|
|
||||||
if filter_type == "laplacian":
|
|
||||||
return fspecial_laplacian(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# degradation models
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def bicubic_degradation(x, sf=3):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
bicubicly downsampled LR image
|
|
||||||
"""
|
|
||||||
x = util.imresize_np(x, scale=1 / sf)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def srmd_degradation(x, k, sf=3):
|
|
||||||
"""blur + bicubic downsampling
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
Reference:
|
|
||||||
@inproceedings{zhang2018learning,
|
|
||||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
|
||||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
|
||||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
|
||||||
pages={3262--3271},
|
|
||||||
year={2018}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror'
|
|
||||||
x = bicubic_degradation(x, sf=sf)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def dpsr_degradation(x, k, sf=3):
|
|
||||||
"""bicubic downsampling + blur
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
Reference:
|
|
||||||
@inproceedings{zhang2019deep,
|
|
||||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
|
||||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
|
||||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
|
||||||
pages={1671--1681},
|
|
||||||
year={2019}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
x = bicubic_degradation(x, sf=sf)
|
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def classical_degradation(x, k, sf=3):
|
|
||||||
"""blur + downsampling
|
|
||||||
Args:
|
|
||||||
x: HxWxC image, [0, 1]/[0, 255]
|
|
||||||
k: hxw, double
|
|
||||||
sf: down-scale factor
|
|
||||||
Return:
|
|
||||||
downsampled LR image
|
|
||||||
"""
|
|
||||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
|
|
||||||
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
|
||||||
st = 0
|
|
||||||
return x[st::sf, st::sf, ...]
|
|
||||||
|
|
||||||
|
|
||||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
|
||||||
"""USM sharpening. borrowed from real-ESRGAN
|
|
||||||
Input image: I; Blurry image: B.
|
|
||||||
1. K = I + weight * (I - B)
|
|
||||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
|
||||||
3. Blur mask:
|
|
||||||
4. Out = Mask * K + (1 - Mask) * I
|
|
||||||
Args:
|
|
||||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
|
||||||
weight (float): Sharp weight. Default: 1.
|
|
||||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
|
||||||
threshold (int):
|
|
||||||
"""
|
|
||||||
if radius % 2 == 0:
|
|
||||||
radius += 1
|
|
||||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
|
||||||
residual = img - blur
|
|
||||||
mask = np.abs(residual) * 255 > threshold
|
|
||||||
mask = mask.astype("float32")
|
|
||||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
|
||||||
|
|
||||||
K = img + weight * residual
|
|
||||||
K = np.clip(K, 0, 1)
|
|
||||||
return soft_mask * K + (1 - soft_mask) * img
|
|
||||||
|
|
||||||
|
|
||||||
def add_blur(img, sf=4):
|
|
||||||
wd2 = 4.0 + sf
|
|
||||||
wd = 2.0 + 0.2 * sf
|
|
||||||
|
|
||||||
wd2 = wd2 / 4
|
|
||||||
wd = wd / 4
|
|
||||||
|
|
||||||
if random.random() < 0.5:
|
|
||||||
l1 = wd2 * random.random()
|
|
||||||
l2 = wd2 * random.random()
|
|
||||||
k = anisotropic_Gaussian(
|
|
||||||
ksize=random.randint(2, 11) + 3,
|
|
||||||
theta=random.random() * np.pi,
|
|
||||||
l1=l1,
|
|
||||||
l2=l2,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
k = fspecial("gaussian", random.randint(2, 4) + 3, wd * random.random())
|
|
||||||
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode="mirror")
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_resize(img, sf=4):
|
|
||||||
rnum = np.random.rand()
|
|
||||||
if rnum > 0.8: # up
|
|
||||||
sf1 = random.uniform(1, 2)
|
|
||||||
elif rnum < 0.7: # down
|
|
||||||
sf1 = random.uniform(0.5 / sf, 1)
|
|
||||||
else:
|
|
||||||
sf1 = 1.0
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
# noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
# rnum = np.random.rand()
|
|
||||||
# if rnum > 0.6: # add color Gaussian noise
|
|
||||||
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
# elif rnum < 0.4: # add grayscale Gaussian noise
|
|
||||||
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
# else: # add noise
|
|
||||||
# L = noise_level2 / 255.
|
|
||||||
# D = np.diag(np.random.rand(3))
|
|
||||||
# U = orth(np.random.rand(3, 3))
|
|
||||||
# conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
# img = np.clip(img, 0.0, 1.0)
|
|
||||||
# return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
rnum = np.random.rand()
|
|
||||||
if rnum > 0.6: # add color Gaussian noise
|
|
||||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
|
||||||
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
else: # add noise
|
|
||||||
L = noise_level2 / 255.0
|
|
||||||
D = np.diag(np.random.rand(3))
|
|
||||||
U = orth(np.random.rand(3, 3))
|
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
|
||||||
noise_level = random.randint(noise_level1, noise_level2)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
rnum = random.random()
|
|
||||||
if rnum > 0.6:
|
|
||||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
|
||||||
elif rnum < 0.4:
|
|
||||||
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
|
||||||
else:
|
|
||||||
L = noise_level2 / 255.0
|
|
||||||
D = np.diag(np.random.rand(3))
|
|
||||||
U = orth(np.random.rand(3, 3))
|
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
|
||||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_Poisson_noise(img):
|
|
||||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
|
|
||||||
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
|
|
||||||
if random.random() < 0.5:
|
|
||||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
|
||||||
else:
|
|
||||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
|
||||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
|
|
||||||
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
|
||||||
img += noise_gray[:, :, np.newaxis]
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def add_JPEG_noise(img):
|
|
||||||
quality_factor = random.randint(80, 95)
|
|
||||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
|
||||||
result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
|
||||||
img = cv2.imdecode(encimg, 1)
|
|
||||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
|
||||||
h, w = lq.shape[:2]
|
|
||||||
rnd_h = random.randint(0, h - lq_patchsize)
|
|
||||||
rnd_w = random.randint(0, w - lq_patchsize)
|
|
||||||
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
|
|
||||||
|
|
||||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
|
||||||
hq = hq[
|
|
||||||
rnd_h_H : rnd_h_H + lq_patchsize * sf,
|
|
||||||
rnd_w_H : rnd_w_H + lq_patchsize * sf,
|
|
||||||
:,
|
|
||||||
]
|
|
||||||
return lq, hq
|
|
||||||
|
|
||||||
|
|
||||||
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
|
||||||
"""
|
|
||||||
This is the degradation model of BSRGAN from the paper
|
|
||||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
|
||||||
----------
|
|
||||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
|
||||||
sf: scale factor
|
|
||||||
isp_model: camera ISP model
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
|
||||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
|
||||||
"""
|
|
||||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
|
||||||
sf_ori = sf
|
|
||||||
|
|
||||||
h1, w1 = img.shape[:2]
|
|
||||||
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
|
|
||||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
|
||||||
raise ValueError(f"img size ({h1}X{w1}) is too small!")
|
|
||||||
|
|
||||||
hq = img.copy()
|
|
||||||
|
|
||||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
|
||||||
if np.random.rand() < 0.5:
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
img = util.imresize_np(img, 1 / 2, True)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
sf = 2
|
|
||||||
|
|
||||||
shuffle_order = random.sample(range(7), 7)
|
|
||||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
|
||||||
if idx1 > idx2: # keep downsample3 last
|
|
||||||
shuffle_order[idx1], shuffle_order[idx2] = (
|
|
||||||
shuffle_order[idx2],
|
|
||||||
shuffle_order[idx1],
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in shuffle_order:
|
|
||||||
if i == 0:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
|
|
||||||
elif i == 1:
|
|
||||||
img = add_blur(img, sf=sf)
|
|
||||||
|
|
||||||
elif i == 2:
|
|
||||||
a, b = img.shape[1], img.shape[0]
|
|
||||||
# downsample2
|
|
||||||
if random.random() < 0.75:
|
|
||||||
sf1 = random.uniform(1, 2 * sf)
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
|
||||||
k_shifted = shift_pixel(k, sf)
|
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
|
||||||
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
|
||||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 3:
|
|
||||||
# downsample3
|
|
||||||
img = cv2.resize(
|
|
||||||
img,
|
|
||||||
(int(1 / sf * a), int(1 / sf * b)),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 4:
|
|
||||||
# add Gaussian noise
|
|
||||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
|
|
||||||
|
|
||||||
elif i == 5:
|
|
||||||
# add JPEG noise
|
|
||||||
if random.random() < jpeg_prob:
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
|
|
||||||
elif i == 6:
|
|
||||||
# add processed camera sensor noise
|
|
||||||
if random.random() < isp_prob and isp_model is not None:
|
|
||||||
with torch.no_grad():
|
|
||||||
img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
|
|
||||||
# add final JPEG compression noise
|
|
||||||
img = add_JPEG_noise(img)
|
|
||||||
|
|
||||||
# random crop
|
|
||||||
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
|
||||||
|
|
||||||
return img, hq
|
|
||||||
|
|
||||||
|
|
||||||
# todo no isp_model?
|
|
||||||
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
|
||||||
"""
|
|
||||||
This is the degradation model of BSRGAN from the paper
|
|
||||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
|
||||||
----------
|
|
||||||
sf: scale factor
|
|
||||||
isp_model: camera ISP model
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
|
||||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
|
||||||
"""
|
|
||||||
image = util.uint2single(image)
|
|
||||||
jpeg_prob, scale2_prob = 0.9, 0.25
|
|
||||||
# isp_prob = 0.25 # uncomment with `if i== 6` block below
|
|
||||||
# sf_ori = sf # uncomment with `if i== 6` block below
|
|
||||||
|
|
||||||
h1, w1 = image.shape[:2]
|
|
||||||
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
|
||||||
h, w = image.shape[:2]
|
|
||||||
|
|
||||||
# hq = image.copy() # uncomment with `if i== 6` block below
|
|
||||||
|
|
||||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
|
||||||
if np.random.rand() < 0.5:
|
|
||||||
image = cv2.resize(
|
|
||||||
image,
|
|
||||||
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image = util.imresize_np(image, 1 / 2, True)
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
sf = 2
|
|
||||||
|
|
||||||
shuffle_order = random.sample(range(7), 7)
|
|
||||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
|
||||||
if idx1 > idx2: # keep downsample3 last
|
|
||||||
shuffle_order[idx1], shuffle_order[idx2] = (
|
|
||||||
shuffle_order[idx2],
|
|
||||||
shuffle_order[idx1],
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in shuffle_order:
|
|
||||||
if i == 0:
|
|
||||||
image = add_blur(image, sf=sf)
|
|
||||||
|
|
||||||
# elif i == 1:
|
|
||||||
# image = add_blur(image, sf=sf)
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif i == 2:
|
|
||||||
a, b = image.shape[1], image.shape[0]
|
|
||||||
# downsample2
|
|
||||||
if random.random() < 0.8:
|
|
||||||
sf1 = random.uniform(1, 2 * sf)
|
|
||||||
image = cv2.resize(
|
|
||||||
image,
|
|
||||||
(
|
|
||||||
int(1 / sf1 * image.shape[1]),
|
|
||||||
int(1 / sf1 * image.shape[0]),
|
|
||||||
),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
|
||||||
k_shifted = shift_pixel(k, sf)
|
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
|
||||||
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
|
||||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
|
||||||
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 3:
|
|
||||||
# downsample3
|
|
||||||
image = cv2.resize(
|
|
||||||
image,
|
|
||||||
(int(1 / sf * a), int(1 / sf * b)),
|
|
||||||
interpolation=random.choice([1, 2, 3]),
|
|
||||||
)
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
|
||||||
|
|
||||||
elif i == 4:
|
|
||||||
# add Gaussian noise
|
|
||||||
image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
|
|
||||||
|
|
||||||
elif i == 5:
|
|
||||||
# add JPEG noise
|
|
||||||
if random.random() < jpeg_prob:
|
|
||||||
image = add_JPEG_noise(image)
|
|
||||||
#
|
|
||||||
# elif i == 6:
|
|
||||||
# # add processed camera sensor noise
|
|
||||||
# if random.random() < isp_prob and isp_model is not None:
|
|
||||||
# with torch.no_grad():
|
|
||||||
# img, hq = isp_model.forward(img.copy(), hq)
|
|
||||||
|
|
||||||
# add final JPEG compression noise
|
|
||||||
image = add_JPEG_noise(image)
|
|
||||||
image = util.single2uint(image)
|
|
||||||
example = {"image": image}
|
|
||||||
return example
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("hey")
|
|
||||||
img = util.imread_uint("utils/test.png", 3)
|
|
||||||
img = img[:448, :448]
|
|
||||||
h = img.shape[0] // 4
|
|
||||||
print("resizing to", h)
|
|
||||||
sf = 4
|
|
||||||
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
|
|
||||||
for i in range(20):
|
|
||||||
print(i)
|
|
||||||
img_hq = img
|
|
||||||
img_lq = deg_fn(img)["image"]
|
|
||||||
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
|
|
||||||
print(img_lq)
|
|
||||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)[
|
|
||||||
"image"
|
|
||||||
]
|
|
||||||
print(img_lq.shape)
|
|
||||||
print("bicubic", img_lq_bicubic.shape)
|
|
||||||
print(img_hq.shape)
|
|
||||||
lq_nearest = cv2.resize(
|
|
||||||
util.single2uint(img_lq),
|
|
||||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
|
||||||
interpolation=0,
|
|
||||||
)
|
|
||||||
lq_bicubic_nearest = cv2.resize(
|
|
||||||
util.single2uint(img_lq_bicubic),
|
|
||||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
|
||||||
interpolation=0,
|
|
||||||
)
|
|
||||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
|
||||||
util.imsave(img_concat, str(i) + ".png")
|
|
Binary file not shown.
Before Width: | Height: | Size: 431 KiB |
@ -1,968 +0,0 @@
|
|||||||
import math
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torchvision.utils import make_grid
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# Kai Zhang (github: https://github.com/cszn)
|
|
||||||
# 03/Mar/2019
|
|
||||||
# --------------------------------------------
|
|
||||||
# https://github.com/twhui/SRGAN-pyTorch
|
|
||||||
# https://github.com/xinntao/BasicSR
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
IMG_EXTENSIONS = [
|
|
||||||
".jpg",
|
|
||||||
".JPG",
|
|
||||||
".jpeg",
|
|
||||||
".JPEG",
|
|
||||||
".png",
|
|
||||||
".PNG",
|
|
||||||
".ppm",
|
|
||||||
".PPM",
|
|
||||||
".bmp",
|
|
||||||
".BMP",
|
|
||||||
".tif",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def is_image_file(filename):
|
|
||||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
|
||||||
|
|
||||||
|
|
||||||
def get_timestamp():
|
|
||||||
return datetime.now().strftime("%y%m%d-%H%M%S")
|
|
||||||
|
|
||||||
|
|
||||||
def imshow(x, title=None, cbar=False, figsize=None):
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
plt.figure(figsize=figsize)
|
|
||||||
plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray")
|
|
||||||
if title:
|
|
||||||
plt.title(title)
|
|
||||||
if cbar:
|
|
||||||
plt.colorbar()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
def surf(Z, cmap="rainbow", figsize=None):
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
plt.figure(figsize=figsize)
|
|
||||||
ax3 = plt.axes(projection="3d")
|
|
||||||
|
|
||||||
w, h = Z.shape[:2]
|
|
||||||
xx = np.arange(0, w, 1)
|
|
||||||
yy = np.arange(0, h, 1)
|
|
||||||
X, Y = np.meshgrid(xx, yy)
|
|
||||||
ax3.plot_surface(X, Y, Z, cmap=cmap)
|
|
||||||
# ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# get image pathes
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_paths(dataroot):
|
|
||||||
paths = None # return None if dataroot is None
|
|
||||||
if dataroot is not None:
|
|
||||||
paths = sorted(_get_paths_from_images(dataroot))
|
|
||||||
return paths
|
|
||||||
|
|
||||||
|
|
||||||
def _get_paths_from_images(path):
|
|
||||||
assert os.path.isdir(path), "{:s} is not a valid directory".format(path)
|
|
||||||
images = []
|
|
||||||
for dirpath, _, fnames in sorted(os.walk(path, followlinks=True)):
|
|
||||||
for fname in sorted(fnames):
|
|
||||||
if is_image_file(fname):
|
|
||||||
img_path = os.path.join(dirpath, fname)
|
|
||||||
images.append(img_path)
|
|
||||||
assert images, "{:s} has no valid image file".format(path)
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# split large images into small images
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
|
|
||||||
w, h = img.shape[:2]
|
|
||||||
patches = []
|
|
||||||
if w > p_max and h > p_max:
|
|
||||||
w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int))
|
|
||||||
h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int))
|
|
||||||
w1.append(w - p_size)
|
|
||||||
h1.append(h - p_size)
|
|
||||||
# print(w1)
|
|
||||||
# print(h1)
|
|
||||||
for i in w1:
|
|
||||||
for j in h1:
|
|
||||||
patches.append(img[i : i + p_size, j : j + p_size, :])
|
|
||||||
else:
|
|
||||||
patches.append(img)
|
|
||||||
|
|
||||||
return patches
|
|
||||||
|
|
||||||
|
|
||||||
def imssave(imgs, img_path):
|
|
||||||
"""
|
|
||||||
imgs: list, N images of size WxHxC
|
|
||||||
"""
|
|
||||||
img_name, ext = os.path.splitext(os.path.basename(img_path))
|
|
||||||
|
|
||||||
for i, img in enumerate(imgs):
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = img[:, :, [2, 1, 0]]
|
|
||||||
new_path = os.path.join(
|
|
||||||
os.path.dirname(img_path),
|
|
||||||
img_name + str("_s{:04d}".format(i)) + ".png",
|
|
||||||
)
|
|
||||||
cv2.imwrite(new_path, img)
|
|
||||||
|
|
||||||
|
|
||||||
def split_imageset(
|
|
||||||
original_dataroot,
|
|
||||||
taget_dataroot,
|
|
||||||
n_channels=3,
|
|
||||||
p_size=800,
|
|
||||||
p_overlap=96,
|
|
||||||
p_max=1000,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
|
|
||||||
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
|
|
||||||
will be splitted.
|
|
||||||
Args:
|
|
||||||
original_dataroot:
|
|
||||||
taget_dataroot:
|
|
||||||
p_size: size of small images
|
|
||||||
p_overlap: patch size in training is a good choice
|
|
||||||
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
|
|
||||||
"""
|
|
||||||
paths = get_image_paths(original_dataroot)
|
|
||||||
for img_path in paths:
|
|
||||||
# img_name, ext = os.path.splitext(os.path.basename(img_path))
|
|
||||||
img = imread_uint(img_path, n_channels=n_channels)
|
|
||||||
patches = patches_from_image(img, p_size, p_overlap, p_max)
|
|
||||||
imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path)))
|
|
||||||
# if original_dataroot == taget_dataroot:
|
|
||||||
# del img_path
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# makedir
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def mkdir(path):
|
|
||||||
if not os.path.exists(path):
|
|
||||||
os.makedirs(path)
|
|
||||||
|
|
||||||
|
|
||||||
def mkdirs(paths):
|
|
||||||
if isinstance(paths, str):
|
|
||||||
mkdir(paths)
|
|
||||||
else:
|
|
||||||
for path in paths:
|
|
||||||
mkdir(path)
|
|
||||||
|
|
||||||
|
|
||||||
def mkdir_and_rename(path):
|
|
||||||
if os.path.exists(path):
|
|
||||||
new_name = path + "_archived_" + get_timestamp()
|
|
||||||
logger.error("Path already exists. Rename it to [{:s}]".format(new_name))
|
|
||||||
os.replace(path, new_name)
|
|
||||||
os.makedirs(path)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# read image from path
|
|
||||||
# opencv is fast, but read BGR numpy image
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# get uint8 image of size HxWxn_channles (RGB)
|
|
||||||
# --------------------------------------------
|
|
||||||
def imread_uint(path, n_channels=3):
|
|
||||||
# input: path
|
|
||||||
# output: HxWx3(RGB or GGG), or HxWx1 (G)
|
|
||||||
if n_channels == 1:
|
|
||||||
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
|
|
||||||
img = np.expand_dims(img, axis=2) # HxWx1
|
|
||||||
elif n_channels == 3:
|
|
||||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
|
|
||||||
if img.ndim == 2:
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
|
|
||||||
else:
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# matlab's imwrite
|
|
||||||
# --------------------------------------------
|
|
||||||
def imsave(img, img_path):
|
|
||||||
img = np.squeeze(img)
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = img[:, :, [2, 1, 0]]
|
|
||||||
cv2.imwrite(img_path, img)
|
|
||||||
|
|
||||||
|
|
||||||
def imwrite(img, img_path):
|
|
||||||
img = np.squeeze(img)
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = img[:, :, [2, 1, 0]]
|
|
||||||
cv2.imwrite(img_path, img)
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# get single image of size HxWxn_channles (BGR)
|
|
||||||
# --------------------------------------------
|
|
||||||
def read_img(path):
|
|
||||||
# read image by cv2
|
|
||||||
# return: Numpy float32, HWC, BGR, [0,1]
|
|
||||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
|
|
||||||
img = img.astype(np.float32) / 255.0
|
|
||||||
if img.ndim == 2:
|
|
||||||
img = np.expand_dims(img, axis=2)
|
|
||||||
# some images have 4 channels
|
|
||||||
if img.shape[2] > 3:
|
|
||||||
img = img[:, :, :3]
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# image format conversion
|
|
||||||
# --------------------------------------------
|
|
||||||
# numpy(single) <---> numpy(unit)
|
|
||||||
# numpy(single) <---> tensor
|
|
||||||
# numpy(unit) <---> tensor
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# numpy(single) [0, 1] <---> numpy(unit)
|
|
||||||
# --------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def uint2single(img):
|
|
||||||
return np.float32(img / 255.0)
|
|
||||||
|
|
||||||
|
|
||||||
def single2uint(img):
|
|
||||||
return np.uint8((img.clip(0, 1) * 255.0).round())
|
|
||||||
|
|
||||||
|
|
||||||
def uint162single(img):
|
|
||||||
return np.float32(img / 65535.0)
|
|
||||||
|
|
||||||
|
|
||||||
def single2uint16(img):
|
|
||||||
return np.uint16((img.clip(0, 1) * 65535.0).round())
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# numpy(unit) (HxWxC or HxW) <---> tensor
|
|
||||||
# --------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
# convert uint to 4-dimensional torch tensor
|
|
||||||
def uint2tensor4(img):
|
|
||||||
if img.ndim == 2:
|
|
||||||
img = np.expand_dims(img, axis=2)
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
# convert uint to 3-dimensional torch tensor
|
|
||||||
def uint2tensor3(img):
|
|
||||||
if img.ndim == 2:
|
|
||||||
img = np.expand_dims(img, axis=2)
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0)
|
|
||||||
|
|
||||||
|
|
||||||
# convert 2/3/4-dimensional torch tensor to uint
|
|
||||||
def tensor2uint(img):
|
|
||||||
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = np.transpose(img, (1, 2, 0))
|
|
||||||
return np.uint8((img * 255.0).round())
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# numpy(single) (HxWxC) <---> tensor
|
|
||||||
# --------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
# convert single (HxWxC) to 3-dimensional torch tensor
|
|
||||||
def single2tensor3(img):
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
|
|
||||||
|
|
||||||
|
|
||||||
# convert single (HxWxC) to 4-dimensional torch tensor
|
|
||||||
def single2tensor4(img):
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
# convert torch tensor to single
|
|
||||||
def tensor2single(img):
|
|
||||||
img = img.data.squeeze().float().cpu().numpy()
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = np.transpose(img, (1, 2, 0))
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
# convert torch tensor to single
|
|
||||||
def tensor2single3(img):
|
|
||||||
img = img.data.squeeze().float().cpu().numpy()
|
|
||||||
if img.ndim == 3:
|
|
||||||
img = np.transpose(img, (1, 2, 0))
|
|
||||||
elif img.ndim == 2:
|
|
||||||
img = np.expand_dims(img, axis=2)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def single2tensor5(img):
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
def single32tensor5(img):
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
def single42tensor4(img):
|
|
||||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
|
|
||||||
|
|
||||||
|
|
||||||
# from skimage.io import imread, imsave
|
|
||||||
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
|
|
||||||
"""
|
|
||||||
Converts a torch Tensor into an image Numpy array of BGR channel order
|
|
||||||
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
|
|
||||||
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
|
|
||||||
"""
|
|
||||||
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
|
|
||||||
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
|
|
||||||
n_dim = tensor.dim()
|
|
||||||
if n_dim == 4:
|
|
||||||
n_img = len(tensor)
|
|
||||||
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
|
|
||||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
|
||||||
elif n_dim == 3:
|
|
||||||
img_np = tensor.numpy()
|
|
||||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
|
||||||
elif n_dim == 2:
|
|
||||||
img_np = tensor.numpy()
|
|
||||||
else:
|
|
||||||
raise TypeError("Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format(n_dim))
|
|
||||||
if out_type == np.uint8:
|
|
||||||
img_np = (img_np * 255.0).round()
|
|
||||||
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
|
|
||||||
return img_np.astype(out_type)
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# Augmentation, flipe and/or rotate
|
|
||||||
# --------------------------------------------
|
|
||||||
# The following two are enough.
|
|
||||||
# (1) augmet_img: numpy image of WxHxC or WxH
|
|
||||||
# (2) augment_img_tensor4: tensor image 1xCxWxH
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def augment_img(img, mode=0):
|
|
||||||
"""Kai Zhang (github: https://github.com/cszn)"""
|
|
||||||
if mode == 0:
|
|
||||||
return img
|
|
||||||
elif mode == 1:
|
|
||||||
return np.flipud(np.rot90(img))
|
|
||||||
elif mode == 2:
|
|
||||||
return np.flipud(img)
|
|
||||||
elif mode == 3:
|
|
||||||
return np.rot90(img, k=3)
|
|
||||||
elif mode == 4:
|
|
||||||
return np.flipud(np.rot90(img, k=2))
|
|
||||||
elif mode == 5:
|
|
||||||
return np.rot90(img)
|
|
||||||
elif mode == 6:
|
|
||||||
return np.rot90(img, k=2)
|
|
||||||
elif mode == 7:
|
|
||||||
return np.flipud(np.rot90(img, k=3))
|
|
||||||
|
|
||||||
|
|
||||||
def augment_img_tensor4(img, mode=0):
|
|
||||||
"""Kai Zhang (github: https://github.com/cszn)"""
|
|
||||||
if mode == 0:
|
|
||||||
return img
|
|
||||||
elif mode == 1:
|
|
||||||
return img.rot90(1, [2, 3]).flip([2])
|
|
||||||
elif mode == 2:
|
|
||||||
return img.flip([2])
|
|
||||||
elif mode == 3:
|
|
||||||
return img.rot90(3, [2, 3])
|
|
||||||
elif mode == 4:
|
|
||||||
return img.rot90(2, [2, 3]).flip([2])
|
|
||||||
elif mode == 5:
|
|
||||||
return img.rot90(1, [2, 3])
|
|
||||||
elif mode == 6:
|
|
||||||
return img.rot90(2, [2, 3])
|
|
||||||
elif mode == 7:
|
|
||||||
return img.rot90(3, [2, 3]).flip([2])
|
|
||||||
|
|
||||||
|
|
||||||
def augment_img_tensor(img, mode=0):
|
|
||||||
"""Kai Zhang (github: https://github.com/cszn)"""
|
|
||||||
img_size = img.size()
|
|
||||||
img_np = img.data.cpu().numpy()
|
|
||||||
if len(img_size) == 3:
|
|
||||||
img_np = np.transpose(img_np, (1, 2, 0))
|
|
||||||
elif len(img_size) == 4:
|
|
||||||
img_np = np.transpose(img_np, (2, 3, 1, 0))
|
|
||||||
img_np = augment_img(img_np, mode=mode)
|
|
||||||
img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
|
|
||||||
if len(img_size) == 3:
|
|
||||||
img_tensor = img_tensor.permute(2, 0, 1)
|
|
||||||
elif len(img_size) == 4:
|
|
||||||
img_tensor = img_tensor.permute(3, 2, 0, 1)
|
|
||||||
|
|
||||||
return img_tensor.type_as(img)
|
|
||||||
|
|
||||||
|
|
||||||
def augment_img_np3(img, mode=0):
|
|
||||||
if mode == 0:
|
|
||||||
return img
|
|
||||||
elif mode == 1:
|
|
||||||
return img.transpose(1, 0, 2)
|
|
||||||
elif mode == 2:
|
|
||||||
return img[::-1, :, :]
|
|
||||||
elif mode == 3:
|
|
||||||
img = img[::-1, :, :]
|
|
||||||
img = img.transpose(1, 0, 2)
|
|
||||||
return img
|
|
||||||
elif mode == 4:
|
|
||||||
return img[:, ::-1, :]
|
|
||||||
elif mode == 5:
|
|
||||||
img = img[:, ::-1, :]
|
|
||||||
img = img.transpose(1, 0, 2)
|
|
||||||
return img
|
|
||||||
elif mode == 6:
|
|
||||||
img = img[:, ::-1, :]
|
|
||||||
img = img[::-1, :, :]
|
|
||||||
return img
|
|
||||||
elif mode == 7:
|
|
||||||
img = img[:, ::-1, :]
|
|
||||||
img = img[::-1, :, :]
|
|
||||||
img = img.transpose(1, 0, 2)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def augment_imgs(img_list, hflip=True, rot=True):
|
|
||||||
# horizontal flip OR rotate
|
|
||||||
hflip = hflip and random.random() < 0.5
|
|
||||||
vflip = rot and random.random() < 0.5
|
|
||||||
rot90 = rot and random.random() < 0.5
|
|
||||||
|
|
||||||
def _augment(img):
|
|
||||||
if hflip:
|
|
||||||
img = img[:, ::-1, :]
|
|
||||||
if vflip:
|
|
||||||
img = img[::-1, :, :]
|
|
||||||
if rot90:
|
|
||||||
img = img.transpose(1, 0, 2)
|
|
||||||
return img
|
|
||||||
|
|
||||||
return [_augment(img) for img in img_list]
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# modcrop and shave
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def modcrop(img_in, scale):
|
|
||||||
# img_in: Numpy, HWC or HW
|
|
||||||
img = np.copy(img_in)
|
|
||||||
if img.ndim == 2:
|
|
||||||
H, W = img.shape
|
|
||||||
H_r, W_r = H % scale, W % scale
|
|
||||||
img = img[: H - H_r, : W - W_r]
|
|
||||||
elif img.ndim == 3:
|
|
||||||
H, W, C = img.shape
|
|
||||||
H_r, W_r = H % scale, W % scale
|
|
||||||
img = img[: H - H_r, : W - W_r, :]
|
|
||||||
else:
|
|
||||||
raise ValueError("Wrong img ndim: [{:d}].".format(img.ndim))
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def shave(img_in, border=0):
|
|
||||||
# img_in: Numpy, HWC or HW
|
|
||||||
img = np.copy(img_in)
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
img = img[border : h - border, border : w - border]
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# image processing process on numpy image
|
|
||||||
# channel_convert(in_c, tar_type, img_list):
|
|
||||||
# rgb2ycbcr(img, only_y=True):
|
|
||||||
# bgr2ycbcr(img, only_y=True):
|
|
||||||
# ycbcr2rgb(img):
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def rgb2ycbcr(img, only_y=True):
|
|
||||||
"""same as matlab rgb2ycbcr
|
|
||||||
only_y: only return Y channel
|
|
||||||
Input:
|
|
||||||
uint8, [0, 255]
|
|
||||||
float, [0, 1]
|
|
||||||
"""
|
|
||||||
in_img_type = img.dtype
|
|
||||||
img.astype(np.float32)
|
|
||||||
if in_img_type != np.uint8:
|
|
||||||
img *= 255.0
|
|
||||||
# convert
|
|
||||||
if only_y:
|
|
||||||
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
|
||||||
else:
|
|
||||||
rlt = np.matmul(
|
|
||||||
img,
|
|
||||||
[
|
|
||||||
[65.481, -37.797, 112.0],
|
|
||||||
[128.553, -74.203, -93.786],
|
|
||||||
[24.966, 112.0, -18.214],
|
|
||||||
],
|
|
||||||
) / 255.0 + [16, 128, 128]
|
|
||||||
if in_img_type == np.uint8:
|
|
||||||
rlt = rlt.round()
|
|
||||||
else:
|
|
||||||
rlt /= 255.0
|
|
||||||
return rlt.astype(in_img_type)
|
|
||||||
|
|
||||||
|
|
||||||
def ycbcr2rgb(img):
|
|
||||||
"""same as matlab ycbcr2rgb
|
|
||||||
Input:
|
|
||||||
uint8, [0, 255]
|
|
||||||
float, [0, 1]
|
|
||||||
"""
|
|
||||||
in_img_type = img.dtype
|
|
||||||
img.astype(np.float32)
|
|
||||||
if in_img_type != np.uint8:
|
|
||||||
img *= 255.0
|
|
||||||
# convert
|
|
||||||
rlt = np.matmul(
|
|
||||||
img,
|
|
||||||
[
|
|
||||||
[0.00456621, 0.00456621, 0.00456621],
|
|
||||||
[0, -0.00153632, 0.00791071],
|
|
||||||
[0.00625893, -0.00318811, 0],
|
|
||||||
],
|
|
||||||
) * 255.0 + [-222.921, 135.576, -276.836]
|
|
||||||
if in_img_type == np.uint8:
|
|
||||||
rlt = rlt.round()
|
|
||||||
else:
|
|
||||||
rlt /= 255.0
|
|
||||||
return rlt.astype(in_img_type)
|
|
||||||
|
|
||||||
|
|
||||||
def bgr2ycbcr(img, only_y=True):
|
|
||||||
"""bgr version of rgb2ycbcr
|
|
||||||
only_y: only return Y channel
|
|
||||||
Input:
|
|
||||||
uint8, [0, 255]
|
|
||||||
float, [0, 1]
|
|
||||||
"""
|
|
||||||
in_img_type = img.dtype
|
|
||||||
img.astype(np.float32)
|
|
||||||
if in_img_type != np.uint8:
|
|
||||||
img *= 255.0
|
|
||||||
# convert
|
|
||||||
if only_y:
|
|
||||||
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
|
||||||
else:
|
|
||||||
rlt = np.matmul(
|
|
||||||
img,
|
|
||||||
[
|
|
||||||
[24.966, 112.0, -18.214],
|
|
||||||
[128.553, -74.203, -93.786],
|
|
||||||
[65.481, -37.797, 112.0],
|
|
||||||
],
|
|
||||||
) / 255.0 + [16, 128, 128]
|
|
||||||
if in_img_type == np.uint8:
|
|
||||||
rlt = rlt.round()
|
|
||||||
else:
|
|
||||||
rlt /= 255.0
|
|
||||||
return rlt.astype(in_img_type)
|
|
||||||
|
|
||||||
|
|
||||||
def channel_convert(in_c, tar_type, img_list):
|
|
||||||
# conversion among BGR, gray and y
|
|
||||||
if in_c == 3 and tar_type == "gray": # BGR to gray
|
|
||||||
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
|
|
||||||
return [np.expand_dims(img, axis=2) for img in gray_list]
|
|
||||||
elif in_c == 3 and tar_type == "y": # BGR to y
|
|
||||||
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
|
|
||||||
return [np.expand_dims(img, axis=2) for img in y_list]
|
|
||||||
elif in_c == 1 and tar_type == "RGB": # gray/y to BGR
|
|
||||||
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
|
|
||||||
else:
|
|
||||||
return img_list
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# metric, PSNR and SSIM
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# PSNR
|
|
||||||
# --------------------------------------------
|
|
||||||
def calculate_psnr(img1, img2, border=0):
|
|
||||||
# img1 and img2 have range [0, 255]
|
|
||||||
# img1 = img1.squeeze()
|
|
||||||
# img2 = img2.squeeze()
|
|
||||||
if not img1.shape == img2.shape:
|
|
||||||
raise ValueError("Input images must have the same dimensions.")
|
|
||||||
h, w = img1.shape[:2]
|
|
||||||
img1 = img1[border : h - border, border : w - border]
|
|
||||||
img2 = img2[border : h - border, border : w - border]
|
|
||||||
|
|
||||||
img1 = img1.astype(np.float64)
|
|
||||||
img2 = img2.astype(np.float64)
|
|
||||||
mse = np.mean((img1 - img2) ** 2)
|
|
||||||
if mse == 0:
|
|
||||||
return float("inf")
|
|
||||||
return 20 * math.log10(255.0 / math.sqrt(mse))
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# SSIM
|
|
||||||
# --------------------------------------------
|
|
||||||
def calculate_ssim(img1, img2, border=0):
|
|
||||||
"""calculate SSIM
|
|
||||||
the same outputs as MATLAB's
|
|
||||||
img1, img2: [0, 255]
|
|
||||||
"""
|
|
||||||
# img1 = img1.squeeze()
|
|
||||||
# img2 = img2.squeeze()
|
|
||||||
if not img1.shape == img2.shape:
|
|
||||||
raise ValueError("Input images must have the same dimensions.")
|
|
||||||
h, w = img1.shape[:2]
|
|
||||||
img1 = img1[border : h - border, border : w - border]
|
|
||||||
img2 = img2[border : h - border, border : w - border]
|
|
||||||
|
|
||||||
if img1.ndim == 2:
|
|
||||||
return ssim(img1, img2)
|
|
||||||
elif img1.ndim == 3:
|
|
||||||
if img1.shape[2] == 3:
|
|
||||||
ssims = []
|
|
||||||
for i in range(3):
|
|
||||||
ssims.append(ssim(img1[:, :, i], img2[:, :, i]))
|
|
||||||
return np.array(ssims).mean()
|
|
||||||
elif img1.shape[2] == 1:
|
|
||||||
return ssim(np.squeeze(img1), np.squeeze(img2))
|
|
||||||
else:
|
|
||||||
raise ValueError("Wrong input image dimensions.")
|
|
||||||
|
|
||||||
|
|
||||||
def ssim(img1, img2):
|
|
||||||
C1 = (0.01 * 255) ** 2
|
|
||||||
C2 = (0.03 * 255) ** 2
|
|
||||||
|
|
||||||
img1 = img1.astype(np.float64)
|
|
||||||
img2 = img2.astype(np.float64)
|
|
||||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
|
||||||
window = np.outer(kernel, kernel.transpose())
|
|
||||||
|
|
||||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
|
||||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
|
||||||
mu1_sq = mu1**2
|
|
||||||
mu2_sq = mu2**2
|
|
||||||
mu1_mu2 = mu1 * mu2
|
|
||||||
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
|
||||||
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
|
||||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
|
||||||
|
|
||||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
|
||||||
return ssim_map.mean()
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# --------------------------------------------
|
|
||||||
# matlab's bicubic imresize (numpy and torch) [0, 1]
|
|
||||||
# --------------------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# matlab 'imresize' function, now only support 'bicubic'
|
|
||||||
def cubic(x):
|
|
||||||
absx = torch.abs(x)
|
|
||||||
absx2 = absx**2
|
|
||||||
absx3 = absx**3
|
|
||||||
return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
|
|
||||||
-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
|
|
||||||
) * (((absx > 1) * (absx <= 2)).type_as(absx))
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
|
||||||
if (scale < 1) and (antialiasing):
|
|
||||||
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
|
|
||||||
kernel_width = kernel_width / scale
|
|
||||||
|
|
||||||
# Output-space coordinates
|
|
||||||
x = torch.linspace(1, out_length, out_length)
|
|
||||||
|
|
||||||
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
|
||||||
# in output space maps to 0.5 in input space, and 0.5+scale in output
|
|
||||||
# space maps to 1.5 in input space.
|
|
||||||
u = x / scale + 0.5 * (1 - 1 / scale)
|
|
||||||
|
|
||||||
# What is the left-most pixel that can be involved in the computation?
|
|
||||||
left = torch.floor(u - kernel_width / 2)
|
|
||||||
|
|
||||||
# What is the maximum number of pixels that can be involved in the
|
|
||||||
# computation? Note: it's OK to use an extra pixel here; if the
|
|
||||||
# corresponding weights are all zero, it will be eliminated at the end
|
|
||||||
# of this function.
|
|
||||||
P = math.ceil(kernel_width) + 2
|
|
||||||
|
|
||||||
# The indices of the input pixels involved in computing the k-th output
|
|
||||||
# pixel are in row k of the indices matrix.
|
|
||||||
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(1, P).expand(
|
|
||||||
out_length, P
|
|
||||||
)
|
|
||||||
|
|
||||||
# The weights used to compute the k-th output pixel are in row k of the
|
|
||||||
# weights matrix.
|
|
||||||
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
|
|
||||||
# apply cubic kernel
|
|
||||||
if (scale < 1) and (antialiasing):
|
|
||||||
weights = scale * cubic(distance_to_center * scale)
|
|
||||||
else:
|
|
||||||
weights = cubic(distance_to_center)
|
|
||||||
# Normalize the weights matrix so that each row sums to 1.
|
|
||||||
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
|
||||||
weights = weights / weights_sum.expand(out_length, P)
|
|
||||||
|
|
||||||
# If a column in weights is all zero, get rid of it. only consider the first and last column.
|
|
||||||
weights_zero_tmp = torch.sum((weights == 0), 0)
|
|
||||||
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
|
||||||
indices = indices.narrow(1, 1, P - 2)
|
|
||||||
weights = weights.narrow(1, 1, P - 2)
|
|
||||||
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
|
||||||
indices = indices.narrow(1, 0, P - 2)
|
|
||||||
weights = weights.narrow(1, 0, P - 2)
|
|
||||||
weights = weights.contiguous()
|
|
||||||
indices = indices.contiguous()
|
|
||||||
sym_len_s = -indices.min() + 1
|
|
||||||
sym_len_e = indices.max() - in_length
|
|
||||||
indices = indices + sym_len_s - 1
|
|
||||||
return weights, indices, int(sym_len_s), int(sym_len_e)
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# imresize for tensor image [0, 1]
|
|
||||||
# --------------------------------------------
|
|
||||||
def imresize(img, scale, antialiasing=True):
|
|
||||||
# Now the scale should be the same for H and W
|
|
||||||
# input: img: pytorch tensor, CHW or HW [0,1]
|
|
||||||
# output: CHW or HW [0,1] w/o round
|
|
||||||
need_squeeze = True if img.dim() == 2 else False
|
|
||||||
if need_squeeze:
|
|
||||||
img.unsqueeze_(0)
|
|
||||||
in_C, in_H, in_W = img.size()
|
|
||||||
out_C, out_H, out_W = (
|
|
||||||
in_C,
|
|
||||||
math.ceil(in_H * scale),
|
|
||||||
math.ceil(in_W * scale),
|
|
||||||
)
|
|
||||||
kernel_width = 4
|
|
||||||
kernel = "cubic"
|
|
||||||
|
|
||||||
# Return the desired dimension order for performing the resize. The
|
|
||||||
# strategy is to perform the resize first along the dimension with the
|
|
||||||
# smallest scale factor.
|
|
||||||
# Now we do not support this.
|
|
||||||
|
|
||||||
# get weights and indices
|
|
||||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
|
||||||
in_H, out_H, scale, kernel, kernel_width, antialiasing
|
|
||||||
)
|
|
||||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
|
||||||
in_W, out_W, scale, kernel, kernel_width, antialiasing
|
|
||||||
)
|
|
||||||
# process H dimension
|
|
||||||
# symmetric copying
|
|
||||||
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
|
|
||||||
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
|
|
||||||
|
|
||||||
sym_patch = img[:, :sym_len_Hs, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
|
||||||
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
sym_patch = img[:, -sym_len_He:, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
|
||||||
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
out_1 = torch.FloatTensor(in_C, out_H, in_W)
|
|
||||||
kernel_width = weights_H.size(1)
|
|
||||||
for i in range(out_H):
|
|
||||||
idx = int(indices_H[i][0])
|
|
||||||
for j in range(out_C):
|
|
||||||
out_1[j, i, :] = img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
|
||||||
|
|
||||||
# process W dimension
|
|
||||||
# symmetric copying
|
|
||||||
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
|
|
||||||
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
|
|
||||||
|
|
||||||
sym_patch = out_1[:, :, :sym_len_Ws]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
|
||||||
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
sym_patch = out_1[:, :, -sym_len_We:]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
|
||||||
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
out_2 = torch.FloatTensor(in_C, out_H, out_W)
|
|
||||||
kernel_width = weights_W.size(1)
|
|
||||||
for i in range(out_W):
|
|
||||||
idx = int(indices_W[i][0])
|
|
||||||
for j in range(out_C):
|
|
||||||
out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(weights_W[i])
|
|
||||||
if need_squeeze:
|
|
||||||
out_2.squeeze_()
|
|
||||||
return out_2
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
|
||||||
# imresize for numpy image [0, 1]
|
|
||||||
# --------------------------------------------
|
|
||||||
def imresize_np(img, scale, antialiasing=True):
|
|
||||||
# Now the scale should be the same for H and W
|
|
||||||
# input: img: Numpy, HWC or HW [0,1]
|
|
||||||
# output: HWC or HW [0,1] w/o round
|
|
||||||
img = torch.from_numpy(img)
|
|
||||||
need_squeeze = True if img.dim() == 2 else False
|
|
||||||
if need_squeeze:
|
|
||||||
img.unsqueeze_(2)
|
|
||||||
|
|
||||||
in_H, in_W, in_C = img.size()
|
|
||||||
out_C, out_H, out_W = (
|
|
||||||
in_C,
|
|
||||||
math.ceil(in_H * scale),
|
|
||||||
math.ceil(in_W * scale),
|
|
||||||
)
|
|
||||||
kernel_width = 4
|
|
||||||
kernel = "cubic"
|
|
||||||
|
|
||||||
# Return the desired dimension order for performing the resize. The
|
|
||||||
# strategy is to perform the resize first along the dimension with the
|
|
||||||
# smallest scale factor.
|
|
||||||
# Now we do not support this.
|
|
||||||
|
|
||||||
# get weights and indices
|
|
||||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
|
||||||
in_H, out_H, scale, kernel, kernel_width, antialiasing
|
|
||||||
)
|
|
||||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
|
||||||
in_W, out_W, scale, kernel, kernel_width, antialiasing
|
|
||||||
)
|
|
||||||
# process H dimension
|
|
||||||
# symmetric copying
|
|
||||||
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
|
|
||||||
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
|
|
||||||
|
|
||||||
sym_patch = img[:sym_len_Hs, :, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
|
||||||
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
sym_patch = img[-sym_len_He:, :, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
|
||||||
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
out_1 = torch.FloatTensor(out_H, in_W, in_C)
|
|
||||||
kernel_width = weights_H.size(1)
|
|
||||||
for i in range(out_H):
|
|
||||||
idx = int(indices_H[i][0])
|
|
||||||
for j in range(out_C):
|
|
||||||
out_1[i, :, j] = img_aug[idx : idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
|
|
||||||
|
|
||||||
# process W dimension
|
|
||||||
# symmetric copying
|
|
||||||
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
|
|
||||||
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
|
|
||||||
|
|
||||||
sym_patch = out_1[:, :sym_len_Ws, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
|
||||||
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
sym_patch = out_1[:, -sym_len_We:, :]
|
|
||||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
|
||||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
|
||||||
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
|
||||||
|
|
||||||
out_2 = torch.FloatTensor(out_H, out_W, in_C)
|
|
||||||
kernel_width = weights_W.size(1)
|
|
||||||
for i in range(out_W):
|
|
||||||
idx = int(indices_W[i][0])
|
|
||||||
for j in range(out_C):
|
|
||||||
out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv(weights_W[i])
|
|
||||||
if need_squeeze:
|
|
||||||
out_2.squeeze_()
|
|
||||||
|
|
||||||
return out_2.numpy()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("---")
|
|
||||||
# img = imread_uint('test.bmp', 3)
|
|
||||||
# img = uint2single(img)
|
|
||||||
# img_bicubic = imresize_np(img, 1/4)
|
|
@ -10,7 +10,6 @@ from .devices import ( # noqa: F401
|
|||||||
normalize_device,
|
normalize_device,
|
||||||
torch_dtype,
|
torch_dtype,
|
||||||
)
|
)
|
||||||
from .log import write_log # noqa: F401
|
|
||||||
from .util import ( # noqa: F401
|
from .util import ( # noqa: F401
|
||||||
ask_user,
|
ask_user,
|
||||||
download_with_resume,
|
download_with_resume,
|
||||||
|
@ -75,6 +75,7 @@
|
|||||||
"@reduxjs/toolkit": "^1.9.5",
|
"@reduxjs/toolkit": "^1.9.5",
|
||||||
"@roarr/browser-log-writer": "^1.1.5",
|
"@roarr/browser-log-writer": "^1.1.5",
|
||||||
"@stevebel/png": "^1.5.1",
|
"@stevebel/png": "^1.5.1",
|
||||||
|
"compare-versions": "^6.1.0",
|
||||||
"dateformat": "^5.0.3",
|
"dateformat": "^5.0.3",
|
||||||
"formik": "^2.4.3",
|
"formik": "^2.4.3",
|
||||||
"framer-motion": "^10.16.1",
|
"framer-motion": "^10.16.1",
|
||||||
|
@ -84,6 +84,7 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
|||||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||||
|
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
|
|
||||||
@ -202,6 +203,9 @@ addBoardIdSelectedListener();
|
|||||||
// Node schemas
|
// Node schemas
|
||||||
addReceivedOpenAPISchemaListener();
|
addReceivedOpenAPISchemaListener();
|
||||||
|
|
||||||
|
// Workflows
|
||||||
|
addWorkflowLoadedListener();
|
||||||
|
|
||||||
// DND
|
// DND
|
||||||
addImageDroppedListener();
|
addImageDroppedListener();
|
||||||
|
|
||||||
|
@ -0,0 +1,55 @@
|
|||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||||
|
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||||
|
import { validateWorkflow } from 'features/nodes/util/validateWorkflow';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
|
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
|
export const addWorkflowLoadedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: workflowLoadRequested,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
const log = logger('nodes');
|
||||||
|
const workflow = action.payload;
|
||||||
|
const nodeTemplates = getState().nodes.nodeTemplates;
|
||||||
|
|
||||||
|
const { workflow: validatedWorkflow, errors } = validateWorkflow(
|
||||||
|
workflow,
|
||||||
|
nodeTemplates
|
||||||
|
);
|
||||||
|
|
||||||
|
dispatch(workflowLoaded(validatedWorkflow));
|
||||||
|
|
||||||
|
if (!errors.length) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: 'Workflow Loaded',
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: 'Workflow Loaded with Warnings',
|
||||||
|
status: 'warning',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
errors.forEach(({ message, ...rest }) => {
|
||||||
|
log.warn(rest, message);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(setActiveTab('nodes'));
|
||||||
|
requestAnimationFrame(() => {
|
||||||
|
$flow.get()?.fitView();
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -63,7 +63,11 @@ const selector = createSelector(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (fieldTemplate.required && !field.value && !hasConnection) {
|
if (
|
||||||
|
fieldTemplate.required &&
|
||||||
|
field.value === undefined &&
|
||||||
|
!hasConnection
|
||||||
|
) {
|
||||||
reasons.push(
|
reasons.push(
|
||||||
`${node.data.label || nodeTemplate.title} -> ${
|
`${node.data.label || nodeTemplate.title} -> ${
|
||||||
field.label || fieldTemplate.title
|
field.label || fieldTemplate.title
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
export const colorTokenToCssVar = (colorToken: string) =>
|
export const colorTokenToCssVar = (colorToken: string) =>
|
||||||
`var(--invokeai-colors-${colorToken.split('.').join('-')}`;
|
`var(--invokeai-colors-${colorToken.split('.').join('-')})`;
|
||||||
|
@ -17,16 +17,13 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
|
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
|
||||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||||
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
|
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
|
||||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import {
|
import {
|
||||||
setActiveTab,
|
|
||||||
setShouldShowImageDetails,
|
setShouldShowImageDetails,
|
||||||
setShouldShowProgressInViewer,
|
setShouldShowProgressInViewer,
|
||||||
} from 'features/ui/store/uiSlice';
|
} from 'features/ui/store/uiSlice';
|
||||||
@ -110,7 +107,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||||
lastSelectedImage?.image_name ?? skipToken,
|
lastSelectedImage ?? skipToken,
|
||||||
{
|
{
|
||||||
selectFromResult: (res) => ({
|
selectFromResult: (res) => ({
|
||||||
isLoading: res.isFetching,
|
isLoading: res.isFetching,
|
||||||
@ -124,16 +121,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
if (!workflow) {
|
if (!workflow) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(workflowLoaded(workflow));
|
dispatch(workflowLoadRequested(workflow));
|
||||||
dispatch(setActiveTab('nodes'));
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: 'Workflow Loaded',
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}, [dispatch, workflow]);
|
}, [dispatch, workflow]);
|
||||||
|
|
||||||
const handleClickUseAllParameters = useCallback(() => {
|
const handleClickUseAllParameters = useCallback(() => {
|
||||||
|
@ -7,12 +7,9 @@ import {
|
|||||||
isModalOpenChanged,
|
isModalOpenChanged,
|
||||||
} from 'features/changeBoardModal/store/slice';
|
} from 'features/changeBoardModal/store/slice';
|
||||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
|
||||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
||||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
@ -36,6 +33,7 @@ import {
|
|||||||
} from 'services/api/endpoints/images';
|
} from 'services/api/endpoints/images';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
||||||
|
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||||
|
|
||||||
type SingleSelectionMenuItemsProps = {
|
type SingleSelectionMenuItemsProps = {
|
||||||
imageDTO: ImageDTO;
|
imageDTO: ImageDTO;
|
||||||
@ -52,7 +50,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||||
|
|
||||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||||
imageDTO.image_name,
|
imageDTO,
|
||||||
{
|
{
|
||||||
selectFromResult: (res) => ({
|
selectFromResult: (res) => ({
|
||||||
isLoading: res.isFetching,
|
isLoading: res.isFetching,
|
||||||
@ -102,16 +100,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
if (!workflow) {
|
if (!workflow) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(workflowLoaded(workflow));
|
dispatch(workflowLoadRequested(workflow));
|
||||||
dispatch(setActiveTab('nodes'));
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: 'Workflow Loaded',
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}, [dispatch, workflow]);
|
}, [dispatch, workflow]);
|
||||||
|
|
||||||
const handleSendToImageToImage = useCallback(() => {
|
const handleSendToImageToImage = useCallback(() => {
|
||||||
|
@ -101,13 +101,15 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
onClick={handleRecallSeed}
|
onClick={handleRecallSeed}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{metadata.model !== undefined && metadata.model !== null && (
|
{metadata.model !== undefined &&
|
||||||
<ImageMetadataItem
|
metadata.model !== null &&
|
||||||
label="Model"
|
metadata.model.model_name && (
|
||||||
value={metadata.model.model_name}
|
<ImageMetadataItem
|
||||||
onClick={handleRecallModel}
|
label="Model"
|
||||||
/>
|
value={metadata.model.model_name}
|
||||||
)}
|
onClick={handleRecallModel}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
{metadata.width && (
|
{metadata.width && (
|
||||||
<ImageMetadataItem
|
<ImageMetadataItem
|
||||||
label="Width"
|
label="Width"
|
||||||
|
@ -27,15 +27,12 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
|||||||
// dispatch(setShouldShowImageDetails(false));
|
// dispatch(setShouldShowImageDetails(false));
|
||||||
// });
|
// });
|
||||||
|
|
||||||
const { metadata, workflow } = useGetImageMetadataFromFileQuery(
|
const { metadata, workflow } = useGetImageMetadataFromFileQuery(image, {
|
||||||
image.image_name,
|
selectFromResult: (res) => ({
|
||||||
{
|
metadata: res?.currentData?.metadata,
|
||||||
selectFromResult: (res) => ({
|
workflow: res?.currentData?.workflow,
|
||||||
metadata: res?.currentData?.metadata,
|
}),
|
||||||
workflow: res?.currentData?.workflow,
|
});
|
||||||
}),
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
|
@ -3,6 +3,7 @@ import { createSelector } from '@reduxjs/toolkit';
|
|||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||||
import { contextMenusClosed } from 'features/ui/store/uiSlice';
|
import { contextMenusClosed } from 'features/ui/store/uiSlice';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
@ -13,6 +14,7 @@ import {
|
|||||||
OnConnectStart,
|
OnConnectStart,
|
||||||
OnEdgesChange,
|
OnEdgesChange,
|
||||||
OnEdgesDelete,
|
OnEdgesDelete,
|
||||||
|
OnInit,
|
||||||
OnMoveEnd,
|
OnMoveEnd,
|
||||||
OnNodesChange,
|
OnNodesChange,
|
||||||
OnNodesDelete,
|
OnNodesDelete,
|
||||||
@ -147,6 +149,11 @@ export const Flow = () => {
|
|||||||
dispatch(contextMenusClosed());
|
dispatch(contextMenusClosed());
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
|
const onInit: OnInit = useCallback((flow) => {
|
||||||
|
$flow.set(flow);
|
||||||
|
flow.fitView();
|
||||||
|
}, []);
|
||||||
|
|
||||||
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
|
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
dispatch(selectionCopied());
|
dispatch(selectionCopied());
|
||||||
@ -170,6 +177,7 @@ export const Flow = () => {
|
|||||||
edgeTypes={edgeTypes}
|
edgeTypes={edgeTypes}
|
||||||
nodes={nodes}
|
nodes={nodes}
|
||||||
edges={edges}
|
edges={edges}
|
||||||
|
onInit={onInit}
|
||||||
onNodesChange={onNodesChange}
|
onNodesChange={onNodesChange}
|
||||||
onEdgesChange={onEdgesChange}
|
onEdgesChange={onEdgesChange}
|
||||||
onEdgesDelete={onEdgesDelete}
|
onEdgesDelete={onEdgesDelete}
|
||||||
|
@ -12,6 +12,7 @@ import {
|
|||||||
Tooltip,
|
Tooltip,
|
||||||
useDisclosure,
|
useDisclosure,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
|
import { compare } from 'compare-versions';
|
||||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||||
@ -20,6 +21,7 @@ import { isInvocationNodeData } from 'features/nodes/types/types';
|
|||||||
import { memo, useMemo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
import { FaInfoCircle } from 'react-icons/fa';
|
import { FaInfoCircle } from 'react-icons/fa';
|
||||||
import NotesTextarea from './NotesTextarea';
|
import NotesTextarea from './NotesTextarea';
|
||||||
|
import { useDoNodeVersionsMatch } from 'features/nodes/hooks/useDoNodeVersionsMatch';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -29,6 +31,7 @@ const InvocationNodeNotes = ({ nodeId }: Props) => {
|
|||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
const label = useNodeLabel(nodeId);
|
const label = useNodeLabel(nodeId);
|
||||||
const title = useNodeTemplateTitle(nodeId);
|
const title = useNodeTemplateTitle(nodeId);
|
||||||
|
const doVersionsMatch = useDoNodeVersionsMatch(nodeId);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
@ -50,7 +53,11 @@ const InvocationNodeNotes = ({ nodeId }: Props) => {
|
|||||||
>
|
>
|
||||||
<Icon
|
<Icon
|
||||||
as={FaInfoCircle}
|
as={FaInfoCircle}
|
||||||
sx={{ boxSize: 4, w: 8, color: 'base.400' }}
|
sx={{
|
||||||
|
boxSize: 4,
|
||||||
|
w: 8,
|
||||||
|
color: doVersionsMatch ? 'base.400' : 'error.400',
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
@ -92,16 +99,59 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
|||||||
return 'Unknown Node';
|
return 'Unknown Node';
|
||||||
}, [data, nodeTemplate]);
|
}, [data, nodeTemplate]);
|
||||||
|
|
||||||
|
const versionComponent = useMemo(() => {
|
||||||
|
if (!isInvocationNodeData(data) || !nodeTemplate) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!data.version) {
|
||||||
|
return (
|
||||||
|
<Text as="span" sx={{ color: 'error.500' }}>
|
||||||
|
Version unknown
|
||||||
|
</Text>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!nodeTemplate.version) {
|
||||||
|
return (
|
||||||
|
<Text as="span" sx={{ color: 'error.500' }}>
|
||||||
|
Version {data.version} (unknown template)
|
||||||
|
</Text>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (compare(data.version, nodeTemplate.version, '<')) {
|
||||||
|
return (
|
||||||
|
<Text as="span" sx={{ color: 'error.500' }}>
|
||||||
|
Version {data.version} (update node)
|
||||||
|
</Text>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (compare(data.version, nodeTemplate.version, '>')) {
|
||||||
|
return (
|
||||||
|
<Text as="span" sx={{ color: 'error.500' }}>
|
||||||
|
Version {data.version} (update app)
|
||||||
|
</Text>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return <Text as="span">Version {data.version}</Text>;
|
||||||
|
}, [data, nodeTemplate]);
|
||||||
|
|
||||||
if (!isInvocationNodeData(data)) {
|
if (!isInvocationNodeData(data)) {
|
||||||
return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>;
|
return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex sx={{ flexDir: 'column' }}>
|
<Flex sx={{ flexDir: 'column' }}>
|
||||||
<Text sx={{ fontWeight: 600 }}>{title}</Text>
|
<Text as="span" sx={{ fontWeight: 600 }}>
|
||||||
|
{title}
|
||||||
|
</Text>
|
||||||
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
|
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
|
||||||
{nodeTemplate?.description}
|
{nodeTemplate?.description}
|
||||||
</Text>
|
</Text>
|
||||||
|
{versionComponent}
|
||||||
{data?.notes && <Text>{data.notes}</Text>}
|
{data?.notes && <Text>{data.notes}</Text>}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
import { Tooltip } from '@chakra-ui/react';
|
import { Tooltip } from '@chakra-ui/react';
|
||||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||||
import {
|
import {
|
||||||
|
COLLECTION_TYPES,
|
||||||
FIELDS,
|
FIELDS,
|
||||||
HANDLE_TOOLTIP_OPEN_DELAY,
|
HANDLE_TOOLTIP_OPEN_DELAY,
|
||||||
|
MODEL_TYPES,
|
||||||
|
POLYMORPHIC_TYPES,
|
||||||
} from 'features/nodes/types/constants';
|
} from 'features/nodes/types/constants';
|
||||||
import {
|
import {
|
||||||
InputFieldTemplate,
|
InputFieldTemplate,
|
||||||
@ -18,6 +21,7 @@ export const handleBaseStyles: CSSProperties = {
|
|||||||
borderWidth: 0,
|
borderWidth: 0,
|
||||||
zIndex: 1,
|
zIndex: 1,
|
||||||
};
|
};
|
||||||
|
``;
|
||||||
|
|
||||||
export const inputHandleStyles: CSSProperties = {
|
export const inputHandleStyles: CSSProperties = {
|
||||||
left: '-1rem',
|
left: '-1rem',
|
||||||
@ -44,15 +48,25 @@ const FieldHandle = (props: FieldHandleProps) => {
|
|||||||
connectionError,
|
connectionError,
|
||||||
} = props;
|
} = props;
|
||||||
const { name, type } = fieldTemplate;
|
const { name, type } = fieldTemplate;
|
||||||
const { color, title } = FIELDS[type];
|
const { color: typeColor, title } = FIELDS[type];
|
||||||
|
|
||||||
const styles: CSSProperties = useMemo(() => {
|
const styles: CSSProperties = useMemo(() => {
|
||||||
|
const isCollectionType = COLLECTION_TYPES.includes(type);
|
||||||
|
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
|
||||||
|
const isModelType = MODEL_TYPES.includes(type);
|
||||||
|
const color = colorTokenToCssVar(typeColor);
|
||||||
const s: CSSProperties = {
|
const s: CSSProperties = {
|
||||||
backgroundColor: colorTokenToCssVar(color),
|
backgroundColor:
|
||||||
|
isCollectionType || isPolymorphicType
|
||||||
|
? 'var(--invokeai-colors-base-900)'
|
||||||
|
: color,
|
||||||
position: 'absolute',
|
position: 'absolute',
|
||||||
width: '1rem',
|
width: '1rem',
|
||||||
height: '1rem',
|
height: '1rem',
|
||||||
borderWidth: 0,
|
borderWidth: isCollectionType || isPolymorphicType ? 4 : 0,
|
||||||
|
borderStyle: 'solid',
|
||||||
|
borderColor: color,
|
||||||
|
borderRadius: isModelType ? 4 : '100%',
|
||||||
zIndex: 1,
|
zIndex: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -78,11 +92,12 @@ const FieldHandle = (props: FieldHandleProps) => {
|
|||||||
|
|
||||||
return s;
|
return s;
|
||||||
}, [
|
}, [
|
||||||
color,
|
|
||||||
connectionError,
|
connectionError,
|
||||||
handleType,
|
handleType,
|
||||||
isConnectionInProgress,
|
isConnectionInProgress,
|
||||||
isConnectionStartField,
|
isConnectionStartField,
|
||||||
|
type,
|
||||||
|
typeColor,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
const tooltip = useMemo(() => {
|
const tooltip = useMemo(() => {
|
||||||
|
@ -75,6 +75,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
|||||||
sx={{
|
sx={{
|
||||||
display: 'flex',
|
display: 'flex',
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
|
h: 'full',
|
||||||
mb: 0,
|
mb: 0,
|
||||||
px: 1,
|
px: 1,
|
||||||
gap: 2,
|
gap: 2,
|
||||||
|
@ -3,18 +3,10 @@ import { useFieldData } from 'features/nodes/hooks/useFieldData';
|
|||||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import BooleanInputField from './inputs/BooleanInputField';
|
import BooleanInputField from './inputs/BooleanInputField';
|
||||||
import ClipInputField from './inputs/ClipInputField';
|
|
||||||
import CollectionInputField from './inputs/CollectionInputField';
|
|
||||||
import CollectionItemInputField from './inputs/CollectionItemInputField';
|
|
||||||
import ColorInputField from './inputs/ColorInputField';
|
import ColorInputField from './inputs/ColorInputField';
|
||||||
import ConditioningInputField from './inputs/ConditioningInputField';
|
|
||||||
import ControlInputField from './inputs/ControlInputField';
|
|
||||||
import ControlNetModelInputField from './inputs/ControlNetModelInputField';
|
import ControlNetModelInputField from './inputs/ControlNetModelInputField';
|
||||||
import DenoiseMaskInputField from './inputs/DenoiseMaskInputField';
|
|
||||||
import EnumInputField from './inputs/EnumInputField';
|
import EnumInputField from './inputs/EnumInputField';
|
||||||
import ImageCollectionInputField from './inputs/ImageCollectionInputField';
|
|
||||||
import ImageInputField from './inputs/ImageInputField';
|
import ImageInputField from './inputs/ImageInputField';
|
||||||
import LatentsInputField from './inputs/LatentsInputField';
|
|
||||||
import LoRAModelInputField from './inputs/LoRAModelInputField';
|
import LoRAModelInputField from './inputs/LoRAModelInputField';
|
||||||
import MainModelInputField from './inputs/MainModelInputField';
|
import MainModelInputField from './inputs/MainModelInputField';
|
||||||
import NumberInputField from './inputs/NumberInputField';
|
import NumberInputField from './inputs/NumberInputField';
|
||||||
@ -22,8 +14,6 @@ import RefinerModelInputField from './inputs/RefinerModelInputField';
|
|||||||
import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
|
import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
|
||||||
import SchedulerInputField from './inputs/SchedulerInputField';
|
import SchedulerInputField from './inputs/SchedulerInputField';
|
||||||
import StringInputField from './inputs/StringInputField';
|
import StringInputField from './inputs/StringInputField';
|
||||||
import UnetInputField from './inputs/UnetInputField';
|
|
||||||
import VaeInputField from './inputs/VaeInputField';
|
|
||||||
import VaeModelInputField from './inputs/VaeModelInputField';
|
import VaeModelInputField from './inputs/VaeModelInputField';
|
||||||
|
|
||||||
type InputFieldProps = {
|
type InputFieldProps = {
|
||||||
@ -31,7 +21,6 @@ type InputFieldProps = {
|
|||||||
fieldName: string;
|
fieldName: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
// build an individual input element based on the schema
|
|
||||||
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||||
const field = useFieldData(nodeId, fieldName);
|
const field = useFieldData(nodeId, fieldName);
|
||||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
|
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
|
||||||
@ -93,88 +82,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
|
||||||
field?.type === 'LatentsField' &&
|
|
||||||
fieldTemplate?.type === 'LatentsField'
|
|
||||||
) {
|
|
||||||
return (
|
|
||||||
<LatentsInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
field?.type === 'DenoiseMaskField' &&
|
|
||||||
fieldTemplate?.type === 'DenoiseMaskField'
|
|
||||||
) {
|
|
||||||
return (
|
|
||||||
<DenoiseMaskInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
field?.type === 'ConditioningField' &&
|
|
||||||
fieldTemplate?.type === 'ConditioningField'
|
|
||||||
) {
|
|
||||||
return (
|
|
||||||
<ConditioningInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (field?.type === 'UNetField' && fieldTemplate?.type === 'UNetField') {
|
|
||||||
return (
|
|
||||||
<UnetInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (field?.type === 'ClipField' && fieldTemplate?.type === 'ClipField') {
|
|
||||||
return (
|
|
||||||
<ClipInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (field?.type === 'VaeField' && fieldTemplate?.type === 'VaeField') {
|
|
||||||
return (
|
|
||||||
<VaeInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
field?.type === 'ControlField' &&
|
|
||||||
fieldTemplate?.type === 'ControlField'
|
|
||||||
) {
|
|
||||||
return (
|
|
||||||
<ControlInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
field?.type === 'MainModelField' &&
|
field?.type === 'MainModelField' &&
|
||||||
fieldTemplate?.type === 'MainModelField'
|
fieldTemplate?.type === 'MainModelField'
|
||||||
@ -240,29 +147,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (field?.type === 'Collection' && fieldTemplate?.type === 'Collection') {
|
|
||||||
return (
|
|
||||||
<CollectionInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
field?.type === 'CollectionItem' &&
|
|
||||||
fieldTemplate?.type === 'CollectionItem'
|
|
||||||
) {
|
|
||||||
return (
|
|
||||||
<CollectionItemInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
|
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
|
||||||
return (
|
return (
|
||||||
<ColorInputField
|
<ColorInputField
|
||||||
@ -273,19 +157,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
|
||||||
field?.type === 'ImageCollection' &&
|
|
||||||
fieldTemplate?.type === 'ImageCollection'
|
|
||||||
) {
|
|
||||||
return (
|
|
||||||
<ImageCollectionInputField
|
|
||||||
nodeId={nodeId}
|
|
||||||
field={field}
|
|
||||||
fieldTemplate={fieldTemplate}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
field?.type === 'SDXLMainModelField' &&
|
field?.type === 'SDXLMainModelField' &&
|
||||||
fieldTemplate?.type === 'SDXLMainModelField'
|
fieldTemplate?.type === 'SDXLMainModelField'
|
||||||
@ -309,6 +180,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (field && fieldTemplate) {
|
||||||
|
// Fallback for when there is no component for the type
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box p={1}>
|
<Box p={1}>
|
||||||
<Text
|
<Text
|
||||||
|
@ -1,12 +1,17 @@
|
|||||||
import {
|
import {
|
||||||
ControlInputFieldTemplate,
|
ControlInputFieldTemplate,
|
||||||
ControlInputFieldValue,
|
ControlInputFieldValue,
|
||||||
|
ControlPolymorphicInputFieldTemplate,
|
||||||
|
ControlPolymorphicInputFieldValue,
|
||||||
FieldComponentProps,
|
FieldComponentProps,
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
|
||||||
const ControlInputFieldComponent = (
|
const ControlInputFieldComponent = (
|
||||||
_props: FieldComponentProps<ControlInputFieldValue, ControlInputFieldTemplate>
|
_props: FieldComponentProps<
|
||||||
|
ControlInputFieldValue | ControlPolymorphicInputFieldValue,
|
||||||
|
ControlInputFieldTemplate | ControlPolymorphicInputFieldTemplate
|
||||||
|
>
|
||||||
) => {
|
) => {
|
||||||
return null;
|
return null;
|
||||||
};
|
};
|
||||||
|
@ -9,9 +9,9 @@ import {
|
|||||||
} from 'features/dnd/types';
|
} from 'features/dnd/types';
|
||||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import {
|
import {
|
||||||
|
FieldComponentProps,
|
||||||
ImageInputFieldTemplate,
|
ImageInputFieldTemplate,
|
||||||
ImageInputFieldValue,
|
ImageInputFieldValue,
|
||||||
FieldComponentProps,
|
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { FaUndo } from 'react-icons/fa';
|
import { FaUndo } from 'react-icons/fa';
|
||||||
|
@ -2,11 +2,16 @@ import {
|
|||||||
LatentsInputFieldTemplate,
|
LatentsInputFieldTemplate,
|
||||||
LatentsInputFieldValue,
|
LatentsInputFieldValue,
|
||||||
FieldComponentProps,
|
FieldComponentProps,
|
||||||
|
LatentsPolymorphicInputFieldValue,
|
||||||
|
LatentsPolymorphicInputFieldTemplate,
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
|
||||||
const LatentsInputFieldComponent = (
|
const LatentsInputFieldComponent = (
|
||||||
_props: FieldComponentProps<LatentsInputFieldValue, LatentsInputFieldTemplate>
|
_props: FieldComponentProps<
|
||||||
|
LatentsInputFieldValue | LatentsPolymorphicInputFieldValue,
|
||||||
|
LatentsInputFieldTemplate | LatentsPolymorphicInputFieldTemplate
|
||||||
|
>
|
||||||
) => {
|
) => {
|
||||||
return null;
|
return null;
|
||||||
};
|
};
|
||||||
|
@ -9,11 +9,11 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
|||||||
import { numberStringRegex } from 'common/components/IAINumberInput';
|
import { numberStringRegex } from 'common/components/IAINumberInput';
|
||||||
import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import {
|
import {
|
||||||
|
FieldComponentProps,
|
||||||
FloatInputFieldTemplate,
|
FloatInputFieldTemplate,
|
||||||
FloatInputFieldValue,
|
FloatInputFieldValue,
|
||||||
IntegerInputFieldTemplate,
|
IntegerInputFieldTemplate,
|
||||||
IntegerInputFieldValue,
|
IntegerInputFieldValue,
|
||||||
FieldComponentProps,
|
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { memo, useEffect, useMemo, useState } from 'react';
|
import { memo, useEffect, useMemo, useState } from 'react';
|
||||||
|
|
||||||
|
@ -9,13 +9,20 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
|
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
|
||||||
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||||
|
import { nodeExclusivelySelected } from 'features/nodes/store/nodesSlice';
|
||||||
import {
|
import {
|
||||||
DRAG_HANDLE_CLASSNAME,
|
DRAG_HANDLE_CLASSNAME,
|
||||||
NODE_WIDTH,
|
NODE_WIDTH,
|
||||||
} from 'features/nodes/types/constants';
|
} from 'features/nodes/types/constants';
|
||||||
import { NodeStatus } from 'features/nodes/types/types';
|
import { NodeStatus } from 'features/nodes/types/types';
|
||||||
import { contextMenusClosed } from 'features/ui/store/uiSlice';
|
import { contextMenusClosed } from 'features/ui/store/uiSlice';
|
||||||
import { PropsWithChildren, memo, useCallback, useMemo } from 'react';
|
import {
|
||||||
|
MouseEvent,
|
||||||
|
PropsWithChildren,
|
||||||
|
memo,
|
||||||
|
useCallback,
|
||||||
|
useMemo,
|
||||||
|
} from 'react';
|
||||||
|
|
||||||
type NodeWrapperProps = PropsWithChildren & {
|
type NodeWrapperProps = PropsWithChildren & {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -57,9 +64,15 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
|||||||
|
|
||||||
const opacity = useAppSelector((state) => state.nodes.nodeOpacity);
|
const opacity = useAppSelector((state) => state.nodes.nodeOpacity);
|
||||||
|
|
||||||
const handleClick = useCallback(() => {
|
const handleClick = useCallback(
|
||||||
dispatch(contextMenusClosed());
|
(e: MouseEvent<HTMLDivElement>) => {
|
||||||
}, [dispatch]);
|
if (!e.ctrlKey && !e.altKey && !e.metaKey && !e.shiftKey) {
|
||||||
|
dispatch(nodeExclusivelySelected(nodeId));
|
||||||
|
}
|
||||||
|
dispatch(contextMenusClosed());
|
||||||
|
},
|
||||||
|
[dispatch, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box
|
<Box
|
||||||
|
@ -138,13 +138,14 @@ export const useBuildNodeData = () => {
|
|||||||
data: {
|
data: {
|
||||||
id: nodeId,
|
id: nodeId,
|
||||||
type,
|
type,
|
||||||
inputs,
|
version: template.version,
|
||||||
outputs,
|
|
||||||
isOpen: true,
|
|
||||||
label: '',
|
label: '',
|
||||||
notes: '',
|
notes: '',
|
||||||
|
isOpen: true,
|
||||||
embedWorkflow: false,
|
embedWorkflow: false,
|
||||||
isIntermediate: true,
|
isIntermediate: true,
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { compareVersions } from 'compare-versions';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import { isInvocationNode } from '../types/types';
|
||||||
|
|
||||||
|
export const useDoNodeVersionsMatch = (nodeId: string) => {
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ nodes }) => {
|
||||||
|
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||||
|
if (!isInvocationNode(node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
|
||||||
|
if (!nodeTemplate?.version || !node.data?.version) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return compareVersions(nodeTemplate.version, node.data.version) === 0;
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
|
[nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const nodeTemplate = useAppSelector(selector);
|
||||||
|
|
||||||
|
return nodeTemplate;
|
||||||
|
};
|
@ -15,7 +15,7 @@ export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
|
|||||||
if (!isInvocationNode(node)) {
|
if (!isInvocationNode(node)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
return Boolean(node?.data.inputs[fieldName]?.value);
|
return node?.data.inputs[fieldName]?.value !== undefined;
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
),
|
),
|
||||||
|
@ -3,9 +3,19 @@ import graphlib from '@dagrejs/graphlib';
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { Connection, Edge, Node, useReactFlow } from 'reactflow';
|
import { Connection, Edge, Node, useReactFlow } from 'reactflow';
|
||||||
import { COLLECTION_TYPES } from '../types/constants';
|
import {
|
||||||
|
COLLECTION_MAP,
|
||||||
|
COLLECTION_TYPES,
|
||||||
|
POLYMORPHIC_TO_SINGLE_MAP,
|
||||||
|
POLYMORPHIC_TYPES,
|
||||||
|
} from '../types/constants';
|
||||||
import { InvocationNodeData } from '../types/types';
|
import { InvocationNodeData } from '../types/types';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
|
||||||
|
* TODO: Figure out how to do this without duplicating all the logic
|
||||||
|
*/
|
||||||
|
|
||||||
export const useIsValidConnection = () => {
|
export const useIsValidConnection = () => {
|
||||||
const flow = useReactFlow();
|
const flow = useReactFlow();
|
||||||
const shouldValidateGraph = useAppSelector(
|
const shouldValidateGraph = useAppSelector(
|
||||||
@ -42,6 +52,19 @@ export const useIsValidConnection = () => {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
edges
|
||||||
|
.filter((edge) => {
|
||||||
|
return edge.target === target && edge.targetHandle === targetHandle;
|
||||||
|
})
|
||||||
|
.find((edge) => {
|
||||||
|
edge.source === source && edge.sourceHandle === sourceHandle;
|
||||||
|
})
|
||||||
|
) {
|
||||||
|
// We already have a connection from this source to this target
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Connection is invalid if target already has a connection
|
// Connection is invalid if target already has a connection
|
||||||
if (
|
if (
|
||||||
edges.find((edge) => {
|
edges.find((edge) => {
|
||||||
@ -53,21 +76,62 @@ export const useIsValidConnection = () => {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connection types must be the same for a connection
|
/**
|
||||||
if (
|
* Connection types must be the same for a connection, with exceptions:
|
||||||
sourceType !== targetType &&
|
* - CollectionItem can connect to any non-Collection
|
||||||
sourceType !== 'CollectionItem' &&
|
* - Non-Collections can connect to CollectionItem
|
||||||
targetType !== 'CollectionItem'
|
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
|
||||||
) {
|
* - Generic Collection can connect to any other Collection or Polymorphic
|
||||||
if (
|
* - Any Collection can connect to a Generic Collection
|
||||||
!(
|
*/
|
||||||
COLLECTION_TYPES.includes(targetType) &&
|
|
||||||
COLLECTION_TYPES.includes(sourceType)
|
if (sourceType !== targetType) {
|
||||||
)
|
const isCollectionItemToNonCollection =
|
||||||
) {
|
sourceType === 'CollectionItem' &&
|
||||||
return false;
|
!COLLECTION_TYPES.includes(targetType);
|
||||||
}
|
|
||||||
|
const isNonCollectionToCollectionItem =
|
||||||
|
targetType === 'CollectionItem' &&
|
||||||
|
!COLLECTION_TYPES.includes(sourceType) &&
|
||||||
|
!POLYMORPHIC_TYPES.includes(sourceType);
|
||||||
|
|
||||||
|
const isAnythingToPolymorphicOfSameBaseType =
|
||||||
|
POLYMORPHIC_TYPES.includes(targetType) &&
|
||||||
|
(() => {
|
||||||
|
if (!POLYMORPHIC_TYPES.includes(targetType)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const baseType =
|
||||||
|
POLYMORPHIC_TO_SINGLE_MAP[
|
||||||
|
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
|
||||||
|
];
|
||||||
|
|
||||||
|
const collectionType =
|
||||||
|
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
|
||||||
|
|
||||||
|
return sourceType === baseType || sourceType === collectionType;
|
||||||
|
})();
|
||||||
|
|
||||||
|
const isGenericCollectionToAnyCollectionOrPolymorphic =
|
||||||
|
sourceType === 'Collection' &&
|
||||||
|
(COLLECTION_TYPES.includes(targetType) ||
|
||||||
|
POLYMORPHIC_TYPES.includes(targetType));
|
||||||
|
|
||||||
|
const isCollectionToGenericCollection =
|
||||||
|
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
|
||||||
|
|
||||||
|
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
|
||||||
|
|
||||||
|
return (
|
||||||
|
isCollectionItemToNonCollection ||
|
||||||
|
isNonCollectionToCollectionItem ||
|
||||||
|
isAnythingToPolymorphicOfSameBaseType ||
|
||||||
|
isGenericCollectionToAnyCollectionOrPolymorphic ||
|
||||||
|
isCollectionToGenericCollection ||
|
||||||
|
isIntToFloat
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Graphs much be acyclic (no loops!)
|
// Graphs much be acyclic (no loops!)
|
||||||
return getIsGraphAcyclic(source, target, nodes, edges);
|
return getIsGraphAcyclic(source, target, nodes, edges);
|
||||||
},
|
},
|
||||||
|
@ -2,13 +2,13 @@ import { ListItem, Text, UnorderedList } from '@chakra-ui/react';
|
|||||||
import { useLogger } from 'app/logging/useLogger';
|
import { useLogger } from 'app/logging/useLogger';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { parseify } from 'common/util/serialize';
|
import { parseify } from 'common/util/serialize';
|
||||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
import { zWorkflow } from 'features/nodes/types/types';
|
||||||
import { zValidatedWorkflow } from 'features/nodes/types/types';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { ZodError } from 'zod';
|
import { ZodError } from 'zod';
|
||||||
import { fromZodError, fromZodIssue } from 'zod-validation-error';
|
import { fromZodError, fromZodIssue } from 'zod-validation-error';
|
||||||
|
import { workflowLoadRequested } from '../store/actions';
|
||||||
|
|
||||||
export const useLoadWorkflowFromFile = () => {
|
export const useLoadWorkflowFromFile = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
@ -24,7 +24,7 @@ export const useLoadWorkflowFromFile = () => {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
const parsedJSON = JSON.parse(String(rawJSON));
|
const parsedJSON = JSON.parse(String(rawJSON));
|
||||||
const result = zValidatedWorkflow.safeParse(parsedJSON);
|
const result = zWorkflow.safeParse(parsedJSON);
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
const { message } = fromZodError(result.error, {
|
const { message } = fromZodError(result.error, {
|
||||||
@ -45,32 +45,8 @@ export const useLoadWorkflowFromFile = () => {
|
|||||||
reader.abort();
|
reader.abort();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(workflowLoaded(result.data.workflow));
|
|
||||||
|
|
||||||
if (!result.data.warnings.length) {
|
dispatch(workflowLoadRequested(result.data));
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: 'Workflow Loaded',
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
reader.abort();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: 'Workflow Loaded with Warnings',
|
|
||||||
status: 'warning',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
result.data.warnings.forEach(({ message, ...rest }) => {
|
|
||||||
logger.warn(rest, message);
|
|
||||||
});
|
|
||||||
|
|
||||||
reader.abort();
|
reader.abort();
|
||||||
} catch {
|
} catch {
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { createAction, isAnyOf } from '@reduxjs/toolkit';
|
import { createAction, isAnyOf } from '@reduxjs/toolkit';
|
||||||
import { Graph } from 'services/api/types';
|
import { Graph } from 'services/api/types';
|
||||||
|
import { Workflow } from '../types/types';
|
||||||
|
|
||||||
export const textToImageGraphBuilt = createAction<Graph>(
|
export const textToImageGraphBuilt = createAction<Graph>(
|
||||||
'nodes/textToImageGraphBuilt'
|
'nodes/textToImageGraphBuilt'
|
||||||
@ -16,3 +17,7 @@ export const isAnyGraphBuilt = isAnyOf(
|
|||||||
canvasGraphBuilt,
|
canvasGraphBuilt,
|
||||||
nodesGraphBuilt
|
nodesGraphBuilt
|
||||||
);
|
);
|
||||||
|
|
||||||
|
export const workflowLoadRequested = createAction<Workflow>(
|
||||||
|
'nodes/workflowLoadRequested'
|
||||||
|
);
|
||||||
|
@ -443,6 +443,17 @@ const nodesSlice = createSlice({
|
|||||||
}
|
}
|
||||||
node.data.notes = notes;
|
node.data.notes = notes;
|
||||||
},
|
},
|
||||||
|
nodeExclusivelySelected: (state, action: PayloadAction<string>) => {
|
||||||
|
const nodeId = action.payload;
|
||||||
|
state.nodes = applyNodeChanges(
|
||||||
|
state.nodes.map((n) => ({
|
||||||
|
id: n.id,
|
||||||
|
type: 'select',
|
||||||
|
selected: n.id === nodeId ? true : false,
|
||||||
|
})),
|
||||||
|
state.nodes
|
||||||
|
);
|
||||||
|
},
|
||||||
selectedNodesChanged: (state, action: PayloadAction<string[]>) => {
|
selectedNodesChanged: (state, action: PayloadAction<string[]>) => {
|
||||||
state.selectedNodes = action.payload;
|
state.selectedNodes = action.payload;
|
||||||
},
|
},
|
||||||
@ -892,6 +903,7 @@ export const {
|
|||||||
nodeEmbedWorkflowChanged,
|
nodeEmbedWorkflowChanged,
|
||||||
nodeIsIntermediateChanged,
|
nodeIsIntermediateChanged,
|
||||||
mouseOverNodeChanged,
|
mouseOverNodeChanged,
|
||||||
|
nodeExclusivelySelected,
|
||||||
} = nodesSlice.actions;
|
} = nodesSlice.actions;
|
||||||
|
|
||||||
export default nodesSlice.reducer;
|
export default nodesSlice.reducer;
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
import { atom } from 'nanostores';
|
||||||
|
import { ReactFlowInstance } from 'reactflow';
|
||||||
|
|
||||||
|
export const $flow = atom<ReactFlowInstance | null>(null);
|
@ -1,10 +1,20 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection';
|
import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection';
|
||||||
import { COLLECTION_TYPES } from 'features/nodes/types/constants';
|
import {
|
||||||
|
COLLECTION_MAP,
|
||||||
|
COLLECTION_TYPES,
|
||||||
|
POLYMORPHIC_TO_SINGLE_MAP,
|
||||||
|
POLYMORPHIC_TYPES,
|
||||||
|
} from 'features/nodes/types/constants';
|
||||||
import { FieldType } from 'features/nodes/types/types';
|
import { FieldType } from 'features/nodes/types/types';
|
||||||
import { HandleType } from 'reactflow';
|
import { HandleType } from 'reactflow';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
|
||||||
|
* TODO: Figure out how to do this without duplicating all the logic
|
||||||
|
*/
|
||||||
|
|
||||||
export const makeConnectionErrorSelector = (
|
export const makeConnectionErrorSelector = (
|
||||||
nodeId: string,
|
nodeId: string,
|
||||||
fieldName: string,
|
fieldName: string,
|
||||||
@ -19,11 +29,6 @@ export const makeConnectionErrorSelector = (
|
|||||||
const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
|
const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
|
||||||
state.nodes;
|
state.nodes;
|
||||||
|
|
||||||
if (!state.nodes.shouldValidateGraph) {
|
|
||||||
// manual override!
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!connectionStartParams || !currentConnectionFieldType) {
|
if (!connectionStartParams || !currentConnectionFieldType) {
|
||||||
return 'No connection in progress';
|
return 'No connection in progress';
|
||||||
}
|
}
|
||||||
@ -38,9 +43,9 @@ export const makeConnectionErrorSelector = (
|
|||||||
return 'No connection data';
|
return 'No connection data';
|
||||||
}
|
}
|
||||||
|
|
||||||
const targetFieldType =
|
const targetType =
|
||||||
handleType === 'target' ? fieldType : currentConnectionFieldType;
|
handleType === 'target' ? fieldType : currentConnectionFieldType;
|
||||||
const sourceFieldType =
|
const sourceType =
|
||||||
handleType === 'source' ? fieldType : currentConnectionFieldType;
|
handleType === 'source' ? fieldType : currentConnectionFieldType;
|
||||||
|
|
||||||
if (nodeId === connectionNodeId) {
|
if (nodeId === connectionNodeId) {
|
||||||
@ -55,30 +60,73 @@ export const makeConnectionErrorSelector = (
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
fieldType !== currentConnectionFieldType &&
|
|
||||||
fieldType !== 'CollectionItem' &&
|
|
||||||
currentConnectionFieldType !== 'CollectionItem'
|
|
||||||
) {
|
|
||||||
if (
|
|
||||||
!(
|
|
||||||
COLLECTION_TYPES.includes(targetFieldType) &&
|
|
||||||
COLLECTION_TYPES.includes(sourceFieldType)
|
|
||||||
)
|
|
||||||
) {
|
|
||||||
// except for collection items, field types must match
|
|
||||||
return 'Field types must match';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
handleType === 'target' &&
|
|
||||||
edges.find((edge) => {
|
edges.find((edge) => {
|
||||||
return edge.target === nodeId && edge.targetHandle === fieldName;
|
return edge.target === nodeId && edge.targetHandle === fieldName;
|
||||||
}) &&
|
}) &&
|
||||||
// except CollectionItem inputs can have multiples
|
// except CollectionItem inputs can have multiples
|
||||||
targetFieldType !== 'CollectionItem'
|
targetType !== 'CollectionItem'
|
||||||
) {
|
) {
|
||||||
return 'Inputs may only have one connection';
|
return 'Input may only have one connection';
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Connection types must be the same for a connection, with exceptions:
|
||||||
|
* - CollectionItem can connect to any non-Collection
|
||||||
|
* - Non-Collections can connect to CollectionItem
|
||||||
|
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
|
||||||
|
* - Generic Collection can connect to any other Collection or Polymorphic
|
||||||
|
* - Any Collection can connect to a Generic Collection
|
||||||
|
*/
|
||||||
|
|
||||||
|
if (sourceType !== targetType) {
|
||||||
|
const isCollectionItemToNonCollection =
|
||||||
|
sourceType === 'CollectionItem' &&
|
||||||
|
!COLLECTION_TYPES.includes(targetType);
|
||||||
|
|
||||||
|
const isNonCollectionToCollectionItem =
|
||||||
|
targetType === 'CollectionItem' &&
|
||||||
|
!COLLECTION_TYPES.includes(sourceType) &&
|
||||||
|
!POLYMORPHIC_TYPES.includes(sourceType);
|
||||||
|
|
||||||
|
const isAnythingToPolymorphicOfSameBaseType =
|
||||||
|
POLYMORPHIC_TYPES.includes(targetType) &&
|
||||||
|
(() => {
|
||||||
|
if (!POLYMORPHIC_TYPES.includes(targetType)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const baseType =
|
||||||
|
POLYMORPHIC_TO_SINGLE_MAP[
|
||||||
|
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
|
||||||
|
];
|
||||||
|
|
||||||
|
const collectionType =
|
||||||
|
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
|
||||||
|
|
||||||
|
return sourceType === baseType || sourceType === collectionType;
|
||||||
|
})();
|
||||||
|
|
||||||
|
const isGenericCollectionToAnyCollectionOrPolymorphic =
|
||||||
|
sourceType === 'Collection' &&
|
||||||
|
(COLLECTION_TYPES.includes(targetType) ||
|
||||||
|
POLYMORPHIC_TYPES.includes(targetType));
|
||||||
|
|
||||||
|
const isCollectionToGenericCollection =
|
||||||
|
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
|
||||||
|
|
||||||
|
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
|
||||||
|
|
||||||
|
if (
|
||||||
|
!(
|
||||||
|
isCollectionItemToNonCollection ||
|
||||||
|
isNonCollectionToCollectionItem ||
|
||||||
|
isAnythingToPolymorphicOfSameBaseType ||
|
||||||
|
isGenericCollectionToAnyCollectionOrPolymorphic ||
|
||||||
|
isCollectionToGenericCollection ||
|
||||||
|
isIntToFloat
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
return 'Field types must match';
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const isGraphAcyclic = getIsGraphAcyclic(
|
const isGraphAcyclic = getIsGraphAcyclic(
|
||||||
|
@ -17,176 +17,297 @@ export const KIND_MAP = {
|
|||||||
export const COLLECTION_TYPES: FieldType[] = [
|
export const COLLECTION_TYPES: FieldType[] = [
|
||||||
'Collection',
|
'Collection',
|
||||||
'IntegerCollection',
|
'IntegerCollection',
|
||||||
|
'BooleanCollection',
|
||||||
'FloatCollection',
|
'FloatCollection',
|
||||||
'StringCollection',
|
'StringCollection',
|
||||||
'BooleanCollection',
|
|
||||||
'ImageCollection',
|
'ImageCollection',
|
||||||
|
'LatentsCollection',
|
||||||
|
'ConditioningCollection',
|
||||||
|
'ControlCollection',
|
||||||
|
'ColorCollection',
|
||||||
];
|
];
|
||||||
|
|
||||||
|
export const POLYMORPHIC_TYPES = [
|
||||||
|
'IntegerPolymorphic',
|
||||||
|
'BooleanPolymorphic',
|
||||||
|
'FloatPolymorphic',
|
||||||
|
'StringPolymorphic',
|
||||||
|
'ImagePolymorphic',
|
||||||
|
'LatentsPolymorphic',
|
||||||
|
'ConditioningPolymorphic',
|
||||||
|
'ControlPolymorphic',
|
||||||
|
'ColorPolymorphic',
|
||||||
|
];
|
||||||
|
|
||||||
|
export const MODEL_TYPES = [
|
||||||
|
'ControlNetModelField',
|
||||||
|
'LoRAModelField',
|
||||||
|
'MainModelField',
|
||||||
|
'ONNXModelField',
|
||||||
|
'SDXLMainModelField',
|
||||||
|
'SDXLRefinerModelField',
|
||||||
|
'VaeModelField',
|
||||||
|
'UNetField',
|
||||||
|
'VaeField',
|
||||||
|
'ClipField',
|
||||||
|
];
|
||||||
|
|
||||||
|
export const COLLECTION_MAP = {
|
||||||
|
integer: 'IntegerCollection',
|
||||||
|
boolean: 'BooleanCollection',
|
||||||
|
number: 'FloatCollection',
|
||||||
|
float: 'FloatCollection',
|
||||||
|
string: 'StringCollection',
|
||||||
|
ImageField: 'ImageCollection',
|
||||||
|
LatentsField: 'LatentsCollection',
|
||||||
|
ConditioningField: 'ConditioningCollection',
|
||||||
|
ControlField: 'ControlCollection',
|
||||||
|
ColorField: 'ColorCollection',
|
||||||
|
};
|
||||||
|
export const isCollectionItemType = (
|
||||||
|
itemType: string | undefined
|
||||||
|
): itemType is keyof typeof COLLECTION_MAP =>
|
||||||
|
Boolean(itemType && itemType in COLLECTION_MAP);
|
||||||
|
|
||||||
|
export const SINGLE_TO_POLYMORPHIC_MAP = {
|
||||||
|
integer: 'IntegerPolymorphic',
|
||||||
|
boolean: 'BooleanPolymorphic',
|
||||||
|
number: 'FloatPolymorphic',
|
||||||
|
float: 'FloatPolymorphic',
|
||||||
|
string: 'StringPolymorphic',
|
||||||
|
ImageField: 'ImagePolymorphic',
|
||||||
|
LatentsField: 'LatentsPolymorphic',
|
||||||
|
ConditioningField: 'ConditioningPolymorphic',
|
||||||
|
ControlField: 'ControlPolymorphic',
|
||||||
|
ColorField: 'ColorPolymorphic',
|
||||||
|
};
|
||||||
|
|
||||||
|
export const POLYMORPHIC_TO_SINGLE_MAP = {
|
||||||
|
IntegerPolymorphic: 'integer',
|
||||||
|
BooleanPolymorphic: 'boolean',
|
||||||
|
FloatPolymorphic: 'float',
|
||||||
|
StringPolymorphic: 'string',
|
||||||
|
ImagePolymorphic: 'ImageField',
|
||||||
|
LatentsPolymorphic: 'LatentsField',
|
||||||
|
ConditioningPolymorphic: 'ConditioningField',
|
||||||
|
ControlPolymorphic: 'ControlField',
|
||||||
|
ColorPolymorphic: 'ColorField',
|
||||||
|
};
|
||||||
|
|
||||||
|
export const isPolymorphicItemType = (
|
||||||
|
itemType: string | undefined
|
||||||
|
): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP =>
|
||||||
|
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);
|
||||||
|
|
||||||
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||||
integer: {
|
|
||||||
title: 'Integer',
|
|
||||||
description: 'Integers are whole numbers, without a decimal point.',
|
|
||||||
color: 'red.500',
|
|
||||||
},
|
|
||||||
float: {
|
|
||||||
title: 'Float',
|
|
||||||
description: 'Floats are numbers with a decimal point.',
|
|
||||||
color: 'orange.500',
|
|
||||||
},
|
|
||||||
string: {
|
|
||||||
title: 'String',
|
|
||||||
description: 'Strings are text.',
|
|
||||||
color: 'yellow.500',
|
|
||||||
},
|
|
||||||
boolean: {
|
boolean: {
|
||||||
title: 'Boolean',
|
|
||||||
color: 'green.500',
|
color: 'green.500',
|
||||||
description: 'Booleans are true or false.',
|
description: 'Booleans are true or false.',
|
||||||
|
title: 'Boolean',
|
||||||
},
|
},
|
||||||
enum: {
|
BooleanCollection: {
|
||||||
title: 'Enum',
|
color: 'green.500',
|
||||||
description: 'Enums are values that may be one of a number of options.',
|
description: 'A collection of booleans.',
|
||||||
color: 'blue.500',
|
title: 'Boolean Collection',
|
||||||
},
|
},
|
||||||
array: {
|
BooleanPolymorphic: {
|
||||||
title: 'Array',
|
color: 'green.500',
|
||||||
description: 'Enums are values that may be one of a number of options.',
|
description: 'A collection of booleans.',
|
||||||
color: 'base.500',
|
title: 'Boolean Polymorphic',
|
||||||
},
|
|
||||||
ImageField: {
|
|
||||||
title: 'Image',
|
|
||||||
description: 'Images may be passed between nodes.',
|
|
||||||
color: 'purple.500',
|
|
||||||
},
|
|
||||||
DenoiseMaskField: {
|
|
||||||
title: 'Denoise Mask',
|
|
||||||
description: 'Denoise Mask may be passed between nodes',
|
|
||||||
color: 'base.500',
|
|
||||||
},
|
|
||||||
LatentsField: {
|
|
||||||
title: 'Latents',
|
|
||||||
description: 'Latents may be passed between nodes.',
|
|
||||||
color: 'pink.500',
|
|
||||||
},
|
|
||||||
LatentsCollection: {
|
|
||||||
title: 'Latents Collection',
|
|
||||||
description: 'Latents may be passed between nodes.',
|
|
||||||
color: 'pink.500',
|
|
||||||
},
|
|
||||||
ConditioningField: {
|
|
||||||
color: 'cyan.500',
|
|
||||||
title: 'Conditioning',
|
|
||||||
description: 'Conditioning may be passed between nodes.',
|
|
||||||
},
|
|
||||||
ConditioningCollection: {
|
|
||||||
color: 'cyan.500',
|
|
||||||
title: 'Conditioning Collection',
|
|
||||||
description: 'Conditioning may be passed between nodes.',
|
|
||||||
},
|
|
||||||
ImageCollection: {
|
|
||||||
title: 'Image Collection',
|
|
||||||
description: 'A collection of images.',
|
|
||||||
color: 'base.300',
|
|
||||||
},
|
|
||||||
UNetField: {
|
|
||||||
color: 'red.500',
|
|
||||||
title: 'UNet',
|
|
||||||
description: 'UNet submodel.',
|
|
||||||
},
|
},
|
||||||
ClipField: {
|
ClipField: {
|
||||||
color: 'green.500',
|
color: 'green.500',
|
||||||
title: 'Clip',
|
|
||||||
description: 'Tokenizer and text_encoder submodels.',
|
description: 'Tokenizer and text_encoder submodels.',
|
||||||
},
|
title: 'Clip',
|
||||||
VaeField: {
|
|
||||||
color: 'blue.500',
|
|
||||||
title: 'Vae',
|
|
||||||
description: 'Vae submodel.',
|
|
||||||
},
|
|
||||||
ControlField: {
|
|
||||||
color: 'cyan.500',
|
|
||||||
title: 'Control',
|
|
||||||
description: 'Control info passed between nodes.',
|
|
||||||
},
|
|
||||||
MainModelField: {
|
|
||||||
color: 'teal.500',
|
|
||||||
title: 'Model',
|
|
||||||
description: 'TODO',
|
|
||||||
},
|
|
||||||
SDXLRefinerModelField: {
|
|
||||||
color: 'teal.500',
|
|
||||||
title: 'Refiner Model',
|
|
||||||
description: 'TODO',
|
|
||||||
},
|
|
||||||
VaeModelField: {
|
|
||||||
color: 'teal.500',
|
|
||||||
title: 'VAE',
|
|
||||||
description: 'TODO',
|
|
||||||
},
|
|
||||||
LoRAModelField: {
|
|
||||||
color: 'teal.500',
|
|
||||||
title: 'LoRA',
|
|
||||||
description: 'TODO',
|
|
||||||
},
|
|
||||||
ControlNetModelField: {
|
|
||||||
color: 'teal.500',
|
|
||||||
title: 'ControlNet',
|
|
||||||
description: 'TODO',
|
|
||||||
},
|
|
||||||
Scheduler: {
|
|
||||||
color: 'base.500',
|
|
||||||
title: 'Scheduler',
|
|
||||||
description: 'TODO',
|
|
||||||
},
|
},
|
||||||
Collection: {
|
Collection: {
|
||||||
color: 'base.500',
|
color: 'base.500',
|
||||||
title: 'Collection',
|
|
||||||
description: 'TODO',
|
description: 'TODO',
|
||||||
|
title: 'Collection',
|
||||||
},
|
},
|
||||||
CollectionItem: {
|
CollectionItem: {
|
||||||
color: 'base.500',
|
color: 'base.500',
|
||||||
title: 'Collection Item',
|
|
||||||
description: 'TODO',
|
description: 'TODO',
|
||||||
|
title: 'Collection Item',
|
||||||
|
},
|
||||||
|
ColorCollection: {
|
||||||
|
color: 'pink.300',
|
||||||
|
description: 'A collection of colors.',
|
||||||
|
title: 'Color Collection',
|
||||||
},
|
},
|
||||||
ColorField: {
|
ColorField: {
|
||||||
title: 'Color',
|
color: 'pink.300',
|
||||||
description: 'A RGBA color.',
|
description: 'A RGBA color.',
|
||||||
color: 'base.500',
|
title: 'Color',
|
||||||
},
|
},
|
||||||
BooleanCollection: {
|
ColorPolymorphic: {
|
||||||
title: 'Boolean Collection',
|
color: 'pink.300',
|
||||||
description: 'A collection of booleans.',
|
description: 'A collection of colors.',
|
||||||
color: 'green.500',
|
title: 'Color Polymorphic',
|
||||||
},
|
},
|
||||||
IntegerCollection: {
|
ConditioningCollection: {
|
||||||
title: 'Integer Collection',
|
color: 'cyan.500',
|
||||||
description: 'A collection of integers.',
|
description: 'Conditioning may be passed between nodes.',
|
||||||
color: 'red.500',
|
title: 'Conditioning Collection',
|
||||||
|
},
|
||||||
|
ConditioningField: {
|
||||||
|
color: 'cyan.500',
|
||||||
|
description: 'Conditioning may be passed between nodes.',
|
||||||
|
title: 'Conditioning',
|
||||||
|
},
|
||||||
|
ConditioningPolymorphic: {
|
||||||
|
color: 'cyan.500',
|
||||||
|
description: 'Conditioning may be passed between nodes.',
|
||||||
|
title: 'Conditioning Polymorphic',
|
||||||
|
},
|
||||||
|
ControlCollection: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'Control info passed between nodes.',
|
||||||
|
title: 'Control Collection',
|
||||||
|
},
|
||||||
|
ControlField: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'Control info passed between nodes.',
|
||||||
|
title: 'Control',
|
||||||
|
},
|
||||||
|
ControlNetModelField: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'TODO',
|
||||||
|
title: 'ControlNet',
|
||||||
|
},
|
||||||
|
ControlPolymorphic: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'Control info passed between nodes.',
|
||||||
|
title: 'Control Polymorphic',
|
||||||
|
},
|
||||||
|
DenoiseMaskField: {
|
||||||
|
color: 'blue.300',
|
||||||
|
description: 'Denoise Mask may be passed between nodes',
|
||||||
|
title: 'Denoise Mask',
|
||||||
|
},
|
||||||
|
enum: {
|
||||||
|
color: 'blue.500',
|
||||||
|
description: 'Enums are values that may be one of a number of options.',
|
||||||
|
title: 'Enum',
|
||||||
|
},
|
||||||
|
float: {
|
||||||
|
color: 'orange.500',
|
||||||
|
description: 'Floats are numbers with a decimal point.',
|
||||||
|
title: 'Float',
|
||||||
},
|
},
|
||||||
FloatCollection: {
|
FloatCollection: {
|
||||||
color: 'orange.500',
|
color: 'orange.500',
|
||||||
title: 'Float Collection',
|
|
||||||
description: 'A collection of floats.',
|
description: 'A collection of floats.',
|
||||||
|
title: 'Float Collection',
|
||||||
},
|
},
|
||||||
ColorCollection: {
|
FloatPolymorphic: {
|
||||||
color: 'base.500',
|
color: 'orange.500',
|
||||||
title: 'Color Collection',
|
description: 'A collection of floats.',
|
||||||
description: 'A collection of colors.',
|
title: 'Float Polymorphic',
|
||||||
|
},
|
||||||
|
ImageCollection: {
|
||||||
|
color: 'purple.500',
|
||||||
|
description: 'A collection of images.',
|
||||||
|
title: 'Image Collection',
|
||||||
|
},
|
||||||
|
ImageField: {
|
||||||
|
color: 'purple.500',
|
||||||
|
description: 'Images may be passed between nodes.',
|
||||||
|
title: 'Image',
|
||||||
|
},
|
||||||
|
ImagePolymorphic: {
|
||||||
|
color: 'purple.500',
|
||||||
|
description: 'A collection of images.',
|
||||||
|
title: 'Image Polymorphic',
|
||||||
|
},
|
||||||
|
integer: {
|
||||||
|
color: 'red.500',
|
||||||
|
description: 'Integers are whole numbers, without a decimal point.',
|
||||||
|
title: 'Integer',
|
||||||
|
},
|
||||||
|
IntegerCollection: {
|
||||||
|
color: 'red.500',
|
||||||
|
description: 'A collection of integers.',
|
||||||
|
title: 'Integer Collection',
|
||||||
|
},
|
||||||
|
IntegerPolymorphic: {
|
||||||
|
color: 'red.500',
|
||||||
|
description: 'A collection of integers.',
|
||||||
|
title: 'Integer Polymorphic',
|
||||||
|
},
|
||||||
|
LatentsCollection: {
|
||||||
|
color: 'pink.500',
|
||||||
|
description: 'Latents may be passed between nodes.',
|
||||||
|
title: 'Latents Collection',
|
||||||
|
},
|
||||||
|
LatentsField: {
|
||||||
|
color: 'pink.500',
|
||||||
|
description: 'Latents may be passed between nodes.',
|
||||||
|
title: 'Latents',
|
||||||
|
},
|
||||||
|
LatentsPolymorphic: {
|
||||||
|
color: 'pink.500',
|
||||||
|
description: 'Latents may be passed between nodes.',
|
||||||
|
title: 'Latents Polymorphic',
|
||||||
|
},
|
||||||
|
LoRAModelField: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'TODO',
|
||||||
|
title: 'LoRA',
|
||||||
|
},
|
||||||
|
MainModelField: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'TODO',
|
||||||
|
title: 'Model',
|
||||||
},
|
},
|
||||||
ONNXModelField: {
|
ONNXModelField: {
|
||||||
color: 'base.500',
|
color: 'teal.500',
|
||||||
title: 'ONNX Model',
|
|
||||||
description: 'ONNX model field.',
|
description: 'ONNX model field.',
|
||||||
|
title: 'ONNX Model',
|
||||||
|
},
|
||||||
|
Scheduler: {
|
||||||
|
color: 'base.500',
|
||||||
|
description: 'TODO',
|
||||||
|
title: 'Scheduler',
|
||||||
},
|
},
|
||||||
SDXLMainModelField: {
|
SDXLMainModelField: {
|
||||||
color: 'base.500',
|
color: 'teal.500',
|
||||||
title: 'SDXL Model',
|
|
||||||
description: 'SDXL model field.',
|
description: 'SDXL model field.',
|
||||||
|
title: 'SDXL Model',
|
||||||
|
},
|
||||||
|
SDXLRefinerModelField: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'TODO',
|
||||||
|
title: 'Refiner Model',
|
||||||
|
},
|
||||||
|
string: {
|
||||||
|
color: 'yellow.500',
|
||||||
|
description: 'Strings are text.',
|
||||||
|
title: 'String',
|
||||||
},
|
},
|
||||||
StringCollection: {
|
StringCollection: {
|
||||||
color: 'yellow.500',
|
color: 'yellow.500',
|
||||||
title: 'String Collection',
|
|
||||||
description: 'A collection of strings.',
|
description: 'A collection of strings.',
|
||||||
|
title: 'String Collection',
|
||||||
|
},
|
||||||
|
StringPolymorphic: {
|
||||||
|
color: 'yellow.500',
|
||||||
|
description: 'A collection of strings.',
|
||||||
|
title: 'String Polymorphic',
|
||||||
|
},
|
||||||
|
UNetField: {
|
||||||
|
color: 'red.500',
|
||||||
|
description: 'UNet submodel.',
|
||||||
|
title: 'UNet',
|
||||||
|
},
|
||||||
|
VaeField: {
|
||||||
|
color: 'blue.500',
|
||||||
|
description: 'Vae submodel.',
|
||||||
|
title: 'Vae',
|
||||||
|
},
|
||||||
|
VaeModelField: {
|
||||||
|
color: 'teal.500',
|
||||||
|
description: 'TODO',
|
||||||
|
title: 'VAE',
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import {
|
import {
|
||||||
SchedulerParam,
|
SchedulerParam,
|
||||||
zBaseModel,
|
zBaseModel,
|
||||||
|
zMainModel,
|
||||||
zMainOrOnnxModel,
|
zMainOrOnnxModel,
|
||||||
|
zOnnxModel,
|
||||||
zSDXLRefinerModel,
|
zSDXLRefinerModel,
|
||||||
zScheduler,
|
zScheduler,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
@ -9,7 +11,7 @@ import { keyBy } from 'lodash-es';
|
|||||||
import { OpenAPIV3 } from 'openapi-types';
|
import { OpenAPIV3 } from 'openapi-types';
|
||||||
import { RgbaColor } from 'react-colorful';
|
import { RgbaColor } from 'react-colorful';
|
||||||
import { Node } from 'reactflow';
|
import { Node } from 'reactflow';
|
||||||
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
|
import { Graph, _InputField, _OutputField } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
AnyInvocationType,
|
AnyInvocationType,
|
||||||
AnyResult,
|
AnyResult,
|
||||||
@ -50,6 +52,10 @@ export type InvocationTemplate = {
|
|||||||
* The type of this node's output
|
* The type of this node's output
|
||||||
*/
|
*/
|
||||||
outputType: string; // TODO: generate a union of output types
|
outputType: string; // TODO: generate a union of output types
|
||||||
|
/**
|
||||||
|
* The invocation's version.
|
||||||
|
*/
|
||||||
|
version?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type FieldUIConfig = {
|
export type FieldUIConfig = {
|
||||||
@ -60,50 +66,48 @@ export type FieldUIConfig = {
|
|||||||
|
|
||||||
// TODO: Get this from the OpenAPI schema? may be tricky...
|
// TODO: Get this from the OpenAPI schema? may be tricky...
|
||||||
export const zFieldType = z.enum([
|
export const zFieldType = z.enum([
|
||||||
// region Primitives
|
|
||||||
'integer',
|
|
||||||
'float',
|
|
||||||
'boolean',
|
'boolean',
|
||||||
'string',
|
|
||||||
'array',
|
|
||||||
'ImageField',
|
|
||||||
'DenoiseMaskField',
|
|
||||||
'LatentsField',
|
|
||||||
'ConditioningField',
|
|
||||||
'ControlField',
|
|
||||||
'ColorField',
|
|
||||||
'ImageCollection',
|
|
||||||
'ConditioningCollection',
|
|
||||||
'ColorCollection',
|
|
||||||
'LatentsCollection',
|
|
||||||
'IntegerCollection',
|
|
||||||
'FloatCollection',
|
|
||||||
'StringCollection',
|
|
||||||
'BooleanCollection',
|
'BooleanCollection',
|
||||||
// endregion
|
'BooleanPolymorphic',
|
||||||
|
|
||||||
// region Models
|
|
||||||
'MainModelField',
|
|
||||||
'SDXLMainModelField',
|
|
||||||
'SDXLRefinerModelField',
|
|
||||||
'ONNXModelField',
|
|
||||||
'VaeModelField',
|
|
||||||
'LoRAModelField',
|
|
||||||
'ControlNetModelField',
|
|
||||||
'UNetField',
|
|
||||||
'VaeField',
|
|
||||||
'ClipField',
|
'ClipField',
|
||||||
// endregion
|
|
||||||
|
|
||||||
// region Iterate/Collect
|
|
||||||
'Collection',
|
'Collection',
|
||||||
'CollectionItem',
|
'CollectionItem',
|
||||||
// endregion
|
'ColorCollection',
|
||||||
|
'ColorField',
|
||||||
// region Misc
|
'ColorPolymorphic',
|
||||||
|
'ConditioningCollection',
|
||||||
|
'ConditioningField',
|
||||||
|
'ConditioningPolymorphic',
|
||||||
|
'ControlCollection',
|
||||||
|
'ControlField',
|
||||||
|
'ControlNetModelField',
|
||||||
|
'ControlPolymorphic',
|
||||||
|
'DenoiseMaskField',
|
||||||
'enum',
|
'enum',
|
||||||
|
'float',
|
||||||
|
'FloatCollection',
|
||||||
|
'FloatPolymorphic',
|
||||||
|
'ImageCollection',
|
||||||
|
'ImageField',
|
||||||
|
'ImagePolymorphic',
|
||||||
|
'integer',
|
||||||
|
'IntegerCollection',
|
||||||
|
'IntegerPolymorphic',
|
||||||
|
'LatentsCollection',
|
||||||
|
'LatentsField',
|
||||||
|
'LatentsPolymorphic',
|
||||||
|
'LoRAModelField',
|
||||||
|
'MainModelField',
|
||||||
|
'ONNXModelField',
|
||||||
'Scheduler',
|
'Scheduler',
|
||||||
// endregion
|
'SDXLMainModelField',
|
||||||
|
'SDXLRefinerModelField',
|
||||||
|
'string',
|
||||||
|
'StringCollection',
|
||||||
|
'StringPolymorphic',
|
||||||
|
'UNetField',
|
||||||
|
'VaeField',
|
||||||
|
'VaeModelField',
|
||||||
]);
|
]);
|
||||||
|
|
||||||
export type FieldType = z.infer<typeof zFieldType>;
|
export type FieldType = z.infer<typeof zFieldType>;
|
||||||
@ -120,38 +124,6 @@ export const isFieldType = (value: unknown): value is FieldType =>
|
|||||||
zFieldType.safeParse(value).success ||
|
zFieldType.safeParse(value).success ||
|
||||||
zReservedFieldType.safeParse(value).success;
|
zReservedFieldType.safeParse(value).success;
|
||||||
|
|
||||||
/**
|
|
||||||
* An input field template is generated on each page load from the OpenAPI schema.
|
|
||||||
*
|
|
||||||
* The template provides the field type and other field metadata (e.g. title, description,
|
|
||||||
* maximum length, pattern to match, etc).
|
|
||||||
*/
|
|
||||||
export type InputFieldTemplate =
|
|
||||||
| IntegerInputFieldTemplate
|
|
||||||
| FloatInputFieldTemplate
|
|
||||||
| StringInputFieldTemplate
|
|
||||||
| BooleanInputFieldTemplate
|
|
||||||
| ImageInputFieldTemplate
|
|
||||||
| DenoiseMaskInputFieldTemplate
|
|
||||||
| LatentsInputFieldTemplate
|
|
||||||
| ConditioningInputFieldTemplate
|
|
||||||
| UNetInputFieldTemplate
|
|
||||||
| ClipInputFieldTemplate
|
|
||||||
| VaeInputFieldTemplate
|
|
||||||
| ControlInputFieldTemplate
|
|
||||||
| EnumInputFieldTemplate
|
|
||||||
| MainModelInputFieldTemplate
|
|
||||||
| SDXLMainModelInputFieldTemplate
|
|
||||||
| SDXLRefinerModelInputFieldTemplate
|
|
||||||
| VaeModelInputFieldTemplate
|
|
||||||
| LoRAModelInputFieldTemplate
|
|
||||||
| ControlNetModelInputFieldTemplate
|
|
||||||
| CollectionInputFieldTemplate
|
|
||||||
| CollectionItemInputFieldTemplate
|
|
||||||
| ColorInputFieldTemplate
|
|
||||||
| ImageCollectionInputFieldTemplate
|
|
||||||
| SchedulerInputFieldTemplate;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Indicates the kind of input(s) this field may have.
|
* Indicates the kind of input(s) this field may have.
|
||||||
*/
|
*/
|
||||||
@ -230,24 +202,88 @@ export const zIntegerInputFieldValue = zInputFieldValueBase.extend({
|
|||||||
});
|
});
|
||||||
export type IntegerInputFieldValue = z.infer<typeof zIntegerInputFieldValue>;
|
export type IntegerInputFieldValue = z.infer<typeof zIntegerInputFieldValue>;
|
||||||
|
|
||||||
|
export const zIntegerCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('IntegerCollection'),
|
||||||
|
value: z.array(z.number().int()).optional(),
|
||||||
|
});
|
||||||
|
export type IntegerCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zIntegerCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('IntegerPolymorphic'),
|
||||||
|
value: z.union([z.number().int(), z.array(z.number().int())]).optional(),
|
||||||
|
});
|
||||||
|
export type IntegerPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zIntegerPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zFloatInputFieldValue = zInputFieldValueBase.extend({
|
export const zFloatInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('float'),
|
type: z.literal('float'),
|
||||||
value: z.number().optional(),
|
value: z.number().optional(),
|
||||||
});
|
});
|
||||||
export type FloatInputFieldValue = z.infer<typeof zFloatInputFieldValue>;
|
export type FloatInputFieldValue = z.infer<typeof zFloatInputFieldValue>;
|
||||||
|
|
||||||
|
export const zFloatCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('FloatCollection'),
|
||||||
|
value: z.array(z.number()).optional(),
|
||||||
|
});
|
||||||
|
export type FloatCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zFloatCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('FloatPolymorphic'),
|
||||||
|
value: z.union([z.number(), z.array(z.number())]).optional(),
|
||||||
|
});
|
||||||
|
export type FloatPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zFloatPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zStringInputFieldValue = zInputFieldValueBase.extend({
|
export const zStringInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('string'),
|
type: z.literal('string'),
|
||||||
value: z.string().optional(),
|
value: z.string().optional(),
|
||||||
});
|
});
|
||||||
export type StringInputFieldValue = z.infer<typeof zStringInputFieldValue>;
|
export type StringInputFieldValue = z.infer<typeof zStringInputFieldValue>;
|
||||||
|
|
||||||
|
export const zStringCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('StringCollection'),
|
||||||
|
value: z.array(z.string()).optional(),
|
||||||
|
});
|
||||||
|
export type StringCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zStringCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('StringPolymorphic'),
|
||||||
|
value: z.union([z.string(), z.array(z.string())]).optional(),
|
||||||
|
});
|
||||||
|
export type StringPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zStringPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zBooleanInputFieldValue = zInputFieldValueBase.extend({
|
export const zBooleanInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('boolean'),
|
type: z.literal('boolean'),
|
||||||
value: z.boolean().optional(),
|
value: z.boolean().optional(),
|
||||||
});
|
});
|
||||||
export type BooleanInputFieldValue = z.infer<typeof zBooleanInputFieldValue>;
|
export type BooleanInputFieldValue = z.infer<typeof zBooleanInputFieldValue>;
|
||||||
|
|
||||||
|
export const zBooleanCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('BooleanCollection'),
|
||||||
|
value: z.array(z.boolean()).optional(),
|
||||||
|
});
|
||||||
|
export type BooleanCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zBooleanCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('BooleanPolymorphic'),
|
||||||
|
value: z.union([z.boolean(), z.array(z.boolean())]).optional(),
|
||||||
|
});
|
||||||
|
export type BooleanPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zBooleanPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zEnumInputFieldValue = zInputFieldValueBase.extend({
|
export const zEnumInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('enum'),
|
type: z.literal('enum'),
|
||||||
value: z.union([z.string(), z.number()]).optional(),
|
value: z.union([z.string(), z.number()]).optional(),
|
||||||
@ -260,6 +296,22 @@ export const zLatentsInputFieldValue = zInputFieldValueBase.extend({
|
|||||||
});
|
});
|
||||||
export type LatentsInputFieldValue = z.infer<typeof zLatentsInputFieldValue>;
|
export type LatentsInputFieldValue = z.infer<typeof zLatentsInputFieldValue>;
|
||||||
|
|
||||||
|
export const zLatentsCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('LatentsCollection'),
|
||||||
|
value: z.array(zLatentsField).optional(),
|
||||||
|
});
|
||||||
|
export type LatentsCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zLatentsCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zLatentsPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('LatentsPolymorphic'),
|
||||||
|
value: z.union([zLatentsField, z.array(zLatentsField)]).optional(),
|
||||||
|
});
|
||||||
|
export type LatentsPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zLatentsPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({
|
export const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('DenoiseMaskField'),
|
type: z.literal('DenoiseMaskField'),
|
||||||
value: zDenoiseMaskField.optional(),
|
value: zDenoiseMaskField.optional(),
|
||||||
@ -276,6 +328,26 @@ export type ConditioningInputFieldValue = z.infer<
|
|||||||
typeof zConditioningInputFieldValue
|
typeof zConditioningInputFieldValue
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
export const zConditioningCollectionInputFieldValue =
|
||||||
|
zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('ConditioningCollection'),
|
||||||
|
value: z.array(zConditioningField).optional(),
|
||||||
|
});
|
||||||
|
export type ConditioningCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zConditioningCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zConditioningPolymorphicInputFieldValue =
|
||||||
|
zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('ConditioningPolymorphic'),
|
||||||
|
value: z
|
||||||
|
.union([zConditioningField, z.array(zConditioningField)])
|
||||||
|
.optional(),
|
||||||
|
});
|
||||||
|
export type ConditioningPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zConditioningPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zControlNetModel = zModelIdentifier;
|
export const zControlNetModel = zModelIdentifier;
|
||||||
export type ControlNetModel = z.infer<typeof zControlNetModel>;
|
export type ControlNetModel = z.infer<typeof zControlNetModel>;
|
||||||
|
|
||||||
@ -300,6 +372,22 @@ export const zControlInputFieldValue = zInputFieldValueBase.extend({
|
|||||||
});
|
});
|
||||||
export type ControlInputFieldValue = z.infer<typeof zControlInputFieldValue>;
|
export type ControlInputFieldValue = z.infer<typeof zControlInputFieldValue>;
|
||||||
|
|
||||||
|
export const zControlPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('ControlPolymorphic'),
|
||||||
|
value: z.union([zControlField, z.array(zControlField)]).optional(),
|
||||||
|
});
|
||||||
|
export type ControlPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zControlPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zControlCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('ControlCollection'),
|
||||||
|
value: z.array(zControlField).optional(),
|
||||||
|
});
|
||||||
|
export type ControlCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zControlCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zModelType = z.enum([
|
export const zModelType = z.enum([
|
||||||
'onnx',
|
'onnx',
|
||||||
'main',
|
'main',
|
||||||
@ -379,6 +467,14 @@ export const zImageInputFieldValue = zInputFieldValueBase.extend({
|
|||||||
});
|
});
|
||||||
export type ImageInputFieldValue = z.infer<typeof zImageInputFieldValue>;
|
export type ImageInputFieldValue = z.infer<typeof zImageInputFieldValue>;
|
||||||
|
|
||||||
|
export const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('ImagePolymorphic'),
|
||||||
|
value: z.union([zImageField, z.array(zImageField)]).optional(),
|
||||||
|
});
|
||||||
|
export type ImagePolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zImagePolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zImageCollectionInputFieldValue = zInputFieldValueBase.extend({
|
export const zImageCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('ImageCollection'),
|
type: z.literal('ImageCollection'),
|
||||||
value: z.array(zImageField).optional(),
|
value: z.array(zImageField).optional(),
|
||||||
@ -471,6 +567,22 @@ export const zColorInputFieldValue = zInputFieldValueBase.extend({
|
|||||||
});
|
});
|
||||||
export type ColorInputFieldValue = z.infer<typeof zColorInputFieldValue>;
|
export type ColorInputFieldValue = z.infer<typeof zColorInputFieldValue>;
|
||||||
|
|
||||||
|
export const zColorCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('ColorCollection'),
|
||||||
|
value: z.array(zColorField).optional(),
|
||||||
|
});
|
||||||
|
export type ColorCollectionInputFieldValue = z.infer<
|
||||||
|
typeof zColorCollectionInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
|
export const zColorPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('ColorPolymorphic'),
|
||||||
|
value: z.union([zColorField, z.array(zColorField)]).optional(),
|
||||||
|
});
|
||||||
|
export type ColorPolymorphicInputFieldValue = z.infer<
|
||||||
|
typeof zColorPolymorphicInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zSchedulerInputFieldValue = zInputFieldValueBase.extend({
|
export const zSchedulerInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('Scheduler'),
|
type: z.literal('Scheduler'),
|
||||||
value: zScheduler.optional(),
|
value: zScheduler.optional(),
|
||||||
@ -480,30 +592,47 @@ export type SchedulerInputFieldValue = z.infer<
|
|||||||
>;
|
>;
|
||||||
|
|
||||||
export const zInputFieldValue = z.discriminatedUnion('type', [
|
export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||||
zIntegerInputFieldValue,
|
zBooleanCollectionInputFieldValue,
|
||||||
zFloatInputFieldValue,
|
|
||||||
zStringInputFieldValue,
|
|
||||||
zBooleanInputFieldValue,
|
zBooleanInputFieldValue,
|
||||||
zImageInputFieldValue,
|
zBooleanPolymorphicInputFieldValue,
|
||||||
zLatentsInputFieldValue,
|
|
||||||
zDenoiseMaskInputFieldValue,
|
|
||||||
zConditioningInputFieldValue,
|
|
||||||
zUNetInputFieldValue,
|
|
||||||
zClipInputFieldValue,
|
zClipInputFieldValue,
|
||||||
zVaeInputFieldValue,
|
|
||||||
zControlInputFieldValue,
|
|
||||||
zEnumInputFieldValue,
|
|
||||||
zMainModelInputFieldValue,
|
|
||||||
zSDXLMainModelInputFieldValue,
|
|
||||||
zSDXLRefinerModelInputFieldValue,
|
|
||||||
zVaeModelInputFieldValue,
|
|
||||||
zLoRAModelInputFieldValue,
|
|
||||||
zControlNetModelInputFieldValue,
|
|
||||||
zCollectionInputFieldValue,
|
zCollectionInputFieldValue,
|
||||||
zCollectionItemInputFieldValue,
|
zCollectionItemInputFieldValue,
|
||||||
zColorInputFieldValue,
|
zColorInputFieldValue,
|
||||||
|
zColorCollectionInputFieldValue,
|
||||||
|
zColorPolymorphicInputFieldValue,
|
||||||
|
zConditioningInputFieldValue,
|
||||||
|
zConditioningCollectionInputFieldValue,
|
||||||
|
zConditioningPolymorphicInputFieldValue,
|
||||||
|
zControlInputFieldValue,
|
||||||
|
zControlNetModelInputFieldValue,
|
||||||
|
zControlCollectionInputFieldValue,
|
||||||
|
zControlPolymorphicInputFieldValue,
|
||||||
|
zDenoiseMaskInputFieldValue,
|
||||||
|
zEnumInputFieldValue,
|
||||||
|
zFloatCollectionInputFieldValue,
|
||||||
|
zFloatInputFieldValue,
|
||||||
|
zFloatPolymorphicInputFieldValue,
|
||||||
zImageCollectionInputFieldValue,
|
zImageCollectionInputFieldValue,
|
||||||
|
zImagePolymorphicInputFieldValue,
|
||||||
|
zImageInputFieldValue,
|
||||||
|
zIntegerCollectionInputFieldValue,
|
||||||
|
zIntegerPolymorphicInputFieldValue,
|
||||||
|
zIntegerInputFieldValue,
|
||||||
|
zLatentsInputFieldValue,
|
||||||
|
zLatentsCollectionInputFieldValue,
|
||||||
|
zLatentsPolymorphicInputFieldValue,
|
||||||
|
zLoRAModelInputFieldValue,
|
||||||
|
zMainModelInputFieldValue,
|
||||||
zSchedulerInputFieldValue,
|
zSchedulerInputFieldValue,
|
||||||
|
zSDXLMainModelInputFieldValue,
|
||||||
|
zSDXLRefinerModelInputFieldValue,
|
||||||
|
zStringCollectionInputFieldValue,
|
||||||
|
zStringPolymorphicInputFieldValue,
|
||||||
|
zStringInputFieldValue,
|
||||||
|
zUNetInputFieldValue,
|
||||||
|
zVaeInputFieldValue,
|
||||||
|
zVaeModelInputFieldValue,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
export type InputFieldValue = z.infer<typeof zInputFieldValue>;
|
export type InputFieldValue = z.infer<typeof zInputFieldValue>;
|
||||||
@ -512,7 +641,6 @@ export type InputFieldTemplateBase = {
|
|||||||
name: string;
|
name: string;
|
||||||
title: string;
|
title: string;
|
||||||
description: string;
|
description: string;
|
||||||
type: FieldType;
|
|
||||||
required: boolean;
|
required: boolean;
|
||||||
fieldKind: 'input';
|
fieldKind: 'input';
|
||||||
} & _InputField;
|
} & _InputField;
|
||||||
@ -527,6 +655,19 @@ export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
exclusiveMinimum?: boolean;
|
exclusiveMinimum?: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type IntegerCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
type: 'IntegerCollection';
|
||||||
|
default: number[];
|
||||||
|
item_default?: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type IntegerPolymorphicInputFieldTemplate = Omit<
|
||||||
|
IntegerInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'IntegerPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
export type FloatInputFieldTemplate = InputFieldTemplateBase & {
|
export type FloatInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
type: 'float';
|
type: 'float';
|
||||||
default: number;
|
default: number;
|
||||||
@ -537,6 +678,19 @@ export type FloatInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
exclusiveMinimum?: boolean;
|
exclusiveMinimum?: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type FloatCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
type: 'FloatCollection';
|
||||||
|
default: number[];
|
||||||
|
item_default?: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type FloatPolymorphicInputFieldTemplate = Omit<
|
||||||
|
FloatInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'FloatPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
export type StringInputFieldTemplate = InputFieldTemplateBase & {
|
export type StringInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
type: 'string';
|
type: 'string';
|
||||||
default: string;
|
default: string;
|
||||||
@ -545,19 +699,53 @@ export type StringInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
pattern?: string;
|
pattern?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type StringCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
type: 'StringCollection';
|
||||||
|
default: string[];
|
||||||
|
item_default?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type StringPolymorphicInputFieldTemplate = Omit<
|
||||||
|
StringInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'StringPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
export type BooleanInputFieldTemplate = InputFieldTemplateBase & {
|
export type BooleanInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: boolean;
|
default: boolean;
|
||||||
type: 'boolean';
|
type: 'boolean';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type BooleanCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
type: 'BooleanCollection';
|
||||||
|
default: boolean[];
|
||||||
|
item_default?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type BooleanPolymorphicInputFieldTemplate = Omit<
|
||||||
|
BooleanInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'BooleanPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
export type ImageInputFieldTemplate = InputFieldTemplateBase & {
|
export type ImageInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: ImageDTO;
|
default: ImageField;
|
||||||
type: 'ImageField';
|
type: 'ImageField';
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: ImageField[];
|
default: ImageField[];
|
||||||
type: 'ImageCollection';
|
type: 'ImageCollection';
|
||||||
|
item_default?: ImageField;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ImagePolymorphicInputFieldTemplate = Omit<
|
||||||
|
ImageInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'ImagePolymorphic';
|
||||||
};
|
};
|
||||||
|
|
||||||
export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
|
export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
@ -566,15 +754,40 @@ export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
|
export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: string;
|
default: LatentsField;
|
||||||
type: 'LatentsField';
|
type: 'LatentsField';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type LatentsCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: LatentsField[];
|
||||||
|
type: 'LatentsCollection';
|
||||||
|
item_default?: LatentsField;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type LatentsPolymorphicInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: LatentsField;
|
||||||
|
type: 'LatentsPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
|
export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: undefined;
|
default: undefined;
|
||||||
type: 'ConditioningField';
|
type: 'ConditioningField';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type ConditioningCollectionInputFieldTemplate =
|
||||||
|
InputFieldTemplateBase & {
|
||||||
|
default: ConditioningField[];
|
||||||
|
type: 'ConditioningCollection';
|
||||||
|
item_default?: ConditioningField;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ConditioningPolymorphicInputFieldTemplate = Omit<
|
||||||
|
ConditioningInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'ConditioningPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
export type UNetInputFieldTemplate = InputFieldTemplateBase & {
|
export type UNetInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: undefined;
|
default: undefined;
|
||||||
type: 'UNetField';
|
type: 'UNetField';
|
||||||
@ -595,6 +808,19 @@ export type ControlInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'ControlField';
|
type: 'ControlField';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type ControlCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'ControlCollection';
|
||||||
|
item_default?: ControlField;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ControlPolymorphicInputFieldTemplate = Omit<
|
||||||
|
ControlInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'ControlPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: string | number;
|
default: string | number;
|
||||||
type: 'enum';
|
type: 'enum';
|
||||||
@ -647,6 +873,18 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'ColorField';
|
type: 'ColorField';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type ColorPolymorphicInputFieldTemplate = Omit<
|
||||||
|
ColorInputFieldTemplate,
|
||||||
|
'type'
|
||||||
|
> & {
|
||||||
|
type: 'ColorPolymorphic';
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ColorCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: [];
|
||||||
|
type: 'ColorCollection';
|
||||||
|
};
|
||||||
|
|
||||||
export type SchedulerInputFieldTemplate = InputFieldTemplateBase & {
|
export type SchedulerInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: SchedulerParam;
|
default: SchedulerParam;
|
||||||
type: 'Scheduler';
|
type: 'Scheduler';
|
||||||
@ -657,6 +895,55 @@ export type WorkflowInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'WorkflowField';
|
type: 'WorkflowField';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An input field template is generated on each page load from the OpenAPI schema.
|
||||||
|
*
|
||||||
|
* The template provides the field type and other field metadata (e.g. title, description,
|
||||||
|
* maximum length, pattern to match, etc).
|
||||||
|
*/
|
||||||
|
export type InputFieldTemplate =
|
||||||
|
| BooleanCollectionInputFieldTemplate
|
||||||
|
| BooleanPolymorphicInputFieldTemplate
|
||||||
|
| BooleanInputFieldTemplate
|
||||||
|
| ClipInputFieldTemplate
|
||||||
|
| CollectionInputFieldTemplate
|
||||||
|
| CollectionItemInputFieldTemplate
|
||||||
|
| ColorInputFieldTemplate
|
||||||
|
| ColorCollectionInputFieldTemplate
|
||||||
|
| ColorPolymorphicInputFieldTemplate
|
||||||
|
| ConditioningInputFieldTemplate
|
||||||
|
| ConditioningCollectionInputFieldTemplate
|
||||||
|
| ConditioningPolymorphicInputFieldTemplate
|
||||||
|
| ControlInputFieldTemplate
|
||||||
|
| ControlCollectionInputFieldTemplate
|
||||||
|
| ControlNetModelInputFieldTemplate
|
||||||
|
| ControlPolymorphicInputFieldTemplate
|
||||||
|
| DenoiseMaskInputFieldTemplate
|
||||||
|
| EnumInputFieldTemplate
|
||||||
|
| FloatCollectionInputFieldTemplate
|
||||||
|
| FloatInputFieldTemplate
|
||||||
|
| FloatPolymorphicInputFieldTemplate
|
||||||
|
| ImageCollectionInputFieldTemplate
|
||||||
|
| ImagePolymorphicInputFieldTemplate
|
||||||
|
| ImageInputFieldTemplate
|
||||||
|
| IntegerCollectionInputFieldTemplate
|
||||||
|
| IntegerPolymorphicInputFieldTemplate
|
||||||
|
| IntegerInputFieldTemplate
|
||||||
|
| LatentsInputFieldTemplate
|
||||||
|
| LatentsCollectionInputFieldTemplate
|
||||||
|
| LatentsPolymorphicInputFieldTemplate
|
||||||
|
| LoRAModelInputFieldTemplate
|
||||||
|
| MainModelInputFieldTemplate
|
||||||
|
| SchedulerInputFieldTemplate
|
||||||
|
| SDXLMainModelInputFieldTemplate
|
||||||
|
| SDXLRefinerModelInputFieldTemplate
|
||||||
|
| StringCollectionInputFieldTemplate
|
||||||
|
| StringPolymorphicInputFieldTemplate
|
||||||
|
| StringInputFieldTemplate
|
||||||
|
| UNetInputFieldTemplate
|
||||||
|
| VaeInputFieldTemplate
|
||||||
|
| VaeModelInputFieldTemplate;
|
||||||
|
|
||||||
export const isInputFieldValue = (
|
export const isInputFieldValue = (
|
||||||
field?: InputFieldValue | OutputFieldValue
|
field?: InputFieldValue | OutputFieldValue
|
||||||
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
|
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
|
||||||
@ -679,6 +966,7 @@ export type InvocationSchemaExtra = {
|
|||||||
title: string;
|
title: string;
|
||||||
category?: string;
|
category?: string;
|
||||||
tags?: string[];
|
tags?: string[];
|
||||||
|
version?: string;
|
||||||
properties: Omit<
|
properties: Omit<
|
||||||
NonNullable<OpenAPIV3.SchemaObject['properties']> &
|
NonNullable<OpenAPIV3.SchemaObject['properties']> &
|
||||||
(_InputField | _OutputField),
|
(_InputField | _OutputField),
|
||||||
@ -729,8 +1017,22 @@ export type InvocationSchemaObject = (
|
|||||||
) & { class: 'invocation' };
|
) & { class: 'invocation' };
|
||||||
|
|
||||||
export const isSchemaObject = (
|
export const isSchemaObject = (
|
||||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject
|
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||||
): obj is OpenAPIV3.SchemaObject => !('$ref' in obj);
|
): obj is OpenAPIV3.SchemaObject => Boolean(obj && !('$ref' in obj));
|
||||||
|
|
||||||
|
export const isArraySchemaObject = (
|
||||||
|
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||||
|
): obj is OpenAPIV3.ArraySchemaObject =>
|
||||||
|
Boolean(obj && !('$ref' in obj) && obj.type === 'array');
|
||||||
|
|
||||||
|
export const isNonArraySchemaObject = (
|
||||||
|
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||||
|
): obj is OpenAPIV3.NonArraySchemaObject =>
|
||||||
|
Boolean(obj && !('$ref' in obj) && obj.type !== 'array');
|
||||||
|
|
||||||
|
export const isRefObject = (
|
||||||
|
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||||
|
): obj is OpenAPIV3.ReferenceObject => Boolean(obj && '$ref' in obj);
|
||||||
|
|
||||||
export const isInvocationSchemaObject = (
|
export const isInvocationSchemaObject = (
|
||||||
obj:
|
obj:
|
||||||
@ -769,12 +1071,14 @@ export const zCoreMetadata = z
|
|||||||
steps: z.number().int().nullish(),
|
steps: z.number().int().nullish(),
|
||||||
scheduler: z.string().nullish(),
|
scheduler: z.string().nullish(),
|
||||||
clip_skip: z.number().int().nullish(),
|
clip_skip: z.number().int().nullish(),
|
||||||
model: zMainOrOnnxModel.nullish(),
|
model: z
|
||||||
controlnets: z.array(zControlField).nullish(),
|
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
|
||||||
|
.nullish(),
|
||||||
|
controlnets: z.array(zControlField.deepPartial()).nullish(),
|
||||||
loras: z
|
loras: z
|
||||||
.array(
|
.array(
|
||||||
z.object({
|
z.object({
|
||||||
lora: zLoRAModelField,
|
lora: zLoRAModelField.deepPartial(),
|
||||||
weight: z.number(),
|
weight: z.number(),
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
@ -784,18 +1088,41 @@ export const zCoreMetadata = z
|
|||||||
init_image: z.string().nullish(),
|
init_image: z.string().nullish(),
|
||||||
positive_style_prompt: z.string().nullish(),
|
positive_style_prompt: z.string().nullish(),
|
||||||
negative_style_prompt: z.string().nullish(),
|
negative_style_prompt: z.string().nullish(),
|
||||||
refiner_model: zSDXLRefinerModel.nullish(),
|
refiner_model: zSDXLRefinerModel.deepPartial().nullish(),
|
||||||
refiner_cfg_scale: z.number().nullish(),
|
refiner_cfg_scale: z.number().nullish(),
|
||||||
refiner_steps: z.number().int().nullish(),
|
refiner_steps: z.number().int().nullish(),
|
||||||
refiner_scheduler: z.string().nullish(),
|
refiner_scheduler: z.string().nullish(),
|
||||||
refiner_positive_aesthetic_store: z.number().nullish(),
|
refiner_positive_aesthetic_score: z.number().nullish(),
|
||||||
refiner_negative_aesthetic_store: z.number().nullish(),
|
refiner_negative_aesthetic_score: z.number().nullish(),
|
||||||
refiner_start: z.number().nullish(),
|
refiner_start: z.number().nullish(),
|
||||||
})
|
})
|
||||||
.catchall(z.record(z.any()));
|
.passthrough();
|
||||||
|
|
||||||
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
|
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
|
||||||
|
|
||||||
|
export const zSemVer = z.string().refine((val) => {
|
||||||
|
const [major, minor, patch] = val.split('.');
|
||||||
|
return (
|
||||||
|
major !== undefined &&
|
||||||
|
Number.isInteger(Number(major)) &&
|
||||||
|
minor !== undefined &&
|
||||||
|
Number.isInteger(Number(minor)) &&
|
||||||
|
patch !== undefined &&
|
||||||
|
Number.isInteger(Number(patch))
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
export const zParsedSemver = zSemVer.transform((val) => {
|
||||||
|
const [major, minor, patch] = val.split('.');
|
||||||
|
return {
|
||||||
|
major: Number(major),
|
||||||
|
minor: Number(minor),
|
||||||
|
patch: Number(patch),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
export type SemVer = z.infer<typeof zSemVer>;
|
||||||
|
|
||||||
export const zInvocationNodeData = z.object({
|
export const zInvocationNodeData = z.object({
|
||||||
id: z.string().trim().min(1),
|
id: z.string().trim().min(1),
|
||||||
// no easy way to build this dynamically, and we don't want to anyways, because this will be used
|
// no easy way to build this dynamically, and we don't want to anyways, because this will be used
|
||||||
@ -808,6 +1135,7 @@ export const zInvocationNodeData = z.object({
|
|||||||
notes: z.string(),
|
notes: z.string(),
|
||||||
embedWorkflow: z.boolean(),
|
embedWorkflow: z.boolean(),
|
||||||
isIntermediate: z.boolean(),
|
isIntermediate: z.boolean(),
|
||||||
|
version: zSemVer.optional(),
|
||||||
});
|
});
|
||||||
|
|
||||||
// Massage this to get better type safety while developing
|
// Massage this to get better type safety while developing
|
||||||
@ -896,20 +1224,6 @@ export const zFieldIdentifier = z.object({
|
|||||||
|
|
||||||
export type FieldIdentifier = z.infer<typeof zFieldIdentifier>;
|
export type FieldIdentifier = z.infer<typeof zFieldIdentifier>;
|
||||||
|
|
||||||
export const zSemVer = z.string().refine((val) => {
|
|
||||||
const [major, minor, patch] = val.split('.');
|
|
||||||
return (
|
|
||||||
major !== undefined &&
|
|
||||||
minor !== undefined &&
|
|
||||||
patch !== undefined &&
|
|
||||||
Number.isInteger(Number(major)) &&
|
|
||||||
Number.isInteger(Number(minor)) &&
|
|
||||||
Number.isInteger(Number(patch))
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
export type SemVer = z.infer<typeof zSemVer>;
|
|
||||||
|
|
||||||
export type WorkflowWarning = {
|
export type WorkflowWarning = {
|
||||||
message: string;
|
message: string;
|
||||||
issues: string[];
|
issues: string[];
|
||||||
|
@ -1,5 +1,14 @@
|
|||||||
|
import { isBoolean, isInteger, isNumber, isString } from 'lodash-es';
|
||||||
import { OpenAPIV3 } from 'openapi-types';
|
import { OpenAPIV3 } from 'openapi-types';
|
||||||
import {
|
import {
|
||||||
|
COLLECTION_MAP,
|
||||||
|
POLYMORPHIC_TYPES,
|
||||||
|
SINGLE_TO_POLYMORPHIC_MAP,
|
||||||
|
isCollectionItemType,
|
||||||
|
isPolymorphicItemType,
|
||||||
|
} from '../types/constants';
|
||||||
|
import {
|
||||||
|
BooleanCollectionInputFieldTemplate,
|
||||||
BooleanInputFieldTemplate,
|
BooleanInputFieldTemplate,
|
||||||
ClipInputFieldTemplate,
|
ClipInputFieldTemplate,
|
||||||
CollectionInputFieldTemplate,
|
CollectionInputFieldTemplate,
|
||||||
@ -11,10 +20,13 @@ import {
|
|||||||
DenoiseMaskInputFieldTemplate,
|
DenoiseMaskInputFieldTemplate,
|
||||||
EnumInputFieldTemplate,
|
EnumInputFieldTemplate,
|
||||||
FieldType,
|
FieldType,
|
||||||
|
FloatCollectionInputFieldTemplate,
|
||||||
|
FloatPolymorphicInputFieldTemplate,
|
||||||
FloatInputFieldTemplate,
|
FloatInputFieldTemplate,
|
||||||
ImageCollectionInputFieldTemplate,
|
ImageCollectionInputFieldTemplate,
|
||||||
ImageInputFieldTemplate,
|
ImageInputFieldTemplate,
|
||||||
InputFieldTemplateBase,
|
InputFieldTemplateBase,
|
||||||
|
IntegerCollectionInputFieldTemplate,
|
||||||
IntegerInputFieldTemplate,
|
IntegerInputFieldTemplate,
|
||||||
InvocationFieldSchema,
|
InvocationFieldSchema,
|
||||||
InvocationSchemaObject,
|
InvocationSchemaObject,
|
||||||
@ -24,11 +36,32 @@ import {
|
|||||||
SDXLMainModelInputFieldTemplate,
|
SDXLMainModelInputFieldTemplate,
|
||||||
SDXLRefinerModelInputFieldTemplate,
|
SDXLRefinerModelInputFieldTemplate,
|
||||||
SchedulerInputFieldTemplate,
|
SchedulerInputFieldTemplate,
|
||||||
|
StringCollectionInputFieldTemplate,
|
||||||
StringInputFieldTemplate,
|
StringInputFieldTemplate,
|
||||||
UNetInputFieldTemplate,
|
UNetInputFieldTemplate,
|
||||||
VaeInputFieldTemplate,
|
VaeInputFieldTemplate,
|
||||||
VaeModelInputFieldTemplate,
|
VaeModelInputFieldTemplate,
|
||||||
|
isArraySchemaObject,
|
||||||
|
isNonArraySchemaObject,
|
||||||
|
isRefObject,
|
||||||
|
isSchemaObject,
|
||||||
|
ControlPolymorphicInputFieldTemplate,
|
||||||
|
ColorPolymorphicInputFieldTemplate,
|
||||||
|
ColorCollectionInputFieldTemplate,
|
||||||
|
IntegerPolymorphicInputFieldTemplate,
|
||||||
|
StringPolymorphicInputFieldTemplate,
|
||||||
|
BooleanPolymorphicInputFieldTemplate,
|
||||||
|
ImagePolymorphicInputFieldTemplate,
|
||||||
|
LatentsPolymorphicInputFieldTemplate,
|
||||||
|
LatentsCollectionInputFieldTemplate,
|
||||||
|
ConditioningPolymorphicInputFieldTemplate,
|
||||||
|
ConditioningCollectionInputFieldTemplate,
|
||||||
|
ControlCollectionInputFieldTemplate,
|
||||||
|
ImageField,
|
||||||
|
LatentsField,
|
||||||
|
ConditioningField,
|
||||||
} from '../types/types';
|
} from '../types/types';
|
||||||
|
import { ControlField } from 'services/api/types';
|
||||||
|
|
||||||
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
||||||
|
|
||||||
@ -45,15 +78,8 @@ export type BuildInputFieldArg = {
|
|||||||
* @example
|
* @example
|
||||||
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
|
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
|
||||||
*/
|
*/
|
||||||
export const refObjectToFieldType = (
|
export const refObjectToSchemaName = (refObject: OpenAPIV3.ReferenceObject) =>
|
||||||
refObject: OpenAPIV3.ReferenceObject
|
refObject.$ref.split('/').slice(-1)[0];
|
||||||
): FieldType => {
|
|
||||||
const name = refObject.$ref.split('/').slice(-1)[0];
|
|
||||||
if (!name) {
|
|
||||||
throw `Unknown field type: ${name}`;
|
|
||||||
}
|
|
||||||
return name as FieldType;
|
|
||||||
};
|
|
||||||
|
|
||||||
const buildIntegerInputFieldTemplate = ({
|
const buildIntegerInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
@ -88,6 +114,57 @@ const buildIntegerInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildIntegerPolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): IntegerPolymorphicInputFieldTemplate => {
|
||||||
|
const template: IntegerPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'IntegerPolymorphic',
|
||||||
|
default: schemaObject.default ?? 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (schemaObject.multipleOf !== undefined) {
|
||||||
|
template.multipleOf = schemaObject.multipleOf;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.maximum !== undefined) {
|
||||||
|
template.maximum = schemaObject.maximum;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||||
|
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.minimum !== undefined) {
|
||||||
|
template.minimum = schemaObject.minimum;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||||
|
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||||
|
}
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildIntegerCollectionInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): IntegerCollectionInputFieldTemplate => {
|
||||||
|
const item_default =
|
||||||
|
isNumber(schemaObject.item_default) && isInteger(schemaObject.item_default)
|
||||||
|
? schemaObject.item_default
|
||||||
|
: 0;
|
||||||
|
const template: IntegerCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'IntegerCollection',
|
||||||
|
default: schemaObject.default ?? [],
|
||||||
|
item_default,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildFloatInputFieldTemplate = ({
|
const buildFloatInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -121,6 +198,54 @@ const buildFloatInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildFloatPolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): FloatPolymorphicInputFieldTemplate => {
|
||||||
|
const template: FloatPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'FloatPolymorphic',
|
||||||
|
default: schemaObject.default ?? 0,
|
||||||
|
};
|
||||||
|
if (schemaObject.multipleOf !== undefined) {
|
||||||
|
template.multipleOf = schemaObject.multipleOf;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.maximum !== undefined) {
|
||||||
|
template.maximum = schemaObject.maximum;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||||
|
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.minimum !== undefined) {
|
||||||
|
template.minimum = schemaObject.minimum;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||||
|
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||||
|
}
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildFloatCollectionInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): FloatCollectionInputFieldTemplate => {
|
||||||
|
const item_default = isNumber(schemaObject.item_default)
|
||||||
|
? schemaObject.item_default
|
||||||
|
: 0;
|
||||||
|
const template: FloatCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'FloatCollection',
|
||||||
|
default: schemaObject.default ?? [],
|
||||||
|
item_default,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildStringInputFieldTemplate = ({
|
const buildStringInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -146,6 +271,48 @@ const buildStringInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildStringPolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): StringPolymorphicInputFieldTemplate => {
|
||||||
|
const template: StringPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'StringPolymorphic',
|
||||||
|
default: schemaObject.default ?? '',
|
||||||
|
};
|
||||||
|
|
||||||
|
if (schemaObject.minLength !== undefined) {
|
||||||
|
template.minLength = schemaObject.minLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.maxLength !== undefined) {
|
||||||
|
template.maxLength = schemaObject.maxLength;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (schemaObject.pattern !== undefined) {
|
||||||
|
template.pattern = schemaObject.pattern;
|
||||||
|
}
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildStringCollectionInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): StringCollectionInputFieldTemplate => {
|
||||||
|
const item_default = isString(schemaObject.item_default)
|
||||||
|
? schemaObject.item_default
|
||||||
|
: '';
|
||||||
|
const template: StringCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'StringCollection',
|
||||||
|
default: schemaObject.default ?? [],
|
||||||
|
item_default,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildBooleanInputFieldTemplate = ({
|
const buildBooleanInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -159,6 +326,37 @@ const buildBooleanInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildBooleanPolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): BooleanPolymorphicInputFieldTemplate => {
|
||||||
|
const template: BooleanPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'BooleanPolymorphic',
|
||||||
|
default: schemaObject.default ?? false,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildBooleanCollectionInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): BooleanCollectionInputFieldTemplate => {
|
||||||
|
const item_default =
|
||||||
|
schemaObject.item_default && isBoolean(schemaObject.item_default)
|
||||||
|
? schemaObject.item_default
|
||||||
|
: false;
|
||||||
|
const template: BooleanCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'BooleanCollection',
|
||||||
|
default: schemaObject.default ?? [],
|
||||||
|
item_default,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildMainModelInputFieldTemplate = ({
|
const buildMainModelInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -250,6 +448,19 @@ const buildImageInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildImagePolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ImagePolymorphicInputFieldTemplate => {
|
||||||
|
const template: ImagePolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'ImagePolymorphic',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildImageCollectionInputFieldTemplate = ({
|
const buildImageCollectionInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -257,7 +468,8 @@ const buildImageCollectionInputFieldTemplate = ({
|
|||||||
const template: ImageCollectionInputFieldTemplate = {
|
const template: ImageCollectionInputFieldTemplate = {
|
||||||
...baseField,
|
...baseField,
|
||||||
type: 'ImageCollection',
|
type: 'ImageCollection',
|
||||||
default: schemaObject.default ?? undefined,
|
default: schemaObject.default ?? [],
|
||||||
|
item_default: (schemaObject.item_default as ImageField) ?? undefined,
|
||||||
};
|
};
|
||||||
|
|
||||||
return template;
|
return template;
|
||||||
@ -289,6 +501,33 @@ const buildLatentsInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildLatentsPolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): LatentsPolymorphicInputFieldTemplate => {
|
||||||
|
const template: LatentsPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'LatentsPolymorphic',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildLatentsCollectionInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): LatentsCollectionInputFieldTemplate => {
|
||||||
|
const template: LatentsCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'LatentsCollection',
|
||||||
|
default: schemaObject.default ?? [],
|
||||||
|
item_default: (schemaObject.item_default as LatentsField) ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildConditioningInputFieldTemplate = ({
|
const buildConditioningInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -302,6 +541,33 @@ const buildConditioningInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildConditioningPolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ConditioningPolymorphicInputFieldTemplate => {
|
||||||
|
const template: ConditioningPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'ConditioningPolymorphic',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildConditioningCollectionInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ConditioningCollectionInputFieldTemplate => {
|
||||||
|
const template: ConditioningCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'ConditioningCollection',
|
||||||
|
default: schemaObject.default ?? [],
|
||||||
|
item_default: (schemaObject.item_default as ConditioningField) ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildUNetInputFieldTemplate = ({
|
const buildUNetInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -355,6 +621,33 @@ const buildControlInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildControlPolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ControlPolymorphicInputFieldTemplate => {
|
||||||
|
const template: ControlPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'ControlPolymorphic',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildControlCollectionInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ControlCollectionInputFieldTemplate => {
|
||||||
|
const template: ControlCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'ControlCollection',
|
||||||
|
default: schemaObject.default ?? [],
|
||||||
|
item_default: (schemaObject.item_default as ControlField) ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildEnumInputFieldTemplate = ({
|
const buildEnumInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -408,6 +701,32 @@ const buildColorInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildColorPolymorphicInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ColorPolymorphicInputFieldTemplate => {
|
||||||
|
const template: ColorPolymorphicInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'ColorPolymorphic',
|
||||||
|
default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 },
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildColorCollectionInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ColorCollectionInputFieldTemplate => {
|
||||||
|
const template: ColorCollectionInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'ColorCollection',
|
||||||
|
default: schemaObject.default ?? [],
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildSchedulerInputFieldTemplate = ({
|
const buildSchedulerInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -421,45 +740,138 @@ const buildSchedulerInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const getFieldType = (schemaObject: InvocationFieldSchema): string => {
|
export const getFieldType = (
|
||||||
let fieldType = '';
|
schemaObject: InvocationFieldSchema
|
||||||
|
): string | undefined => {
|
||||||
const { ui_type } = schemaObject;
|
if (schemaObject?.ui_type) {
|
||||||
if (ui_type) {
|
return schemaObject.ui_type;
|
||||||
fieldType = ui_type;
|
|
||||||
} else if (!schemaObject.type) {
|
} else if (!schemaObject.type) {
|
||||||
// console.log('refObject', schemaObject);
|
|
||||||
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
|
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
|
||||||
|
|
||||||
if (schemaObject.allOf) {
|
if (schemaObject.allOf) {
|
||||||
fieldType = refObjectToFieldType(
|
const allOf = schemaObject.allOf;
|
||||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
if (allOf && allOf[0] && isRefObject(allOf[0])) {
|
||||||
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
|
return refObjectToSchemaName(allOf[0]);
|
||||||
);
|
}
|
||||||
} else if (schemaObject.anyOf) {
|
} else if (schemaObject.anyOf) {
|
||||||
fieldType = refObjectToFieldType(
|
const anyOf = schemaObject.anyOf;
|
||||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
/**
|
||||||
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
|
* Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is:
|
||||||
);
|
* - an `anyOf` with two items
|
||||||
} else if (schemaObject.oneOf) {
|
* - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items`
|
||||||
fieldType = refObjectToFieldType(
|
* - the other is a `SchemaObject` or `ReferenceObject` of type T
|
||||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
*
|
||||||
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
|
* Any other cases we ignore.
|
||||||
);
|
*/
|
||||||
|
|
||||||
|
let firstType: string | undefined;
|
||||||
|
let secondType: string | undefined;
|
||||||
|
|
||||||
|
if (isArraySchemaObject(anyOf[0])) {
|
||||||
|
// first is array, second is not
|
||||||
|
const first = anyOf[0].items;
|
||||||
|
const second = anyOf[1];
|
||||||
|
if (isRefObject(first) && isRefObject(second)) {
|
||||||
|
firstType = refObjectToSchemaName(first);
|
||||||
|
secondType = refObjectToSchemaName(second);
|
||||||
|
} else if (
|
||||||
|
isNonArraySchemaObject(first) &&
|
||||||
|
isNonArraySchemaObject(second)
|
||||||
|
) {
|
||||||
|
firstType = first.type;
|
||||||
|
secondType = second.type;
|
||||||
|
}
|
||||||
|
} else if (isArraySchemaObject(anyOf[1])) {
|
||||||
|
// first is not array, second is
|
||||||
|
const first = anyOf[0];
|
||||||
|
const second = anyOf[1].items;
|
||||||
|
if (isRefObject(first) && isRefObject(second)) {
|
||||||
|
firstType = refObjectToSchemaName(first);
|
||||||
|
secondType = refObjectToSchemaName(second);
|
||||||
|
} else if (
|
||||||
|
isNonArraySchemaObject(first) &&
|
||||||
|
isNonArraySchemaObject(second)
|
||||||
|
) {
|
||||||
|
firstType = first.type;
|
||||||
|
secondType = second.type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (firstType === secondType && isPolymorphicItemType(firstType)) {
|
||||||
|
return SINGLE_TO_POLYMORPHIC_MAP[firstType];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if (schemaObject.enum) {
|
} else if (schemaObject.enum) {
|
||||||
fieldType = 'enum';
|
return 'enum';
|
||||||
} else if (schemaObject.type) {
|
} else if (schemaObject.type) {
|
||||||
if (schemaObject.type === 'number') {
|
if (schemaObject.type === 'number') {
|
||||||
// floats are "number" in OpenAPI, while ints are "integer"
|
// floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them
|
||||||
fieldType = 'float';
|
return 'float';
|
||||||
|
} else if (schemaObject.type === 'array') {
|
||||||
|
const itemType = isSchemaObject(schemaObject.items)
|
||||||
|
? schemaObject.items.type
|
||||||
|
: refObjectToSchemaName(schemaObject.items);
|
||||||
|
|
||||||
|
if (isCollectionItemType(itemType)) {
|
||||||
|
return COLLECTION_MAP[itemType];
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
} else {
|
} else {
|
||||||
fieldType = schemaObject.type;
|
return schemaObject.type;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return;
|
||||||
return fieldType;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const TEMPLATE_BUILDER_MAP = {
|
||||||
|
boolean: buildBooleanInputFieldTemplate,
|
||||||
|
BooleanCollection: buildBooleanCollectionInputFieldTemplate,
|
||||||
|
BooleanPolymorphic: buildBooleanPolymorphicInputFieldTemplate,
|
||||||
|
ClipField: buildClipInputFieldTemplate,
|
||||||
|
Collection: buildCollectionInputFieldTemplate,
|
||||||
|
CollectionItem: buildCollectionItemInputFieldTemplate,
|
||||||
|
ColorCollection: buildColorCollectionInputFieldTemplate,
|
||||||
|
ColorField: buildColorInputFieldTemplate,
|
||||||
|
ColorPolymorphic: buildColorPolymorphicInputFieldTemplate,
|
||||||
|
ConditioningCollection: buildConditioningCollectionInputFieldTemplate,
|
||||||
|
ConditioningField: buildConditioningInputFieldTemplate,
|
||||||
|
ConditioningPolymorphic: buildConditioningPolymorphicInputFieldTemplate,
|
||||||
|
ControlCollection: buildControlCollectionInputFieldTemplate,
|
||||||
|
ControlField: buildControlInputFieldTemplate,
|
||||||
|
ControlNetModelField: buildControlNetModelInputFieldTemplate,
|
||||||
|
ControlPolymorphic: buildControlPolymorphicInputFieldTemplate,
|
||||||
|
DenoiseMaskField: buildDenoiseMaskInputFieldTemplate,
|
||||||
|
enum: buildEnumInputFieldTemplate,
|
||||||
|
float: buildFloatInputFieldTemplate,
|
||||||
|
FloatCollection: buildFloatCollectionInputFieldTemplate,
|
||||||
|
FloatPolymorphic: buildFloatPolymorphicInputFieldTemplate,
|
||||||
|
ImageCollection: buildImageCollectionInputFieldTemplate,
|
||||||
|
ImageField: buildImageInputFieldTemplate,
|
||||||
|
ImagePolymorphic: buildImagePolymorphicInputFieldTemplate,
|
||||||
|
integer: buildIntegerInputFieldTemplate,
|
||||||
|
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
|
||||||
|
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
|
||||||
|
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
|
||||||
|
LatentsField: buildLatentsInputFieldTemplate,
|
||||||
|
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
|
||||||
|
LoRAModelField: buildLoRAModelInputFieldTemplate,
|
||||||
|
MainModelField: buildMainModelInputFieldTemplate,
|
||||||
|
Scheduler: buildSchedulerInputFieldTemplate,
|
||||||
|
SDXLMainModelField: buildSDXLMainModelInputFieldTemplate,
|
||||||
|
SDXLRefinerModelField: buildRefinerModelInputFieldTemplate,
|
||||||
|
string: buildStringInputFieldTemplate,
|
||||||
|
StringCollection: buildStringCollectionInputFieldTemplate,
|
||||||
|
StringPolymorphic: buildStringPolymorphicInputFieldTemplate,
|
||||||
|
UNetField: buildUNetInputFieldTemplate,
|
||||||
|
VaeField: buildVaeInputFieldTemplate,
|
||||||
|
VaeModelField: buildVaeModelInputFieldTemplate,
|
||||||
|
};
|
||||||
|
|
||||||
|
const isTemplatedFieldType = (
|
||||||
|
fieldType: string | undefined
|
||||||
|
): fieldType is keyof typeof TEMPLATE_BUILDER_MAP =>
|
||||||
|
Boolean(fieldType && fieldType in TEMPLATE_BUILDER_MAP);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an input field from an invocation schema property.
|
* Builds an input field from an invocation schema property.
|
||||||
* @param fieldSchema The schema object
|
* @param fieldSchema The schema object
|
||||||
@ -474,7 +886,8 @@ export const buildInputFieldTemplate = (
|
|||||||
const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema;
|
const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema;
|
||||||
|
|
||||||
const extra = {
|
const extra = {
|
||||||
input,
|
// TODO: Can we support polymorphic inputs in the UI?
|
||||||
|
input: POLYMORPHIC_TYPES.includes(fieldType) ? 'connection' : input,
|
||||||
ui_hidden,
|
ui_hidden,
|
||||||
ui_component,
|
ui_component,
|
||||||
ui_type,
|
ui_type,
|
||||||
@ -490,146 +903,12 @@ export const buildInputFieldTemplate = (
|
|||||||
...extra,
|
...extra,
|
||||||
};
|
};
|
||||||
|
|
||||||
if (fieldType === 'ImageField') {
|
if (!isTemplatedFieldType(fieldType)) {
|
||||||
return buildImageInputFieldTemplate({
|
return;
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
if (fieldType === 'ImageCollection') {
|
|
||||||
return buildImageCollectionInputFieldTemplate({
|
return TEMPLATE_BUILDER_MAP[fieldType]({
|
||||||
schemaObject: fieldSchema,
|
schemaObject: fieldSchema,
|
||||||
baseField,
|
baseField,
|
||||||
});
|
});
|
||||||
}
|
|
||||||
if (fieldType === 'DenoiseMaskField') {
|
|
||||||
return buildDenoiseMaskInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'LatentsField') {
|
|
||||||
return buildLatentsInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'ConditioningField') {
|
|
||||||
return buildConditioningInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'UNetField') {
|
|
||||||
return buildUNetInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'ClipField') {
|
|
||||||
return buildClipInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'VaeField') {
|
|
||||||
return buildVaeInputFieldTemplate({ schemaObject: fieldSchema, baseField });
|
|
||||||
}
|
|
||||||
if (fieldType === 'ControlField') {
|
|
||||||
return buildControlInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'MainModelField') {
|
|
||||||
return buildMainModelInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'SDXLRefinerModelField') {
|
|
||||||
return buildRefinerModelInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'SDXLMainModelField') {
|
|
||||||
return buildSDXLMainModelInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'VaeModelField') {
|
|
||||||
return buildVaeModelInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'LoRAModelField') {
|
|
||||||
return buildLoRAModelInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'ControlNetModelField') {
|
|
||||||
return buildControlNetModelInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'enum') {
|
|
||||||
return buildEnumInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'integer') {
|
|
||||||
return buildIntegerInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'float') {
|
|
||||||
return buildFloatInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'string') {
|
|
||||||
return buildStringInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'boolean') {
|
|
||||||
return buildBooleanInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'Collection') {
|
|
||||||
return buildCollectionInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'CollectionItem') {
|
|
||||||
return buildCollectionItemInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'ColorField') {
|
|
||||||
return buildColorInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (fieldType === 'Scheduler') {
|
|
||||||
return buildSchedulerInputFieldTemplate({
|
|
||||||
schemaObject: fieldSchema,
|
|
||||||
baseField,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
};
|
};
|
||||||
|
@ -1,104 +1,79 @@
|
|||||||
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||||
|
|
||||||
|
const FIELD_VALUE_FALLBACK_MAP = {
|
||||||
|
'enum.number': 0,
|
||||||
|
'enum.string': '',
|
||||||
|
boolean: false,
|
||||||
|
BooleanCollection: [],
|
||||||
|
BooleanPolymorphic: false,
|
||||||
|
ClipField: undefined,
|
||||||
|
Collection: [],
|
||||||
|
CollectionItem: undefined,
|
||||||
|
ColorCollection: [],
|
||||||
|
ColorField: undefined,
|
||||||
|
ColorPolymorphic: undefined,
|
||||||
|
ConditioningCollection: [],
|
||||||
|
ConditioningField: undefined,
|
||||||
|
ConditioningPolymorphic: undefined,
|
||||||
|
ControlCollection: [],
|
||||||
|
ControlField: undefined,
|
||||||
|
ControlNetModelField: undefined,
|
||||||
|
ControlPolymorphic: undefined,
|
||||||
|
DenoiseMaskField: undefined,
|
||||||
|
float: 0,
|
||||||
|
FloatCollection: [],
|
||||||
|
FloatPolymorphic: 0,
|
||||||
|
ImageCollection: [],
|
||||||
|
ImageField: undefined,
|
||||||
|
ImagePolymorphic: undefined,
|
||||||
|
integer: 0,
|
||||||
|
IntegerCollection: [],
|
||||||
|
IntegerPolymorphic: 0,
|
||||||
|
LatentsCollection: [],
|
||||||
|
LatentsField: undefined,
|
||||||
|
LatentsPolymorphic: undefined,
|
||||||
|
LoRAModelField: undefined,
|
||||||
|
MainModelField: undefined,
|
||||||
|
ONNXModelField: undefined,
|
||||||
|
Scheduler: 'euler',
|
||||||
|
SDXLMainModelField: undefined,
|
||||||
|
SDXLRefinerModelField: undefined,
|
||||||
|
string: '',
|
||||||
|
StringCollection: [],
|
||||||
|
StringPolymorphic: '',
|
||||||
|
UNetField: undefined,
|
||||||
|
VaeField: undefined,
|
||||||
|
VaeModelField: undefined,
|
||||||
|
};
|
||||||
|
|
||||||
export const buildInputFieldValue = (
|
export const buildInputFieldValue = (
|
||||||
id: string,
|
id: string,
|
||||||
template: InputFieldTemplate
|
template: InputFieldTemplate
|
||||||
): InputFieldValue => {
|
): InputFieldValue => {
|
||||||
const fieldValue: InputFieldValue = {
|
// TODO: this should be `fieldValue: InputFieldValue`, but that introduces a TS issue I couldn't
|
||||||
|
// resolve - for some reason, it doesn't like `template.type`, which is the discriminant for both
|
||||||
|
// `InputFieldTemplate` union. It is (type-structurally) equal to the discriminant for the
|
||||||
|
// `InputFieldValue` union, but TS doesn't seem to like it...
|
||||||
|
const fieldValue = {
|
||||||
id,
|
id,
|
||||||
name: template.name,
|
name: template.name,
|
||||||
type: template.type,
|
type: template.type,
|
||||||
label: '',
|
label: '',
|
||||||
fieldKind: 'input',
|
fieldKind: 'input',
|
||||||
};
|
} as InputFieldValue;
|
||||||
|
|
||||||
if (template.type === 'string') {
|
|
||||||
fieldValue.value = template.default ?? '';
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'integer') {
|
|
||||||
fieldValue.value = template.default ?? 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'float') {
|
|
||||||
fieldValue.value = template.default ?? 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'boolean') {
|
|
||||||
fieldValue.value = template.default ?? false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'enum') {
|
if (template.type === 'enum') {
|
||||||
if (template.enumType === 'number') {
|
if (template.enumType === 'number') {
|
||||||
fieldValue.value = template.default ?? 0;
|
fieldValue.value =
|
||||||
|
template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.number'];
|
||||||
}
|
}
|
||||||
if (template.enumType === 'string') {
|
if (template.enumType === 'string') {
|
||||||
fieldValue.value = template.default ?? '';
|
fieldValue.value =
|
||||||
|
template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.string'];
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
|
fieldValue.value =
|
||||||
if (template.type === 'Collection') {
|
template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type];
|
||||||
fieldValue.value = template.default ?? 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'ImageField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'ImageCollection') {
|
|
||||||
fieldValue.value = [];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'DenoiseMaskField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'LatentsField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'ConditioningField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'UNetField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'ClipField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'VaeField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'ControlField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'MainModelField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'SDXLRefinerModelField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'VaeModelField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'LoRAModelField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'ControlNetModelField') {
|
|
||||||
fieldValue.value = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (template.type === 'Scheduler') {
|
|
||||||
fieldValue.value = 'euler';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return fieldValue;
|
return fieldValue;
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
import * as png from '@stevebel/png';
|
import * as png from '@stevebel/png';
|
||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { parseify } from 'common/util/serialize';
|
||||||
import {
|
import {
|
||||||
ImageMetadataAndWorkflow,
|
ImageMetadataAndWorkflow,
|
||||||
zCoreMetadata,
|
zCoreMetadata,
|
||||||
@ -18,6 +20,11 @@ export const getMetadataAndWorkflowFromImageBlob = async (
|
|||||||
const metadataResult = zCoreMetadata.safeParse(JSON.parse(rawMetadata));
|
const metadataResult = zCoreMetadata.safeParse(JSON.parse(rawMetadata));
|
||||||
if (metadataResult.success) {
|
if (metadataResult.success) {
|
||||||
data.metadata = metadataResult.data;
|
data.metadata = metadataResult.data;
|
||||||
|
} else {
|
||||||
|
logger('system').error(
|
||||||
|
{ error: parseify(metadataResult.error) },
|
||||||
|
'Problem reading metadata from image'
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -26,6 +33,11 @@ export const getMetadataAndWorkflowFromImageBlob = async (
|
|||||||
const workflowResult = zWorkflow.safeParse(JSON.parse(rawWorkflow));
|
const workflowResult = zWorkflow.safeParse(JSON.parse(rawWorkflow));
|
||||||
if (workflowResult.success) {
|
if (workflowResult.success) {
|
||||||
data.workflow = workflowResult.data;
|
data.workflow = workflowResult.data;
|
||||||
|
} else {
|
||||||
|
logger('system').error(
|
||||||
|
{ error: parseify(workflowResult.error) },
|
||||||
|
'Problem reading workflow from image'
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -60,9 +60,9 @@ export const addSDXLRefinerToGraph = (
|
|||||||
|
|
||||||
if (metadataAccumulator) {
|
if (metadataAccumulator) {
|
||||||
metadataAccumulator.refiner_model = refinerModel;
|
metadataAccumulator.refiner_model = refinerModel;
|
||||||
metadataAccumulator.refiner_positive_aesthetic_store =
|
metadataAccumulator.refiner_positive_aesthetic_score =
|
||||||
refinerPositiveAestheticScore;
|
refinerPositiveAestheticScore;
|
||||||
metadataAccumulator.refiner_negative_aesthetic_store =
|
metadataAccumulator.refiner_negative_aesthetic_score =
|
||||||
refinerNegativeAestheticScore;
|
refinerNegativeAestheticScore;
|
||||||
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
|
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
|
||||||
metadataAccumulator.refiner_scheduler = refinerScheduler;
|
metadataAccumulator.refiner_scheduler = refinerScheduler;
|
||||||
|
@ -73,6 +73,7 @@ export const parseSchema = (
|
|||||||
const title = schema.title.replace('Invocation', '');
|
const title = schema.title.replace('Invocation', '');
|
||||||
const tags = schema.tags ?? [];
|
const tags = schema.tags ?? [];
|
||||||
const description = schema.description ?? '';
|
const description = schema.description ?? '';
|
||||||
|
const version = schema.version ?? '';
|
||||||
|
|
||||||
const inputs = reduce(
|
const inputs = reduce(
|
||||||
schema.properties,
|
schema.properties,
|
||||||
@ -225,11 +226,12 @@ export const parseSchema = (
|
|||||||
const invocation: InvocationTemplate = {
|
const invocation: InvocationTemplate = {
|
||||||
title,
|
title,
|
||||||
type,
|
type,
|
||||||
|
version,
|
||||||
tags,
|
tags,
|
||||||
description,
|
description,
|
||||||
|
outputType,
|
||||||
inputs,
|
inputs,
|
||||||
outputs,
|
outputs,
|
||||||
outputType,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Object.assign(invocationsAccumulator, { [type]: invocation });
|
Object.assign(invocationsAccumulator, { [type]: invocation });
|
||||||
|
@ -0,0 +1,96 @@
|
|||||||
|
import { compareVersions } from 'compare-versions';
|
||||||
|
import { cloneDeep, keyBy } from 'lodash-es';
|
||||||
|
import {
|
||||||
|
InvocationTemplate,
|
||||||
|
Workflow,
|
||||||
|
WorkflowWarning,
|
||||||
|
isWorkflowInvocationNode,
|
||||||
|
} from '../types/types';
|
||||||
|
import { parseify } from 'common/util/serialize';
|
||||||
|
|
||||||
|
export const validateWorkflow = (
|
||||||
|
workflow: Workflow,
|
||||||
|
nodeTemplates: Record<string, InvocationTemplate>
|
||||||
|
) => {
|
||||||
|
const clone = cloneDeep(workflow);
|
||||||
|
const { nodes, edges } = clone;
|
||||||
|
const errors: WorkflowWarning[] = [];
|
||||||
|
const invocationNodes = nodes.filter(isWorkflowInvocationNode);
|
||||||
|
const keyedNodes = keyBy(invocationNodes, 'id');
|
||||||
|
nodes.forEach((node) => {
|
||||||
|
if (!isWorkflowInvocationNode(node)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const nodeTemplate = nodeTemplates[node.data.type];
|
||||||
|
if (!nodeTemplate) {
|
||||||
|
errors.push({
|
||||||
|
message: `Node "${node.data.type}" skipped`,
|
||||||
|
issues: [`Node type "${node.data.type}" does not exist`],
|
||||||
|
data: node,
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
nodeTemplate.version &&
|
||||||
|
node.data.version &&
|
||||||
|
compareVersions(nodeTemplate.version, node.data.version) !== 0
|
||||||
|
) {
|
||||||
|
errors.push({
|
||||||
|
message: `Node "${node.data.type}" has mismatched version`,
|
||||||
|
issues: [
|
||||||
|
`Node "${node.data.type}" v${node.data.version} may be incompatible with installed v${nodeTemplate.version}`,
|
||||||
|
],
|
||||||
|
data: { node, nodeTemplate: parseify(nodeTemplate) },
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
edges.forEach((edge, i) => {
|
||||||
|
const sourceNode = keyedNodes[edge.source];
|
||||||
|
const targetNode = keyedNodes[edge.target];
|
||||||
|
const issues: string[] = [];
|
||||||
|
if (!sourceNode) {
|
||||||
|
issues.push(`Output node ${edge.source} does not exist`);
|
||||||
|
} else if (
|
||||||
|
edge.type === 'default' &&
|
||||||
|
!(edge.sourceHandle in sourceNode.data.outputs)
|
||||||
|
) {
|
||||||
|
issues.push(
|
||||||
|
`Output field "${edge.source}.${edge.sourceHandle}" does not exist`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (!targetNode) {
|
||||||
|
issues.push(`Input node ${edge.target} does not exist`);
|
||||||
|
} else if (
|
||||||
|
edge.type === 'default' &&
|
||||||
|
!(edge.targetHandle in targetNode.data.inputs)
|
||||||
|
) {
|
||||||
|
issues.push(
|
||||||
|
`Input field "${edge.target}.${edge.targetHandle}" does not exist`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (!nodeTemplates[sourceNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) {
|
||||||
|
issues.push(
|
||||||
|
`Source node "${edge.source}" missing template "${sourceNode?.data.type}"`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (!nodeTemplates[targetNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) {
|
||||||
|
issues.push(
|
||||||
|
`Source node "${edge.target}" missing template "${targetNode?.data.type}"`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (issues.length) {
|
||||||
|
delete edges[i];
|
||||||
|
const src = edge.type === 'default' ? edge.sourceHandle : edge.source;
|
||||||
|
const tgt = edge.type === 'default' ? edge.targetHandle : edge.target;
|
||||||
|
errors.push({
|
||||||
|
message: `Edge "${src} -> ${tgt}" skipped`,
|
||||||
|
issues,
|
||||||
|
data: edge,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return { workflow: clone, errors };
|
||||||
|
};
|
@ -341,8 +341,8 @@ export const useRecallParameters = () => {
|
|||||||
refiner_cfg_scale,
|
refiner_cfg_scale,
|
||||||
refiner_steps,
|
refiner_steps,
|
||||||
refiner_scheduler,
|
refiner_scheduler,
|
||||||
refiner_positive_aesthetic_store,
|
refiner_positive_aesthetic_score,
|
||||||
refiner_negative_aesthetic_store,
|
refiner_negative_aesthetic_score,
|
||||||
refiner_start,
|
refiner_start,
|
||||||
} = metadata;
|
} = metadata;
|
||||||
|
|
||||||
@ -403,21 +403,21 @@ export const useRecallParameters = () => {
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
isValidSDXLRefinerPositiveAestheticScore(
|
isValidSDXLRefinerPositiveAestheticScore(
|
||||||
refiner_positive_aesthetic_store
|
refiner_positive_aesthetic_score
|
||||||
)
|
)
|
||||||
) {
|
) {
|
||||||
dispatch(
|
dispatch(
|
||||||
setRefinerPositiveAestheticScore(refiner_positive_aesthetic_store)
|
setRefinerPositiveAestheticScore(refiner_positive_aesthetic_score)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isValidSDXLRefinerNegativeAestheticScore(
|
isValidSDXLRefinerNegativeAestheticScore(
|
||||||
refiner_negative_aesthetic_store
|
refiner_negative_aesthetic_score
|
||||||
)
|
)
|
||||||
) {
|
) {
|
||||||
dispatch(
|
dispatch(
|
||||||
setRefinerNegativeAestheticScore(refiner_negative_aesthetic_store)
|
setRefinerNegativeAestheticScore(refiner_negative_aesthetic_score)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
import { Flex } from '@chakra-ui/react';
|
import { Flex } from '@chakra-ui/react';
|
||||||
import { useForm } from '@mantine/form';
|
import { useForm } from '@mantine/form';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
||||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { useState } from 'react';
|
import { useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
|
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
|
||||||
@ -14,6 +14,7 @@ import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
|
|||||||
import BaseModelSelect from '../shared/BaseModelSelect';
|
import BaseModelSelect from '../shared/BaseModelSelect';
|
||||||
import CheckpointConfigsSelect from '../shared/CheckpointConfigsSelect';
|
import CheckpointConfigsSelect from '../shared/CheckpointConfigsSelect';
|
||||||
import ModelVariantSelect from '../shared/ModelVariantSelect';
|
import ModelVariantSelect from '../shared/ModelVariantSelect';
|
||||||
|
import { getModelName } from './util';
|
||||||
|
|
||||||
type AdvancedAddCheckpointProps = {
|
type AdvancedAddCheckpointProps = {
|
||||||
model_path?: string;
|
model_path?: string;
|
||||||
@ -28,7 +29,7 @@ export default function AdvancedAddCheckpoint(
|
|||||||
|
|
||||||
const advancedAddCheckpointForm = useForm<CheckpointModelConfig>({
|
const advancedAddCheckpointForm = useForm<CheckpointModelConfig>({
|
||||||
initialValues: {
|
initialValues: {
|
||||||
model_name: model_path?.split('\\').splice(-1)[0]?.split('.')[0] ?? '',
|
model_name: model_path ? getModelName(model_path) : '',
|
||||||
base_model: 'sd-1',
|
base_model: 'sd-1',
|
||||||
model_type: 'main',
|
model_type: 'main',
|
||||||
path: model_path ? model_path : '',
|
path: model_path ? model_path : '',
|
||||||
@ -100,6 +101,17 @@ export default function AdvancedAddCheckpoint(
|
|||||||
label="Model Location"
|
label="Model Location"
|
||||||
required
|
required
|
||||||
{...advancedAddCheckpointForm.getInputProps('path')}
|
{...advancedAddCheckpointForm.getInputProps('path')}
|
||||||
|
onBlur={(e) => {
|
||||||
|
if (advancedAddCheckpointForm.values['model_name'] === '') {
|
||||||
|
const modelName = getModelName(e.currentTarget.value);
|
||||||
|
if (modelName) {
|
||||||
|
advancedAddCheckpointForm.setFieldValue(
|
||||||
|
'model_name',
|
||||||
|
modelName as string
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
<IAIMantineTextInput
|
<IAIMantineTextInput
|
||||||
label="Description"
|
label="Description"
|
||||||
|
@ -1,16 +1,17 @@
|
|||||||
import { Flex } from '@chakra-ui/react';
|
import { Flex } from '@chakra-ui/react';
|
||||||
import { useForm } from '@mantine/form';
|
import { useForm } from '@mantine/form';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
|
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
|
||||||
import { DiffusersModelConfig } from 'services/api/types';
|
import { DiffusersModelConfig } from 'services/api/types';
|
||||||
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
|
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
|
||||||
import BaseModelSelect from '../shared/BaseModelSelect';
|
import BaseModelSelect from '../shared/BaseModelSelect';
|
||||||
import ModelVariantSelect from '../shared/ModelVariantSelect';
|
import ModelVariantSelect from '../shared/ModelVariantSelect';
|
||||||
|
import { getModelName } from './util';
|
||||||
|
|
||||||
type AdvancedAddDiffusersProps = {
|
type AdvancedAddDiffusersProps = {
|
||||||
model_path?: string;
|
model_path?: string;
|
||||||
@ -25,7 +26,7 @@ export default function AdvancedAddDiffusers(props: AdvancedAddDiffusersProps) {
|
|||||||
|
|
||||||
const advancedAddDiffusersForm = useForm<DiffusersModelConfig>({
|
const advancedAddDiffusersForm = useForm<DiffusersModelConfig>({
|
||||||
initialValues: {
|
initialValues: {
|
||||||
model_name: model_path?.split('\\').splice(-1)[0] ?? '',
|
model_name: model_path ? getModelName(model_path, false) : '',
|
||||||
base_model: 'sd-1',
|
base_model: 'sd-1',
|
||||||
model_type: 'main',
|
model_type: 'main',
|
||||||
path: model_path ? model_path : '',
|
path: model_path ? model_path : '',
|
||||||
@ -92,6 +93,17 @@ export default function AdvancedAddDiffusers(props: AdvancedAddDiffusersProps) {
|
|||||||
label="Model Location"
|
label="Model Location"
|
||||||
placeholder="Provide the path to a local folder where your Diffusers Model is stored"
|
placeholder="Provide the path to a local folder where your Diffusers Model is stored"
|
||||||
{...advancedAddDiffusersForm.getInputProps('path')}
|
{...advancedAddDiffusersForm.getInputProps('path')}
|
||||||
|
onBlur={(e) => {
|
||||||
|
if (advancedAddDiffusersForm.values['model_name'] === '') {
|
||||||
|
const modelName = getModelName(e.currentTarget.value, false);
|
||||||
|
if (modelName) {
|
||||||
|
advancedAddDiffusersForm.setFieldValue(
|
||||||
|
'model_name',
|
||||||
|
modelName as string
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
<IAIMantineTextInput
|
<IAIMantineTextInput
|
||||||
label="Description"
|
label="Description"
|
||||||
|
@ -0,0 +1,15 @@
|
|||||||
|
export function getModelName(filepath: string, isCheckpoint: boolean = true) {
|
||||||
|
let regex;
|
||||||
|
if (isCheckpoint) {
|
||||||
|
regex = new RegExp('[^\\\\/]+(?=\\.)');
|
||||||
|
} else {
|
||||||
|
regex = new RegExp('[^\\\\/]+(?=[\\\\/]?$)');
|
||||||
|
}
|
||||||
|
|
||||||
|
const match = filepath.match(regex);
|
||||||
|
if (match) {
|
||||||
|
return match[0];
|
||||||
|
} else {
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
}
|
@ -28,6 +28,8 @@ import {
|
|||||||
} from '../util';
|
} from '../util';
|
||||||
import { boardsApi } from './boards';
|
import { boardsApi } from './boards';
|
||||||
import { ImageMetadataAndWorkflow } from 'features/nodes/types/types';
|
import { ImageMetadataAndWorkflow } from 'features/nodes/types/types';
|
||||||
|
import { fetchBaseQuery } from '@reduxjs/toolkit/dist/query';
|
||||||
|
import { $authToken, $projectId } from '../client';
|
||||||
|
|
||||||
export const imagesApi = api.injectEndpoints({
|
export const imagesApi = api.injectEndpoints({
|
||||||
endpoints: (build) => ({
|
endpoints: (build) => ({
|
||||||
@ -115,18 +117,40 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
],
|
],
|
||||||
keepUnusedDataFor: 86400, // 24 hours
|
keepUnusedDataFor: 86400, // 24 hours
|
||||||
}),
|
}),
|
||||||
getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, string>({
|
getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, ImageDTO>({
|
||||||
query: (image_name) => ({
|
queryFn: async (args: ImageDTO, api, extraOptions) => {
|
||||||
url: `images/i/${image_name}/full`,
|
const authToken = $authToken.get();
|
||||||
responseHandler: async (res) => {
|
const projectId = $projectId.get();
|
||||||
return await res.blob();
|
const customBaseQuery = fetchBaseQuery({
|
||||||
},
|
baseUrl: '',
|
||||||
}),
|
prepareHeaders: (headers) => {
|
||||||
providesTags: (result, error, image_name) => [
|
if (authToken) {
|
||||||
{ type: 'ImageMetadataFromFile', id: image_name },
|
headers.set('Authorization', `Bearer ${authToken}`);
|
||||||
|
}
|
||||||
|
if (projectId) {
|
||||||
|
headers.set('project-id', projectId);
|
||||||
|
}
|
||||||
|
|
||||||
|
return headers;
|
||||||
|
},
|
||||||
|
responseHandler: async (res) => {
|
||||||
|
return await res.blob();
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await customBaseQuery(
|
||||||
|
args.image_url,
|
||||||
|
api,
|
||||||
|
extraOptions
|
||||||
|
);
|
||||||
|
const data = await getMetadataAndWorkflowFromImageBlob(
|
||||||
|
response.data as Blob
|
||||||
|
);
|
||||||
|
return { data };
|
||||||
|
},
|
||||||
|
providesTags: (result, error, image_dto) => [
|
||||||
|
{ type: 'ImageMetadataFromFile', id: image_dto.image_name },
|
||||||
],
|
],
|
||||||
transformResponse: (response: Blob) =>
|
|
||||||
getMetadataAndWorkflowFromImageBlob(response),
|
|
||||||
keepUnusedDataFor: 86400, // 24 hours
|
keepUnusedDataFor: 86400, // 24 hours
|
||||||
}),
|
}),
|
||||||
clearIntermediates: build.mutation<number, void>({
|
clearIntermediates: build.mutation<number, void>({
|
||||||
|
126
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
126
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
@ -2970,6 +2970,11 @@ commondir@^1.0.1:
|
|||||||
resolved "https://registry.yarnpkg.com/commondir/-/commondir-1.0.1.tgz#ddd800da0c66127393cca5950ea968a3aaf1253b"
|
resolved "https://registry.yarnpkg.com/commondir/-/commondir-1.0.1.tgz#ddd800da0c66127393cca5950ea968a3aaf1253b"
|
||||||
integrity sha512-W9pAhw0ja1Edb5GVdIF1mjZw/ASI0AlShXM83UUGe2DVr5TdAPEA1OA8m/g8zWp9x6On7gqufY+FatDbC3MDQg==
|
integrity sha512-W9pAhw0ja1Edb5GVdIF1mjZw/ASI0AlShXM83UUGe2DVr5TdAPEA1OA8m/g8zWp9x6On7gqufY+FatDbC3MDQg==
|
||||||
|
|
||||||
|
compare-versions@^6.1.0:
|
||||||
|
version "6.1.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/compare-versions/-/compare-versions-6.1.0.tgz#3f2131e3ae93577df111dba133e6db876ffe127a"
|
||||||
|
integrity sha512-LNZQXhqUvqUTotpZ00qLSaify3b4VFD588aRr8MKFw4CMUr98ytzCW5wDH5qx/DEY5kCDXcbcRuCqL0szEf2tg==
|
||||||
|
|
||||||
compute-scroll-into-view@1.0.20:
|
compute-scroll-into-view@1.0.20:
|
||||||
version "1.0.20"
|
version "1.0.20"
|
||||||
resolved "https://registry.yarnpkg.com/compute-scroll-into-view/-/compute-scroll-into-view-1.0.20.tgz#1768b5522d1172754f5d0c9b02de3af6be506a43"
|
resolved "https://registry.yarnpkg.com/compute-scroll-into-view/-/compute-scroll-into-view-1.0.20.tgz#1768b5522d1172754f5d0c9b02de3af6be506a43"
|
||||||
|
@ -1,283 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "ycYWcsEKc6w7"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"# Stable Diffusion AI Notebook (Release 2.0.0)\n",
|
|
||||||
"\n",
|
|
||||||
"<img src=\"https://user-images.githubusercontent.com/60411196/186547976-d9de378a-9de8-4201-9c25-c057a9c59bad.jpeg\" alt=\"stable-diffusion-ai\" width=\"170px\"/> <br>\n",
|
|
||||||
"#### Instructions:\n",
|
|
||||||
"1. Execute each cell in order to mount a Dream bot and create images from text. <br>\n",
|
|
||||||
"2. Once cells 1-8 were run correctly you'll be executing a terminal in cell #9, you'll need to enter `python scripts/dream.py` command to run Dream bot.<br> \n",
|
|
||||||
"3. After launching dream bot, you'll see: <br> `Dream > ` in terminal. <br> Insert a command, eg. `Dream > Astronaut floating in a distant galaxy`, or type `-h` for help.\n",
|
|
||||||
"3. After completion you'll see your generated images in path `stable-diffusion/outputs/img-samples/`, you can also show last generated images in cell #10.\n",
|
|
||||||
"4. To quit Dream bot use `q` command. <br> \n",
|
|
||||||
"---\n",
|
|
||||||
"<font color=\"red\">Note:</font> It takes some time to load, but after installing all dependencies you can use the bot all time you want while colab instance is up. <br>\n",
|
|
||||||
"<font color=\"red\">Requirements:</font> For this notebook to work you need to have [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) stored in your Google Drive, it will be needed in cell #7\n",
|
|
||||||
"##### For more details visit Github repository: [invoke-ai/InvokeAI](https://github.com/invoke-ai/InvokeAI)\n",
|
|
||||||
"---\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "dr32VLxlnouf"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## ◢ Installation"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "a2Z5Qu_o8VtQ"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 1. Check current GPU assigned\n",
|
|
||||||
"!nvidia-smi -L\n",
|
|
||||||
"!nvidia-smi"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "vbI9ZsQHzjqF"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 2. Download stable-diffusion Repository\n",
|
|
||||||
"from os.path import exists\n",
|
|
||||||
"\n",
|
|
||||||
"!git clone --quiet https://github.com/invoke-ai/InvokeAI.git # Original repo\n",
|
|
||||||
"%cd /content/InvokeAI/\n",
|
|
||||||
"!git checkout --quiet tags/v2.0.0"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "QbXcGXYEFSNB"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 3. Install dependencies\n",
|
|
||||||
"import gc\n",
|
|
||||||
"\n",
|
|
||||||
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-base.txt\n",
|
|
||||||
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-win-colab-cuda.txt\n",
|
|
||||||
"!pip install colab-xterm\n",
|
|
||||||
"!pip install -r requirements-lin-win-colab-CUDA.txt\n",
|
|
||||||
"!pip install clean-fid torchtext\n",
|
|
||||||
"!pip install transformers\n",
|
|
||||||
"gc.collect()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "8rSMhgnAttQa"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 4. Restart Runtime\n",
|
|
||||||
"exit()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "ChIDWxLVHGGJ"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 5. Load small ML models required\n",
|
|
||||||
"import gc\n",
|
|
||||||
"\n",
|
|
||||||
"%cd /content/InvokeAI/\n",
|
|
||||||
"!python scripts/preload_models.py\n",
|
|
||||||
"gc.collect()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "795x1tMoo8b1"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## ◢ Configuration"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "YEWPV-sF1RDM"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 6. Mount google Drive\n",
|
|
||||||
"from google.colab import drive\n",
|
|
||||||
"\n",
|
|
||||||
"drive.mount(\"/content/drive\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "zRTJeZ461WGu"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 7. Drive Path to model\n",
|
|
||||||
"# @markdown Path should start with /content/drive/path-to-your-file <br>\n",
|
|
||||||
"# @markdown <font color=\"red\">Note:</font> Model should be downloaded from https://huggingface.co <br>\n",
|
|
||||||
"# @markdown Lastest release: [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)\n",
|
|
||||||
"from os.path import exists\n",
|
|
||||||
"\n",
|
|
||||||
"model_path = \"\" # @param {type:\"string\"}\n",
|
|
||||||
"if exists(model_path):\n",
|
|
||||||
" print(\"✅ Valid directory\")\n",
|
|
||||||
"else:\n",
|
|
||||||
" print(\"❌ File doesn't exist\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "UY-NNz4I8_aG"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 8. Symlink to model\n",
|
|
||||||
"\n",
|
|
||||||
"from os.path import exists\n",
|
|
||||||
"import os\n",
|
|
||||||
"\n",
|
|
||||||
"# Folder creation if it doesn't exist\n",
|
|
||||||
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1\"):\n",
|
|
||||||
" print(\"❗ Dir stable-diffusion-v1 already exists\")\n",
|
|
||||||
"else:\n",
|
|
||||||
" %mkdir /content/InvokeAI/models/ldm/stable-diffusion-v1\n",
|
|
||||||
" print(\"✅ Dir stable-diffusion-v1 created\")\n",
|
|
||||||
"\n",
|
|
||||||
"# Symbolic link if it doesn't exist\n",
|
|
||||||
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"):\n",
|
|
||||||
" print(\"❗ Symlink already created\")\n",
|
|
||||||
"else:\n",
|
|
||||||
" src = model_path\n",
|
|
||||||
" dst = \"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"\n",
|
|
||||||
" os.symlink(src, dst)\n",
|
|
||||||
" print(\"✅ Symbolic link created successfully\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "Mc28N0_NrCQH"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## ◢ Execution"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "ir4hCrMIuUpl"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# @title 9. Run Terminal and Execute Dream bot\n",
|
|
||||||
"# @markdown <font color=\"blue\">Steps:</font> <br>\n",
|
|
||||||
"# @markdown 1. Execute command `python scripts/invoke.py` to run InvokeAI.<br>\n",
|
|
||||||
"# @markdown 2. After initialized you'll see `Dream>` line.<br>\n",
|
|
||||||
"# @markdown 3. Example text: `Astronaut floating in a distant galaxy` <br>\n",
|
|
||||||
"# @markdown 4. To quit Dream bot use: `q` command.<br>\n",
|
|
||||||
"\n",
|
|
||||||
"%load_ext colabxterm\n",
|
|
||||||
"%xterm\n",
|
|
||||||
"gc.collect()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"cellView": "form",
|
|
||||||
"id": "qnLohSHmKoGk"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"#@title 10. Show the last 15 generated images\n",
|
|
||||||
"import glob\n",
|
|
||||||
"import matplotlib.pyplot as plt\n",
|
|
||||||
"import matplotlib.image as mpimg\n",
|
|
||||||
"%matplotlib inline\n",
|
|
||||||
"\n",
|
|
||||||
"images = []\n",
|
|
||||||
"for img_path in sorted(glob.glob('/content/InvokeAI/outputs/img-samples/*.png'), reverse=True):\n",
|
|
||||||
" images.append(mpimg.imread(img_path))\n",
|
|
||||||
"\n",
|
|
||||||
"images = images[:15] \n",
|
|
||||||
"\n",
|
|
||||||
"plt.figure(figsize=(20,10))\n",
|
|
||||||
"\n",
|
|
||||||
"columns = 5\n",
|
|
||||||
"for i, image in enumerate(images):\n",
|
|
||||||
" ax = plt.subplot(len(images) / columns + 1, columns, i + 1)\n",
|
|
||||||
" ax.axes.xaxis.set_visible(False)\n",
|
|
||||||
" ax.axes.yaxis.set_visible(False)\n",
|
|
||||||
" ax.axis('off')\n",
|
|
||||||
" plt.imshow(image)\n",
|
|
||||||
" gc.collect()\n",
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"accelerator": "GPU",
|
|
||||||
"colab": {
|
|
||||||
"collapsed_sections": [],
|
|
||||||
"private_outputs": true,
|
|
||||||
"provenance": []
|
|
||||||
},
|
|
||||||
"gpuClass": "standard",
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3.9.12 64-bit",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"name": "python",
|
|
||||||
"version": "3.9.12"
|
|
||||||
},
|
|
||||||
"vscode": {
|
|
||||||
"interpreter": {
|
|
||||||
"hash": "4e870c5c5fe42db7e2c5647ae5af656ff3391bf8c2b729cbf7fa0e16ca8cb5af"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 0
|
|
||||||
}
|
|
@ -1,339 +0,0 @@
|
|||||||
from torchvision.datasets.utils import download_url
|
|
||||||
from ldm.util import instantiate_from_config
|
|
||||||
import torch
|
|
||||||
import os
|
|
||||||
|
|
||||||
# todo ?
|
|
||||||
from google.colab import files
|
|
||||||
from IPython.display import Image as ipyimg
|
|
||||||
import ipywidgets as widgets
|
|
||||||
from PIL import Image
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
import torchvision
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from ldm.util import ismap
|
|
||||||
import time
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from ldm.invoke.devices import choose_torch_device
|
|
||||||
|
|
||||||
|
|
||||||
def download_models(mode):
|
|
||||||
if mode == "superresolution":
|
|
||||||
# this is the small bsr light model
|
|
||||||
url_conf = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
|
||||||
url_ckpt = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
|
||||||
|
|
||||||
path_conf = "logs/diffusion/superresolution_bsr/configs/project.yaml"
|
|
||||||
path_ckpt = "logs/diffusion/superresolution_bsr/checkpoints/last.ckpt"
|
|
||||||
|
|
||||||
download_url(url_conf, path_conf)
|
|
||||||
download_url(url_ckpt, path_ckpt)
|
|
||||||
|
|
||||||
path_conf = path_conf + "/?dl=1" # fix it
|
|
||||||
path_ckpt = path_ckpt + "/?dl=1" # fix it
|
|
||||||
return path_conf, path_ckpt
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt):
|
|
||||||
print(f"Loading model from {ckpt}")
|
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
||||||
global_step = pl_sd["global_step"]
|
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
|
||||||
model.cuda()
|
|
||||||
model.eval()
|
|
||||||
return {"model": model}, global_step
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(mode):
|
|
||||||
path_conf, path_ckpt = download_models(mode)
|
|
||||||
config = OmegaConf.load(path_conf)
|
|
||||||
model, step = load_model_from_config(config, path_ckpt)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def get_custom_cond(mode):
|
|
||||||
dest = "data/example_conditioning"
|
|
||||||
|
|
||||||
if mode == "superresolution":
|
|
||||||
uploaded_img = files.upload()
|
|
||||||
filename = next(iter(uploaded_img))
|
|
||||||
name, filetype = filename.split(".") # todo assumes just one dot in name !
|
|
||||||
os.rename(f"{filename}", f"{dest}/{mode}/custom_{name}.{filetype}")
|
|
||||||
|
|
||||||
elif mode == "text_conditional":
|
|
||||||
w = widgets.Text(value="A cake with cream!", disabled=True)
|
|
||||||
display(w) # noqa: F821
|
|
||||||
|
|
||||||
with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", "w") as f:
|
|
||||||
f.write(w.value)
|
|
||||||
|
|
||||||
elif mode == "class_conditional":
|
|
||||||
w = widgets.IntSlider(min=0, max=1000)
|
|
||||||
display(w) # noqa: F821
|
|
||||||
with open(f"{dest}/{mode}/custom.txt", "w") as f:
|
|
||||||
f.write(w.value)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"cond not implemented for mode{mode}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_cond_options(mode):
|
|
||||||
path = "data/example_conditioning"
|
|
||||||
path = os.path.join(path, mode)
|
|
||||||
onlyfiles = [f for f in sorted(os.listdir(path))]
|
|
||||||
return path, onlyfiles
|
|
||||||
|
|
||||||
|
|
||||||
def select_cond_path(mode):
|
|
||||||
path = "data/example_conditioning" # todo
|
|
||||||
path = os.path.join(path, mode)
|
|
||||||
onlyfiles = [f for f in sorted(os.listdir(path))]
|
|
||||||
|
|
||||||
selected = widgets.RadioButtons(options=onlyfiles, description="Select conditioning:", disabled=False)
|
|
||||||
display(selected) # noqa: F821
|
|
||||||
selected_path = os.path.join(path, selected.value)
|
|
||||||
return selected_path
|
|
||||||
|
|
||||||
|
|
||||||
def get_cond(mode, selected_path):
|
|
||||||
example = dict()
|
|
||||||
if mode == "superresolution":
|
|
||||||
up_f = 4
|
|
||||||
visualize_cond_img(selected_path)
|
|
||||||
|
|
||||||
c = Image.open(selected_path)
|
|
||||||
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
|
||||||
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)
|
|
||||||
c_up = rearrange(c_up, "1 c h w -> 1 h w c")
|
|
||||||
c = rearrange(c, "1 c h w -> 1 h w c")
|
|
||||||
c = 2.0 * c - 1.0
|
|
||||||
|
|
||||||
device = choose_torch_device()
|
|
||||||
c = c.to(device)
|
|
||||||
example["LR_image"] = c
|
|
||||||
example["image"] = c_up
|
|
||||||
|
|
||||||
return example
|
|
||||||
|
|
||||||
|
|
||||||
def visualize_cond_img(path):
|
|
||||||
display(ipyimg(filename=path)) # noqa: F821
|
|
||||||
|
|
||||||
|
|
||||||
def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None):
|
|
||||||
example = get_cond(task, selected_path)
|
|
||||||
|
|
||||||
save_intermediate_vid = False
|
|
||||||
n_runs = 1
|
|
||||||
masked = False
|
|
||||||
guider = None
|
|
||||||
ckwargs = None
|
|
||||||
mode = "ddim"
|
|
||||||
ddim_use_x0_pred = False
|
|
||||||
temperature = 1.0
|
|
||||||
eta = 1.0
|
|
||||||
make_progrow = True
|
|
||||||
custom_shape = None
|
|
||||||
|
|
||||||
height, width = example["image"].shape[1:3]
|
|
||||||
split_input = height >= 128 and width >= 128
|
|
||||||
|
|
||||||
if split_input:
|
|
||||||
ks = 128
|
|
||||||
stride = 64
|
|
||||||
vqf = 4 #
|
|
||||||
model.split_input_params = {
|
|
||||||
"ks": (ks, ks),
|
|
||||||
"stride": (stride, stride),
|
|
||||||
"vqf": vqf,
|
|
||||||
"patch_distributed_vq": True,
|
|
||||||
"tie_braker": False,
|
|
||||||
"clip_max_weight": 0.5,
|
|
||||||
"clip_min_weight": 0.01,
|
|
||||||
"clip_max_tie_weight": 0.5,
|
|
||||||
"clip_min_tie_weight": 0.01,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
if hasattr(model, "split_input_params"):
|
|
||||||
delattr(model, "split_input_params")
|
|
||||||
|
|
||||||
invert_mask = False
|
|
||||||
|
|
||||||
x_T = None
|
|
||||||
for n in range(n_runs):
|
|
||||||
if custom_shape is not None:
|
|
||||||
x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
|
||||||
x_T = repeat(x_T, "1 c h w -> b c h w", b=custom_shape[0])
|
|
||||||
|
|
||||||
logs = make_convolutional_sample(
|
|
||||||
example,
|
|
||||||
model,
|
|
||||||
mode=mode,
|
|
||||||
custom_steps=custom_steps,
|
|
||||||
eta=eta,
|
|
||||||
swap_mode=False,
|
|
||||||
masked=masked,
|
|
||||||
invert_mask=invert_mask,
|
|
||||||
quantize_x0=False,
|
|
||||||
custom_schedule=None,
|
|
||||||
decode_interval=10,
|
|
||||||
resize_enabled=resize_enabled,
|
|
||||||
custom_shape=custom_shape,
|
|
||||||
temperature=temperature,
|
|
||||||
noise_dropout=0.0,
|
|
||||||
corrector=guider,
|
|
||||||
corrector_kwargs=ckwargs,
|
|
||||||
x_T=x_T,
|
|
||||||
save_intermediate_vid=save_intermediate_vid,
|
|
||||||
make_progrow=make_progrow,
|
|
||||||
ddim_use_x0_pred=ddim_use_x0_pred,
|
|
||||||
)
|
|
||||||
return logs
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def convsample_ddim(
|
|
||||||
model,
|
|
||||||
cond,
|
|
||||||
steps,
|
|
||||||
shape,
|
|
||||||
eta=1.0,
|
|
||||||
callback=None,
|
|
||||||
normals_sequence=None,
|
|
||||||
mask=None,
|
|
||||||
x0=None,
|
|
||||||
quantize_x0=False,
|
|
||||||
img_callback=None,
|
|
||||||
temperature=1.0,
|
|
||||||
noise_dropout=0.0,
|
|
||||||
score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
x_T=None,
|
|
||||||
log_every_t=None,
|
|
||||||
):
|
|
||||||
ddim = DDIMSampler(model)
|
|
||||||
bs = shape[0] # dont know where this comes from but wayne
|
|
||||||
shape = shape[1:] # cut batch dim
|
|
||||||
print(f"Sampling with eta = {eta}; steps: {steps}")
|
|
||||||
samples, intermediates = ddim.sample(
|
|
||||||
steps,
|
|
||||||
batch_size=bs,
|
|
||||||
shape=shape,
|
|
||||||
conditioning=cond,
|
|
||||||
callback=callback,
|
|
||||||
normals_sequence=normals_sequence,
|
|
||||||
quantize_x0=quantize_x0,
|
|
||||||
eta=eta,
|
|
||||||
mask=mask,
|
|
||||||
x0=x0,
|
|
||||||
temperature=temperature,
|
|
||||||
verbose=False,
|
|
||||||
score_corrector=score_corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
x_T=x_T,
|
|
||||||
)
|
|
||||||
|
|
||||||
return samples, intermediates
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def make_convolutional_sample(
|
|
||||||
batch,
|
|
||||||
model,
|
|
||||||
mode="vanilla",
|
|
||||||
custom_steps=None,
|
|
||||||
eta=1.0,
|
|
||||||
swap_mode=False,
|
|
||||||
masked=False,
|
|
||||||
invert_mask=True,
|
|
||||||
quantize_x0=False,
|
|
||||||
custom_schedule=None,
|
|
||||||
decode_interval=1000,
|
|
||||||
resize_enabled=False,
|
|
||||||
custom_shape=None,
|
|
||||||
temperature=1.0,
|
|
||||||
noise_dropout=0.0,
|
|
||||||
corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
x_T=None,
|
|
||||||
save_intermediate_vid=False,
|
|
||||||
make_progrow=True,
|
|
||||||
ddim_use_x0_pred=False,
|
|
||||||
):
|
|
||||||
log = dict()
|
|
||||||
|
|
||||||
z, c, x, xrec, xc = model.get_input(
|
|
||||||
batch,
|
|
||||||
model.first_stage_key,
|
|
||||||
return_first_stage_outputs=True,
|
|
||||||
force_c_encode=not (hasattr(model, "split_input_params") and model.cond_stage_key == "coordinates_bbox"),
|
|
||||||
return_original_cond=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
log_every_t = 1 if save_intermediate_vid else None
|
|
||||||
|
|
||||||
if custom_shape is not None:
|
|
||||||
z = torch.randn(custom_shape)
|
|
||||||
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
|
|
||||||
|
|
||||||
z0 = None
|
|
||||||
|
|
||||||
log["input"] = x
|
|
||||||
log["reconstruction"] = xrec
|
|
||||||
|
|
||||||
if ismap(xc):
|
|
||||||
log["original_conditioning"] = model.to_rgb(xc)
|
|
||||||
if hasattr(model, "cond_stage_key"):
|
|
||||||
log[model.cond_stage_key] = model.to_rgb(xc)
|
|
||||||
|
|
||||||
else:
|
|
||||||
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
|
|
||||||
if model.cond_stage_model:
|
|
||||||
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
|
|
||||||
if model.cond_stage_key == "class_label":
|
|
||||||
log[model.cond_stage_key] = xc[model.cond_stage_key]
|
|
||||||
|
|
||||||
with model.ema_scope("Plotting"):
|
|
||||||
t0 = time.time()
|
|
||||||
img_cb = None
|
|
||||||
|
|
||||||
sample, intermediates = convsample_ddim(
|
|
||||||
model,
|
|
||||||
c,
|
|
||||||
steps=custom_steps,
|
|
||||||
shape=z.shape,
|
|
||||||
eta=eta,
|
|
||||||
quantize_x0=quantize_x0,
|
|
||||||
img_callback=img_cb,
|
|
||||||
mask=None,
|
|
||||||
x0=z0,
|
|
||||||
temperature=temperature,
|
|
||||||
noise_dropout=noise_dropout,
|
|
||||||
score_corrector=corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
x_T=x_T,
|
|
||||||
log_every_t=log_every_t,
|
|
||||||
)
|
|
||||||
t1 = time.time()
|
|
||||||
|
|
||||||
if ddim_use_x0_pred:
|
|
||||||
sample = intermediates["pred_x0"][-1]
|
|
||||||
|
|
||||||
x_sample = model.decode_first_stage(sample)
|
|
||||||
|
|
||||||
try:
|
|
||||||
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
|
||||||
log["sample_noquant"] = x_sample_noquant
|
|
||||||
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
log["sample"] = x_sample
|
|
||||||
log["time"] = t1 - t0
|
|
||||||
|
|
||||||
return log
|
|
@ -74,6 +74,7 @@ dependencies = [
|
|||||||
"rich~=13.3",
|
"rich~=13.3",
|
||||||
"safetensors==0.3.1",
|
"safetensors==0.3.1",
|
||||||
"scikit-image~=0.21.0",
|
"scikit-image~=0.21.0",
|
||||||
|
"semver~=3.0.1",
|
||||||
"send2trash",
|
"send2trash",
|
||||||
"test-tube~=0.7.5",
|
"test-tube~=0.7.5",
|
||||||
"torch~=2.0.1",
|
"torch~=2.0.1",
|
||||||
|
@ -1,52 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from diffusers.utils import load_image
|
|
||||||
from diffusers.models.controlnet import ControlNetModel
|
|
||||||
from invokeai.backend.generator import Txt2Img
|
|
||||||
from invokeai.backend.model_management import ModelManager
|
|
||||||
|
|
||||||
|
|
||||||
print("loading 'Girl with a Pearl Earring' image")
|
|
||||||
image = load_image(
|
|
||||||
"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
|
|
||||||
)
|
|
||||||
image.show()
|
|
||||||
|
|
||||||
print("preprocessing image with Canny edge detection")
|
|
||||||
image_np = np.array(image)
|
|
||||||
low_threshold = 100
|
|
||||||
high_threshold = 200
|
|
||||||
canny_np = cv2.Canny(image_np, low_threshold, high_threshold)
|
|
||||||
canny_image = Image.fromarray(canny_np)
|
|
||||||
canny_image.show()
|
|
||||||
|
|
||||||
# using invokeai model management for base model
|
|
||||||
print("loading base model stable-diffusion-1.5")
|
|
||||||
model_config_path = os.getcwd() + "/../configs/models.yaml"
|
|
||||||
model_manager = ModelManager(model_config_path)
|
|
||||||
model = model_manager.get_model("stable-diffusion-1.5")
|
|
||||||
|
|
||||||
print("loading control model lllyasviel/sd-controlnet-canny")
|
|
||||||
canny_controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16).to(
|
|
||||||
"cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("testing Txt2Img() constructor with control_model arg")
|
|
||||||
txt2img_canny = Txt2Img(model, control_model=canny_controlnet)
|
|
||||||
|
|
||||||
print("testing Txt2Img.generate() with control_image arg")
|
|
||||||
outputs = txt2img_canny.generate(
|
|
||||||
prompt="old man",
|
|
||||||
control_image=canny_image,
|
|
||||||
control_weight=1.0,
|
|
||||||
seed=0,
|
|
||||||
num_steps=30,
|
|
||||||
precision="float16",
|
|
||||||
)
|
|
||||||
generate_output = next(outputs)
|
|
||||||
out_image = generate_output.image
|
|
||||||
out_image.show()
|
|
@ -1,33 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
"""
|
|
||||||
Read a checkpoint/safetensors file and write out a template .json file containing
|
|
||||||
its metadata for use in fast model probing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from invokeai.backend.model_management.models.base import read_checkpoint_meta
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model")
|
|
||||||
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
|
|
||||||
parser.add_argument("--template", "--out", type=Path, help="Path to the output .json file")
|
|
||||||
|
|
||||||
opt = parser.parse_args()
|
|
||||||
ckpt = read_checkpoint_meta(opt.checkpoint)
|
|
||||||
while "state_dict" in ckpt:
|
|
||||||
ckpt = ckpt["state_dict"]
|
|
||||||
|
|
||||||
tmpl = {}
|
|
||||||
|
|
||||||
for key, tensor in ckpt.items():
|
|
||||||
tmpl[key] = list(tensor.shape)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(opt.template, "w") as f:
|
|
||||||
json.dump(tmpl, f)
|
|
||||||
print(f"Template written out as {opt.template}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An exception occurred while writing template: {str(e)}")
|
|
@ -1,14 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from invokeai.app.cli_app import invoke_cli
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
invoke_cli()
|
|
@ -1,4 +0,0 @@
|
|||||||
from invokeai.backend.install.migrate_to_3 import main
|
|
||||||
|
|
||||||
if __name__=='__main__':
|
|
||||||
main()
|
|
2
scripts/invokeai-model-install.py
Normal file → Executable file
2
scripts/invokeai-model-install.py
Normal file → Executable file
@ -1,3 +1,5 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
from invokeai.frontend.install.model_install import main
|
from invokeai.frontend.install.model_install import main
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
@ -1,41 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip
|
|
||||||
wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip
|
|
||||||
wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
|
|
||||||
wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip
|
|
||||||
wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip
|
|
||||||
wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip
|
|
||||||
wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
|
|
||||||
wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip
|
|
||||||
wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cd models/first_stage_models/kl-f4
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../kl-f8
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../kl-f16
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../kl-f32
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../vq-f4
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../vq-f4-noattn
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../vq-f8
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../vq-f8-n256
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../vq-f16
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../..
|
|
@ -1,49 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip
|
|
||||||
wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip
|
|
||||||
wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip
|
|
||||||
wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip
|
|
||||||
wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip
|
|
||||||
wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip
|
|
||||||
wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip
|
|
||||||
wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip
|
|
||||||
wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
|
|
||||||
wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip
|
|
||||||
wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cd models/ldm/celeba256
|
|
||||||
unzip -o celeba-256.zip
|
|
||||||
|
|
||||||
cd ../ffhq256
|
|
||||||
unzip -o ffhq-256.zip
|
|
||||||
|
|
||||||
cd ../lsun_churches256
|
|
||||||
unzip -o lsun_churches-256.zip
|
|
||||||
|
|
||||||
cd ../lsun_beds256
|
|
||||||
unzip -o lsun_beds-256.zip
|
|
||||||
|
|
||||||
cd ../text2img256
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../cin256
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../semantic_synthesis512
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../semantic_synthesis256
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../bsr_sr
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../layout2img-openimages256
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../inpainting_big
|
|
||||||
unzip -o model.zip
|
|
||||||
|
|
||||||
cd ../..
|
|
@ -1,285 +0,0 @@
|
|||||||
"""make variations of input image"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import PIL
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm, trange
|
|
||||||
from itertools import islice
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from torchvision.utils import make_grid
|
|
||||||
from torch import autocast
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from pytorch_lightning import seed_everything
|
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
|
||||||
from ldm.invoke.devices import choose_torch_device
|
|
||||||
|
|
||||||
|
|
||||||
def chunk(it, size):
|
|
||||||
it = iter(it)
|
|
||||||
return iter(lambda: tuple(islice(it, size)), ())
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt, verbose=False):
|
|
||||||
print(f"Loading model from {ckpt}")
|
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
||||||
if "global_step" in pl_sd:
|
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
|
||||||
if len(m) > 0 and verbose:
|
|
||||||
print("missing keys:")
|
|
||||||
print(m)
|
|
||||||
if len(u) > 0 and verbose:
|
|
||||||
print("unexpected keys:")
|
|
||||||
print(u)
|
|
||||||
|
|
||||||
model.to(choose_torch_device())
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_img(path):
|
|
||||||
image = Image.open(path).convert("RGB")
|
|
||||||
w, h = image.size
|
|
||||||
print(f"loaded input image of size ({w}, {h}) from {path}")
|
|
||||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
|
||||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
|
||||||
image = torch.from_numpy(image)
|
|
||||||
return 2.0 * image - 1.0
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
default="a painting of a virus monster playing guitar",
|
|
||||||
help="the prompt to render",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("--init-img", type=str, nargs="?", help="path to the input image")
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/img2img-samples"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_grid",
|
|
||||||
action="store_true",
|
|
||||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_save",
|
|
||||||
action="store_true",
|
|
||||||
help="do not save indiviual samples. For speed measurements.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--ddim_steps",
|
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="number of ddim sampling steps",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--plms",
|
|
||||||
action="store_true",
|
|
||||||
help="use plms sampling",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--fixed_code",
|
|
||||||
action="store_true",
|
|
||||||
help="if enabled, uses the same starting code across all samples ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--ddim_eta",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_iter",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="sample this often",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--C",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="latent channels",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--f",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="downsampling factor, most often 8 or 16",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_samples",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="how many samples to produce for each given prompt. A.k.a batch size",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_rows",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="rows in the grid (default: n_samples)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--scale",
|
|
||||||
type=float,
|
|
||||||
default=5.0,
|
|
||||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--strength",
|
|
||||||
type=float,
|
|
||||||
default=0.75,
|
|
||||||
help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--from-file",
|
|
||||||
type=str,
|
|
||||||
help="if specified, load prompts from this file",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--config",
|
|
||||||
type=str,
|
|
||||||
default="configs/stable-diffusion/v1-inference.yaml",
|
|
||||||
help="path to config which constructs model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ckpt",
|
|
||||||
type=str,
|
|
||||||
default="models/ldm/stable-diffusion-v1/model.ckpt",
|
|
||||||
help="path to checkpoint of model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--seed",
|
|
||||||
type=int,
|
|
||||||
default=42,
|
|
||||||
help="the seed (for reproducible sampling)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast"
|
|
||||||
)
|
|
||||||
|
|
||||||
opt = parser.parse_args()
|
|
||||||
seed_everything(opt.seed)
|
|
||||||
|
|
||||||
config = OmegaConf.load(f"{opt.config}")
|
|
||||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
|
||||||
|
|
||||||
device = torch.device(choose_torch_device())
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
if opt.plms:
|
|
||||||
raise NotImplementedError("PLMS sampler not (yet) supported")
|
|
||||||
sampler = PLMSSampler(model)
|
|
||||||
else:
|
|
||||||
sampler = DDIMSampler(model)
|
|
||||||
|
|
||||||
os.makedirs(opt.outdir, exist_ok=True)
|
|
||||||
outpath = opt.outdir
|
|
||||||
|
|
||||||
batch_size = opt.n_samples
|
|
||||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
|
||||||
if not opt.from_file:
|
|
||||||
prompt = opt.prompt
|
|
||||||
assert prompt is not None
|
|
||||||
data = [batch_size * [prompt]]
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f"reading prompts from {opt.from_file}")
|
|
||||||
with open(opt.from_file, "r") as f:
|
|
||||||
data = f.read().splitlines()
|
|
||||||
data = list(chunk(data, batch_size))
|
|
||||||
|
|
||||||
sample_path = os.path.join(outpath, "samples")
|
|
||||||
os.makedirs(sample_path, exist_ok=True)
|
|
||||||
base_count = len(os.listdir(sample_path))
|
|
||||||
grid_count = len(os.listdir(outpath)) - 1
|
|
||||||
|
|
||||||
assert os.path.isfile(opt.init_img)
|
|
||||||
init_image = load_img(opt.init_img).to(device)
|
|
||||||
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
|
|
||||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
|
||||||
|
|
||||||
sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)
|
|
||||||
|
|
||||||
assert 0.0 <= opt.strength <= 1.0, "can only work with strength in [0.0, 1.0]"
|
|
||||||
t_enc = int(opt.strength * opt.ddim_steps)
|
|
||||||
print(f"target t_enc is {t_enc} steps")
|
|
||||||
|
|
||||||
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
|
||||||
if device.type in ["mps", "cpu"]:
|
|
||||||
precision_scope = nullcontext # have to use f32 on mps
|
|
||||||
with torch.no_grad():
|
|
||||||
with precision_scope(device.type):
|
|
||||||
with model.ema_scope():
|
|
||||||
all_samples = list()
|
|
||||||
for n in trange(opt.n_iter, desc="Sampling"):
|
|
||||||
for prompts in tqdm(data, desc="data"):
|
|
||||||
uc = None
|
|
||||||
if opt.scale != 1.0:
|
|
||||||
uc = model.get_learned_conditioning(batch_size * [""])
|
|
||||||
if isinstance(prompts, tuple):
|
|
||||||
prompts = list(prompts)
|
|
||||||
c = model.get_learned_conditioning(prompts)
|
|
||||||
|
|
||||||
# encode (scaled latent)
|
|
||||||
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device))
|
|
||||||
# decode it
|
|
||||||
samples = sampler.decode(
|
|
||||||
z_enc,
|
|
||||||
c,
|
|
||||||
t_enc,
|
|
||||||
unconditional_guidance_scale=opt.scale,
|
|
||||||
unconditional_conditioning=uc,
|
|
||||||
)
|
|
||||||
|
|
||||||
x_samples = model.decode_first_stage(samples)
|
|
||||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
|
|
||||||
if not opt.skip_save:
|
|
||||||
for x_sample in x_samples:
|
|
||||||
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
|
|
||||||
Image.fromarray(x_sample.astype(np.uint8)).save(
|
|
||||||
os.path.join(sample_path, f"{base_count:05}.png")
|
|
||||||
)
|
|
||||||
base_count += 1
|
|
||||||
all_samples.append(x_samples)
|
|
||||||
|
|
||||||
if not opt.skip_grid:
|
|
||||||
# additionally, save as grid
|
|
||||||
grid = torch.stack(all_samples, 0)
|
|
||||||
grid = rearrange(grid, "n b c h w -> (n b) c h w")
|
|
||||||
grid = make_grid(grid, nrow=n_rows)
|
|
||||||
|
|
||||||
# to image
|
|
||||||
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
|
|
||||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
|
|
||||||
grid_count += 1
|
|
||||||
|
|
||||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,94 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from main import instantiate_from_config
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from ldm.invoke.devices import choose_torch_device
|
|
||||||
|
|
||||||
|
|
||||||
def make_batch(image, mask, device):
|
|
||||||
image = np.array(Image.open(image).convert("RGB"))
|
|
||||||
image = image.astype(np.float32) / 255.0
|
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
|
||||||
image = torch.from_numpy(image)
|
|
||||||
|
|
||||||
mask = np.array(Image.open(mask).convert("L"))
|
|
||||||
mask = mask.astype(np.float32) / 255.0
|
|
||||||
mask = mask[None, None]
|
|
||||||
mask[mask < 0.5] = 0
|
|
||||||
mask[mask >= 0.5] = 1
|
|
||||||
mask = torch.from_numpy(mask)
|
|
||||||
|
|
||||||
masked_image = (1 - mask) * image
|
|
||||||
|
|
||||||
batch = {"image": image, "mask": mask, "masked_image": masked_image}
|
|
||||||
for k in batch:
|
|
||||||
batch[k] = batch[k].to(device=device)
|
|
||||||
batch[k] = batch[k] * 2.0 - 1.0
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--indir",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--outdir",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
help="dir to write results to",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--steps",
|
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="number of ddim sampling steps",
|
|
||||||
)
|
|
||||||
opt = parser.parse_args()
|
|
||||||
|
|
||||||
masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
|
|
||||||
images = [x.replace("_mask.png", ".png") for x in masks]
|
|
||||||
print(f"Found {len(masks)} inputs.")
|
|
||||||
|
|
||||||
config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], strict=False)
|
|
||||||
|
|
||||||
device = choose_torch_device()
|
|
||||||
model = model.to(device)
|
|
||||||
sampler = DDIMSampler(model)
|
|
||||||
|
|
||||||
os.makedirs(opt.outdir, exist_ok=True)
|
|
||||||
with torch.no_grad():
|
|
||||||
with model.ema_scope():
|
|
||||||
for image, mask in tqdm(zip(images, masks)):
|
|
||||||
outpath = os.path.join(opt.outdir, os.path.split(image)[1])
|
|
||||||
batch = make_batch(image, mask, device=device)
|
|
||||||
|
|
||||||
# encode masked image and concat downsampled mask
|
|
||||||
c = model.cond_stage_model.encode(batch["masked_image"])
|
|
||||||
cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:])
|
|
||||||
c = torch.cat((c, cc), dim=1)
|
|
||||||
|
|
||||||
shape = (c.shape[1] - 1,) + c.shape[2:]
|
|
||||||
samples_ddim, _ = sampler.sample(
|
|
||||||
S=opt.steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False
|
|
||||||
)
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
|
||||||
|
|
||||||
image = torch.clamp((batch["image"] + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
|
|
||||||
inpainted = (1 - mask) * image + mask * predicted_image
|
|
||||||
inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
|
|
||||||
Image.fromarray(inpainted.astype(np.uint8)).save(outpath)
|
|
@ -1,397 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm, trange
|
|
||||||
from itertools import islice
|
|
||||||
from einops import rearrange
|
|
||||||
from torchvision.utils import make_grid
|
|
||||||
import scann
|
|
||||||
import time
|
|
||||||
from multiprocessing import cpu_count
|
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config, parallel_data_prefetch
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
|
||||||
from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder
|
|
||||||
|
|
||||||
DATABASES = [
|
|
||||||
"openimages",
|
|
||||||
"artbench-art_nouveau",
|
|
||||||
"artbench-baroque",
|
|
||||||
"artbench-expressionism",
|
|
||||||
"artbench-impressionism",
|
|
||||||
"artbench-post_impressionism",
|
|
||||||
"artbench-realism",
|
|
||||||
"artbench-romanticism",
|
|
||||||
"artbench-renaissance",
|
|
||||||
"artbench-surrealism",
|
|
||||||
"artbench-ukiyo_e",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def chunk(it, size):
|
|
||||||
it = iter(it)
|
|
||||||
return iter(lambda: tuple(islice(it, size)), ())
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt, verbose=False):
|
|
||||||
print(f"Loading model from {ckpt}")
|
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
||||||
if "global_step" in pl_sd:
|
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
|
||||||
if len(m) > 0 and verbose:
|
|
||||||
print("missing keys:")
|
|
||||||
print(m)
|
|
||||||
if len(u) > 0 and verbose:
|
|
||||||
print("unexpected keys:")
|
|
||||||
print(u)
|
|
||||||
|
|
||||||
model.cuda()
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class Searcher(object):
|
|
||||||
def __init__(self, database, retriever_version="ViT-L/14"):
|
|
||||||
assert database in DATABASES
|
|
||||||
# self.database = self.load_database(database)
|
|
||||||
self.database_name = database
|
|
||||||
self.searcher_savedir = f"data/rdm/searchers/{self.database_name}"
|
|
||||||
self.database_path = f"data/rdm/retrieval_databases/{self.database_name}"
|
|
||||||
self.retriever = self.load_retriever(version=retriever_version)
|
|
||||||
self.database = {"embedding": [], "img_id": [], "patch_coords": []}
|
|
||||||
self.load_database()
|
|
||||||
self.load_searcher()
|
|
||||||
|
|
||||||
def train_searcher(self, k, metric="dot_product", searcher_savedir=None):
|
|
||||||
print("Start training searcher")
|
|
||||||
searcher = scann.scann_ops_pybind.builder(
|
|
||||||
self.database["embedding"] / np.linalg.norm(self.database["embedding"], axis=1)[:, np.newaxis], k, metric
|
|
||||||
)
|
|
||||||
self.searcher = searcher.score_brute_force().build()
|
|
||||||
print("Finish training searcher")
|
|
||||||
|
|
||||||
if searcher_savedir is not None:
|
|
||||||
print(f'Save trained searcher under "{searcher_savedir}"')
|
|
||||||
os.makedirs(searcher_savedir, exist_ok=True)
|
|
||||||
self.searcher.serialize(searcher_savedir)
|
|
||||||
|
|
||||||
def load_single_file(self, saved_embeddings):
|
|
||||||
compressed = np.load(saved_embeddings)
|
|
||||||
self.database = {key: compressed[key] for key in compressed.files}
|
|
||||||
print("Finished loading of clip embeddings.")
|
|
||||||
|
|
||||||
def load_multi_files(self, data_archive):
|
|
||||||
out_data = {key: [] for key in self.database}
|
|
||||||
for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."):
|
|
||||||
for key in d.files:
|
|
||||||
out_data[key].append(d[key])
|
|
||||||
|
|
||||||
return out_data
|
|
||||||
|
|
||||||
def load_database(self):
|
|
||||||
print(f'Load saved patch embedding from "{self.database_path}"')
|
|
||||||
file_content = glob.glob(os.path.join(self.database_path, "*.npz"))
|
|
||||||
|
|
||||||
if len(file_content) == 1:
|
|
||||||
self.load_single_file(file_content[0])
|
|
||||||
elif len(file_content) > 1:
|
|
||||||
data = [np.load(f) for f in file_content]
|
|
||||||
prefetched_data = parallel_data_prefetch(
|
|
||||||
self.load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.database = {
|
|
||||||
key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in self.database
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')
|
|
||||||
|
|
||||||
print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')
|
|
||||||
|
|
||||||
def load_retriever(
|
|
||||||
self,
|
|
||||||
version="ViT-L/14",
|
|
||||||
):
|
|
||||||
model = FrozenClipImageEmbedder(model=version)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
model.cuda()
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
def load_searcher(self):
|
|
||||||
print(f"load searcher for database {self.database_name} from {self.searcher_savedir}")
|
|
||||||
self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
|
|
||||||
print("Finished loading searcher.")
|
|
||||||
|
|
||||||
def search(self, x, k):
|
|
||||||
if self.searcher is None and self.database["embedding"].shape[0] < 2e4:
|
|
||||||
self.train_searcher(k) # quickly fit searcher on the fly for small databases
|
|
||||||
assert self.searcher is not None, "Cannot search with uninitialized searcher"
|
|
||||||
if isinstance(x, torch.Tensor):
|
|
||||||
x = x.detach().cpu().numpy()
|
|
||||||
if len(x.shape) == 3:
|
|
||||||
x = x[:, 0]
|
|
||||||
query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis]
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
|
|
||||||
end = time.time()
|
|
||||||
|
|
||||||
out_embeddings = self.database["embedding"][nns]
|
|
||||||
out_img_ids = self.database["img_id"][nns]
|
|
||||||
out_pc = self.database["patch_coords"][nns]
|
|
||||||
|
|
||||||
out = {
|
|
||||||
"nn_embeddings": out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
|
|
||||||
"img_ids": out_img_ids,
|
|
||||||
"patch_coords": out_pc,
|
|
||||||
"queries": x,
|
|
||||||
"exec_time": end - start,
|
|
||||||
"nns": nns,
|
|
||||||
"q_embeddings": query_embeddings,
|
|
||||||
}
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __call__(self, x, n):
|
|
||||||
return self.search(x, n)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
# TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc)
|
|
||||||
# TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt?
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
default="a painting of a virus monster playing guitar",
|
|
||||||
help="the prompt to render",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_grid",
|
|
||||||
action="store_true",
|
|
||||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--ddim_steps",
|
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="number of ddim sampling steps",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_repeat",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="number of repeats in CLIP latent space",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--plms",
|
|
||||||
action="store_true",
|
|
||||||
help="use plms sampling",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--ddim_eta",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_iter",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="sample this often",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--H",
|
|
||||||
type=int,
|
|
||||||
default=768,
|
|
||||||
help="image height, in pixel space",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--W",
|
|
||||||
type=int,
|
|
||||||
default=768,
|
|
||||||
help="image width, in pixel space",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_samples",
|
|
||||||
type=int,
|
|
||||||
default=3,
|
|
||||||
help="how many samples to produce for each given prompt. A.k.a batch size",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_rows",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="rows in the grid (default: n_samples)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--scale",
|
|
||||||
type=float,
|
|
||||||
default=5.0,
|
|
||||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--from-file",
|
|
||||||
type=str,
|
|
||||||
help="if specified, load prompts from this file",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--config",
|
|
||||||
type=str,
|
|
||||||
default="configs/retrieval-augmented-diffusion/768x768.yaml",
|
|
||||||
help="path to config which constructs model",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--ckpt",
|
|
||||||
type=str,
|
|
||||||
default="models/rdm/rdm768x768/model.ckpt",
|
|
||||||
help="path to checkpoint of model",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--clip_type",
|
|
||||||
type=str,
|
|
||||||
default="ViT-L/14",
|
|
||||||
help="which CLIP model to use for retrieval and NN encoding",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--database",
|
|
||||||
type=str,
|
|
||||||
default="artbench-surrealism",
|
|
||||||
choices=DATABASES,
|
|
||||||
help="The database used for the search, only applied when --use_neighbors=True",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_neighbors",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="Include neighbors in addition to text prompt for conditioning",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--knn",
|
|
||||||
default=10,
|
|
||||||
type=int,
|
|
||||||
help="The number of included neighbors, only applied when --use_neighbors=True",
|
|
||||||
)
|
|
||||||
|
|
||||||
opt = parser.parse_args()
|
|
||||||
|
|
||||||
config = OmegaConf.load(f"{opt.config}")
|
|
||||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
|
||||||
|
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device)
|
|
||||||
|
|
||||||
if opt.plms:
|
|
||||||
sampler = PLMSSampler(model)
|
|
||||||
else:
|
|
||||||
sampler = DDIMSampler(model)
|
|
||||||
|
|
||||||
os.makedirs(opt.outdir, exist_ok=True)
|
|
||||||
outpath = opt.outdir
|
|
||||||
|
|
||||||
batch_size = opt.n_samples
|
|
||||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
|
||||||
if not opt.from_file:
|
|
||||||
prompt = opt.prompt
|
|
||||||
assert prompt is not None
|
|
||||||
data = [batch_size * [prompt]]
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f"reading prompts from {opt.from_file}")
|
|
||||||
with open(opt.from_file, "r") as f:
|
|
||||||
data = f.read().splitlines()
|
|
||||||
data = list(chunk(data, batch_size))
|
|
||||||
|
|
||||||
sample_path = os.path.join(outpath, "samples")
|
|
||||||
os.makedirs(sample_path, exist_ok=True)
|
|
||||||
base_count = len(os.listdir(sample_path))
|
|
||||||
grid_count = len(os.listdir(outpath)) - 1
|
|
||||||
|
|
||||||
print(f"sampling scale for cfg is {opt.scale:.2f}")
|
|
||||||
|
|
||||||
searcher = None
|
|
||||||
if opt.use_neighbors:
|
|
||||||
searcher = Searcher(opt.database)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
with model.ema_scope():
|
|
||||||
for n in trange(opt.n_iter, desc="Sampling"):
|
|
||||||
all_samples = list()
|
|
||||||
for prompts in tqdm(data, desc="data"):
|
|
||||||
print("sampling prompts:", prompts)
|
|
||||||
if isinstance(prompts, tuple):
|
|
||||||
prompts = list(prompts)
|
|
||||||
c = clip_text_encoder.encode(prompts)
|
|
||||||
uc = None
|
|
||||||
if searcher is not None:
|
|
||||||
nn_dict = searcher(c, opt.knn)
|
|
||||||
c = torch.cat([c, torch.from_numpy(nn_dict["nn_embeddings"]).cuda()], dim=1)
|
|
||||||
if opt.scale != 1.0:
|
|
||||||
uc = torch.zeros_like(c)
|
|
||||||
if isinstance(prompts, tuple):
|
|
||||||
prompts = list(prompts)
|
|
||||||
shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model
|
|
||||||
samples_ddim, _ = sampler.sample(
|
|
||||||
S=opt.ddim_steps,
|
|
||||||
conditioning=c,
|
|
||||||
batch_size=c.shape[0],
|
|
||||||
shape=shape,
|
|
||||||
verbose=False,
|
|
||||||
unconditional_guidance_scale=opt.scale,
|
|
||||||
unconditional_conditioning=uc,
|
|
||||||
eta=opt.ddim_eta,
|
|
||||||
)
|
|
||||||
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
|
|
||||||
for x_sample in x_samples_ddim:
|
|
||||||
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
|
|
||||||
Image.fromarray(x_sample.astype(np.uint8)).save(
|
|
||||||
os.path.join(sample_path, f"{base_count:05}.png")
|
|
||||||
)
|
|
||||||
base_count += 1
|
|
||||||
all_samples.append(x_samples_ddim)
|
|
||||||
|
|
||||||
if not opt.skip_grid:
|
|
||||||
# additionally, save as grid
|
|
||||||
grid = torch.stack(all_samples, 0)
|
|
||||||
grid = rearrange(grid, "n b c h w -> (n b) c h w")
|
|
||||||
grid = make_grid(grid, nrow=n_rows)
|
|
||||||
|
|
||||||
# to image
|
|
||||||
grid_np = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
|
|
||||||
Image.fromarray(grid_np.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
|
|
||||||
grid_count += 1
|
|
||||||
|
|
||||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
|
|
File diff suppressed because one or more lines are too long
@ -1,898 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import datetime
|
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
|
|
||||||
from packaging import version
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
from functools import partial
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from pytorch_lightning import seed_everything
|
|
||||||
from pytorch_lightning.trainer import Trainer
|
|
||||||
from pytorch_lightning.callbacks import Callback
|
|
||||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
|
||||||
from pytorch_lightning.utilities import rank_zero_info
|
|
||||||
|
|
||||||
from ldm.data.base import Txt2ImgIterableBaseDataset
|
|
||||||
from ldm.util import instantiate_from_config
|
|
||||||
|
|
||||||
|
|
||||||
def fix_func(orig):
|
|
||||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
||||||
|
|
||||||
def new_func(*args, **kw):
|
|
||||||
device = kw.get("device", "mps")
|
|
||||||
kw["device"] = "cpu"
|
|
||||||
return orig(*args, **kw).to(device)
|
|
||||||
|
|
||||||
return new_func
|
|
||||||
return orig
|
|
||||||
|
|
||||||
|
|
||||||
torch.rand = fix_func(torch.rand)
|
|
||||||
torch.rand_like = fix_func(torch.rand_like)
|
|
||||||
torch.randn = fix_func(torch.randn)
|
|
||||||
torch.randn_like = fix_func(torch.randn_like)
|
|
||||||
torch.randint = fix_func(torch.randint)
|
|
||||||
torch.randint_like = fix_func(torch.randint_like)
|
|
||||||
torch.bernoulli = fix_func(torch.bernoulli)
|
|
||||||
torch.multinomial = fix_func(torch.multinomial)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt, verbose=False):
|
|
||||||
print(f"Loading model from {ckpt}")
|
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
config.model.params.ckpt_path = ckpt
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
|
||||||
if len(m) > 0 and verbose:
|
|
||||||
print("missing keys:")
|
|
||||||
print(m)
|
|
||||||
if len(u) > 0 and verbose:
|
|
||||||
print("unexpected keys:")
|
|
||||||
print(u)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
model.cuda()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser(**parser_kwargs):
|
|
||||||
def str2bool(v):
|
|
||||||
if isinstance(v, bool):
|
|
||||||
return v
|
|
||||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
|
||||||
return True
|
|
||||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(**parser_kwargs)
|
|
||||||
parser.add_argument(
|
|
||||||
"-n",
|
|
||||||
"--name",
|
|
||||||
type=str,
|
|
||||||
const=True,
|
|
||||||
default="",
|
|
||||||
nargs="?",
|
|
||||||
help="postfix for logdir",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-r",
|
|
||||||
"--resume",
|
|
||||||
type=str,
|
|
||||||
const=True,
|
|
||||||
default="",
|
|
||||||
nargs="?",
|
|
||||||
help="resume from logdir or checkpoint in logdir",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-b",
|
|
||||||
"--base",
|
|
||||||
nargs="*",
|
|
||||||
metavar="base_config.yaml",
|
|
||||||
help="paths to base configs. Loaded from left-to-right. "
|
|
||||||
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
|
||||||
default=list(),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-t",
|
|
||||||
"--train",
|
|
||||||
type=str2bool,
|
|
||||||
const=True,
|
|
||||||
default=False,
|
|
||||||
nargs="?",
|
|
||||||
help="train",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-test",
|
|
||||||
type=str2bool,
|
|
||||||
const=True,
|
|
||||||
default=False,
|
|
||||||
nargs="?",
|
|
||||||
help="disable test",
|
|
||||||
)
|
|
||||||
parser.add_argument("-p", "--project", help="name of new or path to existing project")
|
|
||||||
parser.add_argument(
|
|
||||||
"-d",
|
|
||||||
"--debug",
|
|
||||||
type=str2bool,
|
|
||||||
nargs="?",
|
|
||||||
const=True,
|
|
||||||
default=False,
|
|
||||||
help="enable post-mortem debugging",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-s",
|
|
||||||
"--seed",
|
|
||||||
type=int,
|
|
||||||
default=23,
|
|
||||||
help="seed for seed_everything",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-f",
|
|
||||||
"--postfix",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="post-postfix for default name",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-l",
|
|
||||||
"--logdir",
|
|
||||||
type=str,
|
|
||||||
default="logs",
|
|
||||||
help="directory for logging dat shit",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--scale_lr",
|
|
||||||
type=str2bool,
|
|
||||||
nargs="?",
|
|
||||||
const=True,
|
|
||||||
default=True,
|
|
||||||
help="scale base-lr by ngpu * batch_size * n_accumulate",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--datadir_in_name",
|
|
||||||
type=str2bool,
|
|
||||||
nargs="?",
|
|
||||||
const=True,
|
|
||||||
default=True,
|
|
||||||
help="Prepend the final directory in the data_root to the output directory name",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--actual_resume",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Path to model to actually resume from",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--data_root",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to directory with training images",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--embedding_manager_ckpt",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Initialize embedding manager from a checkpoint",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--init_word",
|
|
||||||
type=str,
|
|
||||||
help="Word to use as source for initial token embedding.",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def nondefault_trainer_args(opt):
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser = Trainer.add_argparse_args(parser)
|
|
||||||
args = parser.parse_args([])
|
|
||||||
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
|
|
||||||
|
|
||||||
|
|
||||||
class WrappedDataset(Dataset):
|
|
||||||
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
|
||||||
|
|
||||||
def __init__(self, dataset):
|
|
||||||
self.data = dataset
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return self.data[idx]
|
|
||||||
|
|
||||||
|
|
||||||
def worker_init_fn(_):
|
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
|
||||||
|
|
||||||
dataset = worker_info.dataset
|
|
||||||
worker_id = worker_info.id
|
|
||||||
|
|
||||||
if isinstance(dataset, Txt2ImgIterableBaseDataset):
|
|
||||||
split_size = dataset.num_records // worker_info.num_workers
|
|
||||||
# reset num_records to the true number to retain reliable length information
|
|
||||||
dataset.sample_ids = dataset.valid_ids[worker_id * split_size : (worker_id + 1) * split_size]
|
|
||||||
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
|
||||||
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
|
||||||
else:
|
|
||||||
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
|
||||||
|
|
||||||
|
|
||||||
class DataModuleFromConfig(pl.LightningDataModule):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
batch_size,
|
|
||||||
train=None,
|
|
||||||
validation=None,
|
|
||||||
test=None,
|
|
||||||
predict=None,
|
|
||||||
wrap=False,
|
|
||||||
num_workers=None,
|
|
||||||
shuffle_test_loader=False,
|
|
||||||
use_worker_init_fn=False,
|
|
||||||
shuffle_val_dataloader=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.dataset_configs = dict()
|
|
||||||
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
|
||||||
self.use_worker_init_fn = use_worker_init_fn
|
|
||||||
if train is not None:
|
|
||||||
self.dataset_configs["train"] = train
|
|
||||||
self.train_dataloader = self._train_dataloader
|
|
||||||
if validation is not None:
|
|
||||||
self.dataset_configs["validation"] = validation
|
|
||||||
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
|
|
||||||
if test is not None:
|
|
||||||
self.dataset_configs["test"] = test
|
|
||||||
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
|
|
||||||
if predict is not None:
|
|
||||||
self.dataset_configs["predict"] = predict
|
|
||||||
self.predict_dataloader = self._predict_dataloader
|
|
||||||
self.wrap = wrap
|
|
||||||
|
|
||||||
def prepare_data(self):
|
|
||||||
for data_cfg in self.dataset_configs.values():
|
|
||||||
instantiate_from_config(data_cfg)
|
|
||||||
|
|
||||||
def setup(self, stage=None):
|
|
||||||
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
|
|
||||||
if self.wrap:
|
|
||||||
for k in self.datasets:
|
|
||||||
self.datasets[k] = WrappedDataset(self.datasets[k])
|
|
||||||
|
|
||||||
def _train_dataloader(self):
|
|
||||||
is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset)
|
|
||||||
if is_iterable_dataset or self.use_worker_init_fn:
|
|
||||||
init_fn = worker_init_fn
|
|
||||||
else:
|
|
||||||
init_fn = None
|
|
||||||
return DataLoader(
|
|
||||||
self.datasets["train"],
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
num_workers=self.num_workers,
|
|
||||||
shuffle=False if is_iterable_dataset else True,
|
|
||||||
worker_init_fn=init_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _val_dataloader(self, shuffle=False):
|
|
||||||
if isinstance(self.datasets["validation"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
|
||||||
init_fn = worker_init_fn
|
|
||||||
else:
|
|
||||||
init_fn = None
|
|
||||||
return DataLoader(
|
|
||||||
self.datasets["validation"],
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
num_workers=self.num_workers,
|
|
||||||
worker_init_fn=init_fn,
|
|
||||||
shuffle=shuffle,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _test_dataloader(self, shuffle=False):
|
|
||||||
is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset)
|
|
||||||
if is_iterable_dataset or self.use_worker_init_fn:
|
|
||||||
init_fn = worker_init_fn
|
|
||||||
else:
|
|
||||||
init_fn = None
|
|
||||||
|
|
||||||
# do not shuffle dataloader for iterable dataset
|
|
||||||
shuffle = shuffle and (not is_iterable_dataset)
|
|
||||||
|
|
||||||
return DataLoader(
|
|
||||||
self.datasets["test"],
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
num_workers=self.num_workers,
|
|
||||||
worker_init_fn=init_fn,
|
|
||||||
shuffle=shuffle,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _predict_dataloader(self, shuffle=False):
|
|
||||||
if isinstance(self.datasets["predict"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
|
||||||
init_fn = worker_init_fn
|
|
||||||
else:
|
|
||||||
init_fn = None
|
|
||||||
return DataLoader(
|
|
||||||
self.datasets["predict"],
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
num_workers=self.num_workers,
|
|
||||||
worker_init_fn=init_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SetupCallback(Callback):
|
|
||||||
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
|
|
||||||
super().__init__()
|
|
||||||
self.resume = resume
|
|
||||||
self.now = now
|
|
||||||
self.logdir = logdir
|
|
||||||
self.ckptdir = ckptdir
|
|
||||||
self.cfgdir = cfgdir
|
|
||||||
self.config = config
|
|
||||||
self.lightning_config = lightning_config
|
|
||||||
|
|
||||||
def on_keyboard_interrupt(self, trainer, pl_module):
|
|
||||||
if trainer.global_rank == 0:
|
|
||||||
print("Summoning checkpoint.")
|
|
||||||
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
|
||||||
trainer.save_checkpoint(ckpt_path)
|
|
||||||
|
|
||||||
def on_pretrain_routine_start(self, trainer, pl_module):
|
|
||||||
if trainer.global_rank == 0:
|
|
||||||
# Create logdirs and save configs
|
|
||||||
os.makedirs(self.logdir, exist_ok=True)
|
|
||||||
os.makedirs(self.ckptdir, exist_ok=True)
|
|
||||||
os.makedirs(self.cfgdir, exist_ok=True)
|
|
||||||
|
|
||||||
if "callbacks" in self.lightning_config:
|
|
||||||
if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]:
|
|
||||||
os.makedirs(
|
|
||||||
os.path.join(self.ckptdir, "trainstep_checkpoints"),
|
|
||||||
exist_ok=True,
|
|
||||||
)
|
|
||||||
print("Project config")
|
|
||||||
print(OmegaConf.to_yaml(self.config))
|
|
||||||
OmegaConf.save(
|
|
||||||
self.config,
|
|
||||||
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Lightning config")
|
|
||||||
print(OmegaConf.to_yaml(self.lightning_config))
|
|
||||||
OmegaConf.save(
|
|
||||||
OmegaConf.create({"lightning": self.lightning_config}),
|
|
||||||
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# ModelCheckpoint callback created log directory --- remove it
|
|
||||||
if not self.resume and os.path.exists(self.logdir):
|
|
||||||
dst, name = os.path.split(self.logdir)
|
|
||||||
dst = os.path.join(dst, "child_runs", name)
|
|
||||||
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
|
||||||
try:
|
|
||||||
os.rename(self.logdir, dst)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ImageLogger(Callback):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
batch_frequency,
|
|
||||||
max_images,
|
|
||||||
clamp=True,
|
|
||||||
increase_log_steps=True,
|
|
||||||
rescale=True,
|
|
||||||
disabled=False,
|
|
||||||
log_on_batch_idx=False,
|
|
||||||
log_first_step=False,
|
|
||||||
log_images_kwargs=None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.rescale = rescale
|
|
||||||
self.batch_freq = batch_frequency
|
|
||||||
self.max_images = max_images
|
|
||||||
self.logger_log_images = {}
|
|
||||||
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
|
||||||
if not increase_log_steps:
|
|
||||||
self.log_steps = [self.batch_freq]
|
|
||||||
self.clamp = clamp
|
|
||||||
self.disabled = disabled
|
|
||||||
self.log_on_batch_idx = log_on_batch_idx
|
|
||||||
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
|
||||||
self.log_first_step = log_first_step
|
|
||||||
|
|
||||||
@rank_zero_only
|
|
||||||
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
|
|
||||||
root = os.path.join(save_dir, "images", split)
|
|
||||||
for k in images:
|
|
||||||
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
|
||||||
if self.rescale:
|
|
||||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
|
||||||
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
|
||||||
grid = grid.numpy()
|
|
||||||
grid = (grid * 255).astype(np.uint8)
|
|
||||||
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
|
|
||||||
path = os.path.join(root, filename)
|
|
||||||
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
|
||||||
Image.fromarray(grid).save(path)
|
|
||||||
|
|
||||||
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
|
||||||
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
|
||||||
if (
|
|
||||||
self.check_frequency(check_idx)
|
|
||||||
and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
|
|
||||||
and callable(pl_module.log_images)
|
|
||||||
and self.max_images > 0
|
|
||||||
):
|
|
||||||
logger = type(pl_module.logger)
|
|
||||||
|
|
||||||
is_train = pl_module.training
|
|
||||||
if is_train:
|
|
||||||
pl_module.eval()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
|
|
||||||
|
|
||||||
for k in images:
|
|
||||||
N = min(images[k].shape[0], self.max_images)
|
|
||||||
images[k] = images[k][:N]
|
|
||||||
if isinstance(images[k], torch.Tensor):
|
|
||||||
images[k] = images[k].detach().cpu()
|
|
||||||
if self.clamp:
|
|
||||||
images[k] = torch.clamp(images[k], -1.0, 1.0)
|
|
||||||
|
|
||||||
self.log_local(
|
|
||||||
pl_module.logger.save_dir,
|
|
||||||
split,
|
|
||||||
images,
|
|
||||||
pl_module.global_step,
|
|
||||||
pl_module.current_epoch,
|
|
||||||
batch_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
|
||||||
logger_log_images(pl_module, images, pl_module.global_step, split)
|
|
||||||
|
|
||||||
if is_train:
|
|
||||||
pl_module.train()
|
|
||||||
|
|
||||||
def check_frequency(self, check_idx):
|
|
||||||
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
|
|
||||||
check_idx > 0 or self.log_first_step
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
self.log_steps.pop(0)
|
|
||||||
except IndexError as e:
|
|
||||||
print(e)
|
|
||||||
pass
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None):
|
|
||||||
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
|
||||||
self.log_img(pl_module, batch, batch_idx, split="train")
|
|
||||||
|
|
||||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None):
|
|
||||||
if not self.disabled and pl_module.global_step > 0:
|
|
||||||
self.log_img(pl_module, batch, batch_idx, split="val")
|
|
||||||
if hasattr(pl_module, "calibrate_grad_norm"):
|
|
||||||
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
|
|
||||||
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
|
||||||
|
|
||||||
|
|
||||||
class CUDACallback(Callback):
|
|
||||||
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
|
||||||
def on_train_epoch_start(self, trainer, pl_module):
|
|
||||||
# Reset the memory use counter
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
|
|
||||||
torch.cuda.synchronize(trainer.root_gpu)
|
|
||||||
self.start_time = time.time()
|
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module, outputs=None):
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize(trainer.root_gpu)
|
|
||||||
epoch_time = time.time() - self.start_time
|
|
||||||
|
|
||||||
try:
|
|
||||||
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
|
||||||
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20
|
|
||||||
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
|
||||||
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ModeSwapCallback(Callback):
|
|
||||||
def __init__(self, swap_step=2000):
|
|
||||||
super().__init__()
|
|
||||||
self.is_frozen = False
|
|
||||||
self.swap_step = swap_step
|
|
||||||
|
|
||||||
def on_train_epoch_start(self, trainer, pl_module):
|
|
||||||
if trainer.global_step < self.swap_step and not self.is_frozen:
|
|
||||||
self.is_frozen = True
|
|
||||||
trainer.optimizers = [pl_module.configure_opt_embedding()]
|
|
||||||
|
|
||||||
if trainer.global_step > self.swap_step and self.is_frozen:
|
|
||||||
self.is_frozen = False
|
|
||||||
trainer.optimizers = [pl_module.configure_opt_model()]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# custom parser to specify config files, train, test and debug mode,
|
|
||||||
# postfix, resume.
|
|
||||||
# `--key value` arguments are interpreted as arguments to the trainer.
|
|
||||||
# `nested.key=value` arguments are interpreted as config parameters.
|
|
||||||
# configs are merged from left-to-right followed by command line parameters.
|
|
||||||
|
|
||||||
# model:
|
|
||||||
# base_learning_rate: float
|
|
||||||
# target: path to lightning module
|
|
||||||
# params:
|
|
||||||
# key: value
|
|
||||||
# data:
|
|
||||||
# target: main.DataModuleFromConfig
|
|
||||||
# params:
|
|
||||||
# batch_size: int
|
|
||||||
# wrap: bool
|
|
||||||
# train:
|
|
||||||
# target: path to train dataset
|
|
||||||
# params:
|
|
||||||
# key: value
|
|
||||||
# validation:
|
|
||||||
# target: path to validation dataset
|
|
||||||
# params:
|
|
||||||
# key: value
|
|
||||||
# test:
|
|
||||||
# target: path to test dataset
|
|
||||||
# params:
|
|
||||||
# key: value
|
|
||||||
# lightning: (optional, has sane defaults and can be specified on cmdline)
|
|
||||||
# trainer:
|
|
||||||
# additional arguments to trainer
|
|
||||||
# logger:
|
|
||||||
# logger to instantiate
|
|
||||||
# modelcheckpoint:
|
|
||||||
# modelcheckpoint to instantiate
|
|
||||||
# callbacks:
|
|
||||||
# callback1:
|
|
||||||
# target: importpath
|
|
||||||
# params:
|
|
||||||
# key: value
|
|
||||||
|
|
||||||
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
|
||||||
|
|
||||||
# add cwd for convenience and to make classes in this file available when
|
|
||||||
# running as `python main.py`
|
|
||||||
# (in particular `main.DataModuleFromConfig`)
|
|
||||||
sys.path.append(os.getcwd())
|
|
||||||
|
|
||||||
parser = get_parser()
|
|
||||||
parser = Trainer.add_argparse_args(parser)
|
|
||||||
|
|
||||||
opt, unknown = parser.parse_known_args()
|
|
||||||
if opt.name and opt.resume:
|
|
||||||
raise ValueError(
|
|
||||||
"-n/--name and -r/--resume cannot be specified both."
|
|
||||||
"If you want to resume training in a new log folder, "
|
|
||||||
"use -n/--name in combination with --resume_from_checkpoint"
|
|
||||||
)
|
|
||||||
if opt.resume:
|
|
||||||
if not os.path.exists(opt.resume):
|
|
||||||
raise ValueError("Cannot find {}".format(opt.resume))
|
|
||||||
if os.path.isfile(opt.resume):
|
|
||||||
paths = opt.resume.split("/")
|
|
||||||
# idx = len(paths)-paths[::-1].index("logs")+1
|
|
||||||
# logdir = "/".join(paths[:idx])
|
|
||||||
logdir = "/".join(paths[:-2])
|
|
||||||
ckpt = opt.resume
|
|
||||||
else:
|
|
||||||
assert os.path.isdir(opt.resume), opt.resume
|
|
||||||
logdir = opt.resume.rstrip("/")
|
|
||||||
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
|
||||||
|
|
||||||
opt.resume_from_checkpoint = ckpt
|
|
||||||
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
|
||||||
opt.base = base_configs + opt.base
|
|
||||||
_tmp = logdir.split("/")
|
|
||||||
nowname = _tmp[-1]
|
|
||||||
else:
|
|
||||||
if opt.name:
|
|
||||||
name = "_" + opt.name
|
|
||||||
elif opt.base:
|
|
||||||
cfg_fname = os.path.split(opt.base[0])[-1]
|
|
||||||
cfg_name = os.path.splitext(cfg_fname)[0]
|
|
||||||
name = "_" + cfg_name
|
|
||||||
else:
|
|
||||||
name = ""
|
|
||||||
|
|
||||||
if opt.datadir_in_name:
|
|
||||||
now = os.path.basename(os.path.normpath(opt.data_root)) + now
|
|
||||||
|
|
||||||
nowname = now + name + opt.postfix
|
|
||||||
logdir = os.path.join(opt.logdir, nowname)
|
|
||||||
|
|
||||||
ckptdir = os.path.join(logdir, "checkpoints")
|
|
||||||
cfgdir = os.path.join(logdir, "configs")
|
|
||||||
seed_everything(opt.seed)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# init and save configs
|
|
||||||
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
|
||||||
cli = OmegaConf.from_dotlist(unknown)
|
|
||||||
config = OmegaConf.merge(*configs, cli)
|
|
||||||
lightning_config = config.pop("lightning", OmegaConf.create())
|
|
||||||
# merge trainer cli with config
|
|
||||||
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
|
||||||
# default to ddp
|
|
||||||
trainer_config["accelerator"] = "auto"
|
|
||||||
for k in nondefault_trainer_args(opt):
|
|
||||||
trainer_config[k] = getattr(opt, k)
|
|
||||||
if "gpus" not in trainer_config:
|
|
||||||
del trainer_config["accelerator"]
|
|
||||||
cpu = True
|
|
||||||
else:
|
|
||||||
gpuinfo = trainer_config["gpus"]
|
|
||||||
print(f"Running on GPUs {gpuinfo}")
|
|
||||||
cpu = False
|
|
||||||
trainer_opt = argparse.Namespace(**trainer_config)
|
|
||||||
lightning_config.trainer = trainer_config
|
|
||||||
|
|
||||||
# model
|
|
||||||
|
|
||||||
# config.model.params.personalization_config.params.init_word = opt.init_word
|
|
||||||
config.model.params.personalization_config.params.embedding_manager_ckpt = opt.embedding_manager_ckpt
|
|
||||||
|
|
||||||
if opt.init_word:
|
|
||||||
config.model.params.personalization_config.params.initializer_words = [opt.init_word]
|
|
||||||
|
|
||||||
if opt.actual_resume:
|
|
||||||
model = load_model_from_config(config, opt.actual_resume)
|
|
||||||
else:
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
|
|
||||||
# trainer and callbacks
|
|
||||||
trainer_kwargs = dict()
|
|
||||||
|
|
||||||
# default logger configs
|
|
||||||
def_logger = "csv"
|
|
||||||
def_logger_target = "CSVLogger"
|
|
||||||
default_logger_cfgs = {
|
|
||||||
"wandb": {
|
|
||||||
"target": "pytorch_lightning.loggers.WandbLogger",
|
|
||||||
"params": {
|
|
||||||
"name": nowname,
|
|
||||||
"save_dir": logdir,
|
|
||||||
"offline": opt.debug,
|
|
||||||
"id": nowname,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
def_logger: {
|
|
||||||
"target": "pytorch_lightning.loggers." + def_logger_target,
|
|
||||||
"params": {
|
|
||||||
"name": def_logger,
|
|
||||||
"save_dir": logdir,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
default_logger_cfg = default_logger_cfgs[def_logger]
|
|
||||||
if "logger" in lightning_config:
|
|
||||||
logger_cfg = lightning_config.logger
|
|
||||||
else:
|
|
||||||
logger_cfg = OmegaConf.create()
|
|
||||||
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
|
||||||
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
|
||||||
|
|
||||||
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
|
||||||
# specify which metric is used to determine best models
|
|
||||||
default_modelckpt_cfg = {
|
|
||||||
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
|
||||||
"params": {
|
|
||||||
"dirpath": ckptdir,
|
|
||||||
"filename": "{epoch:06}",
|
|
||||||
"verbose": True,
|
|
||||||
"save_last": True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if hasattr(model, "monitor"):
|
|
||||||
print(f"Monitoring {model.monitor} as checkpoint metric.")
|
|
||||||
default_modelckpt_cfg["params"]["monitor"] = model.monitor
|
|
||||||
default_modelckpt_cfg["params"]["save_top_k"] = 1
|
|
||||||
|
|
||||||
if "modelcheckpoint" in lightning_config:
|
|
||||||
modelckpt_cfg = lightning_config.modelcheckpoint
|
|
||||||
else:
|
|
||||||
modelckpt_cfg = OmegaConf.create()
|
|
||||||
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
|
||||||
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
|
|
||||||
if version.parse(pl.__version__) < version.parse("1.4.0"):
|
|
||||||
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
|
||||||
|
|
||||||
# add callback which sets up log directory
|
|
||||||
default_callbacks_cfg = {
|
|
||||||
"setup_callback": {
|
|
||||||
"target": "main.SetupCallback",
|
|
||||||
"params": {
|
|
||||||
"resume": opt.resume,
|
|
||||||
"now": now,
|
|
||||||
"logdir": logdir,
|
|
||||||
"ckptdir": ckptdir,
|
|
||||||
"cfgdir": cfgdir,
|
|
||||||
"config": config,
|
|
||||||
"lightning_config": lightning_config,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"image_logger": {
|
|
||||||
"target": "main.ImageLogger",
|
|
||||||
"params": {
|
|
||||||
"batch_frequency": 750,
|
|
||||||
"max_images": 4,
|
|
||||||
"clamp": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"learning_rate_logger": {
|
|
||||||
"target": "main.LearningRateMonitor",
|
|
||||||
"params": {
|
|
||||||
"logging_interval": "step",
|
|
||||||
# "log_momentum": True
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"cuda_callback": {"target": "main.CUDACallback"},
|
|
||||||
}
|
|
||||||
if version.parse(pl.__version__) >= version.parse("1.4.0"):
|
|
||||||
default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
|
|
||||||
|
|
||||||
if "callbacks" in lightning_config:
|
|
||||||
callbacks_cfg = lightning_config.callbacks
|
|
||||||
else:
|
|
||||||
callbacks_cfg = OmegaConf.create()
|
|
||||||
|
|
||||||
if "metrics_over_trainsteps_checkpoint" in callbacks_cfg:
|
|
||||||
print(
|
|
||||||
"Caution: Saving checkpoints every n train steps without deleting. This might require some free space."
|
|
||||||
)
|
|
||||||
default_metrics_over_trainsteps_ckpt_dict = {
|
|
||||||
"metrics_over_trainsteps_checkpoint": {
|
|
||||||
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
|
||||||
"params": {
|
|
||||||
"dirpath": os.path.join(ckptdir, "trainstep_checkpoints"),
|
|
||||||
"filename": "{epoch:06}-{step:09}",
|
|
||||||
"verbose": True,
|
|
||||||
"save_top_k": -1,
|
|
||||||
"every_n_train_steps": 10000,
|
|
||||||
"save_weights_only": True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
|
||||||
|
|
||||||
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
|
||||||
if "ignore_keys_callback" in callbacks_cfg and hasattr(trainer_opt, "resume_from_checkpoint"):
|
|
||||||
callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = trainer_opt.resume_from_checkpoint
|
|
||||||
elif "ignore_keys_callback" in callbacks_cfg:
|
|
||||||
del callbacks_cfg["ignore_keys_callback"]
|
|
||||||
|
|
||||||
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
|
||||||
trainer_kwargs["max_steps"] = trainer_opt.max_steps
|
|
||||||
|
|
||||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
||||||
trainer_opt.accelerator = "mps"
|
|
||||||
trainer_opt.detect_anomaly = False
|
|
||||||
|
|
||||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
|
||||||
trainer.logdir = logdir
|
|
||||||
|
|
||||||
# data
|
|
||||||
config.data.params.train.params.data_root = opt.data_root
|
|
||||||
config.data.params.validation.params.data_root = opt.data_root
|
|
||||||
data = instantiate_from_config(config.data)
|
|
||||||
|
|
||||||
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
|
||||||
# calling these ourselves should not be necessary but it is.
|
|
||||||
# lightning still takes care of proper multiprocessing though
|
|
||||||
data.prepare_data()
|
|
||||||
data.setup()
|
|
||||||
print("#### Data #####")
|
|
||||||
for k in data.datasets:
|
|
||||||
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
|
|
||||||
|
|
||||||
# configure learning rate
|
|
||||||
bs, base_lr = (
|
|
||||||
config.data.params.batch_size,
|
|
||||||
config.model.base_learning_rate,
|
|
||||||
)
|
|
||||||
if not cpu:
|
|
||||||
gpus = str(lightning_config.trainer.gpus).strip(", ").split(",")
|
|
||||||
ngpu = len(gpus)
|
|
||||||
else:
|
|
||||||
ngpu = 1
|
|
||||||
if "accumulate_grad_batches" in lightning_config.trainer:
|
|
||||||
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
|
||||||
else:
|
|
||||||
accumulate_grad_batches = 1
|
|
||||||
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
|
||||||
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
|
||||||
if opt.scale_lr:
|
|
||||||
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
|
||||||
print(
|
|
||||||
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
|
|
||||||
model.learning_rate,
|
|
||||||
accumulate_grad_batches,
|
|
||||||
ngpu,
|
|
||||||
bs,
|
|
||||||
base_lr,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model.learning_rate = base_lr
|
|
||||||
print("++++ NOT USING LR SCALING ++++")
|
|
||||||
print(f"Setting learning rate to {model.learning_rate:.2e}")
|
|
||||||
|
|
||||||
# allow checkpointing via USR1
|
|
||||||
def melk(*args, **kwargs):
|
|
||||||
# run all checkpoint hooks
|
|
||||||
if trainer.global_rank == 0:
|
|
||||||
print("Summoning checkpoint.")
|
|
||||||
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
|
||||||
trainer.save_checkpoint(ckpt_path)
|
|
||||||
|
|
||||||
def divein(*args, **kwargs):
|
|
||||||
if trainer.global_rank == 0:
|
|
||||||
import pudb
|
|
||||||
|
|
||||||
pudb.set_trace()
|
|
||||||
|
|
||||||
import signal
|
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, melk)
|
|
||||||
signal.signal(signal.SIGTERM, divein)
|
|
||||||
|
|
||||||
# run
|
|
||||||
if opt.train:
|
|
||||||
try:
|
|
||||||
trainer.fit(model, data)
|
|
||||||
except Exception:
|
|
||||||
melk()
|
|
||||||
raise
|
|
||||||
if not opt.no_test and not trainer.interrupted:
|
|
||||||
trainer.test(model, data)
|
|
||||||
except Exception:
|
|
||||||
if opt.debug and trainer.global_rank == 0:
|
|
||||||
try:
|
|
||||||
import pudb as debugger
|
|
||||||
except ImportError:
|
|
||||||
import pdb as debugger
|
|
||||||
debugger.post_mortem()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
# move newly created debug project to debug_runs
|
|
||||||
if opt.debug and not opt.resume and trainer.global_rank == 0:
|
|
||||||
dst, name = os.path.split(logdir)
|
|
||||||
dst = os.path.join(dst, "debug_runs", name)
|
|
||||||
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
|
||||||
os.rename(logdir, dst)
|
|
||||||
# if trainer.global_rank == 0:
|
|
||||||
# print(trainer.profiler.summary())
|
|
@ -1,130 +0,0 @@
|
|||||||
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
|
|
||||||
from ldm.modules.embedding_manager import EmbeddingManager
|
|
||||||
from ldm.invoke.globals import Globals
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def get_placeholder_loop(placeholder_string, embedder, use_bert):
|
|
||||||
new_placeholder = None
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if new_placeholder is None:
|
|
||||||
new_placeholder = input(
|
|
||||||
f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: "
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
new_placeholder = input(
|
|
||||||
f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: "
|
|
||||||
)
|
|
||||||
|
|
||||||
token = (
|
|
||||||
get_bert_token_for_string(embedder.tknz_fn, new_placeholder)
|
|
||||||
if use_bert
|
|
||||||
else get_clip_token_for_string(embedder.tokenizer, new_placeholder)
|
|
||||||
)
|
|
||||||
|
|
||||||
if token is not None:
|
|
||||||
return new_placeholder, token
|
|
||||||
|
|
||||||
|
|
||||||
def get_clip_token_for_string(tokenizer, string):
|
|
||||||
batch_encoding = tokenizer(
|
|
||||||
string,
|
|
||||||
truncation=True,
|
|
||||||
max_length=77,
|
|
||||||
return_length=True,
|
|
||||||
return_overflowing_tokens=False,
|
|
||||||
padding="max_length",
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
|
|
||||||
tokens = batch_encoding["input_ids"]
|
|
||||||
|
|
||||||
if torch.count_nonzero(tokens - 49407) == 2:
|
|
||||||
return tokens[0, 1]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_bert_token_for_string(tokenizer, string):
|
|
||||||
token = tokenizer(string)
|
|
||||||
if torch.count_nonzero(token) == 3:
|
|
||||||
return token[0, 1]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--root_dir",
|
|
||||||
type=str,
|
|
||||||
default=".",
|
|
||||||
help="Path to the InvokeAI install directory containing 'models', 'outputs' and 'configs'.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--manager_ckpts", type=str, nargs="+", required=True, help="Paths to a set of embedding managers to be merged."
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--output_path",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Output path for the merged manager",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"-sd",
|
|
||||||
"--use_bert",
|
|
||||||
action="store_true",
|
|
||||||
help="Flag to denote that we are not merging stable diffusion embeddings",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
Globals.root = args.root_dir
|
|
||||||
|
|
||||||
if args.use_bert:
|
|
||||||
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()
|
|
||||||
else:
|
|
||||||
embedder = FrozenCLIPEmbedder().cuda()
|
|
||||||
|
|
||||||
EmbeddingManager = partial(EmbeddingManager, embedder, ["*"])
|
|
||||||
|
|
||||||
string_to_token_dict = {}
|
|
||||||
string_to_param_dict = torch.nn.ParameterDict()
|
|
||||||
|
|
||||||
placeholder_to_src = {}
|
|
||||||
|
|
||||||
for manager_ckpt in args.manager_ckpts:
|
|
||||||
print(f"Parsing {manager_ckpt}...")
|
|
||||||
|
|
||||||
manager = EmbeddingManager()
|
|
||||||
manager.load(manager_ckpt)
|
|
||||||
|
|
||||||
for placeholder_string in manager.string_to_token_dict:
|
|
||||||
if placeholder_string not in string_to_token_dict:
|
|
||||||
string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string]
|
|
||||||
string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string]
|
|
||||||
|
|
||||||
placeholder_to_src[placeholder_string] = manager_ckpt
|
|
||||||
else:
|
|
||||||
new_placeholder, new_token = get_placeholder_loop(placeholder_string, embedder, use_bert=args.use_bert)
|
|
||||||
string_to_token_dict[new_placeholder] = new_token
|
|
||||||
string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string]
|
|
||||||
|
|
||||||
placeholder_to_src[new_placeholder] = manager_ckpt
|
|
||||||
|
|
||||||
print("Saving combined manager...")
|
|
||||||
merged_manager = EmbeddingManager()
|
|
||||||
merged_manager.string_to_param_dict = string_to_param_dict
|
|
||||||
merged_manager.string_to_token_dict = string_to_token_dict
|
|
||||||
merged_manager.save(args.output_path)
|
|
||||||
|
|
||||||
print("Managers merged. Final list of placeholders: ")
|
|
||||||
print(placeholder_to_src)
|
|
@ -1,305 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import datetime
|
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import trange
|
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from ldm.util import instantiate_from_config
|
|
||||||
|
|
||||||
|
|
||||||
def rescale(x: float) -> float:
|
|
||||||
return (x + 1.0) / 2.0
|
|
||||||
|
|
||||||
|
|
||||||
def custom_to_pil(x):
|
|
||||||
x = x.detach().cpu()
|
|
||||||
x = torch.clamp(x, -1.0, 1.0)
|
|
||||||
x = (x + 1.0) / 2.0
|
|
||||||
x = x.permute(1, 2, 0).numpy()
|
|
||||||
x = (255 * x).astype(np.uint8)
|
|
||||||
x = Image.fromarray(x)
|
|
||||||
if not x.mode == "RGB":
|
|
||||||
x = x.convert("RGB")
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def custom_to_np(x):
|
|
||||||
# saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
|
|
||||||
sample = x.detach().cpu()
|
|
||||||
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
|
|
||||||
sample = sample.permute(0, 2, 3, 1)
|
|
||||||
sample = sample.contiguous()
|
|
||||||
return sample
|
|
||||||
|
|
||||||
|
|
||||||
def logs2pil(logs, keys=["sample"]):
|
|
||||||
imgs = dict()
|
|
||||||
for k in logs:
|
|
||||||
try:
|
|
||||||
if len(logs[k].shape) == 4:
|
|
||||||
img = custom_to_pil(logs[k][0, ...])
|
|
||||||
elif len(logs[k].shape) == 3:
|
|
||||||
img = custom_to_pil(logs[k])
|
|
||||||
else:
|
|
||||||
print(f"Unknown format for key {k}. ")
|
|
||||||
img = None
|
|
||||||
except Exception:
|
|
||||||
img = None
|
|
||||||
imgs[k] = img
|
|
||||||
return imgs
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def convsample(model, shape, return_intermediates=True, verbose=True, make_prog_row=False):
|
|
||||||
if not make_prog_row:
|
|
||||||
return model.p_sample_loop(None, shape, return_intermediates=return_intermediates, verbose=verbose)
|
|
||||||
else:
|
|
||||||
return model.progressive_denoising(None, shape, verbose=True)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def convsample_ddim(model, steps, shape, eta=1.0):
|
|
||||||
ddim = DDIMSampler(model)
|
|
||||||
bs = shape[0]
|
|
||||||
shape = shape[1:]
|
|
||||||
samples, intermediates = ddim.sample(
|
|
||||||
steps,
|
|
||||||
batch_size=bs,
|
|
||||||
shape=shape,
|
|
||||||
eta=eta,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
return samples, intermediates
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def make_convolutional_sample(
|
|
||||||
model,
|
|
||||||
batch_size,
|
|
||||||
vanilla=False,
|
|
||||||
custom_steps=None,
|
|
||||||
eta=1.0,
|
|
||||||
):
|
|
||||||
log = dict()
|
|
||||||
|
|
||||||
shape = [
|
|
||||||
batch_size,
|
|
||||||
model.model.diffusion_model.in_channels,
|
|
||||||
model.model.diffusion_model.image_size,
|
|
||||||
model.model.diffusion_model.image_size,
|
|
||||||
]
|
|
||||||
|
|
||||||
with model.ema_scope("Plotting"):
|
|
||||||
t0 = time.time()
|
|
||||||
if vanilla:
|
|
||||||
sample, progrow = convsample(model, shape, make_prog_row=True)
|
|
||||||
else:
|
|
||||||
sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, eta=eta)
|
|
||||||
|
|
||||||
t1 = time.time()
|
|
||||||
|
|
||||||
x_sample = model.decode_first_stage(sample)
|
|
||||||
|
|
||||||
log["sample"] = x_sample
|
|
||||||
log["time"] = t1 - t0
|
|
||||||
log["throughput"] = sample.shape[0] / (t1 - t0)
|
|
||||||
print(f'Throughput for this batch: {log["throughput"]}')
|
|
||||||
return log
|
|
||||||
|
|
||||||
|
|
||||||
def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):
|
|
||||||
if vanilla:
|
|
||||||
print(f"Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.")
|
|
||||||
else:
|
|
||||||
print(f"Using DDIM sampling with {custom_steps} sampling steps and eta={eta}")
|
|
||||||
|
|
||||||
tstart = time.time()
|
|
||||||
n_saved = len(glob.glob(os.path.join(logdir, "*.png"))) - 1
|
|
||||||
# path = logdir
|
|
||||||
if model.cond_stage_model is None:
|
|
||||||
all_images = []
|
|
||||||
|
|
||||||
print(f"Running unconditional sampling for {n_samples} samples")
|
|
||||||
for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):
|
|
||||||
logs = make_convolutional_sample(
|
|
||||||
model, batch_size=batch_size, vanilla=vanilla, custom_steps=custom_steps, eta=eta
|
|
||||||
)
|
|
||||||
n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
|
|
||||||
all_images.extend([custom_to_np(logs["sample"])])
|
|
||||||
if n_saved >= n_samples:
|
|
||||||
print(f"Finish after generating {n_saved} samples")
|
|
||||||
break
|
|
||||||
all_img = np.concatenate(all_images, axis=0)
|
|
||||||
all_img = all_img[:n_samples]
|
|
||||||
shape_str = "x".join([str(x) for x in all_img.shape])
|
|
||||||
nppath = os.path.join(nplog, f"{shape_str}-samples.npz")
|
|
||||||
np.savez(nppath, all_img)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Currently only sampling for unconditional models supported.")
|
|
||||||
|
|
||||||
print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")
|
|
||||||
|
|
||||||
|
|
||||||
def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
|
|
||||||
for k in logs:
|
|
||||||
if k == key:
|
|
||||||
batch = logs[key]
|
|
||||||
if np_path is None:
|
|
||||||
for x in batch:
|
|
||||||
img = custom_to_pil(x)
|
|
||||||
imgpath = os.path.join(path, f"{key}_{n_saved:06}.png")
|
|
||||||
img.save(imgpath)
|
|
||||||
n_saved += 1
|
|
||||||
else:
|
|
||||||
npbatch = custom_to_np(batch)
|
|
||||||
shape_str = "x".join([str(x) for x in npbatch.shape])
|
|
||||||
nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
|
|
||||||
np.savez(nppath, npbatch)
|
|
||||||
n_saved += npbatch.shape[0]
|
|
||||||
return n_saved
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"-r",
|
|
||||||
"--resume",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
help="load from logdir or checkpoint in logdir",
|
|
||||||
)
|
|
||||||
parser.add_argument("-n", "--n_samples", type=int, nargs="?", help="number of samples to draw", default=50000)
|
|
||||||
parser.add_argument(
|
|
||||||
"-e",
|
|
||||||
"--eta",
|
|
||||||
type=float,
|
|
||||||
nargs="?",
|
|
||||||
help="eta for ddim sampling (0.0 yields deterministic sampling)",
|
|
||||||
default=1.0,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-v",
|
|
||||||
"--vanilla_sample",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="vanilla sampling (default option is DDIM sampling)?",
|
|
||||||
)
|
|
||||||
parser.add_argument("-l", "--logdir", type=str, nargs="?", help="extra logdir", default="none")
|
|
||||||
parser.add_argument(
|
|
||||||
"-c", "--custom_steps", type=int, nargs="?", help="number of steps for ddim and fastdpm sampling", default=50
|
|
||||||
)
|
|
||||||
parser.add_argument("--batch_size", type=int, nargs="?", help="the bs", default=10)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, sd):
|
|
||||||
model = instantiate_from_config(config)
|
|
||||||
model.load_state_dict(sd, strict=False)
|
|
||||||
model.cuda()
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(config, ckpt, gpu, eval_mode):
|
|
||||||
if ckpt:
|
|
||||||
print(f"Loading model from {ckpt}")
|
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
||||||
global_step = pl_sd["global_step"]
|
|
||||||
else:
|
|
||||||
pl_sd = {"state_dict": None}
|
|
||||||
global_step = None
|
|
||||||
model = load_model_from_config(config.model, pl_sd["state_dict"])
|
|
||||||
|
|
||||||
return model, global_step
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
|
||||||
sys.path.append(os.getcwd())
|
|
||||||
command = " ".join(sys.argv)
|
|
||||||
|
|
||||||
parser = get_parser()
|
|
||||||
opt, unknown = parser.parse_known_args()
|
|
||||||
ckpt = None
|
|
||||||
|
|
||||||
if not os.path.exists(opt.resume):
|
|
||||||
raise ValueError("Cannot find {}".format(opt.resume))
|
|
||||||
if os.path.isfile(opt.resume):
|
|
||||||
# paths = opt.resume.split("/")
|
|
||||||
try:
|
|
||||||
logdir = "/".join(opt.resume.split("/")[:-1])
|
|
||||||
# idx = len(paths)-paths[::-1].index("logs")+1
|
|
||||||
print(f"Logdir is {logdir}")
|
|
||||||
except ValueError:
|
|
||||||
paths = opt.resume.split("/")
|
|
||||||
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
|
||||||
logdir = "/".join(paths[:idx])
|
|
||||||
ckpt = opt.resume
|
|
||||||
else:
|
|
||||||
assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory"
|
|
||||||
logdir = opt.resume.rstrip("/")
|
|
||||||
ckpt = os.path.join(logdir, "model.ckpt")
|
|
||||||
|
|
||||||
base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml")))
|
|
||||||
opt.base = base_configs
|
|
||||||
|
|
||||||
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
|
||||||
cli = OmegaConf.from_dotlist(unknown)
|
|
||||||
config = OmegaConf.merge(*configs, cli)
|
|
||||||
|
|
||||||
gpu = True
|
|
||||||
eval_mode = True
|
|
||||||
|
|
||||||
if opt.logdir != "none":
|
|
||||||
locallog = logdir.split(os.sep)[-1]
|
|
||||||
if locallog == "":
|
|
||||||
locallog = logdir.split(os.sep)[-2]
|
|
||||||
print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
|
|
||||||
logdir = os.path.join(opt.logdir, locallog)
|
|
||||||
|
|
||||||
print(config)
|
|
||||||
|
|
||||||
model, global_step = load_model(config, ckpt, gpu, eval_mode)
|
|
||||||
print(f"global step: {global_step}")
|
|
||||||
print(75 * "=")
|
|
||||||
print("logging to:")
|
|
||||||
logdir = os.path.join(logdir, "samples", f"{global_step:08}", now)
|
|
||||||
imglogdir = os.path.join(logdir, "img")
|
|
||||||
numpylogdir = os.path.join(logdir, "numpy")
|
|
||||||
|
|
||||||
os.makedirs(imglogdir)
|
|
||||||
os.makedirs(numpylogdir)
|
|
||||||
print(logdir)
|
|
||||||
print(75 * "=")
|
|
||||||
|
|
||||||
# write config out
|
|
||||||
sampling_file = os.path.join(logdir, "sampling_config.yaml")
|
|
||||||
sampling_conf = vars(opt)
|
|
||||||
|
|
||||||
with open(sampling_file, "w") as f:
|
|
||||||
yaml.dump(sampling_conf, f, default_flow_style=False)
|
|
||||||
print(sampling_conf)
|
|
||||||
|
|
||||||
run(
|
|
||||||
model,
|
|
||||||
imglogdir,
|
|
||||||
eta=opt.eta,
|
|
||||||
vanilla=opt.vanilla_sample,
|
|
||||||
n_samples=opt.n_samples,
|
|
||||||
custom_steps=opt.custom_steps,
|
|
||||||
batch_size=opt.batch_size,
|
|
||||||
nplog=numpylogdir,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("done.")
|
|
@ -1,169 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
import scann
|
|
||||||
import argparse
|
|
||||||
import glob
|
|
||||||
from multiprocessing import cpu_count
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from ldm.util import parallel_data_prefetch
|
|
||||||
|
|
||||||
|
|
||||||
def search_bruteforce(searcher):
|
|
||||||
return searcher.score_brute_force().build()
|
|
||||||
|
|
||||||
|
|
||||||
def search_partioned_ah(
|
|
||||||
searcher, dims_per_block, aiq_threshold, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search
|
|
||||||
):
|
|
||||||
return (
|
|
||||||
searcher.tree(
|
|
||||||
num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=partioning_trainsize
|
|
||||||
)
|
|
||||||
.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold)
|
|
||||||
.reorder(reorder_k)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
|
|
||||||
return (
|
|
||||||
searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_datapool(dpath):
|
|
||||||
def load_single_file(saved_embeddings):
|
|
||||||
compressed = np.load(saved_embeddings)
|
|
||||||
database = {key: compressed[key] for key in compressed.files}
|
|
||||||
return database
|
|
||||||
|
|
||||||
def load_multi_files(data_archive):
|
|
||||||
database = {key: [] for key in data_archive[0].files}
|
|
||||||
for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."):
|
|
||||||
for key in d.files:
|
|
||||||
database[key].append(d[key])
|
|
||||||
|
|
||||||
return database
|
|
||||||
|
|
||||||
print(f'Load saved patch embedding from "{dpath}"')
|
|
||||||
file_content = glob.glob(os.path.join(dpath, "*.npz"))
|
|
||||||
|
|
||||||
if len(file_content) == 1:
|
|
||||||
data_pool = load_single_file(file_content[0])
|
|
||||||
elif len(file_content) > 1:
|
|
||||||
data = [np.load(f) for f in file_content]
|
|
||||||
prefetched_data = parallel_data_prefetch(
|
|
||||||
load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict"
|
|
||||||
)
|
|
||||||
|
|
||||||
data_pool = {
|
|
||||||
key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
|
|
||||||
|
|
||||||
print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
|
|
||||||
return data_pool
|
|
||||||
|
|
||||||
|
|
||||||
def train_searcher(
|
|
||||||
opt,
|
|
||||||
metric="dot_product",
|
|
||||||
partioning_trainsize=None,
|
|
||||||
reorder_k=None,
|
|
||||||
# todo tune
|
|
||||||
aiq_thld=0.2,
|
|
||||||
dims_per_block=2,
|
|
||||||
num_leaves=None,
|
|
||||||
num_leaves_to_search=None,
|
|
||||||
):
|
|
||||||
data_pool = load_datapool(opt.database)
|
|
||||||
k = opt.knn
|
|
||||||
|
|
||||||
if not reorder_k:
|
|
||||||
reorder_k = 2 * k
|
|
||||||
|
|
||||||
# normalize
|
|
||||||
# embeddings =
|
|
||||||
searcher = scann.scann_ops_pybind.builder(
|
|
||||||
data_pool["embedding"] / np.linalg.norm(data_pool["embedding"], axis=1)[:, np.newaxis], k, metric
|
|
||||||
)
|
|
||||||
pool_size = data_pool["embedding"].shape[0]
|
|
||||||
|
|
||||||
print(*(["#"] * 100))
|
|
||||||
print("Initializing scaNN searcher with the following values:")
|
|
||||||
print(f"k: {k}")
|
|
||||||
print(f"metric: {metric}")
|
|
||||||
print(f"reorder_k: {reorder_k}")
|
|
||||||
print(f"anisotropic_quantization_threshold: {aiq_thld}")
|
|
||||||
print(f"dims_per_block: {dims_per_block}")
|
|
||||||
print(*(["#"] * 100))
|
|
||||||
print("Start training searcher....")
|
|
||||||
print(f"N samples in pool is {pool_size}")
|
|
||||||
|
|
||||||
# this reflects the recommended design choices proposed at
|
|
||||||
# https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
|
|
||||||
if pool_size < 2e4:
|
|
||||||
print("Using brute force search.")
|
|
||||||
searcher = search_bruteforce(searcher)
|
|
||||||
elif 2e4 <= pool_size and pool_size < 1e5:
|
|
||||||
print("Using asymmetric hashing search and reordering.")
|
|
||||||
searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
|
|
||||||
else:
|
|
||||||
print("Using using partioning, asymmetric hashing search and reordering.")
|
|
||||||
|
|
||||||
if not partioning_trainsize:
|
|
||||||
partioning_trainsize = data_pool["embedding"].shape[0] // 10
|
|
||||||
if not num_leaves:
|
|
||||||
num_leaves = int(np.sqrt(pool_size))
|
|
||||||
|
|
||||||
if not num_leaves_to_search:
|
|
||||||
num_leaves_to_search = max(num_leaves // 20, 1)
|
|
||||||
|
|
||||||
print("Partitioning params:")
|
|
||||||
print(f"num_leaves: {num_leaves}")
|
|
||||||
print(f"num_leaves_to_search: {num_leaves_to_search}")
|
|
||||||
# self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
|
|
||||||
searcher = search_partioned_ah(
|
|
||||||
searcher, dims_per_block, aiq_thld, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Finish training searcher")
|
|
||||||
searcher_savedir = opt.target_path
|
|
||||||
os.makedirs(searcher_savedir, exist_ok=True)
|
|
||||||
searcher.serialize(searcher_savedir)
|
|
||||||
print(f'Saved trained searcher under "{searcher_savedir}"')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
sys.path.append(os.getcwd())
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--database",
|
|
||||||
"-d",
|
|
||||||
default="data/rdm/retrieval_databases/openimages",
|
|
||||||
type=str,
|
|
||||||
help="path to folder containing the clip feature of the database",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--target_path",
|
|
||||||
"-t",
|
|
||||||
default="data/rdm/searchers/openimages",
|
|
||||||
type=str,
|
|
||||||
help="path to the target folder where the searcher shall be stored.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--knn",
|
|
||||||
"-k",
|
|
||||||
default=20,
|
|
||||||
type=int,
|
|
||||||
help="number of nearest neighbors, for which the searcher shall be optimized",
|
|
||||||
)
|
|
||||||
|
|
||||||
opt, _ = parser.parse_known_args()
|
|
||||||
|
|
||||||
train_searcher(
|
|
||||||
opt,
|
|
||||||
)
|
|
@ -1,316 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm, trange
|
|
||||||
from itertools import islice
|
|
||||||
from einops import rearrange
|
|
||||||
from torchvision.utils import make_grid
|
|
||||||
from pytorch_lightning import seed_everything
|
|
||||||
from torch import autocast
|
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
import k_diffusion as K
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
|
||||||
from ldm.invoke.devices import choose_torch_device
|
|
||||||
|
|
||||||
|
|
||||||
def chunk(it, size):
|
|
||||||
it = iter(it)
|
|
||||||
return iter(lambda: tuple(islice(it, size)), ())
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt, verbose=False):
|
|
||||||
print(f"Loading model from {ckpt}")
|
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
||||||
if "global_step" in pl_sd:
|
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
|
||||||
if len(m) > 0 and verbose:
|
|
||||||
print("missing keys:")
|
|
||||||
print(m)
|
|
||||||
if len(u) > 0 and verbose:
|
|
||||||
print("unexpected keys:")
|
|
||||||
print(u)
|
|
||||||
|
|
||||||
model.to(choose_torch_device())
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
default="a painting of a virus monster playing guitar",
|
|
||||||
help="the prompt to render",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_grid",
|
|
||||||
action="store_true",
|
|
||||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_save",
|
|
||||||
action="store_true",
|
|
||||||
help="do not save individual samples. For speed measurements.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ddim_steps",
|
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="number of ddim sampling steps",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--plms",
|
|
||||||
action="store_true",
|
|
||||||
help="use plms sampling",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--klms",
|
|
||||||
action="store_true",
|
|
||||||
help="use klms sampling",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--laion400m",
|
|
||||||
action="store_true",
|
|
||||||
help="uses the LAION400M model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--fixed_code",
|
|
||||||
action="store_true",
|
|
||||||
help="if enabled, uses the same starting code across samples ",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ddim_eta",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_iter",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="sample this often",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--H",
|
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help="image height, in pixel space",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--W",
|
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help="image width, in pixel space",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--C",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="latent channels",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--f",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="downsampling factor",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_samples",
|
|
||||||
type=int,
|
|
||||||
default=3,
|
|
||||||
help="how many samples to produce for each given prompt. A.k.a. batch size",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_rows",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="rows in the grid (default: n_samples)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--scale",
|
|
||||||
type=float,
|
|
||||||
default=7.5,
|
|
||||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--from-file",
|
|
||||||
type=str,
|
|
||||||
help="if specified, load prompts from this file",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--config",
|
|
||||||
type=str,
|
|
||||||
default="configs/stable-diffusion/v1-inference.yaml",
|
|
||||||
help="path to config which constructs model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ckpt",
|
|
||||||
type=str,
|
|
||||||
default="models/ldm/stable-diffusion-v1/model.ckpt",
|
|
||||||
help="path to checkpoint of model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--seed",
|
|
||||||
type=int,
|
|
||||||
default=42,
|
|
||||||
help="the seed (for reproducible sampling)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast"
|
|
||||||
)
|
|
||||||
opt = parser.parse_args()
|
|
||||||
|
|
||||||
if opt.laion400m:
|
|
||||||
print("Falling back to LAION 400M model...")
|
|
||||||
opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
|
|
||||||
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
|
|
||||||
opt.outdir = "outputs/txt2img-samples-laion400m"
|
|
||||||
|
|
||||||
config = OmegaConf.load(f"{opt.config}")
|
|
||||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
|
||||||
|
|
||||||
seed_everything(opt.seed)
|
|
||||||
|
|
||||||
device = torch.device(choose_torch_device())
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
# for klms
|
|
||||||
model_wrap = K.external.CompVisDenoiser(model)
|
|
||||||
|
|
||||||
class CFGDenoiser(nn.Module):
|
|
||||||
def __init__(self, model):
|
|
||||||
super().__init__()
|
|
||||||
self.inner_model = model
|
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
|
||||||
x_in = torch.cat([x] * 2)
|
|
||||||
sigma_in = torch.cat([sigma] * 2)
|
|
||||||
cond_in = torch.cat([uncond, cond])
|
|
||||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
|
||||||
return uncond + (cond - uncond) * cond_scale
|
|
||||||
|
|
||||||
if opt.plms:
|
|
||||||
sampler = PLMSSampler(model)
|
|
||||||
else:
|
|
||||||
sampler = DDIMSampler(model)
|
|
||||||
|
|
||||||
os.makedirs(opt.outdir, exist_ok=True)
|
|
||||||
outpath = opt.outdir
|
|
||||||
|
|
||||||
batch_size = opt.n_samples
|
|
||||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
|
||||||
if not opt.from_file:
|
|
||||||
prompt = opt.prompt
|
|
||||||
assert prompt is not None
|
|
||||||
data = [batch_size * [prompt]]
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f"reading prompts from {opt.from_file}")
|
|
||||||
with open(opt.from_file, "r") as f:
|
|
||||||
data = f.read().splitlines()
|
|
||||||
if len(data) >= batch_size:
|
|
||||||
data = list(chunk(data, batch_size))
|
|
||||||
else:
|
|
||||||
while len(data) < batch_size:
|
|
||||||
data.append(data[-1])
|
|
||||||
data = [data]
|
|
||||||
|
|
||||||
sample_path = os.path.join(outpath, "samples")
|
|
||||||
os.makedirs(sample_path, exist_ok=True)
|
|
||||||
base_count = len(os.listdir(sample_path))
|
|
||||||
grid_count = len(os.listdir(outpath)) - 1
|
|
||||||
|
|
||||||
start_code = None
|
|
||||||
if opt.fixed_code:
|
|
||||||
shape = [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f]
|
|
||||||
if device.type == "mps":
|
|
||||||
start_code = torch.randn(shape, device="cpu").to(device)
|
|
||||||
else:
|
|
||||||
torch.randn(shape, device=device)
|
|
||||||
|
|
||||||
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
|
||||||
if device.type in ["mps", "cpu"]:
|
|
||||||
precision_scope = nullcontext # have to use f32 on mps
|
|
||||||
with torch.no_grad():
|
|
||||||
with precision_scope(device.type):
|
|
||||||
with model.ema_scope():
|
|
||||||
all_samples = list()
|
|
||||||
for n in trange(opt.n_iter, desc="Sampling"):
|
|
||||||
for prompts in tqdm(data, desc="data"):
|
|
||||||
uc = None
|
|
||||||
if opt.scale != 1.0:
|
|
||||||
uc = model.get_learned_conditioning(batch_size * [""])
|
|
||||||
if isinstance(prompts, tuple):
|
|
||||||
prompts = list(prompts)
|
|
||||||
c = model.get_learned_conditioning(prompts)
|
|
||||||
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
|
||||||
|
|
||||||
if not opt.klms:
|
|
||||||
samples_ddim, _ = sampler.sample(
|
|
||||||
S=opt.ddim_steps,
|
|
||||||
conditioning=c,
|
|
||||||
batch_size=opt.n_samples,
|
|
||||||
shape=shape,
|
|
||||||
verbose=False,
|
|
||||||
unconditional_guidance_scale=opt.scale,
|
|
||||||
unconditional_conditioning=uc,
|
|
||||||
eta=opt.ddim_eta,
|
|
||||||
x_T=start_code,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sigmas = model_wrap.get_sigmas(opt.ddim_steps)
|
|
||||||
if start_code:
|
|
||||||
x = start_code
|
|
||||||
else:
|
|
||||||
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
|
|
||||||
model_wrap_cfg = CFGDenoiser(model_wrap)
|
|
||||||
extra_args = {"cond": c, "uncond": uc, "cond_scale": opt.scale}
|
|
||||||
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args)
|
|
||||||
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
|
||||||
|
|
||||||
if not opt.skip_save:
|
|
||||||
for x_sample in x_samples_ddim:
|
|
||||||
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
|
|
||||||
Image.fromarray(x_sample.astype(np.uint8)).save(
|
|
||||||
os.path.join(sample_path, f"{base_count:05}.png")
|
|
||||||
)
|
|
||||||
base_count += 1
|
|
||||||
|
|
||||||
if not opt.skip_grid:
|
|
||||||
all_samples.append(x_samples_ddim)
|
|
||||||
|
|
||||||
if not opt.skip_grid:
|
|
||||||
# additionally, save as grid
|
|
||||||
grid = torch.stack(all_samples, 0)
|
|
||||||
grid = rearrange(grid, "n b c h w -> (n b) c h w")
|
|
||||||
grid = make_grid(grid, nrow=n_rows)
|
|
||||||
|
|
||||||
# to image
|
|
||||||
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
|
|
||||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
|
|
||||||
grid_count += 1
|
|
||||||
|
|
||||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,37 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
"""
|
|
||||||
Read a checkpoint/safetensors file and compare it to a template .json.
|
|
||||||
Returns True if their metadata match.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from invokeai.backend.model_management.models.base import read_checkpoint_meta
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Compare a checkpoint/safetensors file to a JSON metadata template.")
|
|
||||||
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
|
|
||||||
parser.add_argument("--template", "--out", type=Path, help="Path to the template .json file to match against")
|
|
||||||
|
|
||||||
opt = parser.parse_args()
|
|
||||||
ckpt = read_checkpoint_meta(opt.checkpoint)
|
|
||||||
while "state_dict" in ckpt:
|
|
||||||
ckpt = ckpt["state_dict"]
|
|
||||||
|
|
||||||
checkpoint_metadata = {}
|
|
||||||
|
|
||||||
for key, tensor in ckpt.items():
|
|
||||||
checkpoint_metadata[key] = list(tensor.shape)
|
|
||||||
|
|
||||||
with open(opt.template, "r") as f:
|
|
||||||
template = json.load(f)
|
|
||||||
|
|
||||||
if checkpoint_metadata == template:
|
|
||||||
print("True")
|
|
||||||
sys.exit(0)
|
|
||||||
else:
|
|
||||||
print("False")
|
|
||||||
sys.exit(-1)
|
|
@ -1,3 +1,10 @@
|
|||||||
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
InvalidVersionError,
|
||||||
|
invocation,
|
||||||
|
invocation_output,
|
||||||
|
)
|
||||||
from .test_nodes import (
|
from .test_nodes import (
|
||||||
ImageToImageTestInvocation,
|
ImageToImageTestInvocation,
|
||||||
TextToImageTestInvocation,
|
TextToImageTestInvocation,
|
||||||
@ -20,7 +27,7 @@ from invokeai.app.invocations.upscale import ESRGANInvocation
|
|||||||
|
|
||||||
from invokeai.app.invocations.image import ShowImageInvocation
|
from invokeai.app.invocations.image import ShowImageInvocation
|
||||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||||
from invokeai.app.invocations.primitives import IntegerInvocation
|
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
|
||||||
from invokeai.app.services.default_graphs import create_text_to_image
|
from invokeai.app.services.default_graphs import create_text_to_image
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -610,6 +617,79 @@ def test_graph_can_deserialize():
|
|||||||
assert g2.edges[0].destination.field == "image"
|
assert g2.edges[0].destination.field == "image"
|
||||||
|
|
||||||
|
|
||||||
|
def test_invocation_decorator():
|
||||||
|
invocation_type = "test_invocation"
|
||||||
|
title = "Test Invocation"
|
||||||
|
tags = ["first", "second", "third"]
|
||||||
|
category = "category"
|
||||||
|
version = "1.2.3"
|
||||||
|
|
||||||
|
@invocation(invocation_type, title=title, tags=tags, category=category, version=version)
|
||||||
|
class TestInvocation(BaseInvocation):
|
||||||
|
def invoke(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
schema = TestInvocation.schema()
|
||||||
|
|
||||||
|
assert schema.get("title") == title
|
||||||
|
assert schema.get("tags") == tags
|
||||||
|
assert schema.get("category") == category
|
||||||
|
assert schema.get("version") == version
|
||||||
|
assert TestInvocation(id="1").type == invocation_type # type: ignore (type is dynamically added)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invocation_version_must_be_semver():
|
||||||
|
invocation_type = "test_invocation"
|
||||||
|
valid_version = "1.0.0"
|
||||||
|
invalid_version = "not_semver"
|
||||||
|
|
||||||
|
@invocation(invocation_type, version=valid_version)
|
||||||
|
class ValidVersionInvocation(BaseInvocation):
|
||||||
|
def invoke(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(InvalidVersionError):
|
||||||
|
|
||||||
|
@invocation(invocation_type, version=invalid_version)
|
||||||
|
class InvalidVersionInvocation(BaseInvocation):
|
||||||
|
def invoke(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_invocation_output_decorator():
|
||||||
|
output_type = "test_output"
|
||||||
|
|
||||||
|
@invocation_output(output_type)
|
||||||
|
class TestOutput(BaseInvocationOutput):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert TestOutput().type == output_type # type: ignore (type is dynamically added)
|
||||||
|
|
||||||
|
|
||||||
|
def test_floats_accept_ints():
|
||||||
|
g = Graph()
|
||||||
|
n1 = IntegerInvocation(id="1", value=1)
|
||||||
|
n2 = FloatInvocation(id="2")
|
||||||
|
g.add_node(n1)
|
||||||
|
g.add_node(n2)
|
||||||
|
e = create_edge(n1.id, "value", n2.id, "value")
|
||||||
|
|
||||||
|
# Not throwing on this line is sufficient
|
||||||
|
g.add_edge(e)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ints_do_not_accept_floats():
|
||||||
|
g = Graph()
|
||||||
|
n1 = FloatInvocation(id="1", value=1.0)
|
||||||
|
n2 = IntegerInvocation(id="2")
|
||||||
|
g.add_node(n1)
|
||||||
|
g.add_node(n2)
|
||||||
|
e = create_edge(n1.id, "value", n2.id, "value")
|
||||||
|
|
||||||
|
with pytest.raises(InvalidEdgeError):
|
||||||
|
g.add_edge(e)
|
||||||
|
|
||||||
|
|
||||||
def test_graph_can_generate_schema():
|
def test_graph_can_generate_schema():
|
||||||
# Not throwing on this line is sufficient
|
# Not throwing on this line is sufficient
|
||||||
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation
|
||||||
|
Loading…
Reference in New Issue
Block a user