From b7d81f96f86d26e4ce66648e2049cf4c20c9a2fe Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 9 Apr 2023 13:22:20 -0400 Subject: [PATCH] add save_image node Several improvements: - New save_image node will make a copy of the input image and save it to indicated path as PNG - Sort nodes and commands alphabetically in help message - Intercept command ValidationErrors and print, rather than crash out. --- invokeai/app/cli/commands.py | 4 ++-- invokeai/app/cli_app.py | 14 +++++++++++--- invokeai/app/invocations/image.py | 23 +++++++++++++++++++++++ invokeai/app/services/image_storage.py | 18 +++++++++--------- 4 files changed, 45 insertions(+), 14 deletions(-) diff --git a/invokeai/app/cli/commands.py b/invokeai/app/cli/commands.py index 5f4da73303..d0fec893a7 100644 --- a/invokeai/app/cli/commands.py +++ b/invokeai/app/cli/commands.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod import argparse -from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints +from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints, Union from pydantic import BaseModel, Field import networkx as nx import matplotlib.pyplot as plt @@ -21,7 +21,7 @@ def add_parsers( """Adds parsers for each command to the subparsers""" # Create subparsers for each command - for command in commands: + for command in sorted(commands, key=lambda x: get_args(get_type_hints(x)[command_field])[0]): hints = get_type_hints(command) cmd_name = get_args(hints[command_field])[0] command_parser = subparsers.add_parser(cmd_name, help=command.__doc__) diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index a257825dcc..0c27feb826 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -9,8 +9,9 @@ from typing import ( Union, get_type_hints, ) - -from pydantic import BaseModel +from argparse import HelpFormatter +from operator import attrgetter +from pydantic import BaseModel, ValidationError from pydantic.fields import Field from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage @@ -57,10 +58,14 @@ def add_invocation_args(command_parser): help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)", ) +class SortedHelpFormatter(HelpFormatter): + def add_arguments(self, actions): + actions = sorted(actions, key=attrgetter('option_strings')) + super(SortedHelpFormatter, self).add_arguments(actions) def get_command_parser() -> argparse.ArgumentParser: # Create invocation parser - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(formatter_class=SortedHelpFormatter) def exit(*args, **kwargs): raise InvalidArgs @@ -287,6 +292,9 @@ def invoke_cli(): print("Session error: creating a new session") context.session = context.invoker.create_execution_state() + except ValidationError as e: + print(f'Validation error: {str(e)}') + except ExitCli: break diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 65ea4c3edb..759c382a03 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -67,6 +67,29 @@ class LoadImageInvocation(BaseInvocation): image=ImageField(image_type=self.image_type, image_name=self.image_name) ) +class SaveImageInvocation(BaseInvocation): + """Take an image as input and save it as a PNG file to a filename.""" + #fmt: off + type: Literal["save_image"] = "save_image" + + # Inputs + image: ImageField = Field(default=None, description="The image to save") + image_type: ImageType = Field(description="The type of the image") + image_name: str = Field(description="The filename or path to save the image to") + #fmt: on + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get( + self.image.image_type, self.image.image_name + ) + context.services.images.save( + self.image_type, self.image_name, image + ) + return ImageOutput( + image=ImageField( + image_type=self.image_type, image_name=self.image_name + ) + ) class ShowImageInvocation(BaseInvocation): """Displays a provided image, and passes it forward in the pipeline.""" diff --git a/invokeai/app/services/image_storage.py b/invokeai/app/services/image_storage.py index 44c33ede92..faddabfbd9 100644 --- a/invokeai/app/services/image_storage.py +++ b/invokeai/app/services/image_storage.py @@ -90,17 +90,17 @@ class DiskImageStorage(ImageStorageBase): return path def save(self, image_type: ImageType, image_name: str, image: Image) -> None: - image_subpath = os.path.join(image_type, image_name) + path = self.get_path(image_type, image_name) self.__pngWriter.save_image_and_prompt_to_png( - image, "", image_subpath, None + image, "", path, None ) # TODO: just pass full path to png writer - save_thumbnail( - image=image, - filename=image_name, - path=os.path.join(self.__output_folder, image_type, "thumbnails"), - ) - image_path = self.get_path(image_type, image_name) - self.__set_cache(image_path, image) + # LS: Save_thumbnail() should be a separate invocation, shouldn't it? + # save_thumbnail( + # image=image, + # filename=image_name, + # path=os.path.join(self.__output_folder, image_type, "thumbnails"), + # ) + self.__set_cache(path, image) def delete(self, image_type: ImageType, image_name: str) -> None: image_path = self.get_path(image_type, image_name)