Compare commits

...

25 Commits

Author SHA1 Message Date
3c50448ccf Merge branch 'main' into dev/pytorch2 2023-04-06 21:47:46 -04:00
76bcd4d44f Fix typo (#3133)
'hotdot' to 'hotdog'; the world's least important PR :)
2023-04-07 12:38:05 +12:00
50f5e1bc83 Fix typo
'hotdot' to 'hotdog'; the world's least important PR :)
2023-04-06 16:47:57 -07:00
85b020f76c [nodes] Add latent nodes, storage, and fix iteration bugs (#3091)
* Add latents nodes.
* Fix iteration expansion.
* Add collection generator nodes, math nodes.
* Add noise node.
* Add some graph debug commands to the CLI.
* Fix negative id linking in CLI.
* Fix a CLI bug with multiple links per node.
2023-04-06 04:06:05 +00:00
a7833cc9a9 [api] Add models router and list model API. 2023-04-05 23:59:07 -04:00
919294e977 fix build-container.yml (#3117)
Add permission go write packages to GITHUB_TOKEN
2023-04-06 00:25:00 +02:00
7640acfb1f update build-container.yml
- add packages write permission
2023-04-05 15:44:26 +02:00
aed9ecef2a feat(nodes): add thumbnail generation to DiskImageStorage 2023-04-05 08:22:23 +10:00
18cddd7972 Right link on pytorch installer for linux rocm (#3084)
Right link on pytorch installer for linux rocm
2023-04-04 17:40:42 -04:00
e6b25f4ae3 Merge branch 'main' into patch-1 2023-04-04 17:40:12 -04:00
d1c0050e65 fix(nodes): fix typo in list_sessions handler (#3109)
The typo accidentally did not affect functionality; when `query==""`, it
`search()`ed but found everything due to empty query, then paginated
results, so it worked the same as `list()`.

Still fix it
2023-04-03 21:24:48 -04:00
ecdfa136a0 fix(nodes): fix typo in list_sessions handler 2023-04-04 00:34:32 +10:00
5cd513ee63 [deps] bump compel version to fix crash on invalid (auto111) syntax (#3107)
currently if users input eg `happy (camper:0.3)` it gets parsed
incorrectly, which causes crashes if it's in the negative prompt. bump
to compel 1.0.5 fixes the parser to avoid this (note the weight is
parsed as plain text, it's not converted to proper invoke syntax)
2023-04-04 02:30:17 +12:00
ab45086546 Merge branch 'main' into deps_bump_compel 2023-04-04 02:05:40 +12:00
77ba7359f4 fix(nodes): commit changes to db 2023-04-03 19:09:49 +10:00
8cbe2e14d9 bump compel version to fix on invalid (auto111) syntax 2023-04-03 10:37:01 +02:00
ee86eedf01 Right link on pytorch installer for linux rocm
Right link on pytorch installer for linux rocm
2023-03-31 17:22:00 -03:00
c4e6511a59 Add support for yet another TI embedding format (main version) (#3050)
- This PR adds support for embedding files that contain a single key
"emb_params". The only example I know of this format is the
"EasyNegative" embedding on HuggingFace, but there are certainly others.

- This PR also adds support for loading embedding files that have been
saved in safetensors format.

- It also cleans up the code so that the logic of probing for and
selecting the right format parser is clear.

- This is the same as #3045, which is on the 2.3 branch.
2023-03-31 03:57:57 -04:00
44843be4c8 Merge branch 'main' into enhance/support-another-embedding-format-main 2023-03-30 23:16:52 -04:00
b9df9e26f2 Merge branch 'main' into enhance/support-another-embedding-format-main 2023-03-30 07:51:23 -04:00
e11c1d66ab handle multiple tokens and embeddings in single file 2023-03-29 22:05:06 -04:00
cdb3616dca Merge branch 'main' into enhance/support-another-embedding-format-main 2023-03-28 21:03:06 -04:00
abe4dc8ac1 Add support for yet another textual inversion embedding format
- This PR adds support for embedding files that contain a single key
  "emb_params". The only example I know of this format is the
  "EasyNegative" embedding on HuggingFace, but there are certainly
  others.

- This PR also adds support for loading embedding files that have been
  saved in safetensors format.

- It also cleans up the code so that the logic of probing for and
  selecting the right format parser is clear.
2023-03-27 09:39:03 -04:00
5dec5b6f51 Merge branch 'main' into dev/pytorch2 2023-03-23 23:31:21 -04:00
e158ad8534 deps: upgrade to PyTorch 2.0 (replaces xformers) 2023-03-15 15:45:48 -07:00
25 changed files with 1128 additions and 206 deletions

View File

@ -18,6 +18,7 @@ on:
permissions: permissions:
contents: write contents: write
packages: write
jobs: jobs:
docker: docker:

View File

@ -268,7 +268,7 @@ model is so good at inpainting, a good substitute is to use the `clipseg` text
masking option: masking option:
```bash ```bash
invoke> a fluffy cat eating a hotdot invoke> a fluffy cat eating a hotdog
Outputs: Outputs:
[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog [1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog
invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat

View File

@ -456,13 +456,12 @@ def get_torch_source() -> (Union[str, None],str):
optional_modules = None optional_modules = None
if OS == "Linux": if OS == "Linux":
if device == "rocm": if device == "rocm":
url = "https://download.pytorch.org/whl/rocm5.2" url = "https://download.pytorch.org/whl/rocm5.4.2"
elif device == "cpu": elif device == "cpu":
url = "https://download.pytorch.org/whl/cpu" url = "https://download.pytorch.org/whl/cpu"
if device == 'cuda': if device == 'cuda':
url = 'https://download.pytorch.org/whl/cu117' url = 'https://download.pytorch.org/whl/cu118'
optional_modules = '[xformers]'
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13 # in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@ -3,6 +3,8 @@
import os import os
from argparse import Namespace from argparse import Namespace
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ...backend import Globals from ...backend import Globals
from ..services.model_manager_initializer import get_model_manager from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices from ..services.restoration_services import RestorationServices
@ -54,7 +56,9 @@ class ApiDependencies:
os.path.join(os.path.dirname(__file__), "../../../../outputs") os.path.join(os.path.dirname(__file__), "../../../../outputs")
) )
images = DiskImageStorage(output_folder) latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
images = DiskImageStorage(f'{output_folder}/images')
# TODO: build a file/path manager? # TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db") db_location = os.path.join(output_folder, "invokeai.db")
@ -62,6 +66,7 @@ class ApiDependencies:
services = InvocationServices( services = InvocationServices(
model_manager=get_model_manager(config), model_manager=get_model_manager(config),
events=events, events=events,
latents=latents,
images=images, images=images,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_execution_manager=SqliteItemStorage[GraphExecutionState]( graph_execution_manager=SqliteItemStorage[GraphExecutionState](

View File

@ -23,6 +23,16 @@ async def get_image(
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name) filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
return FileResponse(filename) return FileResponse(filename)
@images_router.get("/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail")
async def get_thumbnail(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
):
"""Gets a thumbnail"""
# TODO: This is not really secure at all. At least make sure only output results are served
filename = ApiDependencies.invoker.services.images.get_path(image_type, 'thumbnails/' + image_name)
return FileResponse(filename)
@images_router.post( @images_router.post(
"/uploads/", "/uploads/",

View File

@ -0,0 +1,279 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Annotated, Any, List, Literal, Optional, Union
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
class VaeRepo(BaseModel):
repo_id: str = Field(description="The repo ID to use for this VAE")
path: Optional[str] = Field(description="The path to the VAE")
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model")
class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt'
config: str = Field(description="The path to the model config")
weights: str = Field(description="The path to the model weights")
vae: str = Field(description="The path to the model VAE")
width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['diffusers'] = 'diffusers'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class ModelsList(BaseModel):
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
@models_router.get(
"/",
operation_id="list_models",
responses={200: {"model": ModelsList }},
)
async def list_models() -> ModelsList:
"""Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models()
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
# @socketio.on("requestSystemConfig")
# def handle_request_capabilities():
# print(">> System config requested")
# config = self.get_system_config()
# config["model_list"] = self.generate.model_manager.list_models()
# config["infill_methods"] = infill_methods()
# socketio.emit("systemConfig", config)
# @socketio.on("searchForModels")
# def handle_search_models(search_folder: str):
# try:
# if not search_folder:
# socketio.emit(
# "foundModels",
# {"search_folder": None, "found_models": None},
# )
# else:
# (
# search_folder,
# found_models,
# ) = self.generate.model_manager.search_models(search_folder)
# socketio.emit(
# "foundModels",
# {"search_folder": search_folder, "found_models": found_models},
# )
# except Exception as e:
# self.handle_exceptions(e)
# print("\n")
# @socketio.on("addNewModel")
# def handle_add_model(new_model_config: dict):
# try:
# model_name = new_model_config["name"]
# del new_model_config["name"]
# model_attributes = new_model_config
# if len(model_attributes["vae"]) == 0:
# del model_attributes["vae"]
# update = False
# current_model_list = self.generate.model_manager.list_models()
# if model_name in current_model_list:
# update = True
# print(f">> Adding New Model: {model_name}")
# self.generate.model_manager.add_model(
# model_name=model_name,
# model_attributes=model_attributes,
# clobber=True,
# )
# self.generate.model_manager.commit(opt.conf)
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "newModelAdded",
# {
# "new_model_name": model_name,
# "model_list": new_model_list,
# "update": update,
# },
# )
# print(f">> New Model Added: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("deleteModel")
# def handle_delete_model(model_name: str):
# try:
# print(f">> Deleting Model: {model_name}")
# self.generate.model_manager.del_model(model_name)
# self.generate.model_manager.commit(opt.conf)
# updated_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelDeleted",
# {
# "deleted_model_name": model_name,
# "model_list": updated_model_list,
# },
# )
# print(f">> Model Deleted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("requestModelChange")
# def handle_set_model(model_name: str):
# try:
# print(f">> Model change requested: {model_name}")
# model = self.generate.set_model(model_name)
# model_list = self.generate.model_manager.list_models()
# if model is None:
# socketio.emit(
# "modelChangeFailed",
# {"model_name": model_name, "model_list": model_list},
# )
# else:
# socketio.emit(
# "modelChanged",
# {"model_name": model_name, "model_list": model_list},
# )
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict):
# try:
# if model_info := self.generate.model_manager.model_info(
# model_name=model_to_convert["model_name"]
# ):
# if "weights" in model_info:
# ckpt_path = Path(model_info["weights"])
# original_config_file = Path(model_info["config"])
# model_name = model_to_convert["model_name"]
# model_description = model_info["description"]
# else:
# self.socketio.emit(
# "error", {"message": "Model is not a valid checkpoint file"}
# )
# else:
# self.socketio.emit(
# "error", {"message": "Could not retrieve model info."}
# )
# if not ckpt_path.is_absolute():
# ckpt_path = Path(Globals.root, ckpt_path)
# if original_config_file and not original_config_file.is_absolute():
# original_config_file = Path(Globals.root, original_config_file)
# diffusers_path = Path(
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
# )
# if model_to_convert["save_location"] == "root":
# diffusers_path = Path(
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
# )
# if (
# model_to_convert["save_location"] == "custom"
# and model_to_convert["custom_location"] is not None
# ):
# diffusers_path = Path(
# model_to_convert["custom_location"], f"{model_name}_diffusers"
# )
# if diffusers_path.exists():
# shutil.rmtree(diffusers_path)
# self.generate.model_manager.convert_and_import(
# ckpt_path,
# diffusers_path,
# model_name=model_name,
# model_description=model_description,
# vae=None,
# original_config_file=original_config_file,
# commit_to_conf=opt.conf,
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelConverted",
# {
# "new_model_name": model_name,
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Model Converted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("mergeDiffusersModels")
# def merge_diffusers_models(model_merge_info: dict):
# try:
# models_to_merge = model_merge_info["models_to_merge"]
# model_ids_or_paths = [
# self.generate.model_manager.model_name_or_path(x)
# for x in models_to_merge
# ]
# merged_pipe = merge_diffusion_models(
# model_ids_or_paths,
# model_merge_info["alpha"],
# model_merge_info["interp"],
# model_merge_info["force"],
# )
# dump_path = global_models_dir() / "merged_models"
# if model_merge_info["model_merge_save_path"] is not None:
# dump_path = Path(model_merge_info["model_merge_save_path"])
# os.makedirs(dump_path, exist_ok=True)
# dump_path = dump_path / model_merge_info["merged_model_name"]
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
# merged_model_config = dict(
# model_name=model_merge_info["merged_model_name"],
# description=f'Merge of models {", ".join(models_to_merge)}',
# commit_to_conf=opt.conf,
# )
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
# "vae", None
# ):
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
# merged_model_config.update(vae=vae)
# self.generate.model_manager.import_diffuser_model(
# dump_path, **merged_model_config
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelsMerged",
# {
# "merged_models": models_to_merge,
# "merged_model_name": model_merge_info["merged_model_name"],
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
# except Exception as e:
# self.handle_exceptions(e)

View File

@ -51,7 +51,7 @@ async def list_sessions(
query: str = Query(default="", description="The query string to search for"), query: str = Query(default="", description="The query string to search for"),
) -> PaginatedResults[GraphExecutionState]: ) -> PaginatedResults[GraphExecutionState]:
"""Gets a list of sessions, optionally searching""" """Gets a list of sessions, optionally searching"""
if filter == "": if query == "":
result = ApiDependencies.invoker.services.graph_execution_manager.list( result = ApiDependencies.invoker.services.graph_execution_manager.list(
page, per_page page, per_page
) )

View File

@ -14,7 +14,7 @@ from pydantic.schema import schema
from ..backend import Args from ..backend import Args
from .api.dependencies import ApiDependencies from .api.dependencies import ApiDependencies
from .api.routers import images, sessions from .api.routers import images, sessions, models
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations import * from .invocations import *
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
@ -76,6 +76,8 @@ app.include_router(sessions.session_router, prefix="/api")
app.include_router(images.images_router, prefix="/api") app.include_router(images.images_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
# Build a custom OpenAPI to include all outputs # Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow? # TODO: can outputs be included on metadata of invocation schemas somehow?

View File

@ -4,7 +4,8 @@ from abc import ABC, abstractmethod
import argparse import argparse
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import networkx as nx
import matplotlib.pyplot as plt
from ..invocations.image import ImageField from ..invocations.image import ImageField
from ..services.graph import GraphExecutionState from ..services.graph import GraphExecutionState
from ..services.invoker import Invoker from ..services.invoker import Invoker
@ -46,7 +47,7 @@ def add_parsers(
f"--{name}", f"--{name}",
dest=name, dest=name,
type=field_type, type=field_type,
default=field.default, default=field.default if field.default_factory is None else field.default_factory(),
choices=allowed_values, choices=allowed_values,
help=field.field_info.description, help=field.field_info.description,
) )
@ -55,7 +56,7 @@ def add_parsers(
f"--{name}", f"--{name}",
dest=name, dest=name,
type=field.type_, type=field.type_,
default=field.default, default=field.default if field.default_factory is None else field.default_factory(),
help=field.field_info.description, help=field.field_info.description,
) )
@ -200,3 +201,39 @@ class SetDefaultCommand(BaseCommand):
del context.defaults[self.field] del context.defaults[self.field]
else: else:
context.defaults[self.field] = self.value context.defaults[self.field] = self.value
class DrawGraphCommand(BaseCommand):
"""Debugs a graph"""
type: Literal['draw_graph'] = 'draw_graph'
def run(self, context: CliContext) -> None:
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
nxgraph = session.graph.nx_graph_flat()
# Draw the networkx graph
plt.figure(figsize=(20, 20))
pos = nx.spectral_layout(nxgraph)
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
nx.draw_networkx_edges(nxgraph, pos, width=2)
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
plt.axis("off")
plt.show()
class DrawExecutionGraphCommand(BaseCommand):
"""Debugs an execution graph"""
type: Literal['draw_xgraph'] = 'draw_xgraph'
def run(self, context: CliContext) -> None:
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
nxgraph = session.execution_graph.nx_graph_flat()
# Draw the networkx graph
plt.figure(figsize=(20, 20))
pos = nx.spectral_layout(nxgraph)
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
nx.draw_networkx_edges(nxgraph, pos, width=2)
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
plt.axis("off")
plt.show()

View File

@ -2,6 +2,7 @@
import argparse import argparse
import os import os
import re
import shlex import shlex
import time import time
from typing import ( from typing import (
@ -12,6 +13,8 @@ from typing import (
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import Field from pydantic.fields import Field
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..backend import Args from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
from .cli.completer import set_autocompleter from .cli.completer import set_autocompleter
@ -20,7 +23,7 @@ from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase from .services.events import EventServiceBase
from .services.model_manager_initializer import get_model_manager from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices from .services.restoration_services import RestorationServices
from .services.graph import Edge, EdgeConnection, GraphExecutionState from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible
from .services.image_storage import DiskImageStorage from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices from .services.invocation_services import InvocationServices
@ -44,7 +47,7 @@ def add_invocation_args(command_parser):
"-l", "-l",
action="append", action="append",
nargs=3, nargs=3,
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)", help="A link in the format 'source_node source_field dest_field'. source_node can be relative to history (e.g. -1)",
) )
command_parser.add_argument( command_parser.add_argument(
@ -94,6 +97,9 @@ def generate_matching_edges(
invalid_fields = set(["type", "id"]) invalid_fields = set(["type", "id"])
matching_fields = matching_fields.difference(invalid_fields) matching_fields = matching_fields.difference(invalid_fields)
# Validate types
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])]
edges = [ edges = [
Edge( Edge(
source=EdgeConnection(node_id=a.id, field=field), source=EdgeConnection(node_id=a.id, field=field),
@ -149,7 +155,8 @@ def invoke_cli():
services = InvocationServices( services = InvocationServices(
model_manager=model_manager, model_manager=model_manager,
events=events, events=events,
images=DiskImageStorage(output_folder), latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=DiskImageStorage(f'{output_folder}/images'),
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_execution_manager=SqliteItemStorage[GraphExecutionState]( graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions" filename=db_location, table_name="graph_executions"
@ -162,6 +169,8 @@ def invoke_cli():
session: GraphExecutionState = invoker.create_execution_state() session: GraphExecutionState = invoker.create_execution_state()
parser = get_command_parser() parser = get_command_parser()
re_negid = re.compile('^-[0-9]+$')
# Uncomment to print out previous sessions at startup # Uncomment to print out previous sessions at startup
# print(services.session_manager.list()) # print(services.session_manager.list())
@ -227,7 +236,11 @@ def invoke_cli():
# Parse provided links # Parse provided links
if "link_node" in args and args["link_node"]: if "link_node" in args and args["link_node"]:
for link in args["link_node"]: for link in args["link_node"]:
link_node = context.session.graph.get_node(link) node_id = link
if re_negid.match(node_id):
node_id = str(current_id + int(node_id))
link_node = context.session.graph.get_node(node_id)
matching_edges = generate_matching_edges( matching_edges = generate_matching_edges(
link_node, command.command link_node, command.command
) )
@ -237,10 +250,15 @@ def invoke_cli():
if "link" in args and args["link"]: if "link" in args and args["link"]:
for link in args["link"]: for link in args["link"]:
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]] edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]]
node_id = link[0]
if re_negid.match(node_id):
node_id = str(current_id + int(node_id))
edges.append( edges.append(
Edge( Edge(
source=EdgeConnection(node_id=link[1], field=link[0]), source=EdgeConnection(node_id=node_id, field=link[1]),
destination=EdgeConnection( destination=EdgeConnection(
node_id=command.command.id, field=link[2] node_id=command.command.id, field=link[2]
) )

View File

@ -0,0 +1,50 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
import cv2 as cv
import numpy as np
import numpy.random
from PIL import Image, ImageOps
from pydantic import Field
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput
from .image import ImageField, ImageOutput
class IntCollectionOutput(BaseInvocationOutput):
"""A collection of integers"""
type: Literal["int_collection"] = "int_collection"
# Outputs
collection: list[int] = Field(default=[], description="The int collection")
class RangeInvocation(BaseInvocation):
"""Creates a range"""
type: Literal["range"] = "range"
# Inputs
start: int = Field(default=0, description="The start of the range")
stop: int = Field(default=10, description="The stop of the range")
step: int = Field(default=1, description="The step of the range")
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""
type: Literal["random_range"] = "random_range"
# Inputs
low: int = Field(default=0, description="The inclusive low value")
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
size: int = Field(default=1, description="The number of values to generate")
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(numpy.random.randint(self.low, self.high, size=self.size)))

View File

@ -0,0 +1,321 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional
from pydantic import BaseModel, Field
from torch import Tensor
import torch
from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import CUDA_DEVICE, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
import numpy as np
from accelerate.utils import set_seed
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
from diffusers import DiffusionPipeline
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents"""
#fmt: off
type: Literal["latent_output"] = "latent_output"
latents: LatentsField = Field(default=None, description="The output latents")
#fmt: on
class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output"""
#fmt: off
type: Literal["noise_output"] = "noise_output"
noise: LatentsField = Field(default=None, description="The output noise")
#fmt: on
# TODO: this seems like a hack
scheduler_map = dict(
ddim=diffusers.DDIMScheduler,
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_euler=diffusers.EulerDiscreteScheduler,
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
k_heun=diffusers.HeunDiscreteScheduler,
k_lms=diffusers.LMSDiscreteScheduler,
plms=diffusers.PNDMScheduler,
)
SAMPLER_NAME_VALUES = Literal[
tuple(list(scheduler_map.keys()))
]
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
scheduler = scheduler_class.from_config(model.scheduler.config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
return scheduler
def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8):
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(latent_channels, 4)
use_device = "cpu" if (use_mps_noise or device.type == "mps") else device
generator = torch.Generator(device=use_device).manual_seed(seed)
x = torch.randn(
[
1,
input_channels,
height // downsampling_factor,
width // downsampling_factor,
],
dtype=torch_dtype(device),
device=use_device,
generator=generator,
).to(device)
# if self.perlin > 0.0:
# perlin_noise = self.get_perlin_noise(
# width // self.downsampling_factor, height // self.downsampling_factor
# )
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
return x
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", )
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(CUDA_DEVICE)
noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, noise)
return NoiseOutput(
noise=LatentsField(latents_name=name)
)
# Text to image
class TextToLatentsInvocation(BaseInvocation):
"""Generates latents from a prompt."""
type: Literal["t2l"] = "t2l"
# Inputs
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
# fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, sample: Tensor, step: int
) -> None:
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
self.steps,
)
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = model_manager.get_model(self.model)
model_name = model_info['model_name']
model_hash = model_info['hash']
model: StableDiffusionGeneratorPipeline = model_info['model']
model.scheduler = get_scheduler(
model=model,
scheduler_name=self.sampler_name
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
return model
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
conditioning_data = ConditioningData(
uc,
c,
self.cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=0.0,#threshold,
warmup=0.2,#warmup,
h_symmetry_time_pct=None,#h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta)
return conditioning_data
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=model.device, dtype=latent.dtype
)
timesteps, _ = model.get_img2img_timesteps(
self.steps,
self.strength,
device=model.device,
)
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
# Latent to image
class LatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""
type: Literal["l2i"] = "l2i"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
model: str = Field(default="", description="The model to use")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
# TODO: this only really needs the vae
model_info = context.services.model_manager.get_model(self.model)
model: StableDiffusionGeneratorPipeline = model_info['model']
with torch.inference_mode():
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@ -0,0 +1,68 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Literal, Optional
import numpy
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
class IntOutput(BaseInvocationOutput):
"""An integer output"""
#fmt: off
type: Literal["int_output"] = "int_output"
a: int = Field(default=None, description="The output integer")
#fmt: on
class AddInvocation(BaseInvocation):
"""Adds two numbers"""
#fmt: off
type: Literal["add"] = "add"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a + self.b)
class SubtractInvocation(BaseInvocation):
"""Subtracts two numbers"""
#fmt: off
type: Literal["sub"] = "sub"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a - self.b)
class MultiplyInvocation(BaseInvocation):
"""Multiplies two numbers"""
#fmt: off
type: Literal["mul"] = "mul"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a * self.b)
class DivideInvocation(BaseInvocation):
"""Divides two numbers"""
#fmt: off
type: Literal["div"] = "div"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=int(self.a / self.b))

View File

@ -1069,9 +1069,8 @@ class GraphExecutionState(BaseModel):
n n
for n in prepared_nodes for n in prepared_nodes
if all( if all(
pit nx.has_path(execution_graph, pit[0], n)
for pit in parent_iterators for pit in parent_iterators
if nx.has_path(execution_graph, pit[0], n)
) )
), ),
None, None,

View File

@ -9,6 +9,7 @@ from queue import Queue
from typing import Dict from typing import Dict
from PIL.Image import Image from PIL.Image import Image
from invokeai.app.util.save_thumbnail import save_thumbnail
from invokeai.backend.image_util import PngWriter from invokeai.backend.image_util import PngWriter
@ -66,6 +67,9 @@ class DiskImageStorage(ImageStorageBase):
Path(os.path.join(output_folder, image_type)).mkdir( Path(os.path.join(output_folder, image_type)).mkdir(
parents=True, exist_ok=True parents=True, exist_ok=True
) )
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
def get(self, image_type: ImageType, image_name: str) -> Image: def get(self, image_type: ImageType, image_name: str) -> Image:
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_type, image_name)
@ -87,7 +91,11 @@ class DiskImageStorage(ImageStorageBase):
self.__pngWriter.save_image_and_prompt_to_png( self.__pngWriter.save_image_and_prompt_to_png(
image, "", image_subpath, None image, "", image_subpath, None
) # TODO: just pass full path to png writer ) # TODO: just pass full path to png writer
save_thumbnail(
image=image,
filename=image_name,
path=os.path.join(self.__output_folder, image_type, "thumbnails"),
)
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_type, image_name)
self.__set_cache(image_path, image) self.__set_cache(image_path, image)

View File

@ -2,6 +2,7 @@
from invokeai.backend import ModelManager from invokeai.backend import ModelManager
from .events import EventServiceBase from .events import EventServiceBase
from .latent_storage import LatentsStorageBase
from .image_storage import ImageStorageBase from .image_storage import ImageStorageBase
from .restoration_services import RestorationServices from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC from .invocation_queue import InvocationQueueABC
@ -11,6 +12,7 @@ class InvocationServices:
"""Services that can be used by invocations""" """Services that can be used by invocations"""
events: EventServiceBase events: EventServiceBase
latents: LatentsStorageBase
images: ImageStorageBase images: ImageStorageBase
queue: InvocationQueueABC queue: InvocationQueueABC
model_manager: ModelManager model_manager: ModelManager
@ -24,6 +26,7 @@ class InvocationServices:
self, self,
model_manager: ModelManager, model_manager: ModelManager,
events: EventServiceBase, events: EventServiceBase,
latents: LatentsStorageBase,
images: ImageStorageBase, images: ImageStorageBase,
queue: InvocationQueueABC, queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC["GraphExecutionState"], graph_execution_manager: ItemStorageABC["GraphExecutionState"],
@ -32,6 +35,7 @@ class InvocationServices:
): ):
self.model_manager = model_manager self.model_manager = model_manager
self.events = events self.events = events
self.latents = latents
self.images = images self.images = images
self.queue = queue self.queue = queue
self.graph_execution_manager = graph_execution_manager self.graph_execution_manager = graph_execution_manager

View File

@ -33,7 +33,6 @@ class Invoker:
self.services.graph_execution_manager.set(graph_execution_state) self.services.graph_execution_manager.set(graph_execution_state)
# Queue the invocation # Queue the invocation
print(f"queueing item {invocation.id}")
self.services.queue.put( self.services.queue.put(
InvocationQueueItem( InvocationQueueItem(
# session_id = session.id, # session_id = session.id,

View File

@ -0,0 +1,93 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict
import torch
class LatentsStorageBase(ABC):
"""Responsible for storing and retrieving latents."""
@abstractmethod
def get(self, name: str) -> torch.Tensor:
pass
@abstractmethod
def set(self, name: str, data: torch.Tensor) -> None:
pass
@abstractmethod
def delete(self, name: str) -> None:
pass
class ForwardCacheLatentsStorage(LatentsStorageBase):
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
__cache: Dict[str, torch.Tensor]
__cache_ids: Queue
__max_cache_size: int
__underlying_storage: LatentsStorageBase
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
self.__underlying_storage = underlying_storage
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = max_cache_size
def get(self, name: str) -> torch.Tensor:
cache_item = self.__get_cache(name)
if cache_item is not None:
return cache_item
latent = self.__underlying_storage.get(name)
self.__set_cache(name, latent)
return latent
def set(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.set(name, data)
self.__set_cache(name, data)
def delete(self, name: str) -> None:
self.__underlying_storage.delete(name)
if name in self.__cache:
del self.__cache[name]
def __get_cache(self, name: str) -> torch.Tensor|None:
return None if name not in self.__cache else self.__cache[name]
def __set_cache(self, name: str, data: torch.Tensor):
if not name in self.__cache:
self.__cache[name] = data
self.__cache_ids.put(name)
if self.__cache_ids.qsize() > self.__max_cache_size:
self.__cache.pop(self.__cache_ids.get())
class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching"""
__output_folder: str
def __init__(self, output_folder: str):
self.__output_folder = output_folder
Path(output_folder).mkdir(parents=True, exist_ok=True)
def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name)
return torch.load(latent_path)
def set(self, name: str, data: torch.Tensor) -> None:
latent_path = self.get_path(name)
torch.save(data, latent_path)
def delete(self, name: str) -> None:
latent_path = self.get_path(name)
os.remove(latent_path)
def get_path(self, name: str) -> str:
return os.path.join(self.__output_folder, name)

View File

@ -59,6 +59,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""", f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(item.json(),), (item.json(),),
) )
self._conn.commit()
finally: finally:
self._lock.release() self._lock.release()
self._on_changed(item) self._on_changed(item)
@ -84,6 +85,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._cursor.execute( self._cursor.execute(
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),) f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
) )
self._conn.commit()
finally: finally:
self._lock.release() self._lock.release()
self._on_deleted(id) self._on_deleted(id)

View File

@ -0,0 +1,25 @@
import os
from PIL import Image
def save_thumbnail(
image: Image.Image,
filename: str,
path: str,
size: int = 256,
) -> str:
"""
Saves a thumbnail of an image, returning its path.
"""
base_filename = os.path.splitext(filename)[0]
thumbnail_path = os.path.join(path, base_filename + ".webp")
if os.path.exists(thumbnail_path):
return thumbnail_path
image_copy = image.copy()
image_copy.thumbnail(size=(size, size))
image_copy.save(thumbnail_path, "WEBP")
return thumbnail_path

View File

@ -531,7 +531,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id: str = None, run_id: str = None,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
): ):
self._adjust_memory_efficient_attention(latents) # FIXME: do we still use any slicing now that PyTorch 2.0 has scaled dot-product attention on all platforms?
# self._adjust_memory_efficient_attention(latents)
if run_id is None: if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH) run_id = secrets.token_urlsafe(self.ID_LENGTH)
if additional_guidance is None: if additional_guidance is None:

View File

@ -1,16 +1,26 @@
import os
import traceback import traceback
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union, List
import safetensors.torch
import torch import torch
from compel.embeddings_provider import BaseTextualInversionManager from compel.embeddings_provider import BaseTextualInversionManager
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from .concepts_lib import HuggingFaceConceptsLibrary from .concepts_lib import HuggingFaceConceptsLibrary
@dataclass
class EmbeddingInfo:
name: str
embedding: torch.Tensor
num_vectors_per_token: int
token_dim: int
trained_steps: int = None
trained_model_name: str = None
trained_model_checksum: str = None
@dataclass @dataclass
class TextualInversion: class TextualInversion:
@ -72,66 +82,46 @@ class TextualInversionManager(BaseTextualInversionManager):
if str(ckpt_path).endswith(".DS_Store"): if str(ckpt_path).endswith(".DS_Store"):
return return
try: embedding_list = self._parse_embedding(str(ckpt_path))
scan_result = scan_file_path(str(ckpt_path)) for embedding_info in embedding_list:
if scan_result.infected_files == 1: if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
print( print(
f"\n### Security Issues Found in Model: {scan_result.issues_count}" f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
) )
print("### For your safety, InvokeAI will not load this embed.") continue
return
except Exception:
print(
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
)
return
embedding_info = self._parse_embedding(str(ckpt_path)) # Resolve the situation in which an earlier embedding has claimed the same
# trigger string. We replace the trigger with '<source_file>', as we used to.
if embedding_info is None: trigger_str = embedding_info.name
# We've already put out an error message about the bad embedding in _parse_embedding, so just return. sourcefile = (
return f"{ckpt_path.parent.name}/{ckpt_path.name}"
elif (
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
!= embedding_info["token_dim"]
):
print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
)
return
# Resolve the situation in which an earlier embedding has claimed the same
# trigger string. We replace the trigger with '<source_file>', as we used to.
trigger_str = embedding_info["name"]
sourcefile = (
f"{ckpt_path.parent.name}/{ckpt_path.name}"
if ckpt_path.name == "learned_embeds.bin"
else ckpt_path.name
)
if trigger_str in self.trigger_to_sourcefile:
replacement_trigger_str = (
f"<{ckpt_path.parent.name}>"
if ckpt_path.name == "learned_embeds.bin" if ckpt_path.name == "learned_embeds.bin"
else f"<{ckpt_path.stem}>" else ckpt_path.name
) )
print(
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
)
trigger_str = replacement_trigger_str
try: if trigger_str in self.trigger_to_sourcefile:
self._add_textual_inversion( replacement_trigger_str = (
trigger_str, f"<{ckpt_path.parent.name}>"
embedding_info["embedding"], if ckpt_path.name == "learned_embeds.bin"
defer_injecting_tokens=defer_injecting_tokens, else f"<{ckpt_path.stem}>"
) )
# remember which source file claims this trigger print(
self.trigger_to_sourcefile[trigger_str] = sourcefile f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
)
trigger_str = replacement_trigger_str
except ValueError as e: try:
print(f' | Ignoring incompatible embedding {embedding_info["name"]}') self._add_textual_inversion(
print(f" | The error was {str(e)}") trigger_str,
embedding_info.embedding,
defer_injecting_tokens=defer_injecting_tokens,
)
# remember which source file claims this trigger
self.trigger_to_sourcefile[trigger_str] = sourcefile
except ValueError as e:
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
print(f" | The error was {str(e)}")
def _add_textual_inversion( def _add_textual_inversion(
self, trigger_str, embedding, defer_injecting_tokens=False self, trigger_str, embedding, defer_injecting_tokens=False
@ -309,111 +299,130 @@ class TextualInversionManager(BaseTextualInversionManager):
return token_id return token_id
def _parse_embedding(self, embedding_file: str):
file_type = embedding_file.split(".")[-1] def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]:
if file_type == "pt": suffix = Path(embedding_file).suffix
return self._parse_embedding_pt(embedding_file) try:
elif file_type == "bin": if suffix in [".pt",".ckpt",".bin"]:
return self._parse_embedding_bin(embedding_file) scan_result = scan_file_path(embedding_file)
if scan_result.infected_files > 0:
print(
f" ** Security Issues Found in Model: {scan_result.issues_count}"
)
print(" ** For your safety, InvokeAI will not load this embed.")
return list()
ckpt = torch.load(embedding_file,map_location="cpu")
else:
ckpt = safetensors.torch.load_file(embedding_file)
except Exception as e:
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
return list()
# try to figure out what kind of embedding file it is and parse accordingly
keys = list(ckpt.keys())
if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
elif all(x in keys for x in ['string_to_token','string_to_param']):
return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
elif 'emb_params' in keys:
return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
else: else:
print(f"** Notice: unrecognized embedding file format: {embedding_file}") return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
return None
def _parse_embedding_pt(self, embedding_file): def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
embedding_ckpt = torch.load(embedding_file, map_location="cpu") basename = Path(file_path).stem
embedding_info = {} print(f' | Loading v1 embedding file: {basename}')
# Check if valid embedding file embeddings = list()
if "string_to_token" and "string_to_param" in embedding_ckpt: token_counter = -1
# Catch variants that do not have the expected keys or values. for token,embedding in embedding_ckpt["string_to_param"].items():
try: if token_counter < 0:
embedding_info["name"] = embedding_ckpt["name"] or os.path.basename( trigger = embedding_ckpt["name"]
os.path.splitext(embedding_file)[0] elif token_counter == 0:
) trigger = f'<basename>'
else:
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
token_counter += 1
embedding_info = EmbeddingInfo(
name = trigger,
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
trained_steps = embedding_ckpt["step"],
trained_model_name = embedding_ckpt["sd_checkpoint_name"],
trained_model_checksum = embedding_ckpt["sd_checkpoint"]
)
embeddings.append(embedding_info)
return embeddings
# Check num of embeddings and warn user only the first will be used def _parse_embedding_v2 (
embedding_info["num_of_embeddings"] = len( self, embedding_ckpt: dict, file_path: str
embedding_ckpt["string_to_token"] ) -> List[EmbeddingInfo]:
)
if embedding_info["num_of_embeddings"] > 1:
print(">> More than 1 embedding found. Will use the first one")
embedding = list(embedding_ckpt["string_to_param"].values())[0]
except (AttributeError, KeyError):
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
embedding_info["embedding"] = embedding
embedding_info["num_vectors_per_token"] = embedding.size()[0]
embedding_info["token_dim"] = embedding.size()[1]
try:
embedding_info["trained_steps"] = embedding_ckpt["step"]
embedding_info["trained_model_name"] = embedding_ckpt[
"sd_checkpoint_name"
]
embedding_info["trained_model_checksum"] = embedding_ckpt[
"sd_checkpoint"
]
except AttributeError:
print(">> No Training Details Found. Passing ...")
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
# They are actually .bin files
elif len(embedding_ckpt.keys()) == 1:
embedding_info = self._parse_embedding_bin(embedding_file)
else:
print(">> Invalid embedding format")
embedding_info = None
return embedding_info
def _parse_embedding_bin(self, embedding_file):
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
embedding_info = {}
if list(embedding_ckpt.keys()) == 0:
print(">> Invalid concepts file")
embedding_info = None
else:
for token in list(embedding_ckpt.keys()):
embedding_info["name"] = (
token
or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
)
embedding_info["embedding"] = embedding_ckpt[token]
embedding_info[
"num_vectors_per_token"
] = 1 # All Concepts seem to default to 1
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
return embedding_info
def _handle_broken_pt_variants(
self, embedding_ckpt: dict, embedding_file: str
) -> dict:
""" """
This handles the broken .pt file variants. We only know of one at present. This handles embedding .pt file variant #2.
""" """
embedding_info = {} basename = Path(file_path).stem
print(f' | Loading v2 embedding file: {basename}')
embeddings = list()
if isinstance( if isinstance(
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
): ):
for token in list(embedding_ckpt["string_to_token"].keys()): token_counter = 0
embedding_info["name"] = ( for token,embedding in embedding_ckpt["string_to_param"].items():
token trigger = token if token != '*' \
if token != "*" else f'<{basename}>' if token_counter == 0 \
else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>" else f'<{basename}-{int(token_counter:=token_counter+1)}>'
embedding_info = EmbeddingInfo(
name = trigger,
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
) )
embedding_info["embedding"] = embedding_ckpt[ embeddings.append(embedding_info)
"string_to_param"
].state_dict()[token]
embedding_info["num_vectors_per_token"] = embedding_info[
"embedding"
].shape[0]
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
else: else:
print(">> Invalid embedding format") print(f" ** {basename}: Unrecognized embedding format")
embedding_info = None
return embedding_info return embeddings
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
"""
Parse 'version 3' of the .pt textual inversion embedding files.
"""
basename = Path(file_path).stem
print(f' | Loading v3 embedding file: {basename}')
embedding = embedding_ckpt['emb_params']
embedding_info = EmbeddingInfo(
name = f'<{basename}>',
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
)
return [embedding_info]
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]:
"""
Parse 'version 4' of the textual inversion embedding files. This one
is usually associated with .bin files trained by HuggingFace diffusers.
"""
basename = Path(filepath).stem
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
print(f' | Loading v4 embedding file: {short_path}')
embeddings = list()
if list(embedding_ckpt.keys()) == 0:
print(f" ** Invalid embeddings file: {short_path}")
else:
for token,embedding in embedding_ckpt.items():
embedding_info = EmbeddingInfo(
name = token or f"<{basename}>",
embedding = embedding,
num_vectors_per_token = 1, # All Concepts seem to default to 1
token_dim = embedding.size()[0],
)
embeddings.append(embedding_info)
return embeddings

View File

@ -38,7 +38,7 @@ dependencies = [
"albumentations", "albumentations",
"click", "click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==1.0.4", "compel==1.0.5",
"datasets", "datasets",
"diffusers[torch]~=0.14", "diffusers[torch]~=0.14",
"dnspython==2.2.1", "dnspython==2.2.1",
@ -71,10 +71,10 @@ dependencies = [
"scikit-image>=0.19", "scikit-image>=0.19",
"send2trash", "send2trash",
"test-tube>=0.7.5", "test-tube>=0.7.5",
"torch>=1.13.1", "torch~=2.0",
"torchvision>=0.14.1", "torchvision>=0.14.1",
"torchmetrics", "torchmetrics",
"transformers~=4.26", "transformers~=4.27",
"uvicorn[standard]==0.21.1", "uvicorn[standard]==0.21.1",
"windows-curses; sys_platform=='win32'", "windows-curses; sys_platform=='win32'",
] ]
@ -90,10 +90,6 @@ dependencies = [
"pudb", "pudb",
] ]
"test" = ["pytest>6.0.0", "pytest-cov"] "test" = ["pytest>6.0.0", "pytest-cov"]
"xformers" = [
"xformers~=0.0.16; sys_platform!='darwin'",
"triton; sys_platform=='linux'",
]
[project.scripts] [project.scripts]

View File

@ -1,6 +1,8 @@
from .test_invoker import create_edge from .test_invoker import create_edge
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.processor import DefaultInvocationProcessor from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.invocation_queue import MemoryInvocationQueue
@ -21,13 +23,14 @@ def simple_graph():
def mock_services(): def mock_services():
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
return InvocationServices( return InvocationServices(
model_manager = None, model_manager = None, # type: ignore
events = None, events = None, # type: ignore
images = None, images = None, # type: ignore
latents = None, # type: ignore
queue = MemoryInvocationQueue(), queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(), processor = DefaultInvocationProcessor(),
restoration = None, restoration = None, # type: ignore
) )
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
@ -73,31 +76,23 @@ def test_graph_is_not_complete(simple_graph, mock_services):
def test_graph_state_expands_iterator(mock_services): def test_graph_state_expands_iterator(mock_services):
graph = Graph() graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"] graph.add_node(RangeInvocation(id = "0", start = 0, stop = 3, step = 1))
graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts))) graph.add_node(IterateInvocation(id = "1"))
graph.add_node(IterateInvocation(id = "2")) graph.add_node(MultiplyInvocation(id = "2", b = 10))
graph.add_node(ImageTestInvocation(id = "3")) graph.add_node(AddInvocation(id = "3", b = 1))
graph.add_edge(create_edge("1", "collection", "2", "collection")) graph.add_edge(create_edge("0", "collection", "1", "collection"))
graph.add_edge(create_edge("2", "item", "3", "prompt")) graph.add_edge(create_edge("1", "item", "2", "a"))
graph.add_edge(create_edge("2", "a", "3", "a"))
g = GraphExecutionState(graph = graph) g = GraphExecutionState(graph = graph)
n1 = invoke_next(g, mock_services) while not g.is_complete():
n2 = invoke_next(g, mock_services) invoke_next(g, mock_services)
n3 = invoke_next(g, mock_services)
n4 = invoke_next(g, mock_services) prepared_add_nodes = g.source_prepared_mapping['3']
n5 = invoke_next(g, mock_services) results = set([g.results[n].a for n in prepared_add_nodes])
expected = set([1, 11, 21])
assert results == expected
assert g.prepared_source_mapping[n1[0].id] == "1"
assert g.prepared_source_mapping[n2[0].id] == "2"
assert g.prepared_source_mapping[n3[0].id] == "2"
assert g.prepared_source_mapping[n4[0].id] == "3"
assert g.prepared_source_mapping[n5[0].id] == "3"
assert isinstance(n4[0], ImageTestInvocation)
assert isinstance(n5[0], ImageTestInvocation)
prompts = [n4[0].prompt, n5[0].prompt]
assert sorted(prompts) == sorted(test_prompts)
def test_graph_state_collects(mock_services): def test_graph_state_collects(mock_services):
graph = Graph() graph = Graph()

View File

@ -24,10 +24,11 @@ def mock_services() -> InvocationServices:
model_manager = None, # type: ignore model_manager = None, # type: ignore
events = TestEventService(), events = TestEventService(),
images = None, # type: ignore images = None, # type: ignore
latents = None, # type: ignore
queue = MemoryInvocationQueue(), queue = MemoryInvocationQueue(),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(), processor = DefaultInvocationProcessor(),
restoration = None, restoration = None, # type: ignore
) )
@pytest.fixture() @pytest.fixture()