mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into dev/pytorch2
This commit is contained in:
commit
3c50448ccf
8
.github/CODEOWNERS
vendored
8
.github/CODEOWNERS
vendored
@ -1,16 +1,16 @@
|
|||||||
# continuous integration
|
# continuous integration
|
||||||
/.github/workflows/ @mauwii @lstein
|
/.github/workflows/ @mauwii @lstein @blessedcoolant
|
||||||
|
|
||||||
# documentation
|
# documentation
|
||||||
/docs/ @lstein @mauwii @tildebyte
|
/docs/ @lstein @mauwii @tildebyte @blessedcoolant
|
||||||
/mkdocs.yml @lstein @mauwii
|
/mkdocs.yml @lstein @mauwii @blessedcoolant
|
||||||
|
|
||||||
# nodes
|
# nodes
|
||||||
/invokeai/app/ @Kyle0654 @blessedcoolant
|
/invokeai/app/ @Kyle0654 @blessedcoolant
|
||||||
|
|
||||||
# installation and configuration
|
# installation and configuration
|
||||||
/pyproject.toml @mauwii @lstein @blessedcoolant
|
/pyproject.toml @mauwii @lstein @blessedcoolant
|
||||||
/docker/ @mauwii @lstein
|
/docker/ @mauwii @lstein @blessedcoolant
|
||||||
/scripts/ @ebr @lstein
|
/scripts/ @ebr @lstein
|
||||||
/installer/ @lstein @ebr
|
/installer/ @lstein @ebr
|
||||||
/invokeai/assets @lstein @ebr
|
/invokeai/assets @lstein @ebr
|
||||||
|
4
.github/workflows/build-container.yml
vendored
4
.github/workflows/build-container.yml
vendored
@ -16,6 +16,10 @@ on:
|
|||||||
- 'v*.*.*'
|
- 'v*.*.*'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
packages: write
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
docker:
|
docker:
|
||||||
if: github.event.pull_request.draft == false
|
if: github.event.pull_request.draft == false
|
||||||
|
3
.github/workflows/mkdocs-material.yml
vendored
3
.github/workflows/mkdocs-material.yml
vendored
@ -5,6 +5,9 @@ on:
|
|||||||
- 'main'
|
- 'main'
|
||||||
- 'development'
|
- 'development'
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
mkdocs-material:
|
mkdocs-material:
|
||||||
if: github.event.pull_request.draft == false
|
if: github.event.pull_request.draft == false
|
||||||
|
1
.github/workflows/test-invoke-pip-skip.yml
vendored
1
.github/workflows/test-invoke-pip-skip.yml
vendored
@ -6,7 +6,6 @@ on:
|
|||||||
- '!pyproject.toml'
|
- '!pyproject.toml'
|
||||||
- '!invokeai/**'
|
- '!invokeai/**'
|
||||||
- 'invokeai/frontend/web/**'
|
- 'invokeai/frontend/web/**'
|
||||||
- '!invokeai/frontend/web/dist/**'
|
|
||||||
merge_group:
|
merge_group:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
|
2
.github/workflows/test-invoke-pip.yml
vendored
2
.github/workflows/test-invoke-pip.yml
vendored
@ -7,13 +7,11 @@ on:
|
|||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
- 'invokeai/**'
|
- 'invokeai/**'
|
||||||
- '!invokeai/frontend/web/**'
|
- '!invokeai/frontend/web/**'
|
||||||
- 'invokeai/frontend/web/dist/**'
|
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
- 'invokeai/**'
|
- 'invokeai/**'
|
||||||
- '!invokeai/frontend/web/**'
|
- '!invokeai/frontend/web/**'
|
||||||
- 'invokeai/frontend/web/dist/**'
|
|
||||||
types:
|
types:
|
||||||
- 'ready_for_review'
|
- 'ready_for_review'
|
||||||
- 'opened'
|
- 'opened'
|
||||||
|
@ -139,13 +139,13 @@ not supported.
|
|||||||
_For Windows/Linux with an NVIDIA GPU:_
|
_For Windows/Linux with an NVIDIA GPU:_
|
||||||
|
|
||||||
```terminal
|
```terminal
|
||||||
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||||
```
|
```
|
||||||
|
|
||||||
_For Linux with an AMD GPU:_
|
_For Linux with an AMD GPU:_
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2
|
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||||
```
|
```
|
||||||
|
|
||||||
_For Macintoshes, either Intel or M1/M2:_
|
_For Macintoshes, either Intel or M1/M2:_
|
||||||
|
@ -168,11 +168,15 @@ used by Stable Diffusion 1.4 and 1.5.
|
|||||||
After installation, your `models.yaml` should contain an entry that looks like
|
After installation, your `models.yaml` should contain an entry that looks like
|
||||||
this one:
|
this one:
|
||||||
|
|
||||||
inpainting-1.5: weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
|
```yml
|
||||||
description: SD inpainting v1.5 config:
|
inpainting-1.5:
|
||||||
configs/stable-diffusion/v1-inpainting-inference.yaml vae:
|
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
|
||||||
models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt width: 512
|
description: SD inpainting v1.5
|
||||||
|
config: configs/stable-diffusion/v1-inpainting-inference.yaml
|
||||||
|
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||||
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
|
```
|
||||||
|
|
||||||
As shown in the example, you may include a VAE fine-tuning weights file as well.
|
As shown in the example, you may include a VAE fine-tuning weights file as well.
|
||||||
This is strongly recommended.
|
This is strongly recommended.
|
||||||
|
@ -268,7 +268,7 @@ model is so good at inpainting, a good substitute is to use the `clipseg` text
|
|||||||
masking option:
|
masking option:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
invoke> a fluffy cat eating a hotdot
|
invoke> a fluffy cat eating a hotdog
|
||||||
Outputs:
|
Outputs:
|
||||||
[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog
|
[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog
|
||||||
invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat
|
invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat
|
||||||
|
@ -417,7 +417,7 @@ Then type the following commands:
|
|||||||
|
|
||||||
=== "AMD System"
|
=== "AMD System"
|
||||||
```bash
|
```bash
|
||||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.2
|
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||||
```
|
```
|
||||||
|
|
||||||
### Corrupted configuration file
|
### Corrupted configuration file
|
||||||
|
@ -154,7 +154,7 @@ manager, please follow these steps:
|
|||||||
=== "ROCm (AMD)"
|
=== "ROCm (AMD)"
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2
|
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "CPU (Intel Macs & non-GPU systems)"
|
=== "CPU (Intel Macs & non-GPU systems)"
|
||||||
@ -315,7 +315,7 @@ installation protocol (important!)
|
|||||||
|
|
||||||
=== "ROCm (AMD)"
|
=== "ROCm (AMD)"
|
||||||
```bash
|
```bash
|
||||||
pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2
|
pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "CPU (Intel Macs & non-GPU systems)"
|
=== "CPU (Intel Macs & non-GPU systems)"
|
||||||
|
@ -110,7 +110,7 @@ recipes are available
|
|||||||
|
|
||||||
When installing torch and torchvision manually with `pip`, remember to provide
|
When installing torch and torchvision manually with `pip`, remember to provide
|
||||||
the argument `--extra-index-url
|
the argument `--extra-index-url
|
||||||
https://download.pytorch.org/whl/rocm5.2` as described in the [Manual
|
https://download.pytorch.org/whl/rocm5.4.2` as described in the [Manual
|
||||||
Installation Guide](020_INSTALL_MANUAL.md).
|
Installation Guide](020_INSTALL_MANUAL.md).
|
||||||
|
|
||||||
This will be done automatically for you if you use the installer
|
This will be done automatically for you if you use the installer
|
||||||
|
@ -456,7 +456,7 @@ def get_torch_source() -> (Union[str, None],str):
|
|||||||
optional_modules = None
|
optional_modules = None
|
||||||
if OS == "Linux":
|
if OS == "Linux":
|
||||||
if device == "rocm":
|
if device == "rocm":
|
||||||
url = "https://download.pytorch.org/whl/rocm5.2"
|
url = "https://download.pytorch.org/whl/rocm5.4.2"
|
||||||
elif device == "cpu":
|
elif device == "cpu":
|
||||||
url = "https://download.pytorch.org/whl/cpu"
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
|
||||||
|
@ -24,9 +24,9 @@ if [ "$(uname -s)" == "Darwin" ]; then
|
|||||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ "$0" != "bash" ]; then
|
||||||
while true
|
while true
|
||||||
do
|
do
|
||||||
if [ "$0" != "bash" ]; then
|
|
||||||
echo "Do you want to generate images using the"
|
echo "Do you want to generate images using the"
|
||||||
echo "1. command-line interface"
|
echo "1. command-line interface"
|
||||||
echo "2. browser-based UI"
|
echo "2. browser-based UI"
|
||||||
@ -87,9 +87,9 @@ if [ "$0" != "bash" ]; then
|
|||||||
echo "Invalid selection"
|
echo "Invalid selection"
|
||||||
exit;;
|
exit;;
|
||||||
esac
|
esac
|
||||||
|
done
|
||||||
else # in developer console
|
else # in developer console
|
||||||
python --version
|
python --version
|
||||||
echo "Press ^D to exit"
|
echo "Press ^D to exit"
|
||||||
export PS1="(InvokeAI) \u@\h \w> "
|
export PS1="(InvokeAI) \u@\h \w> "
|
||||||
fi
|
fi
|
||||||
done
|
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
import os
|
import os
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
|
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
from ...backend import Globals
|
from ...backend import Globals
|
||||||
from ..services.model_manager_initializer import get_model_manager
|
from ..services.model_manager_initializer import get_model_manager
|
||||||
from ..services.restoration_services import RestorationServices
|
from ..services.restoration_services import RestorationServices
|
||||||
@ -54,7 +56,9 @@ class ApiDependencies:
|
|||||||
os.path.join(os.path.dirname(__file__), "../../../../outputs")
|
os.path.join(os.path.dirname(__file__), "../../../../outputs")
|
||||||
)
|
)
|
||||||
|
|
||||||
images = DiskImageStorage(output_folder)
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
|
||||||
|
|
||||||
|
images = DiskImageStorage(f'{output_folder}/images')
|
||||||
|
|
||||||
# TODO: build a file/path manager?
|
# TODO: build a file/path manager?
|
||||||
db_location = os.path.join(output_folder, "invokeai.db")
|
db_location = os.path.join(output_folder, "invokeai.db")
|
||||||
@ -62,6 +66,7 @@ class ApiDependencies:
|
|||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=get_model_manager(config),
|
model_manager=get_model_manager(config),
|
||||||
events=events,
|
events=events,
|
||||||
|
latents=latents,
|
||||||
images=images,
|
images=images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||||
|
@ -23,6 +23,16 @@ async def get_image(
|
|||||||
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
|
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
|
||||||
return FileResponse(filename)
|
return FileResponse(filename)
|
||||||
|
|
||||||
|
@images_router.get("/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail")
|
||||||
|
async def get_thumbnail(
|
||||||
|
image_type: ImageType = Path(description="The type of image to get"),
|
||||||
|
image_name: str = Path(description="The name of the image to get"),
|
||||||
|
):
|
||||||
|
"""Gets a thumbnail"""
|
||||||
|
# TODO: This is not really secure at all. At least make sure only output results are served
|
||||||
|
filename = ApiDependencies.invoker.services.images.get_path(image_type, 'thumbnails/' + image_name)
|
||||||
|
return FileResponse(filename)
|
||||||
|
|
||||||
|
|
||||||
@images_router.post(
|
@images_router.post(
|
||||||
"/uploads/",
|
"/uploads/",
|
||||||
|
279
invokeai/app/api/routers/models.py
Normal file
279
invokeai/app/api/routers/models.py
Normal file
@ -0,0 +1,279 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from typing import Annotated, Any, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from fastapi.routing import APIRouter
|
||||||
|
from pydantic import BaseModel, Field, parse_obj_as
|
||||||
|
|
||||||
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
|
|
||||||
|
class VaeRepo(BaseModel):
|
||||||
|
repo_id: str = Field(description="The repo ID to use for this VAE")
|
||||||
|
path: Optional[str] = Field(description="The path to the VAE")
|
||||||
|
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInfo(BaseModel):
|
||||||
|
description: Optional[str] = Field(description="A description of the model")
|
||||||
|
|
||||||
|
|
||||||
|
class CkptModelInfo(ModelInfo):
|
||||||
|
format: Literal['ckpt'] = 'ckpt'
|
||||||
|
|
||||||
|
config: str = Field(description="The path to the model config")
|
||||||
|
weights: str = Field(description="The path to the model weights")
|
||||||
|
vae: str = Field(description="The path to the model VAE")
|
||||||
|
width: Optional[int] = Field(description="The width of the model")
|
||||||
|
height: Optional[int] = Field(description="The height of the model")
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusersModelInfo(ModelInfo):
|
||||||
|
format: Literal['diffusers'] = 'diffusers'
|
||||||
|
|
||||||
|
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
||||||
|
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||||
|
path: Optional[str] = Field(description="The path to the model")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelsList(BaseModel):
|
||||||
|
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@models_router.get(
|
||||||
|
"/",
|
||||||
|
operation_id="list_models",
|
||||||
|
responses={200: {"model": ModelsList }},
|
||||||
|
)
|
||||||
|
async def list_models() -> ModelsList:
|
||||||
|
"""Gets a list of models"""
|
||||||
|
models_raw = ApiDependencies.invoker.services.model_manager.list_models()
|
||||||
|
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||||
|
return models
|
||||||
|
|
||||||
|
# @socketio.on("requestSystemConfig")
|
||||||
|
# def handle_request_capabilities():
|
||||||
|
# print(">> System config requested")
|
||||||
|
# config = self.get_system_config()
|
||||||
|
# config["model_list"] = self.generate.model_manager.list_models()
|
||||||
|
# config["infill_methods"] = infill_methods()
|
||||||
|
# socketio.emit("systemConfig", config)
|
||||||
|
|
||||||
|
# @socketio.on("searchForModels")
|
||||||
|
# def handle_search_models(search_folder: str):
|
||||||
|
# try:
|
||||||
|
# if not search_folder:
|
||||||
|
# socketio.emit(
|
||||||
|
# "foundModels",
|
||||||
|
# {"search_folder": None, "found_models": None},
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# (
|
||||||
|
# search_folder,
|
||||||
|
# found_models,
|
||||||
|
# ) = self.generate.model_manager.search_models(search_folder)
|
||||||
|
# socketio.emit(
|
||||||
|
# "foundModels",
|
||||||
|
# {"search_folder": search_folder, "found_models": found_models},
|
||||||
|
# )
|
||||||
|
# except Exception as e:
|
||||||
|
# self.handle_exceptions(e)
|
||||||
|
# print("\n")
|
||||||
|
|
||||||
|
# @socketio.on("addNewModel")
|
||||||
|
# def handle_add_model(new_model_config: dict):
|
||||||
|
# try:
|
||||||
|
# model_name = new_model_config["name"]
|
||||||
|
# del new_model_config["name"]
|
||||||
|
# model_attributes = new_model_config
|
||||||
|
# if len(model_attributes["vae"]) == 0:
|
||||||
|
# del model_attributes["vae"]
|
||||||
|
# update = False
|
||||||
|
# current_model_list = self.generate.model_manager.list_models()
|
||||||
|
# if model_name in current_model_list:
|
||||||
|
# update = True
|
||||||
|
|
||||||
|
# print(f">> Adding New Model: {model_name}")
|
||||||
|
|
||||||
|
# self.generate.model_manager.add_model(
|
||||||
|
# model_name=model_name,
|
||||||
|
# model_attributes=model_attributes,
|
||||||
|
# clobber=True,
|
||||||
|
# )
|
||||||
|
# self.generate.model_manager.commit(opt.conf)
|
||||||
|
|
||||||
|
# new_model_list = self.generate.model_manager.list_models()
|
||||||
|
# socketio.emit(
|
||||||
|
# "newModelAdded",
|
||||||
|
# {
|
||||||
|
# "new_model_name": model_name,
|
||||||
|
# "model_list": new_model_list,
|
||||||
|
# "update": update,
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
# print(f">> New Model Added: {model_name}")
|
||||||
|
# except Exception as e:
|
||||||
|
# self.handle_exceptions(e)
|
||||||
|
|
||||||
|
# @socketio.on("deleteModel")
|
||||||
|
# def handle_delete_model(model_name: str):
|
||||||
|
# try:
|
||||||
|
# print(f">> Deleting Model: {model_name}")
|
||||||
|
# self.generate.model_manager.del_model(model_name)
|
||||||
|
# self.generate.model_manager.commit(opt.conf)
|
||||||
|
# updated_model_list = self.generate.model_manager.list_models()
|
||||||
|
# socketio.emit(
|
||||||
|
# "modelDeleted",
|
||||||
|
# {
|
||||||
|
# "deleted_model_name": model_name,
|
||||||
|
# "model_list": updated_model_list,
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
# print(f">> Model Deleted: {model_name}")
|
||||||
|
# except Exception as e:
|
||||||
|
# self.handle_exceptions(e)
|
||||||
|
|
||||||
|
# @socketio.on("requestModelChange")
|
||||||
|
# def handle_set_model(model_name: str):
|
||||||
|
# try:
|
||||||
|
# print(f">> Model change requested: {model_name}")
|
||||||
|
# model = self.generate.set_model(model_name)
|
||||||
|
# model_list = self.generate.model_manager.list_models()
|
||||||
|
# if model is None:
|
||||||
|
# socketio.emit(
|
||||||
|
# "modelChangeFailed",
|
||||||
|
# {"model_name": model_name, "model_list": model_list},
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# socketio.emit(
|
||||||
|
# "modelChanged",
|
||||||
|
# {"model_name": model_name, "model_list": model_list},
|
||||||
|
# )
|
||||||
|
# except Exception as e:
|
||||||
|
# self.handle_exceptions(e)
|
||||||
|
|
||||||
|
# @socketio.on("convertToDiffusers")
|
||||||
|
# def convert_to_diffusers(model_to_convert: dict):
|
||||||
|
# try:
|
||||||
|
# if model_info := self.generate.model_manager.model_info(
|
||||||
|
# model_name=model_to_convert["model_name"]
|
||||||
|
# ):
|
||||||
|
# if "weights" in model_info:
|
||||||
|
# ckpt_path = Path(model_info["weights"])
|
||||||
|
# original_config_file = Path(model_info["config"])
|
||||||
|
# model_name = model_to_convert["model_name"]
|
||||||
|
# model_description = model_info["description"]
|
||||||
|
# else:
|
||||||
|
# self.socketio.emit(
|
||||||
|
# "error", {"message": "Model is not a valid checkpoint file"}
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# self.socketio.emit(
|
||||||
|
# "error", {"message": "Could not retrieve model info."}
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if not ckpt_path.is_absolute():
|
||||||
|
# ckpt_path = Path(Globals.root, ckpt_path)
|
||||||
|
|
||||||
|
# if original_config_file and not original_config_file.is_absolute():
|
||||||
|
# original_config_file = Path(Globals.root, original_config_file)
|
||||||
|
|
||||||
|
# diffusers_path = Path(
|
||||||
|
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if model_to_convert["save_location"] == "root":
|
||||||
|
# diffusers_path = Path(
|
||||||
|
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if (
|
||||||
|
# model_to_convert["save_location"] == "custom"
|
||||||
|
# and model_to_convert["custom_location"] is not None
|
||||||
|
# ):
|
||||||
|
# diffusers_path = Path(
|
||||||
|
# model_to_convert["custom_location"], f"{model_name}_diffusers"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if diffusers_path.exists():
|
||||||
|
# shutil.rmtree(diffusers_path)
|
||||||
|
|
||||||
|
# self.generate.model_manager.convert_and_import(
|
||||||
|
# ckpt_path,
|
||||||
|
# diffusers_path,
|
||||||
|
# model_name=model_name,
|
||||||
|
# model_description=model_description,
|
||||||
|
# vae=None,
|
||||||
|
# original_config_file=original_config_file,
|
||||||
|
# commit_to_conf=opt.conf,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# new_model_list = self.generate.model_manager.list_models()
|
||||||
|
# socketio.emit(
|
||||||
|
# "modelConverted",
|
||||||
|
# {
|
||||||
|
# "new_model_name": model_name,
|
||||||
|
# "model_list": new_model_list,
|
||||||
|
# "update": True,
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
# print(f">> Model Converted: {model_name}")
|
||||||
|
# except Exception as e:
|
||||||
|
# self.handle_exceptions(e)
|
||||||
|
|
||||||
|
# @socketio.on("mergeDiffusersModels")
|
||||||
|
# def merge_diffusers_models(model_merge_info: dict):
|
||||||
|
# try:
|
||||||
|
# models_to_merge = model_merge_info["models_to_merge"]
|
||||||
|
# model_ids_or_paths = [
|
||||||
|
# self.generate.model_manager.model_name_or_path(x)
|
||||||
|
# for x in models_to_merge
|
||||||
|
# ]
|
||||||
|
# merged_pipe = merge_diffusion_models(
|
||||||
|
# model_ids_or_paths,
|
||||||
|
# model_merge_info["alpha"],
|
||||||
|
# model_merge_info["interp"],
|
||||||
|
# model_merge_info["force"],
|
||||||
|
# )
|
||||||
|
|
||||||
|
# dump_path = global_models_dir() / "merged_models"
|
||||||
|
# if model_merge_info["model_merge_save_path"] is not None:
|
||||||
|
# dump_path = Path(model_merge_info["model_merge_save_path"])
|
||||||
|
|
||||||
|
# os.makedirs(dump_path, exist_ok=True)
|
||||||
|
# dump_path = dump_path / model_merge_info["merged_model_name"]
|
||||||
|
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||||
|
|
||||||
|
# merged_model_config = dict(
|
||||||
|
# model_name=model_merge_info["merged_model_name"],
|
||||||
|
# description=f'Merge of models {", ".join(models_to_merge)}',
|
||||||
|
# commit_to_conf=opt.conf,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
||||||
|
# "vae", None
|
||||||
|
# ):
|
||||||
|
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
|
||||||
|
# merged_model_config.update(vae=vae)
|
||||||
|
|
||||||
|
# self.generate.model_manager.import_diffuser_model(
|
||||||
|
# dump_path, **merged_model_config
|
||||||
|
# )
|
||||||
|
# new_model_list = self.generate.model_manager.list_models()
|
||||||
|
|
||||||
|
# socketio.emit(
|
||||||
|
# "modelsMerged",
|
||||||
|
# {
|
||||||
|
# "merged_models": models_to_merge,
|
||||||
|
# "merged_model_name": model_merge_info["merged_model_name"],
|
||||||
|
# "model_list": new_model_list,
|
||||||
|
# "update": True,
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
# print(f">> Models Merged: {models_to_merge}")
|
||||||
|
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||||
|
# except Exception as e:
|
||||||
|
# self.handle_exceptions(e)
|
@ -51,7 +51,7 @@ async def list_sessions(
|
|||||||
query: str = Query(default="", description="The query string to search for"),
|
query: str = Query(default="", description="The query string to search for"),
|
||||||
) -> PaginatedResults[GraphExecutionState]:
|
) -> PaginatedResults[GraphExecutionState]:
|
||||||
"""Gets a list of sessions, optionally searching"""
|
"""Gets a list of sessions, optionally searching"""
|
||||||
if filter == "":
|
if query == "":
|
||||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(
|
result = ApiDependencies.invoker.services.graph_execution_manager.list(
|
||||||
page, per_page
|
page, per_page
|
||||||
)
|
)
|
||||||
@ -270,3 +270,18 @@ async def invoke_session(
|
|||||||
|
|
||||||
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
||||||
return Response(status_code=202)
|
return Response(status_code=202)
|
||||||
|
|
||||||
|
|
||||||
|
@session_router.delete(
|
||||||
|
"/{session_id}/invoke",
|
||||||
|
operation_id="cancel_session_invoke",
|
||||||
|
responses={
|
||||||
|
202: {"description": "The invocation is canceled"}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def cancel_session_invoke(
|
||||||
|
session_id: str = Path(description="The id of the session to cancel"),
|
||||||
|
) -> None:
|
||||||
|
"""Invokes a session"""
|
||||||
|
ApiDependencies.invoker.cancel(session_id)
|
||||||
|
return Response(status_code=202)
|
||||||
|
@ -14,7 +14,7 @@ from pydantic.schema import schema
|
|||||||
|
|
||||||
from ..backend import Args
|
from ..backend import Args
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import images, sessions
|
from .api.routers import images, sessions, models
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations import *
|
from .invocations import *
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
@ -76,6 +76,8 @@ app.include_router(sessions.session_router, prefix="/api")
|
|||||||
|
|
||||||
app.include_router(images.images_router, prefix="/api")
|
app.include_router(images.images_router, prefix="/api")
|
||||||
|
|
||||||
|
app.include_router(models.models_router, prefix="/api")
|
||||||
|
|
||||||
|
|
||||||
# Build a custom OpenAPI to include all outputs
|
# Build a custom OpenAPI to include all outputs
|
||||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||||
|
@ -4,7 +4,8 @@ from abc import ABC, abstractmethod
|
|||||||
import argparse
|
import argparse
|
||||||
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
import networkx as nx
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
from ..invocations.image import ImageField
|
from ..invocations.image import ImageField
|
||||||
from ..services.graph import GraphExecutionState
|
from ..services.graph import GraphExecutionState
|
||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
@ -46,7 +47,7 @@ def add_parsers(
|
|||||||
f"--{name}",
|
f"--{name}",
|
||||||
dest=name,
|
dest=name,
|
||||||
type=field_type,
|
type=field_type,
|
||||||
default=field.default,
|
default=field.default if field.default_factory is None else field.default_factory(),
|
||||||
choices=allowed_values,
|
choices=allowed_values,
|
||||||
help=field.field_info.description,
|
help=field.field_info.description,
|
||||||
)
|
)
|
||||||
@ -55,7 +56,7 @@ def add_parsers(
|
|||||||
f"--{name}",
|
f"--{name}",
|
||||||
dest=name,
|
dest=name,
|
||||||
type=field.type_,
|
type=field.type_,
|
||||||
default=field.default,
|
default=field.default if field.default_factory is None else field.default_factory(),
|
||||||
help=field.field_info.description,
|
help=field.field_info.description,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -200,3 +201,39 @@ class SetDefaultCommand(BaseCommand):
|
|||||||
del context.defaults[self.field]
|
del context.defaults[self.field]
|
||||||
else:
|
else:
|
||||||
context.defaults[self.field] = self.value
|
context.defaults[self.field] = self.value
|
||||||
|
|
||||||
|
|
||||||
|
class DrawGraphCommand(BaseCommand):
|
||||||
|
"""Debugs a graph"""
|
||||||
|
type: Literal['draw_graph'] = 'draw_graph'
|
||||||
|
|
||||||
|
def run(self, context: CliContext) -> None:
|
||||||
|
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||||
|
nxgraph = session.graph.nx_graph_flat()
|
||||||
|
|
||||||
|
# Draw the networkx graph
|
||||||
|
plt.figure(figsize=(20, 20))
|
||||||
|
pos = nx.spectral_layout(nxgraph)
|
||||||
|
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
||||||
|
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
||||||
|
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||||
|
plt.axis("off")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
class DrawExecutionGraphCommand(BaseCommand):
|
||||||
|
"""Debugs an execution graph"""
|
||||||
|
type: Literal['draw_xgraph'] = 'draw_xgraph'
|
||||||
|
|
||||||
|
def run(self, context: CliContext) -> None:
|
||||||
|
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||||
|
nxgraph = session.execution_graph.nx_graph_flat()
|
||||||
|
|
||||||
|
# Draw the networkx graph
|
||||||
|
plt.figure(figsize=(20, 20))
|
||||||
|
pos = nx.spectral_layout(nxgraph)
|
||||||
|
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
||||||
|
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
||||||
|
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||||
|
plt.axis("off")
|
||||||
|
plt.show()
|
||||||
|
167
invokeai/app/cli/completer.py
Normal file
167
invokeai/app/cli/completer.py
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
"""
|
||||||
|
Readline helper functions for cli_app.py
|
||||||
|
You may import the global singleton `completer` to get access to the
|
||||||
|
completer object.
|
||||||
|
"""
|
||||||
|
import atexit
|
||||||
|
import readline
|
||||||
|
import shlex
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
||||||
|
|
||||||
|
from ...backend import ModelManager, Globals
|
||||||
|
from ..invocations.baseinvocation import BaseInvocation
|
||||||
|
from .commands import BaseCommand
|
||||||
|
|
||||||
|
# singleton object, class variable
|
||||||
|
completer = None
|
||||||
|
|
||||||
|
class Completer(object):
|
||||||
|
|
||||||
|
def __init__(self, model_manager: ModelManager):
|
||||||
|
self.commands = self.get_commands()
|
||||||
|
self.matches = None
|
||||||
|
self.linebuffer = None
|
||||||
|
self.manager = model_manager
|
||||||
|
return
|
||||||
|
|
||||||
|
def complete(self, text, state):
|
||||||
|
"""
|
||||||
|
Complete commands and switches fromm the node CLI command line.
|
||||||
|
Switches are determined in a context-specific manner.
|
||||||
|
"""
|
||||||
|
|
||||||
|
buffer = readline.get_line_buffer()
|
||||||
|
if state == 0:
|
||||||
|
options = None
|
||||||
|
try:
|
||||||
|
current_command, current_switch = self.get_current_command(buffer)
|
||||||
|
options = self.get_command_options(current_command, current_switch)
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
options = options or list(self.parse_commands().keys())
|
||||||
|
|
||||||
|
if not text: # first time
|
||||||
|
self.matches = options
|
||||||
|
else:
|
||||||
|
self.matches = [s for s in options if s and s.startswith(text)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
match = self.matches[state]
|
||||||
|
except IndexError:
|
||||||
|
match = None
|
||||||
|
return match
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_commands(self)->List[object]:
|
||||||
|
"""
|
||||||
|
Return a list of all the client commands and invocations.
|
||||||
|
"""
|
||||||
|
return BaseCommand.get_commands() + BaseInvocation.get_invocations()
|
||||||
|
|
||||||
|
def get_current_command(self, buffer: str)->tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Parse the readline buffer to find the most recent command and its switch.
|
||||||
|
"""
|
||||||
|
if len(buffer)==0:
|
||||||
|
return None, None
|
||||||
|
tokens = shlex.split(buffer)
|
||||||
|
command = None
|
||||||
|
switch = None
|
||||||
|
for t in tokens:
|
||||||
|
if t[0].isalpha():
|
||||||
|
if switch is None:
|
||||||
|
command = t
|
||||||
|
else:
|
||||||
|
switch = t
|
||||||
|
# don't try to autocomplete switches that are already complete
|
||||||
|
if switch and buffer.endswith(' '):
|
||||||
|
switch=None
|
||||||
|
return command or '', switch or ''
|
||||||
|
|
||||||
|
def parse_commands(self)->Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Return a dict in which the keys are the command name
|
||||||
|
and the values are the parameters the command takes.
|
||||||
|
"""
|
||||||
|
result = dict()
|
||||||
|
for command in self.commands:
|
||||||
|
hints = get_type_hints(command)
|
||||||
|
name = get_args(hints['type'])[0]
|
||||||
|
result.update({name:hints})
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_command_options(self, command: str, switch: str)->List[str]:
|
||||||
|
"""
|
||||||
|
Return all the parameters that can be passed to the command as
|
||||||
|
command-line switches. Returns None if the command is unrecognized.
|
||||||
|
"""
|
||||||
|
parsed_commands = self.parse_commands()
|
||||||
|
if command not in parsed_commands:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# handle switches in the format "-foo=bar"
|
||||||
|
argument = None
|
||||||
|
if switch and '=' in switch:
|
||||||
|
switch, argument = switch.split('=')
|
||||||
|
|
||||||
|
parameter = switch.strip('-')
|
||||||
|
if parameter in parsed_commands[command]:
|
||||||
|
if argument is None:
|
||||||
|
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
||||||
|
else:
|
||||||
|
return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])]
|
||||||
|
else:
|
||||||
|
return [f"--{x}" for x in parsed_commands[command].keys()]
|
||||||
|
|
||||||
|
def get_parameter_options(self, parameter: str, typehint)->List[str]:
|
||||||
|
"""
|
||||||
|
Given a parameter type (such as Literal), offers autocompletions.
|
||||||
|
"""
|
||||||
|
if get_origin(typehint) == Literal:
|
||||||
|
return get_args(typehint)
|
||||||
|
if parameter == 'model':
|
||||||
|
return self.manager.model_names()
|
||||||
|
|
||||||
|
def _pre_input_hook(self):
|
||||||
|
if self.linebuffer:
|
||||||
|
readline.insert_text(self.linebuffer)
|
||||||
|
readline.redisplay()
|
||||||
|
self.linebuffer = None
|
||||||
|
|
||||||
|
def set_autocompleter(model_manager: ModelManager) -> Completer:
|
||||||
|
global completer
|
||||||
|
|
||||||
|
if completer:
|
||||||
|
return completer
|
||||||
|
|
||||||
|
completer = Completer(model_manager)
|
||||||
|
|
||||||
|
readline.set_completer(completer.complete)
|
||||||
|
# pyreadline3 does not have a set_auto_history() method
|
||||||
|
try:
|
||||||
|
readline.set_auto_history(True)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||||
|
readline.set_completer_delims(" ")
|
||||||
|
readline.parse_and_bind("tab: complete")
|
||||||
|
readline.parse_and_bind("set print-completions-horizontally off")
|
||||||
|
readline.parse_and_bind("set page-completions on")
|
||||||
|
readline.parse_and_bind("set skip-completed-text on")
|
||||||
|
readline.parse_and_bind("set show-all-if-ambiguous on")
|
||||||
|
|
||||||
|
histfile = Path(Globals.root, ".invoke_history")
|
||||||
|
try:
|
||||||
|
readline.read_history_file(histfile)
|
||||||
|
readline.set_history_length(1000)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
except OSError: # file likely corrupted
|
||||||
|
newname = f"{histfile}.old"
|
||||||
|
print(
|
||||||
|
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||||
|
)
|
||||||
|
histfile.replace(Path(newname))
|
||||||
|
atexit.register(readline.write_history_file, histfile)
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import shlex
|
import shlex
|
||||||
import time
|
import time
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -12,14 +13,17 @@ from typing import (
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
from ..backend import Args
|
from ..backend import Args
|
||||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
||||||
|
from .cli.completer import set_autocompleter
|
||||||
from .invocations import *
|
from .invocations import *
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
from .services.model_manager_initializer import get_model_manager
|
from .services.model_manager_initializer import get_model_manager
|
||||||
from .services.restoration_services import RestorationServices
|
from .services.restoration_services import RestorationServices
|
||||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState
|
from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible
|
||||||
from .services.image_storage import DiskImageStorage
|
from .services.image_storage import DiskImageStorage
|
||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
from .services.invocation_services import InvocationServices
|
from .services.invocation_services import InvocationServices
|
||||||
@ -43,7 +47,7 @@ def add_invocation_args(command_parser):
|
|||||||
"-l",
|
"-l",
|
||||||
action="append",
|
action="append",
|
||||||
nargs=3,
|
nargs=3,
|
||||||
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)",
|
help="A link in the format 'source_node source_field dest_field'. source_node can be relative to history (e.g. -1)",
|
||||||
)
|
)
|
||||||
|
|
||||||
command_parser.add_argument(
|
command_parser.add_argument(
|
||||||
@ -93,6 +97,9 @@ def generate_matching_edges(
|
|||||||
invalid_fields = set(["type", "id"])
|
invalid_fields = set(["type", "id"])
|
||||||
matching_fields = matching_fields.difference(invalid_fields)
|
matching_fields = matching_fields.difference(invalid_fields)
|
||||||
|
|
||||||
|
# Validate types
|
||||||
|
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])]
|
||||||
|
|
||||||
edges = [
|
edges = [
|
||||||
Edge(
|
Edge(
|
||||||
source=EdgeConnection(node_id=a.id, field=field),
|
source=EdgeConnection(node_id=a.id, field=field),
|
||||||
@ -130,6 +137,12 @@ def invoke_cli():
|
|||||||
config.parse_args()
|
config.parse_args()
|
||||||
model_manager = get_model_manager(config)
|
model_manager = get_model_manager(config)
|
||||||
|
|
||||||
|
# This initializes the autocompleter and returns it.
|
||||||
|
# Currently nothing is done with the returned Completer
|
||||||
|
# object, but the object can be used to change autocompletion
|
||||||
|
# behavior on the fly, if desired.
|
||||||
|
completer = set_autocompleter(model_manager)
|
||||||
|
|
||||||
events = EventServiceBase()
|
events = EventServiceBase()
|
||||||
|
|
||||||
output_folder = os.path.abspath(
|
output_folder = os.path.abspath(
|
||||||
@ -142,7 +155,8 @@ def invoke_cli():
|
|||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
events=events,
|
events=events,
|
||||||
images=DiskImageStorage(output_folder),
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||||
|
images=DiskImageStorage(f'{output_folder}/images'),
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||||
filename=db_location, table_name="graph_executions"
|
filename=db_location, table_name="graph_executions"
|
||||||
@ -155,6 +169,8 @@ def invoke_cli():
|
|||||||
session: GraphExecutionState = invoker.create_execution_state()
|
session: GraphExecutionState = invoker.create_execution_state()
|
||||||
parser = get_command_parser()
|
parser = get_command_parser()
|
||||||
|
|
||||||
|
re_negid = re.compile('^-[0-9]+$')
|
||||||
|
|
||||||
# Uncomment to print out previous sessions at startup
|
# Uncomment to print out previous sessions at startup
|
||||||
# print(services.session_manager.list())
|
# print(services.session_manager.list())
|
||||||
|
|
||||||
@ -162,8 +178,8 @@ def invoke_cli():
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
cmd_input = input("> ")
|
cmd_input = input("invoke> ")
|
||||||
except KeyboardInterrupt:
|
except (KeyboardInterrupt, EOFError):
|
||||||
# Ctrl-c exits
|
# Ctrl-c exits
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -220,7 +236,11 @@ def invoke_cli():
|
|||||||
# Parse provided links
|
# Parse provided links
|
||||||
if "link_node" in args and args["link_node"]:
|
if "link_node" in args and args["link_node"]:
|
||||||
for link in args["link_node"]:
|
for link in args["link_node"]:
|
||||||
link_node = context.session.graph.get_node(link)
|
node_id = link
|
||||||
|
if re_negid.match(node_id):
|
||||||
|
node_id = str(current_id + int(node_id))
|
||||||
|
|
||||||
|
link_node = context.session.graph.get_node(node_id)
|
||||||
matching_edges = generate_matching_edges(
|
matching_edges = generate_matching_edges(
|
||||||
link_node, command.command
|
link_node, command.command
|
||||||
)
|
)
|
||||||
@ -230,10 +250,15 @@ def invoke_cli():
|
|||||||
|
|
||||||
if "link" in args and args["link"]:
|
if "link" in args and args["link"]:
|
||||||
for link in args["link"]:
|
for link in args["link"]:
|
||||||
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
|
edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]]
|
||||||
|
|
||||||
|
node_id = link[0]
|
||||||
|
if re_negid.match(node_id):
|
||||||
|
node_id = str(current_id + int(node_id))
|
||||||
|
|
||||||
edges.append(
|
edges.append(
|
||||||
Edge(
|
Edge(
|
||||||
source=EdgeConnection(node_id=link[1], field=link[0]),
|
source=EdgeConnection(node_id=node_id, field=link[1]),
|
||||||
destination=EdgeConnection(
|
destination=EdgeConnection(
|
||||||
node_id=command.command.id, field=link[2]
|
node_id=command.command.id, field=link[2]
|
||||||
)
|
)
|
||||||
|
50
invokeai/app/invocations/collections.py
Normal file
50
invokeai/app/invocations/collections.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import cv2 as cv
|
||||||
|
import numpy as np
|
||||||
|
import numpy.random
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from ..services.image_storage import ImageType
|
||||||
|
from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput
|
||||||
|
from .image import ImageField, ImageOutput
|
||||||
|
|
||||||
|
|
||||||
|
class IntCollectionOutput(BaseInvocationOutput):
|
||||||
|
"""A collection of integers"""
|
||||||
|
|
||||||
|
type: Literal["int_collection"] = "int_collection"
|
||||||
|
|
||||||
|
# Outputs
|
||||||
|
collection: list[int] = Field(default=[], description="The int collection")
|
||||||
|
|
||||||
|
|
||||||
|
class RangeInvocation(BaseInvocation):
|
||||||
|
"""Creates a range"""
|
||||||
|
|
||||||
|
type: Literal["range"] = "range"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
start: int = Field(default=0, description="The start of the range")
|
||||||
|
stop: int = Field(default=10, description="The stop of the range")
|
||||||
|
step: int = Field(default=1, description="The step of the range")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
|
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||||
|
|
||||||
|
|
||||||
|
class RandomRangeInvocation(BaseInvocation):
|
||||||
|
"""Creates a collection of random numbers"""
|
||||||
|
|
||||||
|
type: Literal["random_range"] = "random_range"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
low: int = Field(default=0, description="The inclusive low value")
|
||||||
|
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||||
|
size: int = Field(default=1, description="The number of values to generate")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
|
return IntCollectionOutput(collection=list(numpy.random.randint(self.low, self.high, size=self.size)))
|
@ -1,22 +1,19 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from functools import partial
|
||||||
from typing import Any, Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from PIL import Image
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from skimage.exposure.histogram_matching import match_histograms
|
|
||||||
|
|
||||||
from ..services.image_storage import ImageType
|
from ..services.image_storage import ImageType
|
||||||
from ..services.invocation_services import InvocationServices
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput
|
from .image import ImageField, ImageOutput
|
||||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
|
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.util.util import image_to_dataURL
|
from ..util.util import diffusers_step_callback_adapter, CanceledException
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[
|
SAMPLER_NAME_VALUES = Literal[
|
||||||
tuple(InvokeAIGenerator.schedulers())
|
tuple(InvokeAIGenerator.schedulers())
|
||||||
@ -45,32 +42,26 @@ class TextToImageInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self, context: InvocationContext, sample: Tensor, step: int
|
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO: only output a preview image when requested
|
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
raise CanceledException
|
||||||
|
|
||||||
(width, height) = image.size
|
step = intermediate_state.step
|
||||||
width *= 8
|
if intermediate_state.predicted_original is not None:
|
||||||
height *= 8
|
# Some schedulers report not only the noisy latents at the current timestep,
|
||||||
|
# but also their estimate so far of what the de-noised latents will be.
|
||||||
|
sample = intermediate_state.predicted_original
|
||||||
|
else:
|
||||||
|
sample = intermediate_state.latents
|
||||||
|
|
||||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||||
|
|
||||||
context.services.events.emit_generator_progress(
|
|
||||||
context.graph_execution_state_id,
|
|
||||||
self.id,
|
|
||||||
{
|
|
||||||
"width": width,
|
|
||||||
"height": height,
|
|
||||||
"dataURL": dataURL
|
|
||||||
},
|
|
||||||
step,
|
|
||||||
self.steps,
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
def step_callback(state: PipelineIntermediateState):
|
# def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, state.latents, state.step)
|
# if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||||
|
# raise CanceledException
|
||||||
|
# self.dispatch_progress(context, state.latents, state.step)
|
||||||
|
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
@ -79,7 +70,7 @@ class TextToImageInvocation(BaseInvocation):
|
|||||||
model= context.services.model_manager.get_model()
|
model= context.services.model_manager.get_model()
|
||||||
outputs = Txt2Img(model).generate(
|
outputs = Txt2Img(model).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
step_callback=step_callback,
|
step_callback=partial(self.dispatch_progress, context),
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt"}
|
exclude={"prompt"}
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
@ -116,6 +107,22 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def dispatch_progress(
|
||||||
|
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||||
|
) -> None:
|
||||||
|
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||||
|
raise CanceledException
|
||||||
|
|
||||||
|
step = intermediate_state.step
|
||||||
|
if intermediate_state.predicted_original is not None:
|
||||||
|
# Some schedulers report not only the noisy latents at the current timestep,
|
||||||
|
# but also their estimate so far of what the de-noised latents will be.
|
||||||
|
sample = intermediate_state.predicted_original
|
||||||
|
else:
|
||||||
|
sample = intermediate_state.latents
|
||||||
|
|
||||||
|
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = (
|
image = (
|
||||||
None
|
None
|
||||||
@ -126,24 +133,23 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
)
|
)
|
||||||
mask = None
|
mask = None
|
||||||
|
|
||||||
def step_callback(sample, step=0):
|
|
||||||
self.dispatch_progress(context, sample, step)
|
|
||||||
|
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
model = context.services.model_manager.get_model()
|
model = context.services.model_manager.get_model()
|
||||||
generator_output = next(
|
outputs = Img2Img(model).generate(
|
||||||
Img2Img(model).generate(
|
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_image=image,
|
init_image=image,
|
||||||
init_mask=mask,
|
init_mask=mask,
|
||||||
step_callback=step_callback,
|
step_callback=partial(self.dispatch_progress, context),
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt", "image", "mask"}
|
exclude={"prompt", "image", "mask"}
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||||
|
# each time it is called. We only need the first one.
|
||||||
|
generator_output = next(outputs)
|
||||||
|
|
||||||
result_image = generator_output.image
|
result_image = generator_output.image
|
||||||
|
|
||||||
@ -173,6 +179,22 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
description="The amount by which to replace masked areas with latent noise",
|
description="The amount by which to replace masked areas with latent noise",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def dispatch_progress(
|
||||||
|
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||||
|
) -> None:
|
||||||
|
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||||
|
raise CanceledException
|
||||||
|
|
||||||
|
step = intermediate_state.step
|
||||||
|
if intermediate_state.predicted_original is not None:
|
||||||
|
# Some schedulers report not only the noisy latents at the current timestep,
|
||||||
|
# but also their estimate so far of what the de-noised latents will be.
|
||||||
|
sample = intermediate_state.predicted_original
|
||||||
|
else:
|
||||||
|
sample = intermediate_state.latents
|
||||||
|
|
||||||
|
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = (
|
image = (
|
||||||
None
|
None
|
||||||
@ -187,24 +209,23 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
else context.services.images.get(self.mask.image_type, self.mask.image_name)
|
else context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
def step_callback(sample, step=0):
|
|
||||||
self.dispatch_progress(context, sample, step)
|
|
||||||
|
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
manager = context.services.model_manager.get_model()
|
model = context.services.model_manager.get_model()
|
||||||
generator_output = next(
|
outputs = Inpaint(model).generate(
|
||||||
Inpaint(model).generate(
|
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_image=image,
|
init_img=image,
|
||||||
mask_image=mask,
|
init_mask=mask,
|
||||||
step_callback=step_callback,
|
step_callback=partial(self.dispatch_progress, context),
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt", "image", "mask"}
|
exclude={"prompt", "image", "mask"}
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||||
|
# each time it is called. We only need the first one.
|
||||||
|
generator_output = next(outputs)
|
||||||
|
|
||||||
result_image = generator_output.image
|
result_image = generator_output.image
|
||||||
|
|
||||||
|
@ -28,12 +28,28 @@ class ImageOutput(BaseInvocationOutput):
|
|||||||
image: ImageField = Field(default=None, description="The output image")
|
image: ImageField = Field(default=None, description="The output image")
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
'required': [
|
||||||
|
'type',
|
||||||
|
'image',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
class MaskOutput(BaseInvocationOutput):
|
class MaskOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output a mask"""
|
"""Base class for invocations that output a mask"""
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["mask"] = "mask"
|
type: Literal["mask"] = "mask"
|
||||||
mask: ImageField = Field(default=None, description="The output mask")
|
mask: ImageField = Field(default=None, description="The output mask")
|
||||||
#fomt: on
|
#fmt: on
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
'required': [
|
||||||
|
'type',
|
||||||
|
'mask',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
# TODO: this isn't really necessary anymore
|
# TODO: this isn't really necessary anymore
|
||||||
class LoadImageInvocation(BaseInvocation):
|
class LoadImageInvocation(BaseInvocation):
|
||||||
|
321
invokeai/app/invocations/latent.py
Normal file
321
invokeai/app/invocations/latent.py
Normal file
@ -0,0 +1,321 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from typing import Literal, Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from torch import Tensor
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...backend.model_management.model_manager import ModelManager
|
||||||
|
from ...backend.util.devices import CUDA_DEVICE, torch_dtype
|
||||||
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
|
from ...backend.image_util.seamless import configure_model_padding
|
||||||
|
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||||
|
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
||||||
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
|
import numpy as np
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
from ..services.image_storage import ImageType
|
||||||
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
|
from .image import ImageField, ImageOutput
|
||||||
|
from ...backend.generator import Generator
|
||||||
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
|
from ...backend.util.util import image_to_dataURL
|
||||||
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
import diffusers
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsField(BaseModel):
|
||||||
|
"""A latents field used for passing latents between invocations"""
|
||||||
|
|
||||||
|
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for invocations that output latents"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["latent_output"] = "latent_output"
|
||||||
|
latents: LatentsField = Field(default=None, description="The output latents")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
class NoiseOutput(BaseInvocationOutput):
|
||||||
|
"""Invocation noise output"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["noise_output"] = "noise_output"
|
||||||
|
noise: LatentsField = Field(default=None, description="The output noise")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this seems like a hack
|
||||||
|
scheduler_map = dict(
|
||||||
|
ddim=diffusers.DDIMScheduler,
|
||||||
|
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||||
|
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||||
|
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||||
|
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||||
|
k_euler=diffusers.EulerDiscreteScheduler,
|
||||||
|
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||||
|
k_heun=diffusers.HeunDiscreteScheduler,
|
||||||
|
k_lms=diffusers.LMSDiscreteScheduler,
|
||||||
|
plms=diffusers.PNDMScheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
SAMPLER_NAME_VALUES = Literal[
|
||||||
|
tuple(list(scheduler_map.keys()))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||||
|
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
|
||||||
|
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||||
|
# hack copied over from generate.py
|
||||||
|
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||||
|
scheduler.uses_inpainting_model = lambda: False
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8):
|
||||||
|
# limit noise to only the diffusion image channels, not the mask channels
|
||||||
|
input_channels = min(latent_channels, 4)
|
||||||
|
use_device = "cpu" if (use_mps_noise or device.type == "mps") else device
|
||||||
|
generator = torch.Generator(device=use_device).manual_seed(seed)
|
||||||
|
x = torch.randn(
|
||||||
|
[
|
||||||
|
1,
|
||||||
|
input_channels,
|
||||||
|
height // downsampling_factor,
|
||||||
|
width // downsampling_factor,
|
||||||
|
],
|
||||||
|
dtype=torch_dtype(device),
|
||||||
|
device=use_device,
|
||||||
|
generator=generator,
|
||||||
|
).to(device)
|
||||||
|
# if self.perlin > 0.0:
|
||||||
|
# perlin_noise = self.get_perlin_noise(
|
||||||
|
# width // self.downsampling_factor, height // self.downsampling_factor
|
||||||
|
# )
|
||||||
|
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseInvocation(BaseInvocation):
|
||||||
|
"""Generates latent noise."""
|
||||||
|
|
||||||
|
type: Literal["noise"] = "noise"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", )
|
||||||
|
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
|
||||||
|
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||||
|
device = torch.device(CUDA_DEVICE)
|
||||||
|
noise = get_noise(self.width, self.height, device, self.seed)
|
||||||
|
|
||||||
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
|
context.services.latents.set(name, noise)
|
||||||
|
return NoiseOutput(
|
||||||
|
noise=LatentsField(latents_name=name)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Text to image
|
||||||
|
class TextToLatentsInvocation(BaseInvocation):
|
||||||
|
"""Generates latents from a prompt."""
|
||||||
|
|
||||||
|
type: Literal["t2l"] = "t2l"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||||
|
# fmt: off
|
||||||
|
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||||
|
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||||
|
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||||
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
|
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||||
|
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||||
|
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
|
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
|
||||||
|
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
|
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
|
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||||
|
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
|
def dispatch_progress(
|
||||||
|
self, context: InvocationContext, sample: Tensor, step: int
|
||||||
|
) -> None:
|
||||||
|
# TODO: only output a preview image when requested
|
||||||
|
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||||
|
|
||||||
|
(width, height) = image.size
|
||||||
|
width *= 8
|
||||||
|
height *= 8
|
||||||
|
|
||||||
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||||
|
|
||||||
|
context.services.events.emit_generator_progress(
|
||||||
|
context.graph_execution_state_id,
|
||||||
|
self.id,
|
||||||
|
{
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"dataURL": dataURL
|
||||||
|
},
|
||||||
|
step,
|
||||||
|
self.steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||||
|
model_info = model_manager.get_model(self.model)
|
||||||
|
model_name = model_info['model_name']
|
||||||
|
model_hash = model_info['hash']
|
||||||
|
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||||
|
model.scheduler = get_scheduler(
|
||||||
|
model=model,
|
||||||
|
scheduler_name=self.sampler_name
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(model, DiffusionPipeline):
|
||||||
|
for component in [model.unet, model.vae]:
|
||||||
|
configure_model_padding(component,
|
||||||
|
self.seamless,
|
||||||
|
self.seamless_axes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
configure_model_padding(model,
|
||||||
|
self.seamless,
|
||||||
|
self.seamless_axes
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||||
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
|
||||||
|
conditioning_data = ConditioningData(
|
||||||
|
uc,
|
||||||
|
c,
|
||||||
|
self.cfg_scale,
|
||||||
|
extra_conditioning_info,
|
||||||
|
postprocessing_settings=PostprocessingSettings(
|
||||||
|
threshold=0.0,#threshold,
|
||||||
|
warmup=0.2,#warmup,
|
||||||
|
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||||
|
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||||
|
),
|
||||||
|
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta)
|
||||||
|
return conditioning_data
|
||||||
|
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
|
||||||
|
def step_callback(state: PipelineIntermediateState):
|
||||||
|
self.dispatch_progress(context, state.latents, state.step)
|
||||||
|
|
||||||
|
model = self.get_model(context.services.model_manager)
|
||||||
|
conditioning_data = self.get_conditioning_data(model)
|
||||||
|
|
||||||
|
# TODO: Verify the noise is the right size
|
||||||
|
|
||||||
|
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||||
|
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
|
||||||
|
noise=noise,
|
||||||
|
num_inference_steps=self.steps,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
callback=step_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
|
context.services.latents.set(name, result_latents)
|
||||||
|
return LatentsOutput(
|
||||||
|
latents=LatentsField(latents_name=name)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||||
|
"""Generates latents using latents as base image."""
|
||||||
|
|
||||||
|
type: Literal["l2l"] = "l2l"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||||
|
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
latent = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
|
def step_callback(state: PipelineIntermediateState):
|
||||||
|
self.dispatch_progress(context, state.latents, state.step)
|
||||||
|
|
||||||
|
model = self.get_model(context.services.model_manager)
|
||||||
|
conditioning_data = self.get_conditioning_data(model)
|
||||||
|
|
||||||
|
# TODO: Verify the noise is the right size
|
||||||
|
|
||||||
|
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||||
|
latent, device=model.device, dtype=latent.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
timesteps, _ = model.get_img2img_timesteps(
|
||||||
|
self.steps,
|
||||||
|
self.strength,
|
||||||
|
device=model.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||||
|
latents=initial_latents,
|
||||||
|
timesteps=timesteps,
|
||||||
|
noise=noise,
|
||||||
|
num_inference_steps=self.steps,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
callback=step_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
|
context.services.latents.set(name, result_latents)
|
||||||
|
return LatentsOutput(
|
||||||
|
latents=LatentsField(latents_name=name)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Latent to image
|
||||||
|
class LatentsToImageInvocation(BaseInvocation):
|
||||||
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
|
type: Literal["l2i"] = "l2i"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||||
|
model: str = Field(default="", description="The model to use")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
|
# TODO: this only really needs the vae
|
||||||
|
model_info = context.services.model_manager.get_model(self.model)
|
||||||
|
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
np_image = model.decode_latents(latents)
|
||||||
|
image = model.numpy_to_pil(np_image)[0]
|
||||||
|
|
||||||
|
image_type = ImageType.RESULT
|
||||||
|
image_name = context.services.images.create_name(
|
||||||
|
context.graph_execution_state_id, self.id
|
||||||
|
)
|
||||||
|
context.services.images.save(image_type, image_name, image)
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
|
)
|
68
invokeai/app/invocations/math.py
Normal file
68
invokeai/app/invocations/math.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
from PIL import Image, ImageFilter, ImageOps
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ..services.image_storage import ImageType
|
||||||
|
from ..services.invocation_services import InvocationServices
|
||||||
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
|
|
||||||
|
|
||||||
|
class IntOutput(BaseInvocationOutput):
|
||||||
|
"""An integer output"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["int_output"] = "int_output"
|
||||||
|
a: int = Field(default=None, description="The output integer")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class AddInvocation(BaseInvocation):
|
||||||
|
"""Adds two numbers"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["add"] = "add"
|
||||||
|
a: int = Field(default=0, description="The first number")
|
||||||
|
b: int = Field(default=0, description="The second number")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
|
return IntOutput(a=self.a + self.b)
|
||||||
|
|
||||||
|
|
||||||
|
class SubtractInvocation(BaseInvocation):
|
||||||
|
"""Subtracts two numbers"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["sub"] = "sub"
|
||||||
|
a: int = Field(default=0, description="The first number")
|
||||||
|
b: int = Field(default=0, description="The second number")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
|
return IntOutput(a=self.a - self.b)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiplyInvocation(BaseInvocation):
|
||||||
|
"""Multiplies two numbers"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["mul"] = "mul"
|
||||||
|
a: int = Field(default=0, description="The first number")
|
||||||
|
b: int = Field(default=0, description="The second number")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
|
return IntOutput(a=self.a * self.b)
|
||||||
|
|
||||||
|
|
||||||
|
class DivideInvocation(BaseInvocation):
|
||||||
|
"""Divides two numbers"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["div"] = "div"
|
||||||
|
a: int = Field(default=0, description="The first number")
|
||||||
|
b: int = Field(default=0, description="The second number")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
|
return IntOutput(a=int(self.a / self.b))
|
@ -12,3 +12,11 @@ class PromptOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
prompt: str = Field(default=None, description="The output prompt")
|
prompt: str = Field(default=None, description="The output prompt")
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
'required': [
|
||||||
|
'type',
|
||||||
|
'prompt',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
@ -127,6 +127,13 @@ class NodeAlreadyExecutedError(Exception):
|
|||||||
class GraphInvocationOutput(BaseInvocationOutput):
|
class GraphInvocationOutput(BaseInvocationOutput):
|
||||||
type: Literal["graph_output"] = "graph_output"
|
type: Literal["graph_output"] = "graph_output"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
'required': [
|
||||||
|
'type',
|
||||||
|
'image',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
# TODO: Fill this out and move to invocations
|
# TODO: Fill this out and move to invocations
|
||||||
class GraphInvocation(BaseInvocation):
|
class GraphInvocation(BaseInvocation):
|
||||||
@ -147,6 +154,13 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
item: Any = Field(description="The item being iterated over")
|
item: Any = Field(description="The item being iterated over")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
'required': [
|
||||||
|
'type',
|
||||||
|
'item',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
# TODO: Fill this out and move to invocations
|
# TODO: Fill this out and move to invocations
|
||||||
class IterateInvocation(BaseInvocation):
|
class IterateInvocation(BaseInvocation):
|
||||||
@ -169,6 +183,13 @@ class CollectInvocationOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
collection: list[Any] = Field(description="The collection of input items")
|
collection: list[Any] = Field(description="The collection of input items")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
'required': [
|
||||||
|
'type',
|
||||||
|
'collection',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
class CollectInvocation(BaseInvocation):
|
class CollectInvocation(BaseInvocation):
|
||||||
"""Collects values into a collection"""
|
"""Collects values into a collection"""
|
||||||
@ -1048,9 +1069,8 @@ class GraphExecutionState(BaseModel):
|
|||||||
n
|
n
|
||||||
for n in prepared_nodes
|
for n in prepared_nodes
|
||||||
if all(
|
if all(
|
||||||
pit
|
nx.has_path(execution_graph, pit[0], n)
|
||||||
for pit in parent_iterators
|
for pit in parent_iterators
|
||||||
if nx.has_path(execution_graph, pit[0], n)
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
|
@ -9,6 +9,7 @@ from queue import Queue
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
from invokeai.app.util.save_thumbnail import save_thumbnail
|
||||||
|
|
||||||
from invokeai.backend.image_util import PngWriter
|
from invokeai.backend.image_util import PngWriter
|
||||||
|
|
||||||
@ -66,6 +67,9 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
Path(os.path.join(output_folder, image_type)).mkdir(
|
Path(os.path.join(output_folder, image_type)).mkdir(
|
||||||
parents=True, exist_ok=True
|
parents=True, exist_ok=True
|
||||||
)
|
)
|
||||||
|
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
|
||||||
|
parents=True, exist_ok=True
|
||||||
|
)
|
||||||
|
|
||||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_type, image_name)
|
||||||
@ -87,7 +91,11 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
self.__pngWriter.save_image_and_prompt_to_png(
|
self.__pngWriter.save_image_and_prompt_to_png(
|
||||||
image, "", image_subpath, None
|
image, "", image_subpath, None
|
||||||
) # TODO: just pass full path to png writer
|
) # TODO: just pass full path to png writer
|
||||||
|
save_thumbnail(
|
||||||
|
image=image,
|
||||||
|
filename=image_name,
|
||||||
|
path=os.path.join(self.__output_folder, image_type, "thumbnails"),
|
||||||
|
)
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_type, image_name)
|
||||||
self.__set_cache(image_path, image)
|
self.__set_cache(image_path, image)
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
# TODO: make this serializable
|
# TODO: make this serializable
|
||||||
@ -10,6 +11,7 @@ class InvocationQueueItem:
|
|||||||
graph_execution_state_id: str
|
graph_execution_state_id: str
|
||||||
invocation_id: str
|
invocation_id: str
|
||||||
invoke_all: bool
|
invoke_all: bool
|
||||||
|
timestamp: float
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -22,6 +24,7 @@ class InvocationQueueItem:
|
|||||||
self.graph_execution_state_id = graph_execution_state_id
|
self.graph_execution_state_id = graph_execution_state_id
|
||||||
self.invocation_id = invocation_id
|
self.invocation_id = invocation_id
|
||||||
self.invoke_all = invoke_all
|
self.invoke_all = invoke_all
|
||||||
|
self.timestamp = time.time()
|
||||||
|
|
||||||
|
|
||||||
class InvocationQueueABC(ABC):
|
class InvocationQueueABC(ABC):
|
||||||
@ -35,15 +38,44 @@ class InvocationQueueABC(ABC):
|
|||||||
def put(self, item: InvocationQueueItem | None) -> None:
|
def put(self, item: InvocationQueueItem | None) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cancel(self, graph_execution_state_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MemoryInvocationQueue(InvocationQueueABC):
|
class MemoryInvocationQueue(InvocationQueueABC):
|
||||||
__queue: Queue
|
__queue: Queue
|
||||||
|
__cancellations: dict[str, float]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.__queue = Queue()
|
self.__queue = Queue()
|
||||||
|
self.__cancellations = dict()
|
||||||
|
|
||||||
def get(self) -> InvocationQueueItem:
|
def get(self) -> InvocationQueueItem:
|
||||||
return self.__queue.get()
|
item = self.__queue.get()
|
||||||
|
|
||||||
|
while isinstance(item, InvocationQueueItem) \
|
||||||
|
and item.graph_execution_state_id in self.__cancellations \
|
||||||
|
and self.__cancellations[item.graph_execution_state_id] > item.timestamp:
|
||||||
|
item = self.__queue.get()
|
||||||
|
|
||||||
|
# Clear old items
|
||||||
|
for graph_execution_state_id in list(self.__cancellations.keys()):
|
||||||
|
if self.__cancellations[graph_execution_state_id] < item.timestamp:
|
||||||
|
del self.__cancellations[graph_execution_state_id]
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
def put(self, item: InvocationQueueItem | None) -> None:
|
def put(self, item: InvocationQueueItem | None) -> None:
|
||||||
self.__queue.put(item)
|
self.__queue.put(item)
|
||||||
|
|
||||||
|
def cancel(self, graph_execution_state_id: str) -> None:
|
||||||
|
if graph_execution_state_id not in self.__cancellations:
|
||||||
|
self.__cancellations[graph_execution_state_id] = time.time()
|
||||||
|
|
||||||
|
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||||
|
return graph_execution_state_id in self.__cancellations
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
from invokeai.backend import ModelManager
|
from invokeai.backend import ModelManager
|
||||||
|
|
||||||
from .events import EventServiceBase
|
from .events import EventServiceBase
|
||||||
|
from .latent_storage import LatentsStorageBase
|
||||||
from .image_storage import ImageStorageBase
|
from .image_storage import ImageStorageBase
|
||||||
from .restoration_services import RestorationServices
|
from .restoration_services import RestorationServices
|
||||||
from .invocation_queue import InvocationQueueABC
|
from .invocation_queue import InvocationQueueABC
|
||||||
@ -11,6 +12,7 @@ class InvocationServices:
|
|||||||
"""Services that can be used by invocations"""
|
"""Services that can be used by invocations"""
|
||||||
|
|
||||||
events: EventServiceBase
|
events: EventServiceBase
|
||||||
|
latents: LatentsStorageBase
|
||||||
images: ImageStorageBase
|
images: ImageStorageBase
|
||||||
queue: InvocationQueueABC
|
queue: InvocationQueueABC
|
||||||
model_manager: ModelManager
|
model_manager: ModelManager
|
||||||
@ -24,6 +26,7 @@ class InvocationServices:
|
|||||||
self,
|
self,
|
||||||
model_manager: ModelManager,
|
model_manager: ModelManager,
|
||||||
events: EventServiceBase,
|
events: EventServiceBase,
|
||||||
|
latents: LatentsStorageBase,
|
||||||
images: ImageStorageBase,
|
images: ImageStorageBase,
|
||||||
queue: InvocationQueueABC,
|
queue: InvocationQueueABC,
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
@ -32,6 +35,7 @@ class InvocationServices:
|
|||||||
):
|
):
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.events = events
|
self.events = events
|
||||||
|
self.latents = latents
|
||||||
self.images = images
|
self.images = images
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
self.graph_execution_manager = graph_execution_manager
|
self.graph_execution_manager = graph_execution_manager
|
||||||
|
@ -33,7 +33,6 @@ class Invoker:
|
|||||||
self.services.graph_execution_manager.set(graph_execution_state)
|
self.services.graph_execution_manager.set(graph_execution_state)
|
||||||
|
|
||||||
# Queue the invocation
|
# Queue the invocation
|
||||||
print(f"queueing item {invocation.id}")
|
|
||||||
self.services.queue.put(
|
self.services.queue.put(
|
||||||
InvocationQueueItem(
|
InvocationQueueItem(
|
||||||
# session_id = session.id,
|
# session_id = session.id,
|
||||||
@ -51,6 +50,10 @@ class Invoker:
|
|||||||
self.services.graph_execution_manager.set(new_state)
|
self.services.graph_execution_manager.set(new_state)
|
||||||
return new_state
|
return new_state
|
||||||
|
|
||||||
|
def cancel(self, graph_execution_state_id: str) -> None:
|
||||||
|
"""Cancels the given execution state"""
|
||||||
|
self.services.queue.cancel(graph_execution_state_id)
|
||||||
|
|
||||||
def __start_service(self, service) -> None:
|
def __start_service(self, service) -> None:
|
||||||
# Call start() method on any services that have it
|
# Call start() method on any services that have it
|
||||||
start_op = getattr(service, "start", None)
|
start_op = getattr(service, "start", None)
|
||||||
|
93
invokeai/app/services/latent_storage.py
Normal file
93
invokeai/app/services/latent_storage.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from queue import Queue
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class LatentsStorageBase(ABC):
|
||||||
|
"""Responsible for storing and retrieving latents."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set(self, name: str, data: torch.Tensor) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||||
|
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||||
|
|
||||||
|
__cache: Dict[str, torch.Tensor]
|
||||||
|
__cache_ids: Queue
|
||||||
|
__max_cache_size: int
|
||||||
|
__underlying_storage: LatentsStorageBase
|
||||||
|
|
||||||
|
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||||
|
self.__underlying_storage = underlying_storage
|
||||||
|
self.__cache = dict()
|
||||||
|
self.__cache_ids = Queue()
|
||||||
|
self.__max_cache_size = max_cache_size
|
||||||
|
|
||||||
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
cache_item = self.__get_cache(name)
|
||||||
|
if cache_item is not None:
|
||||||
|
return cache_item
|
||||||
|
|
||||||
|
latent = self.__underlying_storage.get(name)
|
||||||
|
self.__set_cache(name, latent)
|
||||||
|
return latent
|
||||||
|
|
||||||
|
def set(self, name: str, data: torch.Tensor) -> None:
|
||||||
|
self.__underlying_storage.set(name, data)
|
||||||
|
self.__set_cache(name, data)
|
||||||
|
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
self.__underlying_storage.delete(name)
|
||||||
|
if name in self.__cache:
|
||||||
|
del self.__cache[name]
|
||||||
|
|
||||||
|
def __get_cache(self, name: str) -> torch.Tensor|None:
|
||||||
|
return None if name not in self.__cache else self.__cache[name]
|
||||||
|
|
||||||
|
def __set_cache(self, name: str, data: torch.Tensor):
|
||||||
|
if not name in self.__cache:
|
||||||
|
self.__cache[name] = data
|
||||||
|
self.__cache_ids.put(name)
|
||||||
|
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||||
|
self.__cache.pop(self.__cache_ids.get())
|
||||||
|
|
||||||
|
|
||||||
|
class DiskLatentsStorage(LatentsStorageBase):
|
||||||
|
"""Stores latents in a folder on disk without caching"""
|
||||||
|
|
||||||
|
__output_folder: str
|
||||||
|
|
||||||
|
def __init__(self, output_folder: str):
|
||||||
|
self.__output_folder = output_folder
|
||||||
|
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
latent_path = self.get_path(name)
|
||||||
|
return torch.load(latent_path)
|
||||||
|
|
||||||
|
def set(self, name: str, data: torch.Tensor) -> None:
|
||||||
|
latent_path = self.get_path(name)
|
||||||
|
torch.save(data, latent_path)
|
||||||
|
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
latent_path = self.get_path(name)
|
||||||
|
os.remove(latent_path)
|
||||||
|
|
||||||
|
def get_path(self, name: str) -> str:
|
||||||
|
return os.path.join(self.__output_folder, name)
|
||||||
|
|
@ -4,7 +4,7 @@ from threading import Event, Thread
|
|||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
from .invocation_queue import InvocationQueueItem
|
from .invocation_queue import InvocationQueueItem
|
||||||
from .invoker import InvocationProcessorABC, Invoker
|
from .invoker import InvocationProcessorABC, Invoker
|
||||||
|
from ..util.util import CanceledException
|
||||||
|
|
||||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||||
__invoker_thread: Thread
|
__invoker_thread: Thread
|
||||||
@ -58,6 +58,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check queue to see if this is canceled, and skip if so
|
||||||
|
if self.__invoker.services.queue.is_canceled(
|
||||||
|
graph_execution_state.id
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
# Save outputs and history
|
# Save outputs and history
|
||||||
graph_execution_state.complete(invocation.id, outputs)
|
graph_execution_state.complete(invocation.id, outputs)
|
||||||
|
|
||||||
@ -76,6 +82,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
except CanceledException:
|
||||||
|
pass
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error = traceback.format_exc()
|
error = traceback.format_exc()
|
||||||
|
|
||||||
@ -96,6 +105,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Check queue to see if this is canceled, and skip if so
|
||||||
|
if self.__invoker.services.queue.is_canceled(
|
||||||
|
graph_execution_state.id
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
# Queue any further commands if invoking all
|
# Queue any further commands if invoking all
|
||||||
is_complete = graph_execution_state.is_complete()
|
is_complete = graph_execution_state.is_complete()
|
||||||
if queue_item.invoke_all and not is_complete:
|
if queue_item.invoke_all and not is_complete:
|
||||||
|
@ -59,6 +59,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||||
(item.json(),),
|
(item.json(),),
|
||||||
)
|
)
|
||||||
|
self._conn.commit()
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
self._on_changed(item)
|
self._on_changed(item)
|
||||||
@ -84,6 +85,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||||
)
|
)
|
||||||
|
self._conn.commit()
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
self._on_deleted(id)
|
self._on_deleted(id)
|
||||||
|
25
invokeai/app/util/save_thumbnail.py
Normal file
25
invokeai/app/util/save_thumbnail.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def save_thumbnail(
|
||||||
|
image: Image.Image,
|
||||||
|
filename: str,
|
||||||
|
path: str,
|
||||||
|
size: int = 256,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Saves a thumbnail of an image, returning its path.
|
||||||
|
"""
|
||||||
|
base_filename = os.path.splitext(filename)[0]
|
||||||
|
thumbnail_path = os.path.join(path, base_filename + ".webp")
|
||||||
|
|
||||||
|
if os.path.exists(thumbnail_path):
|
||||||
|
return thumbnail_path
|
||||||
|
|
||||||
|
image_copy = image.copy()
|
||||||
|
image_copy.thumbnail(size=(size, size))
|
||||||
|
|
||||||
|
image_copy.save(thumbnail_path, "WEBP")
|
||||||
|
|
||||||
|
return thumbnail_path
|
42
invokeai/app/util/util.py
Normal file
42
invokeai/app/util/util.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
|
from ...backend.util.util import image_to_dataURL
|
||||||
|
from ...backend.generator.base import Generator
|
||||||
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
|
|
||||||
|
class CanceledException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ):
|
||||||
|
# TODO: only output a preview image when requested
|
||||||
|
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||||
|
|
||||||
|
(width, height) = image.size
|
||||||
|
width *= 8
|
||||||
|
height *= 8
|
||||||
|
|
||||||
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||||
|
|
||||||
|
context.services.events.emit_generator_progress(
|
||||||
|
context.graph_execution_state_id,
|
||||||
|
id,
|
||||||
|
{
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"dataURL": dataURL
|
||||||
|
},
|
||||||
|
step,
|
||||||
|
steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
||||||
|
"""
|
||||||
|
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
|
||||||
|
This adapter grabs the needed data and passes it along to the callback function.
|
||||||
|
"""
|
||||||
|
if isinstance(cb_args[0], PipelineIntermediateState):
|
||||||
|
progress_state: PipelineIntermediateState = cb_args[0]
|
||||||
|
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs)
|
||||||
|
else:
|
||||||
|
return fast_latents_step_callback(*cb_args, **kwargs)
|
@ -21,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
|
|||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
from typing import List, Iterator, Type
|
from typing import Callable, List, Iterator, Optional, Type
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
@ -35,23 +35,23 @@ downsampling = 8
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InvokeAIGeneratorBasicParams:
|
class InvokeAIGeneratorBasicParams:
|
||||||
seed: int=None
|
seed: Optional[int]=None
|
||||||
width: int=512
|
width: int=512
|
||||||
height: int=512
|
height: int=512
|
||||||
cfg_scale: int=7.5
|
cfg_scale: float=7.5
|
||||||
steps: int=20
|
steps: int=20
|
||||||
ddim_eta: float=0.0
|
ddim_eta: float=0.0
|
||||||
scheduler: int='ddim'
|
scheduler: str='ddim'
|
||||||
precision: str='float16'
|
precision: str='float16'
|
||||||
perlin: float=0.0
|
perlin: float=0.0
|
||||||
threshold: int=0.0
|
threshold: float=0.0
|
||||||
seamless: bool=False
|
seamless: bool=False
|
||||||
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
|
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
|
||||||
h_symmetry_time_pct: float=None
|
h_symmetry_time_pct: Optional[float]=None
|
||||||
v_symmetry_time_pct: float=None
|
v_symmetry_time_pct: Optional[float]=None
|
||||||
variation_amount: float = 0.0
|
variation_amount: float = 0.0
|
||||||
with_variations: list=field(default_factory=list)
|
with_variations: list=field(default_factory=list)
|
||||||
safety_checker: SafetyChecker=None
|
safety_checker: Optional[SafetyChecker]=None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InvokeAIGeneratorOutput:
|
class InvokeAIGeneratorOutput:
|
||||||
@ -61,10 +61,10 @@ class InvokeAIGeneratorOutput:
|
|||||||
and the model hash, as well as all the generate() parameters that went into
|
and the model hash, as well as all the generate() parameters that went into
|
||||||
generating the image (in .params, also available as attributes)
|
generating the image (in .params, also available as attributes)
|
||||||
'''
|
'''
|
||||||
image: Image
|
image: Image.Image
|
||||||
seed: int
|
seed: int
|
||||||
model_hash: str
|
model_hash: str
|
||||||
attention_maps_images: List[Image]
|
attention_maps_images: List[Image.Image]
|
||||||
params: Namespace
|
params: Namespace
|
||||||
|
|
||||||
# we are interposing a wrapper around the original Generator classes so that
|
# we are interposing a wrapper around the original Generator classes so that
|
||||||
@ -92,8 +92,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
prompt: str='',
|
prompt: str='',
|
||||||
callback: callable=None,
|
callback: Optional[Callable]=None,
|
||||||
step_callback: callable=None,
|
step_callback: Optional[Callable]=None,
|
||||||
iterations: int=1,
|
iterations: int=1,
|
||||||
**keyword_args,
|
**keyword_args,
|
||||||
)->Iterator[InvokeAIGeneratorOutput]:
|
)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
@ -206,10 +206,10 @@ class Txt2Img(InvokeAIGenerator):
|
|||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
class Img2Img(InvokeAIGenerator):
|
class Img2Img(InvokeAIGenerator):
|
||||||
def generate(self,
|
def generate(self,
|
||||||
init_image: Image | torch.FloatTensor,
|
init_image: Image.Image | torch.FloatTensor,
|
||||||
strength: float=0.75,
|
strength: float=0.75,
|
||||||
**keyword_args
|
**keyword_args
|
||||||
)->List[InvokeAIGeneratorOutput]:
|
)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
return super().generate(init_image=init_image,
|
return super().generate(init_image=init_image,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
**keyword_args
|
**keyword_args
|
||||||
@ -223,7 +223,7 @@ class Img2Img(InvokeAIGenerator):
|
|||||||
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||||
class Inpaint(Img2Img):
|
class Inpaint(Img2Img):
|
||||||
def generate(self,
|
def generate(self,
|
||||||
mask_image: Image | torch.FloatTensor,
|
mask_image: Image.Image | torch.FloatTensor,
|
||||||
# Seam settings - when 0, doesn't fill seam
|
# Seam settings - when 0, doesn't fill seam
|
||||||
seam_size: int = 0,
|
seam_size: int = 0,
|
||||||
seam_blur: int = 0,
|
seam_blur: int = 0,
|
||||||
@ -236,7 +236,7 @@ class Inpaint(Img2Img):
|
|||||||
inpaint_height=None,
|
inpaint_height=None,
|
||||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||||
**keyword_args
|
**keyword_args
|
||||||
)->List[InvokeAIGeneratorOutput]:
|
)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
return super().generate(
|
return super().generate(
|
||||||
mask_image=mask_image,
|
mask_image=mask_image,
|
||||||
seam_size=seam_size,
|
seam_size=seam_size,
|
||||||
@ -263,7 +263,7 @@ class Embiggen(Txt2Img):
|
|||||||
embiggen: list=None,
|
embiggen: list=None,
|
||||||
embiggen_tiles: list = None,
|
embiggen_tiles: list = None,
|
||||||
strength: float=0.75,
|
strength: float=0.75,
|
||||||
**kwargs)->List[InvokeAIGeneratorOutput]:
|
**kwargs)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
return super().generate(embiggen=embiggen,
|
return super().generate(embiggen=embiggen,
|
||||||
embiggen_tiles=embiggen_tiles,
|
embiggen_tiles=embiggen_tiles,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
|
@ -378,16 +378,26 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith("model.diffusion_model"):
|
if key.startswith("model.diffusion_model"):
|
||||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||||
|
flat_ema_key_alt = "model_ema." + "".join(key.split(".")[2:])
|
||||||
|
if flat_ema_key in checkpoint:
|
||||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
|
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
|
||||||
flat_ema_key
|
flat_ema_key
|
||||||
)
|
)
|
||||||
|
elif flat_ema_key_alt in checkpoint:
|
||||||
|
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
|
||||||
|
flat_ema_key_alt
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
|
||||||
|
key
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
|
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
|
||||||
)
|
)
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith(unet_key):
|
if key.startswith("model.diffusion_model") and key in checkpoint:
|
||||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||||
|
|
||||||
new_checkpoint = {}
|
new_checkpoint = {}
|
||||||
@ -1026,6 +1036,15 @@ def convert_open_clip_checkpoint(checkpoint):
|
|||||||
|
|
||||||
return text_model
|
return text_model
|
||||||
|
|
||||||
|
def replace_checkpoint_vae(checkpoint, vae_path:str):
|
||||||
|
if vae_path.endswith(".safetensors"):
|
||||||
|
vae_ckpt = load_file(vae_path)
|
||||||
|
else:
|
||||||
|
vae_ckpt = torch.load(vae_path, map_location="cpu")
|
||||||
|
state_dict = vae_ckpt['state_dict'] if "state_dict" in vae_ckpt else vae_ckpt
|
||||||
|
for vae_key in state_dict:
|
||||||
|
new_key = f'first_stage_model.{vae_key}'
|
||||||
|
checkpoint[new_key] = state_dict[vae_key]
|
||||||
|
|
||||||
def load_pipeline_from_original_stable_diffusion_ckpt(
|
def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
@ -1038,8 +1057,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
extract_ema: bool = True,
|
extract_ema: bool = True,
|
||||||
upcast_attn: bool = False,
|
upcast_attn: bool = False,
|
||||||
vae: AutoencoderKL = None,
|
vae: AutoencoderKL = None,
|
||||||
|
vae_path: str = None,
|
||||||
precision: torch.dtype = torch.float32,
|
precision: torch.dtype = torch.float32,
|
||||||
return_generator_pipeline: bool = False,
|
return_generator_pipeline: bool = False,
|
||||||
|
scan_needed:bool=True,
|
||||||
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
|
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
|
||||||
"""
|
"""
|
||||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||||
@ -1067,6 +1088,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
|
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
|
||||||
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
|
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
|
||||||
running stable diffusion 2.1.
|
running stable diffusion 2.1.
|
||||||
|
:param vae: A diffusers VAE to load into the pipeline.
|
||||||
|
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
@ -1074,12 +1097,13 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
verbosity = dlogging.get_verbosity()
|
verbosity = dlogging.get_verbosity()
|
||||||
dlogging.set_verbosity_error()
|
dlogging.set_verbosity_error()
|
||||||
|
|
||||||
checkpoint = (
|
if Path(checkpoint_path).suffix == '.ckpt':
|
||||||
torch.load(checkpoint_path)
|
if scan_needed:
|
||||||
if Path(checkpoint_path).suffix == ".ckpt"
|
ModelManager.scan_model(checkpoint_path,checkpoint_path)
|
||||||
else load_file(checkpoint_path)
|
checkpoint = torch.load(checkpoint_path)
|
||||||
|
else:
|
||||||
|
checkpoint = load_file(checkpoint_path)
|
||||||
|
|
||||||
)
|
|
||||||
cache_dir = global_cache_dir("hub")
|
cache_dir = global_cache_dir("hub")
|
||||||
pipeline_class = (
|
pipeline_class = (
|
||||||
StableDiffusionGeneratorPipeline
|
StableDiffusionGeneratorPipeline
|
||||||
@ -1202,9 +1226,19 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
|
|
||||||
unet.load_state_dict(converted_unet_checkpoint)
|
unet.load_state_dict(converted_unet_checkpoint)
|
||||||
|
|
||||||
# Convert the VAE model, or use the one passed
|
# If a replacement VAE path was specified, we'll incorporate that into
|
||||||
if not vae:
|
# the checkpoint model and then convert it
|
||||||
|
if vae_path:
|
||||||
|
print(f" | Converting VAE {vae_path}")
|
||||||
|
replace_checkpoint_vae(checkpoint,vae_path)
|
||||||
|
# otherwise we use the original VAE, provided that
|
||||||
|
# an externally loaded diffusers VAE was not passed
|
||||||
|
elif not vae:
|
||||||
print(" | Using checkpoint model's original VAE")
|
print(" | Using checkpoint model's original VAE")
|
||||||
|
|
||||||
|
if vae:
|
||||||
|
print(" | Using replacement diffusers VAE")
|
||||||
|
else: # convert the original or replacement VAE
|
||||||
vae_config = create_vae_diffusers_config(
|
vae_config = create_vae_diffusers_config(
|
||||||
original_config, image_size=image_size
|
original_config, image_size=image_size
|
||||||
)
|
)
|
||||||
@ -1214,8 +1248,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
else:
|
|
||||||
print(" | Using external VAE specified in config")
|
|
||||||
|
|
||||||
# Convert the text model.
|
# Convert the text model.
|
||||||
model_type = pipeline_type
|
model_type = pipeline_type
|
||||||
@ -1232,10 +1264,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
)
|
)
|
||||||
pipe = pipeline_class(
|
pipe = pipeline_class(
|
||||||
vae=vae,
|
vae=vae.to(precision),
|
||||||
text_encoder=text_model,
|
text_encoder=text_model.to(precision),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
unet=unet,
|
unet=unet.to(precision),
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
feature_extractor=None,
|
feature_extractor=None,
|
||||||
|
@ -18,7 +18,7 @@ import warnings
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import move, rmtree
|
from shutil import move, rmtree
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union, Callable
|
||||||
|
|
||||||
import safetensors
|
import safetensors
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@ -34,7 +34,7 @@ from picklescan.scanner import scan_file_path
|
|||||||
from invokeai.backend.globals import Globals, global_cache_dir
|
from invokeai.backend.globals import Globals, global_cache_dir
|
||||||
|
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||||
from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume
|
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||||
|
|
||||||
class SDLegacyType(Enum):
|
class SDLegacyType(Enum):
|
||||||
V1 = 1
|
V1 = 1
|
||||||
@ -45,9 +45,6 @@ class SDLegacyType(Enum):
|
|||||||
UNKNOWN = 99
|
UNKNOWN = 99
|
||||||
|
|
||||||
DEFAULT_MAX_MODELS = 2
|
DEFAULT_MAX_MODELS = 2
|
||||||
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
|
||||||
"vae-ft-mse-840000-ema-pruned": "stabilityai/sd-vae-ft-mse",
|
|
||||||
}
|
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
'''
|
'''
|
||||||
@ -285,13 +282,13 @@ class ModelManager(object):
|
|||||||
self.stack.remove(model_name)
|
self.stack.remove(model_name)
|
||||||
if delete_files:
|
if delete_files:
|
||||||
if weights:
|
if weights:
|
||||||
print(f"** deleting file {weights}")
|
print(f"** Deleting file {weights}")
|
||||||
Path(weights).unlink(missing_ok=True)
|
Path(weights).unlink(missing_ok=True)
|
||||||
elif path:
|
elif path:
|
||||||
print(f"** deleting directory {path}")
|
print(f"** Deleting directory {path}")
|
||||||
rmtree(path, ignore_errors=True)
|
rmtree(path, ignore_errors=True)
|
||||||
elif repo_id:
|
elif repo_id:
|
||||||
print(f"** deleting the cached model directory for {repo_id}")
|
print(f"** Deleting the cached model directory for {repo_id}")
|
||||||
self._delete_model_from_cache(repo_id)
|
self._delete_model_from_cache(repo_id)
|
||||||
|
|
||||||
def add_model(
|
def add_model(
|
||||||
@ -435,7 +432,6 @@ class ModelManager(object):
|
|||||||
# square images???
|
# square images???
|
||||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||||
height = width
|
height = width
|
||||||
|
|
||||||
print(f" | Default image dimensions = {width} x {height}")
|
print(f" | Default image dimensions = {width} x {height}")
|
||||||
|
|
||||||
return pipeline, width, height, model_hash
|
return pipeline, width, height, model_hash
|
||||||
@ -457,15 +453,21 @@ class ModelManager(object):
|
|||||||
|
|
||||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.list_models()[self.current_model]['status'] == 'active':
|
||||||
self.offload_model(self.current_model)
|
self.offload_model(self.current_model)
|
||||||
if vae_config := self._choose_diffusers_vae(model_name):
|
except Exception as e:
|
||||||
vae = self._load_vae(vae_config)
|
pass
|
||||||
|
|
||||||
|
vae_path = None
|
||||||
|
if vae:
|
||||||
|
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||||
checkpoint_path=weights,
|
checkpoint_path=weights,
|
||||||
original_config_file=config,
|
original_config_file=config,
|
||||||
vae=vae,
|
vae_path=vae_path,
|
||||||
return_generator_pipeline=True,
|
return_generator_pipeline=True,
|
||||||
precision=torch.float16 if self.precision == "float16" else torch.float32,
|
precision=torch.float16 if self.precision == "float16" else torch.float32,
|
||||||
)
|
)
|
||||||
@ -473,7 +475,6 @@ class ModelManager(object):
|
|||||||
pipeline.enable_offload_submodels(self.device)
|
pipeline.enable_offload_submodels(self.device)
|
||||||
else:
|
else:
|
||||||
pipeline.to(self.device)
|
pipeline.to(self.device)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
pipeline,
|
pipeline,
|
||||||
width,
|
width,
|
||||||
@ -512,18 +513,20 @@ class ModelManager(object):
|
|||||||
print(f">> Offloading {model_name} to CPU")
|
print(f">> Offloading {model_name} to CPU")
|
||||||
model = self.models[model_name]["model"]
|
model = self.models[model_name]["model"]
|
||||||
model.offload_all()
|
model.offload_all()
|
||||||
|
self.current_model = None
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def scan_model(self, model_name, checkpoint):
|
def scan_model(self, model_name, checkpoint):
|
||||||
"""
|
"""
|
||||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||||
and option to exit if an infected file is identified.
|
and option to exit if an infected file is identified.
|
||||||
"""
|
"""
|
||||||
# scan model
|
# scan model
|
||||||
print(f">> Scanning Model: {model_name}")
|
print(f" | Scanning Model: {model_name}")
|
||||||
scan_result = scan_file_path(checkpoint)
|
scan_result = scan_file_path(checkpoint)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
if scan_result.infected_files == 1:
|
if scan_result.infected_files == 1:
|
||||||
@ -546,7 +549,7 @@ class ModelManager(object):
|
|||||||
print("### Exiting InvokeAI")
|
print("### Exiting InvokeAI")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
else:
|
else:
|
||||||
print(">> Model scanned ok")
|
print(" | Model scanned ok")
|
||||||
|
|
||||||
def import_diffuser_model(
|
def import_diffuser_model(
|
||||||
self,
|
self,
|
||||||
@ -627,14 +630,13 @@ class ModelManager(object):
|
|||||||
def heuristic_import(
|
def heuristic_import(
|
||||||
self,
|
self,
|
||||||
path_url_or_repo: str,
|
path_url_or_repo: str,
|
||||||
convert: bool = True,
|
|
||||||
model_name: str = None,
|
model_name: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
model_config_file: Path = None,
|
model_config_file: Path = None,
|
||||||
commit_to_conf: Path = None,
|
commit_to_conf: Path = None,
|
||||||
|
config_file_callback: Callable[[Path], Path] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""Accept a string which could be:
|
||||||
Accept a string which could be:
|
|
||||||
- a HF diffusers repo_id
|
- a HF diffusers repo_id
|
||||||
- a URL pointing to a legacy .ckpt or .safetensors file
|
- a URL pointing to a legacy .ckpt or .safetensors file
|
||||||
- a local path pointing to a legacy .ckpt or .safetensors file
|
- a local path pointing to a legacy .ckpt or .safetensors file
|
||||||
@ -648,16 +650,20 @@ class ModelManager(object):
|
|||||||
The model_name and/or description can be provided. If not, they will
|
The model_name and/or description can be provided. If not, they will
|
||||||
be generated automatically.
|
be generated automatically.
|
||||||
|
|
||||||
If convert is true, legacy models will be converted to diffusers
|
|
||||||
before importing.
|
|
||||||
|
|
||||||
If commit_to_conf is provided, the newly loaded model will be written
|
If commit_to_conf is provided, the newly loaded model will be written
|
||||||
to the `models.yaml` file at the indicated path. Otherwise, the changes
|
to the `models.yaml` file at the indicated path. Otherwise, the changes
|
||||||
will only remain in memory.
|
will only remain in memory.
|
||||||
|
|
||||||
The (potentially derived) name of the model is returned on success, or None
|
The routine will do its best to figure out the config file
|
||||||
on failure. When multiple models are added from a directory, only the last
|
needed to convert legacy checkpoint file, but if it can't it
|
||||||
imported one is returned.
|
will call the config_file_callback routine, if provided. The
|
||||||
|
callback accepts a single argument, the Path to the checkpoint
|
||||||
|
file, and returns a Path to the config file to use.
|
||||||
|
|
||||||
|
The (potentially derived) name of the model is returned on
|
||||||
|
success, or None on failure. When multiple models are added
|
||||||
|
from a directory, only the last imported one is returned.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
model_path: Path = None
|
model_path: Path = None
|
||||||
thing = path_url_or_repo # to save typing
|
thing = path_url_or_repo # to save typing
|
||||||
@ -704,7 +710,7 @@ class ModelManager(object):
|
|||||||
Path(thing).rglob("*.safetensors")
|
Path(thing).rglob("*.safetensors")
|
||||||
):
|
):
|
||||||
if model_name := self.heuristic_import(
|
if model_name := self.heuristic_import(
|
||||||
str(m), convert, commit_to_conf=commit_to_conf
|
str(m), commit_to_conf=commit_to_conf
|
||||||
):
|
):
|
||||||
print(f" >> {model_name} successfully imported")
|
print(f" >> {model_name} successfully imported")
|
||||||
return model_name
|
return model_name
|
||||||
@ -731,14 +737,21 @@ class ModelManager(object):
|
|||||||
return model_path.stem
|
return model_path.stem
|
||||||
|
|
||||||
# another round of heuristics to guess the correct config file.
|
# another round of heuristics to guess the correct config file.
|
||||||
checkpoint = (
|
checkpoint = None
|
||||||
torch.load(model_path)
|
if model_path.suffix in [".ckpt",".pt"]:
|
||||||
if model_path.suffix == ".ckpt"
|
self.scan_model(model_path,model_path)
|
||||||
else safetensors.torch.load_file(model_path)
|
checkpoint = torch.load(model_path)
|
||||||
)
|
else:
|
||||||
|
checkpoint = safetensors.torch.load_file(model_path)
|
||||||
|
|
||||||
# additional probing needed if no config file provided
|
# additional probing needed if no config file provided
|
||||||
if model_config_file is None:
|
if model_config_file is None:
|
||||||
|
# look for a like-named .yaml file in same directory
|
||||||
|
if model_path.with_suffix(".yaml").exists():
|
||||||
|
model_config_file = model_path.with_suffix(".yaml")
|
||||||
|
print(f" | Using config file {model_config_file.name}")
|
||||||
|
|
||||||
|
else:
|
||||||
model_type = self.probe_model_type(checkpoint)
|
model_type = self.probe_model_type(checkpoint)
|
||||||
if model_type == SDLegacyType.V1:
|
if model_type == SDLegacyType.V1:
|
||||||
print(" | SD-v1 model detected")
|
print(" | SD-v1 model detected")
|
||||||
@ -752,20 +765,18 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2_v:
|
elif model_type == SDLegacyType.V2_v:
|
||||||
print(
|
print(
|
||||||
" | SD-v2-v model detected; model will be converted to diffusers format"
|
" | SD-v2-v model detected"
|
||||||
)
|
)
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||||
)
|
)
|
||||||
convert = True
|
|
||||||
elif model_type == SDLegacyType.V2_e:
|
elif model_type == SDLegacyType.V2_e:
|
||||||
print(
|
print(
|
||||||
" | SD-v2-e model detected; model will be converted to diffusers format"
|
" | SD-v2-e model detected"
|
||||||
)
|
)
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||||
)
|
)
|
||||||
convert = True
|
|
||||||
elif model_type == SDLegacyType.V2:
|
elif model_type == SDLegacyType.V2:
|
||||||
print(
|
print(
|
||||||
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||||
@ -777,17 +788,34 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not model_config_file and config_file_callback:
|
||||||
|
model_config_file = config_file_callback(model_path)
|
||||||
|
|
||||||
|
# despite our best efforts, we could not find a model config file, so give up
|
||||||
|
if not model_config_file:
|
||||||
|
return
|
||||||
|
|
||||||
|
# look for a custom vae, a like-named file ending with .vae in the same directory
|
||||||
|
vae_path = None
|
||||||
|
for suffix in ["pt", "ckpt", "safetensors"]:
|
||||||
|
if (model_path.with_suffix(f".vae.{suffix}")).exists():
|
||||||
|
vae_path = model_path.with_suffix(f".vae.{suffix}")
|
||||||
|
print(f" | Using VAE file {vae_path.name}")
|
||||||
|
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
||||||
|
|
||||||
diffuser_path = Path(
|
diffuser_path = Path(
|
||||||
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
|
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
|
||||||
)
|
)
|
||||||
model_name = self.convert_and_import(
|
model_name = self.convert_and_import(
|
||||||
model_path,
|
model_path,
|
||||||
diffusers_path=diffuser_path,
|
diffusers_path=diffuser_path,
|
||||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
vae=vae,
|
||||||
|
vae_path=str(vae_path),
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_description=description,
|
model_description=description,
|
||||||
original_config_file=model_config_file,
|
original_config_file=model_config_file,
|
||||||
commit_to_conf=commit_to_conf,
|
commit_to_conf=commit_to_conf,
|
||||||
|
scan_needed=False,
|
||||||
)
|
)
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
@ -797,9 +825,11 @@ class ModelManager(object):
|
|||||||
diffusers_path: Path,
|
diffusers_path: Path,
|
||||||
model_name=None,
|
model_name=None,
|
||||||
model_description=None,
|
model_description=None,
|
||||||
vae=None,
|
vae:dict=None,
|
||||||
|
vae_path:Path=None,
|
||||||
original_config_file: Path = None,
|
original_config_file: Path = None,
|
||||||
commit_to_conf: Path = None,
|
commit_to_conf: Path = None,
|
||||||
|
scan_needed: bool=True,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a legacy ckpt weights file to diffuser model and import
|
Convert a legacy ckpt weights file to diffuser model and import
|
||||||
@ -822,21 +852,26 @@ class ModelManager(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
model_name = model_name or diffusers_path.name
|
model_name = model_name or diffusers_path.name
|
||||||
model_description = model_description or f"Optimized version of {model_name}"
|
model_description = model_description or f"Converted version of {model_name}"
|
||||||
print(f">> Optimizing {model_name} (30-60s)")
|
print(f" | Converting {model_name} to diffusers (30-60s)")
|
||||||
try:
|
try:
|
||||||
# By passing the specified VAE to the conversion function, the autoencoder
|
# By passing the specified VAE to the conversion function, the autoencoder
|
||||||
# will be built into the model rather than tacked on afterward via the config file
|
# will be built into the model rather than tacked on afterward via the config file
|
||||||
vae_model = self._load_vae(vae) if vae else None
|
vae_model=None
|
||||||
|
if vae:
|
||||||
|
vae_model=self._load_vae(vae)
|
||||||
|
vae_path=None
|
||||||
convert_ckpt_to_diffusers(
|
convert_ckpt_to_diffusers(
|
||||||
ckpt_path,
|
ckpt_path,
|
||||||
diffusers_path,
|
diffusers_path,
|
||||||
extract_ema=True,
|
extract_ema=True,
|
||||||
original_config_file=original_config_file,
|
original_config_file=original_config_file,
|
||||||
vae=vae_model,
|
vae=vae_model,
|
||||||
|
vae_path=vae_path,
|
||||||
|
scan_needed=scan_needed,
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f" | Success. Optimized model is now located at {str(diffusers_path)}"
|
f" | Success. Converted model is now located at {str(diffusers_path)}"
|
||||||
)
|
)
|
||||||
print(f" | Writing new config file entry for {model_name}")
|
print(f" | Writing new config file entry for {model_name}")
|
||||||
new_config = dict(
|
new_config = dict(
|
||||||
@ -849,7 +884,7 @@ class ModelManager(object):
|
|||||||
self.add_model(model_name, new_config, True)
|
self.add_model(model_name, new_config, True)
|
||||||
if commit_to_conf:
|
if commit_to_conf:
|
||||||
self.commit(commit_to_conf)
|
self.commit(commit_to_conf)
|
||||||
print(">> Conversion succeeded")
|
print(" | Conversion succeeded")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"** Conversion failed: {str(e)}")
|
print(f"** Conversion failed: {str(e)}")
|
||||||
print(
|
print(
|
||||||
@ -879,36 +914,6 @@ class ModelManager(object):
|
|||||||
|
|
||||||
return search_folder, found_models
|
return search_folder, found_models
|
||||||
|
|
||||||
def _choose_diffusers_vae(
|
|
||||||
self, model_name: str, vae: str = None
|
|
||||||
) -> Union[dict, str]:
|
|
||||||
# In the event that the original entry is using a custom ckpt VAE, we try to
|
|
||||||
# map that VAE onto a diffuser VAE using a hard-coded dictionary.
|
|
||||||
# I would prefer to do this differently: We load the ckpt model into memory, swap the
|
|
||||||
# VAE in memory, and then pass that to convert_ckpt_to_diffuser() so that the swapped
|
|
||||||
# VAE is built into the model. However, when I tried this I got obscure key errors.
|
|
||||||
if vae:
|
|
||||||
return vae
|
|
||||||
if model_name in self.config and (
|
|
||||||
vae_ckpt_path := self.model_info(model_name).get("vae", None)
|
|
||||||
):
|
|
||||||
vae_basename = Path(vae_ckpt_path).stem
|
|
||||||
diffusers_vae = None
|
|
||||||
if diffusers_vae := VAE_TO_REPO_ID.get(vae_basename, None):
|
|
||||||
print(
|
|
||||||
f">> {vae_basename} VAE corresponds to known {diffusers_vae} diffusers version"
|
|
||||||
)
|
|
||||||
vae = {"repo_id": diffusers_vae}
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f'** Custom VAE "{vae_basename}" found, but corresponding diffusers model unknown'
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
'** Using "stabilityai/sd-vae-ft-mse"; If this isn\'t right, please edit the model config'
|
|
||||||
)
|
|
||||||
vae = {"repo_id": "stabilityai/sd-vae-ft-mse"}
|
|
||||||
return vae
|
|
||||||
|
|
||||||
def _make_cache_room(self) -> None:
|
def _make_cache_room(self) -> None:
|
||||||
num_loaded_models = len(self.models)
|
num_loaded_models = len(self.models)
|
||||||
if num_loaded_models >= self.max_loaded_models:
|
if num_loaded_models >= self.max_loaded_models:
|
||||||
@ -1208,7 +1213,7 @@ class ModelManager(object):
|
|||||||
hashes_to_delete.add(revision.commit_hash)
|
hashes_to_delete.add(revision.commit_hash)
|
||||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
||||||
print(
|
print(
|
||||||
f"** deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||||
)
|
)
|
||||||
strategy.execute()
|
strategy.execute()
|
||||||
|
|
||||||
|
@ -1,16 +1,26 @@
|
|||||||
import os
|
|
||||||
import traceback
|
import traceback
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union, List
|
||||||
|
|
||||||
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from compel.embeddings_provider import BaseTextualInversionManager
|
from compel.embeddings_provider import BaseTextualInversionManager
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmbeddingInfo:
|
||||||
|
name: str
|
||||||
|
embedding: torch.Tensor
|
||||||
|
num_vectors_per_token: int
|
||||||
|
token_dim: int
|
||||||
|
trained_steps: int = None
|
||||||
|
trained_model_name: str = None
|
||||||
|
trained_model_checksum: str = None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TextualInversion:
|
class TextualInversion:
|
||||||
@ -72,37 +82,17 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
if str(ckpt_path).endswith(".DS_Store"):
|
if str(ckpt_path).endswith(".DS_Store"):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
embedding_list = self._parse_embedding(str(ckpt_path))
|
||||||
scan_result = scan_file_path(str(ckpt_path))
|
for embedding_info in embedding_list:
|
||||||
if scan_result.infected_files == 1:
|
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
||||||
print(
|
print(
|
||||||
f"\n### Security Issues Found in Model: {scan_result.issues_count}"
|
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
||||||
)
|
)
|
||||||
print("### For your safety, InvokeAI will not load this embed.")
|
continue
|
||||||
return
|
|
||||||
except Exception:
|
|
||||||
print(
|
|
||||||
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
embedding_info = self._parse_embedding(str(ckpt_path))
|
|
||||||
|
|
||||||
if embedding_info is None:
|
|
||||||
# We've already put out an error message about the bad embedding in _parse_embedding, so just return.
|
|
||||||
return
|
|
||||||
elif (
|
|
||||||
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
|
||||||
!= embedding_info["token_dim"]
|
|
||||||
):
|
|
||||||
print(
|
|
||||||
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Resolve the situation in which an earlier embedding has claimed the same
|
# Resolve the situation in which an earlier embedding has claimed the same
|
||||||
# trigger string. We replace the trigger with '<source_file>', as we used to.
|
# trigger string. We replace the trigger with '<source_file>', as we used to.
|
||||||
trigger_str = embedding_info["name"]
|
trigger_str = embedding_info.name
|
||||||
sourcefile = (
|
sourcefile = (
|
||||||
f"{ckpt_path.parent.name}/{ckpt_path.name}"
|
f"{ckpt_path.parent.name}/{ckpt_path.name}"
|
||||||
if ckpt_path.name == "learned_embeds.bin"
|
if ckpt_path.name == "learned_embeds.bin"
|
||||||
@ -123,7 +113,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
try:
|
try:
|
||||||
self._add_textual_inversion(
|
self._add_textual_inversion(
|
||||||
trigger_str,
|
trigger_str,
|
||||||
embedding_info["embedding"],
|
embedding_info.embedding,
|
||||||
defer_injecting_tokens=defer_injecting_tokens,
|
defer_injecting_tokens=defer_injecting_tokens,
|
||||||
)
|
)
|
||||||
# remember which source file claims this trigger
|
# remember which source file claims this trigger
|
||||||
@ -309,111 +299,130 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
|
|
||||||
return token_id
|
return token_id
|
||||||
|
|
||||||
def _parse_embedding(self, embedding_file: str):
|
|
||||||
file_type = embedding_file.split(".")[-1]
|
|
||||||
if file_type == "pt":
|
|
||||||
return self._parse_embedding_pt(embedding_file)
|
|
||||||
elif file_type == "bin":
|
|
||||||
return self._parse_embedding_bin(embedding_file)
|
|
||||||
else:
|
|
||||||
print(f"** Notice: unrecognized embedding file format: {embedding_file}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _parse_embedding_pt(self, embedding_file):
|
def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]:
|
||||||
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
|
suffix = Path(embedding_file).suffix
|
||||||
embedding_info = {}
|
|
||||||
|
|
||||||
# Check if valid embedding file
|
|
||||||
if "string_to_token" and "string_to_param" in embedding_ckpt:
|
|
||||||
# Catch variants that do not have the expected keys or values.
|
|
||||||
try:
|
try:
|
||||||
embedding_info["name"] = embedding_ckpt["name"] or os.path.basename(
|
if suffix in [".pt",".ckpt",".bin"]:
|
||||||
os.path.splitext(embedding_file)[0]
|
scan_result = scan_file_path(embedding_file)
|
||||||
|
if scan_result.infected_files > 0:
|
||||||
|
print(
|
||||||
|
f" ** Security Issues Found in Model: {scan_result.issues_count}"
|
||||||
)
|
)
|
||||||
|
print(" ** For your safety, InvokeAI will not load this embed.")
|
||||||
|
return list()
|
||||||
|
ckpt = torch.load(embedding_file,map_location="cpu")
|
||||||
|
else:
|
||||||
|
ckpt = safetensors.torch.load_file(embedding_file)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
||||||
|
return list()
|
||||||
|
|
||||||
# Check num of embeddings and warn user only the first will be used
|
# try to figure out what kind of embedding file it is and parse accordingly
|
||||||
embedding_info["num_of_embeddings"] = len(
|
keys = list(ckpt.keys())
|
||||||
embedding_ckpt["string_to_token"]
|
if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
|
||||||
)
|
return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
|
||||||
if embedding_info["num_of_embeddings"] > 1:
|
|
||||||
print(">> More than 1 embedding found. Will use the first one")
|
|
||||||
|
|
||||||
embedding = list(embedding_ckpt["string_to_param"].values())[0]
|
elif all(x in keys for x in ['string_to_token','string_to_param']):
|
||||||
except (AttributeError, KeyError):
|
return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
|
||||||
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
|
|
||||||
|
|
||||||
embedding_info["embedding"] = embedding
|
elif 'emb_params' in keys:
|
||||||
embedding_info["num_vectors_per_token"] = embedding.size()[0]
|
return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
|
||||||
embedding_info["token_dim"] = embedding.size()[1]
|
|
||||||
|
|
||||||
try:
|
|
||||||
embedding_info["trained_steps"] = embedding_ckpt["step"]
|
|
||||||
embedding_info["trained_model_name"] = embedding_ckpt[
|
|
||||||
"sd_checkpoint_name"
|
|
||||||
]
|
|
||||||
embedding_info["trained_model_checksum"] = embedding_ckpt[
|
|
||||||
"sd_checkpoint"
|
|
||||||
]
|
|
||||||
except AttributeError:
|
|
||||||
print(">> No Training Details Found. Passing ...")
|
|
||||||
|
|
||||||
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
|
|
||||||
# They are actually .bin files
|
|
||||||
elif len(embedding_ckpt.keys()) == 1:
|
|
||||||
embedding_info = self._parse_embedding_bin(embedding_file)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(">> Invalid embedding format")
|
return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
|
||||||
embedding_info = None
|
|
||||||
|
|
||||||
return embedding_info
|
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||||
|
basename = Path(file_path).stem
|
||||||
|
print(f' | Loading v1 embedding file: {basename}')
|
||||||
|
|
||||||
def _parse_embedding_bin(self, embedding_file):
|
embeddings = list()
|
||||||
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
|
token_counter = -1
|
||||||
embedding_info = {}
|
for token,embedding in embedding_ckpt["string_to_param"].items():
|
||||||
|
if token_counter < 0:
|
||||||
if list(embedding_ckpt.keys()) == 0:
|
trigger = embedding_ckpt["name"]
|
||||||
print(">> Invalid concepts file")
|
elif token_counter == 0:
|
||||||
embedding_info = None
|
trigger = f'<basename>'
|
||||||
else:
|
else:
|
||||||
for token in list(embedding_ckpt.keys()):
|
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
|
||||||
embedding_info["name"] = (
|
token_counter += 1
|
||||||
token
|
embedding_info = EmbeddingInfo(
|
||||||
or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
|
name = trigger,
|
||||||
|
embedding = embedding,
|
||||||
|
num_vectors_per_token = embedding.size()[0],
|
||||||
|
token_dim = embedding.size()[1],
|
||||||
|
trained_steps = embedding_ckpt["step"],
|
||||||
|
trained_model_name = embedding_ckpt["sd_checkpoint_name"],
|
||||||
|
trained_model_checksum = embedding_ckpt["sd_checkpoint"]
|
||||||
)
|
)
|
||||||
embedding_info["embedding"] = embedding_ckpt[token]
|
embeddings.append(embedding_info)
|
||||||
embedding_info[
|
return embeddings
|
||||||
"num_vectors_per_token"
|
|
||||||
] = 1 # All Concepts seem to default to 1
|
|
||||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
|
|
||||||
|
|
||||||
return embedding_info
|
def _parse_embedding_v2 (
|
||||||
|
self, embedding_ckpt: dict, file_path: str
|
||||||
|
) -> List[EmbeddingInfo]:
|
||||||
|
"""
|
||||||
|
This handles embedding .pt file variant #2.
|
||||||
|
"""
|
||||||
|
basename = Path(file_path).stem
|
||||||
|
print(f' | Loading v2 embedding file: {basename}')
|
||||||
|
embeddings = list()
|
||||||
|
|
||||||
def _handle_broken_pt_variants(
|
|
||||||
self, embedding_ckpt: dict, embedding_file: str
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
This handles the broken .pt file variants. We only know of one at present.
|
|
||||||
"""
|
|
||||||
embedding_info = {}
|
|
||||||
if isinstance(
|
if isinstance(
|
||||||
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
||||||
):
|
):
|
||||||
for token in list(embedding_ckpt["string_to_token"].keys()):
|
token_counter = 0
|
||||||
embedding_info["name"] = (
|
for token,embedding in embedding_ckpt["string_to_param"].items():
|
||||||
token
|
trigger = token if token != '*' \
|
||||||
if token != "*"
|
else f'<{basename}>' if token_counter == 0 \
|
||||||
else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
|
else f'<{basename}-{int(token_counter:=token_counter+1)}>'
|
||||||
|
embedding_info = EmbeddingInfo(
|
||||||
|
name = trigger,
|
||||||
|
embedding = embedding,
|
||||||
|
num_vectors_per_token = embedding.size()[0],
|
||||||
|
token_dim = embedding.size()[1],
|
||||||
)
|
)
|
||||||
embedding_info["embedding"] = embedding_ckpt[
|
embeddings.append(embedding_info)
|
||||||
"string_to_param"
|
|
||||||
].state_dict()[token]
|
|
||||||
embedding_info["num_vectors_per_token"] = embedding_info[
|
|
||||||
"embedding"
|
|
||||||
].shape[0]
|
|
||||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
|
|
||||||
else:
|
else:
|
||||||
print(">> Invalid embedding format")
|
print(f" ** {basename}: Unrecognized embedding format")
|
||||||
embedding_info = None
|
|
||||||
|
|
||||||
return embedding_info
|
return embeddings
|
||||||
|
|
||||||
|
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||||
|
"""
|
||||||
|
Parse 'version 3' of the .pt textual inversion embedding files.
|
||||||
|
"""
|
||||||
|
basename = Path(file_path).stem
|
||||||
|
print(f' | Loading v3 embedding file: {basename}')
|
||||||
|
embedding = embedding_ckpt['emb_params']
|
||||||
|
embedding_info = EmbeddingInfo(
|
||||||
|
name = f'<{basename}>',
|
||||||
|
embedding = embedding,
|
||||||
|
num_vectors_per_token = embedding.size()[0],
|
||||||
|
token_dim = embedding.size()[1],
|
||||||
|
)
|
||||||
|
return [embedding_info]
|
||||||
|
|
||||||
|
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]:
|
||||||
|
"""
|
||||||
|
Parse 'version 4' of the textual inversion embedding files. This one
|
||||||
|
is usually associated with .bin files trained by HuggingFace diffusers.
|
||||||
|
"""
|
||||||
|
basename = Path(filepath).stem
|
||||||
|
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
||||||
|
|
||||||
|
print(f' | Loading v4 embedding file: {short_path}')
|
||||||
|
|
||||||
|
embeddings = list()
|
||||||
|
if list(embedding_ckpt.keys()) == 0:
|
||||||
|
print(f" ** Invalid embeddings file: {short_path}")
|
||||||
|
else:
|
||||||
|
for token,embedding in embedding_ckpt.items():
|
||||||
|
embedding_info = EmbeddingInfo(
|
||||||
|
name = token or f"<{basename}>",
|
||||||
|
embedding = embedding,
|
||||||
|
num_vectors_per_token = 1, # All Concepts seem to default to 1
|
||||||
|
token_dim = embedding.size()[0],
|
||||||
|
)
|
||||||
|
embeddings.append(embedding_info)
|
||||||
|
return embeddings
|
||||||
|
@ -1022,7 +1022,7 @@ class InvokeAIWebServer:
|
|||||||
"RGB"
|
"RGB"
|
||||||
)
|
)
|
||||||
|
|
||||||
def image_progress(sample, step):
|
def image_progress(intermediate_state: PipelineIntermediateState):
|
||||||
if self.canceled.is_set():
|
if self.canceled.is_set():
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
|
|
||||||
@ -1030,6 +1030,14 @@ class InvokeAIWebServer:
|
|||||||
nonlocal generation_parameters
|
nonlocal generation_parameters
|
||||||
nonlocal progress
|
nonlocal progress
|
||||||
|
|
||||||
|
step = intermediate_state.step
|
||||||
|
if intermediate_state.predicted_original is not None:
|
||||||
|
# Some schedulers report not only the noisy latents at the current timestep,
|
||||||
|
# but also their estimate so far of what the de-noised latents will be.
|
||||||
|
sample = intermediate_state.predicted_original
|
||||||
|
else:
|
||||||
|
sample = intermediate_state.latents
|
||||||
|
|
||||||
generation_messages = {
|
generation_messages = {
|
||||||
"txt2img": "common.statusGeneratingTextToImage",
|
"txt2img": "common.statusGeneratingTextToImage",
|
||||||
"img2img": "common.statusGeneratingImageToImage",
|
"img2img": "common.statusGeneratingImageToImage",
|
||||||
@ -1302,16 +1310,9 @@ class InvokeAIWebServer:
|
|||||||
|
|
||||||
progress.set_current_iteration(progress.current_iteration + 1)
|
progress.set_current_iteration(progress.current_iteration + 1)
|
||||||
|
|
||||||
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
|
||||||
if isinstance(cb_args[0], PipelineIntermediateState):
|
|
||||||
progress_state: PipelineIntermediateState = cb_args[0]
|
|
||||||
return image_progress(progress_state.latents, progress_state.step)
|
|
||||||
else:
|
|
||||||
return image_progress(*cb_args, **kwargs)
|
|
||||||
|
|
||||||
self.generate.prompt2image(
|
self.generate.prompt2image(
|
||||||
**generation_parameters,
|
**generation_parameters,
|
||||||
step_callback=diffusers_step_callback_adapter,
|
step_callback=image_progress,
|
||||||
image_callback=image_done,
|
image_callback=image_done,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -626,7 +626,7 @@ def set_default_output_dir(opt: Args, completer: Completer):
|
|||||||
completer.set_default_dir(opt.outdir)
|
completer.set_default_dir(opt.outdir)
|
||||||
|
|
||||||
|
|
||||||
def import_model(model_path: str, gen, opt, completer, convert=False):
|
def import_model(model_path: str, gen, opt, completer):
|
||||||
"""
|
"""
|
||||||
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
|
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
|
||||||
(3) a huggingface repository id; or (4) a local directory containing a
|
(3) a huggingface repository id; or (4) a local directory containing a
|
||||||
@ -657,7 +657,6 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
|
|||||||
model_path,
|
model_path,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
description=model_desc,
|
description=model_desc,
|
||||||
convert=convert,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not imported_name:
|
if not imported_name:
|
||||||
@ -666,7 +665,6 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
|
|||||||
model_path,
|
model_path,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
description=model_desc,
|
description=model_desc,
|
||||||
convert=convert,
|
|
||||||
model_config_file=config_file,
|
model_config_file=config_file,
|
||||||
)
|
)
|
||||||
if not imported_name:
|
if not imported_name:
|
||||||
@ -757,7 +755,6 @@ def _get_model_name_and_desc(
|
|||||||
)
|
)
|
||||||
return model_name, model_description
|
return model_name, model_description
|
||||||
|
|
||||||
|
|
||||||
def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
||||||
model_name_or_path = model_name_or_path.replace("\\", "/") # windows
|
model_name_or_path = model_name_or_path.replace("\\", "/") # windows
|
||||||
manager = gen.model_manager
|
manager = gen.model_manager
|
||||||
@ -772,16 +769,10 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
|||||||
original_config_file = Path(model_info["config"])
|
original_config_file = Path(model_info["config"])
|
||||||
model_name = model_name_or_path
|
model_name = model_name_or_path
|
||||||
model_description = model_info["description"]
|
model_description = model_info["description"]
|
||||||
vae = model_info["vae"]
|
vae_path = model_info.get("vae")
|
||||||
else:
|
else:
|
||||||
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
|
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
|
||||||
return
|
return
|
||||||
if vae_repo := invokeai.backend.model_management.model_manager.VAE_TO_REPO_ID.get(
|
|
||||||
Path(vae).stem
|
|
||||||
):
|
|
||||||
vae_repo = dict(repo_id=vae_repo)
|
|
||||||
else:
|
|
||||||
vae_repo = None
|
|
||||||
model_name = manager.convert_and_import(
|
model_name = manager.convert_and_import(
|
||||||
ckpt_path,
|
ckpt_path,
|
||||||
diffusers_path=Path(
|
diffusers_path=Path(
|
||||||
@ -790,11 +781,11 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_description=model_description,
|
model_description=model_description,
|
||||||
original_config_file=original_config_file,
|
original_config_file=original_config_file,
|
||||||
vae=vae_repo,
|
vae_path=vae_path,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
import_model(model_name_or_path, gen, opt, completer, convert=True)
|
import_model(model_name_or_path, gen, opt, completer)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
"darkTheme": "داكن",
|
"darkTheme": "داكن",
|
||||||
"lightTheme": "فاتح",
|
"lightTheme": "فاتح",
|
||||||
"greenTheme": "أخضر",
|
"greenTheme": "أخضر",
|
||||||
"text2img": "نص إلى صورة",
|
|
||||||
"img2img": "صورة إلى صورة",
|
"img2img": "صورة إلى صورة",
|
||||||
"unifiedCanvas": "لوحة موحدة",
|
"unifiedCanvas": "لوحة موحدة",
|
||||||
"nodes": "عقد",
|
"nodes": "عقد",
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
"darkTheme": "Dunkel",
|
"darkTheme": "Dunkel",
|
||||||
"lightTheme": "Hell",
|
"lightTheme": "Hell",
|
||||||
"greenTheme": "Grün",
|
"greenTheme": "Grün",
|
||||||
"text2img": "Text zu Bild",
|
|
||||||
"img2img": "Bild zu Bild",
|
"img2img": "Bild zu Bild",
|
||||||
"nodes": "Knoten",
|
"nodes": "Knoten",
|
||||||
"langGerman": "Deutsch",
|
"langGerman": "Deutsch",
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
"darkTheme": "Oscuro",
|
"darkTheme": "Oscuro",
|
||||||
"lightTheme": "Claro",
|
"lightTheme": "Claro",
|
||||||
"greenTheme": "Verde",
|
"greenTheme": "Verde",
|
||||||
"text2img": "Texto a Imagen",
|
|
||||||
"img2img": "Imagen a Imagen",
|
"img2img": "Imagen a Imagen",
|
||||||
"unifiedCanvas": "Lienzo Unificado",
|
"unifiedCanvas": "Lienzo Unificado",
|
||||||
"nodes": "Nodos",
|
"nodes": "Nodos",
|
||||||
@ -70,7 +69,11 @@
|
|||||||
"langHebrew": "Hebreo",
|
"langHebrew": "Hebreo",
|
||||||
"pinOptionsPanel": "Pin del panel de opciones",
|
"pinOptionsPanel": "Pin del panel de opciones",
|
||||||
"loading": "Cargando",
|
"loading": "Cargando",
|
||||||
"loadingInvokeAI": "Cargando invocar a la IA"
|
"loadingInvokeAI": "Cargando invocar a la IA",
|
||||||
|
"postprocessing": "Tratamiento posterior",
|
||||||
|
"txt2img": "De texto a imagen",
|
||||||
|
"accept": "Aceptar",
|
||||||
|
"cancel": "Cancelar"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"generations": "Generaciones",
|
"generations": "Generaciones",
|
||||||
@ -404,7 +407,8 @@
|
|||||||
"none": "ninguno",
|
"none": "ninguno",
|
||||||
"pickModelType": "Elige el tipo de modelo",
|
"pickModelType": "Elige el tipo de modelo",
|
||||||
"v2_768": "v2 (768px)",
|
"v2_768": "v2 (768px)",
|
||||||
"addDifference": "Añadir una diferencia"
|
"addDifference": "Añadir una diferencia",
|
||||||
|
"scanForModels": "Buscar modelos"
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"images": "Imágenes",
|
"images": "Imágenes",
|
||||||
@ -574,7 +578,7 @@
|
|||||||
"autoSaveToGallery": "Guardar automáticamente en galería",
|
"autoSaveToGallery": "Guardar automáticamente en galería",
|
||||||
"saveBoxRegionOnly": "Guardar solo región dentro de la caja",
|
"saveBoxRegionOnly": "Guardar solo región dentro de la caja",
|
||||||
"limitStrokesToBox": "Limitar trazos a la caja",
|
"limitStrokesToBox": "Limitar trazos a la caja",
|
||||||
"showCanvasDebugInfo": "Mostrar información de depuración de lienzo",
|
"showCanvasDebugInfo": "Mostrar la información adicional del lienzo",
|
||||||
"clearCanvasHistory": "Limpiar historial de lienzo",
|
"clearCanvasHistory": "Limpiar historial de lienzo",
|
||||||
"clearHistory": "Limpiar historial",
|
"clearHistory": "Limpiar historial",
|
||||||
"clearCanvasHistoryMessage": "Limpiar el historial de lienzo también restablece completamente el lienzo unificado. Esto incluye todo el historial de deshacer/rehacer, las imágenes en el área de preparación y la capa base del lienzo.",
|
"clearCanvasHistoryMessage": "Limpiar el historial de lienzo también restablece completamente el lienzo unificado. Esto incluye todo el historial de deshacer/rehacer, las imágenes en el área de preparación y la capa base del lienzo.",
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
"darkTheme": "Sombre",
|
"darkTheme": "Sombre",
|
||||||
"lightTheme": "Clair",
|
"lightTheme": "Clair",
|
||||||
"greenTheme": "Vert",
|
"greenTheme": "Vert",
|
||||||
"text2img": "Texte en image",
|
|
||||||
"img2img": "Image en image",
|
"img2img": "Image en image",
|
||||||
"unifiedCanvas": "Canvas unifié",
|
"unifiedCanvas": "Canvas unifié",
|
||||||
"nodes": "Nœuds",
|
"nodes": "Nœuds",
|
||||||
@ -47,7 +46,19 @@
|
|||||||
"statusLoadingModel": "Chargement du modèle",
|
"statusLoadingModel": "Chargement du modèle",
|
||||||
"statusModelChanged": "Modèle changé",
|
"statusModelChanged": "Modèle changé",
|
||||||
"discordLabel": "Discord",
|
"discordLabel": "Discord",
|
||||||
"githubLabel": "Github"
|
"githubLabel": "Github",
|
||||||
|
"accept": "Accepter",
|
||||||
|
"statusMergingModels": "Mélange des modèles",
|
||||||
|
"loadingInvokeAI": "Chargement de Invoke AI",
|
||||||
|
"cancel": "Annuler",
|
||||||
|
"langEnglish": "Anglais",
|
||||||
|
"statusConvertingModel": "Conversion du modèle",
|
||||||
|
"statusModelConverted": "Modèle converti",
|
||||||
|
"loading": "Chargement",
|
||||||
|
"pinOptionsPanel": "Épingler la page d'options",
|
||||||
|
"statusMergedModels": "Modèles mélangés",
|
||||||
|
"txt2img": "Texte vers image",
|
||||||
|
"postprocessing": "Post-Traitement"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"generations": "Générations",
|
"generations": "Générations",
|
||||||
@ -518,5 +529,15 @@
|
|||||||
"betaDarkenOutside": "Assombrir à l'extérieur",
|
"betaDarkenOutside": "Assombrir à l'extérieur",
|
||||||
"betaLimitToBox": "Limiter à la boîte",
|
"betaLimitToBox": "Limiter à la boîte",
|
||||||
"betaPreserveMasked": "Conserver masqué"
|
"betaPreserveMasked": "Conserver masqué"
|
||||||
|
},
|
||||||
|
"accessibility": {
|
||||||
|
"uploadImage": "Charger une image",
|
||||||
|
"reset": "Réinitialiser",
|
||||||
|
"nextImage": "Image suivante",
|
||||||
|
"previousImage": "Image précédente",
|
||||||
|
"useThisParameter": "Utiliser ce paramètre",
|
||||||
|
"zoomIn": "Zoom avant",
|
||||||
|
"zoomOut": "Zoom arrière",
|
||||||
|
"showOptionsPanel": "Montrer la page d'options"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -125,7 +125,6 @@
|
|||||||
"langSimplifiedChinese": "סינית",
|
"langSimplifiedChinese": "סינית",
|
||||||
"langUkranian": "אוקראינית",
|
"langUkranian": "אוקראינית",
|
||||||
"langSpanish": "ספרדית",
|
"langSpanish": "ספרדית",
|
||||||
"text2img": "טקסט לתמונה",
|
|
||||||
"img2img": "תמונה לתמונה",
|
"img2img": "תמונה לתמונה",
|
||||||
"unifiedCanvas": "קנבס מאוחד",
|
"unifiedCanvas": "קנבס מאוחד",
|
||||||
"nodes": "צמתים",
|
"nodes": "צמתים",
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
"darkTheme": "Scuro",
|
"darkTheme": "Scuro",
|
||||||
"lightTheme": "Chiaro",
|
"lightTheme": "Chiaro",
|
||||||
"greenTheme": "Verde",
|
"greenTheme": "Verde",
|
||||||
"text2img": "Testo a Immagine",
|
|
||||||
"img2img": "Immagine a Immagine",
|
"img2img": "Immagine a Immagine",
|
||||||
"unifiedCanvas": "Tela unificata",
|
"unifiedCanvas": "Tela unificata",
|
||||||
"nodes": "Nodi",
|
"nodes": "Nodi",
|
||||||
@ -70,7 +69,11 @@
|
|||||||
"loading": "Caricamento in corso",
|
"loading": "Caricamento in corso",
|
||||||
"oceanTheme": "Oceano",
|
"oceanTheme": "Oceano",
|
||||||
"langHebrew": "Ebraico",
|
"langHebrew": "Ebraico",
|
||||||
"loadingInvokeAI": "Caricamento Invoke AI"
|
"loadingInvokeAI": "Caricamento Invoke AI",
|
||||||
|
"postprocessing": "Post Elaborazione",
|
||||||
|
"txt2img": "Testo a Immagine",
|
||||||
|
"accept": "Accetta",
|
||||||
|
"cancel": "Annulla"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"generations": "Generazioni",
|
"generations": "Generazioni",
|
||||||
@ -404,7 +407,8 @@
|
|||||||
"v2_768": "v2 (768px)",
|
"v2_768": "v2 (768px)",
|
||||||
"none": "niente",
|
"none": "niente",
|
||||||
"addDifference": "Aggiungi differenza",
|
"addDifference": "Aggiungi differenza",
|
||||||
"pickModelType": "Scegli il tipo di modello"
|
"pickModelType": "Scegli il tipo di modello",
|
||||||
|
"scanForModels": "Cerca modelli"
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"images": "Immagini",
|
"images": "Immagini",
|
||||||
@ -574,7 +578,7 @@
|
|||||||
"autoSaveToGallery": "Salvataggio automatico nella Galleria",
|
"autoSaveToGallery": "Salvataggio automatico nella Galleria",
|
||||||
"saveBoxRegionOnly": "Salva solo l'area di selezione",
|
"saveBoxRegionOnly": "Salva solo l'area di selezione",
|
||||||
"limitStrokesToBox": "Limita i tratti all'area di selezione",
|
"limitStrokesToBox": "Limita i tratti all'area di selezione",
|
||||||
"showCanvasDebugInfo": "Mostra informazioni di debug della Tela",
|
"showCanvasDebugInfo": "Mostra ulteriori informazioni sulla Tela",
|
||||||
"clearCanvasHistory": "Cancella cronologia Tela",
|
"clearCanvasHistory": "Cancella cronologia Tela",
|
||||||
"clearHistory": "Cancella la cronologia",
|
"clearHistory": "Cancella la cronologia",
|
||||||
"clearCanvasHistoryMessage": "La cancellazione della cronologia della tela lascia intatta la tela corrente, ma cancella in modo irreversibile la cronologia degli annullamenti e dei ripristini.",
|
"clearCanvasHistoryMessage": "La cancellazione della cronologia della tela lascia intatta la tela corrente, ma cancella in modo irreversibile la cronologia degli annullamenti e dei ripristini.",
|
||||||
@ -612,7 +616,7 @@
|
|||||||
"copyMetadataJson": "Copia i metadati JSON",
|
"copyMetadataJson": "Copia i metadati JSON",
|
||||||
"exitViewer": "Esci dal visualizzatore",
|
"exitViewer": "Esci dal visualizzatore",
|
||||||
"zoomIn": "Zoom avanti",
|
"zoomIn": "Zoom avanti",
|
||||||
"zoomOut": "Zoom Indietro",
|
"zoomOut": "Zoom indietro",
|
||||||
"rotateCounterClockwise": "Ruotare in senso antiorario",
|
"rotateCounterClockwise": "Ruotare in senso antiorario",
|
||||||
"rotateClockwise": "Ruotare in senso orario",
|
"rotateClockwise": "Ruotare in senso orario",
|
||||||
"flipHorizontally": "Capovolgi orizzontalmente",
|
"flipHorizontally": "Capovolgi orizzontalmente",
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
"langArabic": "العربية",
|
"langArabic": "العربية",
|
||||||
"langEnglish": "English",
|
"langEnglish": "English",
|
||||||
"langDutch": "Nederlands",
|
"langDutch": "Nederlands",
|
||||||
"text2img": "텍스트->이미지",
|
|
||||||
"unifiedCanvas": "통합 캔버스",
|
"unifiedCanvas": "통합 캔버스",
|
||||||
"langFrench": "Français",
|
"langFrench": "Français",
|
||||||
"langGerman": "Deutsch",
|
"langGerman": "Deutsch",
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
"darkTheme": "Donker",
|
"darkTheme": "Donker",
|
||||||
"lightTheme": "Licht",
|
"lightTheme": "Licht",
|
||||||
"greenTheme": "Groen",
|
"greenTheme": "Groen",
|
||||||
"text2img": "Tekst naar afbeelding",
|
|
||||||
"img2img": "Afbeelding naar afbeelding",
|
"img2img": "Afbeelding naar afbeelding",
|
||||||
"unifiedCanvas": "Centraal canvas",
|
"unifiedCanvas": "Centraal canvas",
|
||||||
"nodes": "Knooppunten",
|
"nodes": "Knooppunten",
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
"darkTheme": "Ciemny",
|
"darkTheme": "Ciemny",
|
||||||
"lightTheme": "Jasny",
|
"lightTheme": "Jasny",
|
||||||
"greenTheme": "Zielony",
|
"greenTheme": "Zielony",
|
||||||
"text2img": "Tekst na obraz",
|
|
||||||
"img2img": "Obraz na obraz",
|
"img2img": "Obraz na obraz",
|
||||||
"unifiedCanvas": "Tryb uniwersalny",
|
"unifiedCanvas": "Tryb uniwersalny",
|
||||||
"nodes": "Węzły",
|
"nodes": "Węzły",
|
||||||
|
@ -20,7 +20,6 @@
|
|||||||
"langSpanish": "Espanhol",
|
"langSpanish": "Espanhol",
|
||||||
"langRussian": "Русский",
|
"langRussian": "Русский",
|
||||||
"langUkranian": "Украї́нська",
|
"langUkranian": "Украї́нська",
|
||||||
"text2img": "Texto para Imagem",
|
|
||||||
"img2img": "Imagem para Imagem",
|
"img2img": "Imagem para Imagem",
|
||||||
"unifiedCanvas": "Tela Unificada",
|
"unifiedCanvas": "Tela Unificada",
|
||||||
"nodes": "Nós",
|
"nodes": "Nós",
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
"darkTheme": "Noite",
|
"darkTheme": "Noite",
|
||||||
"lightTheme": "Dia",
|
"lightTheme": "Dia",
|
||||||
"greenTheme": "Verde",
|
"greenTheme": "Verde",
|
||||||
"text2img": "Texto Para Imagem",
|
|
||||||
"img2img": "Imagem Para Imagem",
|
"img2img": "Imagem Para Imagem",
|
||||||
"unifiedCanvas": "Tela Unificada",
|
"unifiedCanvas": "Tela Unificada",
|
||||||
"nodes": "Nódulos",
|
"nodes": "Nódulos",
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
"darkTheme": "Темная",
|
"darkTheme": "Темная",
|
||||||
"lightTheme": "Светлая",
|
"lightTheme": "Светлая",
|
||||||
"greenTheme": "Зеленая",
|
"greenTheme": "Зеленая",
|
||||||
"text2img": "Изображение из текста (text2img)",
|
|
||||||
"img2img": "Изображение в изображение (img2img)",
|
"img2img": "Изображение в изображение (img2img)",
|
||||||
"unifiedCanvas": "Универсальный холст",
|
"unifiedCanvas": "Универсальный холст",
|
||||||
"nodes": "Ноды",
|
"nodes": "Ноды",
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
"darkTheme": "Темна",
|
"darkTheme": "Темна",
|
||||||
"lightTheme": "Світла",
|
"lightTheme": "Світла",
|
||||||
"greenTheme": "Зелена",
|
"greenTheme": "Зелена",
|
||||||
"text2img": "Зображення із тексту (text2img)",
|
|
||||||
"img2img": "Зображення із зображення (img2img)",
|
"img2img": "Зображення із зображення (img2img)",
|
||||||
"unifiedCanvas": "Універсальне полотно",
|
"unifiedCanvas": "Універсальне полотно",
|
||||||
"nodes": "Вузли",
|
"nodes": "Вузли",
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
"darkTheme": "暗色",
|
"darkTheme": "暗色",
|
||||||
"lightTheme": "亮色",
|
"lightTheme": "亮色",
|
||||||
"greenTheme": "绿色",
|
"greenTheme": "绿色",
|
||||||
"text2img": "文字到图像",
|
|
||||||
"img2img": "图像到图像",
|
"img2img": "图像到图像",
|
||||||
"unifiedCanvas": "统一画布",
|
"unifiedCanvas": "统一画布",
|
||||||
"nodes": "节点",
|
"nodes": "节点",
|
||||||
|
@ -33,7 +33,6 @@
|
|||||||
"langBrPortuguese": "巴西葡萄牙語",
|
"langBrPortuguese": "巴西葡萄牙語",
|
||||||
"langRussian": "俄語",
|
"langRussian": "俄語",
|
||||||
"langSpanish": "西班牙語",
|
"langSpanish": "西班牙語",
|
||||||
"text2img": "文字到圖像",
|
|
||||||
"unifiedCanvas": "統一畫布"
|
"unifiedCanvas": "統一畫布"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -34,7 +34,7 @@ const ReactPanZoomButtons = ({
|
|||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
icon={<BiZoomIn />}
|
icon={<BiZoomIn />}
|
||||||
aria-label={t('accessibility.zoomIn')}
|
aria-label={t('accessibility.zoomIn')}
|
||||||
tooltip="Zoom In"
|
tooltip={t('accessibility.zoomIn')}
|
||||||
onClick={() => zoomIn()}
|
onClick={() => zoomIn()}
|
||||||
fontSize={20}
|
fontSize={20}
|
||||||
/>
|
/>
|
||||||
@ -42,7 +42,7 @@ const ReactPanZoomButtons = ({
|
|||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
icon={<BiZoomOut />}
|
icon={<BiZoomOut />}
|
||||||
aria-label={t('accessibility.zoomOut')}
|
aria-label={t('accessibility.zoomOut')}
|
||||||
tooltip="Zoom Out"
|
tooltip={t('accessibility.zoomOut')}
|
||||||
onClick={() => zoomOut()}
|
onClick={() => zoomOut()}
|
||||||
fontSize={20}
|
fontSize={20}
|
||||||
/>
|
/>
|
||||||
@ -50,7 +50,7 @@ const ReactPanZoomButtons = ({
|
|||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
icon={<BiRotateLeft />}
|
icon={<BiRotateLeft />}
|
||||||
aria-label={t('accessibility.rotateCounterClockwise')}
|
aria-label={t('accessibility.rotateCounterClockwise')}
|
||||||
tooltip="Rotate Counter-Clockwise"
|
tooltip={t('accessibility.rotateCounterClockwise')}
|
||||||
onClick={rotateCounterClockwise}
|
onClick={rotateCounterClockwise}
|
||||||
fontSize={20}
|
fontSize={20}
|
||||||
/>
|
/>
|
||||||
@ -58,7 +58,7 @@ const ReactPanZoomButtons = ({
|
|||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
icon={<BiRotateRight />}
|
icon={<BiRotateRight />}
|
||||||
aria-label={t('accessibility.rotateClockwise')}
|
aria-label={t('accessibility.rotateClockwise')}
|
||||||
tooltip="Rotate Clockwise"
|
tooltip={t('accessibility.rotateClockwise')}
|
||||||
onClick={rotateClockwise}
|
onClick={rotateClockwise}
|
||||||
fontSize={20}
|
fontSize={20}
|
||||||
/>
|
/>
|
||||||
@ -66,7 +66,7 @@ const ReactPanZoomButtons = ({
|
|||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
icon={<MdFlip />}
|
icon={<MdFlip />}
|
||||||
aria-label={t('accessibility.flipHorizontally')}
|
aria-label={t('accessibility.flipHorizontally')}
|
||||||
tooltip="Flip Horizontally"
|
tooltip={t('accessibility.flipHorizontally')}
|
||||||
onClick={flipHorizontally}
|
onClick={flipHorizontally}
|
||||||
fontSize={20}
|
fontSize={20}
|
||||||
/>
|
/>
|
||||||
@ -74,7 +74,7 @@ const ReactPanZoomButtons = ({
|
|||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
icon={<MdFlip style={{ transform: 'rotate(90deg)' }} />}
|
icon={<MdFlip style={{ transform: 'rotate(90deg)' }} />}
|
||||||
aria-label={t('accessibility.flipVertically')}
|
aria-label={t('accessibility.flipVertically')}
|
||||||
tooltip="Flip Vertically"
|
tooltip={t('accessibility.flipVertically')}
|
||||||
onClick={flipVertically}
|
onClick={flipVertically}
|
||||||
fontSize={20}
|
fontSize={20}
|
||||||
/>
|
/>
|
||||||
@ -82,7 +82,7 @@ const ReactPanZoomButtons = ({
|
|||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
icon={<BiReset />}
|
icon={<BiReset />}
|
||||||
aria-label={t('accessibility.reset')}
|
aria-label={t('accessibility.reset')}
|
||||||
tooltip="Reset"
|
tooltip={t('accessibility.reset')}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
resetTransform();
|
resetTransform();
|
||||||
reset();
|
reset();
|
||||||
|
@ -1,8 +1,23 @@
|
|||||||
import i18n from 'i18next';
|
import i18n from 'i18next';
|
||||||
import LanguageDetector from 'i18next-browser-languagedetector';
|
import LanguageDetector from 'i18next-browser-languagedetector';
|
||||||
import Backend from 'i18next-http-backend';
|
import Backend from 'i18next-http-backend';
|
||||||
|
|
||||||
import { initReactI18next } from 'react-i18next';
|
import { initReactI18next } from 'react-i18next';
|
||||||
|
|
||||||
|
import translationEN from '../dist/locales/en.json';
|
||||||
|
|
||||||
|
if (import.meta.env.MODE === 'package') {
|
||||||
|
i18n.use(initReactI18next).init({
|
||||||
|
lng: 'en',
|
||||||
|
resources: {
|
||||||
|
en: { translation: translationEN },
|
||||||
|
},
|
||||||
|
debug: false,
|
||||||
|
interpolation: {
|
||||||
|
escapeValue: false,
|
||||||
|
},
|
||||||
|
returnNull: false,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
i18n
|
i18n
|
||||||
.use(Backend)
|
.use(Backend)
|
||||||
.use(LanguageDetector)
|
.use(LanguageDetector)
|
||||||
@ -18,5 +33,6 @@ i18n
|
|||||||
},
|
},
|
||||||
returnNull: false,
|
returnNull: false,
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
export default i18n;
|
export default i18n;
|
||||||
|
@ -38,14 +38,14 @@ dependencies = [
|
|||||||
"albumentations",
|
"albumentations",
|
||||||
"click",
|
"click",
|
||||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel==1.0.4",
|
"compel==1.0.5",
|
||||||
"datasets",
|
"datasets",
|
||||||
"diffusers[torch]~=0.14",
|
"diffusers[torch]~=0.14",
|
||||||
"dnspython==2.2.1",
|
"dnspython==2.2.1",
|
||||||
"einops",
|
"einops",
|
||||||
"eventlet",
|
"eventlet",
|
||||||
"facexlib",
|
"facexlib",
|
||||||
"fastapi==0.94.1",
|
"fastapi==0.88.0",
|
||||||
"fastapi-events==0.8.0",
|
"fastapi-events==0.8.0",
|
||||||
"fastapi-socketio==0.0.10",
|
"fastapi-socketio==0.0.10",
|
||||||
"flask==2.1.3",
|
"flask==2.1.3",
|
||||||
@ -156,4 +156,3 @@ output = "coverage/index.xml"
|
|||||||
|
|
||||||
[flake8]
|
[flake8]
|
||||||
max-line-length = 120
|
max-line-length = 120
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
from .test_invoker import create_edge
|
from .test_invoker import create_edge
|
||||||
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
|
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
|
from invokeai.app.invocations.collections import RangeInvocation
|
||||||
|
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
||||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
from invokeai.app.services.processor import DefaultInvocationProcessor
|
||||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||||
@ -21,13 +23,14 @@ def simple_graph():
|
|||||||
def mock_services():
|
def mock_services():
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
model_manager = None,
|
model_manager = None, # type: ignore
|
||||||
events = None,
|
events = None, # type: ignore
|
||||||
images = None,
|
images = None, # type: ignore
|
||||||
|
latents = None, # type: ignore
|
||||||
queue = MemoryInvocationQueue(),
|
queue = MemoryInvocationQueue(),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor(),
|
processor = DefaultInvocationProcessor(),
|
||||||
restoration = None,
|
restoration = None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
||||||
@ -73,31 +76,23 @@ def test_graph_is_not_complete(simple_graph, mock_services):
|
|||||||
|
|
||||||
def test_graph_state_expands_iterator(mock_services):
|
def test_graph_state_expands_iterator(mock_services):
|
||||||
graph = Graph()
|
graph = Graph()
|
||||||
test_prompts = ["Banana sushi", "Cat sushi"]
|
graph.add_node(RangeInvocation(id = "0", start = 0, stop = 3, step = 1))
|
||||||
graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts)))
|
graph.add_node(IterateInvocation(id = "1"))
|
||||||
graph.add_node(IterateInvocation(id = "2"))
|
graph.add_node(MultiplyInvocation(id = "2", b = 10))
|
||||||
graph.add_node(ImageTestInvocation(id = "3"))
|
graph.add_node(AddInvocation(id = "3", b = 1))
|
||||||
graph.add_edge(create_edge("1", "collection", "2", "collection"))
|
graph.add_edge(create_edge("0", "collection", "1", "collection"))
|
||||||
graph.add_edge(create_edge("2", "item", "3", "prompt"))
|
graph.add_edge(create_edge("1", "item", "2", "a"))
|
||||||
|
graph.add_edge(create_edge("2", "a", "3", "a"))
|
||||||
|
|
||||||
g = GraphExecutionState(graph = graph)
|
g = GraphExecutionState(graph = graph)
|
||||||
n1 = invoke_next(g, mock_services)
|
while not g.is_complete():
|
||||||
n2 = invoke_next(g, mock_services)
|
invoke_next(g, mock_services)
|
||||||
n3 = invoke_next(g, mock_services)
|
|
||||||
n4 = invoke_next(g, mock_services)
|
|
||||||
n5 = invoke_next(g, mock_services)
|
|
||||||
|
|
||||||
assert g.prepared_source_mapping[n1[0].id] == "1"
|
prepared_add_nodes = g.source_prepared_mapping['3']
|
||||||
assert g.prepared_source_mapping[n2[0].id] == "2"
|
results = set([g.results[n].a for n in prepared_add_nodes])
|
||||||
assert g.prepared_source_mapping[n3[0].id] == "2"
|
expected = set([1, 11, 21])
|
||||||
assert g.prepared_source_mapping[n4[0].id] == "3"
|
assert results == expected
|
||||||
assert g.prepared_source_mapping[n5[0].id] == "3"
|
|
||||||
|
|
||||||
assert isinstance(n4[0], ImageTestInvocation)
|
|
||||||
assert isinstance(n5[0], ImageTestInvocation)
|
|
||||||
|
|
||||||
prompts = [n4[0].prompt, n5[0].prompt]
|
|
||||||
assert sorted(prompts) == sorted(test_prompts)
|
|
||||||
|
|
||||||
def test_graph_state_collects(mock_services):
|
def test_graph_state_collects(mock_services):
|
||||||
graph = Graph()
|
graph = Graph()
|
||||||
|
@ -24,10 +24,11 @@ def mock_services() -> InvocationServices:
|
|||||||
model_manager = None, # type: ignore
|
model_manager = None, # type: ignore
|
||||||
events = TestEventService(),
|
events = TestEventService(),
|
||||||
images = None, # type: ignore
|
images = None, # type: ignore
|
||||||
|
latents = None, # type: ignore
|
||||||
queue = MemoryInvocationQueue(),
|
queue = MemoryInvocationQueue(),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor(),
|
processor = DefaultInvocationProcessor(),
|
||||||
restoration = None,
|
restoration = None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
|
Loading…
Reference in New Issue
Block a user