Merge branch 'main' into dev/pytorch2

This commit is contained in:
Lincoln Stein 2023-04-06 21:47:46 -04:00 committed by GitHub
commit 3c50448ccf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
63 changed files with 1851 additions and 509 deletions

8
.github/CODEOWNERS vendored
View File

@ -1,16 +1,16 @@
# continuous integration # continuous integration
/.github/workflows/ @mauwii @lstein /.github/workflows/ @mauwii @lstein @blessedcoolant
# documentation # documentation
/docs/ @lstein @mauwii @tildebyte /docs/ @lstein @mauwii @tildebyte @blessedcoolant
/mkdocs.yml @lstein @mauwii /mkdocs.yml @lstein @mauwii @blessedcoolant
# nodes # nodes
/invokeai/app/ @Kyle0654 @blessedcoolant /invokeai/app/ @Kyle0654 @blessedcoolant
# installation and configuration # installation and configuration
/pyproject.toml @mauwii @lstein @blessedcoolant /pyproject.toml @mauwii @lstein @blessedcoolant
/docker/ @mauwii @lstein /docker/ @mauwii @lstein @blessedcoolant
/scripts/ @ebr @lstein /scripts/ @ebr @lstein
/installer/ @lstein @ebr /installer/ @lstein @ebr
/invokeai/assets @lstein @ebr /invokeai/assets @lstein @ebr

View File

@ -16,6 +16,10 @@ on:
- 'v*.*.*' - 'v*.*.*'
workflow_dispatch: workflow_dispatch:
permissions:
contents: write
packages: write
jobs: jobs:
docker: docker:
if: github.event.pull_request.draft == false if: github.event.pull_request.draft == false

View File

@ -5,6 +5,9 @@ on:
- 'main' - 'main'
- 'development' - 'development'
permissions:
contents: write
jobs: jobs:
mkdocs-material: mkdocs-material:
if: github.event.pull_request.draft == false if: github.event.pull_request.draft == false

View File

@ -6,7 +6,6 @@ on:
- '!pyproject.toml' - '!pyproject.toml'
- '!invokeai/**' - '!invokeai/**'
- 'invokeai/frontend/web/**' - 'invokeai/frontend/web/**'
- '!invokeai/frontend/web/dist/**'
merge_group: merge_group:
workflow_dispatch: workflow_dispatch:

View File

@ -7,13 +7,11 @@ on:
- 'pyproject.toml' - 'pyproject.toml'
- 'invokeai/**' - 'invokeai/**'
- '!invokeai/frontend/web/**' - '!invokeai/frontend/web/**'
- 'invokeai/frontend/web/dist/**'
pull_request: pull_request:
paths: paths:
- 'pyproject.toml' - 'pyproject.toml'
- 'invokeai/**' - 'invokeai/**'
- '!invokeai/frontend/web/**' - '!invokeai/frontend/web/**'
- 'invokeai/frontend/web/dist/**'
types: types:
- 'ready_for_review' - 'ready_for_review'
- 'opened' - 'opened'

View File

@ -139,13 +139,13 @@ not supported.
_For Windows/Linux with an NVIDIA GPU:_ _For Windows/Linux with an NVIDIA GPU:_
```terminal ```terminal
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117 pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
``` ```
_For Linux with an AMD GPU:_ _For Linux with an AMD GPU:_
```sh ```sh
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2 pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
``` ```
_For Macintoshes, either Intel or M1/M2:_ _For Macintoshes, either Intel or M1/M2:_

View File

@ -168,11 +168,15 @@ used by Stable Diffusion 1.4 and 1.5.
After installation, your `models.yaml` should contain an entry that looks like After installation, your `models.yaml` should contain an entry that looks like
this one: this one:
inpainting-1.5: weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt ```yml
description: SD inpainting v1.5 config: inpainting-1.5:
configs/stable-diffusion/v1-inpainting-inference.yaml vae: weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt width: 512 description: SD inpainting v1.5
config: configs/stable-diffusion/v1-inpainting-inference.yaml
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
width: 512
height: 512 height: 512
```
As shown in the example, you may include a VAE fine-tuning weights file as well. As shown in the example, you may include a VAE fine-tuning weights file as well.
This is strongly recommended. This is strongly recommended.

View File

@ -268,7 +268,7 @@ model is so good at inpainting, a good substitute is to use the `clipseg` text
masking option: masking option:
```bash ```bash
invoke> a fluffy cat eating a hotdot invoke> a fluffy cat eating a hotdog
Outputs: Outputs:
[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog [1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog
invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat

View File

@ -417,7 +417,7 @@ Then type the following commands:
=== "AMD System" === "AMD System"
```bash ```bash
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.2 pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
``` ```
### Corrupted configuration file ### Corrupted configuration file

View File

@ -154,7 +154,7 @@ manager, please follow these steps:
=== "ROCm (AMD)" === "ROCm (AMD)"
```bash ```bash
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2 pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
``` ```
=== "CPU (Intel Macs & non-GPU systems)" === "CPU (Intel Macs & non-GPU systems)"
@ -315,7 +315,7 @@ installation protocol (important!)
=== "ROCm (AMD)" === "ROCm (AMD)"
```bash ```bash
pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2 pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
``` ```
=== "CPU (Intel Macs & non-GPU systems)" === "CPU (Intel Macs & non-GPU systems)"

View File

@ -110,7 +110,7 @@ recipes are available
When installing torch and torchvision manually with `pip`, remember to provide When installing torch and torchvision manually with `pip`, remember to provide
the argument `--extra-index-url the argument `--extra-index-url
https://download.pytorch.org/whl/rocm5.2` as described in the [Manual https://download.pytorch.org/whl/rocm5.4.2` as described in the [Manual
Installation Guide](020_INSTALL_MANUAL.md). Installation Guide](020_INSTALL_MANUAL.md).
This will be done automatically for you if you use the installer This will be done automatically for you if you use the installer

View File

@ -456,7 +456,7 @@ def get_torch_source() -> (Union[str, None],str):
optional_modules = None optional_modules = None
if OS == "Linux": if OS == "Linux":
if device == "rocm": if device == "rocm":
url = "https://download.pytorch.org/whl/rocm5.2" url = "https://download.pytorch.org/whl/rocm5.4.2"
elif device == "cpu": elif device == "cpu":
url = "https://download.pytorch.org/whl/cpu" url = "https://download.pytorch.org/whl/cpu"

View File

@ -24,9 +24,9 @@ if [ "$(uname -s)" == "Darwin" ]; then
export PYTORCH_ENABLE_MPS_FALLBACK=1 export PYTORCH_ENABLE_MPS_FALLBACK=1
fi fi
if [ "$0" != "bash" ]; then
while true while true
do do
if [ "$0" != "bash" ]; then
echo "Do you want to generate images using the" echo "Do you want to generate images using the"
echo "1. command-line interface" echo "1. command-line interface"
echo "2. browser-based UI" echo "2. browser-based UI"
@ -87,9 +87,9 @@ if [ "$0" != "bash" ]; then
echo "Invalid selection" echo "Invalid selection"
exit;; exit;;
esac esac
done
else # in developer console else # in developer console
python --version python --version
echo "Press ^D to exit" echo "Press ^D to exit"
export PS1="(InvokeAI) \u@\h \w> " export PS1="(InvokeAI) \u@\h \w> "
fi fi
done

View File

