Compare commits

...

5 Commits

Author SHA1 Message Date
87798cc8fa Merge branch 'enhance/latents-handling' of github.com:invoke-ai/InvokeAI into enhance/latents-handling 2023-04-09 13:24:01 -04:00
b7d81f96f8 add save_image node
Several improvements:

- New save_image node will make a copy of the input image and save it to
  indicated path as PNG
- Sort nodes and commands alphabetically in help message
- Intercept command ValidationErrors and print, rather than crash out.
2023-04-09 13:22:20 -04:00
fe3f9d41fc fix module names in comments 2023-04-08 11:07:15 -04:00
307cfc075d add i2l invocation
This round trip now works:

```
load_image --image_name ./test.png --image_type local | i2l | l2i | show_image
```
2023-04-08 10:44:21 -04:00
cf7adb1815 t2l and l2i now working as expected
- Added code to generate t2l noise if the noise parameter is not explicitly passed.
- Fixed autocomplete to propose CLI commands after the | symbol

This works as expected:
   t2l --prompt 'banana sushi' | l2i | show_image
2023-04-08 09:47:51 -04:00
7 changed files with 127 additions and 40 deletions

View File

@ -2,7 +2,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import argparse import argparse
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import networkx as nx import networkx as nx
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -21,7 +21,7 @@ def add_parsers(
"""Adds parsers for each command to the subparsers""" """Adds parsers for each command to the subparsers"""
# Create subparsers for each command # Create subparsers for each command
for command in commands: for command in sorted(commands, key=lambda x: get_args(get_type_hints(x)[command_field])[0]):
hints = get_type_hints(command) hints = get_type_hints(command)
cmd_name = get_args(hints[command_field])[0] cmd_name = get_args(hints[command_field])[0]
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__) command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)

View File

@ -4,6 +4,7 @@ You may import the global singleton `completer` to get access to the
completer object. completer object.
""" """
import atexit import atexit
import re
import readline import readline
import shlex import shlex
@ -66,6 +67,8 @@ class Completer(object):
""" """
if len(buffer)==0: if len(buffer)==0:
return None, None return None, None
if re.search('\|\s*[a-zA-Z0-9]*$',buffer): # reset command on pipe symbol
return None,None
tokens = shlex.split(buffer) tokens = shlex.split(buffer)
command = None command = None
switch = None switch = None

View File

@ -9,8 +9,9 @@ from typing import (
Union, Union,
get_type_hints, get_type_hints,
) )
from argparse import HelpFormatter
from pydantic import BaseModel from operator import attrgetter
from pydantic import BaseModel, ValidationError
from pydantic.fields import Field from pydantic.fields import Field
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@ -57,10 +58,14 @@ def add_invocation_args(command_parser):
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)", help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
) )
class SortedHelpFormatter(HelpFormatter):
def add_arguments(self, actions):
actions = sorted(actions, key=attrgetter('option_strings'))
super(SortedHelpFormatter, self).add_arguments(actions)
def get_command_parser() -> argparse.ArgumentParser: def get_command_parser() -> argparse.ArgumentParser:
# Create invocation parser # Create invocation parser
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(formatter_class=SortedHelpFormatter)
def exit(*args, **kwargs): def exit(*args, **kwargs):
raise InvalidArgs raise InvalidArgs
@ -287,6 +292,9 @@ def invoke_cli():
print("Session error: creating a new session") print("Session error: creating a new session")
context.session = context.invoker.create_execution_state() context.session = context.invoker.create_execution_state()
except ValidationError as e:
print(f'Validation error: {str(e)}')
except ExitCli: except ExitCli:
break break

View File

