diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 3abca42b1a..17facf4155 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,16 +1,16 @@ # continuous integration -/.github/workflows/ @mauwii @lstein +/.github/workflows/ @mauwii @lstein @blessedcoolant # documentation -/docs/ @lstein @mauwii @tildebyte -/mkdocs.yml @lstein @mauwii +/docs/ @lstein @mauwii @tildebyte @blessedcoolant +/mkdocs.yml @lstein @mauwii @blessedcoolant # nodes /invokeai/app/ @Kyle0654 @blessedcoolant # installation and configuration /pyproject.toml @mauwii @lstein @blessedcoolant -/docker/ @mauwii @lstein +/docker/ @mauwii @lstein @blessedcoolant /scripts/ @ebr @lstein /installer/ @lstein @ebr /invokeai/assets @lstein @ebr diff --git a/.github/workflows/build-container.yml b/.github/workflows/build-container.yml index 0fabbdf038..23d7c82fe3 100644 --- a/.github/workflows/build-container.yml +++ b/.github/workflows/build-container.yml @@ -16,6 +16,10 @@ on: - 'v*.*.*' workflow_dispatch: +permissions: + contents: write + packages: write + jobs: docker: if: github.event.pull_request.draft == false diff --git a/.github/workflows/mkdocs-material.yml b/.github/workflows/mkdocs-material.yml index 26a46c1328..c8e55f0b1b 100644 --- a/.github/workflows/mkdocs-material.yml +++ b/.github/workflows/mkdocs-material.yml @@ -5,6 +5,9 @@ on: - 'main' - 'development' +permissions: + contents: write + jobs: mkdocs-material: if: github.event.pull_request.draft == false diff --git a/.github/workflows/test-invoke-pip-skip.yml b/.github/workflows/test-invoke-pip-skip.yml index c2347e5ce3..d4c9d9fc00 100644 --- a/.github/workflows/test-invoke-pip-skip.yml +++ b/.github/workflows/test-invoke-pip-skip.yml @@ -6,7 +6,6 @@ on: - '!pyproject.toml' - '!invokeai/**' - 'invokeai/frontend/web/**' - - '!invokeai/frontend/web/dist/**' merge_group: workflow_dispatch: diff --git a/.github/workflows/test-invoke-pip.yml b/.github/workflows/test-invoke-pip.yml index 30ed05379c..c5e4d10bfd 100644 --- a/.github/workflows/test-invoke-pip.yml +++ b/.github/workflows/test-invoke-pip.yml @@ -7,13 +7,11 @@ on: - 'pyproject.toml' - 'invokeai/**' - '!invokeai/frontend/web/**' - - 'invokeai/frontend/web/dist/**' pull_request: paths: - 'pyproject.toml' - 'invokeai/**' - '!invokeai/frontend/web/**' - - 'invokeai/frontend/web/dist/**' types: - 'ready_for_review' - 'opened' diff --git a/README.md b/README.md index 5857b2d898..1b02ced2c9 100644 --- a/README.md +++ b/README.md @@ -139,13 +139,13 @@ not supported. _For Windows/Linux with an NVIDIA GPU:_ ```terminal - pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117 + pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117 ``` _For Linux with an AMD GPU:_ ```sh - pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2 + pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2 ``` _For Macintoshes, either Intel or M1/M2:_ diff --git a/docs/features/INPAINTING.md b/docs/features/INPAINTING.md index f3a879b190..bebf9bc229 100644 --- a/docs/features/INPAINTING.md +++ b/docs/features/INPAINTING.md @@ -168,11 +168,15 @@ used by Stable Diffusion 1.4 and 1.5. After installation, your `models.yaml` should contain an entry that looks like this one: -inpainting-1.5: weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt -description: SD inpainting v1.5 config: -configs/stable-diffusion/v1-inpainting-inference.yaml vae: -models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt width: 512 -height: 512 +```yml +inpainting-1.5: + weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt + description: SD inpainting v1.5 + config: configs/stable-diffusion/v1-inpainting-inference.yaml + vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt + width: 512 + height: 512 +``` As shown in the example, you may include a VAE fine-tuning weights file as well. This is strongly recommended. diff --git a/docs/features/PROMPTS.md b/docs/features/PROMPTS.md index 85919a5b29..045e0d658a 100644 --- a/docs/features/PROMPTS.md +++ b/docs/features/PROMPTS.md @@ -268,7 +268,7 @@ model is so good at inpainting, a good substitute is to use the `clipseg` text masking option: ```bash -invoke> a fluffy cat eating a hotdot +invoke> a fluffy cat eating a hotdog Outputs: [1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat diff --git a/docs/installation/010_INSTALL_AUTOMATED.md b/docs/installation/010_INSTALL_AUTOMATED.md index 228c0ae9a4..83b4415394 100644 --- a/docs/installation/010_INSTALL_AUTOMATED.md +++ b/docs/installation/010_INSTALL_AUTOMATED.md @@ -417,7 +417,7 @@ Then type the following commands: === "AMD System" ```bash - pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.2 + pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.4.2 ``` ### Corrupted configuration file diff --git a/docs/installation/020_INSTALL_MANUAL.md b/docs/installation/020_INSTALL_MANUAL.md index 401560e76c..657e3f055d 100644 --- a/docs/installation/020_INSTALL_MANUAL.md +++ b/docs/installation/020_INSTALL_MANUAL.md @@ -154,7 +154,7 @@ manager, please follow these steps: === "ROCm (AMD)" ```bash - pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2 + pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2 ``` === "CPU (Intel Macs & non-GPU systems)" @@ -315,7 +315,7 @@ installation protocol (important!) === "ROCm (AMD)" ```bash - pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2 + pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2 ``` === "CPU (Intel Macs & non-GPU systems)" diff --git a/docs/installation/030_INSTALL_CUDA_AND_ROCM.md b/docs/installation/030_INSTALL_CUDA_AND_ROCM.md index 8ce690ca64..3d3445e3a0 100644 --- a/docs/installation/030_INSTALL_CUDA_AND_ROCM.md +++ b/docs/installation/030_INSTALL_CUDA_AND_ROCM.md @@ -110,7 +110,7 @@ recipes are available When installing torch and torchvision manually with `pip`, remember to provide the argument `--extra-index-url -https://download.pytorch.org/whl/rocm5.2` as described in the [Manual +https://download.pytorch.org/whl/rocm5.4.2` as described in the [Manual Installation Guide](020_INSTALL_MANUAL.md). This will be done automatically for you if you use the installer diff --git a/installer/lib/installer.py b/installer/lib/installer.py index 344fa12046..14fc657011 100644 --- a/installer/lib/installer.py +++ b/installer/lib/installer.py @@ -456,7 +456,7 @@ def get_torch_source() -> (Union[str, None],str): optional_modules = None if OS == "Linux": if device == "rocm": - url = "https://download.pytorch.org/whl/rocm5.2" + url = "https://download.pytorch.org/whl/rocm5.4.2" elif device == "cpu": url = "https://download.pytorch.org/whl/cpu" diff --git a/installer/templates/invoke.sh.in b/installer/templates/invoke.sh.in index 812bcba458..4576c7172f 100644 --- a/installer/templates/invoke.sh.in +++ b/installer/templates/invoke.sh.in @@ -24,9 +24,9 @@ if [ "$(uname -s)" == "Darwin" ]; then export PYTORCH_ENABLE_MPS_FALLBACK=1 fi -while true -do if [ "$0" != "bash" ]; then + while true + do echo "Do you want to generate images using the" echo "1. command-line interface" echo "2. browser-based UI" @@ -67,29 +67,29 @@ if [ "$0" != "bash" ]; then ;; 7) invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only - ;; - 8) - echo "Developer Console:" + ;; + 8) + echo "Developer Console:" file_name=$(basename "${BASH_SOURCE[0]}") bash --init-file "$file_name" ;; 9) - echo "Update:" + echo "Update:" invokeai-update ;; 10) invokeai --help ;; - [qQ]) + [qQ]) exit 0 ;; *) echo "Invalid selection" exit;; esac + done else # in developer console python --version echo "Press ^D to exit" export PS1="(InvokeAI) \u@\h \w> " fi -done diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 347fba7e97..5698d25758 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -3,6 +3,8 @@ import os from argparse import Namespace +from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage + from ...backend import Globals from ..services.model_manager_initializer import get_model_manager from ..services.restoration_services import RestorationServices @@ -54,7 +56,9 @@ class ApiDependencies: os.path.join(os.path.dirname(__file__), "../../../../outputs") ) - images = DiskImageStorage(output_folder) + latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')) + + images = DiskImageStorage(f'{output_folder}/images') # TODO: build a file/path manager? db_location = os.path.join(output_folder, "invokeai.db") @@ -62,6 +66,7 @@ class ApiDependencies: services = InvocationServices( model_manager=get_model_manager(config), events=events, + latents=latents, images=images, queue=MemoryInvocationQueue(), graph_execution_manager=SqliteItemStorage[GraphExecutionState]( diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 55f1a2f036..453c114a28 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -23,6 +23,16 @@ async def get_image( filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name) return FileResponse(filename) +@images_router.get("/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail") +async def get_thumbnail( + image_type: ImageType = Path(description="The type of image to get"), + image_name: str = Path(description="The name of the image to get"), +): + """Gets a thumbnail""" + # TODO: This is not really secure at all. At least make sure only output results are served + filename = ApiDependencies.invoker.services.images.get_path(image_type, 'thumbnails/' + image_name) + return FileResponse(filename) + @images_router.post( "/uploads/", diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py new file mode 100644 index 0000000000..5b3fbebddd --- /dev/null +++ b/invokeai/app/api/routers/models.py @@ -0,0 +1,279 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +from typing import Annotated, Any, List, Literal, Optional, Union + +from fastapi.routing import APIRouter +from pydantic import BaseModel, Field, parse_obj_as + +from ..dependencies import ApiDependencies + +models_router = APIRouter(prefix="/v1/models", tags=["models"]) + + +class VaeRepo(BaseModel): + repo_id: str = Field(description="The repo ID to use for this VAE") + path: Optional[str] = Field(description="The path to the VAE") + subfolder: Optional[str] = Field(description="The subfolder to use for this VAE") + + +class ModelInfo(BaseModel): + description: Optional[str] = Field(description="A description of the model") + + +class CkptModelInfo(ModelInfo): + format: Literal['ckpt'] = 'ckpt' + + config: str = Field(description="The path to the model config") + weights: str = Field(description="The path to the model weights") + vae: str = Field(description="The path to the model VAE") + width: Optional[int] = Field(description="The width of the model") + height: Optional[int] = Field(description="The height of the model") + + +class DiffusersModelInfo(ModelInfo): + format: Literal['diffusers'] = 'diffusers' + + vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model") + repo_id: Optional[str] = Field(description="The repo ID to use for this model") + path: Optional[str] = Field(description="The path to the model") + + +class ModelsList(BaseModel): + models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]] + + + +@models_router.get( + "/", + operation_id="list_models", + responses={200: {"model": ModelsList }}, +) +async def list_models() -> ModelsList: + """Gets a list of models""" + models_raw = ApiDependencies.invoker.services.model_manager.list_models() + models = parse_obj_as(ModelsList, { "models": models_raw }) + return models + + # @socketio.on("requestSystemConfig") + # def handle_request_capabilities(): + # print(">> System config requested") + # config = self.get_system_config() + # config["model_list"] = self.generate.model_manager.list_models() + # config["infill_methods"] = infill_methods() + # socketio.emit("systemConfig", config) + + # @socketio.on("searchForModels") + # def handle_search_models(search_folder: str): + # try: + # if not search_folder: + # socketio.emit( + # "foundModels", + # {"search_folder": None, "found_models": None}, + # ) + # else: + # ( + # search_folder, + # found_models, + # ) = self.generate.model_manager.search_models(search_folder) + # socketio.emit( + # "foundModels", + # {"search_folder": search_folder, "found_models": found_models}, + # ) + # except Exception as e: + # self.handle_exceptions(e) + # print("\n") + + # @socketio.on("addNewModel") + # def handle_add_model(new_model_config: dict): + # try: + # model_name = new_model_config["name"] + # del new_model_config["name"] + # model_attributes = new_model_config + # if len(model_attributes["vae"]) == 0: + # del model_attributes["vae"] + # update = False + # current_model_list = self.generate.model_manager.list_models() + # if model_name in current_model_list: + # update = True + + # print(f">> Adding New Model: {model_name}") + + # self.generate.model_manager.add_model( + # model_name=model_name, + # model_attributes=model_attributes, + # clobber=True, + # ) + # self.generate.model_manager.commit(opt.conf) + + # new_model_list = self.generate.model_manager.list_models() + # socketio.emit( + # "newModelAdded", + # { + # "new_model_name": model_name, + # "model_list": new_model_list, + # "update": update, + # }, + # ) + # print(f">> New Model Added: {model_name}") + # except Exception as e: + # self.handle_exceptions(e) + + # @socketio.on("deleteModel") + # def handle_delete_model(model_name: str): + # try: + # print(f">> Deleting Model: {model_name}") + # self.generate.model_manager.del_model(model_name) + # self.generate.model_manager.commit(opt.conf) + # updated_model_list = self.generate.model_manager.list_models() + # socketio.emit( + # "modelDeleted", + # { + # "deleted_model_name": model_name, + # "model_list": updated_model_list, + # }, + # ) + # print(f">> Model Deleted: {model_name}") + # except Exception as e: + # self.handle_exceptions(e) + + # @socketio.on("requestModelChange") + # def handle_set_model(model_name: str): + # try: + # print(f">> Model change requested: {model_name}") + # model = self.generate.set_model(model_name) + # model_list = self.generate.model_manager.list_models() + # if model is None: + # socketio.emit( + # "modelChangeFailed", + # {"model_name": model_name, "model_list": model_list}, + # ) + # else: + # socketio.emit( + # "modelChanged", + # {"model_name": model_name, "model_list": model_list}, + # ) + # except Exception as e: + # self.handle_exceptions(e) + + # @socketio.on("convertToDiffusers") + # def convert_to_diffusers(model_to_convert: dict): + # try: + # if model_info := self.generate.model_manager.model_info( + # model_name=model_to_convert["model_name"] + # ): + # if "weights" in model_info: + # ckpt_path = Path(model_info["weights"]) + # original_config_file = Path(model_info["config"]) + # model_name = model_to_convert["model_name"] + # model_description = model_info["description"] + # else: + # self.socketio.emit( + # "error", {"message": "Model is not a valid checkpoint file"} + # ) + # else: + # self.socketio.emit( + # "error", {"message": "Could not retrieve model info."} + # ) + + # if not ckpt_path.is_absolute(): + # ckpt_path = Path(Globals.root, ckpt_path) + + # if original_config_file and not original_config_file.is_absolute(): + # original_config_file = Path(Globals.root, original_config_file) + + # diffusers_path = Path( + # ckpt_path.parent.absolute(), f"{model_name}_diffusers" + # ) + + # if model_to_convert["save_location"] == "root": + # diffusers_path = Path( + # global_converted_ckpts_dir(), f"{model_name}_diffusers" + # ) + + # if ( + # model_to_convert["save_location"] == "custom" + # and model_to_convert["custom_location"] is not None + # ): + # diffusers_path = Path( + # model_to_convert["custom_location"], f"{model_name}_diffusers" + # ) + + # if diffusers_path.exists(): + # shutil.rmtree(diffusers_path) + + # self.generate.model_manager.convert_and_import( + # ckpt_path, + # diffusers_path, + # model_name=model_name, + # model_description=model_description, + # vae=None, + # original_config_file=original_config_file, + # commit_to_conf=opt.conf, + # ) + + # new_model_list = self.generate.model_manager.list_models() + # socketio.emit( + # "modelConverted", + # { + # "new_model_name": model_name, + # "model_list": new_model_list, + # "update": True, + # }, + # ) + # print(f">> Model Converted: {model_name}") + # except Exception as e: + # self.handle_exceptions(e) + + # @socketio.on("mergeDiffusersModels") + # def merge_diffusers_models(model_merge_info: dict): + # try: + # models_to_merge = model_merge_info["models_to_merge"] + # model_ids_or_paths = [ + # self.generate.model_manager.model_name_or_path(x) + # for x in models_to_merge + # ] + # merged_pipe = merge_diffusion_models( + # model_ids_or_paths, + # model_merge_info["alpha"], + # model_merge_info["interp"], + # model_merge_info["force"], + # ) + + # dump_path = global_models_dir() / "merged_models" + # if model_merge_info["model_merge_save_path"] is not None: + # dump_path = Path(model_merge_info["model_merge_save_path"]) + + # os.makedirs(dump_path, exist_ok=True) + # dump_path = dump_path / model_merge_info["merged_model_name"] + # merged_pipe.save_pretrained(dump_path, safe_serialization=1) + + # merged_model_config = dict( + # model_name=model_merge_info["merged_model_name"], + # description=f'Merge of models {", ".join(models_to_merge)}', + # commit_to_conf=opt.conf, + # ) + + # if vae := self.generate.model_manager.config[models_to_merge[0]].get( + # "vae", None + # ): + # print(f">> Using configured VAE assigned to {models_to_merge[0]}") + # merged_model_config.update(vae=vae) + + # self.generate.model_manager.import_diffuser_model( + # dump_path, **merged_model_config + # ) + # new_model_list = self.generate.model_manager.list_models() + + # socketio.emit( + # "modelsMerged", + # { + # "merged_models": models_to_merge, + # "merged_model_name": model_merge_info["merged_model_name"], + # "model_list": new_model_list, + # "update": True, + # }, + # ) + # print(f">> Models Merged: {models_to_merge}") + # print(f">> New Model Added: {model_merge_info['merged_model_name']}") + # except Exception as e: + # self.handle_exceptions(e) \ No newline at end of file diff --git a/invokeai/app/api/routers/sessions.py b/invokeai/app/api/routers/sessions.py index 67e3c840c0..0316398088 100644 --- a/invokeai/app/api/routers/sessions.py +++ b/invokeai/app/api/routers/sessions.py @@ -51,7 +51,7 @@ async def list_sessions( query: str = Query(default="", description="The query string to search for"), ) -> PaginatedResults[GraphExecutionState]: """Gets a list of sessions, optionally searching""" - if filter == "": + if query == "": result = ApiDependencies.invoker.services.graph_execution_manager.list( page, per_page ) @@ -270,3 +270,18 @@ async def invoke_session( ApiDependencies.invoker.invoke(session, invoke_all=all) return Response(status_code=202) + + +@session_router.delete( + "/{session_id}/invoke", + operation_id="cancel_session_invoke", + responses={ + 202: {"description": "The invocation is canceled"} + }, +) +async def cancel_session_invoke( + session_id: str = Path(description="The id of the session to cancel"), +) -> None: + """Invokes a session""" + ApiDependencies.invoker.cancel(session_id) + return Response(status_code=202) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 0ce2386557..ab05cb3344 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -14,7 +14,7 @@ from pydantic.schema import schema from ..backend import Args from .api.dependencies import ApiDependencies -from .api.routers import images, sessions +from .api.routers import images, sessions, models from .api.sockets import SocketIO from .invocations import * from .invocations.baseinvocation import BaseInvocation @@ -76,6 +76,8 @@ app.include_router(sessions.session_router, prefix="/api") app.include_router(images.images_router, prefix="/api") +app.include_router(models.models_router, prefix="/api") + # Build a custom OpenAPI to include all outputs # TODO: can outputs be included on metadata of invocation schemas somehow? diff --git a/invokeai/app/cli/commands.py b/invokeai/app/cli/commands.py index 21e65291e9..5f4da73303 100644 --- a/invokeai/app/cli/commands.py +++ b/invokeai/app/cli/commands.py @@ -4,7 +4,8 @@ from abc import ABC, abstractmethod import argparse from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints from pydantic import BaseModel, Field - +import networkx as nx +import matplotlib.pyplot as plt from ..invocations.image import ImageField from ..services.graph import GraphExecutionState from ..services.invoker import Invoker @@ -46,7 +47,7 @@ def add_parsers( f"--{name}", dest=name, type=field_type, - default=field.default, + default=field.default if field.default_factory is None else field.default_factory(), choices=allowed_values, help=field.field_info.description, ) @@ -55,7 +56,7 @@ def add_parsers( f"--{name}", dest=name, type=field.type_, - default=field.default, + default=field.default if field.default_factory is None else field.default_factory(), help=field.field_info.description, ) @@ -200,3 +201,39 @@ class SetDefaultCommand(BaseCommand): del context.defaults[self.field] else: context.defaults[self.field] = self.value + + +class DrawGraphCommand(BaseCommand): + """Debugs a graph""" + type: Literal['draw_graph'] = 'draw_graph' + + def run(self, context: CliContext) -> None: + session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id) + nxgraph = session.graph.nx_graph_flat() + + # Draw the networkx graph + plt.figure(figsize=(20, 20)) + pos = nx.spectral_layout(nxgraph) + nx.draw_networkx_nodes(nxgraph, pos, node_size=1000) + nx.draw_networkx_edges(nxgraph, pos, width=2) + nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif") + plt.axis("off") + plt.show() + + +class DrawExecutionGraphCommand(BaseCommand): + """Debugs an execution graph""" + type: Literal['draw_xgraph'] = 'draw_xgraph' + + def run(self, context: CliContext) -> None: + session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id) + nxgraph = session.execution_graph.nx_graph_flat() + + # Draw the networkx graph + plt.figure(figsize=(20, 20)) + pos = nx.spectral_layout(nxgraph) + nx.draw_networkx_nodes(nxgraph, pos, node_size=1000) + nx.draw_networkx_edges(nxgraph, pos, width=2) + nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif") + plt.axis("off") + plt.show() diff --git a/invokeai/app/cli/completer.py b/invokeai/app/cli/completer.py new file mode 100644 index 0000000000..86d3e100c3 --- /dev/null +++ b/invokeai/app/cli/completer.py @@ -0,0 +1,167 @@ +""" +Readline helper functions for cli_app.py +You may import the global singleton `completer` to get access to the +completer object. +""" +import atexit +import readline +import shlex + +from pathlib import Path +from typing import List, Dict, Literal, get_args, get_type_hints, get_origin + +from ...backend import ModelManager, Globals +from ..invocations.baseinvocation import BaseInvocation +from .commands import BaseCommand + +# singleton object, class variable +completer = None + +class Completer(object): + + def __init__(self, model_manager: ModelManager): + self.commands = self.get_commands() + self.matches = None + self.linebuffer = None + self.manager = model_manager + return + + def complete(self, text, state): + """ + Complete commands and switches fromm the node CLI command line. + Switches are determined in a context-specific manner. + """ + + buffer = readline.get_line_buffer() + if state == 0: + options = None + try: + current_command, current_switch = self.get_current_command(buffer) + options = self.get_command_options(current_command, current_switch) + except IndexError: + pass + options = options or list(self.parse_commands().keys()) + + if not text: # first time + self.matches = options + else: + self.matches = [s for s in options if s and s.startswith(text)] + + try: + match = self.matches[state] + except IndexError: + match = None + return match + + @classmethod + def get_commands(self)->List[object]: + """ + Return a list of all the client commands and invocations. + """ + return BaseCommand.get_commands() + BaseInvocation.get_invocations() + + def get_current_command(self, buffer: str)->tuple[str, str]: + """ + Parse the readline buffer to find the most recent command and its switch. + """ + if len(buffer)==0: + return None, None + tokens = shlex.split(buffer) + command = None + switch = None + for t in tokens: + if t[0].isalpha(): + if switch is None: + command = t + else: + switch = t + # don't try to autocomplete switches that are already complete + if switch and buffer.endswith(' '): + switch=None + return command or '', switch or '' + + def parse_commands(self)->Dict[str, List[str]]: + """ + Return a dict in which the keys are the command name + and the values are the parameters the command takes. + """ + result = dict() + for command in self.commands: + hints = get_type_hints(command) + name = get_args(hints['type'])[0] + result.update({name:hints}) + return result + + def get_command_options(self, command: str, switch: str)->List[str]: + """ + Return all the parameters that can be passed to the command as + command-line switches. Returns None if the command is unrecognized. + """ + parsed_commands = self.parse_commands() + if command not in parsed_commands: + return None + + # handle switches in the format "-foo=bar" + argument = None + if switch and '=' in switch: + switch, argument = switch.split('=') + + parameter = switch.strip('-') + if parameter in parsed_commands[command]: + if argument is None: + return self.get_parameter_options(parameter, parsed_commands[command][parameter]) + else: + return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])] + else: + return [f"--{x}" for x in parsed_commands[command].keys()] + + def get_parameter_options(self, parameter: str, typehint)->List[str]: + """ + Given a parameter type (such as Literal), offers autocompletions. + """ + if get_origin(typehint) == Literal: + return get_args(typehint) + if parameter == 'model': + return self.manager.model_names() + + def _pre_input_hook(self): + if self.linebuffer: + readline.insert_text(self.linebuffer) + readline.redisplay() + self.linebuffer = None + +def set_autocompleter(model_manager: ModelManager) -> Completer: + global completer + + if completer: + return completer + + completer = Completer(model_manager) + + readline.set_completer(completer.complete) + # pyreadline3 does not have a set_auto_history() method + try: + readline.set_auto_history(True) + except: + pass + readline.set_pre_input_hook(completer._pre_input_hook) + readline.set_completer_delims(" ") + readline.parse_and_bind("tab: complete") + readline.parse_and_bind("set print-completions-horizontally off") + readline.parse_and_bind("set page-completions on") + readline.parse_and_bind("set skip-completed-text on") + readline.parse_and_bind("set show-all-if-ambiguous on") + + histfile = Path(Globals.root, ".invoke_history") + try: + readline.read_history_file(histfile) + readline.set_history_length(1000) + except FileNotFoundError: + pass + except OSError: # file likely corrupted + newname = f"{histfile}.old" + print( + f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}" + ) + histfile.replace(Path(newname)) + atexit.register(readline.write_history_file, histfile) diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index 6390253250..a257825dcc 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -2,6 +2,7 @@ import argparse import os +import re import shlex import time from typing import ( @@ -12,14 +13,17 @@ from typing import ( from pydantic import BaseModel from pydantic.fields import Field +from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage + from ..backend import Args from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history +from .cli.completer import set_autocompleter from .invocations import * from .invocations.baseinvocation import BaseInvocation from .services.events import EventServiceBase from .services.model_manager_initializer import get_model_manager from .services.restoration_services import RestorationServices -from .services.graph import Edge, EdgeConnection, GraphExecutionState +from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible from .services.image_storage import DiskImageStorage from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_services import InvocationServices @@ -43,7 +47,7 @@ def add_invocation_args(command_parser): "-l", action="append", nargs=3, - help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)", + help="A link in the format 'source_node source_field dest_field'. source_node can be relative to history (e.g. -1)", ) command_parser.add_argument( @@ -93,6 +97,9 @@ def generate_matching_edges( invalid_fields = set(["type", "id"]) matching_fields = matching_fields.difference(invalid_fields) + # Validate types + matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])] + edges = [ Edge( source=EdgeConnection(node_id=a.id, field=field), @@ -130,6 +137,12 @@ def invoke_cli(): config.parse_args() model_manager = get_model_manager(config) + # This initializes the autocompleter and returns it. + # Currently nothing is done with the returned Completer + # object, but the object can be used to change autocompletion + # behavior on the fly, if desired. + completer = set_autocompleter(model_manager) + events = EventServiceBase() output_folder = os.path.abspath( @@ -142,7 +155,8 @@ def invoke_cli(): services = InvocationServices( model_manager=model_manager, events=events, - images=DiskImageStorage(output_folder), + latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')), + images=DiskImageStorage(f'{output_folder}/images'), queue=MemoryInvocationQueue(), graph_execution_manager=SqliteItemStorage[GraphExecutionState]( filename=db_location, table_name="graph_executions" @@ -155,6 +169,8 @@ def invoke_cli(): session: GraphExecutionState = invoker.create_execution_state() parser = get_command_parser() + re_negid = re.compile('^-[0-9]+$') + # Uncomment to print out previous sessions at startup # print(services.session_manager.list()) @@ -162,8 +178,8 @@ def invoke_cli(): while True: try: - cmd_input = input("> ") - except KeyboardInterrupt: + cmd_input = input("invoke> ") + except (KeyboardInterrupt, EOFError): # Ctrl-c exits break @@ -220,7 +236,11 @@ def invoke_cli(): # Parse provided links if "link_node" in args and args["link_node"]: for link in args["link_node"]: - link_node = context.session.graph.get_node(link) + node_id = link + if re_negid.match(node_id): + node_id = str(current_id + int(node_id)) + + link_node = context.session.graph.get_node(node_id) matching_edges = generate_matching_edges( link_node, command.command ) @@ -230,10 +250,15 @@ def invoke_cli(): if "link" in args and args["link"]: for link in args["link"]: - edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]] + edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]] + + node_id = link[0] + if re_negid.match(node_id): + node_id = str(current_id + int(node_id)) + edges.append( Edge( - source=EdgeConnection(node_id=link[1], field=link[0]), + source=EdgeConnection(node_id=node_id, field=link[1]), destination=EdgeConnection( node_id=command.command.id, field=link[2] ) diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py new file mode 100644 index 0000000000..c68b7449cc --- /dev/null +++ b/invokeai/app/invocations/collections.py @@ -0,0 +1,50 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +from typing import Literal + +import cv2 as cv +import numpy as np +import numpy.random +from PIL import Image, ImageOps +from pydantic import Field + +from ..services.image_storage import ImageType +from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput +from .image import ImageField, ImageOutput + + +class IntCollectionOutput(BaseInvocationOutput): + """A collection of integers""" + + type: Literal["int_collection"] = "int_collection" + + # Outputs + collection: list[int] = Field(default=[], description="The int collection") + + +class RangeInvocation(BaseInvocation): + """Creates a range""" + + type: Literal["range"] = "range" + + # Inputs + start: int = Field(default=0, description="The start of the range") + stop: int = Field(default=10, description="The stop of the range") + step: int = Field(default=1, description="The step of the range") + + def invoke(self, context: InvocationContext) -> IntCollectionOutput: + return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step))) + + +class RandomRangeInvocation(BaseInvocation): + """Creates a collection of random numbers""" + + type: Literal["random_range"] = "random_range" + + # Inputs + low: int = Field(default=0, description="The inclusive low value") + high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value") + size: int = Field(default=1, description="The number of values to generate") + + def invoke(self, context: InvocationContext) -> IntCollectionOutput: + return IntCollectionOutput(collection=list(numpy.random.randint(self.low, self.high, size=self.size))) diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index b8140b11e9..d6e624b325 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -1,22 +1,19 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from datetime import datetime, timezone -from typing import Any, Literal, Optional, Union +from functools import partial +from typing import Literal, Optional, Union import numpy as np - from torch import Tensor -from PIL import Image + from pydantic import Field -from skimage.exposure.histogram_matching import match_histograms from ..services.image_storage import ImageType -from ..services.invocation_services import InvocationServices from .baseinvocation import BaseInvocation, InvocationContext from .image import ImageField, ImageOutput -from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator +from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator from ...backend.stable_diffusion import PipelineIntermediateState -from ...backend.util.util import image_to_dataURL +from ..util.util import diffusers_step_callback_adapter, CanceledException SAMPLER_NAME_VALUES = Literal[ tuple(InvokeAIGenerator.schedulers()) @@ -45,32 +42,26 @@ class TextToImageInvocation(BaseInvocation): # TODO: pass this an emitter method or something? or a session for dispatching? def dispatch_progress( - self, context: InvocationContext, sample: Tensor, step: int - ) -> None: - # TODO: only output a preview image when requested - image = Generator.sample_to_lowres_estimated_image(sample) + self, context: InvocationContext, intermediate_state: PipelineIntermediateState + ) -> None: + if (context.services.queue.is_canceled(context.graph_execution_state_id)): + raise CanceledException - (width, height) = image.size - width *= 8 - height *= 8 - - dataURL = image_to_dataURL(image, image_format="JPEG") - - context.services.events.emit_generator_progress( - context.graph_execution_state_id, - self.id, - { - "width": width, - "height": height, - "dataURL": dataURL - }, - step, - self.steps, - ) + step = intermediate_state.step + if intermediate_state.predicted_original is not None: + # Some schedulers report not only the noisy latents at the current timestep, + # but also their estimate so far of what the de-noised latents will be. + sample = intermediate_state.predicted_original + else: + sample = intermediate_state.latents + + diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context) def invoke(self, context: InvocationContext) -> ImageOutput: - def step_callback(state: PipelineIntermediateState): - self.dispatch_progress(context, state.latents, state.step) + # def step_callback(state: PipelineIntermediateState): + # if (context.services.queue.is_canceled(context.graph_execution_state_id)): + # raise CanceledException + # self.dispatch_progress(context, state.latents, state.step) # Handle invalid model parameter # TODO: figure out if this can be done via a validator that uses the model_cache @@ -79,7 +70,7 @@ class TextToImageInvocation(BaseInvocation): model= context.services.model_manager.get_model() outputs = Txt2Img(model).generate( prompt=self.prompt, - step_callback=step_callback, + step_callback=partial(self.dispatch_progress, context), **self.dict( exclude={"prompt"} ), # Shorthand for passing all of the parameters above manually @@ -116,6 +107,22 @@ class ImageToImageInvocation(TextToImageInvocation): description="Whether or not the result should be fit to the aspect ratio of the input image", ) + def dispatch_progress( + self, context: InvocationContext, intermediate_state: PipelineIntermediateState + ) -> None: + if (context.services.queue.is_canceled(context.graph_execution_state_id)): + raise CanceledException + + step = intermediate_state.step + if intermediate_state.predicted_original is not None: + # Some schedulers report not only the noisy latents at the current timestep, + # but also their estimate so far of what the de-noised latents will be. + sample = intermediate_state.predicted_original + else: + sample = intermediate_state.latents + + diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context) + def invoke(self, context: InvocationContext) -> ImageOutput: image = ( None @@ -126,24 +133,23 @@ class ImageToImageInvocation(TextToImageInvocation): ) mask = None - def step_callback(sample, step=0): - self.dispatch_progress(context, sample, step) - # Handle invalid model parameter # TODO: figure out if this can be done via a validator that uses the model_cache # TODO: How to get the default model name now? model = context.services.model_manager.get_model() - generator_output = next( - Img2Img(model).generate( + outputs = Img2Img(model).generate( prompt=self.prompt, init_image=image, init_mask=mask, - step_callback=step_callback, + step_callback=partial(self.dispatch_progress, context), **self.dict( exclude={"prompt", "image", "mask"} ), # Shorthand for passing all of the parameters above manually ) - ) + + # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object + # each time it is called. We only need the first one. + generator_output = next(outputs) result_image = generator_output.image @@ -173,6 +179,22 @@ class InpaintInvocation(ImageToImageInvocation): description="The amount by which to replace masked areas with latent noise", ) + def dispatch_progress( + self, context: InvocationContext, intermediate_state: PipelineIntermediateState + ) -> None: + if (context.services.queue.is_canceled(context.graph_execution_state_id)): + raise CanceledException + + step = intermediate_state.step + if intermediate_state.predicted_original is not None: + # Some schedulers report not only the noisy latents at the current timestep, + # but also their estimate so far of what the de-noised latents will be. + sample = intermediate_state.predicted_original + else: + sample = intermediate_state.latents + + diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context) + def invoke(self, context: InvocationContext) -> ImageOutput: image = ( None @@ -187,24 +209,23 @@ class InpaintInvocation(ImageToImageInvocation): else context.services.images.get(self.mask.image_type, self.mask.image_name) ) - def step_callback(sample, step=0): - self.dispatch_progress(context, sample, step) - # Handle invalid model parameter # TODO: figure out if this can be done via a validator that uses the model_cache # TODO: How to get the default model name now? - manager = context.services.model_manager.get_model() - generator_output = next( - Inpaint(model).generate( + model = context.services.model_manager.get_model() + outputs = Inpaint(model).generate( prompt=self.prompt, - init_image=image, - mask_image=mask, - step_callback=step_callback, + init_img=image, + init_mask=mask, + step_callback=partial(self.dispatch_progress, context), **self.dict( exclude={"prompt", "image", "mask"} ), # Shorthand for passing all of the parameters above manually ) - ) + + # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object + # each time it is called. We only need the first one. + generator_output = next(outputs) result_image = generator_output.image diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 7330cd73be..65ea4c3edb 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -28,12 +28,28 @@ class ImageOutput(BaseInvocationOutput): image: ImageField = Field(default=None, description="The output image") #fmt: on + class Config: + schema_extra = { + 'required': [ + 'type', + 'image', + ] + } + class MaskOutput(BaseInvocationOutput): """Base class for invocations that output a mask""" #fmt: off type: Literal["mask"] = "mask" mask: ImageField = Field(default=None, description="The output mask") - #fomt: on + #fmt: on + + class Config: + schema_extra = { + 'required': [ + 'type', + 'mask', + ] + } # TODO: this isn't really necessary anymore class LoadImageInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py new file mode 100644 index 0000000000..0481282ba9 --- /dev/null +++ b/invokeai/app/invocations/latent.py @@ -0,0 +1,321 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +from typing import Literal, Optional +from pydantic import BaseModel, Field +from torch import Tensor +import torch + +from ...backend.model_management.model_manager import ModelManager +from ...backend.util.devices import CUDA_DEVICE, torch_dtype +from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings +from ...backend.image_util.seamless import configure_model_padding +from ...backend.prompting.conditioning import get_uc_and_c_and_ec +from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +import numpy as np +from accelerate.utils import set_seed +from ..services.image_storage import ImageType +from .baseinvocation import BaseInvocation, InvocationContext +from .image import ImageField, ImageOutput +from ...backend.generator import Generator +from ...backend.stable_diffusion import PipelineIntermediateState +from ...backend.util.util import image_to_dataURL +from diffusers.schedulers import SchedulerMixin as Scheduler +import diffusers +from diffusers import DiffusionPipeline + + +class LatentsField(BaseModel): + """A latents field used for passing latents between invocations""" + + latents_name: Optional[str] = Field(default=None, description="The name of the latents") + + +class LatentsOutput(BaseInvocationOutput): + """Base class for invocations that output latents""" + #fmt: off + type: Literal["latent_output"] = "latent_output" + latents: LatentsField = Field(default=None, description="The output latents") + #fmt: on + +class NoiseOutput(BaseInvocationOutput): + """Invocation noise output""" + #fmt: off + type: Literal["noise_output"] = "noise_output" + noise: LatentsField = Field(default=None, description="The output noise") + #fmt: on + + +# TODO: this seems like a hack +scheduler_map = dict( + ddim=diffusers.DDIMScheduler, + dpmpp_2=diffusers.DPMSolverMultistepScheduler, + k_dpm_2=diffusers.KDPM2DiscreteScheduler, + k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler, + k_dpmpp_2=diffusers.DPMSolverMultistepScheduler, + k_euler=diffusers.EulerDiscreteScheduler, + k_euler_a=diffusers.EulerAncestralDiscreteScheduler, + k_heun=diffusers.HeunDiscreteScheduler, + k_lms=diffusers.LMSDiscreteScheduler, + plms=diffusers.PNDMScheduler, +) + + +SAMPLER_NAME_VALUES = Literal[ + tuple(list(scheduler_map.keys())) +] + + +def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: + scheduler_class = scheduler_map.get(scheduler_name,'ddim') + scheduler = scheduler_class.from_config(model.scheduler.config) + # hack copied over from generate.py + if not hasattr(scheduler, 'uses_inpainting_model'): + scheduler.uses_inpainting_model = lambda: False + return scheduler + + +def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8): + # limit noise to only the diffusion image channels, not the mask channels + input_channels = min(latent_channels, 4) + use_device = "cpu" if (use_mps_noise or device.type == "mps") else device + generator = torch.Generator(device=use_device).manual_seed(seed) + x = torch.randn( + [ + 1, + input_channels, + height // downsampling_factor, + width // downsampling_factor, + ], + dtype=torch_dtype(device), + device=use_device, + generator=generator, + ).to(device) + # if self.perlin > 0.0: + # perlin_noise = self.get_perlin_noise( + # width // self.downsampling_factor, height // self.downsampling_factor + # ) + # x = (1 - self.perlin) * x + self.perlin * perlin_noise + return x + + +class NoiseInvocation(BaseInvocation): + """Generates latent noise.""" + + type: Literal["noise"] = "noise" + + # Inputs + seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", ) + width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", ) + height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", ) + + def invoke(self, context: InvocationContext) -> NoiseOutput: + device = torch.device(CUDA_DEVICE) + noise = get_noise(self.width, self.height, device, self.seed) + + name = f'{context.graph_execution_state_id}__{self.id}' + context.services.latents.set(name, noise) + return NoiseOutput( + noise=LatentsField(latents_name=name) + ) + + +# Text to image +class TextToLatentsInvocation(BaseInvocation): + """Generates latents from a prompt.""" + + type: Literal["t2l"] = "t2l" + + # Inputs + # TODO: consider making prompt optional to enable providing prompt through a link + # fmt: off + prompt: Optional[str] = Field(description="The prompt to generate an image from") + seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", ) + noise: Optional[LatentsField] = Field(description="The noise to use") + steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") + width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", ) + height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", ) + cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) + sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" ) + seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) + seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") + model: str = Field(default="", description="The model to use (currently ignored)") + progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) + # fmt: on + + # TODO: pass this an emitter method or something? or a session for dispatching? + def dispatch_progress( + self, context: InvocationContext, sample: Tensor, step: int + ) -> None: + # TODO: only output a preview image when requested + image = Generator.sample_to_lowres_estimated_image(sample) + + (width, height) = image.size + width *= 8 + height *= 8 + + dataURL = image_to_dataURL(image, image_format="JPEG") + + context.services.events.emit_generator_progress( + context.graph_execution_state_id, + self.id, + { + "width": width, + "height": height, + "dataURL": dataURL + }, + step, + self.steps, + ) + + def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline: + model_info = model_manager.get_model(self.model) + model_name = model_info['model_name'] + model_hash = model_info['hash'] + model: StableDiffusionGeneratorPipeline = model_info['model'] + model.scheduler = get_scheduler( + model=model, + scheduler_name=self.sampler_name + ) + + if isinstance(model, DiffusionPipeline): + for component in [model.unet, model.vae]: + configure_model_padding(component, + self.seamless, + self.seamless_axes + ) + else: + configure_model_padding(model, + self.seamless, + self.seamless_axes + ) + + return model + + + def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData: + uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model) + conditioning_data = ConditioningData( + uc, + c, + self.cfg_scale, + extra_conditioning_info, + postprocessing_settings=PostprocessingSettings( + threshold=0.0,#threshold, + warmup=0.2,#warmup, + h_symmetry_time_pct=None,#h_symmetry_time_pct, + v_symmetry_time_pct=None#v_symmetry_time_pct, + ), + ).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta) + return conditioning_data + + + def invoke(self, context: InvocationContext) -> LatentsOutput: + noise = context.services.latents.get(self.noise.latents_name) + + def step_callback(state: PipelineIntermediateState): + self.dispatch_progress(context, state.latents, state.step) + + model = self.get_model(context.services.model_manager) + conditioning_data = self.get_conditioning_data(model) + + # TODO: Verify the noise is the right size + + result_latents, result_attention_map_saver = model.latents_from_embeddings( + latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)), + noise=noise, + num_inference_steps=self.steps, + conditioning_data=conditioning_data, + callback=step_callback + ) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + torch.cuda.empty_cache() + + name = f'{context.graph_execution_state_id}__{self.id}' + context.services.latents.set(name, result_latents) + return LatentsOutput( + latents=LatentsField(latents_name=name) + ) + + +class LatentsToLatentsInvocation(TextToLatentsInvocation): + """Generates latents using latents as base image.""" + + type: Literal["l2l"] = "l2l" + + # Inputs + latents: Optional[LatentsField] = Field(description="The latents to use as a base image") + strength: float = Field(default=0.5, description="The strength of the latents to use") + + def invoke(self, context: InvocationContext) -> LatentsOutput: + noise = context.services.latents.get(self.noise.latents_name) + latent = context.services.latents.get(self.latents.latents_name) + + def step_callback(state: PipelineIntermediateState): + self.dispatch_progress(context, state.latents, state.step) + + model = self.get_model(context.services.model_manager) + conditioning_data = self.get_conditioning_data(model) + + # TODO: Verify the noise is the right size + + initial_latents = latent if self.strength < 1.0 else torch.zeros_like( + latent, device=model.device, dtype=latent.dtype + ) + + timesteps, _ = model.get_img2img_timesteps( + self.steps, + self.strength, + device=model.device, + ) + + result_latents, result_attention_map_saver = model.latents_from_embeddings( + latents=initial_latents, + timesteps=timesteps, + noise=noise, + num_inference_steps=self.steps, + conditioning_data=conditioning_data, + callback=step_callback + ) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + torch.cuda.empty_cache() + + name = f'{context.graph_execution_state_id}__{self.id}' + context.services.latents.set(name, result_latents) + return LatentsOutput( + latents=LatentsField(latents_name=name) + ) + + +# Latent to image +class LatentsToImageInvocation(BaseInvocation): + """Generates an image from latents.""" + + type: Literal["l2i"] = "l2i" + + # Inputs + latents: Optional[LatentsField] = Field(description="The latents to generate an image from") + model: str = Field(default="", description="The model to use") + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> ImageOutput: + latents = context.services.latents.get(self.latents.latents_name) + + # TODO: this only really needs the vae + model_info = context.services.model_manager.get_model(self.model) + model: StableDiffusionGeneratorPipeline = model_info['model'] + + with torch.inference_mode(): + np_image = model.decode_latents(latents) + image = model.numpy_to_pil(np_image)[0] + + image_type = ImageType.RESULT + image_name = context.services.images.create_name( + context.graph_execution_state_id, self.id + ) + context.services.images.save(image_type, image_name, image) + return ImageOutput( + image=ImageField(image_type=image_type, image_name=image_name) + ) diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py new file mode 100644 index 0000000000..ecdcc834c7 --- /dev/null +++ b/invokeai/app/invocations/math.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +from datetime import datetime, timezone +from typing import Literal, Optional + +import numpy +from PIL import Image, ImageFilter, ImageOps +from pydantic import BaseModel, Field + +from ..services.image_storage import ImageType +from ..services.invocation_services import InvocationServices +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext + + +class IntOutput(BaseInvocationOutput): + """An integer output""" + #fmt: off + type: Literal["int_output"] = "int_output" + a: int = Field(default=None, description="The output integer") + #fmt: on + + +class AddInvocation(BaseInvocation): + """Adds two numbers""" + #fmt: off + type: Literal["add"] = "add" + a: int = Field(default=0, description="The first number") + b: int = Field(default=0, description="The second number") + #fmt: on + + def invoke(self, context: InvocationContext) -> IntOutput: + return IntOutput(a=self.a + self.b) + + +class SubtractInvocation(BaseInvocation): + """Subtracts two numbers""" + #fmt: off + type: Literal["sub"] = "sub" + a: int = Field(default=0, description="The first number") + b: int = Field(default=0, description="The second number") + #fmt: on + + def invoke(self, context: InvocationContext) -> IntOutput: + return IntOutput(a=self.a - self.b) + + +class MultiplyInvocation(BaseInvocation): + """Multiplies two numbers""" + #fmt: off + type: Literal["mul"] = "mul" + a: int = Field(default=0, description="The first number") + b: int = Field(default=0, description="The second number") + #fmt: on + + def invoke(self, context: InvocationContext) -> IntOutput: + return IntOutput(a=self.a * self.b) + + +class DivideInvocation(BaseInvocation): + """Divides two numbers""" + #fmt: off + type: Literal["div"] = "div" + a: int = Field(default=0, description="The first number") + b: int = Field(default=0, description="The second number") + #fmt: on + + def invoke(self, context: InvocationContext) -> IntOutput: + return IntOutput(a=int(self.a / self.b)) diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 3544f30859..0c7e3069df 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -12,3 +12,11 @@ class PromptOutput(BaseInvocationOutput): prompt: str = Field(default=None, description="The output prompt") #fmt: on + + class Config: + schema_extra = { + 'required': [ + 'type', + 'prompt', + ] + } diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 0d4102c416..98c2f29308 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -127,6 +127,13 @@ class NodeAlreadyExecutedError(Exception): class GraphInvocationOutput(BaseInvocationOutput): type: Literal["graph_output"] = "graph_output" + class Config: + schema_extra = { + 'required': [ + 'type', + 'image', + ] + } # TODO: Fill this out and move to invocations class GraphInvocation(BaseInvocation): @@ -147,6 +154,13 @@ class IterateInvocationOutput(BaseInvocationOutput): item: Any = Field(description="The item being iterated over") + class Config: + schema_extra = { + 'required': [ + 'type', + 'item', + ] + } # TODO: Fill this out and move to invocations class IterateInvocation(BaseInvocation): @@ -169,6 +183,13 @@ class CollectInvocationOutput(BaseInvocationOutput): collection: list[Any] = Field(description="The collection of input items") + class Config: + schema_extra = { + 'required': [ + 'type', + 'collection', + ] + } class CollectInvocation(BaseInvocation): """Collects values into a collection""" @@ -1048,9 +1069,8 @@ class GraphExecutionState(BaseModel): n for n in prepared_nodes if all( - pit + nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators - if nx.has_path(execution_graph, pit[0], n) ) ), None, diff --git a/invokeai/app/services/image_storage.py b/invokeai/app/services/image_storage.py index ad0ff23f14..c80a4bfb31 100644 --- a/invokeai/app/services/image_storage.py +++ b/invokeai/app/services/image_storage.py @@ -9,6 +9,7 @@ from queue import Queue from typing import Dict from PIL.Image import Image +from invokeai.app.util.save_thumbnail import save_thumbnail from invokeai.backend.image_util import PngWriter @@ -66,6 +67,9 @@ class DiskImageStorage(ImageStorageBase): Path(os.path.join(output_folder, image_type)).mkdir( parents=True, exist_ok=True ) + Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir( + parents=True, exist_ok=True + ) def get(self, image_type: ImageType, image_name: str) -> Image: image_path = self.get_path(image_type, image_name) @@ -87,7 +91,11 @@ class DiskImageStorage(ImageStorageBase): self.__pngWriter.save_image_and_prompt_to_png( image, "", image_subpath, None ) # TODO: just pass full path to png writer - + save_thumbnail( + image=image, + filename=image_name, + path=os.path.join(self.__output_folder, image_type, "thumbnails"), + ) image_path = self.get_path(image_type, image_name) self.__set_cache(image_path, image) diff --git a/invokeai/app/services/invocation_queue.py b/invokeai/app/services/invocation_queue.py index 88a4f8708d..4a42789b12 100644 --- a/invokeai/app/services/invocation_queue.py +++ b/invokeai/app/services/invocation_queue.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from queue import Queue +import time # TODO: make this serializable @@ -10,6 +11,7 @@ class InvocationQueueItem: graph_execution_state_id: str invocation_id: str invoke_all: bool + timestamp: float def __init__( self, @@ -22,6 +24,7 @@ class InvocationQueueItem: self.graph_execution_state_id = graph_execution_state_id self.invocation_id = invocation_id self.invoke_all = invoke_all + self.timestamp = time.time() class InvocationQueueABC(ABC): @@ -35,15 +38,44 @@ class InvocationQueueABC(ABC): def put(self, item: InvocationQueueItem | None) -> None: pass + @abstractmethod + def cancel(self, graph_execution_state_id: str) -> None: + pass + + @abstractmethod + def is_canceled(self, graph_execution_state_id: str) -> bool: + pass + class MemoryInvocationQueue(InvocationQueueABC): __queue: Queue + __cancellations: dict[str, float] def __init__(self): self.__queue = Queue() + self.__cancellations = dict() def get(self) -> InvocationQueueItem: - return self.__queue.get() + item = self.__queue.get() + + while isinstance(item, InvocationQueueItem) \ + and item.graph_execution_state_id in self.__cancellations \ + and self.__cancellations[item.graph_execution_state_id] > item.timestamp: + item = self.__queue.get() + + # Clear old items + for graph_execution_state_id in list(self.__cancellations.keys()): + if self.__cancellations[graph_execution_state_id] < item.timestamp: + del self.__cancellations[graph_execution_state_id] + + return item def put(self, item: InvocationQueueItem | None) -> None: self.__queue.put(item) + + def cancel(self, graph_execution_state_id: str) -> None: + if graph_execution_state_id not in self.__cancellations: + self.__cancellations[graph_execution_state_id] = time.time() + + def is_canceled(self, graph_execution_state_id: str) -> bool: + return graph_execution_state_id in self.__cancellations diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 7f24c34378..2cd0f55fd9 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -2,6 +2,7 @@ from invokeai.backend import ModelManager from .events import EventServiceBase +from .latent_storage import LatentsStorageBase from .image_storage import ImageStorageBase from .restoration_services import RestorationServices from .invocation_queue import InvocationQueueABC @@ -11,6 +12,7 @@ class InvocationServices: """Services that can be used by invocations""" events: EventServiceBase + latents: LatentsStorageBase images: ImageStorageBase queue: InvocationQueueABC model_manager: ModelManager @@ -24,6 +26,7 @@ class InvocationServices: self, model_manager: ModelManager, events: EventServiceBase, + latents: LatentsStorageBase, images: ImageStorageBase, queue: InvocationQueueABC, graph_execution_manager: ItemStorageABC["GraphExecutionState"], @@ -32,6 +35,7 @@ class InvocationServices: ): self.model_manager = model_manager self.events = events + self.latents = latents self.images = images self.queue = queue self.graph_execution_manager = graph_execution_manager diff --git a/invokeai/app/services/invoker.py b/invokeai/app/services/invoker.py index f234cd827b..e3fa6da851 100644 --- a/invokeai/app/services/invoker.py +++ b/invokeai/app/services/invoker.py @@ -33,7 +33,6 @@ class Invoker: self.services.graph_execution_manager.set(graph_execution_state) # Queue the invocation - print(f"queueing item {invocation.id}") self.services.queue.put( InvocationQueueItem( # session_id = session.id, @@ -50,6 +49,10 @@ class Invoker: new_state = GraphExecutionState(graph=Graph() if graph is None else graph) self.services.graph_execution_manager.set(new_state) return new_state + + def cancel(self, graph_execution_state_id: str) -> None: + """Cancels the given execution state""" + self.services.queue.cancel(graph_execution_state_id) def __start_service(self, service) -> None: # Call start() method on any services that have it diff --git a/invokeai/app/services/latent_storage.py b/invokeai/app/services/latent_storage.py new file mode 100644 index 0000000000..0184692e05 --- /dev/null +++ b/invokeai/app/services/latent_storage.py @@ -0,0 +1,93 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +import os +from abc import ABC, abstractmethod +from pathlib import Path +from queue import Queue +from typing import Dict + +import torch + +class LatentsStorageBase(ABC): + """Responsible for storing and retrieving latents.""" + + @abstractmethod + def get(self, name: str) -> torch.Tensor: + pass + + @abstractmethod + def set(self, name: str, data: torch.Tensor) -> None: + pass + + @abstractmethod + def delete(self, name: str) -> None: + pass + + +class ForwardCacheLatentsStorage(LatentsStorageBase): + """Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage""" + + __cache: Dict[str, torch.Tensor] + __cache_ids: Queue + __max_cache_size: int + __underlying_storage: LatentsStorageBase + + def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20): + self.__underlying_storage = underlying_storage + self.__cache = dict() + self.__cache_ids = Queue() + self.__max_cache_size = max_cache_size + + def get(self, name: str) -> torch.Tensor: + cache_item = self.__get_cache(name) + if cache_item is not None: + return cache_item + + latent = self.__underlying_storage.get(name) + self.__set_cache(name, latent) + return latent + + def set(self, name: str, data: torch.Tensor) -> None: + self.__underlying_storage.set(name, data) + self.__set_cache(name, data) + + def delete(self, name: str) -> None: + self.__underlying_storage.delete(name) + if name in self.__cache: + del self.__cache[name] + + def __get_cache(self, name: str) -> torch.Tensor|None: + return None if name not in self.__cache else self.__cache[name] + + def __set_cache(self, name: str, data: torch.Tensor): + if not name in self.__cache: + self.__cache[name] = data + self.__cache_ids.put(name) + if self.__cache_ids.qsize() > self.__max_cache_size: + self.__cache.pop(self.__cache_ids.get()) + + +class DiskLatentsStorage(LatentsStorageBase): + """Stores latents in a folder on disk without caching""" + + __output_folder: str + + def __init__(self, output_folder: str): + self.__output_folder = output_folder + Path(output_folder).mkdir(parents=True, exist_ok=True) + + def get(self, name: str) -> torch.Tensor: + latent_path = self.get_path(name) + return torch.load(latent_path) + + def set(self, name: str, data: torch.Tensor) -> None: + latent_path = self.get_path(name) + torch.save(data, latent_path) + + def delete(self, name: str) -> None: + latent_path = self.get_path(name) + os.remove(latent_path) + + def get_path(self, name: str) -> str: + return os.path.join(self.__output_folder, name) + \ No newline at end of file diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index 5baa64503c..b460563278 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -4,7 +4,7 @@ from threading import Event, Thread from ..invocations.baseinvocation import InvocationContext from .invocation_queue import InvocationQueueItem from .invoker import InvocationProcessorABC, Invoker - +from ..util.util import CanceledException class DefaultInvocationProcessor(InvocationProcessorABC): __invoker_thread: Thread @@ -58,6 +58,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC): ) ) + # Check queue to see if this is canceled, and skip if so + if self.__invoker.services.queue.is_canceled( + graph_execution_state.id + ): + continue + # Save outputs and history graph_execution_state.complete(invocation.id, outputs) @@ -76,6 +82,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC): except KeyboardInterrupt: pass + except CanceledException: + pass + except Exception as e: error = traceback.format_exc() @@ -95,6 +104,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC): ) pass + + # Check queue to see if this is canceled, and skip if so + if self.__invoker.services.queue.is_canceled( + graph_execution_state.id + ): + continue # Queue any further commands if invoking all is_complete = graph_execution_state.is_complete() diff --git a/invokeai/app/services/sqlite.py b/invokeai/app/services/sqlite.py index e5bba4ad31..fd089014bb 100644 --- a/invokeai/app/services/sqlite.py +++ b/invokeai/app/services/sqlite.py @@ -59,6 +59,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""", (item.json(),), ) + self._conn.commit() finally: self._lock.release() self._on_changed(item) @@ -84,6 +85,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): self._cursor.execute( f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),) ) + self._conn.commit() finally: self._lock.release() self._on_deleted(id) diff --git a/invokeai/app/util/save_thumbnail.py b/invokeai/app/util/save_thumbnail.py new file mode 100644 index 0000000000..86fdbe7ef6 --- /dev/null +++ b/invokeai/app/util/save_thumbnail.py @@ -0,0 +1,25 @@ +import os +from PIL import Image + + +def save_thumbnail( + image: Image.Image, + filename: str, + path: str, + size: int = 256, +) -> str: + """ + Saves a thumbnail of an image, returning its path. + """ + base_filename = os.path.splitext(filename)[0] + thumbnail_path = os.path.join(path, base_filename + ".webp") + + if os.path.exists(thumbnail_path): + return thumbnail_path + + image_copy = image.copy() + image_copy.thumbnail(size=(size, size)) + + image_copy.save(thumbnail_path, "WEBP") + + return thumbnail_path diff --git a/invokeai/app/util/util.py b/invokeai/app/util/util.py new file mode 100644 index 0000000000..60a5072cb0 --- /dev/null +++ b/invokeai/app/util/util.py @@ -0,0 +1,42 @@ +import torch +from PIL import Image +from ..invocations.baseinvocation import InvocationContext +from ...backend.util.util import image_to_dataURL +from ...backend.generator.base import Generator +from ...backend.stable_diffusion import PipelineIntermediateState + +class CanceledException(Exception): + pass + +def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ): + # TODO: only output a preview image when requested + image = Generator.sample_to_lowres_estimated_image(sample) + + (width, height) = image.size + width *= 8 + height *= 8 + + dataURL = image_to_dataURL(image, image_format="JPEG") + + context.services.events.emit_generator_progress( + context.graph_execution_state_id, + id, + { + "width": width, + "height": height, + "dataURL": dataURL + }, + step, + steps, + ) + +def diffusers_step_callback_adapter(*cb_args, **kwargs): + """ + txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState. + This adapter grabs the needed data and passes it along to the callback function. + """ + if isinstance(cb_args[0], PipelineIntermediateState): + progress_state: PipelineIntermediateState = cb_args[0] + return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs) + else: + return fast_latents_step_callback(*cb_args, **kwargs) diff --git a/invokeai/backend/generator/base.py b/invokeai/backend/generator/base.py index e30b77ec33..ee56077fa8 100644 --- a/invokeai/backend/generator/base.py +++ b/invokeai/backend/generator/base.py @@ -21,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter from accelerate.utils import set_seed from diffusers import DiffusionPipeline from tqdm import trange -from typing import List, Iterator, Type +from typing import Callable, List, Iterator, Optional, Type from dataclasses import dataclass, field from diffusers.schedulers import SchedulerMixin as Scheduler @@ -35,23 +35,23 @@ downsampling = 8 @dataclass class InvokeAIGeneratorBasicParams: - seed: int=None + seed: Optional[int]=None width: int=512 height: int=512 - cfg_scale: int=7.5 + cfg_scale: float=7.5 steps: int=20 ddim_eta: float=0.0 - scheduler: int='ddim' + scheduler: str='ddim' precision: str='float16' perlin: float=0.0 - threshold: int=0.0 + threshold: float=0.0 seamless: bool=False seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y']) - h_symmetry_time_pct: float=None - v_symmetry_time_pct: float=None + h_symmetry_time_pct: Optional[float]=None + v_symmetry_time_pct: Optional[float]=None variation_amount: float = 0.0 with_variations: list=field(default_factory=list) - safety_checker: SafetyChecker=None + safety_checker: Optional[SafetyChecker]=None @dataclass class InvokeAIGeneratorOutput: @@ -61,10 +61,10 @@ class InvokeAIGeneratorOutput: and the model hash, as well as all the generate() parameters that went into generating the image (in .params, also available as attributes) ''' - image: Image + image: Image.Image seed: int model_hash: str - attention_maps_images: List[Image] + attention_maps_images: List[Image.Image] params: Namespace # we are interposing a wrapper around the original Generator classes so that @@ -92,8 +92,8 @@ class InvokeAIGenerator(metaclass=ABCMeta): def generate(self, prompt: str='', - callback: callable=None, - step_callback: callable=None, + callback: Optional[Callable]=None, + step_callback: Optional[Callable]=None, iterations: int=1, **keyword_args, )->Iterator[InvokeAIGeneratorOutput]: @@ -206,10 +206,10 @@ class Txt2Img(InvokeAIGenerator): # ------------------------------------ class Img2Img(InvokeAIGenerator): def generate(self, - init_image: Image | torch.FloatTensor, + init_image: Image.Image | torch.FloatTensor, strength: float=0.75, **keyword_args - )->List[InvokeAIGeneratorOutput]: + )->Iterator[InvokeAIGeneratorOutput]: return super().generate(init_image=init_image, strength=strength, **keyword_args @@ -223,7 +223,7 @@ class Img2Img(InvokeAIGenerator): # Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff class Inpaint(Img2Img): def generate(self, - mask_image: Image | torch.FloatTensor, + mask_image: Image.Image | torch.FloatTensor, # Seam settings - when 0, doesn't fill seam seam_size: int = 0, seam_blur: int = 0, @@ -236,7 +236,7 @@ class Inpaint(Img2Img): inpaint_height=None, inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF), **keyword_args - )->List[InvokeAIGeneratorOutput]: + )->Iterator[InvokeAIGeneratorOutput]: return super().generate( mask_image=mask_image, seam_size=seam_size, @@ -263,7 +263,7 @@ class Embiggen(Txt2Img): embiggen: list=None, embiggen_tiles: list = None, strength: float=0.75, - **kwargs)->List[InvokeAIGeneratorOutput]: + **kwargs)->Iterator[InvokeAIGeneratorOutput]: return super().generate(embiggen=embiggen, embiggen_tiles=embiggen_tiles, strength=strength, diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index 793ba024cf..b46586611d 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -372,22 +372,32 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False unet_key = "model.diffusion_model." # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA if sum(k.startswith("model_ema") for k in keys) > 100: - print(f" | Checkpoint {path} has both EMA and non-EMA weights.") + print(f" | Checkpoint {path} has both EMA and non-EMA weights.") if extract_ema: - print(" | Extracting EMA weights (usually better for inference)") + print(" | Extracting EMA weights (usually better for inference)") for key in keys: if key.startswith("model.diffusion_model"): flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( - flat_ema_key - ) + flat_ema_key_alt = "model_ema." + "".join(key.split(".")[2:]) + if flat_ema_key in checkpoint: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( + flat_ema_key + ) + elif flat_ema_key_alt in checkpoint: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( + flat_ema_key_alt + ) + else: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( + key + ) else: print( - " | Extracting only the non-EMA weights (usually better for fine-tuning)" + " | Extracting only the non-EMA weights (usually better for fine-tuning)" ) for key in keys: - if key.startswith(unet_key): + if key.startswith("model.diffusion_model") and key in checkpoint: unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) new_checkpoint = {} @@ -1026,6 +1036,15 @@ def convert_open_clip_checkpoint(checkpoint): return text_model +def replace_checkpoint_vae(checkpoint, vae_path:str): + if vae_path.endswith(".safetensors"): + vae_ckpt = load_file(vae_path) + else: + vae_ckpt = torch.load(vae_path, map_location="cpu") + state_dict = vae_ckpt['state_dict'] if "state_dict" in vae_ckpt else vae_ckpt + for vae_key in state_dict: + new_key = f'first_stage_model.{vae_key}' + checkpoint[new_key] = state_dict[vae_key] def load_pipeline_from_original_stable_diffusion_ckpt( checkpoint_path: str, @@ -1038,8 +1057,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt( extract_ema: bool = True, upcast_attn: bool = False, vae: AutoencoderKL = None, + vae_path: str = None, precision: torch.dtype = torch.float32, return_generator_pipeline: bool = False, + scan_needed:bool=True, ) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]: """ Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` @@ -1067,6 +1088,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt( :param precision: precision to use - torch.float16, torch.float32 or torch.autocast :param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when running stable diffusion 2.1. + :param vae: A diffusers VAE to load into the pipeline. + :param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline. """ with warnings.catch_warnings(): @@ -1074,12 +1097,13 @@ def load_pipeline_from_original_stable_diffusion_ckpt( verbosity = dlogging.get_verbosity() dlogging.set_verbosity_error() - checkpoint = ( - torch.load(checkpoint_path) - if Path(checkpoint_path).suffix == ".ckpt" - else load_file(checkpoint_path) - - ) + if Path(checkpoint_path).suffix == '.ckpt': + if scan_needed: + ModelManager.scan_model(checkpoint_path,checkpoint_path) + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = load_file(checkpoint_path) + cache_dir = global_cache_dir("hub") pipeline_class = ( StableDiffusionGeneratorPipeline @@ -1091,7 +1115,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt( if "global_step" in checkpoint: global_step = checkpoint["global_step"] else: - print(" | global_step key not found in model") + print(" | global_step key not found in model") global_step = None # sometimes there is a state_dict key and sometimes not @@ -1202,9 +1226,19 @@ def load_pipeline_from_original_stable_diffusion_ckpt( unet.load_state_dict(converted_unet_checkpoint) - # Convert the VAE model, or use the one passed - if not vae: - print(" | Using checkpoint model's original VAE") + # If a replacement VAE path was specified, we'll incorporate that into + # the checkpoint model and then convert it + if vae_path: + print(f" | Converting VAE {vae_path}") + replace_checkpoint_vae(checkpoint,vae_path) + # otherwise we use the original VAE, provided that + # an externally loaded diffusers VAE was not passed + elif not vae: + print(" | Using checkpoint model's original VAE") + + if vae: + print(" | Using replacement diffusers VAE") + else: # convert the original or replacement VAE vae_config = create_vae_diffusers_config( original_config, image_size=image_size ) @@ -1214,8 +1248,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt( vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint) - else: - print(" | Using external VAE specified in config") # Convert the text model. model_type = pipeline_type @@ -1232,10 +1264,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt( cache_dir=cache_dir, ) pipe = pipeline_class( - vae=vae, - text_encoder=text_model, + vae=vae.to(precision), + text_encoder=text_model.to(precision), tokenizer=tokenizer, - unet=unet, + unet=unet.to(precision), scheduler=scheduler, safety_checker=None, feature_extractor=None, diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index c64560baf8..4a2bb56270 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -18,7 +18,7 @@ import warnings from enum import Enum from pathlib import Path from shutil import move, rmtree -from typing import Any, Optional, Union +from typing import Any, Optional, Union, Callable import safetensors import safetensors.torch @@ -34,7 +34,7 @@ from picklescan.scanner import scan_file_path from invokeai.backend.globals import Globals, global_cache_dir from ..stable_diffusion import StableDiffusionGeneratorPipeline -from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume +from ..util import CUDA_DEVICE, ask_user, download_with_resume class SDLegacyType(Enum): V1 = 1 @@ -45,9 +45,6 @@ class SDLegacyType(Enum): UNKNOWN = 99 DEFAULT_MAX_MODELS = 2 -VAE_TO_REPO_ID = { # hack, see note in convert_and_import() - "vae-ft-mse-840000-ema-pruned": "stabilityai/sd-vae-ft-mse", -} class ModelManager(object): ''' @@ -285,13 +282,13 @@ class ModelManager(object): self.stack.remove(model_name) if delete_files: if weights: - print(f"** deleting file {weights}") + print(f"** Deleting file {weights}") Path(weights).unlink(missing_ok=True) elif path: - print(f"** deleting directory {path}") + print(f"** Deleting directory {path}") rmtree(path, ignore_errors=True) elif repo_id: - print(f"** deleting the cached model directory for {repo_id}") + print(f"** Deleting the cached model directory for {repo_id}") self._delete_model_from_cache(repo_id) def add_model( @@ -382,9 +379,9 @@ class ModelManager(object): print(f">> Loading diffusers model from {name_or_path}") if using_fp16: - print(" | Using faster float16 precision") + print(" | Using faster float16 precision") else: - print(" | Using more accurate float32 precision") + print(" | Using more accurate float32 precision") # TODO: scan weights maybe? pipeline_args: dict[str, Any] = dict( @@ -435,9 +432,8 @@ class ModelManager(object): # square images??? width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor height = width - - print(f" | Default image dimensions = {width} x {height}") - + print(f" | Default image dimensions = {width} x {height}") + return pipeline, width, height, model_hash def _load_ckpt_model(self, model_name, mconfig): @@ -457,15 +453,21 @@ class ModelManager(object): from . import load_pipeline_from_original_stable_diffusion_ckpt - self.offload_model(self.current_model) - if vae_config := self._choose_diffusers_vae(model_name): - vae = self._load_vae(vae_config) + try: + if self.list_models()[self.current_model]['status'] == 'active': + self.offload_model(self.current_model) + except Exception as e: + pass + + vae_path = None + if vae: + vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae)) if self._has_cuda(): torch.cuda.empty_cache() pipeline = load_pipeline_from_original_stable_diffusion_ckpt( checkpoint_path=weights, original_config_file=config, - vae=vae, + vae_path=vae_path, return_generator_pipeline=True, precision=torch.float16 if self.precision == "float16" else torch.float32, ) @@ -473,7 +475,6 @@ class ModelManager(object): pipeline.enable_offload_submodels(self.device) else: pipeline.to(self.device) - return ( pipeline, width, @@ -512,18 +513,20 @@ class ModelManager(object): print(f">> Offloading {model_name} to CPU") model = self.models[model_name]["model"] model.offload_all() + self.current_model = None gc.collect() if self._has_cuda(): torch.cuda.empty_cache() + @classmethod def scan_model(self, model_name, checkpoint): """ Apply picklescanner to the indicated checkpoint and issue a warning and option to exit if an infected file is identified. """ # scan model - print(f">> Scanning Model: {model_name}") + print(f" | Scanning Model: {model_name}") scan_result = scan_file_path(checkpoint) if scan_result.infected_files != 0: if scan_result.infected_files == 1: @@ -546,7 +549,7 @@ class ModelManager(object): print("### Exiting InvokeAI") sys.exit() else: - print(">> Model scanned ok") + print(" | Model scanned ok") def import_diffuser_model( self, @@ -627,14 +630,13 @@ class ModelManager(object): def heuristic_import( self, path_url_or_repo: str, - convert: bool = True, model_name: str = None, description: str = None, model_config_file: Path = None, commit_to_conf: Path = None, + config_file_callback: Callable[[Path], Path] = None, ) -> str: - """ - Accept a string which could be: + """Accept a string which could be: - a HF diffusers repo_id - a URL pointing to a legacy .ckpt or .safetensors file - a local path pointing to a legacy .ckpt or .safetensors file @@ -648,16 +650,20 @@ class ModelManager(object): The model_name and/or description can be provided. If not, they will be generated automatically. - If convert is true, legacy models will be converted to diffusers - before importing. - If commit_to_conf is provided, the newly loaded model will be written to the `models.yaml` file at the indicated path. Otherwise, the changes will only remain in memory. - The (potentially derived) name of the model is returned on success, or None - on failure. When multiple models are added from a directory, only the last - imported one is returned. + The routine will do its best to figure out the config file + needed to convert legacy checkpoint file, but if it can't it + will call the config_file_callback routine, if provided. The + callback accepts a single argument, the Path to the checkpoint + file, and returns a Path to the config file to use. + + The (potentially derived) name of the model is returned on + success, or None on failure. When multiple models are added + from a directory, only the last imported one is returned. + """ model_path: Path = None thing = path_url_or_repo # to save typing @@ -665,7 +671,7 @@ class ModelManager(object): print(f">> Probing {thing} for import") if thing.startswith(("http:", "https:", "ftp:")): - print(f" | {thing} appears to be a URL") + print(f" | {thing} appears to be a URL") model_path = self._resolve_path( thing, "models/ldm/stable-diffusion-v1" ) # _resolve_path does a download if needed @@ -673,15 +679,15 @@ class ModelManager(object): elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")): if Path(thing).stem in ["model", "diffusion_pytorch_model"]: print( - f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import" + f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import" ) return else: - print(f" | {thing} appears to be a checkpoint file on disk") + print(f" | {thing} appears to be a checkpoint file on disk") model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1") elif Path(thing).is_dir() and Path(thing, "model_index.json").exists(): - print(f" | {thing} appears to be a diffusers file on disk") + print(f" | {thing} appears to be a diffusers file on disk") model_name = self.import_diffuser_model( thing, vae=dict(repo_id="stabilityai/sd-vae-ft-mse"), @@ -692,25 +698,25 @@ class ModelManager(object): elif Path(thing).is_dir(): if (Path(thing) / "model_index.json").exists(): - print(f" | {thing} appears to be a diffusers model.") + print(f" | {thing} appears to be a diffusers model.") model_name = self.import_diffuser_model( thing, commit_to_conf=commit_to_conf ) else: print( - f" |{thing} appears to be a directory. Will scan for models to import" + f" |{thing} appears to be a directory. Will scan for models to import" ) for m in list(Path(thing).rglob("*.ckpt")) + list( Path(thing).rglob("*.safetensors") ): if model_name := self.heuristic_import( - str(m), convert, commit_to_conf=commit_to_conf + str(m), commit_to_conf=commit_to_conf ): print(f" >> {model_name} successfully imported") return model_name elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing): - print(f" | {thing} appears to be a HuggingFace diffusers repo_id") + print(f" | {thing} appears to be a HuggingFace diffusers repo_id") model_name = self.import_diffuser_model( thing, commit_to_conf=commit_to_conf ) @@ -727,55 +733,75 @@ class ModelManager(object): return if model_path.stem in self.config: # already imported - print(" | Already imported. Skipping") + print(" | Already imported. Skipping") return model_path.stem # another round of heuristics to guess the correct config file. - checkpoint = ( - torch.load(model_path) - if model_path.suffix == ".ckpt" - else safetensors.torch.load_file(model_path) - ) + checkpoint = None + if model_path.suffix in [".ckpt",".pt"]: + self.scan_model(model_path,model_path) + checkpoint = torch.load(model_path) + else: + checkpoint = safetensors.torch.load_file(model_path) # additional probing needed if no config file provided if model_config_file is None: - model_type = self.probe_model_type(checkpoint) - if model_type == SDLegacyType.V1: - print(" | SD-v1 model detected") - model_config_file = Path( - Globals.root, "configs/stable-diffusion/v1-inference.yaml" - ) - elif model_type == SDLegacyType.V1_INPAINT: - print(" | SD-v1 inpainting model detected") - model_config_file = Path( - Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml" - ) - elif model_type == SDLegacyType.V2_v: - print( - " | SD-v2-v model detected; model will be converted to diffusers format" - ) - model_config_file = Path( - Globals.root, "configs/stable-diffusion/v2-inference-v.yaml" - ) - convert = True - elif model_type == SDLegacyType.V2_e: - print( - " | SD-v2-e model detected; model will be converted to diffusers format" - ) - model_config_file = Path( - Globals.root, "configs/stable-diffusion/v2-inference.yaml" - ) - convert = True - elif model_type == SDLegacyType.V2: - print( - f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path." - ) - return + # look for a like-named .yaml file in same directory + if model_path.with_suffix(".yaml").exists(): + model_config_file = model_path.with_suffix(".yaml") + print(f" | Using config file {model_config_file.name}") + else: - print( - f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path." - ) - return + model_type = self.probe_model_type(checkpoint) + if model_type == SDLegacyType.V1: + print(" | SD-v1 model detected") + model_config_file = Path( + Globals.root, "configs/stable-diffusion/v1-inference.yaml" + ) + elif model_type == SDLegacyType.V1_INPAINT: + print(" | SD-v1 inpainting model detected") + model_config_file = Path( + Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml" + ) + elif model_type == SDLegacyType.V2_v: + print( + " | SD-v2-v model detected" + ) + model_config_file = Path( + Globals.root, "configs/stable-diffusion/v2-inference-v.yaml" + ) + elif model_type == SDLegacyType.V2_e: + print( + " | SD-v2-e model detected" + ) + model_config_file = Path( + Globals.root, "configs/stable-diffusion/v2-inference.yaml" + ) + elif model_type == SDLegacyType.V2: + print( + f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path." + ) + return + else: + print( + f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path." + ) + return + + if not model_config_file and config_file_callback: + model_config_file = config_file_callback(model_path) + + # despite our best efforts, we could not find a model config file, so give up + if not model_config_file: + return + + # look for a custom vae, a like-named file ending with .vae in the same directory + vae_path = None + for suffix in ["pt", "ckpt", "safetensors"]: + if (model_path.with_suffix(f".vae.{suffix}")).exists(): + vae_path = model_path.with_suffix(f".vae.{suffix}") + print(f" | Using VAE file {vae_path.name}") + vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse") diffuser_path = Path( Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem @@ -783,23 +809,27 @@ class ModelManager(object): model_name = self.convert_and_import( model_path, diffusers_path=diffuser_path, - vae=dict(repo_id="stabilityai/sd-vae-ft-mse"), + vae=vae, + vae_path=str(vae_path), model_name=model_name, model_description=description, original_config_file=model_config_file, commit_to_conf=commit_to_conf, + scan_needed=False, ) return model_name def convert_and_import( - self, - ckpt_path: Path, - diffusers_path: Path, - model_name=None, - model_description=None, - vae=None, - original_config_file: Path = None, - commit_to_conf: Path = None, + self, + ckpt_path: Path, + diffusers_path: Path, + model_name=None, + model_description=None, + vae:dict=None, + vae_path:Path=None, + original_config_file: Path = None, + commit_to_conf: Path = None, + scan_needed: bool=True, ) -> str: """ Convert a legacy ckpt weights file to diffuser model and import @@ -822,23 +852,28 @@ class ModelManager(object): return model_name = model_name or diffusers_path.name - model_description = model_description or f"Optimized version of {model_name}" - print(f">> Optimizing {model_name} (30-60s)") + model_description = model_description or f"Converted version of {model_name}" + print(f" | Converting {model_name} to diffusers (30-60s)") try: # By passing the specified VAE to the conversion function, the autoencoder # will be built into the model rather than tacked on afterward via the config file - vae_model = self._load_vae(vae) if vae else None + vae_model=None + if vae: + vae_model=self._load_vae(vae) + vae_path=None convert_ckpt_to_diffusers( ckpt_path, diffusers_path, extract_ema=True, original_config_file=original_config_file, vae=vae_model, + vae_path=vae_path, + scan_needed=scan_needed, ) print( - f" | Success. Optimized model is now located at {str(diffusers_path)}" + f" | Success. Converted model is now located at {str(diffusers_path)}" ) - print(f" | Writing new config file entry for {model_name}") + print(f" | Writing new config file entry for {model_name}") new_config = dict( path=str(diffusers_path), description=model_description, @@ -849,7 +884,7 @@ class ModelManager(object): self.add_model(model_name, new_config, True) if commit_to_conf: self.commit(commit_to_conf) - print(">> Conversion succeeded") + print(" | Conversion succeeded") except Exception as e: print(f"** Conversion failed: {str(e)}") print( @@ -879,36 +914,6 @@ class ModelManager(object): return search_folder, found_models - def _choose_diffusers_vae( - self, model_name: str, vae: str = None - ) -> Union[dict, str]: - # In the event that the original entry is using a custom ckpt VAE, we try to - # map that VAE onto a diffuser VAE using a hard-coded dictionary. - # I would prefer to do this differently: We load the ckpt model into memory, swap the - # VAE in memory, and then pass that to convert_ckpt_to_diffuser() so that the swapped - # VAE is built into the model. However, when I tried this I got obscure key errors. - if vae: - return vae - if model_name in self.config and ( - vae_ckpt_path := self.model_info(model_name).get("vae", None) - ): - vae_basename = Path(vae_ckpt_path).stem - diffusers_vae = None - if diffusers_vae := VAE_TO_REPO_ID.get(vae_basename, None): - print( - f">> {vae_basename} VAE corresponds to known {diffusers_vae} diffusers version" - ) - vae = {"repo_id": diffusers_vae} - else: - print( - f'** Custom VAE "{vae_basename}" found, but corresponding diffusers model unknown' - ) - print( - '** Using "stabilityai/sd-vae-ft-mse"; If this isn\'t right, please edit the model config' - ) - vae = {"repo_id": "stabilityai/sd-vae-ft-mse"} - return vae - def _make_cache_room(self) -> None: num_loaded_models = len(self.models) if num_loaded_models >= self.max_loaded_models: @@ -1105,7 +1110,7 @@ class ModelManager(object): with open(hashpath) as f: hash = f.read() return hash - print(" | Calculating sha256 hash of model files") + print(" | Calculating sha256 hash of model files") tic = time.time() sha = hashlib.sha256() count = 0 @@ -1117,7 +1122,7 @@ class ModelManager(object): sha.update(chunk) hash = sha.hexdigest() toc = time.time() - print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic)) + print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic)) with open(hashpath, "w") as f: f.write(hash) return hash @@ -1162,12 +1167,12 @@ class ModelManager(object): local_files_only=not Globals.internet_available, ) - print(f" | Loading diffusers VAE from {name_or_path}") + print(f" | Loading diffusers VAE from {name_or_path}") if using_fp16: vae_args.update(torch_dtype=torch.float16) fp_args_list = [{"revision": "fp16"}, {}] else: - print(" | Using more accurate float32 precision") + print(" | Using more accurate float32 precision") fp_args_list = [{}] vae = None @@ -1208,7 +1213,7 @@ class ModelManager(object): hashes_to_delete.add(revision.commit_hash) strategy = cache_info.delete_revisions(*hashes_to_delete) print( - f"** deletion of this model is expected to free {strategy.expected_freed_size_str}" + f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}" ) strategy.execute() diff --git a/invokeai/backend/stable_diffusion/textual_inversion_manager.py b/invokeai/backend/stable_diffusion/textual_inversion_manager.py index 2b043afab7..2dba2b88d3 100644 --- a/invokeai/backend/stable_diffusion/textual_inversion_manager.py +++ b/invokeai/backend/stable_diffusion/textual_inversion_manager.py @@ -1,16 +1,26 @@ -import os import traceback from dataclasses import dataclass from pathlib import Path -from typing import Optional, Union +from typing import Optional, Union, List +import safetensors.torch import torch + from compel.embeddings_provider import BaseTextualInversionManager from picklescan.scanner import scan_file_path from transformers import CLIPTextModel, CLIPTokenizer from .concepts_lib import HuggingFaceConceptsLibrary +@dataclass +class EmbeddingInfo: + name: str + embedding: torch.Tensor + num_vectors_per_token: int + token_dim: int + trained_steps: int = None + trained_model_name: str = None + trained_model_checksum: str = None @dataclass class TextualInversion: @@ -72,66 +82,46 @@ class TextualInversionManager(BaseTextualInversionManager): if str(ckpt_path).endswith(".DS_Store"): return - try: - scan_result = scan_file_path(str(ckpt_path)) - if scan_result.infected_files == 1: + embedding_list = self._parse_embedding(str(ckpt_path)) + for embedding_info in embedding_list: + if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim): print( - f"\n### Security Issues Found in Model: {scan_result.issues_count}" + f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}." ) - print("### For your safety, InvokeAI will not load this embed.") - return - except Exception: - print( - f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt." - ) - return + continue - embedding_info = self._parse_embedding(str(ckpt_path)) - - if embedding_info is None: - # We've already put out an error message about the bad embedding in _parse_embedding, so just return. - return - elif ( - self.text_encoder.get_input_embeddings().weight.data[0].shape[0] - != embedding_info["token_dim"] - ): - print( - f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}." - ) - return - - # Resolve the situation in which an earlier embedding has claimed the same - # trigger string. We replace the trigger with '', as we used to. - trigger_str = embedding_info["name"] - sourcefile = ( - f"{ckpt_path.parent.name}/{ckpt_path.name}" - if ckpt_path.name == "learned_embeds.bin" - else ckpt_path.name - ) - - if trigger_str in self.trigger_to_sourcefile: - replacement_trigger_str = ( - f"<{ckpt_path.parent.name}>" + # Resolve the situation in which an earlier embedding has claimed the same + # trigger string. We replace the trigger with '', as we used to. + trigger_str = embedding_info.name + sourcefile = ( + f"{ckpt_path.parent.name}/{ckpt_path.name}" if ckpt_path.name == "learned_embeds.bin" - else f"<{ckpt_path.stem}>" + else ckpt_path.name ) - print( - f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}" - ) - trigger_str = replacement_trigger_str - try: - self._add_textual_inversion( - trigger_str, - embedding_info["embedding"], - defer_injecting_tokens=defer_injecting_tokens, - ) - # remember which source file claims this trigger - self.trigger_to_sourcefile[trigger_str] = sourcefile + if trigger_str in self.trigger_to_sourcefile: + replacement_trigger_str = ( + f"<{ckpt_path.parent.name}>" + if ckpt_path.name == "learned_embeds.bin" + else f"<{ckpt_path.stem}>" + ) + print( + f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}" + ) + trigger_str = replacement_trigger_str - except ValueError as e: - print(f' | Ignoring incompatible embedding {embedding_info["name"]}') - print(f" | The error was {str(e)}") + try: + self._add_textual_inversion( + trigger_str, + embedding_info.embedding, + defer_injecting_tokens=defer_injecting_tokens, + ) + # remember which source file claims this trigger + self.trigger_to_sourcefile[trigger_str] = sourcefile + + except ValueError as e: + print(f' | Ignoring incompatible embedding {embedding_info["name"]}') + print(f" | The error was {str(e)}") def _add_textual_inversion( self, trigger_str, embedding, defer_injecting_tokens=False @@ -309,111 +299,130 @@ class TextualInversionManager(BaseTextualInversionManager): return token_id - def _parse_embedding(self, embedding_file: str): - file_type = embedding_file.split(".")[-1] - if file_type == "pt": - return self._parse_embedding_pt(embedding_file) - elif file_type == "bin": - return self._parse_embedding_bin(embedding_file) + + def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]: + suffix = Path(embedding_file).suffix + try: + if suffix in [".pt",".ckpt",".bin"]: + scan_result = scan_file_path(embedding_file) + if scan_result.infected_files > 0: + print( + f" ** Security Issues Found in Model: {scan_result.issues_count}" + ) + print(" ** For your safety, InvokeAI will not load this embed.") + return list() + ckpt = torch.load(embedding_file,map_location="cpu") + else: + ckpt = safetensors.torch.load_file(embedding_file) + except Exception as e: + print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}") + return list() + + # try to figure out what kind of embedding file it is and parse accordingly + keys = list(ckpt.keys()) + if all(x in keys for x in ['string_to_token','string_to_param','name','step']): + return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt + + elif all(x in keys for x in ['string_to_token','string_to_param']): + return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt + + elif 'emb_params' in keys: + return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors + else: - print(f"** Notice: unrecognized embedding file format: {embedding_file}") - return None + return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file - def _parse_embedding_pt(self, embedding_file): - embedding_ckpt = torch.load(embedding_file, map_location="cpu") - embedding_info = {} + def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]: + basename = Path(file_path).stem + print(f' | Loading v1 embedding file: {basename}') - # Check if valid embedding file - if "string_to_token" and "string_to_param" in embedding_ckpt: - # Catch variants that do not have the expected keys or values. - try: - embedding_info["name"] = embedding_ckpt["name"] or os.path.basename( - os.path.splitext(embedding_file)[0] - ) + embeddings = list() + token_counter = -1 + for token,embedding in embedding_ckpt["string_to_param"].items(): + if token_counter < 0: + trigger = embedding_ckpt["name"] + elif token_counter == 0: + trigger = f'' + else: + trigger = f'<{basename}-{int(token_counter:=token_counter)}>' + token_counter += 1 + embedding_info = EmbeddingInfo( + name = trigger, + embedding = embedding, + num_vectors_per_token = embedding.size()[0], + token_dim = embedding.size()[1], + trained_steps = embedding_ckpt["step"], + trained_model_name = embedding_ckpt["sd_checkpoint_name"], + trained_model_checksum = embedding_ckpt["sd_checkpoint"] + ) + embeddings.append(embedding_info) + return embeddings - # Check num of embeddings and warn user only the first will be used - embedding_info["num_of_embeddings"] = len( - embedding_ckpt["string_to_token"] - ) - if embedding_info["num_of_embeddings"] > 1: - print(">> More than 1 embedding found. Will use the first one") - - embedding = list(embedding_ckpt["string_to_param"].values())[0] - except (AttributeError, KeyError): - return self._handle_broken_pt_variants(embedding_ckpt, embedding_file) - - embedding_info["embedding"] = embedding - embedding_info["num_vectors_per_token"] = embedding.size()[0] - embedding_info["token_dim"] = embedding.size()[1] - - try: - embedding_info["trained_steps"] = embedding_ckpt["step"] - embedding_info["trained_model_name"] = embedding_ckpt[ - "sd_checkpoint_name" - ] - embedding_info["trained_model_checksum"] = embedding_ckpt[ - "sd_checkpoint" - ] - except AttributeError: - print(">> No Training Details Found. Passing ...") - - # .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/ - # They are actually .bin files - elif len(embedding_ckpt.keys()) == 1: - embedding_info = self._parse_embedding_bin(embedding_file) - - else: - print(">> Invalid embedding format") - embedding_info = None - - return embedding_info - - def _parse_embedding_bin(self, embedding_file): - embedding_ckpt = torch.load(embedding_file, map_location="cpu") - embedding_info = {} - - if list(embedding_ckpt.keys()) == 0: - print(">> Invalid concepts file") - embedding_info = None - else: - for token in list(embedding_ckpt.keys()): - embedding_info["name"] = ( - token - or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>" - ) - embedding_info["embedding"] = embedding_ckpt[token] - embedding_info[ - "num_vectors_per_token" - ] = 1 # All Concepts seem to default to 1 - embedding_info["token_dim"] = embedding_info["embedding"].size()[0] - - return embedding_info - - def _handle_broken_pt_variants( - self, embedding_ckpt: dict, embedding_file: str - ) -> dict: + def _parse_embedding_v2 ( + self, embedding_ckpt: dict, file_path: str + ) -> List[EmbeddingInfo]: """ - This handles the broken .pt file variants. We only know of one at present. + This handles embedding .pt file variant #2. """ - embedding_info = {} + basename = Path(file_path).stem + print(f' | Loading v2 embedding file: {basename}') + embeddings = list() + if isinstance( list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor ): - for token in list(embedding_ckpt["string_to_token"].keys()): - embedding_info["name"] = ( - token - if token != "*" - else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>" + token_counter = 0 + for token,embedding in embedding_ckpt["string_to_param"].items(): + trigger = token if token != '*' \ + else f'<{basename}>' if token_counter == 0 \ + else f'<{basename}-{int(token_counter:=token_counter+1)}>' + embedding_info = EmbeddingInfo( + name = trigger, + embedding = embedding, + num_vectors_per_token = embedding.size()[0], + token_dim = embedding.size()[1], ) - embedding_info["embedding"] = embedding_ckpt[ - "string_to_param" - ].state_dict()[token] - embedding_info["num_vectors_per_token"] = embedding_info[ - "embedding" - ].shape[0] - embedding_info["token_dim"] = embedding_info["embedding"].size()[1] + embeddings.append(embedding_info) else: - print(">> Invalid embedding format") - embedding_info = None + print(f" ** {basename}: Unrecognized embedding format") - return embedding_info + return embeddings + + def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]: + """ + Parse 'version 3' of the .pt textual inversion embedding files. + """ + basename = Path(file_path).stem + print(f' | Loading v3 embedding file: {basename}') + embedding = embedding_ckpt['emb_params'] + embedding_info = EmbeddingInfo( + name = f'<{basename}>', + embedding = embedding, + num_vectors_per_token = embedding.size()[0], + token_dim = embedding.size()[1], + ) + return [embedding_info] + + def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]: + """ + Parse 'version 4' of the textual inversion embedding files. This one + is usually associated with .bin files trained by HuggingFace diffusers. + """ + basename = Path(filepath).stem + short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name + + print(f' | Loading v4 embedding file: {short_path}') + + embeddings = list() + if list(embedding_ckpt.keys()) == 0: + print(f" ** Invalid embeddings file: {short_path}") + else: + for token,embedding in embedding_ckpt.items(): + embedding_info = EmbeddingInfo( + name = token or f"<{basename}>", + embedding = embedding, + num_vectors_per_token = 1, # All Concepts seem to default to 1 + token_dim = embedding.size()[0], + ) + embeddings.append(embedding_info) + return embeddings diff --git a/invokeai/backend/web/invoke_ai_web_server.py b/invokeai/backend/web/invoke_ai_web_server.py index dc77ff4723..7209e31449 100644 --- a/invokeai/backend/web/invoke_ai_web_server.py +++ b/invokeai/backend/web/invoke_ai_web_server.py @@ -1022,7 +1022,7 @@ class InvokeAIWebServer: "RGB" ) - def image_progress(sample, step): + def image_progress(intermediate_state: PipelineIntermediateState): if self.canceled.is_set(): raise CanceledException @@ -1030,6 +1030,14 @@ class InvokeAIWebServer: nonlocal generation_parameters nonlocal progress + step = intermediate_state.step + if intermediate_state.predicted_original is not None: + # Some schedulers report not only the noisy latents at the current timestep, + # but also their estimate so far of what the de-noised latents will be. + sample = intermediate_state.predicted_original + else: + sample = intermediate_state.latents + generation_messages = { "txt2img": "common.statusGeneratingTextToImage", "img2img": "common.statusGeneratingImageToImage", @@ -1302,16 +1310,9 @@ class InvokeAIWebServer: progress.set_current_iteration(progress.current_iteration + 1) - def diffusers_step_callback_adapter(*cb_args, **kwargs): - if isinstance(cb_args[0], PipelineIntermediateState): - progress_state: PipelineIntermediateState = cb_args[0] - return image_progress(progress_state.latents, progress_state.step) - else: - return image_progress(*cb_args, **kwargs) - self.generate.prompt2image( **generation_parameters, - step_callback=diffusers_step_callback_adapter, + step_callback=image_progress, image_callback=image_done, ) diff --git a/invokeai/frontend/CLI/CLI.py b/invokeai/frontend/CLI/CLI.py index 17e1c314f7..22e1bbd49d 100644 --- a/invokeai/frontend/CLI/CLI.py +++ b/invokeai/frontend/CLI/CLI.py @@ -626,7 +626,7 @@ def set_default_output_dir(opt: Args, completer: Completer): completer.set_default_dir(opt.outdir) -def import_model(model_path: str, gen, opt, completer, convert=False): +def import_model(model_path: str, gen, opt, completer): """ model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; (3) a huggingface repository id; or (4) a local directory containing a @@ -657,7 +657,6 @@ def import_model(model_path: str, gen, opt, completer, convert=False): model_path, model_name=model_name, description=model_desc, - convert=convert, ) if not imported_name: @@ -666,7 +665,6 @@ def import_model(model_path: str, gen, opt, completer, convert=False): model_path, model_name=model_name, description=model_desc, - convert=convert, model_config_file=config_file, ) if not imported_name: @@ -757,7 +755,6 @@ def _get_model_name_and_desc( ) return model_name, model_description - def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer): model_name_or_path = model_name_or_path.replace("\\", "/") # windows manager = gen.model_manager @@ -772,16 +769,10 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer): original_config_file = Path(model_info["config"]) model_name = model_name_or_path model_description = model_info["description"] - vae = model_info["vae"] + vae_path = model_info.get("vae") else: print(f"** {model_name_or_path} is not a legacy .ckpt weights file") return - if vae_repo := invokeai.backend.model_management.model_manager.VAE_TO_REPO_ID.get( - Path(vae).stem - ): - vae_repo = dict(repo_id=vae_repo) - else: - vae_repo = None model_name = manager.convert_and_import( ckpt_path, diffusers_path=Path( @@ -790,11 +781,11 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer): model_name=model_name, model_description=model_description, original_config_file=original_config_file, - vae=vae_repo, + vae_path=vae_path, ) else: try: - import_model(model_name_or_path, gen, opt, completer, convert=True) + import_model(model_name_or_path, gen, opt, completer) except KeyboardInterrupt: return diff --git a/invokeai/frontend/web/public/locales/ar.json b/invokeai/frontend/web/public/locales/ar.json index 671341e9ab..e5168da4a8 100644 --- a/invokeai/frontend/web/public/locales/ar.json +++ b/invokeai/frontend/web/public/locales/ar.json @@ -8,7 +8,6 @@ "darkTheme": "داكن", "lightTheme": "فاتح", "greenTheme": "أخضر", - "text2img": "نص إلى صورة", "img2img": "صورة إلى صورة", "unifiedCanvas": "لوحة موحدة", "nodes": "عقد", diff --git a/invokeai/frontend/web/public/locales/de.json b/invokeai/frontend/web/public/locales/de.json index 29155a83c6..deeef34194 100644 --- a/invokeai/frontend/web/public/locales/de.json +++ b/invokeai/frontend/web/public/locales/de.json @@ -7,7 +7,6 @@ "darkTheme": "Dunkel", "lightTheme": "Hell", "greenTheme": "Grün", - "text2img": "Text zu Bild", "img2img": "Bild zu Bild", "nodes": "Knoten", "langGerman": "Deutsch", diff --git a/invokeai/frontend/web/public/locales/es.json b/invokeai/frontend/web/public/locales/es.json index ad3fdaf3ed..9c2b6bf983 100644 --- a/invokeai/frontend/web/public/locales/es.json +++ b/invokeai/frontend/web/public/locales/es.json @@ -8,7 +8,6 @@ "darkTheme": "Oscuro", "lightTheme": "Claro", "greenTheme": "Verde", - "text2img": "Texto a Imagen", "img2img": "Imagen a Imagen", "unifiedCanvas": "Lienzo Unificado", "nodes": "Nodos", @@ -70,7 +69,11 @@ "langHebrew": "Hebreo", "pinOptionsPanel": "Pin del panel de opciones", "loading": "Cargando", - "loadingInvokeAI": "Cargando invocar a la IA" + "loadingInvokeAI": "Cargando invocar a la IA", + "postprocessing": "Tratamiento posterior", + "txt2img": "De texto a imagen", + "accept": "Aceptar", + "cancel": "Cancelar" }, "gallery": { "generations": "Generaciones", @@ -404,7 +407,8 @@ "none": "ninguno", "pickModelType": "Elige el tipo de modelo", "v2_768": "v2 (768px)", - "addDifference": "Añadir una diferencia" + "addDifference": "Añadir una diferencia", + "scanForModels": "Buscar modelos" }, "parameters": { "images": "Imágenes", @@ -574,7 +578,7 @@ "autoSaveToGallery": "Guardar automáticamente en galería", "saveBoxRegionOnly": "Guardar solo región dentro de la caja", "limitStrokesToBox": "Limitar trazos a la caja", - "showCanvasDebugInfo": "Mostrar información de depuración de lienzo", + "showCanvasDebugInfo": "Mostrar la información adicional del lienzo", "clearCanvasHistory": "Limpiar historial de lienzo", "clearHistory": "Limpiar historial", "clearCanvasHistoryMessage": "Limpiar el historial de lienzo también restablece completamente el lienzo unificado. Esto incluye todo el historial de deshacer/rehacer, las imágenes en el área de preparación y la capa base del lienzo.", diff --git a/invokeai/frontend/web/public/locales/fr.json b/invokeai/frontend/web/public/locales/fr.json index 472c437702..cf215d7d06 100644 --- a/invokeai/frontend/web/public/locales/fr.json +++ b/invokeai/frontend/web/public/locales/fr.json @@ -8,7 +8,6 @@ "darkTheme": "Sombre", "lightTheme": "Clair", "greenTheme": "Vert", - "text2img": "Texte en image", "img2img": "Image en image", "unifiedCanvas": "Canvas unifié", "nodes": "Nœuds", @@ -47,7 +46,19 @@ "statusLoadingModel": "Chargement du modèle", "statusModelChanged": "Modèle changé", "discordLabel": "Discord", - "githubLabel": "Github" + "githubLabel": "Github", + "accept": "Accepter", + "statusMergingModels": "Mélange des modèles", + "loadingInvokeAI": "Chargement de Invoke AI", + "cancel": "Annuler", + "langEnglish": "Anglais", + "statusConvertingModel": "Conversion du modèle", + "statusModelConverted": "Modèle converti", + "loading": "Chargement", + "pinOptionsPanel": "Épingler la page d'options", + "statusMergedModels": "Modèles mélangés", + "txt2img": "Texte vers image", + "postprocessing": "Post-Traitement" }, "gallery": { "generations": "Générations", @@ -518,5 +529,15 @@ "betaDarkenOutside": "Assombrir à l'extérieur", "betaLimitToBox": "Limiter à la boîte", "betaPreserveMasked": "Conserver masqué" + }, + "accessibility": { + "uploadImage": "Charger une image", + "reset": "Réinitialiser", + "nextImage": "Image suivante", + "previousImage": "Image précédente", + "useThisParameter": "Utiliser ce paramètre", + "zoomIn": "Zoom avant", + "zoomOut": "Zoom arrière", + "showOptionsPanel": "Montrer la page d'options" } } diff --git a/invokeai/frontend/web/public/locales/he.json b/invokeai/frontend/web/public/locales/he.json index 1e760b8edb..c9b4ff3b17 100644 --- a/invokeai/frontend/web/public/locales/he.json +++ b/invokeai/frontend/web/public/locales/he.json @@ -125,7 +125,6 @@ "langSimplifiedChinese": "סינית", "langUkranian": "אוקראינית", "langSpanish": "ספרדית", - "text2img": "טקסט לתמונה", "img2img": "תמונה לתמונה", "unifiedCanvas": "קנבס מאוחד", "nodes": "צמתים", diff --git a/invokeai/frontend/web/public/locales/it.json b/invokeai/frontend/web/public/locales/it.json index 61aa5c6a08..7df34173df 100644 --- a/invokeai/frontend/web/public/locales/it.json +++ b/invokeai/frontend/web/public/locales/it.json @@ -8,7 +8,6 @@ "darkTheme": "Scuro", "lightTheme": "Chiaro", "greenTheme": "Verde", - "text2img": "Testo a Immagine", "img2img": "Immagine a Immagine", "unifiedCanvas": "Tela unificata", "nodes": "Nodi", @@ -70,7 +69,11 @@ "loading": "Caricamento in corso", "oceanTheme": "Oceano", "langHebrew": "Ebraico", - "loadingInvokeAI": "Caricamento Invoke AI" + "loadingInvokeAI": "Caricamento Invoke AI", + "postprocessing": "Post Elaborazione", + "txt2img": "Testo a Immagine", + "accept": "Accetta", + "cancel": "Annulla" }, "gallery": { "generations": "Generazioni", @@ -404,7 +407,8 @@ "v2_768": "v2 (768px)", "none": "niente", "addDifference": "Aggiungi differenza", - "pickModelType": "Scegli il tipo di modello" + "pickModelType": "Scegli il tipo di modello", + "scanForModels": "Cerca modelli" }, "parameters": { "images": "Immagini", @@ -574,7 +578,7 @@ "autoSaveToGallery": "Salvataggio automatico nella Galleria", "saveBoxRegionOnly": "Salva solo l'area di selezione", "limitStrokesToBox": "Limita i tratti all'area di selezione", - "showCanvasDebugInfo": "Mostra informazioni di debug della Tela", + "showCanvasDebugInfo": "Mostra ulteriori informazioni sulla Tela", "clearCanvasHistory": "Cancella cronologia Tela", "clearHistory": "Cancella la cronologia", "clearCanvasHistoryMessage": "La cancellazione della cronologia della tela lascia intatta la tela corrente, ma cancella in modo irreversibile la cronologia degli annullamenti e dei ripristini.", @@ -612,7 +616,7 @@ "copyMetadataJson": "Copia i metadati JSON", "exitViewer": "Esci dal visualizzatore", "zoomIn": "Zoom avanti", - "zoomOut": "Zoom Indietro", + "zoomOut": "Zoom indietro", "rotateCounterClockwise": "Ruotare in senso antiorario", "rotateClockwise": "Ruotare in senso orario", "flipHorizontally": "Capovolgi orizzontalmente", diff --git a/invokeai/frontend/web/public/locales/ko.json b/invokeai/frontend/web/public/locales/ko.json index 888cdc9925..47cde5fec3 100644 --- a/invokeai/frontend/web/public/locales/ko.json +++ b/invokeai/frontend/web/public/locales/ko.json @@ -11,7 +11,6 @@ "langArabic": "العربية", "langEnglish": "English", "langDutch": "Nederlands", - "text2img": "텍스트->이미지", "unifiedCanvas": "통합 캔버스", "langFrench": "Français", "langGerman": "Deutsch", diff --git a/invokeai/frontend/web/public/locales/nl.json b/invokeai/frontend/web/public/locales/nl.json index c06eae06a6..70b836dbc1 100644 --- a/invokeai/frontend/web/public/locales/nl.json +++ b/invokeai/frontend/web/public/locales/nl.json @@ -8,7 +8,6 @@ "darkTheme": "Donker", "lightTheme": "Licht", "greenTheme": "Groen", - "text2img": "Tekst naar afbeelding", "img2img": "Afbeelding naar afbeelding", "unifiedCanvas": "Centraal canvas", "nodes": "Knooppunten", diff --git a/invokeai/frontend/web/public/locales/pl.json b/invokeai/frontend/web/public/locales/pl.json index 7736b27943..246271658a 100644 --- a/invokeai/frontend/web/public/locales/pl.json +++ b/invokeai/frontend/web/public/locales/pl.json @@ -8,7 +8,6 @@ "darkTheme": "Ciemny", "lightTheme": "Jasny", "greenTheme": "Zielony", - "text2img": "Tekst na obraz", "img2img": "Obraz na obraz", "unifiedCanvas": "Tryb uniwersalny", "nodes": "Węzły", diff --git a/invokeai/frontend/web/public/locales/pt.json b/invokeai/frontend/web/public/locales/pt.json index 6e26b9ea56..6d19e3ad92 100644 --- a/invokeai/frontend/web/public/locales/pt.json +++ b/invokeai/frontend/web/public/locales/pt.json @@ -20,7 +20,6 @@ "langSpanish": "Espanhol", "langRussian": "Русский", "langUkranian": "Украї́нська", - "text2img": "Texto para Imagem", "img2img": "Imagem para Imagem", "unifiedCanvas": "Tela Unificada", "nodes": "Nós", diff --git a/invokeai/frontend/web/public/locales/pt_BR.json b/invokeai/frontend/web/public/locales/pt_BR.json index 18b7ab57e1..e77ef14719 100644 --- a/invokeai/frontend/web/public/locales/pt_BR.json +++ b/invokeai/frontend/web/public/locales/pt_BR.json @@ -8,7 +8,6 @@ "darkTheme": "Noite", "lightTheme": "Dia", "greenTheme": "Verde", - "text2img": "Texto Para Imagem", "img2img": "Imagem Para Imagem", "unifiedCanvas": "Tela Unificada", "nodes": "Nódulos", diff --git a/invokeai/frontend/web/public/locales/ru.json b/invokeai/frontend/web/public/locales/ru.json index d4178119e4..0280341dee 100644 --- a/invokeai/frontend/web/public/locales/ru.json +++ b/invokeai/frontend/web/public/locales/ru.json @@ -8,7 +8,6 @@ "darkTheme": "Темная", "lightTheme": "Светлая", "greenTheme": "Зеленая", - "text2img": "Изображение из текста (text2img)", "img2img": "Изображение в изображение (img2img)", "unifiedCanvas": "Универсальный холст", "nodes": "Ноды", diff --git a/invokeai/frontend/web/public/locales/uk.json b/invokeai/frontend/web/public/locales/uk.json index fbcc5014df..044cea64a4 100644 --- a/invokeai/frontend/web/public/locales/uk.json +++ b/invokeai/frontend/web/public/locales/uk.json @@ -8,7 +8,6 @@ "darkTheme": "Темна", "lightTheme": "Світла", "greenTheme": "Зелена", - "text2img": "Зображення із тексту (text2img)", "img2img": "Зображення із зображення (img2img)", "unifiedCanvas": "Універсальне полотно", "nodes": "Вузли", diff --git a/invokeai/frontend/web/public/locales/zh_CN.json b/invokeai/frontend/web/public/locales/zh_CN.json index 701933052e..b23ac8cc99 100644 --- a/invokeai/frontend/web/public/locales/zh_CN.json +++ b/invokeai/frontend/web/public/locales/zh_CN.json @@ -8,7 +8,6 @@ "darkTheme": "暗色", "lightTheme": "亮色", "greenTheme": "绿色", - "text2img": "文字到图像", "img2img": "图像到图像", "unifiedCanvas": "统一画布", "nodes": "节点", diff --git a/invokeai/frontend/web/public/locales/zh_Hant.json b/invokeai/frontend/web/public/locales/zh_Hant.json index af7b0cf328..98b4882018 100644 --- a/invokeai/frontend/web/public/locales/zh_Hant.json +++ b/invokeai/frontend/web/public/locales/zh_Hant.json @@ -33,7 +33,6 @@ "langBrPortuguese": "巴西葡萄牙語", "langRussian": "俄語", "langSpanish": "西班牙語", - "text2img": "文字到圖像", "unifiedCanvas": "統一畫布" } } diff --git a/invokeai/frontend/web/src/features/lightbox/components/ReactPanZoomButtons.tsx b/invokeai/frontend/web/src/features/lightbox/components/ReactPanZoomButtons.tsx index ee9be65cc1..2e592e83d7 100644 --- a/invokeai/frontend/web/src/features/lightbox/components/ReactPanZoomButtons.tsx +++ b/invokeai/frontend/web/src/features/lightbox/components/ReactPanZoomButtons.tsx @@ -34,7 +34,7 @@ const ReactPanZoomButtons = ({ } aria-label={t('accessibility.zoomIn')} - tooltip="Zoom In" + tooltip={t('accessibility.zoomIn')} onClick={() => zoomIn()} fontSize={20} /> @@ -42,7 +42,7 @@ const ReactPanZoomButtons = ({ } aria-label={t('accessibility.zoomOut')} - tooltip="Zoom Out" + tooltip={t('accessibility.zoomOut')} onClick={() => zoomOut()} fontSize={20} /> @@ -50,7 +50,7 @@ const ReactPanZoomButtons = ({ } aria-label={t('accessibility.rotateCounterClockwise')} - tooltip="Rotate Counter-Clockwise" + tooltip={t('accessibility.rotateCounterClockwise')} onClick={rotateCounterClockwise} fontSize={20} /> @@ -58,7 +58,7 @@ const ReactPanZoomButtons = ({ } aria-label={t('accessibility.rotateClockwise')} - tooltip="Rotate Clockwise" + tooltip={t('accessibility.rotateClockwise')} onClick={rotateClockwise} fontSize={20} /> @@ -66,7 +66,7 @@ const ReactPanZoomButtons = ({ } aria-label={t('accessibility.flipHorizontally')} - tooltip="Flip Horizontally" + tooltip={t('accessibility.flipHorizontally')} onClick={flipHorizontally} fontSize={20} /> @@ -74,7 +74,7 @@ const ReactPanZoomButtons = ({ } aria-label={t('accessibility.flipVertically')} - tooltip="Flip Vertically" + tooltip={t('accessibility.flipVertically')} onClick={flipVertically} fontSize={20} /> @@ -82,7 +82,7 @@ const ReactPanZoomButtons = ({ } aria-label={t('accessibility.reset')} - tooltip="Reset" + tooltip={t('accessibility.reset')} onClick={() => { resetTransform(); reset(); diff --git a/invokeai/frontend/web/src/i18n.ts b/invokeai/frontend/web/src/i18n.ts index 9b655b28be..faa30f7289 100644 --- a/invokeai/frontend/web/src/i18n.ts +++ b/invokeai/frontend/web/src/i18n.ts @@ -1,22 +1,38 @@ import i18n from 'i18next'; import LanguageDetector from 'i18next-browser-languagedetector'; import Backend from 'i18next-http-backend'; - import { initReactI18next } from 'react-i18next'; -i18n - .use(Backend) - .use(LanguageDetector) - .use(initReactI18next) - .init({ - fallbackLng: 'en', - debug: false, - backend: { - loadPath: '/locales/{{lng}}.json', + +import translationEN from '../dist/locales/en.json'; + +if (import.meta.env.MODE === 'package') { + i18n.use(initReactI18next).init({ + lng: 'en', + resources: { + en: { translation: translationEN }, }, + debug: false, interpolation: { escapeValue: false, }, returnNull: false, }); +} else { + i18n + .use(Backend) + .use(LanguageDetector) + .use(initReactI18next) + .init({ + fallbackLng: 'en', + debug: false, + backend: { + loadPath: '/locales/{{lng}}.json', + }, + interpolation: { + escapeValue: false, + }, + returnNull: false, + }); +} export default i18n; diff --git a/pyproject.toml b/pyproject.toml index f587694b01..a285b8fd0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,14 +38,14 @@ dependencies = [ "albumentations", "click", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", - "compel==1.0.4", + "compel==1.0.5", "datasets", "diffusers[torch]~=0.14", "dnspython==2.2.1", "einops", "eventlet", "facexlib", - "fastapi==0.94.1", + "fastapi==0.88.0", "fastapi-events==0.8.0", "fastapi-socketio==0.0.10", "flask==2.1.3", @@ -156,4 +156,3 @@ output = "coverage/index.xml" [flake8] max-line-length = 120 - diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index b722539935..506b8653f8 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -1,6 +1,8 @@ from .test_invoker import create_edge from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext +from invokeai.app.invocations.collections import RangeInvocation +from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation from invokeai.app.services.processor import DefaultInvocationProcessor from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory from invokeai.app.services.invocation_queue import MemoryInvocationQueue @@ -21,13 +23,14 @@ def simple_graph(): def mock_services(): # NOTE: none of these are actually called by the test invocations return InvocationServices( - model_manager = None, - events = None, - images = None, + model_manager = None, # type: ignore + events = None, # type: ignore + images = None, # type: ignore + latents = None, # type: ignore queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), processor = DefaultInvocationProcessor(), - restoration = None, + restoration = None, # type: ignore ) def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: @@ -73,31 +76,23 @@ def test_graph_is_not_complete(simple_graph, mock_services): def test_graph_state_expands_iterator(mock_services): graph = Graph() - test_prompts = ["Banana sushi", "Cat sushi"] - graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts))) - graph.add_node(IterateInvocation(id = "2")) - graph.add_node(ImageTestInvocation(id = "3")) - graph.add_edge(create_edge("1", "collection", "2", "collection")) - graph.add_edge(create_edge("2", "item", "3", "prompt")) + graph.add_node(RangeInvocation(id = "0", start = 0, stop = 3, step = 1)) + graph.add_node(IterateInvocation(id = "1")) + graph.add_node(MultiplyInvocation(id = "2", b = 10)) + graph.add_node(AddInvocation(id = "3", b = 1)) + graph.add_edge(create_edge("0", "collection", "1", "collection")) + graph.add_edge(create_edge("1", "item", "2", "a")) + graph.add_edge(create_edge("2", "a", "3", "a")) g = GraphExecutionState(graph = graph) - n1 = invoke_next(g, mock_services) - n2 = invoke_next(g, mock_services) - n3 = invoke_next(g, mock_services) - n4 = invoke_next(g, mock_services) - n5 = invoke_next(g, mock_services) + while not g.is_complete(): + invoke_next(g, mock_services) + + prepared_add_nodes = g.source_prepared_mapping['3'] + results = set([g.results[n].a for n in prepared_add_nodes]) + expected = set([1, 11, 21]) + assert results == expected - assert g.prepared_source_mapping[n1[0].id] == "1" - assert g.prepared_source_mapping[n2[0].id] == "2" - assert g.prepared_source_mapping[n3[0].id] == "2" - assert g.prepared_source_mapping[n4[0].id] == "3" - assert g.prepared_source_mapping[n5[0].id] == "3" - - assert isinstance(n4[0], ImageTestInvocation) - assert isinstance(n5[0], ImageTestInvocation) - - prompts = [n4[0].prompt, n5[0].prompt] - assert sorted(prompts) == sorted(test_prompts) def test_graph_state_collects(mock_services): graph = Graph() diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 718baa7a1f..68df708bdd 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -24,10 +24,11 @@ def mock_services() -> InvocationServices: model_manager = None, # type: ignore events = TestEventService(), images = None, # type: ignore + latents = None, # type: ignore queue = MemoryInvocationQueue(), graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), processor = DefaultInvocationProcessor(), - restoration = None, + restoration = None, # type: ignore ) @pytest.fixture()