@ -3,6 +3,8 @@
import os import os
from argparse import Namespace from argparse import Namespace
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ...backend import Globals from ...backend import Globals
from ..services.model_manager_initializer import get_model_manager from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices from ..services.restoration_services import RestorationServices
@ -54,7 +56,9 @@ class ApiDependencies:
os.path.join(os.path.dirname(__file__), "../../../../outputs") os.path.join(os.path.dirname(__file__), "../../../../outputs")
) )
images = DiskImageStorage(output_folder) latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
images = DiskImageStorage(f'{output_folder}/images')
# TODO: build a file/path manager? # TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db") db_location = os.path.join(output_folder, "invokeai.db")
@ -62,6 +66,7 @@ class ApiDependencies:
services = InvocationServices( services = InvocationServices(
model_manager=get_model_manager(config), model_manager=get_model_manager(config),
events=events, events=events,
latents=latents,
images=images, images=images,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_execution_manager=SqliteItemStorage[GraphExecutionState]( graph_execution_manager=SqliteItemStorage[GraphExecutionState](

View File

@ -23,6 +23,16 @@ async def get_image(
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name) filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
return FileResponse(filename) return FileResponse(filename)
@images_router.get("/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail")
async def get_thumbnail(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
):
"""Gets a thumbnail"""
# TODO: This is not really secure at all. At least make sure only output results are served
filename = ApiDependencies.invoker.services.images.get_path(image_type, 'thumbnails/' + image_name)
return FileResponse(filename)
@images_router.post( @images_router.post(
"/uploads/", "/uploads/",

View File

@ -0,0 +1,279 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Annotated, Any, List, Literal, Optional, Union
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
class VaeRepo(BaseModel):
repo_id: str = Field(description="The repo ID to use for this VAE")
path: Optional[str] = Field(description="The path to the VAE")
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model")
class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt'
config: str = Field(description="The path to the model config")
weights: str = Field(description="The path to the model weights")
vae: str = Field(description="The path to the model VAE")
width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['diffusers'] = 'diffusers'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class ModelsList(BaseModel):
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
@models_router.get(
"/",
operation_id="list_models",
responses={200: {"model": ModelsList }},
)
async def list_models() -> ModelsList:
"""Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models()
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
# @socketio.on("requestSystemConfig")
# def handle_request_capabilities():
# print(">> System config requested")
# config = self.get_system_config()
# config["model_list"] = self.generate.model_manager.list_models()
# config["infill_methods"] = infill_methods()
# socketio.emit("systemConfig", config)
# @socketio.on("searchForModels")
# def handle_search_models(search_folder: str):
# try:
# if not search_folder:
# socketio.emit(
# "foundModels",
# {"search_folder": None, "found_models": None},
# )
# else:
# (
# search_folder,
# found_models,
# ) = self.generate.model_manager.search_models(search_folder)
# socketio.emit(
# "foundModels",
# {"search_folder": search_folder, "found_models": found_models},
# )
# except Exception as e:
# self.handle_exceptions(e)
# print("\n")
# @socketio.on("addNewModel")
# def handle_add_model(new_model_config: dict):
# try:
# model_name = new_model_config["name"]
# del new_model_config["name"]
# model_attributes = new_model_config
# if len(model_attributes["vae"]) == 0:
# del model_attributes["vae"]
# update = False
# current_model_list = self.generate.model_manager.list_models()
# if model_name in current_model_list:
# update = True
# print(f">> Adding New Model: {model_name}")
# self.generate.model_manager.add_model(
# model_name=model_name,
# model_attributes=model_attributes,
# clobber=True,
# )
# self.generate.model_manager.commit(opt.conf)
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "newModelAdded",
# {
# "new_model_name": model_name,
# "model_list": new_model_list,
# "update": update,
# },
# )
# print(f">> New Model Added: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("deleteModel")
# def handle_delete_model(model_name: str):
# try:
# print(f">> Deleting Model: {model_name}")
# self.generate.model_manager.del_model(model_name)
# self.generate.model_manager.commit(opt.conf)
# updated_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelDeleted",
# {
# "deleted_model_name": model_name,
# "model_list": updated_model_list,
# },
# )
# print(f">> Model Deleted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("requestModelChange")
# def handle_set_model(model_name: str):
# try:
# print(f">> Model change requested: {model_name}")
# model = self.generate.set_model(model_name)
# model_list = self.generate.model_manager.list_models()
# if model is None:
# socketio.emit(
# "modelChangeFailed",
# {"model_name": model_name, "model_list": model_list},
# )
# else:
# socketio.emit(
# "modelChanged",
# {"model_name": model_name, "model_list": model_list},
# )
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict):
# try:
# if model_info := self.generate.model_manager.model_info(
# model_name=model_to_convert["model_name"]
# ):
# if "weights" in model_info:
# ckpt_path = Path(model_info["weights"])
# original_config_file = Path(model_info["config"])
# model_name = model_to_convert["model_name"]
# model_description = model_info["description"]
# else:
# self.socketio.emit(
# "error", {"message": "Model is not a valid checkpoint file"}
# )
# else:
# self.socketio.emit(
# "error", {"message": "Could not retrieve model info."}
# )
# if not ckpt_path.is_absolute():
# ckpt_path = Path(Globals.root, ckpt_path)
# if original_config_file and not original_config_file.is_absolute():
# original_config_file = Path(Globals.root, original_config_file)
# diffusers_path = Path(
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
# )
# if model_to_convert["save_location"] == "root":
# diffusers_path = Path(
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
# )
# if (
# model_to_convert["save_location"] == "custom"
# and model_to_convert["custom_location"] is not None
# ):
# diffusers_path = Path(
# model_to_convert["custom_location"], f"{model_name}_diffusers"
# )
# if diffusers_path.exists():
# shutil.rmtree(diffusers_path)
# self.generate.model_manager.convert_and_import(
# ckpt_path,
# diffusers_path,
# model_name=model_name,
# model_description=model_description,
# vae=None,
# original_config_file=original_config_file,
# commit_to_conf=opt.conf,
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelConverted",
# {
# "new_model_name": model_name,
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Model Converted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("mergeDiffusersModels")
# def merge_diffusers_models(model_merge_info: dict):
# try:
# models_to_merge = model_merge_info["models_to_merge"]
# model_ids_or_paths = [
# self.generate.model_manager.model_name_or_path(x)
# for x in models_to_merge
# ]
# merged_pipe = merge_diffusion_models(
# model_ids_or_paths,
# model_merge_info["alpha"],
# model_merge_info["interp"],
# model_merge_info["force"],
# )
# dump_path = global_models_dir() / "merged_models"
# if model_merge_info["model_merge_save_path"] is not None:
# dump_path = Path(model_merge_info["model_merge_save_path"])
# os.makedirs(dump_path, exist_ok=True)
# dump_path = dump_path / model_merge_info["merged_model_name"]
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
# merged_model_config = dict(
# model_name=model_merge_info["merged_model_name"],
# description=f'Merge of models {", ".join(models_to_merge)}',
# commit_to_conf=opt.conf,
# )
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
# "vae", None
# ):
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
# merged_model_config.update(vae=vae)
# self.generate.model_manager.import_diffuser_model(
# dump_path, **merged_model_config
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelsMerged",
# {
# "merged_models": models_to_merge,
# "merged_model_name": model_merge_info["merged_model_name"],
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
# except Exception as e:
# self.handle_exceptions(e)

View File

@ -51,7 +51,7 @@ async def list_sessions(
query: str = Query(default="", description="The query string to search for"), query: str = Query(default="", description="The query string to search for"),
) -> PaginatedResults[GraphExecutionState]: ) -> PaginatedResults[GraphExecutionState]:
"""Gets a list of sessions, optionally searching""" """Gets a list of sessions, optionally searching"""
if filter == "": if query == "":
result = ApiDependencies.invoker.services.graph_execution_manager.list( result = ApiDependencies.invoker.services.graph_execution_manager.list(
page, per_page page, per_page
) )
@ -270,3 +270,18 @@ async def invoke_session(
ApiDependencies.invoker.invoke(session, invoke_all=all) ApiDependencies.invoker.invoke(session, invoke_all=all)
return Response(status_code=202) return Response(status_code=202)
@session_router.delete(
"/{session_id}/invoke",
operation_id="cancel_session_invoke",
responses={
202: {"description": "The invocation is canceled"}
},
)
async def cancel_session_invoke(
session_id: str = Path(description="The id of the session to cancel"),
) -> None:
"""Invokes a session"""
ApiDependencies.invoker.cancel(session_id)
return Response(status_code=202)

View File

@ -14,7 +14,7 @@ from pydantic.schema import schema
from ..backend import Args from ..backend import Args
from .api.dependencies import ApiDependencies from .api.dependencies import ApiDependencies
from .api.routers import images, sessions from .api.routers import images, sessions, models
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations import * from .invocations import *
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
@ -76,6 +76,8 @@ app.include_router(sessions.session_router, prefix="/api")
app.include_router(images.images_router, prefix="/api") app.include_router(images.images_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
# Build a custom OpenAPI to include all outputs # Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow? # TODO: can outputs be included on metadata of invocation schemas somehow?

View File

@ -4,7 +4,8 @@ from abc import ABC, abstractmethod
import argparse import argparse
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import networkx as nx
import matplotlib.pyplot as plt
from ..invocations.image import ImageField from ..invocations.image import ImageField
from ..services.graph import GraphExecutionState from ..services.graph import GraphExecutionState
from ..services.invoker import Invoker from ..services.invoker import Invoker
@ -46,7 +47,7 @@ def add_parsers(
f"--{name}", f"--{name}",
dest=name, dest=name,
type=field_type, type=field_type,
default=field.default, default=field.default if field.default_factory is None else field.default_factory(),
choices=allowed_values, choices=allowed_values,
help=field.field_info.description, help=field.field_info.description,
) )
@ -55,7 +56,7 @@ def add_parsers(
f"--{name}", f"--{name}",
dest=name, dest=name,
type=field.type_, type=field.type_,
default=field.default, default=field.default if field.default_factory is None else field.default_factory(),
help=field.field_info.description, help=field.field_info.description,
) )
@ -200,3 +201,39 @@ class SetDefaultCommand(BaseCommand):
del context.defaults[self.field] del context.defaults[self.field]
else: else:
context.defaults[self.field] = self.value context.defaults[self.field] = self.value
class DrawGraphCommand(BaseCommand):
"""Debugs a graph"""
type: Literal['draw_graph'] = 'draw_graph'
def run(self, context: CliContext) -> None:
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
nxgraph = session.graph.nx_graph_flat()
# Draw the networkx graph
plt.figure(figsize=(20, 20))
pos = nx.spectral_layout(nxgraph)
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
nx.draw_networkx_edges(nxgraph, pos, width=2)
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
plt.axis("off")
plt.show()
class DrawExecutionGraphCommand(BaseCommand):
"""Debugs an execution graph"""
type: Literal['draw_xgraph'] = 'draw_xgraph'
def run(self, context: CliContext) -> None:
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
nxgraph = session.execution_graph.nx_graph_flat()
# Draw the networkx graph
plt.figure(figsize=(20, 20))
pos = nx.spectral_layout(nxgraph)
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
nx.draw_networkx_edges(nxgraph, pos, width=2)
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
plt.axis("off")
plt.show()

View File

@ -0,0 +1,167 @@
"""
Readline helper functions for cli_app.py
You may import the global singleton `completer` to get access to the
completer object.
"""
import atexit
import readline
import shlex
from pathlib import Path
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
from ...backend import ModelManager, Globals
from ..invocations.baseinvocation import BaseInvocation
from .commands import BaseCommand
# singleton object, class variable
completer = None
class Completer(object):
def __init__(self, model_manager: ModelManager):
self.commands = self.get_commands()
self.matches = None
self.linebuffer = None
self.manager = model_manager
return
def complete(self, text, state):
"""
Complete commands and switches fromm the node CLI command line.
Switches are determined in a context-specific manner.
"""
buffer = readline.get_line_buffer()
if state == 0:
options = None
try:
current_command, current_switch = self.get_current_command(buffer)
options = self.get_command_options(current_command, current_switch)
except IndexError:
pass
options = options or list(self.parse_commands().keys())
if not text: # first time
self.matches = options
else:
self.matches = [s for s in options if s and s.startswith(text)]
try:
match = self.matches[state]
except IndexError:
match = None
return match
@classmethod
def get_commands(self)->List[object]:
"""
Return a list of all the client commands and invocations.
"""
return BaseCommand.get_commands() + BaseInvocation.get_invocations()
def get_current_command(self, buffer: str)->tuple[str, str]:
"""
Parse the readline buffer to find the most recent command and its switch.
"""
if len(buffer)==0:
return None, None
tokens = shlex.split(buffer)
command = None
switch = None
for t in tokens:
if t[0].isalpha():
if switch is None:
command = t
else:
switch = t
# don't try to autocomplete switches that are already complete
if switch and buffer.endswith(' '):
switch=None
return command or '', switch or ''
def parse_commands(self)->Dict[str, List[str]]:
"""
Return a dict in which the keys are the command name
and the values are the parameters the command takes.
"""
result = dict()
for command in self.commands:
hints = get_type_hints(command)
name = get_args(hints['type'])[0]
result.update({name:hints})
return result
def get_command_options(self, command: str, switch: str)->List[str]:
"""
Return all the parameters that can be passed to the command as
command-line switches. Returns None if the command is unrecognized.
"""
parsed_commands = self.parse_commands()
if command not in parsed_commands:
return None
# handle switches in the format "-foo=bar"
argument = None
if switch and '=' in switch:
switch, argument = switch.split('=')
parameter = switch.strip('-')
if parameter in parsed_commands[command]:
if argument is None:
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
else:
return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])]
else:
return [f"--{x}" for x in parsed_commands[command].keys()]
def get_parameter_options(self, parameter: str, typehint)->List[str]:
"""
Given a parameter type (such as Literal), offers autocompletions.
"""
if get_origin(typehint) == Literal:
return get_args(typehint)
if parameter == 'model':
return self.manager.model_names()
def _pre_input_hook(self):
if self.linebuffer:
readline.insert_text(self.linebuffer)
readline.redisplay()
self.linebuffer = None
def set_autocompleter(model_manager: ModelManager) -> Completer:
global completer
if completer:
return completer
completer = Completer(model_manager)
readline.set_completer(completer.complete)
# pyreadline3 does not have a set_auto_history() method
try:
readline.set_auto_history(True)
except:
pass
readline.set_pre_input_hook(completer._pre_input_hook)
readline.set_completer_delims(" ")
readline.parse_and_bind("tab: complete")
readline.parse_and_bind("set print-completions-horizontally off")
readline.parse_and_bind("set page-completions on")
readline.parse_and_bind("set skip-completed-text on")
readline.parse_and_bind("set show-all-if-ambiguous on")
histfile = Path(Globals.root, ".invoke_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
except OSError: # file likely corrupted
newname = f"{histfile}.old"
print(
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
)
histfile.replace(Path(newname))
atexit.register(readline.write_history_file, histfile)

View File

@ -2,6 +2,7 @@
import argparse import argparse
import os import os
import re
import shlex import shlex
import time import time
from typing import ( from typing import (
@ -12,14 +13,17 @@ from typing import (
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import Field from pydantic.fields import Field
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..backend import Args from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
from .cli.completer import set_autocompleter
from .invocations import * from .invocations import *
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase from .services.events import EventServiceBase
from .services.model_manager_initializer import get_model_manager from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices from .services.restoration_services import RestorationServices
from .services.graph import Edge, EdgeConnection, GraphExecutionState from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible
from .services.image_storage import DiskImageStorage from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices from .services.invocation_services import InvocationServices
@ -43,7 +47,7 @@ def add_invocation_args(command_parser):
"-l", "-l",
action="append", action="append",
nargs=3, nargs=3,
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)", help="A link in the format 'source_node source_field dest_field'. source_node can be relative to history (e.g. -1)",
) )
command_parser.add_argument( command_parser.add_argument(
@ -93,6 +97,9 @@ def generate_matching_edges(
invalid_fields = set(["type", "id"]) invalid_fields = set(["type", "id"])
matching_fields = matching_fields.difference(invalid_fields) matching_fields = matching_fields.difference(invalid_fields)
# Validate types
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])]
edges = [ edges = [
Edge( Edge(
source=EdgeConnection(node_id=a.id, field=field), source=EdgeConnection(node_id=a.id, field=field),
@ -130,6 +137,12 @@ def invoke_cli():
config.parse_args() config.parse_args()
model_manager = get_model_manager(config) model_manager = get_model_manager(config)
# This initializes the autocompleter and returns it.
# Currently nothing is done with the returned Completer
# object, but the object can be used to change autocompletion
# behavior on the fly, if desired.
completer = set_autocompleter(model_manager)
events = EventServiceBase() events = EventServiceBase()
output_folder = os.path.abspath( output_folder = os.path.abspath(
@ -142,7 +155,8 @@ def invoke_cli():
services = InvocationServices( services = InvocationServices(
model_manager=model_manager, model_manager=model_manager,
events=events, events=events,
images=DiskImageStorage(output_folder), latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=DiskImageStorage(f'{output_folder}/images'),
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_execution_manager=SqliteItemStorage[GraphExecutionState]( graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions" filename=db_location, table_name="graph_executions"
@ -155,6 +169,8 @@ def invoke_cli():
session: GraphExecutionState = invoker.create_execution_state() session: GraphExecutionState = invoker.create_execution_state()
parser = get_command_parser() parser = get_command_parser()
re_negid = re.compile('^-[0-9]+$')
# Uncomment to print out previous sessions at startup # Uncomment to print out previous sessions at startup
# print(services.session_manager.list()) # print(services.session_manager.list())
@ -162,8 +178,8 @@ def invoke_cli():
while True: while True:
try: try:
cmd_input = input("> ") cmd_input = input("invoke> ")
except KeyboardInterrupt: except (KeyboardInterrupt, EOFError):
# Ctrl-c exits # Ctrl-c exits
break break
@ -220,7 +236,11 @@ def invoke_cli():
# Parse provided links # Parse provided links
if "link_node" in args and args["link_node"]: if "link_node" in args and args["link_node"]:
for link in args["link_node"]: for link in args["link_node"]:
link_node = context.session.graph.get_node(link) node_id = link
if re_negid.match(node_id):
node_id = str(current_id + int(node_id))
link_node = context.session.graph.get_node(node_id)
matching_edges = generate_matching_edges( matching_edges = generate_matching_edges(
link_node, command.command link_node, command.command
) )
@ -230,10 +250,15 @@ def invoke_cli():
if "link" in args and args["link"]: if "link" in args and args["link"]:
for link in args["link"]: for link in args["link"]:
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]] edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]]
node_id = link[0]
if re_negid.match(node_id):
node_id = str(current_id + int(node_id))
edges.append( edges.append(
Edge( Edge(
source=EdgeConnection(node_id=link[1], field=link[0]), source=EdgeConnection(node_id=node_id, field=link[1]),
destination=EdgeConnection( destination=EdgeConnection(
node_id=command.command.id, field=link[2] node_id=command.command.id, field=link[2]
) )

View File

@ -0,0 +1,50 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
import cv2 as cv
import numpy as np
import numpy.random
from PIL import Image, ImageOps
from pydantic import Field
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput
from .image import ImageField, ImageOutput
class IntCollectionOutput(BaseInvocationOutput):
"""A collection of integers"""
type: Literal["int_collection"] = "int_collection"
# Outputs
collection: list[int] = Field(default=[], description="The int collection")
class RangeInvocation(BaseInvocation):
"""Creates a range"""
type: Literal["range"] = "range"
# Inputs
start: int = Field(default=0, description="The start of the range")
stop: int = Field(default=10, description="The stop of the range")
step: int = Field(default=1, description="The step of the range")
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""
type: Literal["random_range"] = "random_range"
# Inputs
low: int = Field(default=0, description="The inclusive low value")
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
size: int = Field(default=1, description="The number of values to generate")
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(numpy.random.randint(self.low, self.high, size=self.size)))

View File

@ -1,22 +1,19 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone from functools import partial
from typing import Any, Literal, Optional, Union from typing import Literal, Optional, Union
import numpy as np import numpy as np
from torch import Tensor from torch import Tensor
from PIL import Image
from pydantic import Field from pydantic import Field
from skimage.exposure.histogram_matching import match_histograms
from ..services.image_storage import ImageType from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL from ..util.util import diffusers_step_callback_adapter, CanceledException
SAMPLER_NAME_VALUES = Literal[ SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers()) tuple(InvokeAIGenerator.schedulers())
@ -45,32 +42,26 @@ class TextToImageInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching? # TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress( def dispatch_progress(
self, context: InvocationContext, sample: Tensor, step: int self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None: ) -> None:
# TODO: only output a preview image when requested if (context.services.queue.is_canceled(context.graph_execution_state_id)):
image = Generator.sample_to_lowres_estimated_image(sample) raise CanceledException
(width, height) = image.size step = intermediate_state.step
width *= 8 if intermediate_state.predicted_original is not None:
height *= 8 # Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
dataURL = image_to_dataURL(image, image_format="JPEG") diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
self.steps,
)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(state: PipelineIntermediateState): # def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step) # if (context.services.queue.is_canceled(context.graph_execution_state_id)):
# raise CanceledException
# self.dispatch_progress(context, state.latents, state.step)
# Handle invalid model parameter # Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
@ -79,7 +70,7 @@ class TextToImageInvocation(BaseInvocation):
model= context.services.model_manager.get_model() model= context.services.model_manager.get_model()
outputs = Txt2Img(model).generate( outputs = Txt2Img(model).generate(
prompt=self.prompt, prompt=self.prompt,
step_callback=step_callback, step_callback=partial(self.dispatch_progress, context),
**self.dict( **self.dict(
exclude={"prompt"} exclude={"prompt"}
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually
@ -116,6 +107,22 @@ class ImageToImageInvocation(TextToImageInvocation):
description="Whether or not the result should be fit to the aspect ratio of the input image", description="Whether or not the result should be fit to the aspect ratio of the input image",
) )
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = ( image = (
None None
@ -126,24 +133,23 @@ class ImageToImageInvocation(TextToImageInvocation):
) )
mask = None mask = None
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter # Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now? # TODO: How to get the default model name now?
model = context.services.model_manager.get_model() model = context.services.model_manager.get_model()
generator_output = next( outputs = Img2Img(model).generate(
Img2Img(model).generate(
prompt=self.prompt, prompt=self.prompt,
init_image=image, init_image=image,
init_mask=mask, init_mask=mask,
step_callback=step_callback, step_callback=partial(self.dispatch_progress, context),
**self.dict( **self.dict(
exclude={"prompt", "image", "mask"} exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually
) )
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image result_image = generator_output.image
@ -173,6 +179,22 @@ class InpaintInvocation(ImageToImageInvocation):
description="The amount by which to replace masked areas with latent noise", description="The amount by which to replace masked areas with latent noise",
) )
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = ( image = (
None None
@ -187,24 +209,23 @@ class InpaintInvocation(ImageToImageInvocation):
else context.services.images.get(self.mask.image_type, self.mask.image_name) else context.services.images.get(self.mask.image_type, self.mask.image_name)
) )
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter # Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now? # TODO: How to get the default model name now?
manager = context.services.model_manager.get_model() model = context.services.model_manager.get_model()
generator_output = next( outputs = Inpaint(model).generate(
Inpaint(model).generate(
prompt=self.prompt, prompt=self.prompt,
init_image=image, init_img=image,
mask_image=mask, init_mask=mask,
step_callback=step_callback, step_callback=partial(self.dispatch_progress, context),
**self.dict( **self.dict(
exclude={"prompt", "image", "mask"} exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually
) )
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image result_image = generator_output.image

View File

@ -28,12 +28,28 @@ class ImageOutput(BaseInvocationOutput):
image: ImageField = Field(default=None, description="The output image") image: ImageField = Field(default=None, description="The output image")
#fmt: on #fmt: on
class Config:
schema_extra = {
'required': [
'type',
'image',
]
}
class MaskOutput(BaseInvocationOutput): class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask""" """Base class for invocations that output a mask"""
#fmt: off #fmt: off
type: Literal["mask"] = "mask" type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask") mask: ImageField = Field(default=None, description="The output mask")
#fomt: on #fmt: on
class Config:
schema_extra = {
'required': [
'type',
'mask',
]
}
# TODO: this isn't really necessary anymore # TODO: this isn't really necessary anymore
class LoadImageInvocation(BaseInvocation): class LoadImageInvocation(BaseInvocation):

View File

@ -0,0 +1,321 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional
from pydantic import BaseModel, Field
from torch import Tensor
import torch
from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import CUDA_DEVICE, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
import numpy as np
from accelerate.utils import set_seed
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
from diffusers import DiffusionPipeline
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents"""
#fmt: off
type: Literal["latent_output"] = "latent_output"
latents: LatentsField = Field(default=None, description="The output latents")
#fmt: on
class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output"""
#fmt: off
type: Literal["noise_output"] = "noise_output"
noise: LatentsField = Field(default=None, description="The output noise")
#fmt: on
# TODO: this seems like a hack
scheduler_map = dict(
ddim=diffusers.DDIMScheduler,
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_euler=diffusers.EulerDiscreteScheduler,
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
k_heun=diffusers.HeunDiscreteScheduler,
k_lms=diffusers.LMSDiscreteScheduler,
plms=diffusers.PNDMScheduler,
)
SAMPLER_NAME_VALUES = Literal[
tuple(list(scheduler_map.keys()))
]
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
scheduler = scheduler_class.from_config(model.scheduler.config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
return scheduler
def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8):
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(latent_channels, 4)
use_device = "cpu" if (use_mps_noise or device.type == "mps") else device
generator = torch.Generator(device=use_device).manual_seed(seed)
x = torch.randn(
[
1,
input_channels,
height // downsampling_factor,
width // downsampling_factor,
],
dtype=torch_dtype(device),
device=use_device,
generator=generator,
).to(device)
# if self.perlin > 0.0:
# perlin_noise = self.get_perlin_noise(
# width // self.downsampling_factor, height // self.downsampling_factor
# )
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
return x
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", )
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(CUDA_DEVICE)
noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, noise)
return NoiseOutput(
noise=LatentsField(latents_name=name)
)
# Text to image
class TextToLatentsInvocation(BaseInvocation):
"""Generates latents from a prompt."""
type: Literal["t2l"] = "t2l"
# Inputs
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
# fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, sample: Tensor, step: int
) -> None:
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
self.steps,
)
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = model_manager.get_model(self.model)
model_name = model_info['model_name']
model_hash = model_info['hash']
model: StableDiffusionGeneratorPipeline = model_info['model']
model.scheduler = get_scheduler(
model=model,
scheduler_name=self.sampler_name
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
return model
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
conditioning_data = ConditioningData(
uc,
c,
self.cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=0.0,#threshold,
warmup=0.2,#warmup,
h_symmetry_time_pct=None,#h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta)
return conditioning_data
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=model.device, dtype=latent.dtype
)
timesteps, _ = model.get_img2img_timesteps(
self.steps,
self.strength,
device=model.device,
)
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
# Latent to image
class LatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""
type: Literal["l2i"] = "l2i"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
model: str = Field(default="", description="The model to use")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
# TODO: this only really needs the vae
model_info = context.services.model_manager.get_model(self.model)
model: StableDiffusionGeneratorPipeline = model_info['model']
with torch.inference_mode():
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@ -0,0 +1,68 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Literal, Optional
import numpy
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
class IntOutput(BaseInvocationOutput):
"""An integer output"""
#fmt: off
type: Literal["int_output"] = "int_output"
a: int = Field(default=None, description="The output integer")
#fmt: on
class AddInvocation(BaseInvocation):
"""Adds two numbers"""
#fmt: off
type: Literal["add"] = "add"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a + self.b)
class SubtractInvocation(BaseInvocation):
"""Subtracts two numbers"""
#fmt: off
type: Literal["sub"] = "sub"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a - self.b)
class MultiplyInvocation(BaseInvocation):
"""Multiplies two numbers"""
#fmt: off
type: Literal["mul"] = "mul"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a * self.b)
class DivideInvocation(BaseInvocation):
"""Divides two numbers"""
#fmt: off
type: Literal["div"] = "div"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=int(self.a / self.b))

