mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
5 Commits
release/ad
...
enhance/la
Author | SHA1 | Date | |
---|---|---|---|
87798cc8fa | |||
b7d81f96f8 | |||
fe3f9d41fc | |||
307cfc075d | |||
cf7adb1815 |
@ -2,7 +2,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import argparse
|
||||
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints, Union
|
||||
from pydantic import BaseModel, Field
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
@ -21,7 +21,7 @@ def add_parsers(
|
||||
"""Adds parsers for each command to the subparsers"""
|
||||
|
||||
# Create subparsers for each command
|
||||
for command in commands:
|
||||
for command in sorted(commands, key=lambda x: get_args(get_type_hints(x)[command_field])[0]):
|
||||
hints = get_type_hints(command)
|
||||
cmd_name = get_args(hints[command_field])[0]
|
||||
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
|
||||
|
@ -4,6 +4,7 @@ You may import the global singleton `completer` to get access to the
|
||||
completer object.
|
||||
"""
|
||||
import atexit
|
||||
import re
|
||||
import readline
|
||||
import shlex
|
||||
|
||||
@ -66,6 +67,8 @@ class Completer(object):
|
||||
"""
|
||||
if len(buffer)==0:
|
||||
return None, None
|
||||
if re.search('\|\s*[a-zA-Z0-9]*$',buffer): # reset command on pipe symbol
|
||||
return None,None
|
||||
tokens = shlex.split(buffer)
|
||||
command = None
|
||||
switch = None
|
||||
|
@ -9,8 +9,9 @@ from typing import (
|
||||
Union,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from argparse import HelpFormatter
|
||||
from operator import attrgetter
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic.fields import Field
|
||||
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
@ -57,10 +58,14 @@ def add_invocation_args(command_parser):
|
||||
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
|
||||
)
|
||||
|
||||
class SortedHelpFormatter(HelpFormatter):
|
||||
def add_arguments(self, actions):
|
||||
actions = sorted(actions, key=attrgetter('option_strings'))
|
||||
super(SortedHelpFormatter, self).add_arguments(actions)
|
||||
|
||||
def get_command_parser() -> argparse.ArgumentParser:
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser(formatter_class=SortedHelpFormatter)
|
||||
|
||||
def exit(*args, **kwargs):
|
||||
raise InvalidArgs
|
||||
@ -287,6 +292,9 @@ def invoke_cli():
|
||||
print("Session error: creating a new session")
|
||||
context.session = context.invoker.create_execution_state()
|
||||
|
||||
except ValidationError as e:
|
||||
print(f'Validation error: {str(e)}')
|
||||
|
||||
except ExitCli:
|
||||
break
|
||||
|
||||
|
@ -67,6 +67,29 @@ class LoadImageInvocation(BaseInvocation):
|
||||
image=ImageField(image_type=self.image_type, image_name=self.image_name)
|
||||
)
|
||||
|
||||
class SaveImageInvocation(BaseInvocation):
|
||||
"""Take an image as input and save it as a PNG file to a filename."""
|
||||
#fmt: off
|
||||
type: Literal["save_image"] = "save_image"
|
||||
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The image to save")
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The filename or path to save the image to")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
context.services.images.save(
|
||||
self.image_type, self.image_name, image
|
||||
)
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_type=self.image_type, image_name=self.image_name
|
||||
)
|
||||
)
|
||||
|
||||
class ShowImageInvocation(BaseInvocation):
|
||||
"""Displays a provided image, and passes it forward in the pipeline."""
|
||||
|
@ -1,28 +1,28 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
1# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from PIL import Image
|
||||
from typing import Literal, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from torch import Tensor
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from ...backend.model_management.model_manager import ModelManager
|
||||
from ...backend.util.devices import CUDA_DEVICE, torch_dtype
|
||||
from ...backend.util.devices import torch_dtype, choose_torch_device
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
import numpy as np
|
||||
from accelerate.utils import set_seed
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from ...backend.generator import Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
import diffusers
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import DiffusionPipeline, AutoencoderKL
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
@ -110,7 +110,7 @@ class NoiseInvocation(BaseInvocation):
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
||||
|
||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||
device = torch.device(CUDA_DEVICE)
|
||||
device = choose_torch_device()
|
||||
noise = get_noise(self.width, self.height, device, self.seed)
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
@ -170,8 +170,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||
model_info = model_manager.get_model(self.model)
|
||||
model_name = model_info['model_name']
|
||||
model_hash = model_info['hash']
|
||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||
model.scheduler = get_scheduler(
|
||||
model=model,
|
||||
@ -211,7 +209,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
if self.noise:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
else:
|
||||
noise = get_noise(self.width, self.height, choose_torch_device(), self.seed)
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, state.latents, state.step)
|
||||
@ -230,7 +231,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
if torch.has_cuda:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
@ -280,7 +282,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
if torch.has_cuda:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
@ -302,14 +305,11 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = context.services.model_manager.get_model(self.model)
|
||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||
vae = context.services.model_manager.get_model_vae(self.model)
|
||||
|
||||
with torch.inference_mode():
|
||||
np_image = model.decode_latents(latents)
|
||||
image = model.numpy_to_pil(np_image)[0]
|
||||
np_image = self._decode_latents(vae,latents)
|
||||
image = StableDiffusionGeneratorPipeline.numpy_to_pil(np_image)[0]
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
@ -319,3 +319,54 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
# this should be refactored - duplicated code in diffusers.pipelines.stable_diffusion
|
||||
@classmethod
|
||||
def _decode_latents(self, vae:AutoencoderKL, latents:torch.Tensor)->Image:
|
||||
latents = 1 / vae.config.scaling_factor * latents
|
||||
image = vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
|
||||
# Image to latent
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from an image."""
|
||||
|
||||
type: Literal["i2l"] = "i2l"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(description="The image to generate latents from")
|
||||
model: str = Field(default="", description="The model to use")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.services.images.get(self.image.image_type,self.image.image_name)
|
||||
vae = context.services.model_manager.get_model_vae(self.model)
|
||||
|
||||
with torch.inference_mode():
|
||||
result_latents = self._encode_latents(vae,image)
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=name)
|
||||
)
|
||||
|
||||
# this should be refactored - similar code in invokeai.backend.stable_diffusion.diffusers_pipeline
|
||||
@classmethod
|
||||
def _encode_latents(self, vae:AutoencoderKL, image:Image)->torch.Tensor:
|
||||
device = choose_torch_device()
|
||||
dtype = torch_dtype(device)
|
||||
image_tensor = image_resized_to_grid_as_tensor(image)
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
image_tensor = image_tensor.to(device=device, dtype=dtype)
|
||||
init_latent_dist = vae.encode(image_tensor).latent_dist
|
||||
init_latents = init_latent_dist.sample().to(
|
||||
dtype=dtype
|
||||
)
|
||||
init_latents = 0.18215 * init_latents
|
||||
return init_latents
|
||||
|
||||
|
||||
|
@ -1,13 +1,13 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import datetime
|
||||
import PIL
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict
|
||||
|
||||
from PIL.Image import Image
|
||||
from invokeai.app.util.save_thumbnail import save_thumbnail
|
||||
|
||||
@ -18,7 +18,7 @@ class ImageType(str, Enum):
|
||||
RESULT = "results"
|
||||
INTERMEDIATE = "intermediates"
|
||||
UPLOAD = "uploads"
|
||||
|
||||
LOCAL = "local" # a local path, relative to cwd or absolute
|
||||
|
||||
class ImageStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving images."""
|
||||
@ -77,27 +77,30 @@ class DiskImageStorage(ImageStorageBase):
|
||||
if cache_item:
|
||||
return cache_item
|
||||
|
||||
image = Image.open(image_path)
|
||||
image = PIL.Image.open(image_path)
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||
if image_type == ImageType.LOCAL:
|
||||
path = image_name
|
||||
else:
|
||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||
return path
|
||||
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||
image_subpath = os.path.join(image_type, image_name)
|
||||
path = self.get_path(image_type, image_name)
|
||||
self.__pngWriter.save_image_and_prompt_to_png(
|
||||
image, "", image_subpath, None
|
||||
image, "", path, None
|
||||
) # TODO: just pass full path to png writer
|
||||
save_thumbnail(
|
||||
image=image,
|
||||
filename=image_name,
|
||||
path=os.path.join(self.__output_folder, image_type, "thumbnails"),
|
||||
)
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
self.__set_cache(image_path, image)
|
||||
# LS: Save_thumbnail() should be a separate invocation, shouldn't it?
|
||||
# save_thumbnail(
|
||||
# image=image,
|
||||
# filename=image_name,
|
||||
# path=os.path.join(self.__output_folder, image_type, "thumbnails"),
|
||||
# )
|
||||
self.__set_cache(path, image)
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
|
@ -72,11 +72,10 @@ def get_model_manager(config: Args) -> ModelManager:
|
||||
# try to autoconvert new models
|
||||
# autoimport new .ckpt files
|
||||
if path := config.autoconvert:
|
||||
model_manager.autoconvert_weights(
|
||||
conf_path=config.conf,
|
||||
weights_directory=path,
|
||||
model_manager.heuristic_import(
|
||||
str(path), commit_to_conf=config.conf
|
||||
)
|
||||
|
||||
|
||||
return model_manager
|
||||
|
||||
def report_model_error(opt: Namespace, e: Exception):
|
||||
|
Reference in New Issue
Block a user