@ -67,6 +67,29 @@ class LoadImageInvocation(BaseInvocation):
image=ImageField(image_type=self.image_type, image_name=self.image_name) image=ImageField(image_type=self.image_type, image_name=self.image_name)
) )
class SaveImageInvocation(BaseInvocation):
"""Take an image as input and save it as a PNG file to a filename."""
#fmt: off
type: Literal["save_image"] = "save_image"
# Inputs
image: ImageField = Field(default=None, description="The image to save")
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The filename or path to save the image to")
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
context.services.images.save(
self.image_type, self.image_name, image
)
return ImageOutput(
image=ImageField(
image_type=self.image_type, image_name=self.image_name
)
)
class ShowImageInvocation(BaseInvocation): class ShowImageInvocation(BaseInvocation):
"""Displays a provided image, and passes it forward in the pipeline.""" """Displays a provided image, and passes it forward in the pipeline."""

View File

@ -1,28 +1,28 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) 1# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from PIL import Image
from typing import Literal, Optional from typing import Literal, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from torch import Tensor from torch import Tensor
import einops
import torch import torch
from ...backend.model_management.model_manager import ModelManager from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import CUDA_DEVICE, torch_dtype from ...backend.util.devices import torch_dtype, choose_torch_device
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
import numpy as np import numpy as np
from accelerate.utils import set_seed
from ..services.image_storage import ImageType from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput from .image import ImageField, ImageOutput
from ...backend.generator import Generator from ...backend.generator import Generator
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL from ...backend.util.util import image_to_dataURL
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers import diffusers
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline, AutoencoderKL
class LatentsField(BaseModel): class LatentsField(BaseModel):
@ -110,7 +110,7 @@ class NoiseInvocation(BaseInvocation):
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", ) height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
def invoke(self, context: InvocationContext) -> NoiseOutput: def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(CUDA_DEVICE) device = choose_torch_device()
noise = get_noise(self.width, self.height, device, self.seed) noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
@ -170,8 +170,6 @@ class TextToLatentsInvocation(BaseInvocation):
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline: def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = model_manager.get_model(self.model) model_info = model_manager.get_model(self.model)
model_name = model_info['model_name']
model_hash = model_info['hash']
model: StableDiffusionGeneratorPipeline = model_info['model'] model: StableDiffusionGeneratorPipeline = model_info['model']
model.scheduler = get_scheduler( model.scheduler = get_scheduler(
model=model, model=model,
@ -211,7 +209,10 @@ class TextToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) if self.noise:
noise = context.services.latents.get(self.noise.latents_name)
else:
noise = get_noise(self.width, self.height, choose_torch_device(), self.seed)
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step) self.dispatch_progress(context, state.latents, state.step)
@ -230,7 +231,8 @@ class TextToLatentsInvocation(BaseInvocation):
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() if torch.has_cuda:
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents) context.services.latents.set(name, result_latents)
@ -280,7 +282,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() if torch.has_cuda:
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents) context.services.latents.set(name, result_latents)
@ -302,14 +305,11 @@ class LatentsToImageInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
vae = context.services.model_manager.get_model_vae(self.model)
# TODO: this only really needs the vae
model_info = context.services.model_manager.get_model(self.model)
model: StableDiffusionGeneratorPipeline = model_info['model']
with torch.inference_mode(): with torch.inference_mode():
np_image = model.decode_latents(latents) np_image = self._decode_latents(vae,latents)
image = model.numpy_to_pil(np_image)[0] image = StableDiffusionGeneratorPipeline.numpy_to_pil(np_image)[0]
image_type = ImageType.RESULT image_type = ImageType.RESULT
image_name = context.services.images.create_name( image_name = context.services.images.create_name(
@ -319,3 +319,54 @@ class LatentsToImageInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name) image=ImageField(image_type=image_type, image_name=image_name)
) )
# this should be refactored - duplicated code in diffusers.pipelines.stable_diffusion
@classmethod
def _decode_latents(self, vae:AutoencoderKL, latents:torch.Tensor)->Image:
latents = 1 / vae.config.scaling_factor * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Image to latent
class ImageToLatentsInvocation(BaseInvocation):
"""Generates latents from an image."""
type: Literal["i2l"] = "i2l"
# Inputs
image: Optional[ImageField] = Field(description="The image to generate latents from")
model: str = Field(default="", description="The model to use")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.services.images.get(self.image.image_type,self.image.image_name)
vae = context.services.model_manager.get_model_vae(self.model)
with torch.inference_mode():
result_latents = self._encode_latents(vae,image)
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
# this should be refactored - similar code in invokeai.backend.stable_diffusion.diffusers_pipeline
@classmethod
def _encode_latents(self, vae:AutoencoderKL, image:Image)->torch.Tensor:
device = choose_torch_device()
dtype = torch_dtype(device)
image_tensor = image_resized_to_grid_as_tensor(image)
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
image_tensor = image_tensor.to(device=device, dtype=dtype)
init_latent_dist = vae.encode(image_tensor).latent_dist
init_latents = init_latent_dist.sample().to(
dtype=dtype
)
init_latents = 0.18215 * init_latents
return init_latents

