mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into bugfix/release-updater
This commit is contained in:
commit
cd1b350dae
19
.github/stale.yaml
vendored
Normal file
19
.github/stale.yaml
vendored
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# Number of days of inactivity before an issue becomes stale
|
||||||
|
daysUntilStale: 28
|
||||||
|
# Number of days of inactivity before a stale issue is closed
|
||||||
|
daysUntilClose: 14
|
||||||
|
# Issues with these labels will never be considered stale
|
||||||
|
exemptLabels:
|
||||||
|
- pinned
|
||||||
|
- security
|
||||||
|
# Label to use when marking an issue as stale
|
||||||
|
staleLabel: stale
|
||||||
|
# Comment to post when marking an issue as stale. Set to `false` to disable
|
||||||
|
markComment: >
|
||||||
|
This issue has been automatically marked as stale because it has not had
|
||||||
|
recent activity. It will be closed if no further activity occurs. Please
|
||||||
|
update the ticket if this is still a problem on the latest release.
|
||||||
|
# Comment to post when closing a stale issue. Set to `false` to disable
|
||||||
|
closeComment: >
|
||||||
|
Due to inactivity, this issue has been automatically closed. If this is
|
||||||
|
still a problem on the latest release, please recreate the issue.
|
@ -268,7 +268,7 @@ model is so good at inpainting, a good substitute is to use the `clipseg` text
|
|||||||
masking option:
|
masking option:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
invoke> a fluffy cat eating a hotdot
|
invoke> a fluffy cat eating a hotdog
|
||||||
Outputs:
|
Outputs:
|
||||||
[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog
|
[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog
|
||||||
invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat
|
invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
import os
|
import os
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
|
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
from ...backend import Globals
|
from ...backend import Globals
|
||||||
from ..services.model_manager_initializer import get_model_manager
|
from ..services.model_manager_initializer import get_model_manager
|
||||||
from ..services.restoration_services import RestorationServices
|
from ..services.restoration_services import RestorationServices
|
||||||
@ -54,7 +56,9 @@ class ApiDependencies:
|
|||||||
os.path.join(os.path.dirname(__file__), "../../../../outputs")
|
os.path.join(os.path.dirname(__file__), "../../../../outputs")
|
||||||
)
|
)
|
||||||
|
|
||||||
images = DiskImageStorage(output_folder)
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
|
||||||
|
|
||||||
|
images = DiskImageStorage(f'{output_folder}/images')
|
||||||
|
|
||||||
# TODO: build a file/path manager?
|
# TODO: build a file/path manager?
|
||||||
db_location = os.path.join(output_folder, "invokeai.db")
|
db_location = os.path.join(output_folder, "invokeai.db")
|
||||||
@ -62,6 +66,7 @@ class ApiDependencies:
|
|||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=get_model_manager(config),
|
model_manager=get_model_manager(config),
|
||||||
events=events,
|
events=events,
|
||||||
|
latents=latents,
|
||||||
images=images,
|
images=images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||||
|
279
invokeai/app/api/routers/models.py
Normal file
279
invokeai/app/api/routers/models.py
Normal file
@ -0,0 +1,279 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from typing import Annotated, Any, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from fastapi.routing import APIRouter
|
||||||
|
from pydantic import BaseModel, Field, parse_obj_as
|
||||||
|
|
||||||
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
|
|
||||||
|
class VaeRepo(BaseModel):
|
||||||
|
repo_id: str = Field(description="The repo ID to use for this VAE")
|
||||||
|
path: Optional[str] = Field(description="The path to the VAE")
|
||||||
|
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInfo(BaseModel):
|
||||||
|
description: Optional[str] = Field(description="A description of the model")
|
||||||
|
|
||||||
|
|
||||||
|
class CkptModelInfo(ModelInfo):
|
||||||
|
format: Literal['ckpt'] = 'ckpt'
|
||||||
|
|
||||||
|
config: str = Field(description="The path to the model config")
|
||||||
|
weights: str = Field(description="The path to the model weights")
|
||||||
|
vae: str = Field(description="The path to the model VAE")
|
||||||
|
width: Optional[int] = Field(description="The width of the model")
|
||||||
|
height: Optional[int] = Field(description="The height of the model")
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusersModelInfo(ModelInfo):
|
||||||
|
format: Literal['diffusers'] = 'diffusers'
|
||||||
|
|
||||||
|
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
||||||
|
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||||
|
path: Optional[str] = Field(description="The path to the model")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelsList(BaseModel):
|
||||||
|
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@models_router.get(
|
||||||
|
"/",
|
||||||
|
operation_id="list_models",
|
||||||
|
responses={200: {"model": ModelsList }},
|
||||||
|
)
|
||||||
|
async def list_models() -> ModelsList:
|
||||||
|
"""Gets a list of models"""
|
||||||
|
models_raw = ApiDependencies.invoker.services.model_manager.list_models()
|
||||||
|
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||||
|
return models
|
||||||
|
|
||||||
|
# @socketio.on("requestSystemConfig")
|
||||||
|
# def handle_request_capabilities():
|
||||||
|
# print(">> System config requested")
|
||||||
|
# config = self.get_system_config()
|
||||||
|
# config["model_list"] = self.generate.model_manager.list_models()
|
||||||
|
# config["infill_methods"] = infill_methods()
|
||||||
|
# socketio.emit("systemConfig", config)
|
||||||
|
|
||||||
|
# @socketio.on("searchForModels")
|
||||||
|
# def handle_search_models(search_folder: str):
|
||||||
|
# try:
|
||||||
|
# if not search_folder:
|
||||||
|
# socketio.emit(
|
||||||
|
# "foundModels",
|
||||||
|
# {"search_folder": None, "found_models": None},
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# (
|
||||||
|
# search_folder,
|
||||||
|
# found_models,
|
||||||
|
# ) = self.generate.model_manager.search_models(search_folder)
|
||||||
|
# socketio.emit(
|
||||||
|
# "foundModels",
|
||||||
|
# {"search_folder": search_folder, "found_models": found_models},
|
||||||
|
# )
|
||||||
|
# except Exception as e:
|
||||||
|
# self.handle_exceptions(e)
|
||||||
|
# print("\n")
|
||||||
|
|
||||||
|
# @socketio.on("addNewModel")
|
||||||
|
# def handle_add_model(new_model_config: dict):
|
||||||
|
# try:
|
||||||
|
# model_name = new_model_config["name"]
|
||||||
|
# del new_model_config["name"]
|
||||||
|
# model_attributes = new_model_config
|
||||||
|
# if len(model_attributes["vae"]) == 0:
|
||||||
|
# del model_attributes["vae"]
|
||||||
|
# update = False
|
||||||
|
# current_model_list = self.generate.model_manager.list_models()
|
||||||
|
# if model_name in current_model_list:
|
||||||
|
# update = True
|
||||||
|
|
||||||
|
# print(f">> Adding New Model: {model_name}")
|
||||||
|
|
||||||
|
# self.generate.model_manager.add_model(
|
||||||
|
# model_name=model_name,
|
||||||
|
# model_attributes=model_attributes,
|
||||||
|
# clobber=True,
|
||||||
|
# )
|
||||||
|
# self.generate.model_manager.commit(opt.conf)
|
||||||
|
|
||||||
|
# new_model_list = self.generate.model_manager.list_models()
|
||||||
|
# socketio.emit(
|
||||||
|
# "newModelAdded",
|
||||||
|
# {
|
||||||
|
# "new_model_name": model_name,
|
||||||
|
# "model_list": new_model_list,
|
||||||
|
# "update": update,
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
# print(f">> New Model Added: {model_name}")
|
||||||
|
# except Exception as e:
|
||||||
|
# self.handle_exceptions(e)
|
||||||
|
|
||||||
|
# @socketio.on("deleteModel")
|
||||||
|
# def handle_delete_model(model_name: str):
|
||||||
|
# try:
|
||||||
|
# print(f">> Deleting Model: {model_name}")
|
||||||
|
# self.generate.model_manager.del_model(model_name)
|
||||||
|
# self.generate.model_manager.commit(opt.conf)
|
||||||
|
# updated_model_list = self.generate.model_manager.list_models()
|
||||||
|
# socketio.emit(
|
||||||
|
# "modelDeleted",
|
||||||
|
# {
|
||||||
|
# "deleted_model_name": model_name,
|
||||||
|
# "model_list": updated_model_list,
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
# print(f">> Model Deleted: {model_name}")
|
||||||
|
# except Exception as e:
|
||||||
|
# self.handle_exceptions(e)
|
||||||
|
|
||||||
|
# @socketio.on("requestModelChange")
|
||||||
|
# def handle_set_model(model_name: str):
|
||||||
|
# try:
|
||||||
|
# print(f">> Model change requested: {model_name}")
|
||||||
|
# model = self.generate.set_model(model_name)
|
||||||
|
# model_list = self.generate.model_manager.list_models()
|
||||||
|
# if model is None:
|
||||||
|
# socketio.emit(
|
||||||
|
# "modelChangeFailed",
|
||||||
|
# {"model_name": model_name, "model_list": model_list},
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# socketio.emit(
|
||||||
|
# "modelChanged",
|
||||||
|
# {"model_name": model_name, "model_list": model_list},
|
||||||
|
# )
|
||||||
|
# except Exception as e:
|
||||||
|
# self.handle_exceptions(e)
|
||||||
|
|
||||||
|
# @socketio.on("convertToDiffusers")
|
||||||
|
# def convert_to_diffusers(model_to_convert: dict):
|
||||||
|
# try:
|
||||||
|
# if model_info := self.generate.model_manager.model_info(
|
||||||
|
# model_name=model_to_convert["model_name"]
|
||||||
|
# ):
|
||||||
|
# if "weights" in model_info:
|
||||||
|
# ckpt_path = Path(model_info["weights"])
|
||||||
|
# original_config_file = Path(model_info["config"])
|
||||||
|
# model_name = model_to_convert["model_name"]
|
||||||
|
# model_description = model_info["description"]
|
||||||
|
# else:
|
||||||
|
# self.socketio.emit(
|
||||||
|
# "error", {"message": "Model is not a valid checkpoint file"}
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# self.socketio.emit(
|
||||||
|
# "error", {"message": "Could not retrieve model info."}
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if not ckpt_path.is_absolute():
|
||||||
|
# ckpt_path = Path(Globals.root, ckpt_path)
|
||||||
|
|
||||||
|
# if original_config_file and not original_config_file.is_absolute():
|
||||||
|
# original_config_file = Path(Globals.root, original_config_file)
|
||||||
|
|
||||||
|
# diffusers_path = Path(
|
||||||
|
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if model_to_convert["save_location"] == "root":
|
||||||
|
# diffusers_path = Path(
|
||||||
|
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if (
|
||||||
|
# model_to_convert["save_location"] == "custom"
|
||||||
|
# and model_to_convert["custom_location"] is not None
|
||||||
|
# ):
|
||||||
|
# diffusers_path = Path(
|
||||||
|
# model_to_convert["custom_location"], f"{model_name}_diffusers"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if diffusers_path.exists():
|
||||||
|
# shutil.rmtree(diffusers_path)
|
||||||
|
|
||||||
|
# self.generate.model_manager.convert_and_import(
|
||||||
|
# ckpt_path,
|
||||||
|
# diffusers_path,
|
||||||
|
# model_name=model_name,
|
||||||
|
# model_description=model_description,
|
||||||
|
# vae=None,
|
||||||
|
# original_config_file=original_config_file,
|
||||||
|
# commit_to_conf=opt.conf,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# new_model_list = self.generate.model_manager.list_models()
|
||||||
|
# socketio.emit(
|
||||||
|
# "modelConverted",
|
||||||
|
# {
|
||||||
|
# "new_model_name": model_name,
|
||||||
|
# "model_list": new_model_list,
|
||||||
|
# "update": True,
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
# print(f">> Model Converted: {model_name}")
|
||||||
|
# except Exception as e:
|
||||||
|
# self.handle_exceptions(e)
|
||||||
|
|
||||||
|
# @socketio.on("mergeDiffusersModels")
|
||||||
|
# def merge_diffusers_models(model_merge_info: dict):
|
||||||
|
# try:
|
||||||
|
# models_to_merge = model_merge_info["models_to_merge"]
|
||||||
|
# model_ids_or_paths = [
|
||||||
|
# self.generate.model_manager.model_name_or_path(x)
|
||||||
|
# for x in models_to_merge
|
||||||
|
# ]
|
||||||
|
# merged_pipe = merge_diffusion_models(
|
||||||
|
# model_ids_or_paths,
|
||||||
|
# model_merge_info["alpha"],
|
||||||
|
# model_merge_info["interp"],
|
||||||
|
# model_merge_info["force"],
|
||||||
|
# )
|
||||||
|
|
||||||
|
# dump_path = global_models_dir() / "merged_models"
|
||||||
|
# if model_merge_info["model_merge_save_path"] is not None:
|
||||||
|
# dump_path = Path(model_merge_info["model_merge_save_path"])
|
||||||
|
|
||||||
|
# os.makedirs(dump_path, exist_ok=True)
|
||||||
|
# dump_path = dump_path / model_merge_info["merged_model_name"]
|
||||||
|
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||||
|
|
||||||
|
# merged_model_config = dict(
|
||||||
|
# model_name=model_merge_info["merged_model_name"],
|
||||||
|
# description=f'Merge of models {", ".join(models_to_merge)}',
|
||||||
|
# commit_to_conf=opt.conf,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
||||||
|
# "vae", None
|
||||||
|
# ):
|
||||||
|
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
|
||||||
|
# merged_model_config.update(vae=vae)
|
||||||
|
|
||||||
|
# self.generate.model_manager.import_diffuser_model(
|
||||||
|
# dump_path, **merged_model_config
|
||||||
|
# )
|
||||||
|
# new_model_list = self.generate.model_manager.list_models()
|
||||||
|
|
||||||
|
# socketio.emit(
|
||||||
|
# "modelsMerged",
|
||||||
|
# {
|
||||||
|
# "merged_models": models_to_merge,
|
||||||
|
# "merged_model_name": model_merge_info["merged_model_name"],
|
||||||
|
# "model_list": new_model_list,
|
||||||
|
# "update": True,
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
# print(f">> Models Merged: {models_to_merge}")
|
||||||
|
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||||
|
# except Exception as e:
|
||||||
|
# self.handle_exceptions(e)
|
@ -14,7 +14,7 @@ from pydantic.schema import schema
|
|||||||
|
|
||||||
from ..backend import Args
|
from ..backend import Args
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import images, sessions
|
from .api.routers import images, sessions, models
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations import *
|
from .invocations import *
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
@ -76,6 +76,8 @@ app.include_router(sessions.session_router, prefix="/api")
|
|||||||
|
|
||||||
app.include_router(images.images_router, prefix="/api")
|
app.include_router(images.images_router, prefix="/api")
|
||||||
|
|
||||||
|
app.include_router(models.models_router, prefix="/api")
|
||||||
|
|
||||||
|
|
||||||
# Build a custom OpenAPI to include all outputs
|
# Build a custom OpenAPI to include all outputs
|
||||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||||
|
@ -4,7 +4,8 @@ from abc import ABC, abstractmethod
|
|||||||
import argparse
|
import argparse
|
||||||
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
import networkx as nx
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
from ..invocations.image import ImageField
|
from ..invocations.image import ImageField
|
||||||
from ..services.graph import GraphExecutionState
|
from ..services.graph import GraphExecutionState
|
||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
@ -46,7 +47,7 @@ def add_parsers(
|
|||||||
f"--{name}",
|
f"--{name}",
|
||||||
dest=name,
|
dest=name,
|
||||||
type=field_type,
|
type=field_type,
|
||||||
default=field.default,
|
default=field.default if field.default_factory is None else field.default_factory(),
|
||||||
choices=allowed_values,
|
choices=allowed_values,
|
||||||
help=field.field_info.description,
|
help=field.field_info.description,
|
||||||
)
|
)
|
||||||
@ -55,7 +56,7 @@ def add_parsers(
|
|||||||
f"--{name}",
|
f"--{name}",
|
||||||
dest=name,
|
dest=name,
|
||||||
type=field.type_,
|
type=field.type_,
|
||||||
default=field.default,
|
default=field.default if field.default_factory is None else field.default_factory(),
|
||||||
help=field.field_info.description,
|
help=field.field_info.description,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -200,3 +201,39 @@ class SetDefaultCommand(BaseCommand):
|
|||||||
del context.defaults[self.field]
|
del context.defaults[self.field]
|
||||||
else:
|
else:
|
||||||
context.defaults[self.field] = self.value
|
context.defaults[self.field] = self.value
|
||||||
|
|
||||||
|
|
||||||
|
class DrawGraphCommand(BaseCommand):
|
||||||
|
"""Debugs a graph"""
|
||||||
|
type: Literal['draw_graph'] = 'draw_graph'
|
||||||
|
|
||||||
|
def run(self, context: CliContext) -> None:
|
||||||
|
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||||
|
nxgraph = session.graph.nx_graph_flat()
|
||||||
|
|
||||||
|
# Draw the networkx graph
|
||||||
|
plt.figure(figsize=(20, 20))
|
||||||
|
pos = nx.spectral_layout(nxgraph)
|
||||||
|
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
||||||
|
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
||||||
|
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||||
|
plt.axis("off")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
class DrawExecutionGraphCommand(BaseCommand):
|
||||||
|
"""Debugs an execution graph"""
|
||||||
|
type: Literal['draw_xgraph'] = 'draw_xgraph'
|
||||||
|
|
||||||
|
def run(self, context: CliContext) -> None:
|
||||||
|
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||||
|
nxgraph = session.execution_graph.nx_graph_flat()
|
||||||
|
|
||||||
|
# Draw the networkx graph
|
||||||
|
plt.figure(figsize=(20, 20))
|
||||||
|
pos = nx.spectral_layout(nxgraph)
|
||||||
|
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
||||||
|
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
||||||
|
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||||
|
plt.axis("off")
|
||||||
|
plt.show()
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import shlex
|
import shlex
|
||||||
import time
|
import time
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -12,6 +13,8 @@ from typing import (
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
from ..backend import Args
|
from ..backend import Args
|
||||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
||||||
from .cli.completer import set_autocompleter
|
from .cli.completer import set_autocompleter
|
||||||
@ -20,7 +23,7 @@ from .invocations.baseinvocation import BaseInvocation
|
|||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
from .services.model_manager_initializer import get_model_manager
|
from .services.model_manager_initializer import get_model_manager
|
||||||
from .services.restoration_services import RestorationServices
|
from .services.restoration_services import RestorationServices
|
||||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState
|
from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible
|
||||||
from .services.image_storage import DiskImageStorage
|
from .services.image_storage import DiskImageStorage
|
||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
from .services.invocation_services import InvocationServices
|
from .services.invocation_services import InvocationServices
|
||||||
@ -44,7 +47,7 @@ def add_invocation_args(command_parser):
|
|||||||
"-l",
|
"-l",
|
||||||
action="append",
|
action="append",
|
||||||
nargs=3,
|
nargs=3,
|
||||||
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)",
|
help="A link in the format 'source_node source_field dest_field'. source_node can be relative to history (e.g. -1)",
|
||||||
)
|
)
|
||||||
|
|
||||||
command_parser.add_argument(
|
command_parser.add_argument(
|
||||||
@ -94,6 +97,9 @@ def generate_matching_edges(
|
|||||||
invalid_fields = set(["type", "id"])
|
invalid_fields = set(["type", "id"])
|
||||||
matching_fields = matching_fields.difference(invalid_fields)
|
matching_fields = matching_fields.difference(invalid_fields)
|
||||||
|
|
||||||
|
# Validate types
|
||||||
|
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])]
|
||||||
|
|
||||||
edges = [
|
edges = [
|
||||||
Edge(
|
Edge(
|
||||||
source=EdgeConnection(node_id=a.id, field=field),
|
source=EdgeConnection(node_id=a.id, field=field),
|
||||||
@ -149,7 +155,8 @@ def invoke_cli():
|
|||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
events=events,
|
events=events,
|
||||||
images=DiskImageStorage(output_folder),
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||||
|
images=DiskImageStorage(f'{output_folder}/images'),
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||||
filename=db_location, table_name="graph_executions"
|
filename=db_location, table_name="graph_executions"
|
||||||
@ -162,6 +169,8 @@ def invoke_cli():
|
|||||||
session: GraphExecutionState = invoker.create_execution_state()
|
session: GraphExecutionState = invoker.create_execution_state()
|
||||||
parser = get_command_parser()
|
parser = get_command_parser()
|
||||||
|
|
||||||
|
re_negid = re.compile('^-[0-9]+$')
|
||||||
|
|
||||||
# Uncomment to print out previous sessions at startup
|
# Uncomment to print out previous sessions at startup
|
||||||
# print(services.session_manager.list())
|
# print(services.session_manager.list())
|
||||||
|
|
||||||
@ -227,7 +236,11 @@ def invoke_cli():
|
|||||||
# Parse provided links
|
# Parse provided links
|
||||||
if "link_node" in args and args["link_node"]:
|
if "link_node" in args and args["link_node"]:
|
||||||
for link in args["link_node"]:
|
for link in args["link_node"]:
|
||||||
link_node = context.session.graph.get_node(link)
|
node_id = link
|
||||||
|
if re_negid.match(node_id):
|
||||||
|
node_id = str(current_id + int(node_id))
|
||||||
|
|
||||||
|
link_node = context.session.graph.get_node(node_id)
|
||||||
matching_edges = generate_matching_edges(
|
matching_edges = generate_matching_edges(
|
||||||
link_node, command.command
|
link_node, command.command
|
||||||
)
|
)
|
||||||
@ -237,10 +250,15 @@ def invoke_cli():
|
|||||||
|
|
||||||
if "link" in args and args["link"]:
|
if "link" in args and args["link"]:
|
||||||
for link in args["link"]:
|
for link in args["link"]:
|
||||||
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
|
edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]]
|
||||||
|
|
||||||
|
node_id = link[0]
|
||||||
|
if re_negid.match(node_id):
|
||||||
|
node_id = str(current_id + int(node_id))
|
||||||
|
|
||||||
edges.append(
|
edges.append(
|
||||||
Edge(
|
Edge(
|
||||||
source=EdgeConnection(node_id=link[1], field=link[0]),
|
source=EdgeConnection(node_id=node_id, field=link[1]),
|
||||||
destination=EdgeConnection(
|
destination=EdgeConnection(
|
||||||
node_id=command.command.id, field=link[2]
|
node_id=command.command.id, field=link[2]
|
||||||
)
|
)
|
||||||
|
50
invokeai/app/invocations/collections.py
Normal file
50
invokeai/app/invocations/collections.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import cv2 as cv
|
||||||
|
import numpy as np
|
||||||
|
import numpy.random
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from ..services.image_storage import ImageType
|
||||||
|
from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput
|
||||||
|
from .image import ImageField, ImageOutput
|
||||||
|
|
||||||
|
|
||||||
|
class IntCollectionOutput(BaseInvocationOutput):
|
||||||
|
"""A collection of integers"""
|
||||||
|
|
||||||
|
type: Literal["int_collection"] = "int_collection"
|
||||||
|
|
||||||
|
# Outputs
|
||||||
|
collection: list[int] = Field(default=[], description="The int collection")
|
||||||
|
|
||||||
|
|
||||||
|
class RangeInvocation(BaseInvocation):
|
||||||
|
"""Creates a range"""
|
||||||
|
|
||||||
|
type: Literal["range"] = "range"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
start: int = Field(default=0, description="The start of the range")
|
||||||
|
stop: int = Field(default=10, description="The stop of the range")
|
||||||
|
step: int = Field(default=1, description="The step of the range")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
|
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||||
|
|
||||||
|
|
||||||
|
class RandomRangeInvocation(BaseInvocation):
|
||||||
|
"""Creates a collection of random numbers"""
|
||||||
|
|
||||||
|
type: Literal["random_range"] = "random_range"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
low: int = Field(default=0, description="The inclusive low value")
|
||||||
|
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||||
|
size: int = Field(default=1, description="The number of values to generate")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
|
return IntCollectionOutput(collection=list(numpy.random.randint(self.low, self.high, size=self.size)))
|
321
invokeai/app/invocations/latent.py
Normal file
321
invokeai/app/invocations/latent.py
Normal file
@ -0,0 +1,321 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from typing import Literal, Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from torch import Tensor
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...backend.model_management.model_manager import ModelManager
|
||||||
|
from ...backend.util.devices import CUDA_DEVICE, torch_dtype
|
||||||
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
|
from ...backend.image_util.seamless import configure_model_padding
|
||||||
|
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||||
|
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
||||||
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
|
import numpy as np
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
from ..services.image_storage import ImageType
|
||||||
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
|
from .image import ImageField, ImageOutput
|
||||||
|
from ...backend.generator import Generator
|
||||||
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
|
from ...backend.util.util import image_to_dataURL
|
||||||
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
import diffusers
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsField(BaseModel):
|
||||||
|
"""A latents field used for passing latents between invocations"""
|
||||||
|
|
||||||
|
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for invocations that output latents"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["latent_output"] = "latent_output"
|
||||||
|
latents: LatentsField = Field(default=None, description="The output latents")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
class NoiseOutput(BaseInvocationOutput):
|
||||||
|
"""Invocation noise output"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["noise_output"] = "noise_output"
|
||||||
|
noise: LatentsField = Field(default=None, description="The output noise")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this seems like a hack
|
||||||
|
scheduler_map = dict(
|
||||||
|
ddim=diffusers.DDIMScheduler,
|
||||||
|
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||||
|
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||||
|
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||||
|
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||||
|
k_euler=diffusers.EulerDiscreteScheduler,
|
||||||
|
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||||
|
k_heun=diffusers.HeunDiscreteScheduler,
|
||||||
|
k_lms=diffusers.LMSDiscreteScheduler,
|
||||||
|
plms=diffusers.PNDMScheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
SAMPLER_NAME_VALUES = Literal[
|
||||||
|
tuple(list(scheduler_map.keys()))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||||
|
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
|
||||||
|
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||||
|
# hack copied over from generate.py
|
||||||
|
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||||
|
scheduler.uses_inpainting_model = lambda: False
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8):
|
||||||
|
# limit noise to only the diffusion image channels, not the mask channels
|
||||||
|
input_channels = min(latent_channels, 4)
|
||||||
|
use_device = "cpu" if (use_mps_noise or device.type == "mps") else device
|
||||||
|
generator = torch.Generator(device=use_device).manual_seed(seed)
|
||||||
|
x = torch.randn(
|
||||||
|
[
|
||||||
|
1,
|
||||||
|
input_channels,
|
||||||
|
height // downsampling_factor,
|
||||||
|
width // downsampling_factor,
|
||||||
|
],
|
||||||
|
dtype=torch_dtype(device),
|
||||||
|
device=use_device,
|
||||||
|
generator=generator,
|
||||||
|
).to(device)
|
||||||
|
# if self.perlin > 0.0:
|
||||||
|
# perlin_noise = self.get_perlin_noise(
|
||||||
|
# width // self.downsampling_factor, height // self.downsampling_factor
|
||||||
|
# )
|
||||||
|
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseInvocation(BaseInvocation):
|
||||||
|
"""Generates latent noise."""
|
||||||
|
|
||||||
|
type: Literal["noise"] = "noise"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", )
|
||||||
|
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
|
||||||
|
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||||
|
device = torch.device(CUDA_DEVICE)
|
||||||
|
noise = get_noise(self.width, self.height, device, self.seed)
|
||||||
|
|
||||||
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
|
context.services.latents.set(name, noise)
|
||||||
|
return NoiseOutput(
|
||||||
|
noise=LatentsField(latents_name=name)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Text to image
|
||||||
|
class TextToLatentsInvocation(BaseInvocation):
|
||||||
|
"""Generates latents from a prompt."""
|
||||||
|
|
||||||
|
type: Literal["t2l"] = "t2l"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||||
|
# fmt: off
|
||||||
|
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||||
|
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||||
|
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||||
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
|
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||||
|
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||||
|
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
|
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
|
||||||
|
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
|
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
|
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||||
|
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
|
def dispatch_progress(
|
||||||
|
self, context: InvocationContext, sample: Tensor, step: int
|
||||||
|
) -> None:
|
||||||
|
# TODO: only output a preview image when requested
|
||||||
|
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||||
|
|
||||||
|
(width, height) = image.size
|
||||||
|
width *= 8
|
||||||
|
height *= 8
|
||||||
|
|
||||||
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||||
|
|
||||||
|
context.services.events.emit_generator_progress(
|
||||||
|
context.graph_execution_state_id,
|
||||||
|
self.id,
|
||||||
|
{
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"dataURL": dataURL
|
||||||
|
},
|
||||||
|
step,
|
||||||
|
self.steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||||
|
model_info = model_manager.get_model(self.model)
|
||||||
|
model_name = model_info['model_name']
|
||||||
|
model_hash = model_info['hash']
|
||||||
|
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||||
|
model.scheduler = get_scheduler(
|
||||||
|
model=model,
|
||||||
|
scheduler_name=self.sampler_name
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(model, DiffusionPipeline):
|
||||||
|
for component in [model.unet, model.vae]:
|
||||||
|
configure_model_padding(component,
|
||||||
|
self.seamless,
|
||||||
|
self.seamless_axes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
configure_model_padding(model,
|
||||||
|
self.seamless,
|
||||||
|
self.seamless_axes
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||||
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
|
||||||
|
conditioning_data = ConditioningData(
|
||||||
|
uc,
|
||||||
|
c,
|
||||||
|
self.cfg_scale,
|
||||||
|
extra_conditioning_info,
|
||||||
|
postprocessing_settings=PostprocessingSettings(
|
||||||
|
threshold=0.0,#threshold,
|
||||||
|
warmup=0.2,#warmup,
|
||||||
|
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||||
|
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||||
|
),
|
||||||
|
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta)
|
||||||
|
return conditioning_data
|
||||||
|
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
|
||||||
|
def step_callback(state: PipelineIntermediateState):
|
||||||
|
self.dispatch_progress(context, state.latents, state.step)
|
||||||
|
|
||||||
|
model = self.get_model(context.services.model_manager)
|
||||||
|
conditioning_data = self.get_conditioning_data(model)
|
||||||
|
|
||||||
|
# TODO: Verify the noise is the right size
|
||||||
|
|
||||||
|
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||||
|
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
|
||||||
|
noise=noise,
|
||||||
|
num_inference_steps=self.steps,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
callback=step_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
|
context.services.latents.set(name, result_latents)
|
||||||
|
return LatentsOutput(
|
||||||
|
latents=LatentsField(latents_name=name)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||||
|
"""Generates latents using latents as base image."""
|
||||||
|
|
||||||
|
type: Literal["l2l"] = "l2l"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||||
|
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
latent = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
|
def step_callback(state: PipelineIntermediateState):
|
||||||
|
self.dispatch_progress(context, state.latents, state.step)
|
||||||
|
|
||||||
|
model = self.get_model(context.services.model_manager)
|
||||||
|
conditioning_data = self.get_conditioning_data(model)
|
||||||
|
|
||||||
|
# TODO: Verify the noise is the right size
|
||||||
|
|
||||||
|
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||||
|
latent, device=model.device, dtype=latent.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
timesteps, _ = model.get_img2img_timesteps(
|
||||||
|
self.steps,
|
||||||
|
self.strength,
|
||||||
|
device=model.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||||
|
latents=initial_latents,
|
||||||
|
timesteps=timesteps,
|
||||||
|
noise=noise,
|
||||||
|
num_inference_steps=self.steps,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
callback=step_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
|
context.services.latents.set(name, result_latents)
|
||||||
|
return LatentsOutput(
|
||||||
|
latents=LatentsField(latents_name=name)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Latent to image
|
||||||
|
class LatentsToImageInvocation(BaseInvocation):
|
||||||
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
|
type: Literal["l2i"] = "l2i"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||||
|
model: str = Field(default="", description="The model to use")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
|
# TODO: this only really needs the vae
|
||||||
|
model_info = context.services.model_manager.get_model(self.model)
|
||||||
|
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
np_image = model.decode_latents(latents)
|
||||||
|
image = model.numpy_to_pil(np_image)[0]
|
||||||
|
|
||||||
|
image_type = ImageType.RESULT
|
||||||
|
image_name = context.services.images.create_name(
|
||||||
|
context.graph_execution_state_id, self.id
|
||||||
|
)
|
||||||
|
context.services.images.save(image_type, image_name, image)
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
|
)
|
68
invokeai/app/invocations/math.py
Normal file
68
invokeai/app/invocations/math.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
from PIL import Image, ImageFilter, ImageOps
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ..services.image_storage import ImageType
|
||||||
|
from ..services.invocation_services import InvocationServices
|
||||||
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
|
|
||||||
|
|
||||||
|
class IntOutput(BaseInvocationOutput):
|
||||||
|
"""An integer output"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["int_output"] = "int_output"
|
||||||
|
a: int = Field(default=None, description="The output integer")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class AddInvocation(BaseInvocation):
|
||||||
|
"""Adds two numbers"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["add"] = "add"
|
||||||
|
a: int = Field(default=0, description="The first number")
|
||||||
|
b: int = Field(default=0, description="The second number")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
|
return IntOutput(a=self.a + self.b)
|
||||||
|
|
||||||
|
|
||||||
|
class SubtractInvocation(BaseInvocation):
|
||||||
|
"""Subtracts two numbers"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["sub"] = "sub"
|
||||||
|
a: int = Field(default=0, description="The first number")
|
||||||
|
b: int = Field(default=0, description="The second number")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
|
return IntOutput(a=self.a - self.b)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiplyInvocation(BaseInvocation):
|
||||||
|
"""Multiplies two numbers"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["mul"] = "mul"
|
||||||
|
a: int = Field(default=0, description="The first number")
|
||||||
|
b: int = Field(default=0, description="The second number")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
|
return IntOutput(a=self.a * self.b)
|
||||||
|
|
||||||
|
|
||||||
|
class DivideInvocation(BaseInvocation):
|
||||||
|
"""Divides two numbers"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["div"] = "div"
|
||||||
|
a: int = Field(default=0, description="The first number")
|
||||||
|
b: int = Field(default=0, description="The second number")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
|
return IntOutput(a=int(self.a / self.b))
|
@ -1069,9 +1069,8 @@ class GraphExecutionState(BaseModel):
|
|||||||
n
|
n
|
||||||
for n in prepared_nodes
|
for n in prepared_nodes
|
||||||
if all(
|
if all(
|
||||||
pit
|
nx.has_path(execution_graph, pit[0], n)
|
||||||
for pit in parent_iterators
|
for pit in parent_iterators
|
||||||
if nx.has_path(execution_graph, pit[0], n)
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
from invokeai.backend import ModelManager
|
from invokeai.backend import ModelManager
|
||||||
|
|
||||||
from .events import EventServiceBase
|
from .events import EventServiceBase
|
||||||
|
from .latent_storage import LatentsStorageBase
|
||||||
from .image_storage import ImageStorageBase
|
from .image_storage import ImageStorageBase
|
||||||
from .restoration_services import RestorationServices
|
from .restoration_services import RestorationServices
|
||||||
from .invocation_queue import InvocationQueueABC
|
from .invocation_queue import InvocationQueueABC
|
||||||
@ -11,6 +12,7 @@ class InvocationServices:
|
|||||||
"""Services that can be used by invocations"""
|
"""Services that can be used by invocations"""
|
||||||
|
|
||||||
events: EventServiceBase
|
events: EventServiceBase
|
||||||
|
latents: LatentsStorageBase
|
||||||
images: ImageStorageBase
|
images: ImageStorageBase
|
||||||
queue: InvocationQueueABC
|
queue: InvocationQueueABC
|
||||||
model_manager: ModelManager
|
model_manager: ModelManager
|
||||||
@ -24,6 +26,7 @@ class InvocationServices:
|
|||||||
self,
|
self,
|
||||||
model_manager: ModelManager,
|
model_manager: ModelManager,
|
||||||
events: EventServiceBase,
|
events: EventServiceBase,
|
||||||
|
latents: LatentsStorageBase,
|
||||||
images: ImageStorageBase,
|
images: ImageStorageBase,
|
||||||
queue: InvocationQueueABC,
|
queue: InvocationQueueABC,
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
@ -32,6 +35,7 @@ class InvocationServices:
|
|||||||
):
|
):
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.events = events
|
self.events = events
|
||||||
|
self.latents = latents
|
||||||
self.images = images
|
self.images = images
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
self.graph_execution_manager = graph_execution_manager
|
self.graph_execution_manager = graph_execution_manager
|
||||||
|
@ -33,7 +33,6 @@ class Invoker:
|
|||||||
self.services.graph_execution_manager.set(graph_execution_state)
|
self.services.graph_execution_manager.set(graph_execution_state)
|
||||||
|
|
||||||
# Queue the invocation
|
# Queue the invocation
|
||||||
print(f"queueing item {invocation.id}")
|
|
||||||
self.services.queue.put(
|
self.services.queue.put(
|
||||||
InvocationQueueItem(
|
InvocationQueueItem(
|
||||||
# session_id = session.id,
|
# session_id = session.id,
|
||||||
|
93
invokeai/app/services/latent_storage.py
Normal file
93
invokeai/app/services/latent_storage.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from queue import Queue
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class LatentsStorageBase(ABC):
|
||||||
|
"""Responsible for storing and retrieving latents."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set(self, name: str, data: torch.Tensor) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||||
|
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||||
|
|
||||||
|
__cache: Dict[str, torch.Tensor]
|
||||||
|
__cache_ids: Queue
|
||||||
|
__max_cache_size: int
|
||||||
|
__underlying_storage: LatentsStorageBase
|
||||||
|
|
||||||
|
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||||
|
self.__underlying_storage = underlying_storage
|
||||||
|
self.__cache = dict()
|
||||||
|
self.__cache_ids = Queue()
|
||||||
|
self.__max_cache_size = max_cache_size
|
||||||
|
|
||||||
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
cache_item = self.__get_cache(name)
|
||||||
|
if cache_item is not None:
|
||||||
|
return cache_item
|
||||||
|
|
||||||
|
latent = self.__underlying_storage.get(name)
|
||||||
|
self.__set_cache(name, latent)
|
||||||
|
return latent
|
||||||
|
|
||||||
|
def set(self, name: str, data: torch.Tensor) -> None:
|
||||||
|
self.__underlying_storage.set(name, data)
|
||||||
|
self.__set_cache(name, data)
|
||||||
|
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
self.__underlying_storage.delete(name)
|
||||||
|
if name in self.__cache:
|
||||||
|
del self.__cache[name]
|
||||||
|
|
||||||
|
def __get_cache(self, name: str) -> torch.Tensor|None:
|
||||||
|
return None if name not in self.__cache else self.__cache[name]
|
||||||
|
|
||||||
|
def __set_cache(self, name: str, data: torch.Tensor):
|
||||||
|
if not name in self.__cache:
|
||||||
|
self.__cache[name] = data
|
||||||
|
self.__cache_ids.put(name)
|
||||||
|
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||||
|
self.__cache.pop(self.__cache_ids.get())
|
||||||
|
|
||||||
|
|
||||||
|
class DiskLatentsStorage(LatentsStorageBase):
|
||||||
|
"""Stores latents in a folder on disk without caching"""
|
||||||
|
|
||||||
|
__output_folder: str
|
||||||
|
|
||||||
|
def __init__(self, output_folder: str):
|
||||||
|
self.__output_folder = output_folder
|
||||||
|
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
latent_path = self.get_path(name)
|
||||||
|
return torch.load(latent_path)
|
||||||
|
|
||||||
|
def set(self, name: str, data: torch.Tensor) -> None:
|
||||||
|
latent_path = self.get_path(name)
|
||||||
|
torch.save(data, latent_path)
|
||||||
|
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
latent_path = self.get_path(name)
|
||||||
|
os.remove(latent_path)
|
||||||
|
|
||||||
|
def get_path(self, name: str) -> str:
|
||||||
|
return os.path.join(self.__output_folder, name)
|
||||||
|
|
@ -7,3 +7,4 @@ from .convert_ckpt_to_diffusers import (
|
|||||||
)
|
)
|
||||||
from .model_manager import ModelManager
|
from .model_manager import ModelManager
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""
|
"""enum
|
||||||
Manage a cache of Stable Diffusion model files for fast switching.
|
Manage a cache of Stable Diffusion model files for fast switching.
|
||||||
They are moved between GPU and CPU as necessary. If CPU memory falls
|
They are moved between GPU and CPU as necessary. If CPU memory falls
|
||||||
below a preset minimum, the least recently used model will be
|
below a preset minimum, the least recently used model will be
|
||||||
@ -15,7 +15,7 @@ import sys
|
|||||||
import textwrap
|
import textwrap
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from enum import Enum
|
from enum import Enum, auto
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import move, rmtree
|
from shutil import move, rmtree
|
||||||
from typing import Any, Optional, Union, Callable
|
from typing import Any, Optional, Union, Callable
|
||||||
@ -24,8 +24,12 @@ import safetensors
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import (
|
||||||
from diffusers import logging as dlogging
|
AutoencoderKL,
|
||||||
|
UNet2DConditionModel,
|
||||||
|
SchedulerMixin,
|
||||||
|
logging as dlogging,
|
||||||
|
)
|
||||||
from huggingface_hub import scan_cache_dir
|
from huggingface_hub import scan_cache_dir
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
@ -33,31 +37,52 @@ from picklescan.scanner import scan_file_path
|
|||||||
|
|
||||||
from invokeai.backend.globals import Globals, global_cache_dir
|
from invokeai.backend.globals import Globals, global_cache_dir
|
||||||
|
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from transformers import (
|
||||||
|
CLIPTextModel,
|
||||||
|
CLIPTokenizer,
|
||||||
|
CLIPFeatureExtractor,
|
||||||
|
)
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||||
|
StableDiffusionSafetyChecker,
|
||||||
|
)
|
||||||
|
from ..stable_diffusion import (
|
||||||
|
StableDiffusionGeneratorPipeline,
|
||||||
|
)
|
||||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||||
|
|
||||||
|
|
||||||
class SDLegacyType(Enum):
|
class SDLegacyType(Enum):
|
||||||
V1 = 1
|
V1 = auto()
|
||||||
V1_INPAINT = 2
|
V1_INPAINT = auto()
|
||||||
V2 = 3
|
V2 = auto()
|
||||||
V2_e = 4
|
V2_e = auto()
|
||||||
V2_v = 5
|
V2_v = auto()
|
||||||
UNKNOWN = 99
|
UNKNOWN = auto()
|
||||||
|
|
||||||
|
class SDModelComponent(Enum):
|
||||||
|
vae="vae"
|
||||||
|
text_encoder="text_encoder"
|
||||||
|
tokenizer="tokenizer"
|
||||||
|
unet="unet"
|
||||||
|
scheduler="scheduler"
|
||||||
|
safety_checker="safety_checker"
|
||||||
|
feature_extractor="feature_extractor"
|
||||||
|
|
||||||
DEFAULT_MAX_MODELS = 2
|
DEFAULT_MAX_MODELS = 2
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
'''
|
"""
|
||||||
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: OmegaConf|Path,
|
config: OmegaConf | Path,
|
||||||
device_type: torch.device = CUDA_DEVICE,
|
device_type: torch.device = CUDA_DEVICE,
|
||||||
precision: str = "float16",
|
precision: str = "float16",
|
||||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||||
sequential_offload=False,
|
sequential_offload=False,
|
||||||
embedding_path: Path=None,
|
embedding_path: Path = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file or
|
Initialize with the path to the models.yaml config file or
|
||||||
@ -87,14 +112,24 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
return model_name in self.config
|
return model_name in self.config
|
||||||
|
|
||||||
def get_model(self, model_name: str=None)->dict:
|
def get_model(self, model_name: str = None) -> dict:
|
||||||
"""
|
"""Given a model named identified in models.yaml, return a dict
|
||||||
Given a model named identified in models.yaml, return
|
containing the model object and some of its key features. If
|
||||||
the model object. If in RAM will load into GPU VRAM.
|
in RAM will load into GPU VRAM. If on disk, will load from
|
||||||
If on disk, will load from there.
|
there.
|
||||||
|
The dict has the following keys:
|
||||||
|
'model': The StableDiffusionGeneratorPipeline object
|
||||||
|
'model_name': The name of the model in models.yaml
|
||||||
|
'width': The width of images trained by this model
|
||||||
|
'height': The height of images trained by this model
|
||||||
|
'hash': A unique hash of this model's files on disk.
|
||||||
"""
|
"""
|
||||||
if not model_name:
|
if not model_name:
|
||||||
return self.get_model(self.current_model) if self.current_model else self.get_model(self.default_model())
|
return (
|
||||||
|
self.get_model(self.current_model)
|
||||||
|
if self.current_model
|
||||||
|
else self.get_model(self.default_model())
|
||||||
|
)
|
||||||
|
|
||||||
if not self.valid_model(model_name):
|
if not self.valid_model(model_name):
|
||||||
print(
|
print(
|
||||||
@ -135,6 +170,81 @@ class ModelManager(object):
|
|||||||
"hash": hash,
|
"hash": hash,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_model_vae(self, model_name: str=None)->AutoencoderKL:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned VAE as an
|
||||||
|
AutoencoderKL object. If no model name is provided, return the
|
||||||
|
vae from the model currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.vae)
|
||||||
|
|
||||||
|
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned CLIPTokenizer. If no
|
||||||
|
model name is provided, return the tokenizer from the model
|
||||||
|
currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.tokenizer)
|
||||||
|
|
||||||
|
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned UNet2DConditionModel. If no model
|
||||||
|
name is provided, return the UNet from the model
|
||||||
|
currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.unet)
|
||||||
|
|
||||||
|
def get_model_text_encoder(self, model_name: str=None)->CLIPTextModel:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned CLIPTextModel. If no
|
||||||
|
model name is provided, return the text encoder from the model
|
||||||
|
currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.text_encoder)
|
||||||
|
|
||||||
|
def get_model_feature_extractor(self, model_name: str=None)->CLIPFeatureExtractor:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned CLIPFeatureExtractor. If no
|
||||||
|
model name is provided, return the text encoder from the model
|
||||||
|
currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.feature_extractor)
|
||||||
|
|
||||||
|
def get_model_scheduler(self, model_name: str=None)->SchedulerMixin:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned scheduler. If no
|
||||||
|
model name is provided, return the text encoder from the model
|
||||||
|
currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.scheduler)
|
||||||
|
|
||||||
|
def _get_sub_model(
|
||||||
|
self,
|
||||||
|
model_name: str=None,
|
||||||
|
model_part: SDModelComponent=SDModelComponent.vae,
|
||||||
|
) -> Union[
|
||||||
|
AutoencoderKL,
|
||||||
|
CLIPTokenizer,
|
||||||
|
CLIPFeatureExtractor,
|
||||||
|
UNet2DConditionModel,
|
||||||
|
CLIPTextModel,
|
||||||
|
StableDiffusionSafetyChecker,
|
||||||
|
]:
|
||||||
|
"""Given a model name identified in models.yaml, and the part of the
|
||||||
|
model you wish to retrieve, return that part. Parts are in an Enum
|
||||||
|
class named SDModelComponent, and consist of:
|
||||||
|
SDModelComponent.vae
|
||||||
|
SDModelComponent.text_encoder
|
||||||
|
SDModelComponent.tokenizer
|
||||||
|
SDModelComponent.unet
|
||||||
|
SDModelComponent.scheduler
|
||||||
|
SDModelComponent.safety_checker
|
||||||
|
SDModelComponent.feature_extractor
|
||||||
|
"""
|
||||||
|
model_dict = self.get_model(model_name)
|
||||||
|
model = model_dict["model"]
|
||||||
|
return getattr(model, model_part.value)
|
||||||
|
|
||||||
def default_model(self) -> str | None:
|
def default_model(self) -> str | None:
|
||||||
"""
|
"""
|
||||||
Returns the name of the default model, or None
|
Returns the name of the default model, or None
|
||||||
@ -454,14 +564,18 @@ class ModelManager(object):
|
|||||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.list_models()[self.current_model]['status'] == 'active':
|
if self.list_models()[self.current_model]["status"] == "active":
|
||||||
self.offload_model(self.current_model)
|
self.offload_model(self.current_model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
vae_path = None
|
vae_path = None
|
||||||
if vae:
|
if vae:
|
||||||
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
|
vae_path = (
|
||||||
|
vae
|
||||||
|
if os.path.isabs(vae)
|
||||||
|
else os.path.normpath(os.path.join(Globals.root, vae))
|
||||||
|
)
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||||
@ -571,9 +685,7 @@ class ModelManager(object):
|
|||||||
models.yaml file.
|
models.yaml file.
|
||||||
"""
|
"""
|
||||||
model_name = model_name or Path(repo_or_path).stem
|
model_name = model_name or Path(repo_or_path).stem
|
||||||
model_description = (
|
model_description = description or f"Imported diffusers model {model_name}"
|
||||||
description or f"Imported diffusers model {model_name}"
|
|
||||||
)
|
|
||||||
new_config = dict(
|
new_config = dict(
|
||||||
description=model_description,
|
description=model_description,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
@ -602,7 +714,7 @@ class ModelManager(object):
|
|||||||
SDLegacyType.V2_v (V2 using 'v_prediction' prediction type)
|
SDLegacyType.V2_v (V2 using 'v_prediction' prediction type)
|
||||||
SDLegacyType.UNKNOWN
|
SDLegacyType.UNKNOWN
|
||||||
"""
|
"""
|
||||||
global_step = checkpoint.get('global_step')
|
global_step = checkpoint.get("global_step")
|
||||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -628,13 +740,13 @@ class ModelManager(object):
|
|||||||
return SDLegacyType.UNKNOWN
|
return SDLegacyType.UNKNOWN
|
||||||
|
|
||||||
def heuristic_import(
|
def heuristic_import(
|
||||||
self,
|
self,
|
||||||
path_url_or_repo: str,
|
path_url_or_repo: str,
|
||||||
model_name: str = None,
|
model_name: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
model_config_file: Path = None,
|
model_config_file: Path = None,
|
||||||
commit_to_conf: Path = None,
|
commit_to_conf: Path = None,
|
||||||
config_file_callback: Callable[[Path], Path] = None,
|
config_file_callback: Callable[[Path], Path] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Accept a string which could be:
|
"""Accept a string which could be:
|
||||||
- a HF diffusers repo_id
|
- a HF diffusers repo_id
|
||||||
@ -738,8 +850,8 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# another round of heuristics to guess the correct config file.
|
# another round of heuristics to guess the correct config file.
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
if model_path.suffix in [".ckpt",".pt"]:
|
if model_path.suffix in [".ckpt", ".pt"]:
|
||||||
self.scan_model(model_path,model_path)
|
self.scan_model(model_path, model_path)
|
||||||
checkpoint = torch.load(model_path)
|
checkpoint = torch.load(model_path)
|
||||||
else:
|
else:
|
||||||
checkpoint = safetensors.torch.load_file(model_path)
|
checkpoint = safetensors.torch.load_file(model_path)
|
||||||
@ -761,19 +873,16 @@ class ModelManager(object):
|
|||||||
elif model_type == SDLegacyType.V1_INPAINT:
|
elif model_type == SDLegacyType.V1_INPAINT:
|
||||||
print(" | SD-v1 inpainting model detected")
|
print(" | SD-v1 inpainting model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
|
Globals.root,
|
||||||
|
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2_v:
|
elif model_type == SDLegacyType.V2_v:
|
||||||
print(
|
print(" | SD-v2-v model detected")
|
||||||
" | SD-v2-v model detected"
|
|
||||||
)
|
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2_e:
|
elif model_type == SDLegacyType.V2_e:
|
||||||
print(
|
print(" | SD-v2-e model detected")
|
||||||
" | SD-v2-e model detected"
|
|
||||||
)
|
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||||
)
|
)
|
||||||
@ -820,16 +929,16 @@ class ModelManager(object):
|
|||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def convert_and_import(
|
def convert_and_import(
|
||||||
self,
|
self,
|
||||||
ckpt_path: Path,
|
ckpt_path: Path,
|
||||||
diffusers_path: Path,
|
diffusers_path: Path,
|
||||||
model_name=None,
|
model_name=None,
|
||||||
model_description=None,
|
model_description=None,
|
||||||
vae:dict=None,
|
vae: dict = None,
|
||||||
vae_path:Path=None,
|
vae_path: Path = None,
|
||||||
original_config_file: Path = None,
|
original_config_file: Path = None,
|
||||||
commit_to_conf: Path = None,
|
commit_to_conf: Path = None,
|
||||||
scan_needed: bool=True,
|
scan_needed: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a legacy ckpt weights file to diffuser model and import
|
Convert a legacy ckpt weights file to diffuser model and import
|
||||||
@ -857,10 +966,10 @@ class ModelManager(object):
|
|||||||
try:
|
try:
|
||||||
# By passing the specified VAE to the conversion function, the autoencoder
|
# By passing the specified VAE to the conversion function, the autoencoder
|
||||||
# will be built into the model rather than tacked on afterward via the config file
|
# will be built into the model rather than tacked on afterward via the config file
|
||||||
vae_model=None
|
vae_model = None
|
||||||
if vae:
|
if vae:
|
||||||
vae_model=self._load_vae(vae)
|
vae_model = self._load_vae(vae)
|
||||||
vae_path=None
|
vae_path = None
|
||||||
convert_ckpt_to_diffusers(
|
convert_ckpt_to_diffusers(
|
||||||
ckpt_path,
|
ckpt_path,
|
||||||
diffusers_path,
|
diffusers_path,
|
||||||
@ -976,15 +1085,15 @@ class ModelManager(object):
|
|||||||
legacy_locations = [
|
legacy_locations = [
|
||||||
Path(
|
Path(
|
||||||
models_dir,
|
models_dir,
|
||||||
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker"
|
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker",
|
||||||
),
|
),
|
||||||
Path(models_dir, "bert-base-uncased/models--bert-base-uncased"),
|
Path(models_dir, "bert-base-uncased/models--bert-base-uncased"),
|
||||||
Path(
|
Path(
|
||||||
models_dir,
|
models_dir,
|
||||||
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14"
|
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
legacy_locations.extend(list(global_cache_dir("diffusers").glob('*')))
|
legacy_locations.extend(list(global_cache_dir("diffusers").glob("*")))
|
||||||
|
|
||||||
legacy_layout = False
|
legacy_layout = False
|
||||||
for model in legacy_locations:
|
for model in legacy_locations:
|
||||||
@ -1003,7 +1112,7 @@ class ModelManager(object):
|
|||||||
>> make adjustments, please press ctrl-C now to abort and relaunch InvokeAI when you are ready.
|
>> make adjustments, please press ctrl-C now to abort and relaunch InvokeAI when you are ready.
|
||||||
>> Otherwise press <enter> to continue."""
|
>> Otherwise press <enter> to continue."""
|
||||||
)
|
)
|
||||||
input('continue> ')
|
input("continue> ")
|
||||||
|
|
||||||
# transformer files get moved into the hub directory
|
# transformer files get moved into the hub directory
|
||||||
if cls._is_huggingface_hub_directory_present():
|
if cls._is_huggingface_hub_directory_present():
|
||||||
|
@ -576,6 +576,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
|||||||
|
|
||||||
elif command.startswith("!replay"):
|
elif command.startswith("!replay"):
|
||||||
file_path = command.replace("!replay", "", 1).strip()
|
file_path = command.replace("!replay", "", 1).strip()
|
||||||
|
file_path = os.path.join(opt.outdir, file_path)
|
||||||
if infile is None and os.path.isfile(file_path):
|
if infile is None and os.path.isfile(file_path):
|
||||||
infile = open(file_path, "r", encoding="utf-8")
|
infile = open(file_path, "r", encoding="utf-8")
|
||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
from .test_invoker import create_edge
|
from .test_invoker import create_edge
|
||||||
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
|
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
|
from invokeai.app.invocations.collections import RangeInvocation
|
||||||
|
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
||||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
from invokeai.app.services.processor import DefaultInvocationProcessor
|
||||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||||
@ -21,13 +23,14 @@ def simple_graph():
|
|||||||
def mock_services():
|
def mock_services():
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
model_manager = None,
|
model_manager = None, # type: ignore
|
||||||
events = None,
|
events = None, # type: ignore
|
||||||
images = None,
|
images = None, # type: ignore
|
||||||
|
latents = None, # type: ignore
|
||||||
queue = MemoryInvocationQueue(),
|
queue = MemoryInvocationQueue(),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor(),
|
processor = DefaultInvocationProcessor(),
|
||||||
restoration = None,
|
restoration = None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
||||||
@ -73,31 +76,23 @@ def test_graph_is_not_complete(simple_graph, mock_services):
|
|||||||
|
|
||||||
def test_graph_state_expands_iterator(mock_services):
|
def test_graph_state_expands_iterator(mock_services):
|
||||||
graph = Graph()
|
graph = Graph()
|
||||||
test_prompts = ["Banana sushi", "Cat sushi"]
|
graph.add_node(RangeInvocation(id = "0", start = 0, stop = 3, step = 1))
|
||||||
graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts)))
|
graph.add_node(IterateInvocation(id = "1"))
|
||||||
graph.add_node(IterateInvocation(id = "2"))
|
graph.add_node(MultiplyInvocation(id = "2", b = 10))
|
||||||
graph.add_node(ImageTestInvocation(id = "3"))
|
graph.add_node(AddInvocation(id = "3", b = 1))
|
||||||
graph.add_edge(create_edge("1", "collection", "2", "collection"))
|
graph.add_edge(create_edge("0", "collection", "1", "collection"))
|
||||||
graph.add_edge(create_edge("2", "item", "3", "prompt"))
|
graph.add_edge(create_edge("1", "item", "2", "a"))
|
||||||
|
graph.add_edge(create_edge("2", "a", "3", "a"))
|
||||||
|
|
||||||
g = GraphExecutionState(graph = graph)
|
g = GraphExecutionState(graph = graph)
|
||||||
n1 = invoke_next(g, mock_services)
|
while not g.is_complete():
|
||||||
n2 = invoke_next(g, mock_services)
|
invoke_next(g, mock_services)
|
||||||
n3 = invoke_next(g, mock_services)
|
|
||||||
n4 = invoke_next(g, mock_services)
|
|
||||||
n5 = invoke_next(g, mock_services)
|
|
||||||
|
|
||||||
assert g.prepared_source_mapping[n1[0].id] == "1"
|
prepared_add_nodes = g.source_prepared_mapping['3']
|
||||||
assert g.prepared_source_mapping[n2[0].id] == "2"
|
results = set([g.results[n].a for n in prepared_add_nodes])
|
||||||
assert g.prepared_source_mapping[n3[0].id] == "2"
|
expected = set([1, 11, 21])
|
||||||
assert g.prepared_source_mapping[n4[0].id] == "3"
|
assert results == expected
|
||||||
assert g.prepared_source_mapping[n5[0].id] == "3"
|
|
||||||
|
|
||||||
assert isinstance(n4[0], ImageTestInvocation)
|
|
||||||
assert isinstance(n5[0], ImageTestInvocation)
|
|
||||||
|
|
||||||
prompts = [n4[0].prompt, n5[0].prompt]
|
|
||||||
assert sorted(prompts) == sorted(test_prompts)
|
|
||||||
|
|
||||||
def test_graph_state_collects(mock_services):
|
def test_graph_state_collects(mock_services):
|
||||||
graph = Graph()
|
graph = Graph()
|
||||||
|
@ -24,10 +24,11 @@ def mock_services() -> InvocationServices:
|
|||||||
model_manager = None, # type: ignore
|
model_manager = None, # type: ignore
|
||||||
events = TestEventService(),
|
events = TestEventService(),
|
||||||
images = None, # type: ignore
|
images = None, # type: ignore
|
||||||
|
latents = None, # type: ignore
|
||||||
queue = MemoryInvocationQueue(),
|
queue = MemoryInvocationQueue(),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor(),
|
processor = DefaultInvocationProcessor(),
|
||||||
restoration = None,
|
restoration = None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
|
Loading…
Reference in New Issue
Block a user