add save_image node

Several improvements:

- New save_image node will make a copy of the input image and save it to
  indicated path as PNG
- Sort nodes and commands alphabetically in help message
- Intercept command ValidationErrors and print, rather than crash out.
This commit is contained in:
Lincoln Stein 2023-04-09 13:22:20 -04:00
parent 307cfc075d
commit b7d81f96f8
4 changed files with 45 additions and 14 deletions

View File

@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
import argparse
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints, Union
from pydantic import BaseModel, Field
import networkx as nx
import matplotlib.pyplot as plt
@ -21,7 +21,7 @@ def add_parsers(
"""Adds parsers for each command to the subparsers"""
# Create subparsers for each command
for command in commands:
for command in sorted(commands, key=lambda x: get_args(get_type_hints(x)[command_field])[0]):
hints = get_type_hints(command)
cmd_name = get_args(hints[command_field])[0]
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)

View File

@ -9,8 +9,9 @@ from typing import (
Union,
get_type_hints,
)
from pydantic import BaseModel
from argparse import HelpFormatter
from operator import attrgetter
from pydantic import BaseModel, ValidationError
from pydantic.fields import Field
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@ -57,10 +58,14 @@ def add_invocation_args(command_parser):
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
)
class SortedHelpFormatter(HelpFormatter):
def add_arguments(self, actions):
actions = sorted(actions, key=attrgetter('option_strings'))
super(SortedHelpFormatter, self).add_arguments(actions)
def get_command_parser() -> argparse.ArgumentParser:
# Create invocation parser
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(formatter_class=SortedHelpFormatter)
def exit(*args, **kwargs):
raise InvalidArgs
@ -287,6 +292,9 @@ def invoke_cli():
print("Session error: creating a new session")
context.session = context.invoker.create_execution_state()
except ValidationError as e:
print(f'Validation error: {str(e)}')
except ExitCli:
break

View File

@ -67,6 +67,29 @@ class LoadImageInvocation(BaseInvocation):
image=ImageField(image_type=self.image_type, image_name=self.image_name)
)
class SaveImageInvocation(BaseInvocation):
"""Take an image as input and save it as a PNG file to a filename."""
#fmt: off
type: Literal["save_image"] = "save_image"
# Inputs
image: ImageField = Field(default=None, description="The image to save")
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The filename or path to save the image to")
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
context.services.images.save(
self.image_type, self.image_name, image
)
return ImageOutput(
image=ImageField(
image_type=self.image_type, image_name=self.image_name
)
)
class ShowImageInvocation(BaseInvocation):
"""Displays a provided image, and passes it forward in the pipeline."""

View File

@ -90,17 +90,17 @@ class DiskImageStorage(ImageStorageBase):
return path
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
image_subpath = os.path.join(image_type, image_name)
path = self.get_path(image_type, image_name)
self.__pngWriter.save_image_and_prompt_to_png(
image, "", image_subpath, None
image, "", path, None
) # TODO: just pass full path to png writer
save_thumbnail(
image=image,
filename=image_name,
path=os.path.join(self.__output_folder, image_type, "thumbnails"),
)
image_path = self.get_path(image_type, image_name)
self.__set_cache(image_path, image)
# LS: Save_thumbnail() should be a separate invocation, shouldn't it?
# save_thumbnail(
# image=image,
# filename=image_name,
# path=os.path.join(self.__output_folder, image_type, "thumbnails"),
# )
self.__set_cache(path, image)
def delete(self, image_type: ImageType, image_name: str) -> None:
image_path = self.get_path(image_type, image_name)