mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
5 Commits
release/ad
...
enhance/la
Author | SHA1 | Date | |
---|---|---|---|
87798cc8fa | |||
b7d81f96f8 | |||
fe3f9d41fc | |||
307cfc075d | |||
cf7adb1815 |
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints, Union
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -21,7 +21,7 @@ def add_parsers(
|
|||||||
"""Adds parsers for each command to the subparsers"""
|
"""Adds parsers for each command to the subparsers"""
|
||||||
|
|
||||||
# Create subparsers for each command
|
# Create subparsers for each command
|
||||||
for command in commands:
|
for command in sorted(commands, key=lambda x: get_args(get_type_hints(x)[command_field])[0]):
|
||||||
hints = get_type_hints(command)
|
hints = get_type_hints(command)
|
||||||
cmd_name = get_args(hints[command_field])[0]
|
cmd_name = get_args(hints[command_field])[0]
|
||||||
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
|
command_parser = subparsers.add_parser(cmd_name, help=command.__doc__)
|
||||||
|
@ -4,6 +4,7 @@ You may import the global singleton `completer` to get access to the
|
|||||||
completer object.
|
completer object.
|
||||||
"""
|
"""
|
||||||
import atexit
|
import atexit
|
||||||
|
import re
|
||||||
import readline
|
import readline
|
||||||
import shlex
|
import shlex
|
||||||
|
|
||||||
@ -66,6 +67,8 @@ class Completer(object):
|
|||||||
"""
|
"""
|
||||||
if len(buffer)==0:
|
if len(buffer)==0:
|
||||||
return None, None
|
return None, None
|
||||||
|
if re.search('\|\s*[a-zA-Z0-9]*$',buffer): # reset command on pipe symbol
|
||||||
|
return None,None
|
||||||
tokens = shlex.split(buffer)
|
tokens = shlex.split(buffer)
|
||||||
command = None
|
command = None
|
||||||
switch = None
|
switch = None
|
||||||
|
@ -9,8 +9,9 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
from argparse import HelpFormatter
|
||||||
from pydantic import BaseModel
|
from operator import attrgetter
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
@ -57,10 +58,14 @@ def add_invocation_args(command_parser):
|
|||||||
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
|
help="A link from all fields in the specified node. Node can be relative to history (e.g. -1)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class SortedHelpFormatter(HelpFormatter):
|
||||||
|
def add_arguments(self, actions):
|
||||||
|
actions = sorted(actions, key=attrgetter('option_strings'))
|
||||||
|
super(SortedHelpFormatter, self).add_arguments(actions)
|
||||||
|
|
||||||
def get_command_parser() -> argparse.ArgumentParser:
|
def get_command_parser() -> argparse.ArgumentParser:
|
||||||
# Create invocation parser
|
# Create invocation parser
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser(formatter_class=SortedHelpFormatter)
|
||||||
|
|
||||||
def exit(*args, **kwargs):
|
def exit(*args, **kwargs):
|
||||||
raise InvalidArgs
|
raise InvalidArgs
|
||||||
@ -287,6 +292,9 @@ def invoke_cli():
|
|||||||
print("Session error: creating a new session")
|
print("Session error: creating a new session")
|
||||||
context.session = context.invoker.create_execution_state()
|
context.session = context.invoker.create_execution_state()
|
||||||
|
|
||||||
|
except ValidationError as e:
|
||||||
|
print(f'Validation error: {str(e)}')
|
||||||
|
|
||||||
except ExitCli:
|
except ExitCli:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -67,6 +67,29 @@ class LoadImageInvocation(BaseInvocation):
|
|||||||
image=ImageField(image_type=self.image_type, image_name=self.image_name)
|
image=ImageField(image_type=self.image_type, image_name=self.image_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class SaveImageInvocation(BaseInvocation):
|
||||||
|
"""Take an image as input and save it as a PNG file to a filename."""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["save_image"] = "save_image"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: ImageField = Field(default=None, description="The image to save")
|
||||||
|
image_type: ImageType = Field(description="The type of the image")
|
||||||
|
image_name: str = Field(description="The filename or path to save the image to")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image = context.services.images.get(
|
||||||
|
self.image.image_type, self.image.image_name
|
||||||
|
)
|
||||||
|
context.services.images.save(
|
||||||
|
self.image_type, self.image_name, image
|
||||||
|
)
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(
|
||||||
|
image_type=self.image_type, image_name=self.image_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
class ShowImageInvocation(BaseInvocation):
|
class ShowImageInvocation(BaseInvocation):
|
||||||
"""Displays a provided image, and passes it forward in the pipeline."""
|
"""Displays a provided image, and passes it forward in the pipeline."""
|
||||||
|
@ -1,28 +1,28 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
1# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
import einops
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...backend.model_management.model_manager import ModelManager
|
from ...backend.model_management.model_manager import ModelManager
|
||||||
from ...backend.util.devices import CUDA_DEVICE, torch_dtype
|
from ...backend.util.devices import torch_dtype, choose_torch_device
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.image_util.seamless import configure_model_padding
|
from ...backend.image_util.seamless import configure_model_padding
|
||||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from accelerate.utils import set_seed
|
|
||||||
from ..services.image_storage import ImageType
|
from ..services.image_storage import ImageType
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
|
||||||
from .image import ImageField, ImageOutput
|
from .image import ImageField, ImageOutput
|
||||||
from ...backend.generator import Generator
|
from ...backend.generator import Generator
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.util.util import image_to_dataURL
|
from ...backend.util.util import image_to_dataURL
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
import diffusers
|
import diffusers
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline, AutoencoderKL
|
||||||
|
|
||||||
|
|
||||||
class LatentsField(BaseModel):
|
class LatentsField(BaseModel):
|
||||||
@ -110,7 +110,7 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||||
device = torch.device(CUDA_DEVICE)
|
device = choose_torch_device()
|
||||||
noise = get_noise(self.width, self.height, device, self.seed)
|
noise = get_noise(self.width, self.height, device, self.seed)
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
@ -170,8 +170,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||||
model_info = model_manager.get_model(self.model)
|
model_info = model_manager.get_model(self.model)
|
||||||
model_name = model_info['model_name']
|
|
||||||
model_hash = model_info['hash']
|
|
||||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||||
model.scheduler = get_scheduler(
|
model.scheduler = get_scheduler(
|
||||||
model=model,
|
model=model,
|
||||||
@ -211,7 +209,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
if self.noise:
|
||||||
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
else:
|
||||||
|
noise = get_noise(self.width, self.height, choose_torch_device(), self.seed)
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, state.latents, state.step)
|
self.dispatch_progress(context, state.latents, state.step)
|
||||||
@ -230,7 +231,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
if torch.has_cuda:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.set(name, result_latents)
|
context.services.latents.set(name, result_latents)
|
||||||
@ -280,7 +282,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
if torch.has_cuda:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.set(name, result_latents)
|
context.services.latents.set(name, result_latents)
|
||||||
@ -302,14 +305,11 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
vae = context.services.model_manager.get_model_vae(self.model)
|
||||||
# TODO: this only really needs the vae
|
|
||||||
model_info = context.services.model_manager.get_model(self.model)
|
|
||||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
np_image = model.decode_latents(latents)
|
np_image = self._decode_latents(vae,latents)
|
||||||
image = model.numpy_to_pil(np_image)[0]
|
image = StableDiffusionGeneratorPipeline.numpy_to_pil(np_image)[0]
|
||||||
|
|
||||||
image_type = ImageType.RESULT
|
image_type = ImageType.RESULT
|
||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
@ -319,3 +319,54 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(image_type=image_type, image_name=image_name)
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# this should be refactored - duplicated code in diffusers.pipelines.stable_diffusion
|
||||||
|
@classmethod
|
||||||
|
def _decode_latents(self, vae:AutoencoderKL, latents:torch.Tensor)->Image:
|
||||||
|
latents = 1 / vae.config.scaling_factor * latents
|
||||||
|
image = vae.decode(latents).sample
|
||||||
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
# Image to latent
|
||||||
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
|
"""Generates latents from an image."""
|
||||||
|
|
||||||
|
type: Literal["i2l"] = "i2l"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: Optional[ImageField] = Field(description="The image to generate latents from")
|
||||||
|
model: str = Field(default="", description="The model to use")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
image = context.services.images.get(self.image.image_type,self.image.image_name)
|
||||||
|
vae = context.services.model_manager.get_model_vae(self.model)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
result_latents = self._encode_latents(vae,image)
|
||||||
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
|
context.services.latents.set(name, result_latents)
|
||||||
|
return LatentsOutput(
|
||||||
|
latents=LatentsField(latents_name=name)
|
||||||
|
)
|
||||||
|
|
||||||
|
# this should be refactored - similar code in invokeai.backend.stable_diffusion.diffusers_pipeline
|
||||||
|
@classmethod
|
||||||
|
def _encode_latents(self, vae:AutoencoderKL, image:Image)->torch.Tensor:
|
||||||
|
device = choose_torch_device()
|
||||||
|
dtype = torch_dtype(device)
|
||||||
|
image_tensor = image_resized_to_grid_as_tensor(image)
|
||||||
|
if image_tensor.dim() == 3:
|
||||||
|
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||||
|
image_tensor = image_tensor.to(device=device, dtype=dtype)
|
||||||
|
init_latent_dist = vae.encode(image_tensor).latent_dist
|
||||||
|
init_latents = init_latent_dist.sample().to(
|
||||||
|
dtype=dtype
|
||||||
|
)
|
||||||
|
init_latents = 0.18215 * init_latents
|
||||||
|
return init_latents
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
import PIL
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from invokeai.app.util.save_thumbnail import save_thumbnail
|
from invokeai.app.util.save_thumbnail import save_thumbnail
|
||||||
|
|
||||||
@ -18,7 +18,7 @@ class ImageType(str, Enum):
|
|||||||
RESULT = "results"
|
RESULT = "results"
|
||||||
INTERMEDIATE = "intermediates"
|
INTERMEDIATE = "intermediates"
|
||||||
UPLOAD = "uploads"
|
UPLOAD = "uploads"
|
||||||
|
LOCAL = "local" # a local path, relative to cwd or absolute
|
||||||
|
|
||||||
class ImageStorageBase(ABC):
|
class ImageStorageBase(ABC):
|
||||||
"""Responsible for storing and retrieving images."""
|
"""Responsible for storing and retrieving images."""
|
||||||
@ -77,27 +77,30 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
if cache_item:
|
if cache_item:
|
||||||
return cache_item
|
return cache_item
|
||||||
|
|
||||||
image = Image.open(image_path)
|
image = PIL.Image.open(image_path)
|
||||||
self.__set_cache(image_path, image)
|
self.__set_cache(image_path, image)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
if image_type == ImageType.LOCAL:
|
||||||
|
path = image_name
|
||||||
|
else:
|
||||||
|
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||||
image_subpath = os.path.join(image_type, image_name)
|
path = self.get_path(image_type, image_name)
|
||||||
self.__pngWriter.save_image_and_prompt_to_png(
|
self.__pngWriter.save_image_and_prompt_to_png(
|
||||||
image, "", image_subpath, None
|
image, "", path, None
|
||||||
) # TODO: just pass full path to png writer
|
) # TODO: just pass full path to png writer
|
||||||
save_thumbnail(
|
# LS: Save_thumbnail() should be a separate invocation, shouldn't it?
|
||||||
image=image,
|
# save_thumbnail(
|
||||||
filename=image_name,
|
# image=image,
|
||||||
path=os.path.join(self.__output_folder, image_type, "thumbnails"),
|
# filename=image_name,
|
||||||
)
|
# path=os.path.join(self.__output_folder, image_type, "thumbnails"),
|
||||||
image_path = self.get_path(image_type, image_name)
|
# )
|
||||||
self.__set_cache(image_path, image)
|
self.__set_cache(path, image)
|
||||||
|
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_type, image_name)
|
||||||
|
@ -72,11 +72,10 @@ def get_model_manager(config: Args) -> ModelManager:
|
|||||||
# try to autoconvert new models
|
# try to autoconvert new models
|
||||||
# autoimport new .ckpt files
|
# autoimport new .ckpt files
|
||||||
if path := config.autoconvert:
|
if path := config.autoconvert:
|
||||||
model_manager.autoconvert_weights(
|
model_manager.heuristic_import(
|
||||||
conf_path=config.conf,
|
str(path), commit_to_conf=config.conf
|
||||||
weights_directory=path,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_manager
|
return model_manager
|
||||||
|
|
||||||
def report_model_error(opt: Namespace, e: Exception):
|
def report_model_error(opt: Namespace, e: Exception):
|
||||||
|
Reference in New Issue
Block a user