View File

@ -12,3 +12,11 @@ class PromptOutput(BaseInvocationOutput):
prompt: str = Field(default=None, description="The output prompt") prompt: str = Field(default=None, description="The output prompt")
#fmt: on #fmt: on
class Config:
schema_extra = {
'required': [
'type',
'prompt',
]
}

View File

@ -127,6 +127,13 @@ class NodeAlreadyExecutedError(Exception):
class GraphInvocationOutput(BaseInvocationOutput): class GraphInvocationOutput(BaseInvocationOutput):
type: Literal["graph_output"] = "graph_output" type: Literal["graph_output"] = "graph_output"
class Config:
schema_extra = {
'required': [
'type',
'image',
]
}
# TODO: Fill this out and move to invocations # TODO: Fill this out and move to invocations
class GraphInvocation(BaseInvocation): class GraphInvocation(BaseInvocation):
@ -147,6 +154,13 @@ class IterateInvocationOutput(BaseInvocationOutput):
item: Any = Field(description="The item being iterated over") item: Any = Field(description="The item being iterated over")
class Config:
schema_extra = {
'required': [
'type',
'item',
]
}
# TODO: Fill this out and move to invocations # TODO: Fill this out and move to invocations
class IterateInvocation(BaseInvocation): class IterateInvocation(BaseInvocation):
@ -169,6 +183,13 @@ class CollectInvocationOutput(BaseInvocationOutput):
collection: list[Any] = Field(description="The collection of input items") collection: list[Any] = Field(description="The collection of input items")
class Config:
schema_extra = {
'required': [
'type',
'collection',
]
}
class CollectInvocation(BaseInvocation): class CollectInvocation(BaseInvocation):
"""Collects values into a collection""" """Collects values into a collection"""
@ -1048,9 +1069,8 @@ class GraphExecutionState(BaseModel):
n n
for n in prepared_nodes for n in prepared_nodes
if all( if all(
pit nx.has_path(execution_graph, pit[0], n)
for pit in parent_iterators for pit in parent_iterators
if nx.has_path(execution_graph, pit[0], n)
) )
), ),
None, None,

View File

@ -9,6 +9,7 @@ from queue import Queue
from typing import Dict from typing import Dict
from PIL.Image import Image from PIL.Image import Image
from invokeai.app.util.save_thumbnail import save_thumbnail
from invokeai.backend.image_util import PngWriter from invokeai.backend.image_util import PngWriter
@ -66,6 +67,9 @@ class DiskImageStorage(ImageStorageBase):
Path(os.path.join(output_folder, image_type)).mkdir( Path(os.path.join(output_folder, image_type)).mkdir(
parents=True, exist_ok=True parents=True, exist_ok=True
) )
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
def get(self, image_type: ImageType, image_name: str) -> Image: def get(self, image_type: ImageType, image_name: str) -> Image:
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_type, image_name)
@ -87,7 +91,11 @@ class DiskImageStorage(ImageStorageBase):
self.__pngWriter.save_image_and_prompt_to_png( self.__pngWriter.save_image_and_prompt_to_png(
image, "", image_subpath, None image, "", image_subpath, None
) # TODO: just pass full path to png writer ) # TODO: just pass full path to png writer
save_thumbnail(
image=image,
filename=image_name,
path=os.path.join(self.__output_folder, image_type, "thumbnails"),
)
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_type, image_name)
self.__set_cache(image_path, image) self.__set_cache(image_path, image)

View File

