Merge branch 'main' into feat/ip-adapter

This commit is contained in:
blessedcoolant 2023-09-05 11:35:19 +12:00
commit 6bb378a101
92 changed files with 2245 additions and 7309 deletions

View File

@ -244,8 +244,12 @@ copy-paste the template above.
We can use the `@invocation` decorator to provide some additional info to the
UI, like a custom title, tags and category.
We also encourage providing a version. This must be a
[semver](https://semver.org/) version string ("$MAJOR.$MINOR.$PATCH"). The UI
will let users know if their workflow is using a mismatched version of the node.
```python
@invocation("resize", title="My Resizer", tags=["resize", "image"], category="My Invocations")
@invocation("resize", title="My Resizer", tags=["resize", "image"], category="My Invocations", version="1.0.0")
class ResizeInvocation(BaseInvocation):
"""Resizes an image"""
@ -279,8 +283,6 @@ take a look a at our [contributing nodes overview](contributingNodes).
## Advanced
-->
### Custom Output Types
Like with custom inputs, sometimes you might find yourself needing custom

View File

@ -22,12 +22,26 @@ To use a community node graph, download the the `.json` node graph file and load
![b920b710-1882-49a0-8d02-82dff2cca907](https://github.com/invoke-ai/InvokeAI/assets/25252829/7660c1ed-bf7d-4d0a-947f-1fc1679557ba)
![71a91805-fda5-481c-b380-264665703133](https://github.com/invoke-ai/InvokeAI/assets/25252829/f8f6a2ee-2b68-4482-87da-b90221d5c3e2)
--------------------------------
### Ideal Size
**Description:** This node calculates an ideal image size for a first pass of a multi-pass upscaling. The aim is to avoid duplication that results from choosing a size larger than the model is capable of.
**Node Link:** https://github.com/JPPhoto/ideal-size-node
--------------------------------
### Film Grain
**Description:** This node adds a film grain effect to the input image based on the weights, seeds, and blur radii parameters. It works with RGB input images only.
**Node Link:** https://github.com/JPPhoto/film-grain-node
--------------------------------
### Image Picker
**Description:** This InvokeAI node takes in a collection of images and randomly chooses one. This can be useful when you have a number of poses to choose from for a ControlNet node, or a number of input images for another purpose.
**Node Link:** https://github.com/JPPhoto/film-grain-node
--------------------------------
### Retroize

View File

@ -26,11 +26,16 @@ from typing import (
from pydantic import BaseModel, Field, validator
from pydantic.fields import Undefined, ModelField
from pydantic.typing import NoArgAnyCallable
import semver
if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
class InvalidVersionError(ValueError):
pass
class FieldDescriptions:
denoising_start = "When to start denoising, expressed a percentage of total steps"
denoising_end = "When to stop denoising, expressed a percentage of total steps"
@ -105,24 +110,39 @@ class UIType(str, Enum):
"""
# region Primitives
Integer = "integer"
Float = "float"
Boolean = "boolean"
String = "string"
Array = "array"
Image = "ImageField"
Latents = "LatentsField"
Color = "ColorField"
Conditioning = "ConditioningField"
Control = "ControlField"
Color = "ColorField"
ImageCollection = "ImageCollection"
ConditioningCollection = "ConditioningCollection"
ColorCollection = "ColorCollection"
LatentsCollection = "LatentsCollection"
IntegerCollection = "IntegerCollection"
FloatCollection = "FloatCollection"
StringCollection = "StringCollection"
Float = "float"
Image = "ImageField"
Integer = "integer"
Latents = "LatentsField"
String = "string"
# endregion
# region Collection Primitives
BooleanCollection = "BooleanCollection"
ColorCollection = "ColorCollection"
ConditioningCollection = "ConditioningCollection"
ControlCollection = "ControlCollection"
FloatCollection = "FloatCollection"
ImageCollection = "ImageCollection"
IntegerCollection = "IntegerCollection"
LatentsCollection = "LatentsCollection"
StringCollection = "StringCollection"
# endregion
# region Polymorphic Primitives
BooleanPolymorphic = "BooleanPolymorphic"
ColorPolymorphic = "ColorPolymorphic"
ConditioningPolymorphic = "ConditioningPolymorphic"
ControlPolymorphic = "ControlPolymorphic"
FloatPolymorphic = "FloatPolymorphic"
ImagePolymorphic = "ImagePolymorphic"
IntegerPolymorphic = "IntegerPolymorphic"
LatentsPolymorphic = "LatentsPolymorphic"
StringPolymorphic = "StringPolymorphic"
# endregion
# region Models
@ -176,6 +196,7 @@ class _InputField(BaseModel):
ui_type: Optional[UIType]
ui_component: Optional[UIComponent]
ui_order: Optional[int]
item_default: Optional[Any]
class _OutputField(BaseModel):
@ -223,6 +244,7 @@ def InputField(
ui_component: Optional[UIComponent] = None,
ui_hidden: bool = False,
ui_order: Optional[int] = None,
item_default: Optional[Any] = None,
**kwargs: Any,
) -> Any:
"""
@ -249,6 +271,11 @@ def InputField(
For this case, you could provide `UIComponent.Textarea`.
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
: param bool item_default: [None] Specifies the default item value, if this is a collection input. \
Ignored for non-collection fields..
"""
return Field(
*args,
@ -282,6 +309,7 @@ def InputField(
ui_component=ui_component,
ui_hidden=ui_hidden,
ui_order=ui_order,
item_default=item_default,
**kwargs,
)
@ -332,6 +360,8 @@ def OutputField(
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
"""
return Field(
*args,
@ -376,6 +406,9 @@ class UIConfigBase(BaseModel):
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
title: Optional[str] = Field(default=None, description="The node's display name")
category: Optional[str] = Field(default=None, description="The node's category")
version: Optional[str] = Field(
default=None, description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".'
)
class InvocationContext:
@ -474,6 +507,8 @@ class BaseInvocation(ABC, BaseModel):
schema["tags"] = uiconfig.tags
if uiconfig and hasattr(uiconfig, "category"):
schema["category"] = uiconfig.category
if uiconfig and hasattr(uiconfig, "version"):
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = list()
schema["required"].extend(["type", "id"])
@ -542,7 +577,11 @@ GenericBaseInvocation = TypeVar("GenericBaseInvocation", bound=BaseInvocation)
def invocation(
invocation_type: str, title: Optional[str] = None, tags: Optional[list[str]] = None, category: Optional[str] = None
invocation_type: str,
title: Optional[str] = None,
tags: Optional[list[str]] = None,
category: Optional[str] = None,
version: Optional[str] = None,
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
"""
Adds metadata to an invocation.
@ -569,6 +608,12 @@ def invocation(
cls.UIConfig.tags = tags
if category is not None:
cls.UIConfig.category = category
if version is not None:
try:
semver.Version.parse(version)
except ValueError as e:
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
cls.UIConfig.version = version
# Add the invocation type to the pydantic model of the invocation
invocation_type_annotation = Literal[invocation_type] # type: ignore
@ -580,8 +625,9 @@ def invocation(
config=cls.__config__,
)
cls.__fields__.update({"type": invocation_type_field})
cls.__annotations__.update({"type": invocation_type_annotation})
# to support 3.9, 3.10 and 3.11, as described in https://docs.python.org/3/howto/annotations.html
if annotations := cls.__dict__.get("__annotations__", None):
annotations.update({"type": invocation_type_annotation})
return cls
return wrapper
@ -615,7 +661,10 @@ def invocation_output(
config=cls.__config__,
)
cls.__fields__.update({"type": output_type_field})
cls.__annotations__.update({"type": output_type_annotation})
# to support 3.9, 3.10 and 3.11, as described in https://docs.python.org/3/howto/annotations.html
if annotations := cls.__dict__.get("__annotations__", None):
annotations.update({"type": output_type_annotation})
return cls

View File

@ -10,7 +10,9 @@ from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
@invocation("range", title="Integer Range", tags=["collection", "integer", "range"], category="collections")
@invocation(
"range", title="Integer Range", tags=["collection", "integer", "range"], category="collections", version="1.0.0"
)
class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step"""
@ -33,6 +35,7 @@ class RangeInvocation(BaseInvocation):
title="Integer Range of Size",
tags=["collection", "integer", "size", "range"],
category="collections",
version="1.0.0",
)
class RangeOfSizeInvocation(BaseInvocation):
"""Creates a range from start to start + size with step"""
@ -50,6 +53,7 @@ class RangeOfSizeInvocation(BaseInvocation):
title="Random Range",
tags=["range", "integer", "random", "collection"],
category="collections",
version="1.0.0",
)
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""

View File

@ -44,7 +44,7 @@ class ConditioningFieldData:
# PerpNeg = "perp_neg"
@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning")
@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning", version="1.0.0")
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
@ -267,6 +267,7 @@ class SDXLPromptInvocationBase:
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
version="1.0.0",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
@ -279,8 +280,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
crop_left: int = InputField(default=0, description="")
target_width: int = InputField(default=1024, description="")
target_height: int = InputField(default=1024, description="")
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -351,6 +352,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
title="SDXL Refiner Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
version="1.0.0",
)
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
@ -403,7 +405,7 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning")
@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning", version="1.0.0")
class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model."""

View File

@ -31,8 +31,8 @@ from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
InputField,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
@ -40,7 +40,9 @@ from .baseinvocation import (
)
@invocation("image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet")
@invocation(
"image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet", version="1.0.0"
)
class ImageProcessorInvocation(BaseInvocation):
"""Base class for invocations that preprocess images for ControlNet"""
@ -84,6 +86,7 @@ class ImageProcessorInvocation(BaseInvocation):
title="Canny Processor",
tags=["controlnet", "canny"],
category="controlnet",
version="1.0.0",
)
class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet"""
@ -106,6 +109,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
title="HED (softedge) Processor",
tags=["controlnet", "hed", "softedge"],
category="controlnet",
version="1.0.0",
)
class HedImageProcessorInvocation(ImageProcessorInvocation):
"""Applies HED edge detection to image"""
@ -134,6 +138,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
title="Lineart Processor",
tags=["controlnet", "lineart"],
category="controlnet",
version="1.0.0",
)
class LineartImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art processing to image"""
@ -155,6 +160,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
title="Lineart Anime Processor",
tags=["controlnet", "lineart", "anime"],
category="controlnet",
version="1.0.0",
)
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art anime processing to image"""
@ -177,6 +183,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
title="Openpose Processor",
tags=["controlnet", "openpose", "pose"],
category="controlnet",
version="1.0.0",
)
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Openpose processing to image"""
@ -201,6 +208,7 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
title="Midas Depth Processor",
tags=["controlnet", "midas"],
category="controlnet",
version="1.0.0",
)
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Midas depth processing to image"""
@ -227,6 +235,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
title="Normal BAE Processor",
tags=["controlnet"],
category="controlnet",
version="1.0.0",
)
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies NormalBae processing to image"""
@ -242,7 +251,9 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
return processed_image
@invocation("mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet")
@invocation(
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.0.0"
)
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
"""Applies MLSD processing to image"""
@ -263,7 +274,9 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
return processed_image
@invocation("pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet")
@invocation(
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.0.0"
)
class PidiImageProcessorInvocation(ImageProcessorInvocation):
"""Applies PIDI processing to image"""
@ -289,6 +302,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
title="Content Shuffle Processor",
tags=["controlnet", "contentshuffle"],
category="controlnet",
version="1.0.0",
)
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
"""Applies content shuffle processing to image"""
@ -318,6 +332,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
title="Zoe (Depth) Processor",
tags=["controlnet", "zoe", "depth"],
category="controlnet",
version="1.0.0",
)
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""
@ -333,6 +348,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
title="Mediapipe Face Processor",
tags=["controlnet", "mediapipe", "face"],
category="controlnet",
version="1.0.0",
)
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
"""Applies mediapipe face processing to image"""
@ -355,6 +371,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
title="Leres (Depth) Processor",
tags=["controlnet", "leres", "depth"],
category="controlnet",
version="1.0.0",
)
class LeresImageProcessorInvocation(ImageProcessorInvocation):
"""Applies leres processing to image"""
@ -383,6 +400,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
title="Tile Resample Processor",
tags=["controlnet", "tile"],
category="controlnet",
version="1.0.0",
)
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
"""Tile resampler processor"""
@ -422,6 +440,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
title="Segment Anything Processor",
tags=["controlnet", "segmentanything"],
category="controlnet",
version="1.0.0",
)
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image"""

View File

@ -10,12 +10,7 @@ from invokeai.app.models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
@invocation(
"cv_inpaint",
title="OpenCV Inpaint",
tags=["opencv", "inpaint"],
category="inpaint",
)
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.0.0")
class CvInpaintInvocation(BaseInvocation):
"""Simple inpaint using opencv."""

View File

@ -16,7 +16,7 @@ from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
@invocation("show_image", title="Show Image", tags=["image"], category="image")
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
class ShowImageInvocation(BaseInvocation):
"""Displays a provided image using the OS image viewer, and passes it forward in the pipeline."""
@ -36,7 +36,7 @@ class ShowImageInvocation(BaseInvocation):
)
@invocation("blank_image", title="Blank Image", tags=["image"], category="image")
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.0.0")
class BlankImageInvocation(BaseInvocation):
"""Creates a blank image and forwards it to the pipeline"""
@ -65,7 +65,7 @@ class BlankImageInvocation(BaseInvocation):
)
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image")
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.0.0")
class ImageCropInvocation(BaseInvocation):
"""Crops an image to a specified box. The box can be outside of the image."""
@ -98,7 +98,7 @@ class ImageCropInvocation(BaseInvocation):
)
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image")
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.0.0")
class ImagePasteInvocation(BaseInvocation):
"""Pastes an image into another image."""
@ -146,7 +146,7 @@ class ImagePasteInvocation(BaseInvocation):
)
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image")
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.0.0")
class MaskFromAlphaInvocation(BaseInvocation):
"""Extracts the alpha channel of an image as a mask."""
@ -177,7 +177,7 @@ class MaskFromAlphaInvocation(BaseInvocation):
)
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image")
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.0.0")
class ImageMultiplyInvocation(BaseInvocation):
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
@ -210,7 +210,7 @@ class ImageMultiplyInvocation(BaseInvocation):
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image")
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.0.0")
class ImageChannelInvocation(BaseInvocation):
"""Gets a channel from an image."""
@ -242,7 +242,7 @@ class ImageChannelInvocation(BaseInvocation):
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image")
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.0.0")
class ImageConvertInvocation(BaseInvocation):
"""Converts an image to a different mode."""
@ -271,7 +271,7 @@ class ImageConvertInvocation(BaseInvocation):
)
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image")
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.0.0")
class ImageBlurInvocation(BaseInvocation):
"""Blurs an image"""
@ -325,7 +325,7 @@ PIL_RESAMPLING_MAP = {
}
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image")
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.0.0")
class ImageResizeInvocation(BaseInvocation):
"""Resizes an image to specific dimensions"""
@ -365,7 +365,7 @@ class ImageResizeInvocation(BaseInvocation):
)
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image")
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.0.0")
class ImageScaleInvocation(BaseInvocation):
"""Scales an image by a factor"""
@ -406,7 +406,7 @@ class ImageScaleInvocation(BaseInvocation):
)
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image")
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.0.0")
class ImageLerpInvocation(BaseInvocation):
"""Linear interpolation of all pixels of an image"""
@ -439,7 +439,7 @@ class ImageLerpInvocation(BaseInvocation):
)
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image")
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.0.0")
class ImageInverseLerpInvocation(BaseInvocation):
"""Inverse linear interpolation of all pixels of an image"""
@ -472,7 +472,7 @@ class ImageInverseLerpInvocation(BaseInvocation):
)
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image")
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.0.0")
class ImageNSFWBlurInvocation(BaseInvocation):
"""Add blur to NSFW-flagged images"""
@ -517,7 +517,9 @@ class ImageNSFWBlurInvocation(BaseInvocation):
return caution.resize((caution.width // 2, caution.height // 2))
@invocation("img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image")
@invocation(
"img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image", version="1.0.0"
)
class ImageWatermarkInvocation(BaseInvocation):
"""Add an invisible watermark to an image"""
@ -548,7 +550,7 @@ class ImageWatermarkInvocation(BaseInvocation):
)
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image")
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.0.0")
class MaskEdgeInvocation(BaseInvocation):
"""Applies an edge mask to an image"""
@ -593,7 +595,9 @@ class MaskEdgeInvocation(BaseInvocation):
)
@invocation("mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image")
@invocation(
"mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image", version="1.0.0"
)
class MaskCombineInvocation(BaseInvocation):
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
@ -623,7 +627,7 @@ class MaskCombineInvocation(BaseInvocation):
)
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image")
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.0.0")
class ColorCorrectInvocation(BaseInvocation):
"""
Shifts the colors of a target image to match the reference image, optionally
@ -728,7 +732,7 @@ class ColorCorrectInvocation(BaseInvocation):
)
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image")
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.0.0")
class ImageHueAdjustmentInvocation(BaseInvocation):
"""Adjusts the Hue of an image."""
@ -774,6 +778,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
title="Adjust Image Luminosity",
tags=["image", "luminosity", "hsl"],
category="image",
version="1.0.0",
)
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
"""Adjusts the Luminosity (Value) of an image."""
@ -826,6 +831,7 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
title="Adjust Image Saturation",
tags=["image", "saturation", "hsl"],
category="image",
version="1.0.0",
)
class ImageSaturationAdjustmentInvocation(BaseInvocation):
"""Adjusts the Saturation of an image."""

View File

@ -116,7 +116,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
return si
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint")
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color"""
@ -151,7 +151,7 @@ class InfillColorInvocation(BaseInvocation):
)
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint")
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
class InfillTileInvocation(BaseInvocation):
"""Infills transparent areas of an image with tiles of the image"""
@ -187,7 +187,9 @@ class InfillTileInvocation(BaseInvocation):
)
@invocation("infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint")
@invocation(
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0"
)
class InfillPatchMatchInvocation(BaseInvocation):
"""Infills transparent areas of an image using the PatchMatch algorithm"""
@ -218,7 +220,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
)
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint")
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
class LaMaInfillInvocation(BaseInvocation):
"""Infills transparent areas of an image using the LaMa model"""

View File

@ -76,7 +76,7 @@ class SchedulerOutput(BaseInvocationOutput):
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents")
@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents", version="1.0.0")
class SchedulerInvocation(BaseInvocation):
"""Selects a scheduler."""
@ -88,7 +88,9 @@ class SchedulerInvocation(BaseInvocation):
return SchedulerOutput(scheduler=self.scheduler)
@invocation("create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents")
@invocation(
"create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents", version="1.0.0"
)
class CreateDenoiseMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
@ -188,6 +190,7 @@ def get_scheduler(
title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
version="1.0.0",
)
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
@ -210,12 +213,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
control: Union[ControlField, list[ControlField]] = InputField(
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
default=None,
description=FieldDescriptions.control,
input=Input.Connection,
ui_order=5,
)
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,
description=FieldDescriptions.mask,
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=6
)
# ip_adapter_image: Optional[ImageField] = InputField(input=Input.Connection, title="IP Adapter Image", ui_order=6)
# ip_adapter_strength: float = InputField(default=1.0, ge=0, le=2, ui_type=UIType.Float,
@ -322,7 +327,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
context: InvocationContext,
# really only need model for dtype and device
model: StableDiffusionGeneratorPipeline,
control_input: List[ControlField],
control_input: Union[ControlField, List[ControlField]],
latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
@ -573,7 +578,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
@invocation("l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents")
@invocation(
"l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", version="1.0.0"
)
class LatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""
@ -670,7 +677,7 @@ class LatentsToImageInvocation(BaseInvocation):
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents")
@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
@ -714,7 +721,7 @@ class ResizeLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents")
@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
class ScaleLatentsInvocation(BaseInvocation):
"""Scales latents by a given factor."""
@ -750,7 +757,9 @@ class ScaleLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@invocation("i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents")
@invocation(
"i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", version="1.0.0"
)
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
@ -830,7 +839,7 @@ class ImageToLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=latents, seed=None)
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents")
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0")
class BlendLatentsInvocation(BaseInvocation):
"""Blend two latents using a given alpha. Latents must have same size."""

View File

@ -7,7 +7,7 @@ from invokeai.app.invocations.primitives import IntegerOutput
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
@invocation("add", title="Add Integers", tags=["math", "add"], category="math")
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0")
class AddInvocation(BaseInvocation):
"""Adds two numbers"""
@ -18,7 +18,7 @@ class AddInvocation(BaseInvocation):
return IntegerOutput(value=self.a + self.b)
@invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math")
@invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math", version="1.0.0")
class SubtractInvocation(BaseInvocation):
"""Subtracts two numbers"""
@ -29,7 +29,7 @@ class SubtractInvocation(BaseInvocation):
return IntegerOutput(value=self.a - self.b)
@invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math")
@invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math", version="1.0.0")
class MultiplyInvocation(BaseInvocation):
"""Multiplies two numbers"""
@ -40,7 +40,7 @@ class MultiplyInvocation(BaseInvocation):
return IntegerOutput(value=self.a * self.b)
@invocation("div", title="Divide Integers", tags=["math", "divide"], category="math")
@invocation("div", title="Divide Integers", tags=["math", "divide"], category="math", version="1.0.0")
class DivideInvocation(BaseInvocation):
"""Divides two numbers"""
@ -51,7 +51,7 @@ class DivideInvocation(BaseInvocation):
return IntegerOutput(value=int(self.a / self.b))
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math")
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math", version="1.0.0")
class RandomIntInvocation(BaseInvocation):
"""Outputs a single random integer."""

View File

@ -72,10 +72,10 @@ class CoreMetadata(BaseModelExcludeNull):
)
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
refiner_positive_aesthetic_store: Optional[float] = Field(
refiner_positive_aesthetic_score: Optional[float] = Field(
default=None, description="The aesthetic score used for the refiner"
)
refiner_negative_aesthetic_store: Optional[float] = Field(
refiner_negative_aesthetic_score: Optional[float] = Field(
default=None, description="The aesthetic score used for the refiner"
)
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
@ -98,7 +98,9 @@ class MetadataAccumulatorOutput(BaseInvocationOutput):
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
@invocation("metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata")
@invocation(
"metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata", version="1.0.0"
)
class MetadataAccumulatorInvocation(BaseInvocation):
"""Outputs a Core Metadata Object"""
@ -160,11 +162,11 @@ class MetadataAccumulatorInvocation(BaseInvocation):
default=None,
description="The scheduler used for the refiner",
)
refiner_positive_aesthetic_store: Optional[float] = InputField(
refiner_positive_aesthetic_score: Optional[float] = InputField(
default=None,
description="The aesthetic score used for the refiner",
)
refiner_negative_aesthetic_store: Optional[float] = InputField(
refiner_negative_aesthetic_score: Optional[float] = InputField(
default=None,
description="The aesthetic score used for the refiner",
)

View File

@ -73,7 +73,7 @@ class LoRAModelField(BaseModel):
base_model: BaseModelType = Field(description="Base model")
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model")
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model", version="1.0.0")
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
@ -173,7 +173,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.0")
class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
@ -244,19 +244,19 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model")
@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model", version="1.0.0")
class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = Field(
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET"
unet: Optional[UNetField] = InputField(
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
)
clip: Optional[ClipField] = Field(
clip: Optional[ClipField] = InputField(
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
)
clip2: Optional[ClipField] = Field(
clip2: Optional[ClipField] = InputField(
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
)
@ -338,7 +338,7 @@ class VaeLoaderOutput(BaseInvocationOutput):
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
@ -376,7 +376,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model")
@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model", version="1.0.0")
class SeamlessModeInvocation(BaseInvocation):
"""Applies the seamless transformation to the Model UNet and VAE."""

View File

@ -78,7 +78,7 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
)
@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents")
@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents", version="1.0.0")
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""

View File

@ -56,7 +56,7 @@ ORT_TO_NP_TYPE = {
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning")
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
class ONNXPromptInvocation(BaseInvocation):
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
@ -143,6 +143,7 @@ class ONNXPromptInvocation(BaseInvocation):
title="ONNX Text to Latents",
tags=["latents", "inference", "txt2img", "onnx"],
category="latents",
version="1.0.0",
)
class ONNXTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
@ -319,6 +320,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
title="ONNX Latents to Image",
tags=["latents", "image", "vae", "onnx"],
category="image",
version="1.0.0",
)
class ONNXLatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""
@ -403,7 +405,7 @@ class OnnxModelField(BaseModel):
model_type: ModelType = Field(description="Model Type")
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model")
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0")
class OnnxModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""

View File

@ -45,7 +45,7 @@ from invokeai.app.invocations.primitives import FloatCollectionOutput
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
@invocation("float_range", title="Float Range", tags=["math", "range"], category="math")
@invocation("float_range", title="Float Range", tags=["math", "range"], category="math", version="1.0.0")
class FloatLinearRangeInvocation(BaseInvocation):
"""Creates a range"""
@ -96,7 +96,7 @@ EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
# actually I think for now could just use CollectionOutput (which is list[Any]
@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step")
@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step", version="1.0.0")
class StepParamEasingInvocation(BaseInvocation):
"""Experimental per-step parameter easing for denoising steps"""

View File

@ -14,7 +14,6 @@ from .baseinvocation import (
InvocationContext,
OutputField,
UIComponent,
UIType,
invocation,
invocation_output,
)
@ -40,10 +39,14 @@ class BooleanOutput(BaseInvocationOutput):
class BooleanCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of booleans"""
collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection)
collection: list[bool] = OutputField(
description="The output boolean collection",
)
@invocation("boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives")
@invocation(
"boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives", version="1.0.0"
)
class BooleanInvocation(BaseInvocation):
"""A boolean primitive value"""
@ -58,13 +61,12 @@ class BooleanInvocation(BaseInvocation):
title="Boolean Collection Primitive",
tags=["primitives", "boolean", "collection"],
category="primitives",
version="1.0.0",
)
class BooleanCollectionInvocation(BaseInvocation):
"""A collection of boolean primitive values"""
collection: list[bool] = InputField(
default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection
)
collection: list[bool] = InputField(default_factory=list, description="The collection of boolean values")
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
return BooleanCollectionOutput(collection=self.collection)
@ -86,10 +88,14 @@ class IntegerOutput(BaseInvocationOutput):
class IntegerCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of integers"""
collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection)
collection: list[int] = OutputField(
description="The int collection",
)
@invocation("integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives")
@invocation(
"integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives", version="1.0.0"
)
class IntegerInvocation(BaseInvocation):
"""An integer primitive value"""
@ -104,13 +110,12 @@ class IntegerInvocation(BaseInvocation):
title="Integer Collection Primitive",
tags=["primitives", "integer", "collection"],
category="primitives",
version="1.0.0",
)
class IntegerCollectionInvocation(BaseInvocation):
"""A collection of integer primitive values"""
collection: list[int] = InputField(
default_factory=list, description="The collection of integer values", ui_type=UIType.IntegerCollection
)
collection: list[int] = InputField(default_factory=list, description="The collection of integer values")
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
return IntegerCollectionOutput(collection=self.collection)
@ -132,10 +137,12 @@ class FloatOutput(BaseInvocationOutput):
class FloatCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of floats"""
collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection)
collection: list[float] = OutputField(
description="The float collection",
)
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives")
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives", version="1.0.0")
class FloatInvocation(BaseInvocation):
"""A float primitive value"""
@ -150,13 +157,12 @@ class FloatInvocation(BaseInvocation):
title="Float Collection Primitive",
tags=["primitives", "float", "collection"],
category="primitives",
version="1.0.0",
)
class FloatCollectionInvocation(BaseInvocation):
"""A collection of float primitive values"""
collection: list[float] = InputField(
default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection
)
collection: list[float] = InputField(default_factory=list, description="The collection of float values")
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
return FloatCollectionOutput(collection=self.collection)
@ -178,10 +184,12 @@ class StringOutput(BaseInvocationOutput):
class StringCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of strings"""
collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection)
collection: list[str] = OutputField(
description="The output strings",
)
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives")
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives", version="1.0.0")
class StringInvocation(BaseInvocation):
"""A string primitive value"""
@ -196,13 +204,12 @@ class StringInvocation(BaseInvocation):
title="String Collection Primitive",
tags=["primitives", "string", "collection"],
category="primitives",
version="1.0.0",
)
class StringCollectionInvocation(BaseInvocation):
"""A collection of string primitive values"""
collection: list[str] = InputField(
default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection
)
collection: list[str] = InputField(default_factory=list, description="The collection of string values")
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
return StringCollectionOutput(collection=self.collection)
@ -232,10 +239,12 @@ class ImageOutput(BaseInvocationOutput):
class ImageCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of images"""
collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection)
collection: list[ImageField] = OutputField(
description="The output images",
)
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives")
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0")
class ImageInvocation(BaseInvocation):
"""An image primitive value"""
@ -256,13 +265,12 @@ class ImageInvocation(BaseInvocation):
title="Image Collection Primitive",
tags=["primitives", "image", "collection"],
category="primitives",
version="1.0.0",
)
class ImageCollectionInvocation(BaseInvocation):
"""A collection of image primitive values"""
collection: list[ImageField] = InputField(
default_factory=list, description="The collection of image values", ui_type=UIType.ImageCollection
)
collection: list[ImageField] = InputField(description="The collection of image values")
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
return ImageCollectionOutput(collection=self.collection)
@ -316,11 +324,12 @@ class LatentsCollectionOutput(BaseInvocationOutput):
collection: list[LatentsField] = OutputField(
description=FieldDescriptions.latents,
ui_type=UIType.LatentsCollection,
)
@invocation("latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives")
@invocation(
"latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.0"
)
class LatentsInvocation(BaseInvocation):
"""A latents tensor primitive value"""
@ -337,12 +346,13 @@ class LatentsInvocation(BaseInvocation):
title="Latents Collection Primitive",
tags=["primitives", "latents", "collection"],
category="primitives",
version="1.0.0",
)
class LatentsCollectionInvocation(BaseInvocation):
"""A collection of latents tensor primitive values"""
collection: list[LatentsField] = InputField(
description="The collection of latents tensors", ui_type=UIType.LatentsCollection
description="The collection of latents tensors",
)
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
@ -385,10 +395,12 @@ class ColorOutput(BaseInvocationOutput):
class ColorCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of colors"""
collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection)
collection: list[ColorField] = OutputField(
description="The output colors",
)
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives")
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives", version="1.0.0")
class ColorInvocation(BaseInvocation):
"""A color primitive value"""
@ -422,7 +434,6 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
collection: list[ConditioningField] = OutputField(
description="The output conditioning tensors",
ui_type=UIType.ConditioningCollection,
)
@ -431,6 +442,7 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
title="Conditioning Primitive",
tags=["primitives", "conditioning"],
category="primitives",
version="1.0.0",
)
class ConditioningInvocation(BaseInvocation):
"""A conditioning tensor primitive value"""
@ -446,6 +458,7 @@ class ConditioningInvocation(BaseInvocation):
title="Conditioning Collection Primitive",
tags=["primitives", "conditioning", "collection"],
category="primitives",
version="1.0.0",
)
class ConditioningCollectionInvocation(BaseInvocation):
"""A collection of conditioning tensor primitive values"""
@ -453,7 +466,6 @@ class ConditioningCollectionInvocation(BaseInvocation):
collection: list[ConditioningField] = InputField(
default_factory=list,
description="The collection of conditioning tensors",
ui_type=UIType.ConditioningCollection,
)
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:

View File

@ -10,7 +10,7 @@ from invokeai.app.invocations.primitives import StringCollectionOutput
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt")
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt", version="1.0.0")
class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
@ -29,7 +29,7 @@ class DynamicPromptInvocation(BaseInvocation):
return StringCollectionOutput(collection=prompts)
@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt")
@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt", version="1.0.0")
class PromptsFromFileInvocation(BaseInvocation):
"""Loads prompts from a text file"""

View File

@ -33,7 +33,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.0")
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
@ -119,6 +119,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
title="SDXL Refiner Model",
tags=["model", "sdxl", "refiner"],
category="model",
version="1.0.0",
)
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""

View File

@ -23,7 +23,7 @@ ESRGAN_MODELS = Literal[
]
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan")
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.0.0")
class ESRGANInvocation(BaseInvocation):
"""Upscales an image using RealESRGAN."""

View File

@ -112,6 +112,10 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if to_type in get_args(from_type):
return True
# allow int -> float, pydantic will cast for us
if from_type is int and to_type is float:
return True
# if not issubclass(from_type, to_type):
if not is_union_subtype(from_type, to_type):
return False

View File

@ -50,6 +50,7 @@ class ModelProbe(object):
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
}

View File

@ -1,6 +0,0 @@
from ldm.modules.image_degradation.bsrgan import ( # noqa: F401
degradation_bsrgan_variant as degradation_fn_bsr,
)
from ldm.modules.image_degradation.bsrgan_light import ( # noqa: F401
degradation_bsrgan_variant as degradation_fn_bsr_light,
)

View File

@ -1,794 +0,0 @@
# -*- coding: utf-8 -*-
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
import random
from functools import partial
import albumentations
import cv2
import ldm.modules.image_degradation.utils_image as util
import numpy as np
import scipy
import scipy.stats as ss
import torch
from scipy import ndimage
from scipy.interpolate import interp2d
from scipy.linalg import orth
def modcrop_np(img, sf):
"""
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
"""
w, h = img.shape[:2]
im = np.copy(img)
return im[: w - w % sf, : h - h % sf, ...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def analytic_kernel(k):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size = k.shape[0]
# Calculate the big kernels size
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
# Loop over the small kernel to fill the big one
for r in range(k_size):
for c in range(k_size):
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop]
# Normalize to 1
return cropped_big_k / cropped_big_k.sum()
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
"""generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v = np.dot(
np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]),
np.array([1.0, 0.0]),
)
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
return k
def gm_blur_kernel(mean, cov, size=15):
center = size / 2.0 + 0.5
k = np.zeros([size, size])
for y in range(size):
for x in range(size):
cy = y - center + 1
cx = x - center + 1
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
k = k / np.sum(k)
return k
def shift_pixel(x, sf, upper_left=True):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h, w = x.shape[:2]
shift = (sf - 1) * 0.5
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
if upper_left:
x1 = xv + shift
y1 = yv + shift
else:
x1 = xv - shift
y1 = yv - shift
x1 = np.clip(x1, 0, w - 1)
y1 = np.clip(y1, 0, h - 1)
if x.ndim == 2:
x = interp2d(xv, yv, x)(x1, y1)
if x.ndim == 3:
for i in range(x.shape[-1]):
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
return x
def blur(x, k):
"""
x: image, NxcxHxW
k: kernel, Nx1xhxw
"""
n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate")
k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
x = x.view(n, c, x.shape[2], x.shape[3])
return x
def gen_kernel(
k_size=np.array([15, 15]),
scale_factor=np.array([4, 4]),
min_var=0.6,
max_var=10.0,
noise_level=0,
):
""" "
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
theta = np.random.rand() * np.pi # random theta
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
# Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None]
# Create meshgrid for Gaussian
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
Z = np.stack([X, Y], 2)[:, :, :, None]
# Calcualte Gaussian for every pixel of the kernel
ZZ = Z - MU
ZZ_t = ZZ.transpose(0, 1, 3, 2)
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel = raw_kernel / np.sum(raw_kernel)
return kernel
def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0
sumh = h.sum()
if sumh != 0:
h = h / sumh
return h
def fspecial_laplacian(alpha):
alpha = max([0, min([alpha, 1])])
h1 = alpha / (alpha + 1)
h2 = (1 - alpha) / (alpha + 1)
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
h = np.array(h)
return h
def fspecial(filter_type, *args, **kwargs):
"""
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
"""
if filter_type == "gaussian":
return fspecial_gaussian(*args, **kwargs)
if filter_type == "laplacian":
return fspecial_laplacian(*args, **kwargs)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def bicubic_degradation(x, sf=3):
"""
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
"""
x = util.imresize_np(x, scale=1 / sf)
return x
def srmd_degradation(x, k, sf=3):
"""blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
"""
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
def dpsr_degradation(x, k, sf=3):
"""bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
"""
x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
return x
def classical_degradation(x, k, sf=3):
"""blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
"""
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0
return x[st::sf, st::sf, ...]
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if radius % 2 == 0:
radius += 1
blur = cv2.GaussianBlur(img, (radius, radius), 0)
residual = img - blur
mask = np.abs(residual) * 255 > threshold
mask = mask.astype("float32")
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
K = img + weight * residual
K = np.clip(K, 0, 1)
return soft_mask * K + (1 - soft_mask) * img
def add_blur(img, sf=4):
wd2 = 4.0 + sf
wd = 2.0 + 0.2 * sf
if random.random() < 0.5:
l1 = wd2 * random.random()
l2 = wd2 * random.random()
k = anisotropic_Gaussian(
ksize=2 * random.randint(2, 11) + 3,
theta=random.random() * np.pi,
l1=l1,
l2=l2,
)
else:
k = fspecial("gaussian", 2 * random.randint(2, 11) + 3, wd * random.random())
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode="mirror")
return img
def add_resize(img, sf=4):
rnum = np.random.rand()
if rnum > 0.8: # up
sf1 = random.uniform(1, 2)
elif rnum < 0.7: # down
sf1 = random.uniform(0.5 / sf, 1)
else:
sf1 = 1.0
img = cv2.resize(
img,
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
return img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
else: # add noise
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
img = np.clip(img, 0.0, 1.0)
rnum = random.random()
if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
elif rnum < 0.4:
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
else:
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals
else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0)
return img
def add_JPEG_noise(img):
quality_factor = random.randint(30, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img
def random_crop(lq, hq, sf=4, lq_patchsize=64):
h, w = lq.shape[:2]
rnd_h = random.randint(0, h - lq_patchsize)
rnd_w = random.randint(0, w - lq_patchsize)
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[
rnd_h_H : rnd_h_H + lq_patchsize * sf,
rnd_w_H : rnd_w_H + lq_patchsize * sf,
:,
]
return lq, hq
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
sf_ori = sf
h1, w1 = img.shape[:2]
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
raise ValueError(f"img size ({h1}X{w1}) is too small!")
hq = img.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
img = cv2.resize(
img,
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
if i == 0:
img = add_blur(img, sf=sf)
elif i == 1:
img = add_blur(img, sf=sf)
elif i == 2:
a, b = img.shape[1], img.shape[0]
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(
img,
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror")
img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0)
elif i == 3:
# downsample3
img = cv2.resize(
img,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
img = add_JPEG_noise(img)
elif i == 6:
# add processed camera sensor noise
if random.random() < isp_prob and isp_model is not None:
with torch.no_grad():
img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
img = add_JPEG_noise(img)
# random crop
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
return img, hq
# todo no isp_model?
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image = util.uint2single(image)
jpeg_prob, scale2_prob = 0.9, 0.25
# isp_prob = 0.25 # uncomment with `if i== 6` block below
# sf_ori = sf # uncomment with `if i== 6` block below
h1, w1 = image.shape[:2]
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2]
# hq = image.copy() # uncomment with `if i== 6` block below
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
image = cv2.resize(
image,
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
if i == 0:
image = add_blur(image, sf=sf)
elif i == 1:
image = add_blur(image, sf=sf)
elif i == 2:
a, b = image.shape[1], image.shape[0]
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(
image,
(
int(1 / sf1 * image.shape[1]),
int(1 / sf1 * image.shape[0]),
),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror")
image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0)
elif i == 3:
# downsample3
image = cv2.resize(
image,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
image = np.clip(image, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
image = add_JPEG_noise(image)
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = util.single2uint(image)
example = {"image": image}
return example
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
def degradation_bsrgan_plus(
img,
sf=4,
shuffle_prob=0.5,
use_sharp=True,
lq_patchsize=64,
isp_model=None,
):
"""
This is an extended degradation model by combining
the degradation models of BSRGAN and Real-ESRGAN
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
use_shuffle: the degradation shuffle
use_sharp: sharpening the img
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
h1, w1 = img.shape[:2]
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
raise ValueError(f"img size ({h1}X{w1}) is too small!")
if use_sharp:
img = add_sharpening(img)
hq = img.copy()
if random.random() < shuffle_prob:
shuffle_order = random.sample(range(13), 13)
else:
shuffle_order = list(range(13))
# local shuffle for noise, JPEG is always the last one
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
for i in shuffle_order:
if i == 0:
img = add_blur(img, sf=sf)
elif i == 1:
img = add_resize(img, sf=sf)
elif i == 2:
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
elif i == 3:
if random.random() < poisson_prob:
img = add_Poisson_noise(img)
elif i == 4:
if random.random() < speckle_prob:
img = add_speckle_noise(img)
elif i == 5:
if random.random() < isp_prob and isp_model is not None:
with torch.no_grad():
img, hq = isp_model.forward(img.copy(), hq)
elif i == 6:
img = add_JPEG_noise(img)
elif i == 7:
img = add_blur(img, sf=sf)
elif i == 8:
img = add_resize(img, sf=sf)
elif i == 9:
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
elif i == 10:
if random.random() < poisson_prob:
img = add_Poisson_noise(img)
elif i == 11:
if random.random() < speckle_prob:
img = add_speckle_noise(img)
elif i == 12:
if random.random() < isp_prob and isp_model is not None:
with torch.no_grad():
img, hq = isp_model.forward(img.copy(), hq)
else:
print("check the shuffle!")
# resize to desired size
img = cv2.resize(
img,
(int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
# add final JPEG compression noise
img = add_JPEG_noise(img)
# random crop
img, hq = random_crop(img, hq, sf, lq_patchsize)
return img, hq
if __name__ == "__main__":
print("hey")
img = util.imread_uint("utils/test.png", 3)
print(img)
img = util.uint2single(img)
print(img)
img = img[:448, :448]
h = img.shape[0] // 4
print("resizing to", h)
sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20):
print(i)
img_lq = deg_fn(img)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
# print(img_hq.shape)
lq_nearest = cv2.resize(
util.single2uint(img_lq),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
lq_bicubic_nearest = cv2.resize(
util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
# img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest], axis=1)
util.imsave(img_concat, str(i) + ".png")

View File

@ -1,704 +0,0 @@
# -*- coding: utf-8 -*-
import random
from functools import partial
import albumentations
import cv2
import ldm.modules.image_degradation.utils_image as util
import numpy as np
import scipy
import scipy.stats as ss
import torch
from scipy import ndimage
from scipy.interpolate import interp2d
from scipy.linalg import orth
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
def modcrop_np(img, sf):
"""
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
"""
w, h = img.shape[:2]
im = np.copy(img)
return im[: w - w % sf, : h - h % sf, ...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def analytic_kernel(k):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size = k.shape[0]
# Calculate the big kernels size
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
# Loop over the small kernel to fill the big one
for r in range(k_size):
for c in range(k_size):
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop]
# Normalize to 1
return cropped_big_k / cropped_big_k.sum()
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
"""generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v = np.dot(
np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]),
np.array([1.0, 0.0]),
)
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
return k
def gm_blur_kernel(mean, cov, size=15):
center = size / 2.0 + 0.5
k = np.zeros([size, size])
for y in range(size):
for x in range(size):
cy = y - center + 1
cx = x - center + 1
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
k = k / np.sum(k)
return k
def shift_pixel(x, sf, upper_left=True):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h, w = x.shape[:2]
shift = (sf - 1) * 0.5
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
if upper_left:
x1 = xv + shift
y1 = yv + shift
else:
x1 = xv - shift
y1 = yv - shift
x1 = np.clip(x1, 0, w - 1)
y1 = np.clip(y1, 0, h - 1)
if x.ndim == 2:
x = interp2d(xv, yv, x)(x1, y1)
if x.ndim == 3:
for i in range(x.shape[-1]):
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
return x
def blur(x, k):
"""
x: image, NxcxHxW
k: kernel, Nx1xhxw
"""
n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate")
k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
x = x.view(n, c, x.shape[2], x.shape[3])
return x
def gen_kernel(
k_size=np.array([15, 15]),
scale_factor=np.array([4, 4]),
min_var=0.6,
max_var=10.0,
noise_level=0,
):
""" "
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
theta = np.random.rand() * np.pi # random theta
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
# Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None]
# Create meshgrid for Gaussian
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
Z = np.stack([X, Y], 2)[:, :, :, None]
# Calcualte Gaussian for every pixel of the kernel
ZZ = Z - MU
ZZ_t = ZZ.transpose(0, 1, 3, 2)
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel = raw_kernel / np.sum(raw_kernel)
return kernel
def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0
sumh = h.sum()
if sumh != 0:
h = h / sumh
return h
def fspecial_laplacian(alpha):
alpha = max([0, min([alpha, 1])])
h1 = alpha / (alpha + 1)
h2 = (1 - alpha) / (alpha + 1)
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
h = np.array(h)
return h
def fspecial(filter_type, *args, **kwargs):
"""
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
"""
if filter_type == "gaussian":
return fspecial_gaussian(*args, **kwargs)
if filter_type == "laplacian":
return fspecial_laplacian(*args, **kwargs)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def bicubic_degradation(x, sf=3):
"""
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
"""
x = util.imresize_np(x, scale=1 / sf)
return x
def srmd_degradation(x, k, sf=3):
"""blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
"""
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
def dpsr_degradation(x, k, sf=3):
"""bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
"""
x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
return x
def classical_degradation(x, k, sf=3):
"""blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
"""
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0
return x[st::sf, st::sf, ...]
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if radius % 2 == 0:
radius += 1
blur = cv2.GaussianBlur(img, (radius, radius), 0)
residual = img - blur
mask = np.abs(residual) * 255 > threshold
mask = mask.astype("float32")
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
K = img + weight * residual
K = np.clip(K, 0, 1)
return soft_mask * K + (1 - soft_mask) * img
def add_blur(img, sf=4):
wd2 = 4.0 + sf
wd = 2.0 + 0.2 * sf
wd2 = wd2 / 4
wd = wd / 4
if random.random() < 0.5:
l1 = wd2 * random.random()
l2 = wd2 * random.random()
k = anisotropic_Gaussian(
ksize=random.randint(2, 11) + 3,
theta=random.random() * np.pi,
l1=l1,
l2=l2,
)
else:
k = fspecial("gaussian", random.randint(2, 4) + 3, wd * random.random())
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode="mirror")
return img
def add_resize(img, sf=4):
rnum = np.random.rand()
if rnum > 0.8: # up
sf1 = random.uniform(1, 2)
elif rnum < 0.7: # down
sf1 = random.uniform(0.5 / sf, 1)
else:
sf1 = 1.0
img = cv2.resize(
img,
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
return img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
else: # add noise
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
img = np.clip(img, 0.0, 1.0)
rnum = random.random()
if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
elif rnum < 0.4:
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
else:
L = noise_level2 / 255.0
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals
else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0)
return img
def add_JPEG_noise(img):
quality_factor = random.randint(80, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img
def random_crop(lq, hq, sf=4, lq_patchsize=64):
h, w = lq.shape[:2]
rnd_h = random.randint(0, h - lq_patchsize)
rnd_w = random.randint(0, w - lq_patchsize)
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[
rnd_h_H : rnd_h_H + lq_patchsize * sf,
rnd_w_H : rnd_w_H + lq_patchsize * sf,
:,
]
return lq, hq
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
sf_ori = sf
h1, w1 = img.shape[:2]
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
raise ValueError(f"img size ({h1}X{w1}) is too small!")
hq = img.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
img = cv2.resize(
img,
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
if i == 0:
img = add_blur(img, sf=sf)
elif i == 1:
img = add_blur(img, sf=sf)
elif i == 2:
a, b = img.shape[1], img.shape[0]
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(
img,
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror")
img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0)
elif i == 3:
# downsample3
img = cv2.resize(
img,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
img = np.clip(img, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
img = add_JPEG_noise(img)
elif i == 6:
# add processed camera sensor noise
if random.random() < isp_prob and isp_model is not None:
with torch.no_grad():
img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
img = add_JPEG_noise(img)
# random crop
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
return img, hq
# todo no isp_model?
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image = util.uint2single(image)
jpeg_prob, scale2_prob = 0.9, 0.25
# isp_prob = 0.25 # uncomment with `if i== 6` block below
# sf_ori = sf # uncomment with `if i== 6` block below
h1, w1 = image.shape[:2]
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2]
# hq = image.copy() # uncomment with `if i== 6` block below
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
image = cv2.resize(
image,
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]),
)
else:
image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = (
shuffle_order[idx2],
shuffle_order[idx1],
)
for i in shuffle_order:
if i == 0:
image = add_blur(image, sf=sf)
# elif i == 1:
# image = add_blur(image, sf=sf)
if i == 0:
pass
elif i == 2:
a, b = image.shape[1], image.shape[0]
# downsample2
if random.random() < 0.8:
sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(
image,
(
int(1 / sf1 * image.shape[1]),
int(1 / sf1 * image.shape[0]),
),
interpolation=random.choice([1, 2, 3]),
)
else:
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror")
image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0)
elif i == 3:
# downsample3
image = cv2.resize(
image,
(int(1 / sf * a), int(1 / sf * b)),
interpolation=random.choice([1, 2, 3]),
)
image = np.clip(image, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
image = add_JPEG_noise(image)
#
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = util.single2uint(image)
example = {"image": image}
return example
if __name__ == "__main__":
print("hey")
img = util.imread_uint("utils/test.png", 3)
img = img[:448, :448]
h = img.shape[0] // 4
print("resizing to", h)
sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20):
print(i)
img_hq = img
img_lq = deg_fn(img)["image"]
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)[
"image"
]
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print(img_hq.shape)
lq_nearest = cv2.resize(
util.single2uint(img_lq),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
lq_bicubic_nearest = cv2.resize(
util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0,
)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
util.imsave(img_concat, str(i) + ".png")

Binary file not shown.

Before

Width:  |  Height:  |  Size: 431 KiB

View File

@ -1,968 +0,0 @@
import math
import os
import random
from datetime import datetime
import cv2
import numpy as np
import torch
from torchvision.utils import make_grid
import invokeai.backend.util.logging as logger
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
"""
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
# --------------------------------------------
# https://github.com/twhui/SRGAN-pyTorch
# https://github.com/xinntao/BasicSR
# --------------------------------------------
"""
IMG_EXTENSIONS = [
".jpg",
".JPG",
".jpeg",
".JPEG",
".png",
".PNG",
".ppm",
".PPM",
".bmp",
".BMP",
".tif",
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def get_timestamp():
return datetime.now().strftime("%y%m%d-%H%M%S")
def imshow(x, title=None, cbar=False, figsize=None):
import matplotlib.pyplot as plt
plt.figure(figsize=figsize)
plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray")
if title:
plt.title(title)
if cbar:
plt.colorbar()
plt.show()
def surf(Z, cmap="rainbow", figsize=None):
import matplotlib.pyplot as plt
plt.figure(figsize=figsize)
ax3 = plt.axes(projection="3d")
w, h = Z.shape[:2]
xx = np.arange(0, w, 1)
yy = np.arange(0, h, 1)
X, Y = np.meshgrid(xx, yy)
ax3.plot_surface(X, Y, Z, cmap=cmap)
# ax3.contour(X,Y,Z, zdim='z',offset=-2cmap=cmap)
plt.show()
"""
# --------------------------------------------
# get image pathes
# --------------------------------------------
"""
def get_image_paths(dataroot):
paths = None # return None if dataroot is None
if dataroot is not None:
paths = sorted(_get_paths_from_images(dataroot))
return paths
def _get_paths_from_images(path):
assert os.path.isdir(path), "{:s} is not a valid directory".format(path)
images = []
for dirpath, _, fnames in sorted(os.walk(path, followlinks=True)):
for fname in sorted(fnames):
if is_image_file(fname):
img_path = os.path.join(dirpath, fname)
images.append(img_path)
assert images, "{:s} has no valid image file".format(path)
return images
"""
# --------------------------------------------
# split large images into small images
# --------------------------------------------
"""
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
w, h = img.shape[:2]
patches = []
if w > p_max and h > p_max:
w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int))
h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int))
w1.append(w - p_size)
h1.append(h - p_size)
# print(w1)
# print(h1)
for i in w1:
for j in h1:
patches.append(img[i : i + p_size, j : j + p_size, :])
else:
patches.append(img)
return patches
def imssave(imgs, img_path):
"""
imgs: list, N images of size WxHxC
"""
img_name, ext = os.path.splitext(os.path.basename(img_path))
for i, img in enumerate(imgs):
if img.ndim == 3:
img = img[:, :, [2, 1, 0]]
new_path = os.path.join(
os.path.dirname(img_path),
img_name + str("_s{:04d}".format(i)) + ".png",
)
cv2.imwrite(new_path, img)
def split_imageset(
original_dataroot,
taget_dataroot,
n_channels=3,
p_size=800,
p_overlap=96,
p_max=1000,
):
"""
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
will be splitted.
Args:
original_dataroot:
taget_dataroot:
p_size: size of small images
p_overlap: patch size in training is a good choice
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
"""
paths = get_image_paths(original_dataroot)
for img_path in paths:
# img_name, ext = os.path.splitext(os.path.basename(img_path))
img = imread_uint(img_path, n_channels=n_channels)
patches = patches_from_image(img, p_size, p_overlap, p_max)
imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path)))
# if original_dataroot == taget_dataroot:
# del img_path
"""
# --------------------------------------------
# makedir
# --------------------------------------------
"""
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def mkdirs(paths):
if isinstance(paths, str):
mkdir(paths)
else:
for path in paths:
mkdir(path)
def mkdir_and_rename(path):
if os.path.exists(path):
new_name = path + "_archived_" + get_timestamp()
logger.error("Path already exists. Rename it to [{:s}]".format(new_name))
os.replace(path, new_name)
os.makedirs(path)
"""
# --------------------------------------------
# read image from path
# opencv is fast, but read BGR numpy image
# --------------------------------------------
"""
# --------------------------------------------
# get uint8 image of size HxWxn_channles (RGB)
# --------------------------------------------
def imread_uint(path, n_channels=3):
# input: path
# output: HxWx3(RGB or GGG), or HxWx1 (G)
if n_channels == 1:
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
img = np.expand_dims(img, axis=2) # HxWx1
elif n_channels == 3:
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
else:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
return img
# --------------------------------------------
# matlab's imwrite
# --------------------------------------------
def imsave(img, img_path):
img = np.squeeze(img)
if img.ndim == 3:
img = img[:, :, [2, 1, 0]]
cv2.imwrite(img_path, img)
def imwrite(img, img_path):
img = np.squeeze(img)
if img.ndim == 3:
img = img[:, :, [2, 1, 0]]
cv2.imwrite(img_path, img)
# --------------------------------------------
# get single image of size HxWxn_channles (BGR)
# --------------------------------------------
def read_img(path):
# read image by cv2
# return: Numpy float32, HWC, BGR, [0,1]
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
img = img.astype(np.float32) / 255.0
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return img
"""
# --------------------------------------------
# image format conversion
# --------------------------------------------
# numpy(single) <---> numpy(unit)
# numpy(single) <---> tensor
# numpy(unit) <---> tensor
# --------------------------------------------
"""
# --------------------------------------------
# numpy(single) [0, 1] <---> numpy(unit)
# --------------------------------------------
def uint2single(img):
return np.float32(img / 255.0)
def single2uint(img):
return np.uint8((img.clip(0, 1) * 255.0).round())
def uint162single(img):
return np.float32(img / 65535.0)
def single2uint16(img):
return np.uint16((img.clip(0, 1) * 65535.0).round())
# --------------------------------------------
# numpy(unit) (HxWxC or HxW) <---> tensor
# --------------------------------------------
# convert uint to 4-dimensional torch tensor
def uint2tensor4(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0).unsqueeze(0)
# convert uint to 3-dimensional torch tensor
def uint2tensor3(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0)
# convert 2/3/4-dimensional torch tensor to uint
def tensor2uint(img):
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
if img.ndim == 3:
img = np.transpose(img, (1, 2, 0))
return np.uint8((img * 255.0).round())
# --------------------------------------------
# numpy(single) (HxWxC) <---> tensor
# --------------------------------------------
# convert single (HxWxC) to 3-dimensional torch tensor
def single2tensor3(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
# convert single (HxWxC) to 4-dimensional torch tensor
def single2tensor4(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
# convert torch tensor to single
def tensor2single(img):
img = img.data.squeeze().float().cpu().numpy()
if img.ndim == 3:
img = np.transpose(img, (1, 2, 0))
return img
# convert torch tensor to single
def tensor2single3(img):
img = img.data.squeeze().float().cpu().numpy()
if img.ndim == 3:
img = np.transpose(img, (1, 2, 0))
elif img.ndim == 2:
img = np.expand_dims(img, axis=2)
return img
def single2tensor5(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
def single32tensor5(img):
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
def single42tensor4(img):
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
# from skimage.io import imread, imsave
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
"""
Converts a torch Tensor into an image Numpy array of BGR channel order
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
"""
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
n_dim = tensor.dim()
if n_dim == 4:
n_img = len(tensor)
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
elif n_dim == 3:
img_np = tensor.numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
elif n_dim == 2:
img_np = tensor.numpy()
else:
raise TypeError("Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format(n_dim))
if out_type == np.uint8:
img_np = (img_np * 255.0).round()
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
return img_np.astype(out_type)
"""
# --------------------------------------------
# Augmentation, flipe and/or rotate
# --------------------------------------------
# The following two are enough.
# (1) augmet_img: numpy image of WxHxC or WxH
# (2) augment_img_tensor4: tensor image 1xCxWxH
# --------------------------------------------
"""
def augment_img(img, mode=0):
"""Kai Zhang (github: https://github.com/cszn)"""
if mode == 0:
return img
elif mode == 1:
return np.flipud(np.rot90(img))
elif mode == 2:
return np.flipud(img)
elif mode == 3:
return np.rot90(img, k=3)
elif mode == 4:
return np.flipud(np.rot90(img, k=2))
elif mode == 5:
return np.rot90(img)
elif mode == 6:
return np.rot90(img, k=2)
elif mode == 7:
return np.flipud(np.rot90(img, k=3))
def augment_img_tensor4(img, mode=0):
"""Kai Zhang (github: https://github.com/cszn)"""
if mode == 0:
return img
elif mode == 1:
return img.rot90(1, [2, 3]).flip([2])
elif mode == 2:
return img.flip([2])
elif mode == 3:
return img.rot90(3, [2, 3])
elif mode == 4:
return img.rot90(2, [2, 3]).flip([2])
elif mode == 5:
return img.rot90(1, [2, 3])
elif mode == 6:
return img.rot90(2, [2, 3])
elif mode == 7:
return img.rot90(3, [2, 3]).flip([2])
def augment_img_tensor(img, mode=0):
"""Kai Zhang (github: https://github.com/cszn)"""
img_size = img.size()
img_np = img.data.cpu().numpy()
if len(img_size) == 3:
img_np = np.transpose(img_np, (1, 2, 0))
elif len(img_size) == 4:
img_np = np.transpose(img_np, (2, 3, 1, 0))
img_np = augment_img(img_np, mode=mode)
img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
if len(img_size) == 3:
img_tensor = img_tensor.permute(2, 0, 1)
elif len(img_size) == 4:
img_tensor = img_tensor.permute(3, 2, 0, 1)
return img_tensor.type_as(img)
def augment_img_np3(img, mode=0):
if mode == 0:
return img
elif mode == 1:
return img.transpose(1, 0, 2)
elif mode == 2:
return img[::-1, :, :]
elif mode == 3:
img = img[::-1, :, :]
img = img.transpose(1, 0, 2)
return img
elif mode == 4:
return img[:, ::-1, :]
elif mode == 5:
img = img[:, ::-1, :]
img = img.transpose(1, 0, 2)
return img
elif mode == 6:
img = img[:, ::-1, :]
img = img[::-1, :, :]
return img
elif mode == 7:
img = img[:, ::-1, :]
img = img[::-1, :, :]
img = img.transpose(1, 0, 2)
return img
def augment_imgs(img_list, hflip=True, rot=True):
# horizontal flip OR rotate
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
"""
# --------------------------------------------
# modcrop and shave
# --------------------------------------------
"""
def modcrop(img_in, scale):
# img_in: Numpy, HWC or HW
img = np.copy(img_in)
if img.ndim == 2:
H, W = img.shape
H_r, W_r = H % scale, W % scale
img = img[: H - H_r, : W - W_r]
elif img.ndim == 3:
H, W, C = img.shape
H_r, W_r = H % scale, W % scale
img = img[: H - H_r, : W - W_r, :]
else:
raise ValueError("Wrong img ndim: [{:d}].".format(img.ndim))
return img
def shave(img_in, border=0):
# img_in: Numpy, HWC or HW
img = np.copy(img_in)
h, w = img.shape[:2]
img = img[border : h - border, border : w - border]
return img
"""
# --------------------------------------------
# image processing process on numpy image
# channel_convert(in_c, tar_type, img_list):
# rgb2ycbcr(img, only_y=True):
# bgr2ycbcr(img, only_y=True):
# ycbcr2rgb(img):
# --------------------------------------------
"""
def rgb2ycbcr(img, only_y=True):
"""same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.0
# convert
if only_y:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
else:
rlt = np.matmul(
img,
[
[65.481, -37.797, 112.0],
[128.553, -74.203, -93.786],
[24.966, 112.0, -18.214],
],
) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.0
return rlt.astype(in_img_type)
def ycbcr2rgb(img):
"""same as matlab ycbcr2rgb
Input:
uint8, [0, 255]
float, [0, 1]
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.0
# convert
rlt = np.matmul(
img,
[
[0.00456621, 0.00456621, 0.00456621],
[0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0],
],
) * 255.0 + [-222.921, 135.576, -276.836]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.0
return rlt.astype(in_img_type)
def bgr2ycbcr(img, only_y=True):
"""bgr version of rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.0
# convert
if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else:
rlt = np.matmul(
img,
[
[24.966, 112.0, -18.214],
[128.553, -74.203, -93.786],
[65.481, -37.797, 112.0],
],
) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.0
return rlt.astype(in_img_type)
def channel_convert(in_c, tar_type, img_list):
# conversion among BGR, gray and y
if in_c == 3 and tar_type == "gray": # BGR to gray
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
return [np.expand_dims(img, axis=2) for img in gray_list]
elif in_c == 3 and tar_type == "y": # BGR to y
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
return [np.expand_dims(img, axis=2) for img in y_list]
elif in_c == 1 and tar_type == "RGB": # gray/y to BGR
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
else:
return img_list
"""
# --------------------------------------------
# metric, PSNR and SSIM
# --------------------------------------------
"""
# --------------------------------------------
# PSNR
# --------------------------------------------
def calculate_psnr(img1, img2, border=0):
# img1 and img2 have range [0, 255]
# img1 = img1.squeeze()
# img2 = img2.squeeze()
if not img1.shape == img2.shape:
raise ValueError("Input images must have the same dimensions.")
h, w = img1.shape[:2]
img1 = img1[border : h - border, border : w - border]
img2 = img2[border : h - border, border : w - border]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img1 - img2) ** 2)
if mse == 0:
return float("inf")
return 20 * math.log10(255.0 / math.sqrt(mse))
# --------------------------------------------
# SSIM
# --------------------------------------------
def calculate_ssim(img1, img2, border=0):
"""calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
"""
# img1 = img1.squeeze()
# img2 = img2.squeeze()
if not img1.shape == img2.shape:
raise ValueError("Input images must have the same dimensions.")
h, w = img1.shape[:2]
img1 = img1[border : h - border, border : w - border]
img2 = img2[border : h - border, border : w - border]
if img1.ndim == 2:
return ssim(img1, img2)
elif img1.ndim == 3:
if img1.shape[2] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(img1[:, :, i], img2[:, :, i]))
return np.array(ssims).mean()
elif img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
else:
raise ValueError("Wrong input image dimensions.")
def ssim(img1, img2):
C1 = (0.01 * 255) ** 2
C2 = (0.03 * 255) ** 2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
"""
# --------------------------------------------
# matlab's bicubic imresize (numpy and torch) [0, 1]
# --------------------------------------------
"""
# matlab 'imresize' function, now only support 'bicubic'
def cubic(x):
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
) * (((absx > 1) * (absx <= 2)).type_as(absx))
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
if (scale < 1) and (antialiasing):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width = kernel_width / scale
# Output-space coordinates
x = torch.linspace(1, out_length, out_length)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5+scale in output
# space maps to 1.5 in input space.
u = x / scale + 0.5 * (1 - 1 / scale)
# What is the left-most pixel that can be involved in the computation?
left = torch.floor(u - kernel_width / 2)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
P = math.ceil(kernel_width) + 2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(1, P).expand(
out_length, P
)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
# apply cubic kernel
if (scale < 1) and (antialiasing):
weights = scale * cubic(distance_to_center * scale)
else:
weights = cubic(distance_to_center)
# Normalize the weights matrix so that each row sums to 1.
weights_sum = torch.sum(weights, 1).view(out_length, 1)
weights = weights / weights_sum.expand(out_length, P)
# If a column in weights is all zero, get rid of it. only consider the first and last column.
weights_zero_tmp = torch.sum((weights == 0), 0)
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
indices = indices.narrow(1, 1, P - 2)
weights = weights.narrow(1, 1, P - 2)
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
indices = indices.narrow(1, 0, P - 2)
weights = weights.narrow(1, 0, P - 2)
weights = weights.contiguous()
indices = indices.contiguous()
sym_len_s = -indices.min() + 1
sym_len_e = indices.max() - in_length
indices = indices + sym_len_s - 1
return weights, indices, int(sym_len_s), int(sym_len_e)
# --------------------------------------------
# imresize for tensor image [0, 1]
# --------------------------------------------
def imresize(img, scale, antialiasing=True):
# Now the scale should be the same for H and W
# input: img: pytorch tensor, CHW or HW [0,1]
# output: CHW or HW [0,1] w/o round
need_squeeze = True if img.dim() == 2 else False
if need_squeeze:
img.unsqueeze_(0)
in_C, in_H, in_W = img.size()
out_C, out_H, out_W = (
in_C,
math.ceil(in_H * scale),
math.ceil(in_W * scale),
)
kernel_width = 4
kernel = "cubic"
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing
)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing
)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
sym_patch = img[:, :sym_len_Hs, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
sym_patch = img[:, -sym_len_He:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(in_C, out_H, in_W)
kernel_width = weights_H.size(1)
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[j, i, :] = img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
sym_patch = out_1[:, :, :sym_len_Ws]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
sym_patch = out_1[:, :, -sym_len_We:]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(in_C, out_H, out_W)
kernel_width = weights_W.size(1)
for i in range(out_W):
idx = int(indices_W[i][0])
for j in range(out_C):
out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(weights_W[i])
if need_squeeze:
out_2.squeeze_()
return out_2
# --------------------------------------------
# imresize for numpy image [0, 1]
# --------------------------------------------
def imresize_np(img, scale, antialiasing=True):
# Now the scale should be the same for H and W
# input: img: Numpy, HWC or HW [0,1]
# output: HWC or HW [0,1] w/o round
img = torch.from_numpy(img)
need_squeeze = True if img.dim() == 2 else False
if need_squeeze:
img.unsqueeze_(2)
in_H, in_W, in_C = img.size()
out_C, out_H, out_W = (
in_C,
math.ceil(in_H * scale),
math.ceil(in_W * scale),
)
kernel_width = 4
kernel = "cubic"
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing
)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing
)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
sym_patch = img[:sym_len_Hs, :, :]
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(0, inv_idx)
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
sym_patch = img[-sym_len_He:, :, :]
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(0, inv_idx)
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(out_H, in_W, in_C)
kernel_width = weights_H.size(1)
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[i, :, j] = img_aug[idx : idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
sym_patch = out_1[:, :sym_len_Ws, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
sym_patch = out_1[:, -sym_len_We:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(out_H, out_W, in_C)
kernel_width = weights_W.size(1)
for i in range(out_W):
idx = int(indices_W[i][0])
for j in range(out_C):
out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv(weights_W[i])
if need_squeeze:
out_2.squeeze_()
return out_2.numpy()
if __name__ == "__main__":
print("---")
# img = imread_uint('test.bmp', 3)
# img = uint2single(img)
# img_bicubic = imresize_np(img, 1/4)

View File

@ -10,7 +10,6 @@ from .devices import ( # noqa: F401
normalize_device,
torch_dtype,
)
from .log import write_log # noqa: F401
from .util import ( # noqa: F401
ask_user,
download_with_resume,

View File

@ -75,6 +75,7 @@
"@reduxjs/toolkit": "^1.9.5",
"@roarr/browser-log-writer": "^1.1.5",
"@stevebel/png": "^1.5.1",
"compare-versions": "^6.1.0",
"dateformat": "^5.0.3",
"formik": "^2.4.3",
"framer-motion": "^10.16.1",

View File

@ -84,6 +84,7 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
export const listenerMiddleware = createListenerMiddleware();
@ -202,6 +203,9 @@ addBoardIdSelectedListener();
// Node schemas
addReceivedOpenAPISchemaListener();
// Workflows
addWorkflowLoadedListener();
// DND
addImageDroppedListener();

View File

@ -0,0 +1,55 @@
import { logger } from 'app/logging/logger';
import { workflowLoadRequested } from 'features/nodes/store/actions';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { validateWorkflow } from 'features/nodes/util/validateWorkflow';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { startAppListening } from '..';
export const addWorkflowLoadedListener = () => {
startAppListening({
actionCreator: workflowLoadRequested,
effect: (action, { dispatch, getState }) => {
const log = logger('nodes');
const workflow = action.payload;
const nodeTemplates = getState().nodes.nodeTemplates;
const { workflow: validatedWorkflow, errors } = validateWorkflow(
workflow,
nodeTemplates
);
dispatch(workflowLoaded(validatedWorkflow));
if (!errors.length) {
dispatch(
addToast(
makeToast({
title: 'Workflow Loaded',
status: 'success',
})
)
);
} else {
dispatch(
addToast(
makeToast({
title: 'Workflow Loaded with Warnings',
status: 'warning',
})
)
);
errors.forEach(({ message, ...rest }) => {
log.warn(rest, message);
});
}
dispatch(setActiveTab('nodes'));
requestAnimationFrame(() => {
$flow.get()?.fitView();
});
},
});
};

View File

@ -63,7 +63,11 @@ const selector = createSelector(
return;
}
if (fieldTemplate.required && !field.value && !hasConnection) {
if (
fieldTemplate.required &&
field.value === undefined &&
!hasConnection
) {
reasons.push(
`${node.data.label || nodeTemplate.title} -> ${
field.label || fieldTemplate.title

View File

@ -1,2 +1,2 @@
export const colorTokenToCssVar = (colorToken: string) =>
`var(--invokeai-colors-${colorToken.split('.').join('-')}`;
`var(--invokeai-colors-${colorToken.split('.').join('-')})`;

View File

@ -17,16 +17,13 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import { workflowLoadRequested } from 'features/nodes/store/actions';
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
setActiveTab,
setShouldShowImageDetails,
setShouldShowProgressInViewer,
} from 'features/ui/store/uiSlice';
@ -110,7 +107,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
);
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
lastSelectedImage?.image_name ?? skipToken,
lastSelectedImage ?? skipToken,
{
selectFromResult: (res) => ({
isLoading: res.isFetching,
@ -124,16 +121,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
if (!workflow) {
return;
}
dispatch(workflowLoaded(workflow));
dispatch(setActiveTab('nodes'));
dispatch(
addToast(
makeToast({
title: 'Workflow Loaded',
status: 'success',
})
)
);
dispatch(workflowLoadRequested(workflow));
}, [dispatch, workflow]);
const handleClickUseAllParameters = useCallback(() => {

View File

@ -7,12 +7,9 @@ import {
isModalOpenChanged,
} from 'features/changeBoardModal/store/slice';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { memo, useCallback } from 'react';
@ -36,6 +33,7 @@ import {
} from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
import { workflowLoadRequested } from 'features/nodes/store/actions';
type SingleSelectionMenuItemsProps = {
imageDTO: ImageDTO;
@ -52,7 +50,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
imageDTO.image_name,
imageDTO,
{
selectFromResult: (res) => ({
isLoading: res.isFetching,
@ -102,16 +100,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
if (!workflow) {
return;
}
dispatch(workflowLoaded(workflow));
dispatch(setActiveTab('nodes'));
dispatch(
addToast(
makeToast({
title: 'Workflow Loaded',
status: 'success',
})
)
);
dispatch(workflowLoadRequested(workflow));
}, [dispatch, workflow]);
const handleSendToImageToImage = useCallback(() => {

View File

@ -101,13 +101,15 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallSeed}
/>
)}
{metadata.model !== undefined && metadata.model !== null && (
<ImageMetadataItem
label="Model"
value={metadata.model.model_name}
onClick={handleRecallModel}
/>
)}
{metadata.model !== undefined &&
metadata.model !== null &&
metadata.model.model_name && (
<ImageMetadataItem
label="Model"
value={metadata.model.model_name}
onClick={handleRecallModel}
/>
)}
{metadata.width && (
<ImageMetadataItem
label="Width"

View File

@ -27,15 +27,12 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
// dispatch(setShouldShowImageDetails(false));
// });
const { metadata, workflow } = useGetImageMetadataFromFileQuery(
image.image_name,
{
selectFromResult: (res) => ({
metadata: res?.currentData?.metadata,
workflow: res?.currentData?.workflow,
}),
}
);
const { metadata, workflow } = useGetImageMetadataFromFileQuery(image, {
selectFromResult: (res) => ({
metadata: res?.currentData?.metadata,
workflow: res?.currentData?.workflow,
}),
});
return (
<Flex

View File

@ -3,6 +3,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { contextMenusClosed } from 'features/ui/store/uiSlice';
import { useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
@ -13,6 +14,7 @@ import {
OnConnectStart,
OnEdgesChange,
OnEdgesDelete,
OnInit,
OnMoveEnd,
OnNodesChange,
OnNodesDelete,
@ -147,6 +149,11 @@ export const Flow = () => {
dispatch(contextMenusClosed());
}, [dispatch]);
const onInit: OnInit = useCallback((flow) => {
$flow.set(flow);
flow.fitView();
}, []);
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
e.preventDefault();
dispatch(selectionCopied());
@ -170,6 +177,7 @@ export const Flow = () => {
edgeTypes={edgeTypes}
nodes={nodes}
edges={edges}
onInit={onInit}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onEdgesDelete={onEdgesDelete}

View File

@ -12,6 +12,7 @@ import {
Tooltip,
useDisclosure,
} from '@chakra-ui/react';
import { compare } from 'compare-versions';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
@ -20,6 +21,7 @@ import { isInvocationNodeData } from 'features/nodes/types/types';
import { memo, useMemo } from 'react';
import { FaInfoCircle } from 'react-icons/fa';
import NotesTextarea from './NotesTextarea';
import { useDoNodeVersionsMatch } from 'features/nodes/hooks/useDoNodeVersionsMatch';
interface Props {
nodeId: string;
@ -29,6 +31,7 @@ const InvocationNodeNotes = ({ nodeId }: Props) => {
const { isOpen, onOpen, onClose } = useDisclosure();
const label = useNodeLabel(nodeId);
const title = useNodeTemplateTitle(nodeId);
const doVersionsMatch = useDoNodeVersionsMatch(nodeId);
return (
<>
@ -50,7 +53,11 @@ const InvocationNodeNotes = ({ nodeId }: Props) => {
>
<Icon
as={FaInfoCircle}
sx={{ boxSize: 4, w: 8, color: 'base.400' }}
sx={{
boxSize: 4,
w: 8,
color: doVersionsMatch ? 'base.400' : 'error.400',
}}
/>
</Flex>
</Tooltip>
@ -92,16 +99,59 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
return 'Unknown Node';
}, [data, nodeTemplate]);
const versionComponent = useMemo(() => {
if (!isInvocationNodeData(data) || !nodeTemplate) {
return null;
}
if (!data.version) {
return (
<Text as="span" sx={{ color: 'error.500' }}>
Version unknown
</Text>
);
}
if (!nodeTemplate.version) {
return (
<Text as="span" sx={{ color: 'error.500' }}>
Version {data.version} (unknown template)
</Text>
);
}
if (compare(data.version, nodeTemplate.version, '<')) {
return (
<Text as="span" sx={{ color: 'error.500' }}>
Version {data.version} (update node)
</Text>
);
}
if (compare(data.version, nodeTemplate.version, '>')) {
return (
<Text as="span" sx={{ color: 'error.500' }}>
Version {data.version} (update app)
</Text>
);
}
return <Text as="span">Version {data.version}</Text>;
}, [data, nodeTemplate]);
if (!isInvocationNodeData(data)) {
return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>;
}
return (
<Flex sx={{ flexDir: 'column' }}>
<Text sx={{ fontWeight: 600 }}>{title}</Text>
<Text as="span" sx={{ fontWeight: 600 }}>
{title}
</Text>
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
{nodeTemplate?.description}
</Text>
{versionComponent}
{data?.notes && <Text>{data.notes}</Text>}
</Flex>
);

View File

@ -1,8 +1,11 @@
import { Tooltip } from '@chakra-ui/react';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import {
COLLECTION_TYPES,
FIELDS,
HANDLE_TOOLTIP_OPEN_DELAY,
MODEL_TYPES,
POLYMORPHIC_TYPES,
} from 'features/nodes/types/constants';
import {
InputFieldTemplate,
@ -18,6 +21,7 @@ export const handleBaseStyles: CSSProperties = {
borderWidth: 0,
zIndex: 1,
};
``;
export const inputHandleStyles: CSSProperties = {
left: '-1rem',
@ -44,15 +48,25 @@ const FieldHandle = (props: FieldHandleProps) => {
connectionError,
} = props;
const { name, type } = fieldTemplate;
const { color, title } = FIELDS[type];
const { color: typeColor, title } = FIELDS[type];
const styles: CSSProperties = useMemo(() => {
const isCollectionType = COLLECTION_TYPES.includes(type);
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
const isModelType = MODEL_TYPES.includes(type);
const color = colorTokenToCssVar(typeColor);
const s: CSSProperties = {
backgroundColor: colorTokenToCssVar(color),
backgroundColor:
isCollectionType || isPolymorphicType
? 'var(--invokeai-colors-base-900)'
: color,
position: 'absolute',
width: '1rem',
height: '1rem',
borderWidth: 0,
borderWidth: isCollectionType || isPolymorphicType ? 4 : 0,
borderStyle: 'solid',
borderColor: color,
borderRadius: isModelType ? 4 : '100%',
zIndex: 1,
};
@ -78,11 +92,12 @@ const FieldHandle = (props: FieldHandleProps) => {
return s;
}, [
color,
connectionError,
handleType,
isConnectionInProgress,
isConnectionStartField,
type,
typeColor,
]);
const tooltip = useMemo(() => {

View File

@ -75,6 +75,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
sx={{
display: 'flex',
alignItems: 'center',
h: 'full',
mb: 0,
px: 1,
gap: 2,

View File

@ -3,18 +3,10 @@ import { useFieldData } from 'features/nodes/hooks/useFieldData';
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import { memo } from 'react';
import BooleanInputField from './inputs/BooleanInputField';
import ClipInputField from './inputs/ClipInputField';
import CollectionInputField from './inputs/CollectionInputField';
import CollectionItemInputField from './inputs/CollectionItemInputField';
import ColorInputField from './inputs/ColorInputField';
import ConditioningInputField from './inputs/ConditioningInputField';
import ControlInputField from './inputs/ControlInputField';
import ControlNetModelInputField from './inputs/ControlNetModelInputField';
import DenoiseMaskInputField from './inputs/DenoiseMaskInputField';
import EnumInputField from './inputs/EnumInputField';
import ImageCollectionInputField from './inputs/ImageCollectionInputField';
import ImageInputField from './inputs/ImageInputField';
import LatentsInputField from './inputs/LatentsInputField';
import LoRAModelInputField from './inputs/LoRAModelInputField';
import MainModelInputField from './inputs/MainModelInputField';
import NumberInputField from './inputs/NumberInputField';
@ -22,8 +14,6 @@ import RefinerModelInputField from './inputs/RefinerModelInputField';
import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
import SchedulerInputField from './inputs/SchedulerInputField';
import StringInputField from './inputs/StringInputField';
import UnetInputField from './inputs/UnetInputField';
import VaeInputField from './inputs/VaeInputField';
import VaeModelInputField from './inputs/VaeModelInputField';
type InputFieldProps = {
@ -31,7 +21,6 @@ type InputFieldProps = {
fieldName: string;
};
// build an individual input element based on the schema
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const field = useFieldData(nodeId, fieldName);
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
@ -93,88 +82,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
);
}
if (
field?.type === 'LatentsField' &&
fieldTemplate?.type === 'LatentsField'
) {
return (
<LatentsInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'DenoiseMaskField' &&
fieldTemplate?.type === 'DenoiseMaskField'
) {
return (
<DenoiseMaskInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ConditioningField' &&
fieldTemplate?.type === 'ConditioningField'
) {
return (
<ConditioningInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'UNetField' && fieldTemplate?.type === 'UNetField') {
return (
<UnetInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'ClipField' && fieldTemplate?.type === 'ClipField') {
return (
<ClipInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'VaeField' && fieldTemplate?.type === 'VaeField') {
return (
<VaeInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ControlField' &&
fieldTemplate?.type === 'ControlField'
) {
return (
<ControlInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'MainModelField' &&
fieldTemplate?.type === 'MainModelField'
@ -240,29 +147,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
);
}
if (field?.type === 'Collection' && fieldTemplate?.type === 'Collection') {
return (
<CollectionInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'CollectionItem' &&
fieldTemplate?.type === 'CollectionItem'
) {
return (
<CollectionItemInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
return (
<ColorInputField
@ -273,19 +157,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
);
}
if (
field?.type === 'ImageCollection' &&
fieldTemplate?.type === 'ImageCollection'
) {
return (
<ImageCollectionInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'SDXLMainModelField' &&
fieldTemplate?.type === 'SDXLMainModelField'
@ -309,6 +180,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
);
}
if (field && fieldTemplate) {
// Fallback for when there is no component for the type
return null;
}
return (
<Box p={1}>
<Text

View File

@ -1,12 +1,17 @@
import {
ControlInputFieldTemplate,
ControlInputFieldValue,
ControlPolymorphicInputFieldTemplate,
ControlPolymorphicInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const ControlInputFieldComponent = (
_props: FieldComponentProps<ControlInputFieldValue, ControlInputFieldTemplate>
_props: FieldComponentProps<
ControlInputFieldValue | ControlPolymorphicInputFieldValue,
ControlInputFieldTemplate | ControlPolymorphicInputFieldTemplate
>
) => {
return null;
};

View File

@ -9,9 +9,9 @@ import {
} from 'features/dnd/types';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import {
FieldComponentProps,
ImageInputFieldTemplate,
ImageInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo, useCallback, useMemo } from 'react';
import { FaUndo } from 'react-icons/fa';

View File

@ -2,11 +2,16 @@ import {
LatentsInputFieldTemplate,
LatentsInputFieldValue,
FieldComponentProps,
LatentsPolymorphicInputFieldValue,
LatentsPolymorphicInputFieldTemplate,
} from 'features/nodes/types/types';
import { memo } from 'react';
const LatentsInputFieldComponent = (
_props: FieldComponentProps<LatentsInputFieldValue, LatentsInputFieldTemplate>
_props: FieldComponentProps<
LatentsInputFieldValue | LatentsPolymorphicInputFieldValue,
LatentsInputFieldTemplate | LatentsPolymorphicInputFieldTemplate
>
) => {
return null;
};

View File

@ -9,11 +9,11 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { numberStringRegex } from 'common/components/IAINumberInput';
import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice';
import {
FieldComponentProps,
FloatInputFieldTemplate,
FloatInputFieldValue,
IntegerInputFieldTemplate,
IntegerInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo, useEffect, useMemo, useState } from 'react';

View File

@ -9,13 +9,20 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { nodeExclusivelySelected } from 'features/nodes/store/nodesSlice';
import {
DRAG_HANDLE_CLASSNAME,
NODE_WIDTH,
} from 'features/nodes/types/constants';
import { NodeStatus } from 'features/nodes/types/types';
import { contextMenusClosed } from 'features/ui/store/uiSlice';
import { PropsWithChildren, memo, useCallback, useMemo } from 'react';
import {
MouseEvent,
PropsWithChildren,
memo,
useCallback,
useMemo,
} from 'react';
type NodeWrapperProps = PropsWithChildren & {
nodeId: string;
@ -57,9 +64,15 @@ const NodeWrapper = (props: NodeWrapperProps) => {
const opacity = useAppSelector((state) => state.nodes.nodeOpacity);
const handleClick = useCallback(() => {
dispatch(contextMenusClosed());
}, [dispatch]);
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (!e.ctrlKey && !e.altKey && !e.metaKey && !e.shiftKey) {
dispatch(nodeExclusivelySelected(nodeId));
}
dispatch(contextMenusClosed());
},
[dispatch, nodeId]
);
return (
<Box

View File

@ -138,13 +138,14 @@ export const useBuildNodeData = () => {
data: {
id: nodeId,
type,
inputs,
outputs,
isOpen: true,
version: template.version,
label: '',
notes: '',
isOpen: true,
embedWorkflow: false,
isIntermediate: true,
inputs,
outputs,
},
};

View File

@ -0,0 +1,33 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { compareVersions } from 'compare-versions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
export const useDoNodeVersionsMatch = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
if (!nodeTemplate?.version || !node.data?.version) {
return false;
}
return compareVersions(nodeTemplate.version, node.data.version) === 0;
},
defaultSelectorOptions
),
[nodeId]
);
const nodeTemplate = useAppSelector(selector);
return nodeTemplate;
};

View File

@ -15,7 +15,7 @@ export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
if (!isInvocationNode(node)) {
return;
}
return Boolean(node?.data.inputs[fieldName]?.value);
return node?.data.inputs[fieldName]?.value !== undefined;
},
defaultSelectorOptions
),

View File

@ -3,9 +3,19 @@ import graphlib from '@dagrejs/graphlib';
import { useAppSelector } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { Connection, Edge, Node, useReactFlow } from 'reactflow';
import { COLLECTION_TYPES } from '../types/constants';
import {
COLLECTION_MAP,
COLLECTION_TYPES,
POLYMORPHIC_TO_SINGLE_MAP,
POLYMORPHIC_TYPES,
} from '../types/constants';
import { InvocationNodeData } from '../types/types';
/**
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
* TODO: Figure out how to do this without duplicating all the logic
*/
export const useIsValidConnection = () => {
const flow = useReactFlow();
const shouldValidateGraph = useAppSelector(
@ -42,6 +52,19 @@ export const useIsValidConnection = () => {
return false;
}
if (
edges
.filter((edge) => {
return edge.target === target && edge.targetHandle === targetHandle;
})
.find((edge) => {
edge.source === source && edge.sourceHandle === sourceHandle;
})
) {
// We already have a connection from this source to this target
return false;
}
// Connection is invalid if target already has a connection
if (
edges.find((edge) => {
@ -53,21 +76,62 @@ export const useIsValidConnection = () => {
return false;
}
// Connection types must be the same for a connection
if (
sourceType !== targetType &&
sourceType !== 'CollectionItem' &&
targetType !== 'CollectionItem'
) {
if (
!(
COLLECTION_TYPES.includes(targetType) &&
COLLECTION_TYPES.includes(sourceType)
)
) {
return false;
}
/**
* Connection types must be the same for a connection, with exceptions:
* - CollectionItem can connect to any non-Collection
* - Non-Collections can connect to CollectionItem
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
* - Generic Collection can connect to any other Collection or Polymorphic
* - Any Collection can connect to a Generic Collection
*/
if (sourceType !== targetType) {
const isCollectionItemToNonCollection =
sourceType === 'CollectionItem' &&
!COLLECTION_TYPES.includes(targetType);
const isNonCollectionToCollectionItem =
targetType === 'CollectionItem' &&
!COLLECTION_TYPES.includes(sourceType) &&
!POLYMORPHIC_TYPES.includes(sourceType);
const isAnythingToPolymorphicOfSameBaseType =
POLYMORPHIC_TYPES.includes(targetType) &&
(() => {
if (!POLYMORPHIC_TYPES.includes(targetType)) {
return false;
}
const baseType =
POLYMORPHIC_TO_SINGLE_MAP[
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
];
const collectionType =
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
return sourceType === baseType || sourceType === collectionType;
})();
const isGenericCollectionToAnyCollectionOrPolymorphic =
sourceType === 'Collection' &&
(COLLECTION_TYPES.includes(targetType) ||
POLYMORPHIC_TYPES.includes(targetType));
const isCollectionToGenericCollection =
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
return (
isCollectionItemToNonCollection ||
isNonCollectionToCollectionItem ||
isAnythingToPolymorphicOfSameBaseType ||
isGenericCollectionToAnyCollectionOrPolymorphic ||
isCollectionToGenericCollection ||
isIntToFloat
);
}
// Graphs much be acyclic (no loops!)
return getIsGraphAcyclic(source, target, nodes, edges);
},

View File

@ -2,13 +2,13 @@ import { ListItem, Text, UnorderedList } from '@chakra-ui/react';
import { useLogger } from 'app/logging/useLogger';
import { useAppDispatch } from 'app/store/storeHooks';
import { parseify } from 'common/util/serialize';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import { zValidatedWorkflow } from 'features/nodes/types/types';
import { zWorkflow } from 'features/nodes/types/types';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback } from 'react';
import { ZodError } from 'zod';
import { fromZodError, fromZodIssue } from 'zod-validation-error';
import { workflowLoadRequested } from '../store/actions';
export const useLoadWorkflowFromFile = () => {
const dispatch = useAppDispatch();
@ -24,7 +24,7 @@ export const useLoadWorkflowFromFile = () => {
try {
const parsedJSON = JSON.parse(String(rawJSON));
const result = zValidatedWorkflow.safeParse(parsedJSON);
const result = zWorkflow.safeParse(parsedJSON);
if (!result.success) {
const { message } = fromZodError(result.error, {
@ -45,32 +45,8 @@ export const useLoadWorkflowFromFile = () => {
reader.abort();
return;
}
dispatch(workflowLoaded(result.data.workflow));
if (!result.data.warnings.length) {
dispatch(
addToast(
makeToast({
title: 'Workflow Loaded',
status: 'success',
})
)
);
reader.abort();
return;
}
dispatch(
addToast(
makeToast({
title: 'Workflow Loaded with Warnings',
status: 'warning',
})
)
);
result.data.warnings.forEach(({ message, ...rest }) => {
logger.warn(rest, message);
});
dispatch(workflowLoadRequested(result.data));
reader.abort();
} catch {

View File

@ -1,5 +1,6 @@
import { createAction, isAnyOf } from '@reduxjs/toolkit';
import { Graph } from 'services/api/types';
import { Workflow } from '../types/types';
export const textToImageGraphBuilt = createAction<Graph>(
'nodes/textToImageGraphBuilt'
@ -16,3 +17,7 @@ export const isAnyGraphBuilt = isAnyOf(
canvasGraphBuilt,
nodesGraphBuilt
);
export const workflowLoadRequested = createAction<Workflow>(
'nodes/workflowLoadRequested'
);

View File

@ -443,6 +443,17 @@ const nodesSlice = createSlice({
}
node.data.notes = notes;
},
nodeExclusivelySelected: (state, action: PayloadAction<string>) => {
const nodeId = action.payload;
state.nodes = applyNodeChanges(
state.nodes.map((n) => ({
id: n.id,
type: 'select',
selected: n.id === nodeId ? true : false,
})),
state.nodes
);
},
selectedNodesChanged: (state, action: PayloadAction<string[]>) => {
state.selectedNodes = action.payload;
},
@ -892,6 +903,7 @@ export const {
nodeEmbedWorkflowChanged,
nodeIsIntermediateChanged,
mouseOverNodeChanged,
nodeExclusivelySelected,
} = nodesSlice.actions;
export default nodesSlice.reducer;

View File

@ -0,0 +1,4 @@
import { atom } from 'nanostores';
import { ReactFlowInstance } from 'reactflow';
export const $flow = atom<ReactFlowInstance | null>(null);

View File

@ -1,10 +1,20 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection';
import { COLLECTION_TYPES } from 'features/nodes/types/constants';
import {
COLLECTION_MAP,
COLLECTION_TYPES,
POLYMORPHIC_TO_SINGLE_MAP,
POLYMORPHIC_TYPES,
} from 'features/nodes/types/constants';
import { FieldType } from 'features/nodes/types/types';
import { HandleType } from 'reactflow';
/**
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
* TODO: Figure out how to do this without duplicating all the logic
*/
export const makeConnectionErrorSelector = (
nodeId: string,
fieldName: string,
@ -19,11 +29,6 @@ export const makeConnectionErrorSelector = (
const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
state.nodes;
if (!state.nodes.shouldValidateGraph) {
// manual override!
return null;
}
if (!connectionStartParams || !currentConnectionFieldType) {
return 'No connection in progress';
}
@ -38,9 +43,9 @@ export const makeConnectionErrorSelector = (
return 'No connection data';
}
const targetFieldType =
const targetType =
handleType === 'target' ? fieldType : currentConnectionFieldType;
const sourceFieldType =
const sourceType =
handleType === 'source' ? fieldType : currentConnectionFieldType;
if (nodeId === connectionNodeId) {
@ -55,30 +60,73 @@ export const makeConnectionErrorSelector = (
}
if (
fieldType !== currentConnectionFieldType &&
fieldType !== 'CollectionItem' &&
currentConnectionFieldType !== 'CollectionItem'
) {
if (
!(
COLLECTION_TYPES.includes(targetFieldType) &&
COLLECTION_TYPES.includes(sourceFieldType)
)
) {
// except for collection items, field types must match
return 'Field types must match';
}
}
if (
handleType === 'target' &&
edges.find((edge) => {
return edge.target === nodeId && edge.targetHandle === fieldName;
}) &&
// except CollectionItem inputs can have multiples
targetFieldType !== 'CollectionItem'
targetType !== 'CollectionItem'
) {
return 'Inputs may only have one connection';
return 'Input may only have one connection';
}
/**
* Connection types must be the same for a connection, with exceptions:
* - CollectionItem can connect to any non-Collection
* - Non-Collections can connect to CollectionItem
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
* - Generic Collection can connect to any other Collection or Polymorphic
* - Any Collection can connect to a Generic Collection
*/
if (sourceType !== targetType) {
const isCollectionItemToNonCollection =
sourceType === 'CollectionItem' &&
!COLLECTION_TYPES.includes(targetType);
const isNonCollectionToCollectionItem =
targetType === 'CollectionItem' &&
!COLLECTION_TYPES.includes(sourceType) &&
!POLYMORPHIC_TYPES.includes(sourceType);
const isAnythingToPolymorphicOfSameBaseType =
POLYMORPHIC_TYPES.includes(targetType) &&
(() => {
if (!POLYMORPHIC_TYPES.includes(targetType)) {
return false;
}
const baseType =
POLYMORPHIC_TO_SINGLE_MAP[
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
];
const collectionType =
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
return sourceType === baseType || sourceType === collectionType;
})();
const isGenericCollectionToAnyCollectionOrPolymorphic =
sourceType === 'Collection' &&
(COLLECTION_TYPES.includes(targetType) ||
POLYMORPHIC_TYPES.includes(targetType));
const isCollectionToGenericCollection =
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
if (
!(
isCollectionItemToNonCollection ||
isNonCollectionToCollectionItem ||
isAnythingToPolymorphicOfSameBaseType ||
isGenericCollectionToAnyCollectionOrPolymorphic ||
isCollectionToGenericCollection ||
isIntToFloat
)
) {
return 'Field types must match';
}
}
const isGraphAcyclic = getIsGraphAcyclic(

View File

@ -17,176 +17,297 @@ export const KIND_MAP = {
export const COLLECTION_TYPES: FieldType[] = [
'Collection',
'IntegerCollection',
'BooleanCollection',
'FloatCollection',
'StringCollection',
'BooleanCollection',
'ImageCollection',
'LatentsCollection',
'ConditioningCollection',
'ControlCollection',
'ColorCollection',
];
export const POLYMORPHIC_TYPES = [
'IntegerPolymorphic',
'BooleanPolymorphic',
'FloatPolymorphic',
'StringPolymorphic',
'ImagePolymorphic',
'LatentsPolymorphic',
'ConditioningPolymorphic',
'ControlPolymorphic',
'ColorPolymorphic',
];
export const MODEL_TYPES = [
'ControlNetModelField',
'LoRAModelField',
'MainModelField',
'ONNXModelField',
'SDXLMainModelField',
'SDXLRefinerModelField',
'VaeModelField',
'UNetField',
'VaeField',
'ClipField',
];
export const COLLECTION_MAP = {
integer: 'IntegerCollection',
boolean: 'BooleanCollection',
number: 'FloatCollection',
float: 'FloatCollection',
string: 'StringCollection',
ImageField: 'ImageCollection',
LatentsField: 'LatentsCollection',
ConditioningField: 'ConditioningCollection',
ControlField: 'ControlCollection',
ColorField: 'ColorCollection',
};
export const isCollectionItemType = (
itemType: string | undefined
): itemType is keyof typeof COLLECTION_MAP =>
Boolean(itemType && itemType in COLLECTION_MAP);
export const SINGLE_TO_POLYMORPHIC_MAP = {
integer: 'IntegerPolymorphic',
boolean: 'BooleanPolymorphic',
number: 'FloatPolymorphic',
float: 'FloatPolymorphic',
string: 'StringPolymorphic',
ImageField: 'ImagePolymorphic',
LatentsField: 'LatentsPolymorphic',
ConditioningField: 'ConditioningPolymorphic',
ControlField: 'ControlPolymorphic',
ColorField: 'ColorPolymorphic',
};
export const POLYMORPHIC_TO_SINGLE_MAP = {
IntegerPolymorphic: 'integer',
BooleanPolymorphic: 'boolean',
FloatPolymorphic: 'float',
StringPolymorphic: 'string',
ImagePolymorphic: 'ImageField',
LatentsPolymorphic: 'LatentsField',
ConditioningPolymorphic: 'ConditioningField',
ControlPolymorphic: 'ControlField',
ColorPolymorphic: 'ColorField',
};
export const isPolymorphicItemType = (
itemType: string | undefined
): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP =>
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);
export const FIELDS: Record<FieldType, FieldUIConfig> = {
integer: {
title: 'Integer',
description: 'Integers are whole numbers, without a decimal point.',
color: 'red.500',
},
float: {
title: 'Float',
description: 'Floats are numbers with a decimal point.',
color: 'orange.500',
},
string: {
title: 'String',
description: 'Strings are text.',
color: 'yellow.500',
},
boolean: {
title: 'Boolean',
color: 'green.500',
description: 'Booleans are true or false.',
title: 'Boolean',
},
enum: {
title: 'Enum',
description: 'Enums are values that may be one of a number of options.',
color: 'blue.500',
BooleanCollection: {
color: 'green.500',
description: 'A collection of booleans.',
title: 'Boolean Collection',
},
array: {
title: 'Array',
description: 'Enums are values that may be one of a number of options.',
color: 'base.500',
},
ImageField: {
title: 'Image',
description: 'Images may be passed between nodes.',
color: 'purple.500',
},
DenoiseMaskField: {
title: 'Denoise Mask',
description: 'Denoise Mask may be passed between nodes',
color: 'base.500',
},
LatentsField: {
title: 'Latents',
description: 'Latents may be passed between nodes.',
color: 'pink.500',
},
LatentsCollection: {
title: 'Latents Collection',
description: 'Latents may be passed between nodes.',
color: 'pink.500',
},
ConditioningField: {
color: 'cyan.500',
title: 'Conditioning',
description: 'Conditioning may be passed between nodes.',
},
ConditioningCollection: {
color: 'cyan.500',
title: 'Conditioning Collection',
description: 'Conditioning may be passed between nodes.',
},
ImageCollection: {
title: 'Image Collection',
description: 'A collection of images.',
color: 'base.300',
},
UNetField: {
color: 'red.500',
title: 'UNet',
description: 'UNet submodel.',
BooleanPolymorphic: {
color: 'green.500',
description: 'A collection of booleans.',
title: 'Boolean Polymorphic',
},
ClipField: {
color: 'green.500',
title: 'Clip',
description: 'Tokenizer and text_encoder submodels.',
},
VaeField: {
color: 'blue.500',
title: 'Vae',
description: 'Vae submodel.',
},
ControlField: {
color: 'cyan.500',
title: 'Control',
description: 'Control info passed between nodes.',
},
MainModelField: {
color: 'teal.500',
title: 'Model',
description: 'TODO',
},
SDXLRefinerModelField: {
color: 'teal.500',
title: 'Refiner Model',
description: 'TODO',
},
VaeModelField: {
color: 'teal.500',
title: 'VAE',
description: 'TODO',
},
LoRAModelField: {
color: 'teal.500',
title: 'LoRA',
description: 'TODO',
},
ControlNetModelField: {
color: 'teal.500',
title: 'ControlNet',
description: 'TODO',
},
Scheduler: {
color: 'base.500',
title: 'Scheduler',
description: 'TODO',
title: 'Clip',
},
Collection: {
color: 'base.500',
title: 'Collection',
description: 'TODO',
title: 'Collection',
},
CollectionItem: {
color: 'base.500',
title: 'Collection Item',
description: 'TODO',
title: 'Collection Item',
},
ColorCollection: {
color: 'pink.300',
description: 'A collection of colors.',
title: 'Color Collection',
},
ColorField: {
title: 'Color',
color: 'pink.300',
description: 'A RGBA color.',
color: 'base.500',
title: 'Color',
},
BooleanCollection: {
title: 'Boolean Collection',
description: 'A collection of booleans.',
color: 'green.500',
ColorPolymorphic: {
color: 'pink.300',
description: 'A collection of colors.',
title: 'Color Polymorphic',
},
IntegerCollection: {
title: 'Integer Collection',
description: 'A collection of integers.',
color: 'red.500',
ConditioningCollection: {
color: 'cyan.500',
description: 'Conditioning may be passed between nodes.',
title: 'Conditioning Collection',
},
ConditioningField: {
color: 'cyan.500',
description: 'Conditioning may be passed between nodes.',
title: 'Conditioning',
},
ConditioningPolymorphic: {
color: 'cyan.500',
description: 'Conditioning may be passed between nodes.',
title: 'Conditioning Polymorphic',
},
ControlCollection: {
color: 'teal.500',
description: 'Control info passed between nodes.',
title: 'Control Collection',
},
ControlField: {
color: 'teal.500',
description: 'Control info passed between nodes.',
title: 'Control',
},
ControlNetModelField: {
color: 'teal.500',
description: 'TODO',
title: 'ControlNet',
},
ControlPolymorphic: {
color: 'teal.500',
description: 'Control info passed between nodes.',
title: 'Control Polymorphic',
},
DenoiseMaskField: {
color: 'blue.300',
description: 'Denoise Mask may be passed between nodes',
title: 'Denoise Mask',
},
enum: {
color: 'blue.500',
description: 'Enums are values that may be one of a number of options.',
title: 'Enum',
},
float: {
color: 'orange.500',
description: 'Floats are numbers with a decimal point.',
title: 'Float',
},
FloatCollection: {
color: 'orange.500',
title: 'Float Collection',
description: 'A collection of floats.',
title: 'Float Collection',
},
ColorCollection: {
color: 'base.500',
title: 'Color Collection',
description: 'A collection of colors.',
FloatPolymorphic: {
color: 'orange.500',
description: 'A collection of floats.',
title: 'Float Polymorphic',
},
ImageCollection: {
color: 'purple.500',
description: 'A collection of images.',
title: 'Image Collection',
},
ImageField: {
color: 'purple.500',
description: 'Images may be passed between nodes.',
title: 'Image',
},
ImagePolymorphic: {
color: 'purple.500',
description: 'A collection of images.',
title: 'Image Polymorphic',
},
integer: {
color: 'red.500',
description: 'Integers are whole numbers, without a decimal point.',
title: 'Integer',
},
IntegerCollection: {
color: 'red.500',
description: 'A collection of integers.',
title: 'Integer Collection',
},
IntegerPolymorphic: {
color: 'red.500',
description: 'A collection of integers.',
title: 'Integer Polymorphic',
},
LatentsCollection: {
color: 'pink.500',
description: 'Latents may be passed between nodes.',
title: 'Latents Collection',
},
LatentsField: {
color: 'pink.500',
description: 'Latents may be passed between nodes.',
title: 'Latents',
},
LatentsPolymorphic: {
color: 'pink.500',
description: 'Latents may be passed between nodes.',
title: 'Latents Polymorphic',
},
LoRAModelField: {
color: 'teal.500',
description: 'TODO',
title: 'LoRA',
},
MainModelField: {
color: 'teal.500',
description: 'TODO',
title: 'Model',
},
ONNXModelField: {
color: 'base.500',
title: 'ONNX Model',
color: 'teal.500',
description: 'ONNX model field.',
title: 'ONNX Model',
},
Scheduler: {
color: 'base.500',
description: 'TODO',
title: 'Scheduler',
},
SDXLMainModelField: {
color: 'base.500',
title: 'SDXL Model',
color: 'teal.500',
description: 'SDXL model field.',
title: 'SDXL Model',
},
SDXLRefinerModelField: {
color: 'teal.500',
description: 'TODO',
title: 'Refiner Model',
},
string: {
color: 'yellow.500',
description: 'Strings are text.',
title: 'String',
},
StringCollection: {
color: 'yellow.500',
title: 'String Collection',
description: 'A collection of strings.',
title: 'String Collection',
},
StringPolymorphic: {
color: 'yellow.500',
description: 'A collection of strings.',
title: 'String Polymorphic',
},
UNetField: {
color: 'red.500',
description: 'UNet submodel.',
title: 'UNet',
},
VaeField: {
color: 'blue.500',
description: 'Vae submodel.',
title: 'Vae',
},
VaeModelField: {
color: 'teal.500',
description: 'TODO',
title: 'VAE',
},
};

View File

@ -1,7 +1,9 @@
import {
SchedulerParam,
zBaseModel,
zMainModel,
zMainOrOnnxModel,
zOnnxModel,
zSDXLRefinerModel,
zScheduler,
} from 'features/parameters/types/parameterSchemas';
@ -9,7 +11,7 @@ import { keyBy } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful';
import { Node } from 'reactflow';
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
import { Graph, _InputField, _OutputField } from 'services/api/types';
import {
AnyInvocationType,
AnyResult,
@ -50,6 +52,10 @@ export type InvocationTemplate = {
* The type of this node's output
*/
outputType: string; // TODO: generate a union of output types
/**
* The invocation's version.
*/
version?: string;
};
export type FieldUIConfig = {
@ -60,50 +66,48 @@ export type FieldUIConfig = {
// TODO: Get this from the OpenAPI schema? may be tricky...
export const zFieldType = z.enum([
// region Primitives
'integer',
'float',
'boolean',
'string',
'array',
'ImageField',
'DenoiseMaskField',
'LatentsField',
'ConditioningField',
'ControlField',
'ColorField',
'ImageCollection',
'ConditioningCollection',
'ColorCollection',
'LatentsCollection',
'IntegerCollection',
'FloatCollection',
'StringCollection',
'BooleanCollection',
// endregion
// region Models
'MainModelField',
'SDXLMainModelField',
'SDXLRefinerModelField',
'ONNXModelField',
'VaeModelField',
'LoRAModelField',
'ControlNetModelField',
'UNetField',
'VaeField',
'BooleanPolymorphic',
'ClipField',
// endregion
// region Iterate/Collect
'Collection',
'CollectionItem',
// endregion
// region Misc
'ColorCollection',
'ColorField',
'ColorPolymorphic',
'ConditioningCollection',
'ConditioningField',
'ConditioningPolymorphic',
'ControlCollection',
'ControlField',
'ControlNetModelField',
'ControlPolymorphic',
'DenoiseMaskField',
'enum',
'float',
'FloatCollection',
'FloatPolymorphic',
'ImageCollection',
'ImageField',
'ImagePolymorphic',
'integer',
'IntegerCollection',
'IntegerPolymorphic',
'LatentsCollection',
'LatentsField',
'LatentsPolymorphic',
'LoRAModelField',
'MainModelField',
'ONNXModelField',
'Scheduler',
// endregion
'SDXLMainModelField',
'SDXLRefinerModelField',
'string',
'StringCollection',
'StringPolymorphic',
'UNetField',
'VaeField',
'VaeModelField',
]);
export type FieldType = z.infer<typeof zFieldType>;
@ -120,38 +124,6 @@ export const isFieldType = (value: unknown): value is FieldType =>
zFieldType.safeParse(value).success ||
zReservedFieldType.safeParse(value).success;
/**
* An input field template is generated on each page load from the OpenAPI schema.
*
* The template provides the field type and other field metadata (e.g. title, description,
* maximum length, pattern to match, etc).
*/
export type InputFieldTemplate =
| IntegerInputFieldTemplate
| FloatInputFieldTemplate
| StringInputFieldTemplate
| BooleanInputFieldTemplate
| ImageInputFieldTemplate
| DenoiseMaskInputFieldTemplate
| LatentsInputFieldTemplate
| ConditioningInputFieldTemplate
| UNetInputFieldTemplate
| ClipInputFieldTemplate
| VaeInputFieldTemplate
| ControlInputFieldTemplate
| EnumInputFieldTemplate
| MainModelInputFieldTemplate
| SDXLMainModelInputFieldTemplate
| SDXLRefinerModelInputFieldTemplate
| VaeModelInputFieldTemplate
| LoRAModelInputFieldTemplate
| ControlNetModelInputFieldTemplate
| CollectionInputFieldTemplate
| CollectionItemInputFieldTemplate
| ColorInputFieldTemplate
| ImageCollectionInputFieldTemplate
| SchedulerInputFieldTemplate;
/**
* Indicates the kind of input(s) this field may have.
*/
@ -230,24 +202,88 @@ export const zIntegerInputFieldValue = zInputFieldValueBase.extend({
});
export type IntegerInputFieldValue = z.infer<typeof zIntegerInputFieldValue>;
export const zIntegerCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('IntegerCollection'),
value: z.array(z.number().int()).optional(),
});
export type IntegerCollectionInputFieldValue = z.infer<
typeof zIntegerCollectionInputFieldValue
>;
export const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('IntegerPolymorphic'),
value: z.union([z.number().int(), z.array(z.number().int())]).optional(),
});
export type IntegerPolymorphicInputFieldValue = z.infer<
typeof zIntegerPolymorphicInputFieldValue
>;
export const zFloatInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('float'),
value: z.number().optional(),
});
export type FloatInputFieldValue = z.infer<typeof zFloatInputFieldValue>;
export const zFloatCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('FloatCollection'),
value: z.array(z.number()).optional(),
});
export type FloatCollectionInputFieldValue = z.infer<
typeof zFloatCollectionInputFieldValue
>;
export const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('FloatPolymorphic'),
value: z.union([z.number(), z.array(z.number())]).optional(),
});
export type FloatPolymorphicInputFieldValue = z.infer<
typeof zFloatPolymorphicInputFieldValue
>;
export const zStringInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('string'),
value: z.string().optional(),
});
export type StringInputFieldValue = z.infer<typeof zStringInputFieldValue>;
export const zStringCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('StringCollection'),
value: z.array(z.string()).optional(),
});
export type StringCollectionInputFieldValue = z.infer<
typeof zStringCollectionInputFieldValue
>;
export const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('StringPolymorphic'),
value: z.union([z.string(), z.array(z.string())]).optional(),
});
export type StringPolymorphicInputFieldValue = z.infer<
typeof zStringPolymorphicInputFieldValue
>;
export const zBooleanInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('boolean'),
value: z.boolean().optional(),
});
export type BooleanInputFieldValue = z.infer<typeof zBooleanInputFieldValue>;
export const zBooleanCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('BooleanCollection'),
value: z.array(z.boolean()).optional(),
});
export type BooleanCollectionInputFieldValue = z.infer<
typeof zBooleanCollectionInputFieldValue
>;
export const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('BooleanPolymorphic'),
value: z.union([z.boolean(), z.array(z.boolean())]).optional(),
});
export type BooleanPolymorphicInputFieldValue = z.infer<
typeof zBooleanPolymorphicInputFieldValue
>;
export const zEnumInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('enum'),
value: z.union([z.string(), z.number()]).optional(),
@ -260,6 +296,22 @@ export const zLatentsInputFieldValue = zInputFieldValueBase.extend({
});
export type LatentsInputFieldValue = z.infer<typeof zLatentsInputFieldValue>;
export const zLatentsCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('LatentsCollection'),
value: z.array(zLatentsField).optional(),
});
export type LatentsCollectionInputFieldValue = z.infer<
typeof zLatentsCollectionInputFieldValue
>;
export const zLatentsPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('LatentsPolymorphic'),
value: z.union([zLatentsField, z.array(zLatentsField)]).optional(),
});
export type LatentsPolymorphicInputFieldValue = z.infer<
typeof zLatentsPolymorphicInputFieldValue
>;
export const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('DenoiseMaskField'),
value: zDenoiseMaskField.optional(),
@ -276,6 +328,26 @@ export type ConditioningInputFieldValue = z.infer<
typeof zConditioningInputFieldValue
>;
export const zConditioningCollectionInputFieldValue =
zInputFieldValueBase.extend({
type: z.literal('ConditioningCollection'),
value: z.array(zConditioningField).optional(),
});
export type ConditioningCollectionInputFieldValue = z.infer<
typeof zConditioningCollectionInputFieldValue
>;
export const zConditioningPolymorphicInputFieldValue =
zInputFieldValueBase.extend({
type: z.literal('ConditioningPolymorphic'),
value: z
.union([zConditioningField, z.array(zConditioningField)])
.optional(),
});
export type ConditioningPolymorphicInputFieldValue = z.infer<
typeof zConditioningPolymorphicInputFieldValue
>;
export const zControlNetModel = zModelIdentifier;
export type ControlNetModel = z.infer<typeof zControlNetModel>;
@ -300,6 +372,22 @@ export const zControlInputFieldValue = zInputFieldValueBase.extend({
});
export type ControlInputFieldValue = z.infer<typeof zControlInputFieldValue>;
export const zControlPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ControlPolymorphic'),
value: z.union([zControlField, z.array(zControlField)]).optional(),
});
export type ControlPolymorphicInputFieldValue = z.infer<
typeof zControlPolymorphicInputFieldValue
>;
export const zControlCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ControlCollection'),
value: z.array(zControlField).optional(),
});
export type ControlCollectionInputFieldValue = z.infer<
typeof zControlCollectionInputFieldValue
>;
export const zModelType = z.enum([
'onnx',
'main',
@ -379,6 +467,14 @@ export const zImageInputFieldValue = zInputFieldValueBase.extend({
});
export type ImageInputFieldValue = z.infer<typeof zImageInputFieldValue>;
export const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ImagePolymorphic'),
value: z.union([zImageField, z.array(zImageField)]).optional(),
});
export type ImagePolymorphicInputFieldValue = z.infer<
typeof zImagePolymorphicInputFieldValue
>;
export const zImageCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ImageCollection'),
value: z.array(zImageField).optional(),
@ -471,6 +567,22 @@ export const zColorInputFieldValue = zInputFieldValueBase.extend({
});
export type ColorInputFieldValue = z.infer<typeof zColorInputFieldValue>;
export const zColorCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ColorCollection'),
value: z.array(zColorField).optional(),
});
export type ColorCollectionInputFieldValue = z.infer<
typeof zColorCollectionInputFieldValue
>;
export const zColorPolymorphicInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('ColorPolymorphic'),
value: z.union([zColorField, z.array(zColorField)]).optional(),
});
export type ColorPolymorphicInputFieldValue = z.infer<
typeof zColorPolymorphicInputFieldValue
>;
export const zSchedulerInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('Scheduler'),
value: zScheduler.optional(),
@ -480,30 +592,47 @@ export type SchedulerInputFieldValue = z.infer<
>;
export const zInputFieldValue = z.discriminatedUnion('type', [
zIntegerInputFieldValue,
zFloatInputFieldValue,
zStringInputFieldValue,
zBooleanCollectionInputFieldValue,
zBooleanInputFieldValue,
zImageInputFieldValue,
zLatentsInputFieldValue,
zDenoiseMaskInputFieldValue,
zConditioningInputFieldValue,
zUNetInputFieldValue,
zBooleanPolymorphicInputFieldValue,
zClipInputFieldValue,
zVaeInputFieldValue,
zControlInputFieldValue,
zEnumInputFieldValue,
zMainModelInputFieldValue,
zSDXLMainModelInputFieldValue,
zSDXLRefinerModelInputFieldValue,
zVaeModelInputFieldValue,
zLoRAModelInputFieldValue,
zControlNetModelInputFieldValue,
zCollectionInputFieldValue,
zCollectionItemInputFieldValue,
zColorInputFieldValue,
zColorCollectionInputFieldValue,
zColorPolymorphicInputFieldValue,
zConditioningInputFieldValue,
zConditioningCollectionInputFieldValue,
zConditioningPolymorphicInputFieldValue,
zControlInputFieldValue,
zControlNetModelInputFieldValue,
zControlCollectionInputFieldValue,
zControlPolymorphicInputFieldValue,
zDenoiseMaskInputFieldValue,
zEnumInputFieldValue,
zFloatCollectionInputFieldValue,
zFloatInputFieldValue,
zFloatPolymorphicInputFieldValue,
zImageCollectionInputFieldValue,
zImagePolymorphicInputFieldValue,
zImageInputFieldValue,
zIntegerCollectionInputFieldValue,
zIntegerPolymorphicInputFieldValue,
zIntegerInputFieldValue,
zLatentsInputFieldValue,
zLatentsCollectionInputFieldValue,
zLatentsPolymorphicInputFieldValue,
zLoRAModelInputFieldValue,
zMainModelInputFieldValue,
zSchedulerInputFieldValue,
zSDXLMainModelInputFieldValue,
zSDXLRefinerModelInputFieldValue,
zStringCollectionInputFieldValue,
zStringPolymorphicInputFieldValue,
zStringInputFieldValue,
zUNetInputFieldValue,
zVaeInputFieldValue,
zVaeModelInputFieldValue,
]);
export type InputFieldValue = z.infer<typeof zInputFieldValue>;
@ -512,7 +641,6 @@ export type InputFieldTemplateBase = {
name: string;
title: string;
description: string;
type: FieldType;
required: boolean;
fieldKind: 'input';
} & _InputField;
@ -527,6 +655,19 @@ export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
exclusiveMinimum?: boolean;
};
export type IntegerCollectionInputFieldTemplate = InputFieldTemplateBase & {
type: 'IntegerCollection';
default: number[];
item_default?: number;
};
export type IntegerPolymorphicInputFieldTemplate = Omit<
IntegerInputFieldTemplate,
'type'
> & {
type: 'IntegerPolymorphic';
};
export type FloatInputFieldTemplate = InputFieldTemplateBase & {
type: 'float';
default: number;
@ -537,6 +678,19 @@ export type FloatInputFieldTemplate = InputFieldTemplateBase & {
exclusiveMinimum?: boolean;
};
export type FloatCollectionInputFieldTemplate = InputFieldTemplateBase & {
type: 'FloatCollection';
default: number[];
item_default?: number;
};
export type FloatPolymorphicInputFieldTemplate = Omit<
FloatInputFieldTemplate,
'type'
> & {
type: 'FloatPolymorphic';
};
export type StringInputFieldTemplate = InputFieldTemplateBase & {
type: 'string';
default: string;
@ -545,19 +699,53 @@ export type StringInputFieldTemplate = InputFieldTemplateBase & {
pattern?: string;
};
export type StringCollectionInputFieldTemplate = InputFieldTemplateBase & {
type: 'StringCollection';
default: string[];
item_default?: string;
};
export type StringPolymorphicInputFieldTemplate = Omit<
StringInputFieldTemplate,
'type'
> & {
type: 'StringPolymorphic';
};
export type BooleanInputFieldTemplate = InputFieldTemplateBase & {
default: boolean;
type: 'boolean';
};
export type BooleanCollectionInputFieldTemplate = InputFieldTemplateBase & {
type: 'BooleanCollection';
default: boolean[];
item_default?: boolean;
};
export type BooleanPolymorphicInputFieldTemplate = Omit<
BooleanInputFieldTemplate,
'type'
> & {
type: 'BooleanPolymorphic';
};
export type ImageInputFieldTemplate = InputFieldTemplateBase & {
default: ImageDTO;
default: ImageField;
type: 'ImageField';
};
export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & {
default: ImageField[];
type: 'ImageCollection';
item_default?: ImageField;
};
export type ImagePolymorphicInputFieldTemplate = Omit<
ImageInputFieldTemplate,
'type'
> & {
type: 'ImagePolymorphic';
};
export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
@ -566,15 +754,40 @@ export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
};
export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
default: string;
default: LatentsField;
type: 'LatentsField';
};
export type LatentsCollectionInputFieldTemplate = InputFieldTemplateBase & {
default: LatentsField[];
type: 'LatentsCollection';
item_default?: LatentsField;
};
export type LatentsPolymorphicInputFieldTemplate = InputFieldTemplateBase & {
default: LatentsField;
type: 'LatentsPolymorphic';
};
export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'ConditioningField';
};
export type ConditioningCollectionInputFieldTemplate =
InputFieldTemplateBase & {
default: ConditioningField[];
type: 'ConditioningCollection';
item_default?: ConditioningField;
};
export type ConditioningPolymorphicInputFieldTemplate = Omit<
ConditioningInputFieldTemplate,
'type'
> & {
type: 'ConditioningPolymorphic';
};
export type UNetInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'UNetField';
@ -595,6 +808,19 @@ export type ControlInputFieldTemplate = InputFieldTemplateBase & {
type: 'ControlField';
};
export type ControlCollectionInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'ControlCollection';
item_default?: ControlField;
};
export type ControlPolymorphicInputFieldTemplate = Omit<
ControlInputFieldTemplate,
'type'
> & {
type: 'ControlPolymorphic';
};
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
default: string | number;
type: 'enum';
@ -647,6 +873,18 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & {
type: 'ColorField';
};
export type ColorPolymorphicInputFieldTemplate = Omit<
ColorInputFieldTemplate,
'type'
> & {
type: 'ColorPolymorphic';
};
export type ColorCollectionInputFieldTemplate = InputFieldTemplateBase & {
default: [];
type: 'ColorCollection';
};
export type SchedulerInputFieldTemplate = InputFieldTemplateBase & {
default: SchedulerParam;
type: 'Scheduler';
@ -657,6 +895,55 @@ export type WorkflowInputFieldTemplate = InputFieldTemplateBase & {
type: 'WorkflowField';
};
/**
* An input field template is generated on each page load from the OpenAPI schema.
*
* The template provides the field type and other field metadata (e.g. title, description,
* maximum length, pattern to match, etc).
*/
export type InputFieldTemplate =
| BooleanCollectionInputFieldTemplate
| BooleanPolymorphicInputFieldTemplate
| BooleanInputFieldTemplate
| ClipInputFieldTemplate
| CollectionInputFieldTemplate
| CollectionItemInputFieldTemplate
| ColorInputFieldTemplate
| ColorCollectionInputFieldTemplate
| ColorPolymorphicInputFieldTemplate
| ConditioningInputFieldTemplate
| ConditioningCollectionInputFieldTemplate
| ConditioningPolymorphicInputFieldTemplate
| ControlInputFieldTemplate
| ControlCollectionInputFieldTemplate
| ControlNetModelInputFieldTemplate
| ControlPolymorphicInputFieldTemplate
| DenoiseMaskInputFieldTemplate
| EnumInputFieldTemplate
| FloatCollectionInputFieldTemplate
| FloatInputFieldTemplate
| FloatPolymorphicInputFieldTemplate
| ImageCollectionInputFieldTemplate
| ImagePolymorphicInputFieldTemplate
| ImageInputFieldTemplate
| IntegerCollectionInputFieldTemplate
| IntegerPolymorphicInputFieldTemplate
| IntegerInputFieldTemplate
| LatentsInputFieldTemplate
| LatentsCollectionInputFieldTemplate
| LatentsPolymorphicInputFieldTemplate
| LoRAModelInputFieldTemplate
| MainModelInputFieldTemplate
| SchedulerInputFieldTemplate
| SDXLMainModelInputFieldTemplate
| SDXLRefinerModelInputFieldTemplate
| StringCollectionInputFieldTemplate
| StringPolymorphicInputFieldTemplate
| StringInputFieldTemplate
| UNetInputFieldTemplate
| VaeInputFieldTemplate
| VaeModelInputFieldTemplate;
export const isInputFieldValue = (
field?: InputFieldValue | OutputFieldValue
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
@ -679,6 +966,7 @@ export type InvocationSchemaExtra = {
title: string;
category?: string;
tags?: string[];
version?: string;
properties: Omit<
NonNullable<OpenAPIV3.SchemaObject['properties']> &
(_InputField | _OutputField),
@ -729,8 +1017,22 @@ export type InvocationSchemaObject = (
) & { class: 'invocation' };
export const isSchemaObject = (
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject
): obj is OpenAPIV3.SchemaObject => !('$ref' in obj);
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
): obj is OpenAPIV3.SchemaObject => Boolean(obj && !('$ref' in obj));
export const isArraySchemaObject = (
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
): obj is OpenAPIV3.ArraySchemaObject =>
Boolean(obj && !('$ref' in obj) && obj.type === 'array');
export const isNonArraySchemaObject = (
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
): obj is OpenAPIV3.NonArraySchemaObject =>
Boolean(obj && !('$ref' in obj) && obj.type !== 'array');
export const isRefObject = (
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
): obj is OpenAPIV3.ReferenceObject => Boolean(obj && '$ref' in obj);
export const isInvocationSchemaObject = (
obj:
@ -769,12 +1071,14 @@ export const zCoreMetadata = z
steps: z.number().int().nullish(),
scheduler: z.string().nullish(),
clip_skip: z.number().int().nullish(),
model: zMainOrOnnxModel.nullish(),
controlnets: z.array(zControlField).nullish(),
model: z
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
.nullish(),
controlnets: z.array(zControlField.deepPartial()).nullish(),
loras: z
.array(
z.object({
lora: zLoRAModelField,
lora: zLoRAModelField.deepPartial(),
weight: z.number(),
})
)
@ -784,18 +1088,41 @@ export const zCoreMetadata = z
init_image: z.string().nullish(),
positive_style_prompt: z.string().nullish(),
negative_style_prompt: z.string().nullish(),
refiner_model: zSDXLRefinerModel.nullish(),
refiner_model: zSDXLRefinerModel.deepPartial().nullish(),
refiner_cfg_scale: z.number().nullish(),
refiner_steps: z.number().int().nullish(),
refiner_scheduler: z.string().nullish(),
refiner_positive_aesthetic_store: z.number().nullish(),
refiner_negative_aesthetic_store: z.number().nullish(),
refiner_positive_aesthetic_score: z.number().nullish(),
refiner_negative_aesthetic_score: z.number().nullish(),
refiner_start: z.number().nullish(),
})
.catchall(z.record(z.any()));
.passthrough();
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
export const zSemVer = z.string().refine((val) => {
const [major, minor, patch] = val.split('.');
return (
major !== undefined &&
Number.isInteger(Number(major)) &&
minor !== undefined &&
Number.isInteger(Number(minor)) &&
patch !== undefined &&
Number.isInteger(Number(patch))
);
});
export const zParsedSemver = zSemVer.transform((val) => {
const [major, minor, patch] = val.split('.');
return {
major: Number(major),
minor: Number(minor),
patch: Number(patch),
};
});
export type SemVer = z.infer<typeof zSemVer>;
export const zInvocationNodeData = z.object({
id: z.string().trim().min(1),
// no easy way to build this dynamically, and we don't want to anyways, because this will be used
@ -808,6 +1135,7 @@ export const zInvocationNodeData = z.object({
notes: z.string(),
embedWorkflow: z.boolean(),
isIntermediate: z.boolean(),
version: zSemVer.optional(),
});
// Massage this to get better type safety while developing
@ -896,20 +1224,6 @@ export const zFieldIdentifier = z.object({
export type FieldIdentifier = z.infer<typeof zFieldIdentifier>;
export const zSemVer = z.string().refine((val) => {
const [major, minor, patch] = val.split('.');
return (
major !== undefined &&
minor !== undefined &&
patch !== undefined &&
Number.isInteger(Number(major)) &&
Number.isInteger(Number(minor)) &&
Number.isInteger(Number(patch))
);
});
export type SemVer = z.infer<typeof zSemVer>;
export type WorkflowWarning = {
message: string;
issues: string[];

View File

@ -1,5 +1,14 @@
import { isBoolean, isInteger, isNumber, isString } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types';
import {
COLLECTION_MAP,
POLYMORPHIC_TYPES,
SINGLE_TO_POLYMORPHIC_MAP,
isCollectionItemType,
isPolymorphicItemType,
} from '../types/constants';
import {
BooleanCollectionInputFieldTemplate,
BooleanInputFieldTemplate,
ClipInputFieldTemplate,
CollectionInputFieldTemplate,
@ -11,10 +20,13 @@ import {
DenoiseMaskInputFieldTemplate,
EnumInputFieldTemplate,
FieldType,
FloatCollectionInputFieldTemplate,
FloatPolymorphicInputFieldTemplate,
FloatInputFieldTemplate,
ImageCollectionInputFieldTemplate,
ImageInputFieldTemplate,
InputFieldTemplateBase,
IntegerCollectionInputFieldTemplate,
IntegerInputFieldTemplate,
InvocationFieldSchema,
InvocationSchemaObject,
@ -24,11 +36,32 @@ import {
SDXLMainModelInputFieldTemplate,
SDXLRefinerModelInputFieldTemplate,
SchedulerInputFieldTemplate,
StringCollectionInputFieldTemplate,
StringInputFieldTemplate,
UNetInputFieldTemplate,
VaeInputFieldTemplate,
VaeModelInputFieldTemplate,
isArraySchemaObject,
isNonArraySchemaObject,
isRefObject,
isSchemaObject,
ControlPolymorphicInputFieldTemplate,
ColorPolymorphicInputFieldTemplate,
ColorCollectionInputFieldTemplate,
IntegerPolymorphicInputFieldTemplate,
StringPolymorphicInputFieldTemplate,
BooleanPolymorphicInputFieldTemplate,
ImagePolymorphicInputFieldTemplate,
LatentsPolymorphicInputFieldTemplate,
LatentsCollectionInputFieldTemplate,
ConditioningPolymorphicInputFieldTemplate,
ConditioningCollectionInputFieldTemplate,
ControlCollectionInputFieldTemplate,
ImageField,
LatentsField,
ConditioningField,
} from '../types/types';
import { ControlField } from 'services/api/types';
export type BaseFieldProperties = 'name' | 'title' | 'description';
@ -45,15 +78,8 @@ export type BuildInputFieldArg = {
* @example
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
*/
export const refObjectToFieldType = (
refObject: OpenAPIV3.ReferenceObject
): FieldType => {
const name = refObject.$ref.split('/').slice(-1)[0];
if (!name) {
throw `Unknown field type: ${name}`;
}
return name as FieldType;
};
export const refObjectToSchemaName = (refObject: OpenAPIV3.ReferenceObject) =>
refObject.$ref.split('/').slice(-1)[0];
const buildIntegerInputFieldTemplate = ({
schemaObject,
@ -88,6 +114,57 @@ const buildIntegerInputFieldTemplate = ({
return template;
};
const buildIntegerPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): IntegerPolymorphicInputFieldTemplate => {
const template: IntegerPolymorphicInputFieldTemplate = {
...baseField,
type: 'IntegerPolymorphic',
default: schemaObject.default ?? 0,
};
if (schemaObject.multipleOf !== undefined) {
template.multipleOf = schemaObject.multipleOf;
}
if (schemaObject.maximum !== undefined) {
template.maximum = schemaObject.maximum;
}
if (schemaObject.exclusiveMaximum !== undefined) {
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (schemaObject.minimum !== undefined) {
template.minimum = schemaObject.minimum;
}
if (schemaObject.exclusiveMinimum !== undefined) {
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return template;
};
const buildIntegerCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): IntegerCollectionInputFieldTemplate => {
const item_default =
isNumber(schemaObject.item_default) && isInteger(schemaObject.item_default)
? schemaObject.item_default
: 0;
const template: IntegerCollectionInputFieldTemplate = {
...baseField,
type: 'IntegerCollection',
default: schemaObject.default ?? [],
item_default,
};
return template;
};
const buildFloatInputFieldTemplate = ({
schemaObject,
baseField,
@ -121,6 +198,54 @@ const buildFloatInputFieldTemplate = ({
return template;
};
const buildFloatPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): FloatPolymorphicInputFieldTemplate => {
const template: FloatPolymorphicInputFieldTemplate = {
...baseField,
type: 'FloatPolymorphic',
default: schemaObject.default ?? 0,
};
if (schemaObject.multipleOf !== undefined) {
template.multipleOf = schemaObject.multipleOf;
}
if (schemaObject.maximum !== undefined) {
template.maximum = schemaObject.maximum;
}
if (schemaObject.exclusiveMaximum !== undefined) {
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (schemaObject.minimum !== undefined) {
template.minimum = schemaObject.minimum;
}
if (schemaObject.exclusiveMinimum !== undefined) {
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return template;
};
const buildFloatCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): FloatCollectionInputFieldTemplate => {
const item_default = isNumber(schemaObject.item_default)
? schemaObject.item_default
: 0;
const template: FloatCollectionInputFieldTemplate = {
...baseField,
type: 'FloatCollection',
default: schemaObject.default ?? [],
item_default,
};
return template;
};
const buildStringInputFieldTemplate = ({
schemaObject,
baseField,
@ -146,6 +271,48 @@ const buildStringInputFieldTemplate = ({
return template;
};
const buildStringPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): StringPolymorphicInputFieldTemplate => {
const template: StringPolymorphicInputFieldTemplate = {
...baseField,
type: 'StringPolymorphic',
default: schemaObject.default ?? '',
};
if (schemaObject.minLength !== undefined) {
template.minLength = schemaObject.minLength;
}
if (schemaObject.maxLength !== undefined) {
template.maxLength = schemaObject.maxLength;
}
if (schemaObject.pattern !== undefined) {
template.pattern = schemaObject.pattern;
}
return template;
};
const buildStringCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): StringCollectionInputFieldTemplate => {
const item_default = isString(schemaObject.item_default)
? schemaObject.item_default
: '';
const template: StringCollectionInputFieldTemplate = {
...baseField,
type: 'StringCollection',
default: schemaObject.default ?? [],
item_default,
};
return template;
};
const buildBooleanInputFieldTemplate = ({
schemaObject,
baseField,
@ -159,6 +326,37 @@ const buildBooleanInputFieldTemplate = ({
return template;
};
const buildBooleanPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): BooleanPolymorphicInputFieldTemplate => {
const template: BooleanPolymorphicInputFieldTemplate = {
...baseField,
type: 'BooleanPolymorphic',
default: schemaObject.default ?? false,
};
return template;
};
const buildBooleanCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): BooleanCollectionInputFieldTemplate => {
const item_default =
schemaObject.item_default && isBoolean(schemaObject.item_default)
? schemaObject.item_default
: false;
const template: BooleanCollectionInputFieldTemplate = {
...baseField,
type: 'BooleanCollection',
default: schemaObject.default ?? [],
item_default,
};
return template;
};
const buildMainModelInputFieldTemplate = ({
schemaObject,
baseField,
@ -250,6 +448,19 @@ const buildImageInputFieldTemplate = ({
return template;
};
const buildImagePolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ImagePolymorphicInputFieldTemplate => {
const template: ImagePolymorphicInputFieldTemplate = {
...baseField,
type: 'ImagePolymorphic',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageCollectionInputFieldTemplate = ({
schemaObject,
baseField,
@ -257,7 +468,8 @@ const buildImageCollectionInputFieldTemplate = ({
const template: ImageCollectionInputFieldTemplate = {
...baseField,
type: 'ImageCollection',
default: schemaObject.default ?? undefined,
default: schemaObject.default ?? [],
item_default: (schemaObject.item_default as ImageField) ?? undefined,
};
return template;
@ -289,6 +501,33 @@ const buildLatentsInputFieldTemplate = ({
return template;
};
const buildLatentsPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): LatentsPolymorphicInputFieldTemplate => {
const template: LatentsPolymorphicInputFieldTemplate = {
...baseField,
type: 'LatentsPolymorphic',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildLatentsCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): LatentsCollectionInputFieldTemplate => {
const template: LatentsCollectionInputFieldTemplate = {
...baseField,
type: 'LatentsCollection',
default: schemaObject.default ?? [],
item_default: (schemaObject.item_default as LatentsField) ?? undefined,
};
return template;
};
const buildConditioningInputFieldTemplate = ({
schemaObject,
baseField,
@ -302,6 +541,33 @@ const buildConditioningInputFieldTemplate = ({
return template;
};
const buildConditioningPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ConditioningPolymorphicInputFieldTemplate => {
const template: ConditioningPolymorphicInputFieldTemplate = {
...baseField,
type: 'ConditioningPolymorphic',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildConditioningCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ConditioningCollectionInputFieldTemplate => {
const template: ConditioningCollectionInputFieldTemplate = {
...baseField,
type: 'ConditioningCollection',
default: schemaObject.default ?? [],
item_default: (schemaObject.item_default as ConditioningField) ?? undefined,
};
return template;
};
const buildUNetInputFieldTemplate = ({
schemaObject,
baseField,
@ -355,6 +621,33 @@ const buildControlInputFieldTemplate = ({
return template;
};
const buildControlPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ControlPolymorphicInputFieldTemplate => {
const template: ControlPolymorphicInputFieldTemplate = {
...baseField,
type: 'ControlPolymorphic',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildControlCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ControlCollectionInputFieldTemplate => {
const template: ControlCollectionInputFieldTemplate = {
...baseField,
type: 'ControlCollection',
default: schemaObject.default ?? [],
item_default: (schemaObject.item_default as ControlField) ?? undefined,
};
return template;
};
const buildEnumInputFieldTemplate = ({
schemaObject,
baseField,
@ -408,6 +701,32 @@ const buildColorInputFieldTemplate = ({
return template;
};
const buildColorPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ColorPolymorphicInputFieldTemplate => {
const template: ColorPolymorphicInputFieldTemplate = {
...baseField,
type: 'ColorPolymorphic',
default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 },
};
return template;
};
const buildColorCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ColorCollectionInputFieldTemplate => {
const template: ColorCollectionInputFieldTemplate = {
...baseField,
type: 'ColorCollection',
default: schemaObject.default ?? [],
};
return template;
};
const buildSchedulerInputFieldTemplate = ({
schemaObject,
baseField,
@ -421,45 +740,138 @@ const buildSchedulerInputFieldTemplate = ({
return template;
};
export const getFieldType = (schemaObject: InvocationFieldSchema): string => {
let fieldType = '';
const { ui_type } = schemaObject;
if (ui_type) {
fieldType = ui_type;
export const getFieldType = (
schemaObject: InvocationFieldSchema
): string | undefined => {
if (schemaObject?.ui_type) {
return schemaObject.ui_type;
} else if (!schemaObject.type) {
// console.log('refObject', schemaObject);
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
if (schemaObject.allOf) {
fieldType = refObjectToFieldType(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
);
const allOf = schemaObject.allOf;
if (allOf && allOf[0] && isRefObject(allOf[0])) {
return refObjectToSchemaName(allOf[0]);
}
} else if (schemaObject.anyOf) {
fieldType = refObjectToFieldType(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
);
} else if (schemaObject.oneOf) {
fieldType = refObjectToFieldType(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
);
const anyOf = schemaObject.anyOf;
/**
* Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is:
* - an `anyOf` with two items
* - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items`
* - the other is a `SchemaObject` or `ReferenceObject` of type T
*
* Any other cases we ignore.
*/
let firstType: string | undefined;
let secondType: string | undefined;
if (isArraySchemaObject(anyOf[0])) {
// first is array, second is not
const first = anyOf[0].items;
const second = anyOf[1];
if (isRefObject(first) && isRefObject(second)) {
firstType = refObjectToSchemaName(first);
secondType = refObjectToSchemaName(second);
} else if (
isNonArraySchemaObject(first) &&
isNonArraySchemaObject(second)
) {
firstType = first.type;
secondType = second.type;
}
} else if (isArraySchemaObject(anyOf[1])) {
// first is not array, second is
const first = anyOf[0];
const second = anyOf[1].items;
if (isRefObject(first) && isRefObject(second)) {
firstType = refObjectToSchemaName(first);
secondType = refObjectToSchemaName(second);
} else if (
isNonArraySchemaObject(first) &&
isNonArraySchemaObject(second)
) {
firstType = first.type;
secondType = second.type;
}
}
if (firstType === secondType && isPolymorphicItemType(firstType)) {
return SINGLE_TO_POLYMORPHIC_MAP[firstType];
}
}
} else if (schemaObject.enum) {
fieldType = 'enum';
return 'enum';
} else if (schemaObject.type) {
if (schemaObject.type === 'number') {
// floats are "number" in OpenAPI, while ints are "integer"
fieldType = 'float';
// floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them
return 'float';
} else if (schemaObject.type === 'array') {
const itemType = isSchemaObject(schemaObject.items)
? schemaObject.items.type
: refObjectToSchemaName(schemaObject.items);
if (isCollectionItemType(itemType)) {
return COLLECTION_MAP[itemType];
}
return;
} else {
fieldType = schemaObject.type;
return schemaObject.type;
}
}
return fieldType;
return;
};
const TEMPLATE_BUILDER_MAP = {
boolean: buildBooleanInputFieldTemplate,
BooleanCollection: buildBooleanCollectionInputFieldTemplate,
BooleanPolymorphic: buildBooleanPolymorphicInputFieldTemplate,
ClipField: buildClipInputFieldTemplate,
Collection: buildCollectionInputFieldTemplate,
CollectionItem: buildCollectionItemInputFieldTemplate,
ColorCollection: buildColorCollectionInputFieldTemplate,
ColorField: buildColorInputFieldTemplate,
ColorPolymorphic: buildColorPolymorphicInputFieldTemplate,
ConditioningCollection: buildConditioningCollectionInputFieldTemplate,
ConditioningField: buildConditioningInputFieldTemplate,
ConditioningPolymorphic: buildConditioningPolymorphicInputFieldTemplate,
ControlCollection: buildControlCollectionInputFieldTemplate,
ControlField: buildControlInputFieldTemplate,
ControlNetModelField: buildControlNetModelInputFieldTemplate,
ControlPolymorphic: buildControlPolymorphicInputFieldTemplate,
DenoiseMaskField: buildDenoiseMaskInputFieldTemplate,
enum: buildEnumInputFieldTemplate,
float: buildFloatInputFieldTemplate,
FloatCollection: buildFloatCollectionInputFieldTemplate,
FloatPolymorphic: buildFloatPolymorphicInputFieldTemplate,
ImageCollection: buildImageCollectionInputFieldTemplate,
ImageField: buildImageInputFieldTemplate,
ImagePolymorphic: buildImagePolymorphicInputFieldTemplate,
integer: buildIntegerInputFieldTemplate,
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
LatentsField: buildLatentsInputFieldTemplate,
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
LoRAModelField: buildLoRAModelInputFieldTemplate,
MainModelField: buildMainModelInputFieldTemplate,
Scheduler: buildSchedulerInputFieldTemplate,
SDXLMainModelField: buildSDXLMainModelInputFieldTemplate,
SDXLRefinerModelField: buildRefinerModelInputFieldTemplate,
string: buildStringInputFieldTemplate,
StringCollection: buildStringCollectionInputFieldTemplate,
StringPolymorphic: buildStringPolymorphicInputFieldTemplate,
UNetField: buildUNetInputFieldTemplate,
VaeField: buildVaeInputFieldTemplate,
VaeModelField: buildVaeModelInputFieldTemplate,
};
const isTemplatedFieldType = (
fieldType: string | undefined
): fieldType is keyof typeof TEMPLATE_BUILDER_MAP =>
Boolean(fieldType && fieldType in TEMPLATE_BUILDER_MAP);
/**
* Builds an input field from an invocation schema property.
* @param fieldSchema The schema object
@ -474,7 +886,8 @@ export const buildInputFieldTemplate = (
const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema;
const extra = {
input,
// TODO: Can we support polymorphic inputs in the UI?
input: POLYMORPHIC_TYPES.includes(fieldType) ? 'connection' : input,
ui_hidden,
ui_component,
ui_type,
@ -490,146 +903,12 @@ export const buildInputFieldTemplate = (
...extra,
};
if (fieldType === 'ImageField') {
return buildImageInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
if (!isTemplatedFieldType(fieldType)) {
return;
}
if (fieldType === 'ImageCollection') {
return buildImageCollectionInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'DenoiseMaskField') {
return buildDenoiseMaskInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'LatentsField') {
return buildLatentsInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'ConditioningField') {
return buildConditioningInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'UNetField') {
return buildUNetInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'ClipField') {
return buildClipInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'VaeField') {
return buildVaeInputFieldTemplate({ schemaObject: fieldSchema, baseField });
}
if (fieldType === 'ControlField') {
return buildControlInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'MainModelField') {
return buildMainModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'SDXLRefinerModelField') {
return buildRefinerModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'SDXLMainModelField') {
return buildSDXLMainModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'VaeModelField') {
return buildVaeModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'LoRAModelField') {
return buildLoRAModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'ControlNetModelField') {
return buildControlNetModelInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'enum') {
return buildEnumInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'integer') {
return buildIntegerInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'float') {
return buildFloatInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'string') {
return buildStringInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'boolean') {
return buildBooleanInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'Collection') {
return buildCollectionInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'CollectionItem') {
return buildCollectionItemInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'ColorField') {
return buildColorInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
if (fieldType === 'Scheduler') {
return buildSchedulerInputFieldTemplate({
schemaObject: fieldSchema,
baseField,
});
}
return;
return TEMPLATE_BUILDER_MAP[fieldType]({
schemaObject: fieldSchema,
baseField,
});
};

View File

@ -1,104 +1,79 @@
import { InputFieldTemplate, InputFieldValue } from '../types/types';
const FIELD_VALUE_FALLBACK_MAP = {
'enum.number': 0,
'enum.string': '',
boolean: false,
BooleanCollection: [],
BooleanPolymorphic: false,
ClipField: undefined,
Collection: [],
CollectionItem: undefined,
ColorCollection: [],
ColorField: undefined,
ColorPolymorphic: undefined,
ConditioningCollection: [],
ConditioningField: undefined,
ConditioningPolymorphic: undefined,
ControlCollection: [],
ControlField: undefined,
ControlNetModelField: undefined,
ControlPolymorphic: undefined,
DenoiseMaskField: undefined,
float: 0,
FloatCollection: [],
FloatPolymorphic: 0,
ImageCollection: [],
ImageField: undefined,
ImagePolymorphic: undefined,
integer: 0,
IntegerCollection: [],
IntegerPolymorphic: 0,
LatentsCollection: [],
LatentsField: undefined,
LatentsPolymorphic: undefined,
LoRAModelField: undefined,
MainModelField: undefined,
ONNXModelField: undefined,
Scheduler: 'euler',
SDXLMainModelField: undefined,
SDXLRefinerModelField: undefined,
string: '',
StringCollection: [],
StringPolymorphic: '',
UNetField: undefined,
VaeField: undefined,
VaeModelField: undefined,
};
export const buildInputFieldValue = (
id: string,
template: InputFieldTemplate
): InputFieldValue => {
const fieldValue: InputFieldValue = {
// TODO: this should be `fieldValue: InputFieldValue`, but that introduces a TS issue I couldn't
// resolve - for some reason, it doesn't like `template.type`, which is the discriminant for both
// `InputFieldTemplate` union. It is (type-structurally) equal to the discriminant for the
// `InputFieldValue` union, but TS doesn't seem to like it...
const fieldValue = {
id,
name: template.name,
type: template.type,
label: '',
fieldKind: 'input',
};
if (template.type === 'string') {
fieldValue.value = template.default ?? '';
}
if (template.type === 'integer') {
fieldValue.value = template.default ?? 0;
}
if (template.type === 'float') {
fieldValue.value = template.default ?? 0;
}
if (template.type === 'boolean') {
fieldValue.value = template.default ?? false;
}
} as InputFieldValue;
if (template.type === 'enum') {
if (template.enumType === 'number') {
fieldValue.value = template.default ?? 0;
fieldValue.value =
template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.number'];
}
if (template.enumType === 'string') {
fieldValue.value = template.default ?? '';
fieldValue.value =
template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.string'];
}
}
if (template.type === 'Collection') {
fieldValue.value = template.default ?? 1;
}
if (template.type === 'ImageField') {
fieldValue.value = undefined;
}
if (template.type === 'ImageCollection') {
fieldValue.value = [];
}
if (template.type === 'DenoiseMaskField') {
fieldValue.value = undefined;
}
if (template.type === 'LatentsField') {
fieldValue.value = undefined;
}
if (template.type === 'ConditioningField') {
fieldValue.value = undefined;
}
if (template.type === 'UNetField') {
fieldValue.value = undefined;
}
if (template.type === 'ClipField') {
fieldValue.value = undefined;
}
if (template.type === 'VaeField') {
fieldValue.value = undefined;
}
if (template.type === 'ControlField') {
fieldValue.value = undefined;
}
if (template.type === 'MainModelField') {
fieldValue.value = undefined;
}
if (template.type === 'SDXLRefinerModelField') {
fieldValue.value = undefined;
}
if (template.type === 'VaeModelField') {
fieldValue.value = undefined;
}
if (template.type === 'LoRAModelField') {
fieldValue.value = undefined;
}
if (template.type === 'ControlNetModelField') {
fieldValue.value = undefined;
}
if (template.type === 'Scheduler') {
fieldValue.value = 'euler';
} else {
fieldValue.value =
template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type];
}
return fieldValue;

View File

@ -1,4 +1,6 @@
import * as png from '@stevebel/png';
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import {
ImageMetadataAndWorkflow,
zCoreMetadata,
@ -18,6 +20,11 @@ export const getMetadataAndWorkflowFromImageBlob = async (
const metadataResult = zCoreMetadata.safeParse(JSON.parse(rawMetadata));
if (metadataResult.success) {
data.metadata = metadataResult.data;
} else {
logger('system').error(
{ error: parseify(metadataResult.error) },
'Problem reading metadata from image'
);
}
}
@ -26,6 +33,11 @@ export const getMetadataAndWorkflowFromImageBlob = async (
const workflowResult = zWorkflow.safeParse(JSON.parse(rawWorkflow));
if (workflowResult.success) {
data.workflow = workflowResult.data;
} else {
logger('system').error(
{ error: parseify(workflowResult.error) },
'Problem reading workflow from image'
);
}
}

View File

@ -60,9 +60,9 @@ export const addSDXLRefinerToGraph = (
if (metadataAccumulator) {
metadataAccumulator.refiner_model = refinerModel;
metadataAccumulator.refiner_positive_aesthetic_store =
metadataAccumulator.refiner_positive_aesthetic_score =
refinerPositiveAestheticScore;
metadataAccumulator.refiner_negative_aesthetic_store =
metadataAccumulator.refiner_negative_aesthetic_score =
refinerNegativeAestheticScore;
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
metadataAccumulator.refiner_scheduler = refinerScheduler;

View File

@ -73,6 +73,7 @@ export const parseSchema = (
const title = schema.title.replace('Invocation', '');
const tags = schema.tags ?? [];
const description = schema.description ?? '';
const version = schema.version ?? '';
const inputs = reduce(
schema.properties,
@ -225,11 +226,12 @@ export const parseSchema = (
const invocation: InvocationTemplate = {
title,
type,
version,
tags,
description,
outputType,
inputs,
outputs,
outputType,
};
Object.assign(invocationsAccumulator, { [type]: invocation });

View File

@ -0,0 +1,96 @@
import { compareVersions } from 'compare-versions';
import { cloneDeep, keyBy } from 'lodash-es';
import {
InvocationTemplate,
Workflow,
WorkflowWarning,
isWorkflowInvocationNode,
} from '../types/types';
import { parseify } from 'common/util/serialize';
export const validateWorkflow = (
workflow: Workflow,
nodeTemplates: Record<string, InvocationTemplate>
) => {
const clone = cloneDeep(workflow);
const { nodes, edges } = clone;
const errors: WorkflowWarning[] = [];
const invocationNodes = nodes.filter(isWorkflowInvocationNode);
const keyedNodes = keyBy(invocationNodes, 'id');
nodes.forEach((node) => {
if (!isWorkflowInvocationNode(node)) {
return;
}
const nodeTemplate = nodeTemplates[node.data.type];
if (!nodeTemplate) {
errors.push({
message: `Node "${node.data.type}" skipped`,
issues: [`Node type "${node.data.type}" does not exist`],
data: node,
});
return;
}
if (
nodeTemplate.version &&
node.data.version &&
compareVersions(nodeTemplate.version, node.data.version) !== 0
) {
errors.push({
message: `Node "${node.data.type}" has mismatched version`,
issues: [
`Node "${node.data.type}" v${node.data.version} may be incompatible with installed v${nodeTemplate.version}`,
],
data: { node, nodeTemplate: parseify(nodeTemplate) },
});
return;
}
});
edges.forEach((edge, i) => {
const sourceNode = keyedNodes[edge.source];
const targetNode = keyedNodes[edge.target];
const issues: string[] = [];
if (!sourceNode) {
issues.push(`Output node ${edge.source} does not exist`);
} else if (
edge.type === 'default' &&
!(edge.sourceHandle in sourceNode.data.outputs)
) {
issues.push(
`Output field "${edge.source}.${edge.sourceHandle}" does not exist`
);
}
if (!targetNode) {
issues.push(`Input node ${edge.target} does not exist`);
} else if (
edge.type === 'default' &&
!(edge.targetHandle in targetNode.data.inputs)
) {
issues.push(
`Input field "${edge.target}.${edge.targetHandle}" does not exist`
);
}
if (!nodeTemplates[sourceNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) {
issues.push(
`Source node "${edge.source}" missing template "${sourceNode?.data.type}"`
);
}
if (!nodeTemplates[targetNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) {
issues.push(
`Source node "${edge.target}" missing template "${targetNode?.data.type}"`
);
}
if (issues.length) {
delete edges[i];
const src = edge.type === 'default' ? edge.sourceHandle : edge.source;
const tgt = edge.type === 'default' ? edge.targetHandle : edge.target;
errors.push({
message: `Edge "${src} -> ${tgt}" skipped`,
issues,
data: edge,
});
}
});
return { workflow: clone, errors };
};

View File

@ -341,8 +341,8 @@ export const useRecallParameters = () => {
refiner_cfg_scale,
refiner_steps,
refiner_scheduler,
refiner_positive_aesthetic_store,
refiner_negative_aesthetic_store,
refiner_positive_aesthetic_score,
refiner_negative_aesthetic_score,
refiner_start,
} = metadata;
@ -403,21 +403,21 @@ export const useRecallParameters = () => {
if (
isValidSDXLRefinerPositiveAestheticScore(
refiner_positive_aesthetic_store
refiner_positive_aesthetic_score
)
) {
dispatch(
setRefinerPositiveAestheticScore(refiner_positive_aesthetic_store)
setRefinerPositiveAestheticScore(refiner_positive_aesthetic_score)
);
}
if (
isValidSDXLRefinerNegativeAestheticScore(
refiner_negative_aesthetic_store
refiner_negative_aesthetic_score
)
) {
dispatch(
setRefinerNegativeAestheticScore(refiner_negative_aesthetic_store)
setRefinerNegativeAestheticScore(refiner_negative_aesthetic_score)
);
}

View File

@ -1,11 +1,11 @@
import { Flex } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { makeToast } from 'features/system/util/makeToast';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
@ -14,6 +14,7 @@ import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import BaseModelSelect from '../shared/BaseModelSelect';
import CheckpointConfigsSelect from '../shared/CheckpointConfigsSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
import { getModelName } from './util';
type AdvancedAddCheckpointProps = {
model_path?: string;
@ -28,7 +29,7 @@ export default function AdvancedAddCheckpoint(
const advancedAddCheckpointForm = useForm<CheckpointModelConfig>({
initialValues: {
model_name: model_path?.split('\\').splice(-1)[0]?.split('.')[0] ?? '',
model_name: model_path ? getModelName(model_path) : '',
base_model: 'sd-1',
model_type: 'main',
path: model_path ? model_path : '',
@ -100,6 +101,17 @@ export default function AdvancedAddCheckpoint(
label="Model Location"
required
{...advancedAddCheckpointForm.getInputProps('path')}
onBlur={(e) => {
if (advancedAddCheckpointForm.values['model_name'] === '') {
const modelName = getModelName(e.currentTarget.value);
if (modelName) {
advancedAddCheckpointForm.setFieldValue(
'model_name',
modelName as string
);
}
}
}}
/>
<IAIMantineTextInput
label="Description"

View File

@ -1,16 +1,17 @@
import { Flex } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { makeToast } from 'features/system/util/makeToast';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useTranslation } from 'react-i18next';
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
import { DiffusersModelConfig } from 'services/api/types';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import BaseModelSelect from '../shared/BaseModelSelect';
import ModelVariantSelect from '../shared/ModelVariantSelect';
import { getModelName } from './util';
type AdvancedAddDiffusersProps = {
model_path?: string;
@ -25,7 +26,7 @@ export default function AdvancedAddDiffusers(props: AdvancedAddDiffusersProps) {
const advancedAddDiffusersForm = useForm<DiffusersModelConfig>({
initialValues: {
model_name: model_path?.split('\\').splice(-1)[0] ?? '',
model_name: model_path ? getModelName(model_path, false) : '',
base_model: 'sd-1',
model_type: 'main',
path: model_path ? model_path : '',
@ -92,6 +93,17 @@ export default function AdvancedAddDiffusers(props: AdvancedAddDiffusersProps) {
label="Model Location"
placeholder="Provide the path to a local folder where your Diffusers Model is stored"
{...advancedAddDiffusersForm.getInputProps('path')}
onBlur={(e) => {
if (advancedAddDiffusersForm.values['model_name'] === '') {
const modelName = getModelName(e.currentTarget.value, false);
if (modelName) {
advancedAddDiffusersForm.setFieldValue(
'model_name',
modelName as string
);
}
}
}}
/>
<IAIMantineTextInput
label="Description"

View File

@ -0,0 +1,15 @@
export function getModelName(filepath: string, isCheckpoint: boolean = true) {
let regex;
if (isCheckpoint) {
regex = new RegExp('[^\\\\/]+(?=\\.)');
} else {
regex = new RegExp('[^\\\\/]+(?=[\\\\/]?$)');
}
const match = filepath.match(regex);
if (match) {
return match[0];
} else {
return '';
}
}

View File

@ -28,6 +28,8 @@ import {
} from '../util';
import { boardsApi } from './boards';
import { ImageMetadataAndWorkflow } from 'features/nodes/types/types';
import { fetchBaseQuery } from '@reduxjs/toolkit/dist/query';
import { $authToken, $projectId } from '../client';
export const imagesApi = api.injectEndpoints({
endpoints: (build) => ({
@ -115,18 +117,40 @@ export const imagesApi = api.injectEndpoints({
],
keepUnusedDataFor: 86400, // 24 hours
}),
getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, string>({
query: (image_name) => ({
url: `images/i/${image_name}/full`,
responseHandler: async (res) => {
return await res.blob();
},
}),
providesTags: (result, error, image_name) => [
{ type: 'ImageMetadataFromFile', id: image_name },
getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, ImageDTO>({
queryFn: async (args: ImageDTO, api, extraOptions) => {
const authToken = $authToken.get();
const projectId = $projectId.get();
const customBaseQuery = fetchBaseQuery({
baseUrl: '',
prepareHeaders: (headers) => {
if (authToken) {
headers.set('Authorization', `Bearer ${authToken}`);
}
if (projectId) {
headers.set('project-id', projectId);
}
return headers;
},
responseHandler: async (res) => {
return await res.blob();
},
});
const response = await customBaseQuery(
args.image_url,
api,
extraOptions
);
const data = await getMetadataAndWorkflowFromImageBlob(
response.data as Blob
);
return { data };
},
providesTags: (result, error, image_dto) => [
{ type: 'ImageMetadataFromFile', id: image_dto.image_name },
],
transformResponse: (response: Blob) =>
getMetadataAndWorkflowFromImageBlob(response),
keepUnusedDataFor: 86400, // 24 hours
}),
clearIntermediates: build.mutation<number, void>({

File diff suppressed because one or more lines are too long

View File

@ -2970,6 +2970,11 @@ commondir@^1.0.1:
resolved "https://registry.yarnpkg.com/commondir/-/commondir-1.0.1.tgz#ddd800da0c66127393cca5950ea968a3aaf1253b"
integrity sha512-W9pAhw0ja1Edb5GVdIF1mjZw/ASI0AlShXM83UUGe2DVr5TdAPEA1OA8m/g8zWp9x6On7gqufY+FatDbC3MDQg==
compare-versions@^6.1.0:
version "6.1.0"
resolved "https://registry.yarnpkg.com/compare-versions/-/compare-versions-6.1.0.tgz#3f2131e3ae93577df111dba133e6db876ffe127a"
integrity sha512-LNZQXhqUvqUTotpZ00qLSaify3b4VFD588aRr8MKFw4CMUr98ytzCW5wDH5qx/DEY5kCDXcbcRuCqL0szEf2tg==
compute-scroll-into-view@1.0.20:
version "1.0.20"
resolved "https://registry.yarnpkg.com/compute-scroll-into-view/-/compute-scroll-into-view-1.0.20.tgz#1768b5522d1172754f5d0c9b02de3af6be506a43"

View File

@ -1,283 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "ycYWcsEKc6w7"
},
"source": [
"# Stable Diffusion AI Notebook (Release 2.0.0)\n",
"\n",
"<img src=\"https://user-images.githubusercontent.com/60411196/186547976-d9de378a-9de8-4201-9c25-c057a9c59bad.jpeg\" alt=\"stable-diffusion-ai\" width=\"170px\"/> <br>\n",
"#### Instructions:\n",
"1. Execute each cell in order to mount a Dream bot and create images from text. <br>\n",
"2. Once cells 1-8 were run correctly you'll be executing a terminal in cell #9, you'll need to enter `python scripts/dream.py` command to run Dream bot.<br> \n",
"3. After launching dream bot, you'll see: <br> `Dream > ` in terminal. <br> Insert a command, eg. `Dream > Astronaut floating in a distant galaxy`, or type `-h` for help.\n",
"3. After completion you'll see your generated images in path `stable-diffusion/outputs/img-samples/`, you can also show last generated images in cell #10.\n",
"4. To quit Dream bot use `q` command. <br> \n",
"---\n",
"<font color=\"red\">Note:</font> It takes some time to load, but after installing all dependencies you can use the bot all time you want while colab instance is up. <br>\n",
"<font color=\"red\">Requirements:</font> For this notebook to work you need to have [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) stored in your Google Drive, it will be needed in cell #7\n",
"##### For more details visit Github repository: [invoke-ai/InvokeAI](https://github.com/invoke-ai/InvokeAI)\n",
"---\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dr32VLxlnouf"
},
"source": [
"## ◢ Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "a2Z5Qu_o8VtQ"
},
"outputs": [],
"source": [
"# @title 1. Check current GPU assigned\n",
"!nvidia-smi -L\n",
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "vbI9ZsQHzjqF"
},
"outputs": [],
"source": [
"# @title 2. Download stable-diffusion Repository\n",
"from os.path import exists\n",
"\n",
"!git clone --quiet https://github.com/invoke-ai/InvokeAI.git # Original repo\n",
"%cd /content/InvokeAI/\n",
"!git checkout --quiet tags/v2.0.0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "QbXcGXYEFSNB"
},
"outputs": [],
"source": [
"# @title 3. Install dependencies\n",
"import gc\n",
"\n",
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-base.txt\n",
"!wget https://raw.githubusercontent.com/invoke-ai/InvokeAI/development/environments-and-requirements/requirements-win-colab-cuda.txt\n",
"!pip install colab-xterm\n",
"!pip install -r requirements-lin-win-colab-CUDA.txt\n",
"!pip install clean-fid torchtext\n",
"!pip install transformers\n",
"gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "8rSMhgnAttQa"
},
"outputs": [],
"source": [
"# @title 4. Restart Runtime\n",
"exit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "ChIDWxLVHGGJ"
},
"outputs": [],
"source": [
"# @title 5. Load small ML models required\n",
"import gc\n",
"\n",
"%cd /content/InvokeAI/\n",
"!python scripts/preload_models.py\n",
"gc.collect()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "795x1tMoo8b1"
},
"source": [
"## ◢ Configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "YEWPV-sF1RDM"
},
"outputs": [],
"source": [
"# @title 6. Mount google Drive\n",
"from google.colab import drive\n",
"\n",
"drive.mount(\"/content/drive\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "zRTJeZ461WGu"
},
"outputs": [],
"source": [
"# @title 7. Drive Path to model\n",
"# @markdown Path should start with /content/drive/path-to-your-file <br>\n",
"# @markdown <font color=\"red\">Note:</font> Model should be downloaded from https://huggingface.co <br>\n",
"# @markdown Lastest release: [Stable-Diffusion-v-1-4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)\n",
"from os.path import exists\n",
"\n",
"model_path = \"\" # @param {type:\"string\"}\n",
"if exists(model_path):\n",
" print(\"✅ Valid directory\")\n",
"else:\n",
" print(\"❌ File doesn't exist\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "UY-NNz4I8_aG"
},
"outputs": [],
"source": [
"# @title 8. Symlink to model\n",
"\n",
"from os.path import exists\n",
"import os\n",
"\n",
"# Folder creation if it doesn't exist\n",
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1\"):\n",
" print(\"❗ Dir stable-diffusion-v1 already exists\")\n",
"else:\n",
" %mkdir /content/InvokeAI/models/ldm/stable-diffusion-v1\n",
" print(\"✅ Dir stable-diffusion-v1 created\")\n",
"\n",
"# Symbolic link if it doesn't exist\n",
"if exists(\"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"):\n",
" print(\"❗ Symlink already created\")\n",
"else:\n",
" src = model_path\n",
" dst = \"/content/InvokeAI/models/ldm/stable-diffusion-v1/model.ckpt\"\n",
" os.symlink(src, dst)\n",
" print(\"✅ Symbolic link created successfully\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mc28N0_NrCQH"
},
"source": [
"## ◢ Execution"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "ir4hCrMIuUpl"
},
"outputs": [],
"source": [
"# @title 9. Run Terminal and Execute Dream bot\n",
"# @markdown <font color=\"blue\">Steps:</font> <br>\n",
"# @markdown 1. Execute command `python scripts/invoke.py` to run InvokeAI.<br>\n",
"# @markdown 2. After initialized you'll see `Dream>` line.<br>\n",
"# @markdown 3. Example text: `Astronaut floating in a distant galaxy` <br>\n",
"# @markdown 4. To quit Dream bot use: `q` command.<br>\n",
"\n",
"%load_ext colabxterm\n",
"%xterm\n",
"gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "qnLohSHmKoGk"
},
"outputs": [],
"source": [
"#@title 10. Show the last 15 generated images\n",
"import glob\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.image as mpimg\n",
"%matplotlib inline\n",
"\n",
"images = []\n",
"for img_path in sorted(glob.glob('/content/InvokeAI/outputs/img-samples/*.png'), reverse=True):\n",
" images.append(mpimg.imread(img_path))\n",
"\n",
"images = images[:15] \n",
"\n",
"plt.figure(figsize=(20,10))\n",
"\n",
"columns = 5\n",
"for i, image in enumerate(images):\n",
" ax = plt.subplot(len(images) / columns + 1, columns, i + 1)\n",
" ax.axes.xaxis.set_visible(False)\n",
" ax.axes.yaxis.set_visible(False)\n",
" ax.axis('off')\n",
" plt.imshow(image)\n",
" gc.collect()\n",
"\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"private_outputs": true,
"provenance": []
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3.9.12 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.12"
},
"vscode": {
"interpreter": {
"hash": "4e870c5c5fe42db7e2c5647ae5af656ff3391bf8c2b729cbf7fa0e16ca8cb5af"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@ -1,339 +0,0 @@
from torchvision.datasets.utils import download_url
from ldm.util import instantiate_from_config
import torch
import os
# todo ?
from google.colab import files
from IPython.display import Image as ipyimg
import ipywidgets as widgets
from PIL import Image
from einops import rearrange, repeat
import torchvision
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import ismap
import time
from omegaconf import OmegaConf
from ldm.invoke.devices import choose_torch_device
def download_models(mode):
if mode == "superresolution":
# this is the small bsr light model
url_conf = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
url_ckpt = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
path_conf = "logs/diffusion/superresolution_bsr/configs/project.yaml"
path_ckpt = "logs/diffusion/superresolution_bsr/checkpoints/last.ckpt"
download_url(url_conf, path_conf)
download_url(url_ckpt, path_ckpt)
path_conf = path_conf + "/?dl=1" # fix it
path_ckpt = path_ckpt + "/?dl=1" # fix it
return path_conf, path_ckpt
else:
raise NotImplementedError
def load_model_from_config(config, ckpt):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.cuda()
model.eval()
return {"model": model}, global_step
def get_model(mode):
path_conf, path_ckpt = download_models(mode)
config = OmegaConf.load(path_conf)
model, step = load_model_from_config(config, path_ckpt)
return model
def get_custom_cond(mode):
dest = "data/example_conditioning"
if mode == "superresolution":
uploaded_img = files.upload()
filename = next(iter(uploaded_img))
name, filetype = filename.split(".") # todo assumes just one dot in name !
os.rename(f"{filename}", f"{dest}/{mode}/custom_{name}.{filetype}")
elif mode == "text_conditional":
w = widgets.Text(value="A cake with cream!", disabled=True)
display(w) # noqa: F821
with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", "w") as f:
f.write(w.value)
elif mode == "class_conditional":
w = widgets.IntSlider(min=0, max=1000)
display(w) # noqa: F821
with open(f"{dest}/{mode}/custom.txt", "w") as f:
f.write(w.value)
else:
raise NotImplementedError(f"cond not implemented for mode{mode}")
def get_cond_options(mode):
path = "data/example_conditioning"
path = os.path.join(path, mode)
onlyfiles = [f for f in sorted(os.listdir(path))]
return path, onlyfiles
def select_cond_path(mode):
path = "data/example_conditioning" # todo
path = os.path.join(path, mode)
onlyfiles = [f for f in sorted(os.listdir(path))]
selected = widgets.RadioButtons(options=onlyfiles, description="Select conditioning:", disabled=False)
display(selected) # noqa: F821
selected_path = os.path.join(path, selected.value)
return selected_path
def get_cond(mode, selected_path):
example = dict()
if mode == "superresolution":
up_f = 4
visualize_cond_img(selected_path)
c = Image.open(selected_path)
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)
c_up = rearrange(c_up, "1 c h w -> 1 h w c")
c = rearrange(c, "1 c h w -> 1 h w c")
c = 2.0 * c - 1.0
device = choose_torch_device()
c = c.to(device)
example["LR_image"] = c
example["image"] = c_up
return example
def visualize_cond_img(path):
display(ipyimg(filename=path)) # noqa: F821
def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None):
example = get_cond(task, selected_path)
save_intermediate_vid = False
n_runs = 1
masked = False
guider = None
ckwargs = None
mode = "ddim"
ddim_use_x0_pred = False
temperature = 1.0
eta = 1.0
make_progrow = True
custom_shape = None
height, width = example["image"].shape[1:3]
split_input = height >= 128 and width >= 128
if split_input:
ks = 128
stride = 64
vqf = 4 #
model.split_input_params = {
"ks": (ks, ks),
"stride": (stride, stride),
"vqf": vqf,
"patch_distributed_vq": True,
"tie_braker": False,
"clip_max_weight": 0.5,
"clip_min_weight": 0.01,
"clip_max_tie_weight": 0.5,
"clip_min_tie_weight": 0.01,
}
else:
if hasattr(model, "split_input_params"):
delattr(model, "split_input_params")
invert_mask = False
x_T = None
for n in range(n_runs):
if custom_shape is not None:
x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_T = repeat(x_T, "1 c h w -> b c h w", b=custom_shape[0])
logs = make_convolutional_sample(
example,
model,
mode=mode,
custom_steps=custom_steps,
eta=eta,
swap_mode=False,
masked=masked,
invert_mask=invert_mask,
quantize_x0=False,
custom_schedule=None,
decode_interval=10,
resize_enabled=resize_enabled,
custom_shape=custom_shape,
temperature=temperature,
noise_dropout=0.0,
corrector=guider,
corrector_kwargs=ckwargs,
x_T=x_T,
save_intermediate_vid=save_intermediate_vid,
make_progrow=make_progrow,
ddim_use_x0_pred=ddim_use_x0_pred,
)
return logs
@torch.no_grad()
def convsample_ddim(
model,
cond,
steps,
shape,
eta=1.0,
callback=None,
normals_sequence=None,
mask=None,
x0=None,
quantize_x0=False,
img_callback=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
x_T=None,
log_every_t=None,
):
ddim = DDIMSampler(model)
bs = shape[0] # dont know where this comes from but wayne
shape = shape[1:] # cut batch dim
print(f"Sampling with eta = {eta}; steps: {steps}")
samples, intermediates = ddim.sample(
steps,
batch_size=bs,
shape=shape,
conditioning=cond,
callback=callback,
normals_sequence=normals_sequence,
quantize_x0=quantize_x0,
eta=eta,
mask=mask,
x0=x0,
temperature=temperature,
verbose=False,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(
batch,
model,
mode="vanilla",
custom_steps=None,
eta=1.0,
swap_mode=False,
masked=False,
invert_mask=True,
quantize_x0=False,
custom_schedule=None,
decode_interval=1000,
resize_enabled=False,
custom_shape=None,
temperature=1.0,
noise_dropout=0.0,
corrector=None,
corrector_kwargs=None,
x_T=None,
save_intermediate_vid=False,
make_progrow=True,
ddim_use_x0_pred=False,
):
log = dict()
z, c, x, xrec, xc = model.get_input(
batch,
model.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=not (hasattr(model, "split_input_params") and model.cond_stage_key == "coordinates_bbox"),
return_original_cond=True,
)
log_every_t = 1 if save_intermediate_vid else None
if custom_shape is not None:
z = torch.randn(custom_shape)
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
z0 = None
log["input"] = x
log["reconstruction"] = xrec
if ismap(xc):
log["original_conditioning"] = model.to_rgb(xc)
if hasattr(model, "cond_stage_key"):
log[model.cond_stage_key] = model.to_rgb(xc)
else:
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_model:
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_key == "class_label":
log[model.cond_stage_key] = xc[model.cond_stage_key]
with model.ema_scope("Plotting"):
t0 = time.time()
img_cb = None
sample, intermediates = convsample_ddim(
model,
c,
steps=custom_steps,
shape=z.shape,
eta=eta,
quantize_x0=quantize_x0,
img_callback=img_cb,
mask=None,
x0=z0,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
)
t1 = time.time()
if ddim_use_x0_pred:
sample = intermediates["pred_x0"][-1]
x_sample = model.decode_first_stage(sample)
try:
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
except Exception:
pass
log["sample"] = x_sample
log["time"] = t1 - t0
return log

View File

@ -74,6 +74,7 @@ dependencies = [
"rich~=13.3",
"safetensors==0.3.1",
"scikit-image~=0.21.0",
"semver~=3.0.1",
"send2trash",
"test-tube~=0.7.5",
"torch~=2.0.1",

View File

@ -1,52 +0,0 @@
import os
import torch
import cv2
import numpy as np
from PIL import Image
from diffusers.utils import load_image
from diffusers.models.controlnet import ControlNetModel
from invokeai.backend.generator import Txt2Img
from invokeai.backend.model_management import ModelManager
print("loading 'Girl with a Pearl Earring' image")
image = load_image(
"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
)
image.show()
print("preprocessing image with Canny edge detection")
image_np = np.array(image)
low_threshold = 100
high_threshold = 200
canny_np = cv2.Canny(image_np, low_threshold, high_threshold)
canny_image = Image.fromarray(canny_np)
canny_image.show()
# using invokeai model management for base model
print("loading base model stable-diffusion-1.5")
model_config_path = os.getcwd() + "/../configs/models.yaml"
model_manager = ModelManager(model_config_path)
model = model_manager.get_model("stable-diffusion-1.5")
print("loading control model lllyasviel/sd-controlnet-canny")
canny_controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16).to(
"cuda"
)
print("testing Txt2Img() constructor with control_model arg")
txt2img_canny = Txt2Img(model, control_model=canny_controlnet)
print("testing Txt2Img.generate() with control_image arg")
outputs = txt2img_canny.generate(
prompt="old man",
control_image=canny_image,
control_weight=1.0,
seed=0,
num_steps=30,
precision="float16",
)
generate_output = next(outputs)
out_image = generate_output.image
out_image.show()

View File

@ -1,33 +0,0 @@
#!/usr/bin/env python
"""
Read a checkpoint/safetensors file and write out a template .json file containing
its metadata for use in fast model probing.
"""
import argparse
import json
from pathlib import Path
from invokeai.backend.model_management.models.base import read_checkpoint_meta
parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model")
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
parser.add_argument("--template", "--out", type=Path, help="Path to the output .json file")
opt = parser.parse_args()
ckpt = read_checkpoint_meta(opt.checkpoint)
while "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
tmpl = {}
for key, tensor in ckpt.items():
tmpl[key] = list(tensor.shape)
try:
with open(opt.template, "w") as f:
json.dump(tmpl, f)
print(f"Template written out as {opt.template}")
except Exception as e:
print(f"An exception occurred while writing template: {str(e)}")

View File

@ -1,14 +0,0 @@
#!/usr/bin/env python
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
import warnings
from invokeai.app.cli_app import invoke_cli
warnings.warn(
"dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API",
DeprecationWarning,
)
invoke_cli()

View File

@ -1,4 +0,0 @@
from invokeai.backend.install.migrate_to_3 import main
if __name__=='__main__':
main()

2
scripts/invokeai-model-install.py Normal file → Executable file
View File

@ -1,3 +1,5 @@
#!/usr/bin/env python
from invokeai.frontend.install.model_install import main
main()

View File

@ -1,41 +0,0 @@
#!/bin/bash
wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip
wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip
wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip
wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip
wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip
wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip
wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip
cd models/first_stage_models/kl-f4
unzip -o model.zip
cd ../kl-f8
unzip -o model.zip
cd ../kl-f16
unzip -o model.zip
cd ../kl-f32
unzip -o model.zip
cd ../vq-f4
unzip -o model.zip
cd ../vq-f4-noattn
unzip -o model.zip
cd ../vq-f8
unzip -o model.zip
cd ../vq-f8-n256
unzip -o model.zip
cd ../vq-f16
unzip -o model.zip
cd ../..

View File

@ -1,49 +0,0 @@
#!/bin/bash
wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip
wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip
wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip
wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip
wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip
wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip
wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip
wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip
wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip
wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip
cd models/ldm/celeba256
unzip -o celeba-256.zip
cd ../ffhq256
unzip -o ffhq-256.zip
cd ../lsun_churches256
unzip -o lsun_churches-256.zip
cd ../lsun_beds256
unzip -o lsun_beds-256.zip
cd ../text2img256
unzip -o model.zip
cd ../cin256
unzip -o model.zip
cd ../semantic_synthesis512
unzip -o model.zip
cd ../semantic_synthesis256
unzip -o model.zip
cd ../bsr_sr
unzip -o model.zip
cd ../layout2img-openimages256
unzip -o model.zip
cd ../inpainting_big
unzip -o model.zip
cd ../..

View File

@ -1,285 +0,0 @@
"""make variations of input image"""
import argparse
import os
import PIL
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange, repeat
from torchvision.utils import make_grid
from torch import autocast
from contextlib import nullcontext
from pytorch_lightning import seed_everything
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.invoke.devices import choose_torch_device
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.to(choose_torch_device())
model.eval()
return model
def load_img(path):
image = Image.open(path).convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}")
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt",
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render",
)
parser.add_argument("--init-img", type=str, nargs="?", help="path to the input image")
parser.add_argument(
"--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/img2img-samples"
)
parser.add_argument(
"--skip_grid",
action="store_true",
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
parser.add_argument(
"--skip_save",
action="store_true",
help="do not save indiviual samples. For speed measurements.",
)
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="number of ddim sampling steps",
)
parser.add_argument(
"--plms",
action="store_true",
help="use plms sampling",
)
parser.add_argument(
"--fixed_code",
action="store_true",
help="if enabled, uses the same starting code across all samples ",
)
parser.add_argument(
"--ddim_eta",
type=float,
default=0.0,
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
)
parser.add_argument(
"--n_iter",
type=int,
default=1,
help="sample this often",
)
parser.add_argument(
"--C",
type=int,
default=4,
help="latent channels",
)
parser.add_argument(
"--f",
type=int,
default=8,
help="downsampling factor, most often 8 or 16",
)
parser.add_argument(
"--n_samples",
type=int,
default=2,
help="how many samples to produce for each given prompt. A.k.a batch size",
)
parser.add_argument(
"--n_rows",
type=int,
default=0,
help="rows in the grid (default: n_samples)",
)
parser.add_argument(
"--scale",
type=float,
default=5.0,
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
)
parser.add_argument(
"--strength",
type=float,
default=0.75,
help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
)
parser.add_argument(
"--from-file",
type=str,
help="if specified, load prompts from this file",
)
parser.add_argument(
"--config",
type=str,
default="configs/stable-diffusion/v1-inference.yaml",
help="path to config which constructs model",
)
parser.add_argument(
"--ckpt",
type=str,
default="models/ldm/stable-diffusion-v1/model.ckpt",
help="path to checkpoint of model",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="the seed (for reproducible sampling)",
)
parser.add_argument(
"--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast"
)
opt = parser.parse_args()
seed_everything(opt.seed)
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device(choose_torch_device())
model = model.to(device)
if opt.plms:
raise NotImplementedError("PLMS sampler not (yet) supported")
sampler = PLMSSampler(model)
else:
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir
batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
if not opt.from_file:
prompt = opt.prompt
assert prompt is not None
data = [batch_size * [prompt]]
else:
print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
data = list(chunk(data, batch_size))
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
assert os.path.isfile(opt.init_img)
init_image = load_img(opt.init_img).to(device)
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)
assert 0.0 <= opt.strength <= 1.0, "can only work with strength in [0.0, 1.0]"
t_enc = int(opt.strength * opt.ddim_steps)
print(f"target t_enc is {t_enc} steps")
precision_scope = autocast if opt.precision == "autocast" else nullcontext
if device.type in ["mps", "cpu"]:
precision_scope = nullcontext # have to use f32 on mps
with torch.no_grad():
with precision_scope(device.type):
with model.ema_scope():
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
uc = None
if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device))
# decode it
samples = sampler.decode(
z_enc,
c,
t_enc,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
)
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if not opt.skip_save:
for x_sample in x_samples:
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.png")
)
base_count += 1
all_samples.append(x_samples)
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, "n b c h w -> (n b) c h w")
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")
if __name__ == "__main__":
main()

View File

@ -1,94 +0,0 @@
import argparse
import glob
import os
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm
import numpy as np
import torch
from main import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.invoke.devices import choose_torch_device
def make_batch(image, mask, device):
image = np.array(Image.open(image).convert("RGB"))
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
mask = np.array(Image.open(mask).convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = (1 - mask) * image
batch = {"image": image, "mask": mask, "masked_image": masked_image}
for k in batch:
batch[k] = batch[k].to(device=device)
batch[k] = batch[k] * 2.0 - 1.0
return batch
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--indir",
type=str,
nargs="?",
help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
)
parser.add_argument(
"--steps",
type=int,
default=50,
help="number of ddim sampling steps",
)
opt = parser.parse_args()
masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
images = [x.replace("_mask.png", ".png") for x in masks]
print(f"Found {len(masks)} inputs.")
config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], strict=False)
device = choose_torch_device()
model = model.to(device)
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)
with torch.no_grad():
with model.ema_scope():
for image, mask in tqdm(zip(images, masks)):
outpath = os.path.join(opt.outdir, os.path.split(image)[1])
batch = make_batch(image, mask, device=device)
# encode masked image and concat downsampled mask
c = model.cond_stage_model.encode(batch["masked_image"])
cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:])
c = torch.cat((c, cc), dim=1)
shape = (c.shape[1] - 1,) + c.shape[2:]
samples_ddim, _ = sampler.sample(
S=opt.steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False
)
x_samples_ddim = model.decode_first_stage(samples_ddim)
image = torch.clamp((batch["image"] + 1.0) / 2.0, min=0.0, max=1.0)
mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0)
predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
inpainted = (1 - mask) * image + mask * predicted_image
inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
Image.fromarray(inpainted.astype(np.uint8)).save(outpath)

View File

@ -1,397 +0,0 @@
import argparse
import glob
import os
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import scann
import time
from multiprocessing import cpu_count
from ldm.util import instantiate_from_config, parallel_data_prefetch
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder
DATABASES = [
"openimages",
"artbench-art_nouveau",
"artbench-baroque",
"artbench-expressionism",
"artbench-impressionism",
"artbench-post_impressionism",
"artbench-realism",
"artbench-romanticism",
"artbench-renaissance",
"artbench-surrealism",
"artbench-ukiyo_e",
]
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
return model
class Searcher(object):
def __init__(self, database, retriever_version="ViT-L/14"):
assert database in DATABASES
# self.database = self.load_database(database)
self.database_name = database
self.searcher_savedir = f"data/rdm/searchers/{self.database_name}"
self.database_path = f"data/rdm/retrieval_databases/{self.database_name}"
self.retriever = self.load_retriever(version=retriever_version)
self.database = {"embedding": [], "img_id": [], "patch_coords": []}
self.load_database()
self.load_searcher()
def train_searcher(self, k, metric="dot_product", searcher_savedir=None):
print("Start training searcher")
searcher = scann.scann_ops_pybind.builder(
self.database["embedding"] / np.linalg.norm(self.database["embedding"], axis=1)[:, np.newaxis], k, metric
)
self.searcher = searcher.score_brute_force().build()
print("Finish training searcher")
if searcher_savedir is not None:
print(f'Save trained searcher under "{searcher_savedir}"')
os.makedirs(searcher_savedir, exist_ok=True)
self.searcher.serialize(searcher_savedir)
def load_single_file(self, saved_embeddings):
compressed = np.load(saved_embeddings)
self.database = {key: compressed[key] for key in compressed.files}
print("Finished loading of clip embeddings.")
def load_multi_files(self, data_archive):
out_data = {key: [] for key in self.database}
for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."):
for key in d.files:
out_data[key].append(d[key])
return out_data
def load_database(self):
print(f'Load saved patch embedding from "{self.database_path}"')
file_content = glob.glob(os.path.join(self.database_path, "*.npz"))
if len(file_content) == 1:
self.load_single_file(file_content[0])
elif len(file_content) > 1:
data = [np.load(f) for f in file_content]
prefetched_data = parallel_data_prefetch(
self.load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict"
)
self.database = {
key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in self.database
}
else:
raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')
print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')
def load_retriever(
self,
version="ViT-L/14",
):
model = FrozenClipImageEmbedder(model=version)
if torch.cuda.is_available():
model.cuda()
model.eval()
return model
def load_searcher(self):
print(f"load searcher for database {self.database_name} from {self.searcher_savedir}")
self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
print("Finished loading searcher.")
def search(self, x, k):
if self.searcher is None and self.database["embedding"].shape[0] < 2e4:
self.train_searcher(k) # quickly fit searcher on the fly for small databases
assert self.searcher is not None, "Cannot search with uninitialized searcher"
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
if len(x.shape) == 3:
x = x[:, 0]
query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis]
start = time.time()
nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
end = time.time()
out_embeddings = self.database["embedding"][nns]
out_img_ids = self.database["img_id"][nns]
out_pc = self.database["patch_coords"][nns]
out = {
"nn_embeddings": out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
"img_ids": out_img_ids,
"patch_coords": out_pc,
"queries": x,
"exec_time": end - start,
"nns": nns,
"q_embeddings": query_embeddings,
}
return out
def __call__(self, x, n):
return self.search(x, n)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc)
# TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt?
parser.add_argument(
"--prompt",
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render",
)
parser.add_argument(
"--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples"
)
parser.add_argument(
"--skip_grid",
action="store_true",
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="number of ddim sampling steps",
)
parser.add_argument(
"--n_repeat",
type=int,
default=1,
help="number of repeats in CLIP latent space",
)
parser.add_argument(
"--plms",
action="store_true",
help="use plms sampling",
)
parser.add_argument(
"--ddim_eta",
type=float,
default=0.0,
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
)
parser.add_argument(
"--n_iter",
type=int,
default=1,
help="sample this often",
)
parser.add_argument(
"--H",
type=int,
default=768,
help="image height, in pixel space",
)
parser.add_argument(
"--W",
type=int,
default=768,
help="image width, in pixel space",
)
parser.add_argument(
"--n_samples",
type=int,
default=3,
help="how many samples to produce for each given prompt. A.k.a batch size",
)
parser.add_argument(
"--n_rows",
type=int,
default=0,
help="rows in the grid (default: n_samples)",
)
parser.add_argument(
"--scale",
type=float,
default=5.0,
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
)
parser.add_argument(
"--from-file",
type=str,
help="if specified, load prompts from this file",
)
parser.add_argument(
"--config",
type=str,
default="configs/retrieval-augmented-diffusion/768x768.yaml",
help="path to config which constructs model",
)
parser.add_argument(
"--ckpt",
type=str,
default="models/rdm/rdm768x768/model.ckpt",
help="path to checkpoint of model",
)
parser.add_argument(
"--clip_type",
type=str,
default="ViT-L/14",
help="which CLIP model to use for retrieval and NN encoding",
)
parser.add_argument(
"--database",
type=str,
default="artbench-surrealism",
choices=DATABASES,
help="The database used for the search, only applied when --use_neighbors=True",
)
parser.add_argument(
"--use_neighbors",
default=False,
action="store_true",
help="Include neighbors in addition to text prompt for conditioning",
)
parser.add_argument(
"--knn",
default=10,
type=int,
help="The number of included neighbors, only applied when --use_neighbors=True",
)
opt = parser.parse_args()
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device)
if opt.plms:
sampler = PLMSSampler(model)
else:
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir
batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
if not opt.from_file:
prompt = opt.prompt
assert prompt is not None
data = [batch_size * [prompt]]
else:
print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
data = list(chunk(data, batch_size))
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
print(f"sampling scale for cfg is {opt.scale:.2f}")
searcher = None
if opt.use_neighbors:
searcher = Searcher(opt.database)
with torch.no_grad():
with model.ema_scope():
for n in trange(opt.n_iter, desc="Sampling"):
all_samples = list()
for prompts in tqdm(data, desc="data"):
print("sampling prompts:", prompts)
if isinstance(prompts, tuple):
prompts = list(prompts)
c = clip_text_encoder.encode(prompts)
uc = None
if searcher is not None:
nn_dict = searcher(c, opt.knn)
c = torch.cat([c, torch.from_numpy(nn_dict["nn_embeddings"]).cuda()], dim=1)
if opt.scale != 1.0:
uc = torch.zeros_like(c)
if isinstance(prompts, tuple):
prompts = list(prompts)
shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model
samples_ddim, _ = sampler.sample(
S=opt.ddim_steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples_ddim:
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.png")
)
base_count += 1
all_samples.append(x_samples_ddim)
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, "n b c h w -> (n b) c h w")
grid = make_grid(grid, nrow=n_rows)
# to image
grid_np = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
Image.fromarray(grid_np.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")

File diff suppressed because one or more lines are too long

View File

@ -1,898 +0,0 @@
import argparse
import datetime
import glob
import os
import sys
import numpy as np
import time
import torch
import torchvision
import pytorch_lightning as pl
from packaging import version
from omegaconf import OmegaConf
from torch.utils.data import DataLoader, Dataset
from functools import partial
from PIL import Image
from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities import rank_zero_info
from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config
def fix_func(orig):
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
def new_func(*args, **kw):
device = kw.get("device", "mps")
kw["device"] = "cpu"
return orig(*args, **kw).to(device)
return new_func
return orig
torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
torch.randn_like = fix_func(torch.randn_like)
torch.randint = fix_func(torch.randint)
torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
sd = pl_sd["state_dict"]
config.model.params.ckpt_path = ckpt
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
if torch.cuda.is_available():
model.cuda()
return model
def get_parser(**parser_kwargs):
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument(
"-n",
"--name",
type=str,
const=True,
default="",
nargs="?",
help="postfix for logdir",
)
parser.add_argument(
"-r",
"--resume",
type=str,
const=True,
default="",
nargs="?",
help="resume from logdir or checkpoint in logdir",
)
parser.add_argument(
"-b",
"--base",
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
default=list(),
)
parser.add_argument(
"-t",
"--train",
type=str2bool,
const=True,
default=False,
nargs="?",
help="train",
)
parser.add_argument(
"--no-test",
type=str2bool,
const=True,
default=False,
nargs="?",
help="disable test",
)
parser.add_argument("-p", "--project", help="name of new or path to existing project")
parser.add_argument(
"-d",
"--debug",
type=str2bool,
nargs="?",
const=True,
default=False,
help="enable post-mortem debugging",
)
parser.add_argument(
"-s",
"--seed",
type=int,
default=23,
help="seed for seed_everything",
)
parser.add_argument(
"-f",
"--postfix",
type=str,
default="",
help="post-postfix for default name",
)
parser.add_argument(
"-l",
"--logdir",
type=str,
default="logs",
help="directory for logging dat shit",
)
parser.add_argument(
"--scale_lr",
type=str2bool,
nargs="?",
const=True,
default=True,
help="scale base-lr by ngpu * batch_size * n_accumulate",
)
parser.add_argument(
"--datadir_in_name",
type=str2bool,
nargs="?",
const=True,
default=True,
help="Prepend the final directory in the data_root to the output directory name",
)
parser.add_argument(
"--actual_resume",
type=str,
default="",
help="Path to model to actually resume from",
)
parser.add_argument(
"--data_root",
type=str,
required=True,
help="Path to directory with training images",
)
parser.add_argument(
"--embedding_manager_ckpt",
type=str,
default="",
help="Initialize embedding manager from a checkpoint",
)
parser.add_argument(
"--init_word",
type=str,
help="Word to use as source for initial token embedding.",
)
return parser
def nondefault_trainer_args(opt):
parser = argparse.ArgumentParser()
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args([])
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def worker_init_fn(_):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
if isinstance(dataset, Txt2ImgIterableBaseDataset):
split_size = dataset.num_records // worker_info.num_workers
# reset num_records to the true number to retain reliable length information
dataset.sample_ids = dataset.valid_ids[worker_id * split_size : (worker_id + 1) * split_size]
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
else:
return np.random.seed(np.random.get_state()[1][0] + worker_id)
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
batch_size,
train=None,
validation=None,
test=None,
predict=None,
wrap=False,
num_workers=None,
shuffle_test_loader=False,
use_worker_init_fn=False,
shuffle_val_dataloader=False,
):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = dict()
self.num_workers = num_workers if num_workers is not None else batch_size * 2
self.use_worker_init_fn = use_worker_init_fn
if train is not None:
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs["validation"] = validation
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
if test is not None:
self.dataset_configs["test"] = test
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
if predict is not None:
self.dataset_configs["predict"] = predict
self.predict_dataloader = self._predict_dataloader
self.wrap = wrap
def prepare_data(self):
for data_cfg in self.dataset_configs.values():
instantiate_from_config(data_cfg)
def setup(self, stage=None):
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
if self.wrap:
for k in self.datasets:
self.datasets[k] = WrappedDataset(self.datasets[k])
def _train_dataloader(self):
is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset)
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["train"],
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False if is_iterable_dataset else True,
worker_init_fn=init_fn,
)
def _val_dataloader(self, shuffle=False):
if isinstance(self.datasets["validation"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["validation"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
)
def _test_dataloader(self, shuffle=False):
is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset)
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
# do not shuffle dataloader for iterable dataset
shuffle = shuffle and (not is_iterable_dataset)
return DataLoader(
self.datasets["test"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
)
def _predict_dataloader(self, shuffle=False):
if isinstance(self.datasets["predict"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["predict"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
)
class SetupCallback(Callback):
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
super().__init__()
self.resume = resume
self.now = now
self.logdir = logdir
self.ckptdir = ckptdir
self.cfgdir = cfgdir
self.config = config
self.lightning_config = lightning_config
def on_keyboard_interrupt(self, trainer, pl_module):
if trainer.global_rank == 0:
print("Summoning checkpoint.")
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
def on_pretrain_routine_start(self, trainer, pl_module):
if trainer.global_rank == 0:
# Create logdirs and save configs
os.makedirs(self.logdir, exist_ok=True)
os.makedirs(self.ckptdir, exist_ok=True)
os.makedirs(self.cfgdir, exist_ok=True)
if "callbacks" in self.lightning_config:
if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]:
os.makedirs(
os.path.join(self.ckptdir, "trainstep_checkpoints"),
exist_ok=True,
)
print("Project config")
print(OmegaConf.to_yaml(self.config))
OmegaConf.save(
self.config,
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
)
print("Lightning config")
print(OmegaConf.to_yaml(self.lightning_config))
OmegaConf.save(
OmegaConf.create({"lightning": self.lightning_config}),
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
)
else:
# ModelCheckpoint callback created log directory --- remove it
if not self.resume and os.path.exists(self.logdir):
dst, name = os.path.split(self.logdir)
dst = os.path.join(dst, "child_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
try:
os.rename(self.logdir, dst)
except FileNotFoundError:
pass
class ImageLogger(Callback):
def __init__(
self,
batch_frequency,
max_images,
clamp=True,
increase_log_steps=True,
rescale=True,
disabled=False,
log_on_batch_idx=False,
log_first_step=False,
log_images_kwargs=None,
):
super().__init__()
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
self.logger_log_images = {}
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
self.disabled = disabled
self.log_on_batch_idx = log_on_batch_idx
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.log_first_step = log_first_step
@rank_zero_only
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
root = os.path.join(save_dir, "images", split)
for k in images:
grid = torchvision.utils.make_grid(images[k], nrow=4)
if self.rescale:
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
def log_img(self, pl_module, batch, batch_idx, split="train"):
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
if (
self.check_frequency(check_idx)
and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
and callable(pl_module.log_images)
and self.max_images > 0
):
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
for k in images:
N = min(images[k].shape[0], self.max_images)
images[k] = images[k][:N]
if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().cpu()
if self.clamp:
images[k] = torch.clamp(images[k], -1.0, 1.0)
self.log_local(
pl_module.logger.save_dir,
split,
images,
pl_module.global_step,
pl_module.current_epoch,
batch_idx,
)
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
logger_log_images(pl_module, images, pl_module.global_step, split)
if is_train:
pl_module.train()
def check_frequency(self, check_idx):
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
check_idx > 0 or self.log_first_step
):
try:
self.log_steps.pop(0)
except IndexError as e:
print(e)
pass
return True
return False
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None):
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
self.log_img(pl_module, batch, batch_idx, split="train")
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None):
if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split="val")
if hasattr(pl_module, "calibrate_grad_norm"):
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
class CUDACallback(Callback):
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
def on_train_epoch_start(self, trainer, pl_module):
# Reset the memory use counter
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
torch.cuda.synchronize(trainer.root_gpu)
self.start_time = time.time()
def on_train_epoch_end(self, trainer, pl_module, outputs=None):
if torch.cuda.is_available():
torch.cuda.synchronize(trainer.root_gpu)
epoch_time = time.time() - self.start_time
try:
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
if torch.cuda.is_available():
max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20
max_memory = trainer.training_type_plugin.reduce(max_memory)
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
except AttributeError:
pass
class ModeSwapCallback(Callback):
def __init__(self, swap_step=2000):
super().__init__()
self.is_frozen = False
self.swap_step = swap_step
def on_train_epoch_start(self, trainer, pl_module):
if trainer.global_step < self.swap_step and not self.is_frozen:
self.is_frozen = True
trainer.optimizers = [pl_module.configure_opt_embedding()]
if trainer.global_step > self.swap_step and self.is_frozen:
self.is_frozen = False
trainer.optimizers = [pl_module.configure_opt_model()]
if __name__ == "__main__":
# custom parser to specify config files, train, test and debug mode,
# postfix, resume.
# `--key value` arguments are interpreted as arguments to the trainer.
# `nested.key=value` arguments are interpreted as config parameters.
# configs are merged from left-to-right followed by command line parameters.
# model:
# base_learning_rate: float
# target: path to lightning module
# params:
# key: value
# data:
# target: main.DataModuleFromConfig
# params:
# batch_size: int
# wrap: bool
# train:
# target: path to train dataset
# params:
# key: value
# validation:
# target: path to validation dataset
# params:
# key: value
# test:
# target: path to test dataset
# params:
# key: value
# lightning: (optional, has sane defaults and can be specified on cmdline)
# trainer:
# additional arguments to trainer
# logger:
# logger to instantiate
# modelcheckpoint:
# modelcheckpoint to instantiate
# callbacks:
# callback1:
# target: importpath
# params:
# key: value
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
# add cwd for convenience and to make classes in this file available when
# running as `python main.py`
# (in particular `main.DataModuleFromConfig`)
sys.path.append(os.getcwd())
parser = get_parser()
parser = Trainer.add_argparse_args(parser)
opt, unknown = parser.parse_known_args()
if opt.name and opt.resume:
raise ValueError(
"-n/--name and -r/--resume cannot be specified both."
"If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint"
)
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split("/")
# idx = len(paths)-paths[::-1].index("logs")+1
# logdir = "/".join(paths[:idx])
logdir = "/".join(paths[:-2])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
opt.resume_from_checkpoint = ckpt
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
opt.base = base_configs + opt.base
_tmp = logdir.split("/")
nowname = _tmp[-1]
else:
if opt.name:
name = "_" + opt.name
elif opt.base:
cfg_fname = os.path.split(opt.base[0])[-1]
cfg_name = os.path.splitext(cfg_fname)[0]
name = "_" + cfg_name
else:
name = ""
if opt.datadir_in_name:
now = os.path.basename(os.path.normpath(opt.data_root)) + now
nowname = now + name + opt.postfix
logdir = os.path.join(opt.logdir, nowname)
ckptdir = os.path.join(logdir, "checkpoints")
cfgdir = os.path.join(logdir, "configs")
seed_everything(opt.seed)
try:
# init and save configs
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
lightning_config = config.pop("lightning", OmegaConf.create())
# merge trainer cli with config
trainer_config = lightning_config.get("trainer", OmegaConf.create())
# default to ddp
trainer_config["accelerator"] = "auto"
for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k)
if "gpus" not in trainer_config:
del trainer_config["accelerator"]
cpu = True
else:
gpuinfo = trainer_config["gpus"]
print(f"Running on GPUs {gpuinfo}")
cpu = False
trainer_opt = argparse.Namespace(**trainer_config)
lightning_config.trainer = trainer_config
# model
# config.model.params.personalization_config.params.init_word = opt.init_word
config.model.params.personalization_config.params.embedding_manager_ckpt = opt.embedding_manager_ckpt
if opt.init_word:
config.model.params.personalization_config.params.initializer_words = [opt.init_word]
if opt.actual_resume:
model = load_model_from_config(config, opt.actual_resume)
else:
model = instantiate_from_config(config.model)
# trainer and callbacks
trainer_kwargs = dict()
# default logger configs
def_logger = "csv"
def_logger_target = "CSVLogger"
default_logger_cfgs = {
"wandb": {
"target": "pytorch_lightning.loggers.WandbLogger",
"params": {
"name": nowname,
"save_dir": logdir,
"offline": opt.debug,
"id": nowname,
},
},
def_logger: {
"target": "pytorch_lightning.loggers." + def_logger_target,
"params": {
"name": def_logger,
"save_dir": logdir,
},
},
}
default_logger_cfg = default_logger_cfgs[def_logger]
if "logger" in lightning_config:
logger_cfg = lightning_config.logger
else:
logger_cfg = OmegaConf.create()
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
default_modelckpt_cfg = {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{epoch:06}",
"verbose": True,
"save_last": True,
},
}
if hasattr(model, "monitor"):
print(f"Monitoring {model.monitor} as checkpoint metric.")
default_modelckpt_cfg["params"]["monitor"] = model.monitor
default_modelckpt_cfg["params"]["save_top_k"] = 1
if "modelcheckpoint" in lightning_config:
modelckpt_cfg = lightning_config.modelcheckpoint
else:
modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
if version.parse(pl.__version__) < version.parse("1.4.0"):
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
# add callback which sets up log directory
default_callbacks_cfg = {
"setup_callback": {
"target": "main.SetupCallback",
"params": {
"resume": opt.resume,
"now": now,
"logdir": logdir,
"ckptdir": ckptdir,
"cfgdir": cfgdir,
"config": config,
"lightning_config": lightning_config,
},
},
"image_logger": {
"target": "main.ImageLogger",
"params": {
"batch_frequency": 750,
"max_images": 4,
"clamp": True,
},
},
"learning_rate_logger": {
"target": "main.LearningRateMonitor",
"params": {
"logging_interval": "step",
# "log_momentum": True
},
},
"cuda_callback": {"target": "main.CUDACallback"},
}
if version.parse(pl.__version__) >= version.parse("1.4.0"):
default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks
else:
callbacks_cfg = OmegaConf.create()
if "metrics_over_trainsteps_checkpoint" in callbacks_cfg:
print(
"Caution: Saving checkpoints every n train steps without deleting. This might require some free space."
)
default_metrics_over_trainsteps_ckpt_dict = {
"metrics_over_trainsteps_checkpoint": {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": os.path.join(ckptdir, "trainstep_checkpoints"),
"filename": "{epoch:06}-{step:09}",
"verbose": True,
"save_top_k": -1,
"every_n_train_steps": 10000,
"save_weights_only": True,
},
}
}
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
if "ignore_keys_callback" in callbacks_cfg and hasattr(trainer_opt, "resume_from_checkpoint"):
callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = trainer_opt.resume_from_checkpoint
elif "ignore_keys_callback" in callbacks_cfg:
del callbacks_cfg["ignore_keys_callback"]
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
trainer_kwargs["max_steps"] = trainer_opt.max_steps
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
trainer_opt.accelerator = "mps"
trainer_opt.detect_anomaly = False
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir
# data
config.data.params.train.params.data_root = opt.data_root
config.data.params.validation.params.data_root = opt.data_root
data = instantiate_from_config(config.data)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though
data.prepare_data()
data.setup()
print("#### Data #####")
for k in data.datasets:
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
# configure learning rate
bs, base_lr = (
config.data.params.batch_size,
config.model.base_learning_rate,
)
if not cpu:
gpus = str(lightning_config.trainer.gpus).strip(", ").split(",")
ngpu = len(gpus)
else:
ngpu = 1
if "accumulate_grad_batches" in lightning_config.trainer:
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
else:
accumulate_grad_batches = 1
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
if opt.scale_lr:
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
print(
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
model.learning_rate,
accumulate_grad_batches,
ngpu,
bs,
base_lr,
)
)
else:
model.learning_rate = base_lr
print("++++ NOT USING LR SCALING ++++")
print(f"Setting learning rate to {model.learning_rate:.2e}")
# allow checkpointing via USR1
def melk(*args, **kwargs):
# run all checkpoint hooks
if trainer.global_rank == 0:
print("Summoning checkpoint.")
ckpt_path = os.path.join(ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
def divein(*args, **kwargs):
if trainer.global_rank == 0:
import pudb
pudb.set_trace()
import signal
signal.signal(signal.SIGTERM, melk)
signal.signal(signal.SIGTERM, divein)
# run
if opt.train:
try:
trainer.fit(model, data)
except Exception:
melk()
raise
if not opt.no_test and not trainer.interrupted:
trainer.test(model, data)
except Exception:
if opt.debug and trainer.global_rank == 0:
try:
import pudb as debugger
except ImportError:
import pdb as debugger
debugger.post_mortem()
raise
finally:
# move newly created debug project to debug_runs
if opt.debug and not opt.resume and trainer.global_rank == 0:
dst, name = os.path.split(logdir)
dst = os.path.join(dst, "debug_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
os.rename(logdir, dst)
# if trainer.global_rank == 0:
# print(trainer.profiler.summary())

View File

@ -1,130 +0,0 @@
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
from ldm.modules.embedding_manager import EmbeddingManager
from ldm.invoke.globals import Globals
import argparse
from functools import partial
import torch
def get_placeholder_loop(placeholder_string, embedder, use_bert):
new_placeholder = None
while True:
if new_placeholder is None:
new_placeholder = input(
f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: "
)
else:
new_placeholder = input(
f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: "
)
token = (
get_bert_token_for_string(embedder.tknz_fn, new_placeholder)
if use_bert
else get_clip_token_for_string(embedder.tokenizer, new_placeholder)
)
if token is not None:
return new_placeholder, token
def get_clip_token_for_string(tokenizer, string):
batch_encoding = tokenizer(
string,
truncation=True,
max_length=77,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"]
if torch.count_nonzero(tokens - 49407) == 2:
return tokens[0, 1]
return None
def get_bert_token_for_string(tokenizer, string):
token = tokenizer(string)
if torch.count_nonzero(token) == 3:
return token[0, 1]
return None
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--root_dir",
type=str,
default=".",
help="Path to the InvokeAI install directory containing 'models', 'outputs' and 'configs'.",
)
parser.add_argument(
"--manager_ckpts", type=str, nargs="+", required=True, help="Paths to a set of embedding managers to be merged."
)
parser.add_argument(
"--output_path",
type=str,
required=True,
help="Output path for the merged manager",
)
parser.add_argument(
"-sd",
"--use_bert",
action="store_true",
help="Flag to denote that we are not merging stable diffusion embeddings",
)
args = parser.parse_args()
Globals.root = args.root_dir
if args.use_bert:
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()
else:
embedder = FrozenCLIPEmbedder().cuda()
EmbeddingManager = partial(EmbeddingManager, embedder, ["*"])
string_to_token_dict = {}
string_to_param_dict = torch.nn.ParameterDict()
placeholder_to_src = {}
for manager_ckpt in args.manager_ckpts:
print(f"Parsing {manager_ckpt}...")
manager = EmbeddingManager()
manager.load(manager_ckpt)
for placeholder_string in manager.string_to_token_dict:
if placeholder_string not in string_to_token_dict:
string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string]
string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string]
placeholder_to_src[placeholder_string] = manager_ckpt
else:
new_placeholder, new_token = get_placeholder_loop(placeholder_string, embedder, use_bert=args.use_bert)
string_to_token_dict[new_placeholder] = new_token
string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string]
placeholder_to_src[new_placeholder] = manager_ckpt
print("Saving combined manager...")
merged_manager = EmbeddingManager()
merged_manager.string_to_param_dict = string_to_param_dict
merged_manager.string_to_token_dict = string_to_token_dict
merged_manager.save(args.output_path)
print("Managers merged. Final list of placeholders: ")
print(placeholder_to_src)

View File

@ -1,305 +0,0 @@
import argparse
import datetime
import glob
import os
import sys
import time
import yaml
import torch
import numpy as np
from tqdm import trange
from omegaconf import OmegaConf
from PIL import Image
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config
def rescale(x: float) -> float:
return (x + 1.0) / 2.0
def custom_to_pil(x):
x = x.detach().cpu()
x = torch.clamp(x, -1.0, 1.0)
x = (x + 1.0) / 2.0
x = x.permute(1, 2, 0).numpy()
x = (255 * x).astype(np.uint8)
x = Image.fromarray(x)
if not x.mode == "RGB":
x = x.convert("RGB")
return x
def custom_to_np(x):
# saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
sample = x.detach().cpu()
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
sample = sample.permute(0, 2, 3, 1)
sample = sample.contiguous()
return sample
def logs2pil(logs, keys=["sample"]):
imgs = dict()
for k in logs:
try:
if len(logs[k].shape) == 4:
img = custom_to_pil(logs[k][0, ...])
elif len(logs[k].shape) == 3:
img = custom_to_pil(logs[k])
else:
print(f"Unknown format for key {k}. ")
img = None
except Exception:
img = None
imgs[k] = img
return imgs
@torch.no_grad()
def convsample(model, shape, return_intermediates=True, verbose=True, make_prog_row=False):
if not make_prog_row:
return model.p_sample_loop(None, shape, return_intermediates=return_intermediates, verbose=verbose)
else:
return model.progressive_denoising(None, shape, verbose=True)
@torch.no_grad()
def convsample_ddim(model, steps, shape, eta=1.0):
ddim = DDIMSampler(model)
bs = shape[0]
shape = shape[1:]
samples, intermediates = ddim.sample(
steps,
batch_size=bs,
shape=shape,
eta=eta,
verbose=False,
)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(
model,
batch_size,
vanilla=False,
custom_steps=None,
eta=1.0,
):
log = dict()
shape = [
batch_size,
model.model.diffusion_model.in_channels,
model.model.diffusion_model.image_size,
model.model.diffusion_model.image_size,
]
with model.ema_scope("Plotting"):
t0 = time.time()
if vanilla:
sample, progrow = convsample(model, shape, make_prog_row=True)
else:
sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, eta=eta)
t1 = time.time()
x_sample = model.decode_first_stage(sample)
log["sample"] = x_sample
log["time"] = t1 - t0
log["throughput"] = sample.shape[0] / (t1 - t0)
print(f'Throughput for this batch: {log["throughput"]}')
return log
def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):
if vanilla:
print(f"Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.")
else:
print(f"Using DDIM sampling with {custom_steps} sampling steps and eta={eta}")
tstart = time.time()
n_saved = len(glob.glob(os.path.join(logdir, "*.png"))) - 1
# path = logdir
if model.cond_stage_model is None:
all_images = []
print(f"Running unconditional sampling for {n_samples} samples")
for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):
logs = make_convolutional_sample(
model, batch_size=batch_size, vanilla=vanilla, custom_steps=custom_steps, eta=eta
)
n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
all_images.extend([custom_to_np(logs["sample"])])
if n_saved >= n_samples:
print(f"Finish after generating {n_saved} samples")
break
all_img = np.concatenate(all_images, axis=0)
all_img = all_img[:n_samples]
shape_str = "x".join([str(x) for x in all_img.shape])
nppath = os.path.join(nplog, f"{shape_str}-samples.npz")
np.savez(nppath, all_img)
else:
raise NotImplementedError("Currently only sampling for unconditional models supported.")
print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")
def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
for k in logs:
if k == key:
batch = logs[key]
if np_path is None:
for x in batch:
img = custom_to_pil(x)
imgpath = os.path.join(path, f"{key}_{n_saved:06}.png")
img.save(imgpath)
n_saved += 1
else:
npbatch = custom_to_np(batch)
shape_str = "x".join([str(x) for x in npbatch.shape])
nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
np.savez(nppath, npbatch)
n_saved += npbatch.shape[0]
return n_saved
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-r",
"--resume",
type=str,
nargs="?",
help="load from logdir or checkpoint in logdir",
)
parser.add_argument("-n", "--n_samples", type=int, nargs="?", help="number of samples to draw", default=50000)
parser.add_argument(
"-e",
"--eta",
type=float,
nargs="?",
help="eta for ddim sampling (0.0 yields deterministic sampling)",
default=1.0,
)
parser.add_argument(
"-v",
"--vanilla_sample",
default=False,
action="store_true",
help="vanilla sampling (default option is DDIM sampling)?",
)
parser.add_argument("-l", "--logdir", type=str, nargs="?", help="extra logdir", default="none")
parser.add_argument(
"-c", "--custom_steps", type=int, nargs="?", help="number of steps for ddim and fastdpm sampling", default=50
)
parser.add_argument("--batch_size", type=int, nargs="?", help="the bs", default=10)
return parser
def load_model_from_config(config, sd):
model = instantiate_from_config(config)
model.load_state_dict(sd, strict=False)
model.cuda()
model.eval()
return model
def load_model(config, ckpt, gpu, eval_mode):
if ckpt:
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
else:
pl_sd = {"state_dict": None}
global_step = None
model = load_model_from_config(config.model, pl_sd["state_dict"])
return model, global_step
if __name__ == "__main__":
now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
sys.path.append(os.getcwd())
command = " ".join(sys.argv)
parser = get_parser()
opt, unknown = parser.parse_known_args()
ckpt = None
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
# paths = opt.resume.split("/")
try:
logdir = "/".join(opt.resume.split("/")[:-1])
# idx = len(paths)-paths[::-1].index("logs")+1
print(f"Logdir is {logdir}")
except ValueError:
paths = opt.resume.split("/")
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
logdir = "/".join(paths[:idx])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory"
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "model.ckpt")
base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml")))
opt.base = base_configs
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
gpu = True
eval_mode = True
if opt.logdir != "none":
locallog = logdir.split(os.sep)[-1]
if locallog == "":
locallog = logdir.split(os.sep)[-2]
print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
logdir = os.path.join(opt.logdir, locallog)
print(config)
model, global_step = load_model(config, ckpt, gpu, eval_mode)
print(f"global step: {global_step}")
print(75 * "=")
print("logging to:")
logdir = os.path.join(logdir, "samples", f"{global_step:08}", now)
imglogdir = os.path.join(logdir, "img")
numpylogdir = os.path.join(logdir, "numpy")
os.makedirs(imglogdir)
os.makedirs(numpylogdir)
print(logdir)
print(75 * "=")
# write config out
sampling_file = os.path.join(logdir, "sampling_config.yaml")
sampling_conf = vars(opt)
with open(sampling_file, "w") as f:
yaml.dump(sampling_conf, f, default_flow_style=False)
print(sampling_conf)
run(
model,
imglogdir,
eta=opt.eta,
vanilla=opt.vanilla_sample,
n_samples=opt.n_samples,
custom_steps=opt.custom_steps,
batch_size=opt.batch_size,
nplog=numpylogdir,
)
print("done.")

View File

@ -1,169 +0,0 @@
import os
import sys
import numpy as np
import scann
import argparse
import glob
from multiprocessing import cpu_count
from tqdm import tqdm
from ldm.util import parallel_data_prefetch
def search_bruteforce(searcher):
return searcher.score_brute_force().build()
def search_partioned_ah(
searcher, dims_per_block, aiq_threshold, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search
):
return (
searcher.tree(
num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=partioning_trainsize
)
.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold)
.reorder(reorder_k)
.build()
)
def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
return (
searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
)
def load_datapool(dpath):
def load_single_file(saved_embeddings):
compressed = np.load(saved_embeddings)
database = {key: compressed[key] for key in compressed.files}
return database
def load_multi_files(data_archive):
database = {key: [] for key in data_archive[0].files}
for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."):
for key in d.files:
database[key].append(d[key])
return database
print(f'Load saved patch embedding from "{dpath}"')
file_content = glob.glob(os.path.join(dpath, "*.npz"))
if len(file_content) == 1:
data_pool = load_single_file(file_content[0])
elif len(file_content) > 1:
data = [np.load(f) for f in file_content]
prefetched_data = parallel_data_prefetch(
load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict"
)
data_pool = {
key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()
}
else:
raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
return data_pool
def train_searcher(
opt,
metric="dot_product",
partioning_trainsize=None,
reorder_k=None,
# todo tune
aiq_thld=0.2,
dims_per_block=2,
num_leaves=None,
num_leaves_to_search=None,
):
data_pool = load_datapool(opt.database)
k = opt.knn
if not reorder_k:
reorder_k = 2 * k
# normalize
# embeddings =
searcher = scann.scann_ops_pybind.builder(
data_pool["embedding"] / np.linalg.norm(data_pool["embedding"], axis=1)[:, np.newaxis], k, metric
)
pool_size = data_pool["embedding"].shape[0]
print(*(["#"] * 100))
print("Initializing scaNN searcher with the following values:")
print(f"k: {k}")
print(f"metric: {metric}")
print(f"reorder_k: {reorder_k}")
print(f"anisotropic_quantization_threshold: {aiq_thld}")
print(f"dims_per_block: {dims_per_block}")
print(*(["#"] * 100))
print("Start training searcher....")
print(f"N samples in pool is {pool_size}")
# this reflects the recommended design choices proposed at
# https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
if pool_size < 2e4:
print("Using brute force search.")
searcher = search_bruteforce(searcher)
elif 2e4 <= pool_size and pool_size < 1e5:
print("Using asymmetric hashing search and reordering.")
searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
else:
print("Using using partioning, asymmetric hashing search and reordering.")
if not partioning_trainsize:
partioning_trainsize = data_pool["embedding"].shape[0] // 10
if not num_leaves:
num_leaves = int(np.sqrt(pool_size))
if not num_leaves_to_search:
num_leaves_to_search = max(num_leaves // 20, 1)
print("Partitioning params:")
print(f"num_leaves: {num_leaves}")
print(f"num_leaves_to_search: {num_leaves_to_search}")
# self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
searcher = search_partioned_ah(
searcher, dims_per_block, aiq_thld, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search
)
print("Finish training searcher")
searcher_savedir = opt.target_path
os.makedirs(searcher_savedir, exist_ok=True)
searcher.serialize(searcher_savedir)
print(f'Saved trained searcher under "{searcher_savedir}"')
if __name__ == "__main__":
sys.path.append(os.getcwd())
parser = argparse.ArgumentParser()
parser.add_argument(
"--database",
"-d",
default="data/rdm/retrieval_databases/openimages",
type=str,
help="path to folder containing the clip feature of the database",
)
parser.add_argument(
"--target_path",
"-t",
default="data/rdm/searchers/openimages",
type=str,
help="path to the target folder where the searcher shall be stored.",
)
parser.add_argument(
"--knn",
"-k",
default=20,
type=int,
help="number of nearest neighbors, for which the searcher shall be optimized",
)
opt, _ = parser.parse_known_args()
train_searcher(
opt,
)

View File

@ -1,316 +0,0 @@
import argparse
import os
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import nullcontext
import k_diffusion as K
import torch.nn as nn
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.invoke.devices import choose_torch_device
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.to(choose_torch_device())
model.eval()
return model
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt",
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render",
)
parser.add_argument(
"--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples"
)
parser.add_argument(
"--skip_grid",
action="store_true",
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
parser.add_argument(
"--skip_save",
action="store_true",
help="do not save individual samples. For speed measurements.",
)
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="number of ddim sampling steps",
)
parser.add_argument(
"--plms",
action="store_true",
help="use plms sampling",
)
parser.add_argument(
"--klms",
action="store_true",
help="use klms sampling",
)
parser.add_argument(
"--laion400m",
action="store_true",
help="uses the LAION400M model",
)
parser.add_argument(
"--fixed_code",
action="store_true",
help="if enabled, uses the same starting code across samples ",
)
parser.add_argument(
"--ddim_eta",
type=float,
default=0.0,
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
)
parser.add_argument(
"--n_iter",
type=int,
default=2,
help="sample this often",
)
parser.add_argument(
"--H",
type=int,
default=512,
help="image height, in pixel space",
)
parser.add_argument(
"--W",
type=int,
default=512,
help="image width, in pixel space",
)
parser.add_argument(
"--C",
type=int,
default=4,
help="latent channels",
)
parser.add_argument(
"--f",
type=int,
default=8,
help="downsampling factor",
)
parser.add_argument(
"--n_samples",
type=int,
default=3,
help="how many samples to produce for each given prompt. A.k.a. batch size",
)
parser.add_argument(
"--n_rows",
type=int,
default=0,
help="rows in the grid (default: n_samples)",
)
parser.add_argument(
"--scale",
type=float,
default=7.5,
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
)
parser.add_argument(
"--from-file",
type=str,
help="if specified, load prompts from this file",
)
parser.add_argument(
"--config",
type=str,
default="configs/stable-diffusion/v1-inference.yaml",
help="path to config which constructs model",
)
parser.add_argument(
"--ckpt",
type=str,
default="models/ldm/stable-diffusion-v1/model.ckpt",
help="path to checkpoint of model",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="the seed (for reproducible sampling)",
)
parser.add_argument(
"--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast"
)
opt = parser.parse_args()
if opt.laion400m:
print("Falling back to LAION 400M model...")
opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
opt.outdir = "outputs/txt2img-samples-laion400m"
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
seed_everything(opt.seed)
device = torch.device(choose_torch_device())
model = model.to(device)
# for klms
model_wrap = K.external.CompVisDenoiser(model)
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
if opt.plms:
sampler = PLMSSampler(model)
else:
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir
batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
if not opt.from_file:
prompt = opt.prompt
assert prompt is not None
data = [batch_size * [prompt]]
else:
print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
if len(data) >= batch_size:
data = list(chunk(data, batch_size))
else:
while len(data) < batch_size:
data.append(data[-1])
data = [data]
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
start_code = None
if opt.fixed_code:
shape = [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f]
if device.type == "mps":
start_code = torch.randn(shape, device="cpu").to(device)
else:
torch.randn(shape, device=device)
precision_scope = autocast if opt.precision == "autocast" else nullcontext
if device.type in ["mps", "cpu"]:
precision_scope = nullcontext # have to use f32 on mps
with torch.no_grad():
with precision_scope(device.type):
with model.ema_scope():
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
uc = None
if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
if not opt.klms:
samples_ddim, _ = sampler.sample(
S=opt.ddim_steps,
conditioning=c,
batch_size=opt.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code,
)
else:
sigmas = model_wrap.get_sigmas(opt.ddim_steps)
if start_code:
x = start_code
else:
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(model_wrap)
extra_args = {"cond": c, "uncond": uc, "cond_scale": opt.scale}
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if not opt.skip_save:
for x_sample in x_samples_ddim:
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.png")
)
base_count += 1
if not opt.skip_grid:
all_samples.append(x_samples_ddim)
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, "n b c h w -> (n b) c h w")
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")
if __name__ == "__main__":
main()

View File

@ -1,37 +0,0 @@
#!/usr/bin/env python
"""
Read a checkpoint/safetensors file and compare it to a template .json.
Returns True if their metadata match.
"""
import sys
import argparse
import json
from pathlib import Path
from invokeai.backend.model_management.models.base import read_checkpoint_meta
parser = argparse.ArgumentParser(description="Compare a checkpoint/safetensors file to a JSON metadata template.")
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
parser.add_argument("--template", "--out", type=Path, help="Path to the template .json file to match against")
opt = parser.parse_args()
ckpt = read_checkpoint_meta(opt.checkpoint)
while "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
checkpoint_metadata = {}
for key, tensor in ckpt.items():
checkpoint_metadata[key] = list(tensor.shape)
with open(opt.template, "r") as f:
template = json.load(f)
if checkpoint_metadata == template:
print("True")
sys.exit(0)
else:
print("False")
sys.exit(-1)

View File

@ -1,3 +1,10 @@
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvalidVersionError,
invocation,
invocation_output,
)
from .test_nodes import (
ImageToImageTestInvocation,
TextToImageTestInvocation,
@ -20,7 +27,7 @@ from invokeai.app.invocations.upscale import ESRGANInvocation
from invokeai.app.invocations.image import ShowImageInvocation
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
from invokeai.app.invocations.primitives import IntegerInvocation
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
from invokeai.app.services.default_graphs import create_text_to_image
import pytest
@ -610,6 +617,79 @@ def test_graph_can_deserialize():
assert g2.edges[0].destination.field == "image"
def test_invocation_decorator():
invocation_type = "test_invocation"
title = "Test Invocation"
tags = ["first", "second", "third"]
category = "category"
version = "1.2.3"
@invocation(invocation_type, title=title, tags=tags, category=category, version=version)
class TestInvocation(BaseInvocation):
def invoke(self):
pass
schema = TestInvocation.schema()
assert schema.get("title") == title
assert schema.get("tags") == tags
assert schema.get("category") == category
assert schema.get("version") == version
assert TestInvocation(id="1").type == invocation_type # type: ignore (type is dynamically added)
def test_invocation_version_must_be_semver():
invocation_type = "test_invocation"
valid_version = "1.0.0"
invalid_version = "not_semver"
@invocation(invocation_type, version=valid_version)
class ValidVersionInvocation(BaseInvocation):
def invoke(self):
pass
with pytest.raises(InvalidVersionError):
@invocation(invocation_type, version=invalid_version)
class InvalidVersionInvocation(BaseInvocation):
def invoke(self):
pass
def test_invocation_output_decorator():
output_type = "test_output"
@invocation_output(output_type)
class TestOutput(BaseInvocationOutput):
pass
assert TestOutput().type == output_type # type: ignore (type is dynamically added)
def test_floats_accept_ints():
g = Graph()
n1 = IntegerInvocation(id="1", value=1)
n2 = FloatInvocation(id="2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id, "value", n2.id, "value")
# Not throwing on this line is sufficient
g.add_edge(e)
def test_ints_do_not_accept_floats():
g = Graph()
n1 = FloatInvocation(id="1", value=1.0)
n2 = IntegerInvocation(id="2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id, "value", n2.id, "value")
with pytest.raises(InvalidEdgeError):
g.add_edge(e)
def test_graph_can_generate_schema():
# Not throwing on this line is sufficient
# NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation