mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add save_image node
Several improvements: - New save_image node will make a copy of the input image and save it to indicated path as PNG - Sort nodes and commands alphabetically in help message - Intercept command ValidationErrors and print, rather than crash out.
This commit is contained in:
parent
307cfc075d
commit
b7d81f96f8
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints, Union
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -21,7 +21,7 @@ def add_parsers(
|
|||||||
"""Adds parsers for each command to the subparsers"""
|
"""Adds parsers for each command to the subparsers"""
|
||||||
|
|
||||||
# Create subparsers for each command
|
# Create subparsers for each command
|
||||||
for command in commands:
|
for command in sorted(commands, key=lambda x: get_args(get_type_hints(x)[command_field])[0]):
|
||||||
hints = get_type_hints(command)
|
hints = get_type_hints(command)
|
||||||
cmd_name = get_args(hints[command_field])[0]
|
cmd_name = get_args(hints[command_field])[0]
|
||||||
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
|
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
|
||||||
|
@ -9,8 +9,9 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
from argparse import HelpFormatter
|
||||||
from pydantic import BaseModel
|
from operator import attrgetter
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
@ -57,10 +58,14 @@ def add_invocation_args(command_parser):
|
|||||||
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
|
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class SortedHelpFormatter(HelpFormatter):
|
||||||
|
def add_arguments(self, actions):
|
||||||
|
actions = sorted(actions, key=attrgetter('option_strings'))
|
||||||
|
super(SortedHelpFormatter, self).add_arguments(actions)
|
||||||
|
|
||||||
def get_command_parser() -> argparse.ArgumentParser:
|
def get_command_parser() -> argparse.ArgumentParser:
|
||||||
# Create invocation parser
|
# Create invocation parser
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser(formatter_class=SortedHelpFormatter)
|
||||||
|
|
||||||
def exit(*args, **kwargs):
|
def exit(*args, **kwargs):
|
||||||
raise InvalidArgs
|
raise InvalidArgs
|
||||||
@ -287,6 +292,9 @@ def invoke_cli():
|
|||||||
print("Session error: creating a new session")
|
print("Session error: creating a new session")
|
||||||
context.session = context.invoker.create_execution_state()
|
context.session = context.invoker.create_execution_state()
|
||||||
|
|
||||||
|
except ValidationError as e:
|
||||||
|
print(f'Validation error: {str(e)}')
|
||||||
|
|
||||||
except ExitCli:
|
except ExitCli:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -67,6 +67,29 @@ class LoadImageInvocation(BaseInvocation):
|
|||||||
image=ImageField(image_type=self.image_type, image_name=self.image_name)
|
image=ImageField(image_type=self.image_type, image_name=self.image_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class SaveImageInvocation(BaseInvocation):
|
||||||
|
"""Take an image as input and save it as a PNG file to a filename."""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["save_image"] = "save_image"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: ImageField = Field(default=None, description="The image to save")
|
||||||
|
image_type: ImageType = Field(description="The type of the image")
|
||||||
|
image_name: str = Field(description="The filename or path to save the image to")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image = context.services.images.get(
|
||||||
|
self.image.image_type, self.image.image_name
|
||||||
|
)
|
||||||
|
context.services.images.save(
|
||||||
|
self.image_type, self.image_name, image
|
||||||
|
)
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(
|
||||||
|
image_type=self.image_type, image_name=self.image_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
class ShowImageInvocation(BaseInvocation):
|
class ShowImageInvocation(BaseInvocation):
|
||||||
"""Displays a provided image, and passes it forward in the pipeline."""
|
"""Displays a provided image, and passes it forward in the pipeline."""
|
||||||
|
@ -90,17 +90,17 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
return path
|
return path
|
||||||
|
|
||||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||||
image_subpath = os.path.join(image_type, image_name)
|
path = self.get_path(image_type, image_name)
|
||||||
self.__pngWriter.save_image_and_prompt_to_png(
|
self.__pngWriter.save_image_and_prompt_to_png(
|
||||||
image, "", image_subpath, None
|
image, "", path, None
|
||||||
) # TODO: just pass full path to png writer
|
) # TODO: just pass full path to png writer
|
||||||
save_thumbnail(
|
# LS: Save_thumbnail() should be a separate invocation, shouldn't it?
|
||||||
image=image,
|
# save_thumbnail(
|
||||||
filename=image_name,
|
# image=image,
|
||||||
path=os.path.join(self.__output_folder, image_type, "thumbnails"),
|
# filename=image_name,
|
||||||
)
|
# path=os.path.join(self.__output_folder, image_type, "thumbnails"),
|
||||||
image_path = self.get_path(image_type, image_name)
|
# )
|
||||||
self.__set_cache(image_path, image)
|
self.__set_cache(path, image)
|
||||||
|
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_type, image_name)
|
||||||
|
Loading…
Reference in New Issue
Block a user