@ -2,6 +2,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from queue import Queue from queue import Queue
import time
# TODO: make this serializable # TODO: make this serializable
@ -10,6 +11,7 @@ class InvocationQueueItem:
graph_execution_state_id: str graph_execution_state_id: str
invocation_id: str invocation_id: str
invoke_all: bool invoke_all: bool
timestamp: float
def __init__( def __init__(
self, self,
@ -22,6 +24,7 @@ class InvocationQueueItem:
self.graph_execution_state_id = graph_execution_state_id self.graph_execution_state_id = graph_execution_state_id
self.invocation_id = invocation_id self.invocation_id = invocation_id
self.invoke_all = invoke_all self.invoke_all = invoke_all
self.timestamp = time.time()
class InvocationQueueABC(ABC): class InvocationQueueABC(ABC):
@ -35,15 +38,44 @@ class InvocationQueueABC(ABC):
def put(self, item: InvocationQueueItem | None) -> None: def put(self, item: InvocationQueueItem | None) -> None:
pass pass
@abstractmethod
def cancel(self, graph_execution_state_id: str) -> None:
pass
@abstractmethod
def is_canceled(self, graph_execution_state_id: str) -> bool:
pass
class MemoryInvocationQueue(InvocationQueueABC): class MemoryInvocationQueue(InvocationQueueABC):
__queue: Queue __queue: Queue
__cancellations: dict[str, float]
def __init__(self): def __init__(self):
self.__queue = Queue() self.__queue = Queue()
self.__cancellations = dict()
def get(self) -> InvocationQueueItem: def get(self) -> InvocationQueueItem:
return self.__queue.get() item = self.__queue.get()
while isinstance(item, InvocationQueueItem) \
and item.graph_execution_state_id in self.__cancellations \
and self.__cancellations[item.graph_execution_state_id] > item.timestamp:
item = self.__queue.get()
# Clear old items
for graph_execution_state_id in list(self.__cancellations.keys()):
if self.__cancellations[graph_execution_state_id] < item.timestamp:
del self.__cancellations[graph_execution_state_id]
return item
def put(self, item: InvocationQueueItem | None) -> None: def put(self, item: InvocationQueueItem | None) -> None:
self.__queue.put(item) self.__queue.put(item)
def cancel(self, graph_execution_state_id: str) -> None:
if graph_execution_state_id not in self.__cancellations:
self.__cancellations[graph_execution_state_id] = time.time()
def is_canceled(self, graph_execution_state_id: str) -> bool:
return graph_execution_state_id in self.__cancellations

View File

@ -2,6 +2,7 @@
from invokeai.backend import ModelManager from invokeai.backend import ModelManager
from .events import EventServiceBase from .events import EventServiceBase
from .latent_storage import LatentsStorageBase
from .image_storage import ImageStorageBase from .image_storage import ImageStorageBase
from .restoration_services import RestorationServices from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC from .invocation_queue import InvocationQueueABC
@ -11,6 +12,7 @@ class InvocationServices:
"""Services that can be used by invocations""" """Services that can be used by invocations"""
events: EventServiceBase events: EventServiceBase
latents: LatentsStorageBase
images: ImageStorageBase images: ImageStorageBase
queue: InvocationQueueABC queue: InvocationQueueABC
model_manager: ModelManager model_manager: ModelManager
@ -24,6 +26,7 @@ class InvocationServices:
self, self,
model_manager: ModelManager, model_manager: ModelManager,
events: EventServiceBase, events: EventServiceBase,
latents: LatentsStorageBase,
images: ImageStorageBase, images: ImageStorageBase,
queue: InvocationQueueABC, queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC["GraphExecutionState"], graph_execution_manager: ItemStorageABC["GraphExecutionState"],
@ -32,6 +35,7 @@ class InvocationServices:
): ):
self.model_manager = model_manager self.model_manager = model_manager
self.events = events self.events = events
self.latents = latents
self.images = images self.images = images
self.queue = queue self.queue = queue
self.graph_execution_manager = graph_execution_manager self.graph_execution_manager = graph_execution_manager

View File

@ -33,7 +33,6 @@ class Invoker:
self.services.graph_execution_manager.set(graph_execution_state) self.services.graph_execution_manager.set(graph_execution_state)
# Queue the invocation # Queue the invocation
print(f"queueing item {invocation.id}")
self.services.queue.put( self.services.queue.put(
InvocationQueueItem( InvocationQueueItem(
# session_id = session.id, # session_id = session.id,
@ -51,6 +50,10 @@ class Invoker:
self.services.graph_execution_manager.set(new_state) self.services.graph_execution_manager.set(new_state)
return new_state return new_state
def cancel(self, graph_execution_state_id: str) -> None:
"""Cancels the given execution state"""
self.services.queue.cancel(graph_execution_state_id)
def __start_service(self, service) -> None: def __start_service(self, service) -> None:
# Call start() method on any services that have it # Call start() method on any services that have it
start_op = getattr(service, "start", None) start_op = getattr(service, "start", None)

View File

@ -0,0 +1,93 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict
import torch
class LatentsStorageBase(ABC):
"""Responsible for storing and retrieving latents."""
@abstractmethod
def get(self, name: str) -> torch.Tensor:
pass
@abstractmethod
def set(self, name: str, data: torch.Tensor) -> None:
pass
@abstractmethod
def delete(self, name: str) -> None:
pass
class ForwardCacheLatentsStorage(LatentsStorageBase):
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
__cache: Dict[str, torch.Tensor]
__cache_ids: Queue
__max_cache_size: int
__underlying_storage: LatentsStorageBase
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
self.__underlying_storage = underlying_storage
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = max_cache_size
def get(self, name: str) -> torch.Tensor:
cache_item = self.__get_cache(name)
if cache_item is not None:
return cache_item
latent = self.__underlying_storage.get(name)
self.__set_cache(name, latent)
return latent
def set(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.set(name, data)
self.__set_cache(name, data)
def delete(self, name: str) -> None:
self.__underlying_storage.delete(name)
if name in self.__cache:
del self.__cache[name]
def __get_cache(self, name: str) -> torch.Tensor|None:
return None if name not in self.__cache else self.__cache[name]
def __set_cache(self, name: str, data: torch.Tensor):
if not name in self.__cache:
self.__cache[name] = data
self.__cache_ids.put(name)
if self.__cache_ids.qsize() > self.__max_cache_size:
self.__cache.pop(self.__cache_ids.get())
class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching"""
__output_folder: str
def __init__(self, output_folder: str):
self.__output_folder = output_folder
Path(output_folder).mkdir(parents=True, exist_ok=True)
def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name)
return torch.load(latent_path)
def set(self, name: str, data: torch.Tensor) -> None:
latent_path = self.get_path(name)
torch.save(data, latent_path)
def delete(self, name: str) -> None:
latent_path = self.get_path(name)
os.remove(latent_path)
def get_path(self, name: str) -> str:
return os.path.join(self.__output_folder, name)

View File

@ -4,7 +4,7 @@ from threading import Event, Thread
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker from .invoker import InvocationProcessorABC, Invoker
from ..util.util import CanceledException
class DefaultInvocationProcessor(InvocationProcessorABC): class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread __invoker_thread: Thread
@ -58,6 +58,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
) )
) )
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(
graph_execution_state.id
):
continue
# Save outputs and history # Save outputs and history
graph_execution_state.complete(invocation.id, outputs) graph_execution_state.complete(invocation.id, outputs)
@ -76,6 +82,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
except CanceledException:
pass
except Exception as e: except Exception as e:
error = traceback.format_exc() error = traceback.format_exc()
@ -96,6 +105,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
pass pass
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(
graph_execution_state.id
):
continue
# Queue any further commands if invoking all # Queue any further commands if invoking all
is_complete = graph_execution_state.is_complete() is_complete = graph_execution_state.is_complete()
if queue_item.invoke_all and not is_complete: if queue_item.invoke_all and not is_complete:

View File

@ -59,6 +59,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""", f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(item.json(),), (item.json(),),
) )
self._conn.commit()
finally: finally:
self._lock.release() self._lock.release()
self._on_changed(item) self._on_changed(item)
@ -84,6 +85,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._cursor.execute( self._cursor.execute(
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),) f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
) )
self._conn.commit()
finally: finally:
self._lock.release() self._lock.release()
self._on_deleted(id) self._on_deleted(id)

View File

@ -0,0 +1,25 @@
import os
from PIL import Image
def save_thumbnail(
image: Image.Image,
filename: str,
path: str,
size: int = 256,
) -> str:
"""
Saves a thumbnail of an image, returning its path.
"""
base_filename = os.path.splitext(filename)[0]
thumbnail_path = os.path.join(path, base_filename + ".webp")
if os.path.exists(thumbnail_path):
return thumbnail_path
image_copy = image.copy()
image_copy.thumbnail(size=(size, size))
image_copy.save(thumbnail_path, "WEBP")
return thumbnail_path

42
invokeai/app/util/util.py Normal file
View File

@ -0,0 +1,42 @@
import torch
from PIL import Image
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
class CanceledException(Exception):
pass
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ):
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
steps,
)
def diffusers_step_callback_adapter(*cb_args, **kwargs):
"""
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
This adapter grabs the needed data and passes it along to the callback function.
"""
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs)
else:
return fast_latents_step_callback(*cb_args, **kwargs)

View File

@ -21,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from tqdm import trange from tqdm import trange
from typing import List, Iterator, Type from typing import Callable, List, Iterator, Optional, Type
from dataclasses import dataclass, field from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
@ -35,23 +35,23 @@ downsampling = 8
@dataclass @dataclass
class InvokeAIGeneratorBasicParams: class InvokeAIGeneratorBasicParams:
seed: int=None seed: Optional[int]=None
width: int=512 width: int=512
height: int=512 height: int=512
cfg_scale: int=7.5 cfg_scale: float=7.5
steps: int=20 steps: int=20
ddim_eta: float=0.0 ddim_eta: float=0.0
scheduler: int='ddim' scheduler: str='ddim'
precision: str='float16' precision: str='float16'
perlin: float=0.0 perlin: float=0.0
threshold: int=0.0 threshold: float=0.0
seamless: bool=False seamless: bool=False
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y']) seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
h_symmetry_time_pct: float=None h_symmetry_time_pct: Optional[float]=None
v_symmetry_time_pct: float=None v_symmetry_time_pct: Optional[float]=None
variation_amount: float = 0.0 variation_amount: float = 0.0
with_variations: list=field(default_factory=list) with_variations: list=field(default_factory=list)
safety_checker: SafetyChecker=None safety_checker: Optional[SafetyChecker]=None
@dataclass @dataclass
class InvokeAIGeneratorOutput: class InvokeAIGeneratorOutput:
@ -61,10 +61,10 @@ class InvokeAIGeneratorOutput:
and the model hash, as well as all the generate() parameters that went into and the model hash, as well as all the generate() parameters that went into
generating the image (in .params, also available as attributes) generating the image (in .params, also available as attributes)
''' '''
image: Image image: Image.Image
seed: int seed: int
model_hash: str model_hash: str
attention_maps_images: List[Image] attention_maps_images: List[Image.Image]
params: Namespace params: Namespace
# we are interposing a wrapper around the original Generator classes so that # we are interposing a wrapper around the original Generator classes so that
@ -92,8 +92,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def generate(self, def generate(self,
prompt: str='', prompt: str='',
callback: callable=None, callback: Optional[Callable]=None,
step_callback: callable=None, step_callback: Optional[Callable]=None,
iterations: int=1, iterations: int=1,
**keyword_args, **keyword_args,
)->Iterator[InvokeAIGeneratorOutput]: )->Iterator[InvokeAIGeneratorOutput]:
@ -206,10 +206,10 @@ class Txt2Img(InvokeAIGenerator):
# ------------------------------------ # ------------------------------------
class Img2Img(InvokeAIGenerator): class Img2Img(InvokeAIGenerator):
def generate(self, def generate(self,
init_image: Image | torch.FloatTensor, init_image: Image.Image | torch.FloatTensor,
strength: float=0.75, strength: float=0.75,
**keyword_args **keyword_args
)->List[InvokeAIGeneratorOutput]: )->Iterator[InvokeAIGeneratorOutput]:
return super().generate(init_image=init_image, return super().generate(init_image=init_image,
strength=strength, strength=strength,
**keyword_args **keyword_args
@ -223,7 +223,7 @@ class Img2Img(InvokeAIGenerator):
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff # Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
class Inpaint(Img2Img): class Inpaint(Img2Img):
def generate(self, def generate(self,
mask_image: Image | torch.FloatTensor, mask_image: Image.Image | torch.FloatTensor,
# Seam settings - when 0, doesn't fill seam # Seam settings - when 0, doesn't fill seam
seam_size: int = 0, seam_size: int = 0,
seam_blur: int = 0, seam_blur: int = 0,
@ -236,7 +236,7 @@ class Inpaint(Img2Img):
inpaint_height=None, inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF), inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
**keyword_args **keyword_args
)->List[InvokeAIGeneratorOutput]: )->Iterator[InvokeAIGeneratorOutput]:
return super().generate( return super().generate(
mask_image=mask_image, mask_image=mask_image,
seam_size=seam_size, seam_size=seam_size,
@ -263,7 +263,7 @@ class Embiggen(Txt2Img):
embiggen: list=None, embiggen: list=None,
embiggen_tiles: list = None, embiggen_tiles: list = None,
strength: float=0.75, strength: float=0.75,
**kwargs)->List[InvokeAIGeneratorOutput]: **kwargs)->Iterator[InvokeAIGeneratorOutput]:
return super().generate(embiggen=embiggen, return super().generate(embiggen=embiggen,
embiggen_tiles=embiggen_tiles, embiggen_tiles=embiggen_tiles,
strength=strength, strength=strength,

View File

@ -378,16 +378,26 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
for key in keys: for key in keys:
if key.startswith("model.diffusion_model"): if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
flat_ema_key_alt = "model_ema." + "".join(key.split(".")[2:])
if flat_ema_key in checkpoint:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key flat_ema_key
) )
elif flat_ema_key_alt in checkpoint:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key_alt
)
else:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
key
)
else: else:
print( print(
" | Extracting only the non-EMA weights (usually better for fine-tuning)" " | Extracting only the non-EMA weights (usually better for fine-tuning)"
) )
for key in keys: for key in keys:
if key.startswith(unet_key): if key.startswith("model.diffusion_model") and key in checkpoint:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {} new_checkpoint = {}
@ -1026,6 +1036,15 @@ def convert_open_clip_checkpoint(checkpoint):
return text_model return text_model
def replace_checkpoint_vae(checkpoint, vae_path:str):
if vae_path.endswith(".safetensors"):
vae_ckpt = load_file(vae_path)
else:
vae_ckpt = torch.load(vae_path, map_location="cpu")
state_dict = vae_ckpt['state_dict'] if "state_dict" in vae_ckpt else vae_ckpt
for vae_key in state_dict:
new_key = f'first_stage_model.{vae_key}'
checkpoint[new_key] = state_dict[vae_key]
def load_pipeline_from_original_stable_diffusion_ckpt( def load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path: str, checkpoint_path: str,
@ -1038,8 +1057,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
extract_ema: bool = True, extract_ema: bool = True,
upcast_attn: bool = False, upcast_attn: bool = False,
vae: AutoencoderKL = None, vae: AutoencoderKL = None,
vae_path: str = None,
precision: torch.dtype = torch.float32, precision: torch.dtype = torch.float32,
return_generator_pipeline: bool = False, return_generator_pipeline: bool = False,
scan_needed:bool=True,
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]: ) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
""" """
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
@ -1067,6 +1088,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast :param precision: precision to use - torch.float16, torch.float32 or torch.autocast
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when :param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
running stable diffusion 2.1. running stable diffusion 2.1.
:param vae: A diffusers VAE to load into the pipeline.
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
""" """
with warnings.catch_warnings(): with warnings.catch_warnings():
@ -1074,12 +1097,13 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
verbosity = dlogging.get_verbosity() verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error() dlogging.set_verbosity_error()
checkpoint = ( if Path(checkpoint_path).suffix == '.ckpt':
torch.load(checkpoint_path) if scan_needed:
if Path(checkpoint_path).suffix == ".ckpt" ModelManager.scan_model(checkpoint_path,checkpoint_path)
else load_file(checkpoint_path) checkpoint = torch.load(checkpoint_path)
else:
checkpoint = load_file(checkpoint_path)
)
cache_dir = global_cache_dir("hub") cache_dir = global_cache_dir("hub")
pipeline_class = ( pipeline_class = (
StableDiffusionGeneratorPipeline StableDiffusionGeneratorPipeline
@ -1202,9 +1226,19 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
unet.load_state_dict(converted_unet_checkpoint) unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model, or use the one passed # If a replacement VAE path was specified, we'll incorporate that into
if not vae: # the checkpoint model and then convert it
if vae_path:
print(f" | Converting VAE {vae_path}")
replace_checkpoint_vae(checkpoint,vae_path)
# otherwise we use the original VAE, provided that
# an externally loaded diffusers VAE was not passed
elif not vae:
print(" | Using checkpoint model's original VAE") print(" | Using checkpoint model's original VAE")
if vae:
print(" | Using replacement diffusers VAE")
else: # convert the original or replacement VAE
vae_config = create_vae_diffusers_config( vae_config = create_vae_diffusers_config(
original_config, image_size=image_size original_config, image_size=image_size
) )
@ -1214,8 +1248,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)
else:
print(" | Using external VAE specified in config")
# Convert the text model. # Convert the text model.
model_type = pipeline_type model_type = pipeline_type
@ -1232,10 +1264,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
cache_dir=cache_dir, cache_dir=cache_dir,
) )
pipe = pipeline_class( pipe = pipeline_class(
vae=vae, vae=vae.to(precision),
text_encoder=text_model, text_encoder=text_model.to(precision),
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet, unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,

View File

@ -18,7 +18,7 @@ import warnings
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from shutil import move, rmtree from shutil import move, rmtree
from typing import Any, Optional, Union from typing import Any, Optional, Union, Callable
import safetensors import safetensors
import safetensors.torch import safetensors.torch
@ -34,7 +34,7 @@ from picklescan.scanner import scan_file_path
from invokeai.backend.globals import Globals, global_cache_dir from invokeai.backend.globals import Globals, global_cache_dir
from ..stable_diffusion import StableDiffusionGeneratorPipeline from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume from ..util import CUDA_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum): class SDLegacyType(Enum):
V1 = 1 V1 = 1
@ -45,9 +45,6 @@ class SDLegacyType(Enum):
UNKNOWN = 99 UNKNOWN = 99
DEFAULT_MAX_MODELS = 2 DEFAULT_MAX_MODELS = 2
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
"vae-ft-mse-840000-ema-pruned": "stabilityai/sd-vae-ft-mse",
}
class ModelManager(object): class ModelManager(object):
''' '''
@ -285,13 +282,13 @@ class ModelManager(object):
self.stack.remove(model_name) self.stack.remove(model_name)
if delete_files: if delete_files:
if weights: if weights:
print(f"** deleting file {weights}") print(f"** Deleting file {weights}")
Path(weights).unlink(missing_ok=True) Path(weights).unlink(missing_ok=True)
elif path: elif path:
print(f"** deleting directory {path}") print(f"** Deleting directory {path}")
rmtree(path, ignore_errors=True) rmtree(path, ignore_errors=True)
elif repo_id: elif repo_id:
print(f"** deleting the cached model directory for {repo_id}") print(f"** Deleting the cached model directory for {repo_id}")
self._delete_model_from_cache(repo_id) self._delete_model_from_cache(repo_id)
def add_model( def add_model(
@ -435,7 +432,6 @@ class ModelManager(object):
# square images??? # square images???
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
height = width height = width
print(f" | Default image dimensions = {width} x {height}") print(f" | Default image dimensions = {width} x {height}")
return pipeline, width, height, model_hash return pipeline, width, height, model_hash
@ -457,15 +453,21 @@ class ModelManager(object):
from . import load_pipeline_from_original_stable_diffusion_ckpt from . import load_pipeline_from_original_stable_diffusion_ckpt
try:
if self.list_models()[self.current_model]['status'] == 'active':
self.offload_model(self.current_model) self.offload_model(self.current_model)
if vae_config := self._choose_diffusers_vae(model_name): except Exception as e:
vae = self._load_vae(vae_config) pass
vae_path = None
if vae:
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
if self._has_cuda(): if self._has_cuda():
torch.cuda.empty_cache() torch.cuda.empty_cache()
pipeline = load_pipeline_from_original_stable_diffusion_ckpt( pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=weights, checkpoint_path=weights,
original_config_file=config, original_config_file=config,
vae=vae, vae_path=vae_path,
return_generator_pipeline=True, return_generator_pipeline=True,
precision=torch.float16 if self.precision == "float16" else torch.float32, precision=torch.float16 if self.precision == "float16" else torch.float32,
) )
@ -473,7 +475,6 @@ class ModelManager(object):
pipeline.enable_offload_submodels(self.device) pipeline.enable_offload_submodels(self.device)
else: else:
pipeline.to(self.device) pipeline.to(self.device)
return ( return (
pipeline, pipeline,
width, width,
@ -512,18 +513,20 @@ class ModelManager(object):
print(f">> Offloading {model_name} to CPU") print(f">> Offloading {model_name} to CPU")
model = self.models[model_name]["model"] model = self.models[model_name]["model"]
model.offload_all() model.offload_all()
self.current_model = None
gc.collect() gc.collect()
if self._has_cuda(): if self._has_cuda():
torch.cuda.empty_cache() torch.cuda.empty_cache()
@classmethod
def scan_model(self, model_name, checkpoint): def scan_model(self, model_name, checkpoint):
""" """
Apply picklescanner to the indicated checkpoint and issue a warning Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified. and option to exit if an infected file is identified.
""" """
# scan model # scan model
print(f">> Scanning Model: {model_name}") print(f" | Scanning Model: {model_name}")
scan_result = scan_file_path(checkpoint) scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0: if scan_result.infected_files != 0:
if scan_result.infected_files == 1: if scan_result.infected_files == 1:
@ -546,7 +549,7 @@ class ModelManager(object):
print("### Exiting InvokeAI") print("### Exiting InvokeAI")
sys.exit() sys.exit()
else: else:
print(">> Model scanned ok") print(" | Model scanned ok")
def import_diffuser_model( def import_diffuser_model(
self, self,
@ -627,14 +630,13 @@ class ModelManager(object):
def heuristic_import( def heuristic_import(
self, self,
path_url_or_repo: str, path_url_or_repo: str,
convert: bool = True,
model_name: str = None, model_name: str = None,
description: str = None, description: str = None,
model_config_file: Path = None, model_config_file: Path = None,
commit_to_conf: Path = None, commit_to_conf: Path = None,
config_file_callback: Callable[[Path], Path] = None,
) -> str: ) -> str:
""" """Accept a string which could be:
Accept a string which could be:
- a HF diffusers repo_id - a HF diffusers repo_id
- a URL pointing to a legacy .ckpt or .safetensors file - a URL pointing to a legacy .ckpt or .safetensors file
- a local path pointing to a legacy .ckpt or .safetensors file - a local path pointing to a legacy .ckpt or .safetensors file
@ -648,16 +650,20 @@ class ModelManager(object):
The model_name and/or description can be provided. If not, they will The model_name and/or description can be provided. If not, they will
be generated automatically. be generated automatically.
If convert is true, legacy models will be converted to diffusers
before importing.
If commit_to_conf is provided, the newly loaded model will be written If commit_to_conf is provided, the newly loaded model will be written
to the `models.yaml` file at the indicated path. Otherwise, the changes to the `models.yaml` file at the indicated path. Otherwise, the changes
will only remain in memory. will only remain in memory.
The (potentially derived) name of the model is returned on success, or None The routine will do its best to figure out the config file
on failure. When multiple models are added from a directory, only the last needed to convert legacy checkpoint file, but if it can't it
imported one is returned. will call the config_file_callback routine, if provided. The
callback accepts a single argument, the Path to the checkpoint
file, and returns a Path to the config file to use.
The (potentially derived) name of the model is returned on
success, or None on failure. When multiple models are added
from a directory, only the last imported one is returned.
""" """
model_path: Path = None model_path: Path = None
thing = path_url_or_repo # to save typing thing = path_url_or_repo # to save typing
@ -704,7 +710,7 @@ class ModelManager(object):
Path(thing).rglob("*.safetensors") Path(thing).rglob("*.safetensors")
): ):
if model_name := self.heuristic_import( if model_name := self.heuristic_import(
str(m), convert, commit_to_conf=commit_to_conf str(m), commit_to_conf=commit_to_conf
): ):
print(f" >> {model_name} successfully imported") print(f" >> {model_name} successfully imported")
return model_name return model_name
@ -731,14 +737,21 @@ class ModelManager(object):
return model_path.stem return model_path.stem
# another round of heuristics to guess the correct config file. # another round of heuristics to guess the correct config file.
checkpoint = ( checkpoint = None
torch.load(model_path) if model_path.suffix in [".ckpt",".pt"]:
if model_path.suffix == ".ckpt" self.scan_model(model_path,model_path)
else safetensors.torch.load_file(model_path) checkpoint = torch.load(model_path)
) else:
checkpoint = safetensors.torch.load_file(model_path)
# additional probing needed if no config file provided # additional probing needed if no config file provided
if model_config_file is None: if model_config_file is None:
# look for a like-named .yaml file in same directory
if model_path.with_suffix(".yaml").exists():
model_config_file = model_path.with_suffix(".yaml")
print(f" | Using config file {model_config_file.name}")
else:
model_type = self.probe_model_type(checkpoint) model_type = self.probe_model_type(checkpoint)
if model_type == SDLegacyType.V1: if model_type == SDLegacyType.V1:
print(" | SD-v1 model detected") print(" | SD-v1 model detected")
@ -752,20 +765,18 @@ class ModelManager(object):
) )
elif model_type == SDLegacyType.V2_v: elif model_type == SDLegacyType.V2_v:
print( print(
" | SD-v2-v model detected; model will be converted to diffusers format" " | SD-v2-v model detected"
) )
model_config_file = Path( model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml" Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
) )
convert = True
elif model_type == SDLegacyType.V2_e: elif model_type == SDLegacyType.V2_e:
print( print(
" | SD-v2-e model detected; model will be converted to diffusers format" " | SD-v2-e model detected"
) )
model_config_file = Path( model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference.yaml" Globals.root, "configs/stable-diffusion/v2-inference.yaml"
) )
convert = True
elif model_type == SDLegacyType.V2: elif model_type == SDLegacyType.V2:
print( print(
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path." f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
@ -777,17 +788,34 @@ class ModelManager(object):
) )
return return
if not model_config_file and config_file_callback:
model_config_file = config_file_callback(model_path)
# despite our best efforts, we could not find a model config file, so give up
if not model_config_file:
return
# look for a custom vae, a like-named file ending with .vae in the same directory
vae_path = None
for suffix in ["pt", "ckpt", "safetensors"]:
if (model_path.with_suffix(f".vae.{suffix}")).exists():
vae_path = model_path.with_suffix(f".vae.{suffix}")
print(f" | Using VAE file {vae_path.name}")
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
diffuser_path = Path( diffuser_path = Path(
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
) )
model_name = self.convert_and_import( model_name = self.convert_and_import(
model_path, model_path,
diffusers_path=diffuser_path, diffusers_path=diffuser_path,
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"), vae=vae,
vae_path=str(vae_path),
model_name=model_name, model_name=model_name,
model_description=description, model_description=description,
original_config_file=model_config_file, original_config_file=model_config_file,
commit_to_conf=commit_to_conf, commit_to_conf=commit_to_conf,
scan_needed=False,
) )
return model_name return model_name
@ -797,9 +825,11 @@ class ModelManager(object):
diffusers_path: Path, diffusers_path: Path,
model_name=None, model_name=None,
model_description=None, model_description=None,
vae=None, vae:dict=None,
vae_path:Path=None,
original_config_file: Path = None, original_config_file: Path = None,
commit_to_conf: Path = None, commit_to_conf: Path = None,
scan_needed: bool=True,
) -> str: ) -> str:
""" """
Convert a legacy ckpt weights file to diffuser model and import Convert a legacy ckpt weights file to diffuser model and import
@ -822,21 +852,26 @@ class ModelManager(object):
return return
model_name = model_name or diffusers_path.name model_name = model_name or diffusers_path.name
model_description = model_description or f"Optimized version of {model_name}" model_description = model_description or f"Converted version of {model_name}"
print(f">> Optimizing {model_name} (30-60s)") print(f" | Converting {model_name} to diffusers (30-60s)")
try: try:
# By passing the specified VAE to the conversion function, the autoencoder # By passing the specified VAE to the conversion function, the autoencoder
# will be built into the model rather than tacked on afterward via the config file # will be built into the model rather than tacked on afterward via the config file
vae_model = self._load_vae(vae) if vae else None vae_model=None
if vae:
vae_model=self._load_vae(vae)
vae_path=None
convert_ckpt_to_diffusers( convert_ckpt_to_diffusers(
ckpt_path, ckpt_path,
diffusers_path, diffusers_path,
extract_ema=True, extract_ema=True,
original_config_file=original_config_file, original_config_file=original_config_file,
vae=vae_model, vae=vae_model,
vae_path=vae_path,
scan_needed=scan_needed,
) )
print( print(
f" | Success. Optimized model is now located at {str(diffusers_path)}" f" | Success. Converted model is now located at {str(diffusers_path)}"
) )
print(f" | Writing new config file entry for {model_name}") print(f" | Writing new config file entry for {model_name}")
new_config = dict( new_config = dict(
@ -849,7 +884,7 @@ class ModelManager(object):
self.add_model(model_name, new_config, True) self.add_model(model_name, new_config, True)
if commit_to_conf: if commit_to_conf:
self.commit(commit_to_conf) self.commit(commit_to_conf)
print(">> Conversion succeeded") print(" | Conversion succeeded")
except Exception as e: except Exception as e:
print(f"** Conversion failed: {str(e)}") print(f"** Conversion failed: {str(e)}")
print( print(
@ -879,36 +914,6 @@ class ModelManager(object):
return search_folder, found_models return search_folder, found_models
def _choose_diffusers_vae(
self, model_name: str, vae: str = None
) -> Union[dict, str]:
# In the event that the original entry is using a custom ckpt VAE, we try to
# map that VAE onto a diffuser VAE using a hard-coded dictionary.
# I would prefer to do this differently: We load the ckpt model into memory, swap the
# VAE in memory, and then pass that to convert_ckpt_to_diffuser() so that the swapped
# VAE is built into the model. However, when I tried this I got obscure key errors.
if vae:
return vae
if model_name in self.config and (
vae_ckpt_path := self.model_info(model_name).get("vae", None)
):
vae_basename = Path(vae_ckpt_path).stem
diffusers_vae = None
if diffusers_vae := VAE_TO_REPO_ID.get(vae_basename, None):
print(
f">> {vae_basename} VAE corresponds to known {diffusers_vae} diffusers version"
)
vae = {"repo_id": diffusers_vae}
else:
print(
f'** Custom VAE "{vae_basename}" found, but corresponding diffusers model unknown'
)
print(
'** Using "stabilityai/sd-vae-ft-mse"; If this isn\'t right, please edit the model config'
)
vae = {"repo_id": "stabilityai/sd-vae-ft-mse"}
return vae
def _make_cache_room(self) -> None: def _make_cache_room(self) -> None:
num_loaded_models = len(self.models) num_loaded_models = len(self.models)
if num_loaded_models >= self.max_loaded_models: if num_loaded_models >= self.max_loaded_models:
@ -1208,7 +1213,7 @@ class ModelManager(object):
hashes_to_delete.add(revision.commit_hash) hashes_to_delete.add(revision.commit_hash)
strategy = cache_info.delete_revisions(*hashes_to_delete) strategy = cache_info.delete_revisions(*hashes_to_delete)
print( print(
f"** deletion of this model is expected to free {strategy.expected_freed_size_str}" f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
) )
strategy.execute() strategy.execute()

View File

@ -1,16 +1,26 @@
import os
import traceback import traceback
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union, List
import safetensors.torch
import torch import torch
from compel.embeddings_provider import BaseTextualInversionManager from compel.embeddings_provider import BaseTextualInversionManager
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from .concepts_lib import HuggingFaceConceptsLibrary from .concepts_lib import HuggingFaceConceptsLibrary
@dataclass
class EmbeddingInfo:
name: str
embedding: torch.Tensor
num_vectors_per_token: int
token_dim: int
trained_steps: int = None
trained_model_name: str = None
trained_model_checksum: str = None
@dataclass @dataclass
class TextualInversion: class TextualInversion:
@ -72,37 +82,17 @@ class TextualInversionManager(BaseTextualInversionManager):
if str(ckpt_path).endswith(".DS_Store"): if str(ckpt_path).endswith(".DS_Store"):
return return
try: embedding_list = self._parse_embedding(str(ckpt_path))
scan_result = scan_file_path(str(ckpt_path)) for embedding_info in embedding_list:
if scan_result.infected_files == 1: if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
print( print(
f"\n### Security Issues Found in Model: {scan_result.issues_count}" f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
) )
print("### For your safety, InvokeAI will not load this embed.") continue
return
except Exception:
print(
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
)
return
embedding_info = self._parse_embedding(str(ckpt_path))
if embedding_info is None:
# We've already put out an error message about the bad embedding in _parse_embedding, so just return.
return
elif (
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
!= embedding_info["token_dim"]
):
print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
)
return
# Resolve the situation in which an earlier embedding has claimed the same # Resolve the situation in which an earlier embedding has claimed the same
# trigger string. We replace the trigger with '<source_file>', as we used to. # trigger string. We replace the trigger with '<source_file>', as we used to.
trigger_str = embedding_info["name"] trigger_str = embedding_info.name
sourcefile = ( sourcefile = (
f"{ckpt_path.parent.name}/{ckpt_path.name}" f"{ckpt_path.parent.name}/{ckpt_path.name}"
if ckpt_path.name == "learned_embeds.bin" if ckpt_path.name == "learned_embeds.bin"
@ -123,7 +113,7 @@ class TextualInversionManager(BaseTextualInversionManager):
try: try:
self._add_textual_inversion( self._add_textual_inversion(
trigger_str, trigger_str,
embedding_info["embedding"], embedding_info.embedding,
defer_injecting_tokens=defer_injecting_tokens, defer_injecting_tokens=defer_injecting_tokens,
) )
# remember which source file claims this trigger # remember which source file claims this trigger
@ -309,111 +299,130 @@ class TextualInversionManager(BaseTextualInversionManager):
return token_id return token_id
def _parse_embedding(self, embedding_file: str):
file_type = embedding_file.split(".")[-1]
if file_type == "pt":
return self._parse_embedding_pt(embedding_file)
elif file_type == "bin":
return self._parse_embedding_bin(embedding_file)
else:
print(f"** Notice: unrecognized embedding file format: {embedding_file}")
return None
def _parse_embedding_pt(self, embedding_file): def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]:
embedding_ckpt = torch.load(embedding_file, map_location="cpu") suffix = Path(embedding_file).suffix
embedding_info = {}
# Check if valid embedding file
if "string_to_token" and "string_to_param" in embedding_ckpt:
# Catch variants that do not have the expected keys or values.
try: try:
embedding_info["name"] = embedding_ckpt["name"] or os.path.basename( if suffix in [".pt",".ckpt",".bin"]:
os.path.splitext(embedding_file)[0] scan_result = scan_file_path(embedding_file)
if scan_result.infected_files > 0:
print(
f" ** Security Issues Found in Model: {scan_result.issues_count}"
) )
print(" ** For your safety, InvokeAI will not load this embed.")
return list()
ckpt = torch.load(embedding_file,map_location="cpu")
else:
ckpt = safetensors.torch.load_file(embedding_file)
except Exception as e:
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
return list()
# Check num of embeddings and warn user only the first will be used # try to figure out what kind of embedding file it is and parse accordingly
embedding_info["num_of_embeddings"] = len( keys = list(ckpt.keys())
embedding_ckpt["string_to_token"] if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
) return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
if embedding_info["num_of_embeddings"] > 1:
print(">> More than 1 embedding found. Will use the first one")
embedding = list(embedding_ckpt["string_to_param"].values())[0] elif all(x in keys for x in ['string_to_token','string_to_param']):
except (AttributeError, KeyError): return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
embedding_info["embedding"] = embedding elif 'emb_params' in keys:
embedding_info["num_vectors_per_token"] = embedding.size()[0] return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
embedding_info["token_dim"] = embedding.size()[1]
try:
embedding_info["trained_steps"] = embedding_ckpt["step"]
embedding_info["trained_model_name"] = embedding_ckpt[
"sd_checkpoint_name"
]
embedding_info["trained_model_checksum"] = embedding_ckpt[
"sd_checkpoint"
]
except AttributeError:
print(">> No Training Details Found. Passing ...")
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
# They are actually .bin files
elif len(embedding_ckpt.keys()) == 1:
embedding_info = self._parse_embedding_bin(embedding_file)
else: else:
print(">> Invalid embedding format") return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
embedding_info = None
return embedding_info def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
basename = Path(file_path).stem
print(f' | Loading v1 embedding file: {basename}')
def _parse_embedding_bin(self, embedding_file): embeddings = list()
embedding_ckpt = torch.load(embedding_file, map_location="cpu") token_counter = -1
embedding_info = {} for token,embedding in embedding_ckpt["string_to_param"].items():
if token_counter < 0:
if list(embedding_ckpt.keys()) == 0: trigger = embedding_ckpt["name"]
print(">> Invalid concepts file") elif token_counter == 0:
embedding_info = None trigger = f'<basename>'
else: else:
for token in list(embedding_ckpt.keys()): trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
embedding_info["name"] = ( token_counter += 1
token embedding_info = EmbeddingInfo(
or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>" name = trigger,
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
trained_steps = embedding_ckpt["step"],
trained_model_name = embedding_ckpt["sd_checkpoint_name"],
trained_model_checksum = embedding_ckpt["sd_checkpoint"]
) )
embedding_info["embedding"] = embedding_ckpt[token] embeddings.append(embedding_info)
embedding_info[ return embeddings
"num_vectors_per_token"
] = 1 # All Concepts seem to default to 1
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
return embedding_info def _parse_embedding_v2 (
self, embedding_ckpt: dict, file_path: str
) -> List[EmbeddingInfo]:
"""
This handles embedding .pt file variant #2.
"""
basename = Path(file_path).stem
print(f' | Loading v2 embedding file: {basename}')
embeddings = list()
def _handle_broken_pt_variants(
self, embedding_ckpt: dict, embedding_file: str
) -> dict:
"""
This handles the broken .pt file variants. We only know of one at present.
"""
embedding_info = {}
if isinstance( if isinstance(
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
): ):
for token in list(embedding_ckpt["string_to_token"].keys()): token_counter = 0
embedding_info["name"] = ( for token,embedding in embedding_ckpt["string_to_param"].items():
token trigger = token if token != '*' \
if token != "*" else f'<{basename}>' if token_counter == 0 \
else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>" else f'<{basename}-{int(token_counter:=token_counter+1)}>'
embedding_info = EmbeddingInfo(
name = trigger,
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
) )
embedding_info["embedding"] = embedding_ckpt[ embeddings.append(embedding_info)
"string_to_param"
].state_dict()[token]
embedding_info["num_vectors_per_token"] = embedding_info[
"embedding"
].shape[0]
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
else: else:
print(">> Invalid embedding format") print(f" ** {basename}: Unrecognized embedding format")
embedding_info = None
return embedding_info return embeddings
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
"""
Parse 'version 3' of the .pt textual inversion embedding files.
"""
basename = Path(file_path).stem
print(f' | Loading v3 embedding file: {basename}')
embedding = embedding_ckpt['emb_params']
embedding_info = EmbeddingInfo(
name = f'<{basename}>',
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
)
return [embedding_info]
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]:
"""
Parse 'version 4' of the textual inversion embedding files. This one
is usually associated with .bin files trained by HuggingFace diffusers.
"""
basename = Path(filepath).stem
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
print(f' | Loading v4 embedding file: {short_path}')
embeddings = list()
if list(embedding_ckpt.keys()) == 0:
print(f" ** Invalid embeddings file: {short_path}")
else:
for token,embedding in embedding_ckpt.items():
embedding_info = EmbeddingInfo(
name = token or f"<{basename}>",
embedding = embedding,
num_vectors_per_token = 1, # All Concepts seem to default to 1
token_dim = embedding.size()[0],
)
embeddings.append(embedding_info)
return embeddings

View File

@ -1022,7 +1022,7 @@ class InvokeAIWebServer:
"RGB" "RGB"
) )
def image_progress(sample, step): def image_progress(intermediate_state: PipelineIntermediateState):
if self.canceled.is_set(): if self.canceled.is_set():
raise CanceledException raise CanceledException
@ -1030,6 +1030,14 @@ class InvokeAIWebServer:
nonlocal generation_parameters nonlocal generation_parameters
nonlocal progress nonlocal progress
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
generation_messages = { generation_messages = {
"txt2img": "common.statusGeneratingTextToImage", "txt2img": "common.statusGeneratingTextToImage",
"img2img": "common.statusGeneratingImageToImage", "img2img": "common.statusGeneratingImageToImage",
@ -1302,16 +1310,9 @@ class InvokeAIWebServer:
progress.set_current_iteration(progress.current_iteration + 1) progress.set_current_iteration(progress.current_iteration + 1)
def diffusers_step_callback_adapter(*cb_args, **kwargs):
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
return image_progress(progress_state.latents, progress_state.step)
else:
return image_progress(*cb_args, **kwargs)
self.generate.prompt2image( self.generate.prompt2image(
**generation_parameters, **generation_parameters,
step_callback=diffusers_step_callback_adapter, step_callback=image_progress,
image_callback=image_done, image_callback=image_done,
) )

View File

@ -626,7 +626,7 @@ def set_default_output_dir(opt: Args, completer: Completer):
completer.set_default_dir(opt.outdir) completer.set_default_dir(opt.outdir)
def import_model(model_path: str, gen, opt, completer, convert=False): def import_model(model_path: str, gen, opt, completer):
""" """
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
(3) a huggingface repository id; or (4) a local directory containing a (3) a huggingface repository id; or (4) a local directory containing a
@ -657,7 +657,6 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
model_path, model_path,
model_name=model_name, model_name=model_name,
description=model_desc, description=model_desc,
convert=convert,
) )
if not imported_name: if not imported_name:
@ -666,7 +665,6 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
model_path, model_path,
model_name=model_name, model_name=model_name,
description=model_desc, description=model_desc,
convert=convert,
model_config_file=config_file, model_config_file=config_file,
) )
if not imported_name: if not imported_name:
@ -757,7 +755,6 @@ def _get_model_name_and_desc(
) )
return model_name, model_description return model_name, model_description
def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer): def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
model_name_or_path = model_name_or_path.replace("\\", "/") # windows model_name_or_path = model_name_or_path.replace("\\", "/") # windows
manager = gen.model_manager manager = gen.model_manager
@ -772,16 +769,10 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
original_config_file = Path(model_info["config"]) original_config_file = Path(model_info["config"])
model_name = model_name_or_path model_name = model_name_or_path
model_description = model_info["description"] model_description = model_info["description"]
vae = model_info["vae"] vae_path = model_info.get("vae")
else: else:
print(f"** {model_name_or_path} is not a legacy .ckpt weights file") print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
return return
if vae_repo := invokeai.backend.model_management.model_manager.VAE_TO_REPO_ID.get(
Path(vae).stem
):
vae_repo = dict(repo_id=vae_repo)
else:
vae_repo = None
model_name = manager.convert_and_import( model_name = manager.convert_and_import(
ckpt_path, ckpt_path,
diffusers_path=Path( diffusers_path=Path(
@ -790,11 +781,11 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
model_name=model_name, model_name=model_name,
model_description=model_description, model_description=model_description,
original_config_file=original_config_file, original_config_file=original_config_file,
vae=vae_repo, vae_path=vae_path,
) )
else: else:
try: try:
import_model(model_name_or_path, gen, opt, completer, convert=True) import_model(model_name_or_path, gen, opt, completer)
except KeyboardInterrupt: except KeyboardInterrupt:
return return

View File

@ -8,7 +8,6 @@
"darkTheme": "داكن", "darkTheme": "داكن",
"lightTheme": "فاتح", "lightTheme": "فاتح",
"greenTheme": "أخضر", "greenTheme": "أخضر",
"text2img": "نص إلى صورة",
"img2img": "صورة إلى صورة", "img2img": "صورة إلى صورة",
"unifiedCanvas": "لوحة موحدة", "unifiedCanvas": "لوحة موحدة",
"nodes": "عقد", "nodes": "عقد",

View File

@ -7,7 +7,6 @@
"darkTheme": "Dunkel", "darkTheme": "Dunkel",
"lightTheme": "Hell", "lightTheme": "Hell",
"greenTheme": "Grün", "greenTheme": "Grün",
"text2img": "Text zu Bild",
"img2img": "Bild zu Bild", "img2img": "Bild zu Bild",
"nodes": "Knoten", "nodes": "Knoten",
"langGerman": "Deutsch", "langGerman": "Deutsch",

View File

@ -8,7 +8,6 @@
"darkTheme": "Oscuro", "darkTheme": "Oscuro",
"lightTheme": "Claro", "lightTheme": "Claro",
"greenTheme": "Verde", "greenTheme": "Verde",
"text2img": "Texto a Imagen",
"img2img": "Imagen a Imagen", "img2img": "Imagen a Imagen",
"unifiedCanvas": "Lienzo Unificado", "unifiedCanvas": "Lienzo Unificado",
"nodes": "Nodos", "nodes": "Nodos",
@ -70,7 +69,11 @@
"langHebrew": "Hebreo", "langHebrew": "Hebreo",
"pinOptionsPanel": "Pin del panel de opciones", "pinOptionsPanel": "Pin del panel de opciones",
"loading": "Cargando", "loading": "Cargando",
"loadingInvokeAI": "Cargando invocar a la IA" "loadingInvokeAI": "Cargando invocar a la IA",
"postprocessing": "Tratamiento posterior",
"txt2img": "De texto a imagen",
"accept": "Aceptar",
"cancel": "Cancelar"
}, },
"gallery": { "gallery": {
"generations": "Generaciones", "generations": "Generaciones",
@ -404,7 +407,8 @@
"none": "ninguno", "none": "ninguno",
"pickModelType": "Elige el tipo de modelo", "pickModelType": "Elige el tipo de modelo",
"v2_768": "v2 (768px)", "v2_768": "v2 (768px)",
"addDifference": "Añadir una diferencia" "addDifference": "Añadir una diferencia",
"scanForModels": "Buscar modelos"
}, },
"parameters": { "parameters": {
"images": "Imágenes", "images": "Imágenes",
@ -574,7 +578,7 @@
"autoSaveToGallery": "Guardar automáticamente en galería", "autoSaveToGallery": "Guardar automáticamente en galería",
"saveBoxRegionOnly": "Guardar solo región dentro de la caja", "saveBoxRegionOnly": "Guardar solo región dentro de la caja",
"limitStrokesToBox": "Limitar trazos a la caja", "limitStrokesToBox": "Limitar trazos a la caja",
"showCanvasDebugInfo": "Mostrar información de depuración de lienzo", "showCanvasDebugInfo": "Mostrar la información adicional del lienzo",
"clearCanvasHistory": "Limpiar historial de lienzo", "clearCanvasHistory": "Limpiar historial de lienzo",
"clearHistory": "Limpiar historial", "clearHistory": "Limpiar historial",
"clearCanvasHistoryMessage": "Limpiar el historial de lienzo también restablece completamente el lienzo unificado. Esto incluye todo el historial de deshacer/rehacer, las imágenes en el área de preparación y la capa base del lienzo.", "clearCanvasHistoryMessage": "Limpiar el historial de lienzo también restablece completamente el lienzo unificado. Esto incluye todo el historial de deshacer/rehacer, las imágenes en el área de preparación y la capa base del lienzo.",

View File

@ -8,7 +8,6 @@
"darkTheme": "Sombre", "darkTheme": "Sombre",
"lightTheme": "Clair", "lightTheme": "Clair",
"greenTheme": "Vert", "greenTheme": "Vert",
"text2img": "Texte en image",
"img2img": "Image en image", "img2img": "Image en image",
"unifiedCanvas": "Canvas unifié", "unifiedCanvas": "Canvas unifié",
"nodes": "Nœuds", "nodes": "Nœuds",
@ -47,7 +46,19 @@
"statusLoadingModel": "Chargement du modèle", "statusLoadingModel": "Chargement du modèle",
"statusModelChanged": "Modèle changé", "statusModelChanged": "Modèle changé",
"discordLabel": "Discord", "discordLabel": "Discord",
"githubLabel": "Github" "githubLabel": "Github",
"accept": "Accepter",
"statusMergingModels": "Mélange des modèles",
"loadingInvokeAI": "Chargement de Invoke AI",
"cancel": "Annuler",
"langEnglish": "Anglais",
"statusConvertingModel": "Conversion du modèle",
"statusModelConverted": "Modèle converti",
"loading": "Chargement",
"pinOptionsPanel": "Épingler la page d'options",
"statusMergedModels": "Modèles mélangés",
"txt2img": "Texte vers image",
"postprocessing": "Post-Traitement"
}, },
"gallery": { "gallery": {
"generations": "Générations", "generations": "Générations",
@ -518,5 +529,15 @@
"betaDarkenOutside": "Assombrir à l'extérieur", "betaDarkenOutside": "Assombrir à l'extérieur",
"betaLimitToBox": "Limiter à la boîte", "betaLimitToBox": "Limiter à la boîte",
"betaPreserveMasked": "Conserver masqué" "betaPreserveMasked": "Conserver masqué"
},
"accessibility": {
"uploadImage": "Charger une image",
"reset": "Réinitialiser",
"nextImage": "Image suivante",
"previousImage": "Image précédente",
"useThisParameter": "Utiliser ce paramètre",
"zoomIn": "Zoom avant",
"zoomOut": "Zoom arrière",
"showOptionsPanel": "Montrer la page d'options"
} }
} }

View File

@ -125,7 +125,6 @@
"langSimplifiedChinese": "סינית", "langSimplifiedChinese": "סינית",
"langUkranian": "אוקראינית", "langUkranian": "אוקראינית",
"langSpanish": "ספרדית", "langSpanish": "ספרדית",
"text2img": "טקסט לתמונה",
"img2img": "תמונה לתמונה", "img2img": "תמונה לתמונה",
"unifiedCanvas": "קנבס מאוחד", "unifiedCanvas": "קנבס מאוחד",
"nodes": "צמתים", "nodes": "צמתים",

View File

@ -8,7 +8,6 @@
"darkTheme": "Scuro", "darkTheme": "Scuro",
"lightTheme": "Chiaro", "lightTheme": "Chiaro",
"greenTheme": "Verde", "greenTheme": "Verde",
"text2img": "Testo a Immagine",
"img2img": "Immagine a Immagine", "img2img": "Immagine a Immagine",
"unifiedCanvas": "Tela unificata", "unifiedCanvas": "Tela unificata",
"nodes": "Nodi", "nodes": "Nodi",
@ -70,7 +69,11 @@
"loading": "Caricamento in corso", "loading": "Caricamento in corso",
"oceanTheme": "Oceano", "oceanTheme": "Oceano",
"langHebrew": "Ebraico", "langHebrew": "Ebraico",
"loadingInvokeAI": "Caricamento Invoke AI" "loadingInvokeAI": "Caricamento Invoke AI",
"postprocessing": "Post Elaborazione",
"txt2img": "Testo a Immagine",
"accept": "Accetta",
"cancel": "Annulla"
}, },
"gallery": { "gallery": {
"generations": "Generazioni", "generations": "Generazioni",
@ -404,7 +407,8 @@
"v2_768": "v2 (768px)", "v2_768": "v2 (768px)",
"none": "niente", "none": "niente",
"addDifference": "Aggiungi differenza", "addDifference": "Aggiungi differenza",
"pickModelType": "Scegli il tipo di modello" "pickModelType": "Scegli il tipo di modello",
"scanForModels": "Cerca modelli"
}, },
"parameters": { "parameters": {
"images": "Immagini", "images": "Immagini",
@ -574,7 +578,7 @@
"autoSaveToGallery": "Salvataggio automatico nella Galleria", "autoSaveToGallery": "Salvataggio automatico nella Galleria",
"saveBoxRegionOnly": "Salva solo l'area di selezione", "saveBoxRegionOnly": "Salva solo l'area di selezione",
"limitStrokesToBox": "Limita i tratti all'area di selezione", "limitStrokesToBox": "Limita i tratti all'area di selezione",
"showCanvasDebugInfo": "Mostra informazioni di debug della Tela", "showCanvasDebugInfo": "Mostra ulteriori informazioni sulla Tela",
"clearCanvasHistory": "Cancella cronologia Tela", "clearCanvasHistory": "Cancella cronologia Tela",
"clearHistory": "Cancella la cronologia", "clearHistory": "Cancella la cronologia",
"clearCanvasHistoryMessage": "La cancellazione della cronologia della tela lascia intatta la tela corrente, ma cancella in modo irreversibile la cronologia degli annullamenti e dei ripristini.", "clearCanvasHistoryMessage": "La cancellazione della cronologia della tela lascia intatta la tela corrente, ma cancella in modo irreversibile la cronologia degli annullamenti e dei ripristini.",
@ -612,7 +616,7 @@
"copyMetadataJson": "Copia i metadati JSON", "copyMetadataJson": "Copia i metadati JSON",
"exitViewer": "Esci dal visualizzatore", "exitViewer": "Esci dal visualizzatore",
"zoomIn": "Zoom avanti", "zoomIn": "Zoom avanti",
"zoomOut": "Zoom Indietro", "zoomOut": "Zoom indietro",
"rotateCounterClockwise": "Ruotare in senso antiorario", "rotateCounterClockwise": "Ruotare in senso antiorario",
"rotateClockwise": "Ruotare in senso orario", "rotateClockwise": "Ruotare in senso orario",
"flipHorizontally": "Capovolgi orizzontalmente", "flipHorizontally": "Capovolgi orizzontalmente",

View File

@ -11,7 +11,6 @@
"langArabic": "العربية", "langArabic": "العربية",
"langEnglish": "English", "langEnglish": "English",
"langDutch": "Nederlands", "langDutch": "Nederlands",
"text2img": "텍스트->이미지",
"unifiedCanvas": "통합 캔버스", "unifiedCanvas": "통합 캔버스",
"langFrench": "Français", "langFrench": "Français",
"langGerman": "Deutsch", "langGerman": "Deutsch",

View File

@ -8,7 +8,6 @@
"darkTheme": "Donker", "darkTheme": "Donker",
"lightTheme": "Licht", "lightTheme": "Licht",
"greenTheme": "Groen", "greenTheme": "Groen",
"text2img": "Tekst naar afbeelding",
"img2img": "Afbeelding naar afbeelding", "img2img": "Afbeelding naar afbeelding",
"unifiedCanvas": "Centraal canvas", "unifiedCanvas": "Centraal canvas",
"nodes": "Knooppunten", "nodes": "Knooppunten",

View File

@ -8,7 +8,6 @@
"darkTheme": "Ciemny", "darkTheme": "Ciemny",
"lightTheme": "Jasny", "lightTheme": "Jasny",
"greenTheme": "Zielony", "greenTheme": "Zielony",
"text2img": "Tekst na obraz",
"img2img": "Obraz na obraz", "img2img": "Obraz na obraz",
"unifiedCanvas": "Tryb uniwersalny", "unifiedCanvas": "Tryb uniwersalny",
"nodes": "Węzły", "nodes": "Węzły",

View File

@ -20,7 +20,6 @@
"langSpanish": "Espanhol", "langSpanish": "Espanhol",
"langRussian": "Русский", "langRussian": "Русский",
"langUkranian": "Украї́нська", "langUkranian": "Украї́нська",
"text2img": "Texto para Imagem",
"img2img": "Imagem para Imagem", "img2img": "Imagem para Imagem",
"unifiedCanvas": "Tela Unificada", "unifiedCanvas": "Tela Unificada",
"nodes": "Nós", "nodes": "Nós",

View File

@ -8,7 +8,6 @@
"darkTheme": "Noite", "darkTheme": "Noite",
"lightTheme": "Dia", "lightTheme": "Dia",
"greenTheme": "Verde", "greenTheme": "Verde",
"text2img": "Texto Para Imagem",
"img2img": "Imagem Para Imagem", "img2img": "Imagem Para Imagem",
"unifiedCanvas": "Tela Unificada", "unifiedCanvas": "Tela Unificada",
"nodes": "Nódulos", "nodes": "Nódulos",

View File

@ -8,7 +8,6 @@
"darkTheme": "Темная", "darkTheme": "Темная",
"lightTheme": "Светлая", "lightTheme": "Светлая",
"greenTheme": "Зеленая", "greenTheme": "Зеленая",
"text2img": "Изображение из текста (text2img)",
"img2img": "Изображение в изображение (img2img)", "img2img": "Изображение в изображение (img2img)",
"unifiedCanvas": "Универсальный холст", "unifiedCanvas": "Универсальный холст",
"nodes": "Ноды", "nodes": "Ноды",

View File

@ -8,7 +8,6 @@
"darkTheme": "Темна", "darkTheme": "Темна",
"lightTheme": "Світла", "lightTheme": "Світла",
"greenTheme": "Зелена", "greenTheme": "Зелена",
"text2img": "Зображення із тексту (text2img)",
"img2img": "Зображення із зображення (img2img)", "img2img": "Зображення із зображення (img2img)",
"unifiedCanvas": "Універсальне полотно", "unifiedCanvas": "Універсальне полотно",
"nodes": "Вузли", "nodes": "Вузли",

View File

@ -8,7 +8,6 @@
"darkTheme": "暗色", "darkTheme": "暗色",
"lightTheme": "亮色", "lightTheme": "亮色",
"greenTheme": "绿色", "greenTheme": "绿色",
"text2img": "文字到图像",
"img2img": "图像到图像", "img2img": "图像到图像",
"unifiedCanvas": "统一画布", "unifiedCanvas": "统一画布",
"nodes": "节点", "nodes": "节点",

View File

@ -33,7 +33,6 @@
"langBrPortuguese": "巴西葡萄牙語", "langBrPortuguese": "巴西葡萄牙語",
"langRussian": "俄語", "langRussian": "俄語",
"langSpanish": "西班牙語", "langSpanish": "西班牙語",
"text2img": "文字到圖像",
"unifiedCanvas": "統一畫布" "unifiedCanvas": "統一畫布"
} }
} }