View File

@ -1,13 +1,13 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import datetime import datetime
import PIL
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
from typing import Dict from typing import Dict
from PIL.Image import Image from PIL.Image import Image
from invokeai.app.util.save_thumbnail import save_thumbnail from invokeai.app.util.save_thumbnail import save_thumbnail
@ -18,7 +18,7 @@ class ImageType(str, Enum):
RESULT = "results" RESULT = "results"
INTERMEDIATE = "intermediates" INTERMEDIATE = "intermediates"
UPLOAD = "uploads" UPLOAD = "uploads"
LOCAL = "local" # a local path, relative to cwd or absolute
class ImageStorageBase(ABC): class ImageStorageBase(ABC):
"""Responsible for storing and retrieving images.""" """Responsible for storing and retrieving images."""
@ -77,27 +77,30 @@ class DiskImageStorage(ImageStorageBase):
if cache_item: if cache_item:
return cache_item return cache_item
image = Image.open(image_path) image = PIL.Image.open(image_path)
self.__set_cache(image_path, image) self.__set_cache(image_path, image)
return image return image
# TODO: make this a bit more flexible for e.g. cloud storage # TODO: make this a bit more flexible for e.g. cloud storage
def get_path(self, image_type: ImageType, image_name: str) -> str: def get_path(self, image_type: ImageType, image_name: str) -> str:
path = os.path.join(self.__output_folder, image_type, image_name) if image_type == ImageType.LOCAL:
path = image_name
else:
path = os.path.join(self.__output_folder, image_type, image_name)
return path return path
def save(self, image_type: ImageType, image_name: str, image: Image) -> None: def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
image_subpath = os.path.join(image_type, image_name) path = self.get_path(image_type, image_name)
self.__pngWriter.save_image_and_prompt_to_png( self.__pngWriter.save_image_and_prompt_to_png(
image, "", image_subpath, None image, "", path, None
) # TODO: just pass full path to png writer ) # TODO: just pass full path to png writer
save_thumbnail( # LS: Save_thumbnail() should be a separate invocation, shouldn't it?
image=image, # save_thumbnail(
filename=image_name, # image=image,
path=os.path.join(self.__output_folder, image_type, "thumbnails"), # filename=image_name,
) # path=os.path.join(self.__output_folder, image_type, "thumbnails"),
image_path = self.get_path(image_type, image_name) # )
self.__set_cache(image_path, image) self.__set_cache(path, image)
def delete(self, image_type: ImageType, image_name: str) -> None: def delete(self, image_type: ImageType, image_name: str) -> None:
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_type, image_name)

View File

@ -72,11 +72,10 @@ def get_model_manager(config: Args) -> ModelManager:
# try to autoconvert new models # try to autoconvert new models
# autoimport new .ckpt files # autoimport new .ckpt files
if path := config.autoconvert: if path := config.autoconvert:
model_manager.autoconvert_weights( model_manager.heuristic_import(
conf_path=config.conf, str(path), commit_to_conf=config.conf
weights_directory=path,
) )
return model_manager return model_manager
def report_model_error(opt: Namespace, e: Exception): def report_model_error(opt: Namespace, e: Exception):