View File

@ -34,7 +34,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton <IAIIconButton
icon={<BiZoomIn />} icon={<BiZoomIn />}
aria-label={t('accessibility.zoomIn')} aria-label={t('accessibility.zoomIn')}
tooltip="Zoom In" tooltip={t('accessibility.zoomIn')}
onClick={() => zoomIn()} onClick={() => zoomIn()}
fontSize={20} fontSize={20}
/> />
@ -42,7 +42,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton <IAIIconButton
icon={<BiZoomOut />} icon={<BiZoomOut />}
aria-label={t('accessibility.zoomOut')} aria-label={t('accessibility.zoomOut')}
tooltip="Zoom Out" tooltip={t('accessibility.zoomOut')}
onClick={() => zoomOut()} onClick={() => zoomOut()}
fontSize={20} fontSize={20}
/> />
@ -50,7 +50,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton <IAIIconButton
icon={<BiRotateLeft />} icon={<BiRotateLeft />}
aria-label={t('accessibility.rotateCounterClockwise')} aria-label={t('accessibility.rotateCounterClockwise')}
tooltip="Rotate Counter-Clockwise" tooltip={t('accessibility.rotateCounterClockwise')}
onClick={rotateCounterClockwise} onClick={rotateCounterClockwise}
fontSize={20} fontSize={20}
/> />
@ -58,7 +58,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton <IAIIconButton
icon={<BiRotateRight />} icon={<BiRotateRight />}
aria-label={t('accessibility.rotateClockwise')} aria-label={t('accessibility.rotateClockwise')}
tooltip="Rotate Clockwise" tooltip={t('accessibility.rotateClockwise')}
onClick={rotateClockwise} onClick={rotateClockwise}
fontSize={20} fontSize={20}
/> />
@ -66,7 +66,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton <IAIIconButton
icon={<MdFlip />} icon={<MdFlip />}
aria-label={t('accessibility.flipHorizontally')} aria-label={t('accessibility.flipHorizontally')}
tooltip="Flip Horizontally" tooltip={t('accessibility.flipHorizontally')}
onClick={flipHorizontally} onClick={flipHorizontally}
fontSize={20} fontSize={20}
/> />
@ -74,7 +74,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton <IAIIconButton
icon={<MdFlip style={{ transform: 'rotate(90deg)' }} />} icon={<MdFlip style={{ transform: 'rotate(90deg)' }} />}
aria-label={t('accessibility.flipVertically')} aria-label={t('accessibility.flipVertically')}
tooltip="Flip Vertically" tooltip={t('accessibility.flipVertically')}
onClick={flipVertically} onClick={flipVertically}
fontSize={20} fontSize={20}
/> />
@ -82,7 +82,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton <IAIIconButton
icon={<BiReset />} icon={<BiReset />}
aria-label={t('accessibility.reset')} aria-label={t('accessibility.reset')}
tooltip="Reset" tooltip={t('accessibility.reset')}
onClick={() => { onClick={() => {
resetTransform(); resetTransform();
reset(); reset();

View File

@ -1,8 +1,23 @@
import i18n from 'i18next'; import i18n from 'i18next';
import LanguageDetector from 'i18next-browser-languagedetector'; import LanguageDetector from 'i18next-browser-languagedetector';
import Backend from 'i18next-http-backend'; import Backend from 'i18next-http-backend';
import { initReactI18next } from 'react-i18next'; import { initReactI18next } from 'react-i18next';
import translationEN from '../dist/locales/en.json';
if (import.meta.env.MODE === 'package') {
i18n.use(initReactI18next).init({
lng: 'en',
resources: {
en: { translation: translationEN },
},
debug: false,
interpolation: {
escapeValue: false,
},
returnNull: false,
});
} else {
i18n i18n
.use(Backend) .use(Backend)
.use(LanguageDetector) .use(LanguageDetector)
@ -18,5 +33,6 @@ i18n
}, },
returnNull: false, returnNull: false,
}); });
}
export default i18n; export default i18n;

View File

@ -38,14 +38,14 @@ dependencies = [
"albumentations", "albumentations",
"click", "click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==1.0.4", "compel==1.0.5",
"datasets", "datasets",
"diffusers[torch]~=0.14", "diffusers[torch]~=0.14",
"dnspython==2.2.1", "dnspython==2.2.1",
"einops", "einops",
"eventlet", "eventlet",
"facexlib", "facexlib",
"fastapi==0.94.1", "fastapi==0.88.0",
"fastapi-events==0.8.0", "fastapi-events==0.8.0",
"fastapi-socketio==0.0.10", "fastapi-socketio==0.0.10",
"flask==2.1.3", "flask==2.1.3",
@ -156,4 +156,3 @@ output = "coverage/index.xml"
[flake8] [flake8]
max-line-length = 120 max-line-length = 120

View File

@ -1,6 +1,8 @@
from .test_invoker import create_edge from .test_invoker import create_edge
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.processor import DefaultInvocationProcessor from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.invocation_queue import MemoryInvocationQueue
@ -21,13 +23,14 @@ def simple_graph():
def mock_services(): def mock_services():
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
return InvocationServices( return InvocationServices(
model_manager = None, model_manager = None, # type: ignore
events = None, events = None, # type: ignore
images = None, images = None, # type: ignore
latents = None, # type: ignore
queue = MemoryInvocationQueue(), queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(), processor = DefaultInvocationProcessor(),
restoration = None, restoration = None, # type: ignore
) )
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
@ -73,31 +76,23 @@ def test_graph_is_not_complete(simple_graph, mock_services):
def test_graph_state_expands_iterator(mock_services): def test_graph_state_expands_iterator(mock_services):
graph = Graph() graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"] graph.add_node(RangeInvocation(id = "0", start = 0, stop = 3, step = 1))
graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts))) graph.add_node(IterateInvocation(id = "1"))
graph.add_node(IterateInvocation(id = "2")) graph.add_node(MultiplyInvocation(id = "2", b = 10))
graph.add_node(ImageTestInvocation(id = "3")) graph.add_node(AddInvocation(id = "3", b = 1))
graph.add_edge(create_edge("1", "collection", "2", "collection")) graph.add_edge(create_edge("0", "collection", "1", "collection"))
graph.add_edge(create_edge("2", "item", "3", "prompt")) graph.add_edge(create_edge("1", "item", "2", "a"))
graph.add_edge(create_edge("2", "a", "3", "a"))
g = GraphExecutionState(graph = graph) g = GraphExecutionState(graph = graph)
n1 = invoke_next(g, mock_services) while not g.is_complete():
n2 = invoke_next(g, mock_services) invoke_next(g, mock_services)
n3 = invoke_next(g, mock_services)
n4 = invoke_next(g, mock_services)
n5 = invoke_next(g, mock_services)
assert g.prepared_source_mapping[n1[0].id] == "1" prepared_add_nodes = g.source_prepared_mapping['3']
assert g.prepared_source_mapping[n2[0].id] == "2" results = set([g.results[n].a for n in prepared_add_nodes])
assert g.prepared_source_mapping[n3[0].id] == "2" expected = set([1, 11, 21])
assert g.prepared_source_mapping[n4[0].id] == "3" assert results == expected
assert g.prepared_source_mapping[n5[0].id] == "3"
assert isinstance(n4[0], ImageTestInvocation)
assert isinstance(n5[0], ImageTestInvocation)
prompts = [n4[0].prompt, n5[0].prompt]
assert sorted(prompts) == sorted(test_prompts)
def test_graph_state_collects(mock_services): def test_graph_state_collects(mock_services):
graph = Graph() graph = Graph()

View File

@ -24,10 +24,11 @@ def mock_services() -> InvocationServices:
model_manager = None, # type: ignore model_manager = None, # type: ignore
events = TestEventService(), events = TestEventService(),
images = None, # type: ignore images = None, # type: ignore
latents = None, # type: ignore
queue = MemoryInvocationQueue(), queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(), processor = DefaultInvocationProcessor(),
restoration = None, restoration = None, # type: ignore
) )
@pytest.fixture() @pytest.fixture()