merge with main, resolve conflicts

This commit is contained in:
Lincoln Stein 2023-07-09 13:25:32 -04:00
commit f2b2ebfffa
163 changed files with 4355 additions and 6666 deletions

6
.gitignore vendored
View File

@ -34,7 +34,7 @@ __pycache__/
.Python .Python
build/ build/
develop-eggs/ develop-eggs/
dist/ # dist/
downloads/ downloads/
eggs/ eggs/
.eggs/ .eggs/
@ -79,6 +79,7 @@ cov.xml
.pytest.ini .pytest.ini
cover/ cover/
junit/ junit/
notes/
# Translations # Translations
*.mo *.mo
@ -201,6 +202,9 @@ checkpoints
# If it's a Mac # If it's a Mac
.DS_Store .DS_Store
invokeai/frontend/yarn.lock
invokeai/frontend/node_modules
# Let the frontend manage its own gitignore # Let the frontend manage its own gitignore
!invokeai/frontend/web/* !invokeai/frontend/web/*

199
README.md
View File

@ -36,9 +36,38 @@
</div> </div>
InvokeAI is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. InvokeAI offers an industry leading Web Interface, interactive Command Line Interface, and also serves as the foundation for multiple commercial products. _**Note: This is an alpha release. Bugs are expected and not all
features are fully implemented. Please use the GitHub [Issues
pages](https://github.com/invoke-ai/InvokeAI/issues?q=is%3Aissue+is%3Aopen)
to report unexpected problems. Also note that InvokeAI root directory
which contains models, outputs and configuration files, has changed
between the 2.x and 3.x release. If you wish to use your v2.3 root
directory with v3.0, please follow the directions in [Migrating a 2.3
root directory to 3.0](#migrating-to-3).**_
**Quick links**: [[How to Install](https://invoke-ai.github.io/InvokeAI/#installation)] [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>] [<a href="https://invoke-ai.github.io/InvokeAI/">Documentation and Tutorials</a>] [<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a href="https://github.com/invoke-ai/InvokeAI/issues">Bug Reports</a>] [<a href="https://github.com/invoke-ai/InvokeAI/discussions">Discussion, Ideas & Q&A</a>] InvokeAI is a leading creative engine built to empower professionals
and enthusiasts alike. Generate and create stunning visual media using
the latest AI-driven technologies. InvokeAI offers an industry leading
Web Interface, interactive Command Line Interface, and also serves as
the foundation for multiple commercial products.
**Quick links**: [[How to
Install](https://invoke-ai.github.io/InvokeAI/#installation)] [<a
href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>] [<a
href="https://invoke-ai.github.io/InvokeAI/">Documentation and
Tutorials</a>] [<a
href="https://github.com/invoke-ai/InvokeAI/">Code and
Downloads</a>] [<a
href="https://github.com/invoke-ai/InvokeAI/issues">Bug Reports</a>]
[<a
href="https://github.com/invoke-ai/InvokeAI/discussions">Discussion,
Ideas & Q&A</a>]
<div align="center">
![canvas preview](https://github.com/invoke-ai/InvokeAI/raw/main/docs/assets/canvas_preview.png)
</div>
## Table of Contents ## Table of Contents
@ -63,6 +92,9 @@ Table of Contents 📝
For full installation and upgrade instructions, please see: For full installation and upgrade instructions, please see:
[InvokeAI Installation Overview](https://invoke-ai.github.io/InvokeAI/installation/) [InvokeAI Installation Overview](https://invoke-ai.github.io/InvokeAI/installation/)
If upgrading from version 2.3, please read [Migrating a 2.3 root
directory to 3.0](#migrating-to-3) first.
### Automatic Installer (suggested for 1st time users) ### Automatic Installer (suggested for 1st time users)
1. Go to the bottom of the [Latest Release Page](https://github.com/invoke-ai/InvokeAI/releases/latest) 1. Go to the bottom of the [Latest Release Page](https://github.com/invoke-ai/InvokeAI/releases/latest)
@ -100,7 +132,168 @@ and go to http://localhost:9090.
### Command-Line Installation (for developers and users familiar with Terminals) ### Command-Line Installation (for developers and users familiar with Terminals)
Please see [InvokeAI Installation Overview](https://invoke-ai.github.io/InvokeAI/installation/) for more details on installing and managing your virtual environment manually. You must have Python 3.9 or 3.10 installed on your machine. Earlier or later versions are
not supported.
1. Open a command-line window on your machine. The PowerShell is recommended for Windows.
2. Create a directory to install InvokeAI into. You'll need at least 15 GB of free space:
```terminal
mkdir invokeai
````
3. Create a virtual environment named `.venv` inside this directory and activate it:
```terminal
cd invokeai
python -m venv .venv --prompt InvokeAI
```
4. Activate the virtual environment (do it every time you run InvokeAI)
_For Linux/Mac users:_
```sh
source .venv/bin/activate
```
_For Windows users:_
```ps
.venv\Scripts\activate
```
5. Install the InvokeAI module and its dependencies. Choose the command suited for your platform & GPU.
_For Windows/Linux with an NVIDIA GPU:_
```terminal
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
```
_For Linux with an AMD GPU:_
```sh
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
```
_For non-GPU systems:_
```terminal
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
```
_For Macintoshes, either Intel or M1/M2:_
```sh
pip install InvokeAI --use-pep517
```
6. Configure InvokeAI and install a starting set of image generation models (you only need to do this once):
```terminal
invokeai-configure
```
7. Launch the web server (do it every time you run InvokeAI):
```terminal
invokeai --web
```
8. Point your browser to http://localhost:9090 to bring up the web interface.
9. Type `banana sushi` in the box on the top left and click `Invoke`.
Be sure to activate the virtual environment each time before re-launching InvokeAI,
using `source .venv/bin/activate` or `.venv\Scripts\activate`.
## Detailed Installation Instructions
This fork is supported across Linux, Windows and Macintosh. Linux
users can use either an Nvidia-based card (with CUDA support) or an
AMD card (using the ROCm driver). For full installation and upgrade
instructions, please see:
[InvokeAI Installation Overview](https://invoke-ai.github.io/InvokeAI/installation/INSTALL_SOURCE/)
<a name="migrating-to-3"></a>
### Migrating a v2.3 InvokeAI root directory
The InvokeAI root directory is where the InvokeAI startup file,
installed models, and generated images are stored. It is ordinarily
named `invokeai` and located in your home directory. The contents and
layout of this directory has changed between versions 2.3 and 3.0 and
cannot be used directly.
We currently recommend that you use the installer to create a new root
directory named differently from the 2.3 one, e.g. `invokeai-3` and
then use a migration script to copy your 2.3 models into the new
location. However, if you choose, you can upgrade this directory in
place. This section gives both recipes.
#### Creating a new root directory and migrating old models
This is the safer recipe because it leaves your old root directory in
place to fall back on.
1. Follow the instructions above to create and install InvokeAI in a
directory that has a different name from the 2.3 invokeai directory.
In this example, we will use "invokeai-3"
2. When you are prompted to select models to install, select a minimal
set of models, such as stable-diffusion-v1.5 only.
3. After installation is complete launch `invokeai.sh` (Linux/Mac) or
`invokeai.bat` and select option 8 "Open the developers console". This
will take you to the command line.
4. Issue the command `invokeai-migrate3 --from /path/to/v2.3-root --to
/path/to/invokeai-3-root`. Provide the correct `--from` and `--to`
paths for your v2.3 and v3.0 root directories respectively.
This will copy and convert your old models from 2.3 format to 3.0
format and create a new `models` directory in the 3.0 directory. The
old models directory (which contains the models selected at install
time) will be renamed `models.orig` and can be deleted once you have
confirmed that the migration was successful.
#### Migrating in place
For the adventurous, you may do an in-place upgrade from 2.3 to 3.0
without touching the command line. The recipe is as follows>
1. Launch the InvokeAI launcher script in your current v2.3 root directory.
2. Select option [9] "Update InvokeAI" to bring up the updater dialog.
3a. During the alpha release phase, select option [3] and manually
enter the tag name `v3.0.0+a2`.
3b. Once 3.0 is released, select option [1] to upgrade to the latest release.
4. Once the upgrade is finished you will be returned to the launcher
menu. Select option [7] "Re-run the configure script to fix a broken
install or to complete a major upgrade".
This will run the configure script against the v2.3 directory and
update it to the 3.0 format. The following files will be replaced:
- The invokeai.init file, replaced by invokeai.yaml
- The models directory
- The configs/models.yaml model index
The original versions of these files will be saved with the suffix
".orig" appended to the end. Once you have confirmed that the upgrade
worked, you can safely remove these files. Alternatively you can
restore a working v2.3 directory by removing the new files and
restoring the ".orig" files' original names.
#### Migration Caveats
The migration script will migrate your invokeai settings and models,
including textual inversion models, LoRAs and merges that you may have
installed previously. However it does **not** migrate the generated
images stored in your 2.3-format outputs directory. The released
version of 3.0 is expected to have an interface for importing an
entire directory of image files as a batch.
## Hardware Requirements ## Hardware Requirements

View File

@ -149,7 +149,7 @@ class Installer:
return venv_dir return venv_dir
def install(self, root: str = "~/invokeai", version: str = "latest", yes_to_all=False, find_links: Path = None) -> None: def install(self, root: str = "~/invokeai-3", version: str = "latest", yes_to_all=False, find_links: Path = None) -> None:
""" """
Install the InvokeAI application into the given runtime path Install the InvokeAI application into the given runtime path

View File

@ -14,7 +14,7 @@ echo 3. Run textual inversion training
echo 4. Merge models (diffusers type only) echo 4. Merge models (diffusers type only)
echo 5. Download and install models echo 5. Download and install models
echo 6. Change InvokeAI startup options echo 6. Change InvokeAI startup options
echo 7. Re-run the configure script to fix a broken install echo 7. Re-run the configure script to fix a broken install or to complete a major upgrade
echo 8. Open the developer console echo 8. Open the developer console
echo 9. Update InvokeAI echo 9. Update InvokeAI
echo 10. Command-line help echo 10. Command-line help

View File

@ -81,7 +81,7 @@ do_choice() {
;; ;;
7) 7)
clear clear
printf "Re-run the configure script to fix a broken install\n" printf "Re-run the configure script to fix a broken install or to complete a major upgrade\n"
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
;; ;;
8) 8)
@ -118,12 +118,12 @@ do_choice() {
do_dialog() { do_dialog() {
options=( options=(
1 "Generate images with a browser-based interface" 1 "Generate images with a browser-based interface"
2 "Generate images using a command-line interface" 2 "Explore InvokeAI nodes using a command-line interface"
3 "Textual inversion training" 3 "Textual inversion training"
4 "Merge models (diffusers type only)" 4 "Merge models (diffusers type only)"
5 "Download and install models" 5 "Download and install models"
6 "Change InvokeAI startup options" 6 "Change InvokeAI startup options"
7 "Re-run the configure script to fix a broken install" 7 "Re-run the configure script to fix a broken install or to complete a major upgrade"
8 "Open the developer console" 8 "Open the developer console"
9 "Update InvokeAI") 9 "Update InvokeAI")

View File

@ -0,0 +1,18 @@
from fastapi.routing import APIRouter
from pydantic import BaseModel
from invokeai.version import __version__
app_router = APIRouter(prefix="/v1/app", tags=['app'])
class AppVersion(BaseModel):
"""App Version Response"""
version: str
@app_router.get('/version', operation_id="app_version",
status_code=200,
response_model=AppVersion)
async def get_version() -> AppVersion:
return AppVersion(version=__version__)

View File

@ -1,75 +1,30 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
from typing import Literal, Optional, Union
from fastapi import Query, Body from typing import Literal, List, Optional, Union
from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from fastapi import Body, Path, Query, Response
from ..dependencies import ApiDependencies from fastapi.routing import APIRouter
from pydantic import BaseModel, parse_obj_as
from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import AddModelResult from invokeai.backend.model_management.models import (
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType OPENAPI_MODEL_CONFIGS,
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)] SchedulerPredictionType,
)
from invokeai.backend.model_management import MergeInterpolationMethod
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
class VaeRepo(BaseModel): UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
repo_id: str = Field(description="The repo ID to use for this VAE") ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
path: Optional[str] = Field(description="The path to the VAE") ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE") MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model")
model_name: str = Field(description="The name of the model")
model_type: str = Field(description="The type of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['folder'] = 'folder'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt'
config: str = Field(description="The path to the model config")
weights: str = Field(description="The path to the model weights")
vae: str = Field(description="The path to the model VAE")
width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model")
class SafetensorsModelInfo(CkptModelInfo):
format: Literal['safetensors'] = 'safetensors'
class CreateModelRequest(BaseModel):
name: str = Field(description="The name of the model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
class CreateModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
status: str = Field(description="The status of the API response")
class ImportModelResponse(BaseModel):
name: str = Field(description="The name of the imported model")
# base_model: str = Field(description="The base model")
# model_type: str = Field(description="The model type")
info: AddModelResult = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model")
info: CkptModelInfo = Field(description="The converted model info")
save_location: str = Field(description="The path to save the converted model weights")
class ConvertedModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: list[MODEL_CONFIGS] models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@models_router.get( @models_router.get(
"/", "/",
@ -77,75 +32,103 @@ class ModelsList(BaseModel):
responses={200: {"model": ModelsList }}, responses={200: {"model": ModelsList }},
) )
async def list_models( async def list_models(
base_model: Optional[BaseModelType] = Query( base_model: Optional[BaseModelType] = Query(default=None, description="Base model"),
default=None, description="Base model" model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
),
model_type: Optional[ModelType] = Query(
default=None, description="The type of model to get"
),
) -> ModelsList: ) -> ModelsList:
"""Gets a list of models""" """Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type) models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
models = parse_obj_as(ModelsList, { "models": models_raw }) models = parse_obj_as(ModelsList, { "models": models_raw })
return models return models
@models_router.post( @models_router.patch(
"/", "/{base_model}/{model_type}/{model_name}",
operation_id="update_model", operation_id="update_model",
responses={200: {"status": "success"}}, responses={200: {"description" : "The model was updated successfully"},
404: {"description" : "The model could not be found"},
400: {"description" : "Bad request"}
},
status_code = 200,
response_model = UpdateModelResponse,
) )
async def update_model( async def update_model(
model_request: CreateModelRequest base_model: BaseModelType = Path(description="Base model"),
) -> CreateModelResponse: model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse:
""" Add Model """ """ Add Model """
model_request_info = model_request.info try:
info_dict = model_request_info.dict() ApiDependencies.invoker.services.model_manager.update_model(
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success") model_name=model_name,
base_model=base_model,
ApiDependencies.invoker.services.model_manager.add_model( model_type=model_type,
model_name=model_request.name, model_attributes=info.dict()
model_attributes=info_dict,
clobber=True,
) )
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
model_response = parse_obj_as(UpdateModelResponse, model_raw)
except KeyError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return model_response return model_response
@models_router.post( @models_router.post(
"/import", "/",
operation_id="import_model", operation_id="import_model",
responses= { responses= {
201: {"description" : "The model imported successfully"}, 201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"}, 404: {"description" : "The model could not be found"},
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
}, },
status_code=201, status_code=201,
response_model=ImportModelResponse response_model=ImportModelResponse
) )
async def import_model( async def import_model(
name: str = Query(description="A model path, repo_id or URL to import"), location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse: ) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL """ """ Add a model using its local path, repo_id, or remote URL """
items_to_import = {name}
items_to_import = {location}
prediction_types = { x.value: x for x in SchedulerPredictionType } prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try:
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import = items_to_import, items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(prediction_type) prediction_type_helper = lambda x: prediction_types.get(prediction_type)
) )
if info := installed_models.get(name): info = installed_models.get(location)
logger.info(f'Successfully imported {name}, got {info}')
return ImportModelResponse( if not info:
name = name, logger.error("Import failed")
info = info, raise HTTPException(status_code=424)
status = "success",
logger.info(f'Successfully imported {location}, got {info}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.name,
base_model=info.base_model,
model_type=info.model_type
) )
else: return parse_obj_as(ImportModelResponse, model_raw)
logger.error(f'Model {name} not imported')
raise HTTPException(status_code=404, detail=f'Model {name} not found') except KeyError as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete( @models_router.delete(
"/{model_name}", "/{base_model}/{model_type}/{model_name}",
operation_id="del_model", operation_id="del_model",
responses={ responses={
204: { 204: {
@ -156,144 +139,95 @@ async def import_model(
} }
}, },
) )
async def delete_model(model_name: str) -> None: async def delete_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
) -> Response:
"""Delete Model""" """Delete Model"""
model_names = ApiDependencies.invoker.services.model_manager.model_names()
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
model_exists = model_name in model_names
# check if model exists try:
logger.info(f"Checking for model {model_name}...") ApiDependencies.invoker.services.model_manager.del_model(model_name,
base_model = base_model,
if model_exists: model_type = model_type
logger.info(f"Deleting Model: {model_name}") )
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True) logger.info(f"Deleted model: {model_name}")
logger.info(f"Model Deleted: {model_name}") return Response(status_code=204)
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully") except KeyError:
logger.error(f"Model not found: {model_name}")
else:
logger.error("Model not found")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
@models_router.put(
"/convert/{base_model}/{model_type}/{model_name}",
operation_id="convert_model",
responses={
200: { "description": "Model converted successfully" },
400: {"description" : "Bad request" },
404: { "description": "Model not found" },
},
status_code = 200,
response_model = ConvertModelResponse,
)
async def convert_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
) -> ConvertModelResponse:
"""Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Converting model: {model_name}")
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
base_model = base_model,
model_type = model_type
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
base_model = base_model,
model_type = model_type)
response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
# @socketio.on("convertToDiffusers") @models_router.put(
# def convert_to_diffusers(model_to_convert: dict): "/merge/{base_model}",
# try: operation_id="merge_models",
# if model_info := self.generate.model_manager.model_info( responses={
# model_name=model_to_convert["model_name"] 200: { "description": "Model converted successfully" },
# ): 400: { "description": "Incompatible models" },
# if "weights" in model_info: 404: { "description": "One or more models not found" },
# ckpt_path = Path(model_info["weights"]) },
# original_config_file = Path(model_info["config"]) status_code = 200,
# model_name = model_to_convert["model_name"] response_model = MergeModelResponse,
# model_description = model_info["description"] )
# else: async def merge_models(
# self.socketio.emit( base_model: BaseModelType = Path(description="Base model"),
# "error", {"message": "Model is not a valid checkpoint file"} model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
# ) merged_model_name: Optional[str] = Body(description="Name of destination model"),
# else: alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
# self.socketio.emit( interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
# "error", {"message": "Could not retrieve model info."} force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
# ) ) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model"""
# if not ckpt_path.is_absolute(): logger = ApiDependencies.invoker.services.logger
# ckpt_path = Path(Globals.root, ckpt_path) try:
logger.info(f"Merging models: {model_names}")
# if original_config_file and not original_config_file.is_absolute(): result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
# original_config_file = Path(Globals.root, original_config_file) base_model,
merged_model_name or "+".join(model_names),
# diffusers_path = Path( alpha,
# ckpt_path.parent.absolute(), f"{model_name}_diffusers" interp,
# ) force)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
# if model_to_convert["save_location"] == "root": base_model = base_model,
# diffusers_path = Path( model_type = ModelType.Main,
# global_converted_ckpts_dir(), f"{model_name}_diffusers" )
# ) response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError:
# if ( raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
# model_to_convert["save_location"] == "custom" except ValueError as e:
# and model_to_convert["custom_location"] is not None raise HTTPException(status_code=400, detail=str(e))
# ): return response
# diffusers_path = Path(
# model_to_convert["custom_location"], f"{model_name}_diffusers"
# )
# if diffusers_path.exists():
# shutil.rmtree(diffusers_path)
# self.generate.model_manager.convert_and_import(
# ckpt_path,
# diffusers_path,
# model_name=model_name,
# model_description=model_description,
# vae=None,
# original_config_file=original_config_file,
# commit_to_conf=opt.conf,
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelConverted",
# {
# "new_model_name": model_name,
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Model Converted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("mergeDiffusersModels")
# def merge_diffusers_models(model_merge_info: dict):
# try:
# models_to_merge = model_merge_info["models_to_merge"]
# model_ids_or_paths = [
# self.generate.model_manager.model_name_or_path(x)
# for x in models_to_merge
# ]
# merged_pipe = merge_diffusion_models(
# model_ids_or_paths,
# model_merge_info["alpha"],
# model_merge_info["interp"],
# model_merge_info["force"],
# )
# dump_path = global_models_dir() / "merged_models"
# if model_merge_info["model_merge_save_path"] is not None:
# dump_path = Path(model_merge_info["model_merge_save_path"])
# os.makedirs(dump_path, exist_ok=True)
# dump_path = dump_path / model_merge_info["merged_model_name"]
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
# merged_model_config = dict(
# model_name=model_merge_info["merged_model_name"],
# description=f'Merge of models {", ".join(models_to_merge)}',
# commit_to_conf=opt.conf,
# )
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
# "vae", None
# ):
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
# merged_model_config.update(vae=vae)
# self.generate.model_manager.import_diffuser_model(
# dump_path, **merged_model_config
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelsMerged",
# {
# "merged_models": models_to_merge,
# "merged_model_name": model_merge_info["merged_model_name"],
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
# except Exception as e:

View File

@ -24,10 +24,14 @@ logger = InvokeAILogger.getLogger(config=app_config)
import invokeai.frontend.web as web_dir import invokeai.frontend.web as web_dir
from .api.dependencies import ApiDependencies from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images, boards, board_images from .api.routers import sessions, models, images, boards, board_images, app_info
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
import torch
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes
# Create the app # Create the app
# TODO: create this all in a method so configuration/etc. can be passed in? # TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None) app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
@ -82,6 +86,8 @@ app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api") app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(app_info.app_router, prefix='/api')
# Build a custom OpenAPI to include all outputs # Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow? # TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi(): def custom_openapi():

View File

@ -47,7 +47,7 @@ def add_parsers(
commands: list[type], commands: list[type],
command_field: str = "type", command_field: str = "type",
exclude_fields: list[str] = ["id", "type"], exclude_fields: list[str] = ["id", "type"],
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None add_arguments: Union[Callable[[argparse.ArgumentParser], None],None] = None
): ):
"""Adds parsers for each command to the subparsers""" """Adds parsers for each command to the subparsers"""
@ -72,7 +72,7 @@ def add_parsers(
def add_graph_parsers( def add_graph_parsers(
subparsers, subparsers,
graphs: list[LibraryGraph], graphs: list[LibraryGraph],
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
): ):
for graph in graphs: for graph in graphs:
command_parser = subparsers.add_parser(graph.name, help=graph.description) command_parser = subparsers.add_parser(graph.name, help=graph.description)

View File

@ -1,12 +1,11 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import argparse import argparse
import os
import re import re
import shlex import shlex
import sys import sys
import time import time
from typing import Union, get_type_hints from typing import Union, get_type_hints, Optional
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from pydantic.fields import Field from pydantic.fields import Field
@ -53,6 +52,10 @@ from .services.processor import DefaultInvocationProcessor
from .services.restoration_services import RestorationServices from .services.restoration_services import RestorationServices
from .services.sqlite import SqliteItemStorage from .services.sqlite import SqliteItemStorage
import torch
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes
class CliCommand(BaseModel): class CliCommand(BaseModel):
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
@ -348,7 +351,7 @@ def invoke_cli():
# Parse invocation # Parse invocation
command: CliCommand = None # type:ignore command: CliCommand = None # type:ignore
system_graph: LibraryGraph|None = None system_graph: Optional[LibraryGraph] = None
if args['type'] in system_graph_names: if args['type'] in system_graph_names:
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs)) system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id)) invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))

View File

@ -1,19 +1,16 @@
from typing import Literal, Optional, Union, List
from pydantic import BaseModel, Field
import re import re
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import torch import torch
from compel import Compel from compel import Compel
from compel.prompt_parser import (Blend, Conjunction, from compel.prompt_parser import (Blend, Conjunction,
CrossAttentionControlSubstitute, CrossAttentionControlSubstitute,
FlattenedPrompt, Fragment) FlattenedPrompt, Fragment)
from pydantic import BaseModel, Field from ...backend.util.devices import torch_dtype
from ...backend.model_management import ModelType
from ...backend.model_management.models import ModelNotFoundException from ...backend.model_management.models import ModelNotFoundException
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.util.devices import torch_dtype
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext) InvocationConfig, InvocationContext)
from .model import ClipField from .model import ClipField
@ -95,6 +92,7 @@ class CompelInvocation(BaseInvocation):
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\ with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\ ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),\
text_encoder_info as text_encoder: text_encoder_info as text_encoder:
compel = Compel( compel = Compel(
@ -134,6 +132,24 @@ class CompelInvocation(BaseInvocation):
), ),
) )
class ClipSkipInvocationOutput(BaseInvocationOutput):
"""Clip skip node output"""
type: Literal["clip_skip_output"] = "clip_skip_output"
clip: ClipField = Field(None, description="Clip with skipped layers")
class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model."""
type: Literal["clip_skip"] = "clip_skip"
clip: ClipField = Field(None, description="Clip to use")
skipped_layers: int = Field(0, description="Number of layers to skip in text_encoder")
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers
return ClipSkipInvocationOutput(
clip=self.clip,
)
def get_max_token_count( def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction],

View File

@ -6,7 +6,7 @@ from builtins import float, bool
import cv2 import cv2
import numpy as np import numpy as np
from typing import Literal, Optional, Union, List, Dict from typing import Literal, Optional, Union, List, Dict
from PIL import Image, ImageFilter, ImageOps from PIL import Image
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from ..models.image import ImageField, ImageCategory, ResourceOrigin from ..models.image import ImageField, ImageCategory, ResourceOrigin
@ -422,9 +422,9 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
# Inputs # Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
h: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `h` parameter") h: Optional[int] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
w: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `w` parameter") w: Optional[int] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
f: Union[int, None] = Field(default=256, ge=0, description="Content shuffle `f` parameter") f: Optional[int] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
# fmt: on # fmt: on
def run_processor(self, image): def run_processor(self, image):

View File

@ -1,11 +1,10 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from functools import partial from functools import partial
from typing import Literal, Optional, Union, get_args from typing import Literal, Optional, get_args
import torch import torch
from diffusers import ControlNetModel from pydantic import Field
from pydantic import BaseModel, Field
from invokeai.app.models.image import (ColorField, ImageCategory, ImageField, from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
ResourceOrigin) ResourceOrigin)
@ -18,7 +17,6 @@ from ..util.step_callback import stable_diffusion_step_callback
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
from .image import ImageOutput from .image import ImageOutput
import re
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from .model import UNetField, VaeField from .model import UNetField, VaeField
@ -76,7 +74,7 @@ class InpaintInvocation(BaseInvocation):
vae: VaeField = Field(default=None, description="Vae model") vae: VaeField = Field(default=None, description="Vae model")
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The input image") image: Optional[ImageField] = Field(description="The input image")
strength: float = Field( strength: float = Field(
default=0.75, gt=0, le=1, description="The strength of the original image" default=0.75, gt=0, le=1, description="The strength of the original image"
) )
@ -86,7 +84,7 @@ class InpaintInvocation(BaseInvocation):
) )
# Inputs # Inputs
mask: Union[ImageField, None] = Field(description="The mask") mask: Optional[ImageField] = Field(description="The mask")
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)") seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
seam_blur: int = Field( seam_blur: int = Field(
default=16, ge=0, description="The seam inpaint blur radius (px)" default=16, ge=0, description="The seam inpaint blur radius (px)"

View File

@ -1,7 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io from typing import Literal, Optional
from typing import Literal, Optional, Union
import numpy import numpy
from PIL import Image, ImageFilter, ImageOps, ImageChops from PIL import Image, ImageFilter, ImageOps, ImageChops
@ -67,7 +66,7 @@ class LoadImageInvocation(BaseInvocation):
type: Literal["load_image"] = "load_image" type: Literal["load_image"] = "load_image"
# Inputs # Inputs
image: Union[ImageField, None] = Field( image: Optional[ImageField] = Field(
default=None, description="The image to load" default=None, description="The image to load"
) )
# fmt: on # fmt: on
@ -87,7 +86,7 @@ class ShowImageInvocation(BaseInvocation):
type: Literal["show_image"] = "show_image" type: Literal["show_image"] = "show_image"
# Inputs # Inputs
image: Union[ImageField, None] = Field( image: Optional[ImageField] = Field(
default=None, description="The image to show" default=None, description="The image to show"
) )
@ -112,7 +111,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_crop"] = "img_crop" type: Literal["img_crop"] = "img_crop"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to crop") image: Optional[ImageField] = Field(default=None, description="The image to crop")
x: int = Field(default=0, description="The left x coordinate of the crop rectangle") x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
y: int = Field(default=0, description="The top y coordinate of the crop rectangle") y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle") width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
@ -150,8 +149,8 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_paste"] = "img_paste" type: Literal["img_paste"] = "img_paste"
# Inputs # Inputs
base_image: Union[ImageField, None] = Field(default=None, description="The base image") base_image: Optional[ImageField] = Field(default=None, description="The base image")
image: Union[ImageField, None] = Field(default=None, description="The image to paste") image: Optional[ImageField] = Field(default=None, description="The image to paste")
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting") mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
x: int = Field(default=0, description="The left x coordinate at which to paste the image") x: int = Field(default=0, description="The left x coordinate at which to paste the image")
y: int = Field(default=0, description="The top y coordinate at which to paste the image") y: int = Field(default=0, description="The top y coordinate at which to paste the image")
@ -203,7 +202,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["tomask"] = "tomask" type: Literal["tomask"] = "tomask"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to create the mask from") image: Optional[ImageField] = Field(default=None, description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask") invert: bool = Field(default=False, description="Whether or not to invert the mask")
# fmt: on # fmt: on
@ -237,8 +236,8 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_mul"] = "img_mul" type: Literal["img_mul"] = "img_mul"
# Inputs # Inputs
image1: Union[ImageField, None] = Field(default=None, description="The first image to multiply") image1: Optional[ImageField] = Field(default=None, description="The first image to multiply")
image2: Union[ImageField, None] = Field(default=None, description="The second image to multiply") image2: Optional[ImageField] = Field(default=None, description="The second image to multiply")
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@ -273,7 +272,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_chan"] = "img_chan" type: Literal["img_chan"] = "img_chan"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to get the channel from") image: Optional[ImageField] = Field(default=None, description="The image to get the channel from")
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get") channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
# fmt: on # fmt: on
@ -308,7 +307,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_conv"] = "img_conv" type: Literal["img_conv"] = "img_conv"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to convert") image: Optional[ImageField] = Field(default=None, description="The image to convert")
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to") mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
# fmt: on # fmt: on
@ -340,7 +339,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_blur"] = "img_blur" type: Literal["img_blur"] = "img_blur"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to blur") image: Optional[ImageField] = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius") radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur") blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
# fmt: on # fmt: on
@ -398,7 +397,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_resize"] = "img_resize" type: Literal["img_resize"] = "img_resize"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to resize") image: Optional[ImageField] = Field(default=None, description="The image to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
@ -437,7 +436,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_scale"] = "img_scale" type: Literal["img_scale"] = "img_scale"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to scale") image: Optional[ImageField] = Field(default=None, description="The image to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the image") scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on # fmt: on
@ -477,7 +476,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_lerp"] = "img_lerp" type: Literal["img_lerp"] = "img_lerp"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to lerp") image: Optional[ImageField] = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value") min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value") max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
# fmt: on # fmt: on
@ -513,7 +512,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_ilerp"] = "img_ilerp" type: Literal["img_ilerp"] = "img_ilerp"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to lerp") image: Optional[ImageField] = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value") min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value") max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
# fmt: on # fmt: on

View File

@ -1,6 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import Literal, Union, get_args from typing import Literal, Optional, get_args
import numpy as np import numpy as np
import math import math
@ -68,7 +68,7 @@ def get_tile_images(image: np.ndarray, width=8, height=8):
def tile_fill_missing( def tile_fill_missing(
im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
) -> Image.Image: ) -> Image.Image:
# Only fill if there's an alpha layer # Only fill if there's an alpha layer
if im.mode != "RGBA": if im.mode != "RGBA":
@ -125,7 +125,7 @@ class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color""" """Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba" type: Literal["infill_rgba"] = "infill_rgba"
image: Union[ImageField, None] = Field( image: Optional[ImageField] = Field(
default=None, description="The image to infill" default=None, description="The image to infill"
) )
color: ColorField = Field( color: ColorField = Field(
@ -162,7 +162,7 @@ class InfillTileInvocation(BaseInvocation):
type: Literal["infill_tile"] = "infill_tile" type: Literal["infill_tile"] = "infill_tile"
image: Union[ImageField, None] = Field( image: Optional[ImageField] = Field(
default=None, description="The image to infill" default=None, description="The image to infill"
) )
tile_size: int = Field(default=32, ge=1, description="The tile size (px)") tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
@ -202,7 +202,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
type: Literal["infill_patchmatch"] = "infill_patchmatch" type: Literal["infill_patchmatch"] = "infill_patchmatch"
image: Union[ImageField, None] = Field( image: Optional[ImageField] = Field(
default=None, description="The image to infill" default=None, description="The image to infill"
) )

View File

@ -1,19 +1,17 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import ExitStack
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
import einops import einops
import torch import torch
from diffusers import ControlNetModel, DPMSolverMultistepScheduler from diffusers import ControlNetModel
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ...backend.image_util.seamless import configure_model_padding from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
@ -23,7 +21,6 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import torch_dtype from ...backend.util.devices import torch_dtype
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext) InvocationConfig, InvocationContext)
from .compel import ConditioningField from .compel import ConditioningField
@ -585,7 +582,7 @@ class ImageToLatentsInvocation(BaseInvocation):
type: Literal["i2l"] = "i2l" type: Literal["i2l"] = "i2l"
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The image to encode") image: Optional[ImageField] = Field(description="The image to encode")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field( tiled: bool = Field(
default=False, default=False,

View File

@ -30,6 +30,7 @@ class UNetField(BaseModel):
class ClipField(BaseModel): class ClipField(BaseModel):
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel") tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel") text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading") loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
@ -154,6 +155,7 @@ class MainModelLoaderInvocation(BaseInvocation):
submodel=SubModelType.TextEncoder, submodel=SubModelType.TextEncoder,
), ),
loras=[], loras=[],
skipped_layers=0,
), ),
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(

View File

@ -32,7 +32,7 @@ def get_noise(
perlin: float = 0.0, perlin: float = 0.0,
): ):
"""Generate noise for a given image size.""" """Generate noise for a given image size."""
noise_device_type = "cpu" if (use_cpu or device.type == "mps") else device.type noise_device_type = "cpu" if use_cpu else device.type
# limit noise to only the diffusion image channels, not the mask channels # limit noise to only the diffusion image channels, not the mask channels
input_channels = min(latent_channels, 4) input_channels = min(latent_channels, 4)

View File

@ -1,4 +1,4 @@
from typing import Literal, Union from typing import Literal, Optional
from pydantic import Field from pydantic import Field
@ -15,7 +15,7 @@ class RestoreFaceInvocation(BaseInvocation):
type: Literal["restore_face"] = "restore_face" type: Literal["restore_face"] = "restore_face"
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The input image") image: Optional[ImageField] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" ) strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
# fmt: on # fmt: on

View File

@ -1,6 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Union from typing import Literal, Optional
from pydantic import Field from pydantic import Field
@ -16,7 +16,7 @@ class UpscaleInvocation(BaseInvocation):
type: Literal["upscale"] = "upscale" type: Literal["upscale"] = "upscale"
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None) image: Optional[ImageField] = Field(description="The input image", default=None)
strength: float = Field(default=0.75, gt=0, le=1, description="The strength") strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2, 4] = Field(default=2, description="The upscale level") level: Literal[2, 4] = Field(default=2, description="The upscale level")
# fmt: on # fmt: on

View File

@ -1,8 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import sqlite3 import sqlite3
import threading import threading
from typing import Union, cast from typing import Optional, cast
from invokeai.app.services.board_record_storage import BoardRecord
from invokeai.app.services.image_record_storage import OffsetPaginatedResults from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
@ -44,7 +43,7 @@ class BoardImageRecordStorageBase(ABC):
def get_board_for_image( def get_board_for_image(
self, self,
image_name: str, image_name: str,
) -> Union[str, None]: ) -> Optional[str]:
"""Gets an image's board id, if it has one.""" """Gets an image's board id, if it has one."""
pass pass
@ -215,7 +214,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
def get_board_for_image( def get_board_for_image(
self, self,
image_name: str, image_name: str,
) -> Union[str, None]: ) -> Optional[str]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger from logging import Logger
from typing import List, Union from typing import List, Union, Optional
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_record_storage import ( from invokeai.app.services.board_record_storage import (
BoardRecord, BoardRecord,
@ -49,7 +49,7 @@ class BoardImagesServiceABC(ABC):
def get_board_for_image( def get_board_for_image(
self, self,
image_name: str, image_name: str,
) -> Union[str, None]: ) -> Optional[str]:
"""Gets an image's board id, if it has one.""" """Gets an image's board id, if it has one."""
pass pass
@ -126,13 +126,13 @@ class BoardImagesService(BoardImagesServiceABC):
def get_board_for_image( def get_board_for_image(
self, self,
image_name: str, image_name: str,
) -> Union[str, None]: ) -> Optional[str]:
board_id = self._services.board_image_records.get_board_for_image(image_name) board_id = self._services.board_image_records.get_board_for_image(image_name)
return board_id return board_id
def board_record_to_dto( def board_record_to_dto(
board_record: BoardRecord, cover_image_name: str | None, image_count: int board_record: BoardRecord, cover_image_name: Optional[str], image_count: int
) -> BoardDTO: ) -> BoardDTO:
"""Converts a board record to a board DTO.""" """Converts a board record to a board DTO."""
return BoardDTO( return BoardDTO(

View File

@ -171,6 +171,7 @@ from pydantic import BaseSettings, Field, parse_obj_as
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
INIT_FILE = Path('invokeai.yaml') INIT_FILE = Path('invokeai.yaml')
MODEL_CORE = Path('models/core')
DB_FILE = Path('invokeai.db') DB_FILE = Path('invokeai.db')
LEGACY_INIT_FILE = Path('invokeai.init') LEGACY_INIT_FILE = Path('invokeai.init')
@ -324,16 +325,11 @@ class InvokeAISettings(BaseSettings):
help=field.field_info.description, help=field.field_info.description,
) )
def _find_root()->Path: def _find_root()->Path:
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
if os.environ.get("INVOKEAI_ROOT"): if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve() root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
elif ( elif any([(venv.parent/x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE, MODEL_CORE]]):
os.environ.get("VIRTUAL_ENV") root = (venv.parent).resolve()
and (Path(os.environ.get("VIRTUAL_ENV"), "..", INIT_FILE).exists()
or
Path(os.environ.get("VIRTUAL_ENV"), "..", LEGACY_INIT_FILE).exists()
)
):
root = Path(os.environ.get("VIRTUAL_ENV"), "..").resolve()
else: else:
root = Path("~/invokeai").expanduser().resolve() root = Path("~/invokeai").expanduser().resolve()
return root return root

View File

@ -1,10 +1,9 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any from typing import Any, Optional
from invokeai.app.models.image import ProgressImage from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
from invokeai.app.models.exceptions import CanceledException
class EventServiceBase: class EventServiceBase:
session_event: str = "session_event" session_event: str = "session_event"
@ -28,7 +27,7 @@ class EventServiceBase:
graph_execution_state_id: str, graph_execution_state_id: str,
node: dict, node: dict,
source_node_id: str, source_node_id: str,
progress_image: ProgressImage | None, progress_image: Optional[ProgressImage],
step: int, step: int,
total_steps: int, total_steps: int,
) -> None: ) -> None:

View File

@ -3,7 +3,6 @@
import copy import copy
import itertools import itertools
import uuid import uuid
from types import NoneType
from typing import ( from typing import (
Annotated, Annotated,
Any, Any,
@ -26,6 +25,8 @@ from ..invocations.baseinvocation import (
InvocationContext, InvocationContext,
) )
# in 3.10 this would be "from types import NoneType"
NoneType = type(None)
class EdgeConnection(BaseModel): class EdgeConnection(BaseModel):
node_id: str = Field(description="The id of the node for this edge connection") node_id: str = Field(description="The id of the node for this edge connection")
@ -60,8 +61,6 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
node_input_field = node_inputs.get(field) or None node_input_field = node_inputs.get(field) or None
return node_input_field return node_input_field
from typing import Optional, Union, List, get_args
def is_union_subtype(t1, t2): def is_union_subtype(t1, t2):
t1_args = get_args(t1) t1_args = get_args(t1)
t2_args = get_args(t2) t2_args = get_args(t2)
@ -846,7 +845,7 @@ class GraphExecutionState(BaseModel):
] ]
} }
def next(self) -> BaseInvocation | None: def next(self) -> Optional[BaseInvocation]:
"""Gets the next node ready to execute.""" """Gets the next node ready to execute."""
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes # TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes

View File

@ -2,13 +2,12 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
from typing import Dict, Optional from typing import Dict, Optional, Union
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from send2trash import send2trash from send2trash import send2trash
from invokeai.app.models.image import ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@ -80,7 +79,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
__cache: Dict[Path, PILImageType] __cache: Dict[Path, PILImageType]
__max_cache_size: int __max_cache_size: int
def __init__(self, output_folder: str | Path): def __init__(self, output_folder: Union[str, Path]):
self.__cache = dict() self.__cache = dict()
self.__cache_ids = Queue() self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config self.__max_cache_size = 10 # TODO: get this from config
@ -164,7 +163,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
return path return path
def validate_path(self, path: str | Path) -> bool: def validate_path(self, path: Union[str, Path]) -> bool:
"""Validates the path given for an image or thumbnail.""" """Validates the path given for an image or thumbnail."""
path = path if isinstance(path, Path) else Path(path) path = path if isinstance(path, Path) else Path(path)
return path.exists() return path.exists()
@ -175,7 +174,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
for folder in folders: for folder in folders:
folder.mkdir(parents=True, exist_ok=True) folder.mkdir(parents=True, exist_ok=True)
def __get_cache(self, image_name: Path) -> PILImageType | None: def __get_cache(self, image_name: Path) -> Optional[PILImageType]:
return None if image_name not in self.__cache else self.__cache[image_name] return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: Path, image: PILImageType): def __set_cache(self, image_name: Path, image: PILImageType):

View File

@ -3,7 +3,6 @@ from datetime import datetime
from typing import Generic, Optional, TypeVar, cast from typing import Generic, Optional, TypeVar, cast
import sqlite3 import sqlite3
import threading import threading
from typing import Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from pydantic.generics import GenericModel from pydantic.generics import GenericModel
@ -116,7 +115,7 @@ class ImageRecordStorageBase(ABC):
pass pass
@abstractmethod @abstractmethod
def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None: def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
"""Gets the most recent image for a board.""" """Gets the most recent image for a board."""
pass pass
@ -208,7 +207,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""" """
) )
def get(self, image_name: str) -> Union[ImageRecord, None]: def get(self, image_name: str) -> Optional[ImageRecord]:
try: try:
self._lock.acquire() self._lock.acquire()
@ -220,7 +219,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
(image_name,), (image_name,),
) )
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone()) result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
raise ImageRecordNotFoundException from e raise ImageRecordNotFoundException from e
@ -475,7 +474,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def get_most_recent_image_for_board( def get_most_recent_image_for_board(
self, board_id: str self, board_id: str
) -> Union[ImageRecord, None]: ) -> Optional[ImageRecord]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -490,7 +489,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
(board_id,), (board_id,),
) )
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone()) result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
finally: finally:
self._lock.release() self._lock.release()
if result is None: if result is None:

View File

@ -370,7 +370,7 @@ class ImageService(ImageServiceABC):
def _get_metadata( def _get_metadata(
self, session_id: Optional[str] = None, node_id: Optional[str] = None self, session_id: Optional[str] = None, node_id: Optional[str] = None
) -> Union[ImageMetadata, None]: ) -> Optional[ImageMetadata]:
"""Get the metadata for a node.""" """Get the metadata for a node."""
metadata = None metadata = None

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from queue import Queue from queue import Queue
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional
class InvocationQueueItem(BaseModel): class InvocationQueueItem(BaseModel):
graph_execution_state_id: str = Field(description="The ID of the graph execution state") graph_execution_state_id: str = Field(description="The ID of the graph execution state")
@ -22,7 +22,7 @@ class InvocationQueueABC(ABC):
pass pass
@abstractmethod @abstractmethod
def put(self, item: InvocationQueueItem | None) -> None: def put(self, item: Optional[InvocationQueueItem]) -> None:
pass pass
@abstractmethod @abstractmethod
@ -57,7 +57,7 @@ class MemoryInvocationQueue(InvocationQueueABC):
return item return item
def put(self, item: InvocationQueueItem | None) -> None: def put(self, item: Optional[InvocationQueueItem]) -> None:
self.__queue.put(item) self.__queue.put(item)
def cancel(self, graph_execution_state_id: str) -> None: def cancel(self, graph_execution_state_id: str) -> None:

View File

@ -1,14 +1,11 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC from abc import ABC
from threading import Event, Thread from typing import Optional
from ..invocations.baseinvocation import InvocationContext
from .graph import Graph, GraphExecutionState from .graph import Graph, GraphExecutionState
from .invocation_queue import InvocationQueueABC, InvocationQueueItem from .invocation_queue import InvocationQueueItem
from .invocation_services import InvocationServices from .invocation_services import InvocationServices
from .item_storage import ItemStorageABC
class Invoker: class Invoker:
"""The invoker, used to execute invocations""" """The invoker, used to execute invocations"""
@ -21,7 +18,7 @@ class Invoker:
def invoke( def invoke(
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
) -> str | None: ) -> Optional[str]:
"""Determines the next node to invoke and enqueues it, preparing if needed. """Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue.""" Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
@ -45,7 +42,7 @@ class Invoker:
return invocation.id return invocation.id
def create_execution_state(self, graph: Graph | None = None) -> GraphExecutionState: def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState:
"""Creates a new execution state for the given graph""" """Creates a new execution state for the given graph"""
new_state = GraphExecutionState(graph=Graph() if graph is None else graph) new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
self.services.graph_execution_manager.set(new_state) self.services.graph_execution_manager.set(new_state)

View File

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
from typing import Dict from typing import Dict, Union, Optional
import torch import torch
@ -55,7 +55,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
if name in self.__cache: if name in self.__cache:
del self.__cache[name] del self.__cache[name]
def __get_cache(self, name: str) -> torch.Tensor|None: def __get_cache(self, name: str) -> Optional[torch.Tensor]:
return None if name not in self.__cache else self.__cache[name] return None if name not in self.__cache else self.__cache[name]
def __set_cache(self, name: str, data: torch.Tensor): def __set_cache(self, name: str, data: torch.Tensor):
@ -69,9 +69,9 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
class DiskLatentsStorage(LatentsStorageBase): class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching""" """Stores latents in a folder on disk without caching"""
__output_folder: str | Path __output_folder: Union[str, Path]
def __init__(self, output_folder: str | Path): def __init__(self, output_folder: Union[str, Path]):
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder) self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder.mkdir(parents=True, exist_ok=True) self.__output_folder.mkdir(parents=True, exist_ok=True)

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Union from typing import Any, Optional
import networkx as nx import networkx as nx
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
@ -34,7 +34,7 @@ class CoreMetadataService(MetadataServiceBase):
return metadata return metadata
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Union[str, None]: def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Optional[str]:
""" """
Finds the id of the nearest ancestor (of a valid type) of a given node. Finds the id of the nearest ancestor (of a valid type) of a given node.
@ -65,7 +65,7 @@ class CoreMetadataService(MetadataServiceBase):
def _get_additional_metadata( def _get_additional_metadata(
self, graph: Graph, node_id: str self, graph: Graph, node_id: str
) -> Union[dict[str, Any], None]: ) -> Optional[dict[str, Any]]:
""" """
Returns additional metadata for a given node. Returns additional metadata for a given node.

View File

@ -2,22 +2,29 @@
from __future__ import annotations from __future__ import annotations
import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING from pydantic import Field
from dataclasses import dataclass from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
from types import ModuleType
from invokeai.backend.model_management.model_manager import ( from invokeai.backend.model_management import (
ModelManager, ModelManager,
BaseModelType, BaseModelType,
ModelType, ModelType,
SubModelType, SubModelType,
ModelInfo, ModelInfo,
AddModelResult,
SchedulerPredictionType,
ModelMerger,
MergeInterpolationMethod,
) )
import torch
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
from .config import InvokeAIAppConfig
from ...backend.util import choose_precision, choose_torch_device from ...backend.util import choose_precision, choose_torch_device
from .config import InvokeAIAppConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from ..invocations.baseinvocation import BaseInvocation, InvocationContext from ..invocations.baseinvocation import BaseInvocation, InvocationContext
@ -30,7 +37,7 @@ class ModelManagerServiceBase(ABC):
def __init__( def __init__(
self, self,
config: InvokeAIAppConfig, config: InvokeAIAppConfig,
logger: types.ModuleType, logger: ModuleType,
): ):
""" """
Initialize with the path to the models.yaml config file. Initialize with the path to the models.yaml config file.
@ -73,13 +80,7 @@ class ModelManagerServiceBase(ABC):
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
""" """
Given a model name returns a dict-like (OmegaConf) object describing it. Given a model name returns a dict-like (OmegaConf) object describing it.
""" Uses the exact format as the omegaconf stanza.
pass
@abstractmethod
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
""" """
pass pass
@ -102,6 +103,19 @@ class ModelManagerServiceBase(ABC):
""" """
pass pass
@abstractmethod
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Return information about the model using the same format as list_models()
"""
pass
@abstractmethod
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
"""
pass
@abstractmethod @abstractmethod
def add_model( def add_model(
@ -111,7 +125,7 @@ class ModelManagerServiceBase(ABC):
model_type: ModelType, model_type: ModelType,
model_attributes: dict, model_attributes: dict,
clobber: bool = False clobber: bool = False
) -> None: ) -> AddModelResult:
""" """
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite. assertion error if the name already exists. Pass clobber=True to overwrite.
@ -121,6 +135,24 @@ class ModelManagerServiceBase(ABC):
""" """
pass pass
@abstractmethod
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
KeyErrorException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
pass
@abstractmethod @abstractmethod
def del_model( def del_model(
self, self,
@ -135,11 +167,32 @@ class ModelManagerServiceBase(ABC):
""" """
pass pass
@abstractmethod
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
pass
@abstractmethod @abstractmethod
def heuristic_import(self, def heuristic_import(self,
items_to_import: Set[str], items_to_import: set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
)->Dict[str, AddModelResult]: )->dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of '''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items. successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported. :param items_to_import: Set of strings corresponding to models to be imported.
@ -159,7 +212,27 @@ class ModelManagerServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def commit(self, conf_file: Path = None) -> None: def merge_models(
self,
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
"""
pass
@abstractmethod
def commit(self, conf_file: Optional[Path] = None) -> None:
""" """
Write current configuration out to the indicated file. Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the If no conf_file is provided, then replaces the
@ -173,7 +246,7 @@ class ModelManagerService(ModelManagerServiceBase):
def __init__( def __init__(
self, self,
config: InvokeAIAppConfig, config: InvokeAIAppConfig,
logger: types.ModuleType, logger: ModuleType,
): ):
""" """
Initialize with the path to the models.yaml config file. Initialize with the path to the models.yaml config file.
@ -191,6 +264,8 @@ class ModelManagerService(ModelManagerServiceBase):
logger.debug(f'config file={config_file}') logger.debug(f'config file={config_file}')
device = torch.device(choose_torch_device()) device = torch.device(choose_torch_device())
logger.debug(f'GPU device = {device}')
precision = config.precision precision = config.precision
if precision == "auto": if precision == "auto":
precision = choose_precision(device) precision = choose_precision(device)
@ -299,12 +374,19 @@ class ModelManagerService(ModelManagerServiceBase):
base_model: Optional[BaseModelType] = None, base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None model_type: Optional[ModelType] = None
) -> list[dict]: ) -> list[dict]:
# ) -> dict:
""" """
Return a list of models. Return a list of models.
""" """
return self.mgr.list_models(base_model, model_type) return self.mgr.list_models(base_model, model_type)
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Return information about the model using the same format as list_models()
"""
return self.mgr.list_model(model_name=model_name,
base_model=base_model,
model_type=model_type)
def add_model( def add_model(
self, self,
model_name: str, model_name: str,
@ -320,8 +402,27 @@ class ModelManagerService(ModelManagerServiceBase):
with an assertion error if provided attributes are incorrect or with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk. the model name is missing. Call commit() to write changes to disk.
""" """
self.logger.debug(f'add/update model {model_name}')
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
KeyError exception if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f'update model {model_name}')
if not self.model_exists(model_name, base_model, model_type):
raise KeyError(f"Unknown model {model_name}")
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
def del_model( def del_model(
self, self,
@ -334,8 +435,29 @@ class ModelManagerService(ModelManagerServiceBase):
then the underlying weight file or diffusers directory will be deleted then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk. as well. Call commit() to write to disk.
""" """
self.logger.debug(f'delete model {model_name}')
self.mgr.del_model(model_name, base_model, model_type) self.mgr.del_model(model_name, base_model, model_type)
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
self.logger.debug(f'convert model {model_name}')
return self.mgr.convert_model(model_name, base_model, model_type)
def commit(self, conf_file: Optional[Path]=None): def commit(self, conf_file: Optional[Path]=None):
""" """
@ -387,9 +509,9 @@ class ModelManagerService(ModelManagerServiceBase):
return self.mgr.logger return self.mgr.logger
def heuristic_import(self, def heuristic_import(self,
items_to_import: Set[str], items_to_import: set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
)->Dict[str, AddModelResult]: )->dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of '''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items. successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported. :param items_to_import: Set of strings corresponding to models to be imported.
@ -407,3 +529,34 @@ class ModelManagerService(ModelManagerServiceBase):
that model. that model.
''' '''
return self.mgr.heuristic_import(items_to_import, prediction_type_helper) return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
def merge_models(
self,
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
"""
merger = ModelMerger(self.mgr)
try:
result = merger.merge_diffusion_models_and_save(
model_names = model_names,
base_model = base_model,
merged_model_name = merged_model_name,
alpha = alpha,
interp = interp,
force = force,
)
except AssertionError as e:
raise ValueError(e)
return result

View File

@ -88,7 +88,7 @@ class ImageUrlsDTO(BaseModel):
class ImageDTO(ImageRecord, ImageUrlsDTO): class ImageDTO(ImageRecord, ImageUrlsDTO):
"""Deserialized image record, enriched for the frontend.""" """Deserialized image record, enriched for the frontend."""
board_id: Union[str, None] = Field( board_id: Optional[str] = Field(
description="The id of the board the image belongs to, if one exists." description="The id of the board the image belongs to, if one exists."
) )
"""The id of the board the image belongs to, if one exists.""" """The id of the board the image belongs to, if one exists."""
@ -96,7 +96,7 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
def image_record_to_dto( def image_record_to_dto(
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None] image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Optional[str]
) -> ImageDTO: ) -> ImageDTO:
"""Converts an image record to an image DTO.""" """Converts an image record to an image DTO."""
return ImageDTO( return ImageDTO(

View File

@ -1,6 +1,6 @@
import sqlite3 import sqlite3
from threading import Lock from threading import Lock
from typing import Generic, TypeVar, Union, get_args from typing import Generic, TypeVar, Optional, Union, get_args
from pydantic import BaseModel, parse_raw_as from pydantic import BaseModel, parse_raw_as
@ -63,7 +63,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._lock.release() self._lock.release()
self._on_changed(item) self._on_changed(item)
def get(self, id: str) -> Union[T, None]: def get(self, id: str) -> Optional[T]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(

View File

@ -21,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from tqdm import trange from tqdm import trange
from typing import Callable, List, Iterator, Optional, Type from typing import Callable, List, Iterator, Optional, Type, Union
from dataclasses import dataclass, field from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
@ -178,7 +178,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
# ------------------------------------ # ------------------------------------
class Img2Img(InvokeAIGenerator): class Img2Img(InvokeAIGenerator):
def generate(self, def generate(self,
init_image: Image.Image | torch.FloatTensor, init_image: Union[Image.Image, torch.FloatTensor],
strength: float=0.75, strength: float=0.75,
**keyword_args **keyword_args
)->Iterator[InvokeAIGeneratorOutput]: )->Iterator[InvokeAIGeneratorOutput]:
@ -195,7 +195,7 @@ class Img2Img(InvokeAIGenerator):
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff # Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
class Inpaint(Img2Img): class Inpaint(Img2Img):
def generate(self, def generate(self,
mask_image: Image.Image | torch.FloatTensor, mask_image: Union[Image.Image, torch.FloatTensor],
# Seam settings - when 0, doesn't fill seam # Seam settings - when 0, doesn't fill seam
seam_size: int = 96, seam_size: int = 96,
seam_blur: int = 16, seam_blur: int = 16,
@ -570,18 +570,6 @@ class Generator:
device = self.model.device device = self.model.device
# limit noise to only the diffusion image channels, not the mask channels # limit noise to only the diffusion image channels, not the mask channels
input_channels = min(self.latent_channels, 4) input_channels = min(self.latent_channels, 4)
if self.use_mps_noise or device.type == "mps":
x = torch.randn(
[
1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor,
],
dtype=self.torch_dtype(),
device="cpu",
).to(device)
else:
x = torch.randn( x = torch.randn(
[ [
1, 1,

View File

@ -88,9 +88,6 @@ class Img2Img(Generator):
def get_noise_like(self, like: torch.Tensor): def get_noise_like(self, like: torch.Tensor):
device = like.device device = like.device
if device.type == "mps":
x = torch.randn_like(like, device="cpu").to(device)
else:
x = torch.randn_like(like, device=device) x = torch.randn_like(like, device=device)
if self.perlin > 0.0: if self.perlin > 0.0:
shape = like.shape shape = like.shape

View File

@ -4,11 +4,10 @@ invokeai.backend.generator.inpaint descends from .generator
from __future__ import annotations from __future__ import annotations
import math import math
from typing import Tuple, Union from typing import Tuple, Union, Optional
import cv2 import cv2
import numpy as np import numpy as np
import PIL
import torch import torch
from PIL import Image, ImageChops, ImageFilter, ImageOps from PIL import Image, ImageChops, ImageFilter, ImageOps
@ -76,7 +75,7 @@ class Inpaint(Img2Img):
return im_patched return im_patched
def tile_fill_missing( def tile_fill_missing(
self, im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
) -> Image.Image: ) -> Image.Image:
# Only fill if there's an alpha layer # Only fill if there's an alpha layer
if im.mode != "RGBA": if im.mode != "RGBA":
@ -203,8 +202,8 @@ class Inpaint(Img2Img):
cfg_scale, cfg_scale,
ddim_eta, ddim_eta,
conditioning, conditioning,
init_image: Image.Image | torch.FloatTensor, init_image: Union[Image.Image, torch.FloatTensor],
mask_image: Image.Image | torch.FloatTensor, mask_image: Union[Image.Image, torch.FloatTensor],
strength: float, strength: float,
mask_blur_radius: int = 8, mask_blur_radius: int = 8,
# Seam settings - when 0, doesn't fill seam # Seam settings - when 0, doesn't fill seam

View File

@ -45,6 +45,7 @@ from invokeai.app.services.config import (
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
from invokeai.frontend.install.widgets import ( from invokeai.frontend.install.widgets import (
SingleSelectColumns,
CenteredButtonPress, CenteredButtonPress,
IntTitleSlider, IntTitleSlider,
set_min_terminal_size, set_min_terminal_size,
@ -76,7 +77,7 @@ Weights_dir = "ldm/stable-diffusion-v1/"
Default_config_file = config.model_conf_path Default_config_file = config.model_conf_path
SD_Configs = config.legacy_conf_path SD_Configs = config.legacy_conf_path
PRECISION_CHOICES = ['auto','float16','float32','autocast'] PRECISION_CHOICES = ['auto','float16','float32']
INIT_FILE_PREAMBLE = """# InvokeAI initialization file INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# This is the InvokeAI initialization file, which contains command-line default values. # This is the InvokeAI initialization file, which contains command-line default values.
@ -359,9 +360,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
scroll_exit=True, scroll_exit=True,
) )
self.nextrely += 1 self.nextrely += 1
label = """If you have an account at HuggingFace you may optionally paste your access token here label = """HuggingFace access token (OPTIONAL) for automatic model downloads. See https://huggingface.co/settings/tokens."""
to allow InvokeAI to download restricted styles & subjects from the "Concept Library". See https://huggingface.co/settings/tokens.
"""
for line in textwrap.wrap(label,width=window_width-6): for line in textwrap.wrap(label,width=window_width-6):
self.add_widget_intelligent( self.add_widget_intelligent(
npyscreen.FixedText, npyscreen.FixedText,
@ -423,6 +422,7 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib
) )
self.precision = self.add_widget_intelligent( self.precision = self.add_widget_intelligent(
npyscreen.TitleSelectOne, npyscreen.TitleSelectOne,
columns = 2,
name="Precision", name="Precision",
values=PRECISION_CHOICES, values=PRECISION_CHOICES,
value=PRECISION_CHOICES.index(precision), value=PRECISION_CHOICES.index(precision),

View File

@ -3,7 +3,6 @@ Migrate the models directory and models.yaml file from an existing
InvokeAI 2.3 installation to 3.0.0. InvokeAI 2.3 installation to 3.0.0.
''' '''
import io
import os import os
import argparse import argparse
import shutil import shutil
@ -28,9 +27,10 @@ from transformers import (
) )
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager from invokeai.backend.model_management import ModelManager
from invokeai.backend.model_management.model_probe import ( from invokeai.backend.model_management.model_probe import (
ModelProbe, ModelType, BaseModelType, SchedulerPredictionType, ModelProbeInfo ModelProbe, ModelType, BaseModelType, ModelProbeInfo
) )
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -47,52 +47,27 @@ class ModelPaths:
class MigrateTo3(object): class MigrateTo3(object):
def __init__(self, def __init__(self,
root_directory: Path, from_root: Path,
dest_models: Path, to_models: Path,
yaml_file: io.TextIOBase, model_manager: ModelManager,
src_paths: ModelPaths, src_paths: ModelPaths,
): ):
self.root_directory = root_directory self.root_directory = from_root
self.dest_models = dest_models self.dest_models = to_models
self.dest_yaml = yaml_file self.mgr = model_manager
self.model_names = set()
self.src_paths = src_paths self.src_paths = src_paths
self._initialize_yaml() @classmethod
def initialize_yaml(cls, yaml_file: Path):
def _initialize_yaml(self): with open(yaml_file, 'w') as file:
self.dest_yaml.write( file.write(
yaml.dump( yaml.dump(
{ {
'__metadata__': '__metadata__': {'version':'3.0.0'}
{
'version':'3.0.0'}
} }
) )
) )
def unique_name(self,name,info)->str:
'''
Create a unique name for a model for use within models.yaml.
'''
done = False
# some model names have slashes in them, which really screws things up
name = name.replace('/','_')
key = ModelManager.create_key(name,info.base_type,info.model_type)
unique_name = key
counter = 1
while not done:
if unique_name in self.model_names:
unique_name = f'{key}-{counter:0>2d}'
counter += 1
else:
done = True
self.model_names.add(unique_name)
name,_,_ = ModelManager.parse_key(unique_name)
return name
def create_directory_structure(self): def create_directory_structure(self):
''' '''
Create the basic directory structure for the models folder. Create the basic directory structure for the models folder.
@ -140,23 +115,8 @@ class MigrateTo3(object):
that looks like a model, and copy the model into the that looks like a model, and copy the model into the
appropriate location within the destination models directory. appropriate location within the destination models directory.
''' '''
directories_scanned = set()
for root, dirs, files in os.walk(src_dir): for root, dirs, files in os.walk(src_dir):
for f in files:
# hack - don't copy raw learned_embeds.bin, let them
# be copied as part of a tree copy operation
if f == 'learned_embeds.bin':
continue
try:
model = Path(root,f)
info = ModelProbe().heuristic_probe(model)
if not info:
continue
dest = self._model_probe_to_path(info) / f
self.copy_file(model, dest)
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
for d in dirs: for d in dirs:
try: try:
model = Path(root,d) model = Path(root,d)
@ -165,6 +125,29 @@ class MigrateTo3(object):
continue continue
dest = self._model_probe_to_path(info) / model.name dest = self._model_probe_to_path(info) / model.name
self.copy_dir(model, dest) self.copy_dir(model, dest)
directories_scanned.add(model)
except Exception as e:
logger.error(str(e))
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
for f in files:
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
# let them be copied as part of a tree copy operation
try:
if f in {'learned_embeds.bin','pytorch_lora_weights.bin'}:
continue
model = Path(root,f)
if model.parent in directories_scanned:
continue
info = ModelProbe().heuristic_probe(model)
if not info:
continue
dest = self._model_probe_to_path(info) / f
self.copy_file(model, dest)
except Exception as e:
logger.error(str(e))
except KeyboardInterrupt: except KeyboardInterrupt:
raise raise
except Exception as e: except Exception as e:
@ -267,28 +250,6 @@ class MigrateTo3(object):
except Exception as e: except Exception as e:
logger.error(str(e)) logger.error(str(e))
def write_yaml(self, model_name: str, path:Path, info:ModelProbeInfo, **kwargs):
'''
Write a stanza for a moved model into the new models.yaml file.
'''
name = self.unique_name(model_name, info)
stanza = {
f'{info.base_type.value}/{info.model_type.value}/{name}': {
'name': model_name,
'path': str(path),
'description': f'A {info.base_type.value} {info.model_type.value} model',
'format': info.format,
'image_size': info.image_size,
'base': info.base_type.value,
'variant': info.variant_type.value,
'prediction_type': info.prediction_type.value,
'upcast_attention': info.prediction_type == SchedulerPredictionType.VPrediction,
**kwargs,
}
}
self.dest_yaml.write(yaml.dump(stanza))
self.dest_yaml.flush()
def _model_probe_to_path(self, info: ModelProbeInfo)->Path: def _model_probe_to_path(self, info: ModelProbeInfo)->Path:
return Path(self.dest_models, info.base_type.value, info.model_type.value) return Path(self.dest_models, info.base_type.value, info.model_type.value)
@ -332,6 +293,7 @@ class MigrateTo3(object):
elif repo_id := vae.get('repo_id'): elif repo_id := vae.get('repo_id'):
if repo_id=='stabilityai/sd-vae-ft-mse': # this guy is already downloaded if repo_id=='stabilityai/sd-vae-ft-mse': # this guy is already downloaded
vae_path = 'models/core/convert/sd-vae-ft-mse' vae_path = 'models/core/convert/sd-vae-ft-mse'
return vae_path
else: else:
vae_path = self._download_vae(repo_id, vae.get('subfolder')) vae_path = self._download_vae(repo_id, vae.get('subfolder'))
@ -344,7 +306,10 @@ class MigrateTo3(object):
info = ModelProbe().heuristic_probe(vae_path) info = ModelProbe().heuristic_probe(vae_path)
dest = self._model_probe_to_path(info) / vae_path.name dest = self._model_probe_to_path(info) / vae_path.name
if not dest.exists(): if not dest.exists():
if vae_path.is_dir():
self.copy_dir(vae_path,dest) self.copy_dir(vae_path,dest)
else:
self.copy_file(vae_path,dest)
vae_path = dest vae_path = dest
if vae_path.is_relative_to(self.dest_models): if vae_path.is_relative_to(self.dest_models):
@ -353,7 +318,7 @@ class MigrateTo3(object):
else: else:
return vae_path return vae_path
def migrate_repo_id(self, repo_id: str, model_name :str=None, **extra_config): def migrate_repo_id(self, repo_id: str, model_name: str=None, **extra_config):
''' '''
Migrate a locally-cached diffusers pipeline identified with a repo_id Migrate a locally-cached diffusers pipeline identified with a repo_id
''' '''
@ -385,11 +350,15 @@ class MigrateTo3(object):
if not info: if not info:
return return
dest = self._model_probe_to_path(info) / repo_name if self.mgr.model_exists(model_name, info.base_type, info.model_type):
logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.')
return
dest = self._model_probe_to_path(info) / model_name
self._save_pretrained(pipeline, dest) self._save_pretrained(pipeline, dest)
rel_path = Path('models',dest.relative_to(dest_dir)) rel_path = Path('models',dest.relative_to(dest_dir))
self.write_yaml(model_name, path=rel_path, info=info, **extra_config) self._add_model(model_name, info, rel_path, **extra_config)
def migrate_path(self, location: Path, model_name: str=None, **extra_config): def migrate_path(self, location: Path, model_name: str=None, **extra_config):
''' '''
@ -399,19 +368,48 @@ class MigrateTo3(object):
# handle relative paths # handle relative paths
dest_dir = self.dest_models dest_dir = self.dest_models
location = self.root_directory / location location = self.root_directory / location
model_name = model_name or location.stem
info = ModelProbe().heuristic_probe(location) info = ModelProbe().heuristic_probe(location)
if not info: if not info:
return return
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.')
return
# uh oh, weights is in the old models directory - move it into the new one # uh oh, weights is in the old models directory - move it into the new one
if Path(location).is_relative_to(self.src_paths.models): if Path(location).is_relative_to(self.src_paths.models):
dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name) dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name)
if location.is_dir():
self.copy_dir(location,dest) self.copy_dir(location,dest)
else:
self.copy_file(location,dest)
location = Path('models', info.base_type.value, info.model_type.value, location.name) location = Path('models', info.base_type.value, info.model_type.value, location.name)
model_name = model_name or location.stem
model_name = self.unique_name(model_name, info) self._add_model(model_name, info, location, **extra_config)
self.write_yaml(model_name, path=location, info=info, **extra_config)
def _add_model(self,
model_name: str,
info: ModelProbeInfo,
location: Path,
**extra_config):
if info.model_type != ModelType.Main:
return
self.mgr.add_model(
model_name = model_name,
base_model = info.base_type,
model_type = info.model_type,
clobber = True,
model_attributes = {
'path': str(location),
'description': f'A {info.base_type.value} {info.model_type.value} model',
'model_format': info.format,
'variant': info.variant_type.value,
**extra_config,
}
)
def migrate_defined_models(self): def migrate_defined_models(self):
''' '''
@ -435,6 +433,9 @@ class MigrateTo3(object):
if config := stanza.get('config'): if config := stanza.get('config'):
passthru_args['config'] = config passthru_args['config'] = config
if description:= stanza.get('description'):
passthru_args['description'] = description
if repo_id := stanza.get('repo_id'): if repo_id := stanza.get('repo_id'):
logger.info(f'Migrating diffusers model {model_name}') logger.info(f'Migrating diffusers model {model_name}')
self.migrate_repo_id(repo_id, model_name, **passthru_args) self.migrate_repo_id(repo_id, model_name, **passthru_args)
@ -514,30 +515,49 @@ def get_legacy_embeddings(root: Path) -> ModelPaths:
return _parse_legacy_yamlfile(root, path) return _parse_legacy_yamlfile(root, path)
def do_migrate(src_directory: Path, dest_directory: Path): def do_migrate(src_directory: Path, dest_directory: Path):
"""
Migrate models from src to dest InvokeAI root directories
"""
config_file = dest_directory / 'configs' / 'models.yaml.3'
dest_models = dest_directory / 'models.3'
dest_models = dest_directory / 'models-3.0' version_3 = (dest_directory / 'models' / 'core').exists()
dest_yaml = dest_directory / 'configs/models.yaml-3.0'
# Here we create the destination models.yaml file.
# If we are writing into a version 3 directory and the
# file already exists, then we write into a copy of it to
# avoid deleting its previous customizations. Otherwise we
# create a new empty one.
if version_3: # write into the dest directory
try:
shutil.copy(dest_directory / 'configs' / 'models.yaml', config_file)
except:
MigrateTo3.initialize_yaml(config_file)
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
(dest_directory / 'models').replace(dest_models)
else:
MigrateTo3.initialize_yaml(config_file)
mgr = ModelManager(config_file)
paths = get_legacy_embeddings(src_directory) paths = get_legacy_embeddings(src_directory)
migrator = MigrateTo3(
with open(dest_yaml,'w') as yaml_file: from_root = src_directory,
migrator = MigrateTo3(src_directory, to_models = dest_models,
dest_models, model_manager = mgr,
yaml_file, src_paths = paths
src_paths = paths,
) )
migrator.migrate() migrator.migrate()
print("Migration successful.")
shutil.rmtree(dest_directory / 'models.orig', ignore_errors=True) if not version_3:
(dest_directory / 'models').replace(dest_directory / 'models.orig') (dest_directory / 'models').replace(src_directory / 'models.orig')
dest_models.replace(dest_directory / 'models') print(f'Original models directory moved to {dest_directory}/models.orig')
(dest_directory /'configs/models.yaml').replace(dest_directory / 'configs/models.yaml.orig') (dest_directory / 'configs' / 'models.yaml').replace(src_directory / 'configs' / 'models.yaml.orig')
dest_yaml.replace(dest_directory / 'configs/models.yaml') print(f'Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig')
print(f"""Migration successful.
Original models directory moved to {dest_directory}/models.orig config_file.replace(config_file.with_suffix(''))
Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig dest_models.replace(dest_models.with_suffix(''))
""")
def main(): def main():
parser = argparse.ArgumentParser(prog="invokeai-migrate3", parser = argparse.ArgumentParser(prog="invokeai-migrate3",
@ -550,36 +570,34 @@ It is safe to provide the same directory for both arguments, but it is better to
script, which will perform a full upgrade in place.""" script, which will perform a full upgrade in place."""
) )
parser.add_argument('--from-directory', parser.add_argument('--from-directory',
dest='root_directory', dest='src_root',
type=Path, type=Path,
required=True, required=True,
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")' help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")'
) )
parser.add_argument('--to-directory', parser.add_argument('--to-directory',
dest='dest_directory', dest='dest_root',
type=Path, type=Path,
required=True, required=True,
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")' help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")'
) )
# TO DO: Implement full directory scanning
# parser.add_argument('--all-models',
# action="store_true",
# help='Migrate all models found in `models` directory, not just those mentioned in models.yaml',
# )
args = parser.parse_args() args = parser.parse_args()
root_directory = args.root_directory src_root = args.src_root
assert root_directory.is_dir(), f"{root_directory} is not a valid directory" assert src_root.is_dir(), f"{src_root} is not a valid directory"
assert (root_directory / 'models').is_dir(), f"{root_directory} does not contain a 'models' subdirectory" assert (src_root / 'models').is_dir(), f"{src_root} does not contain a 'models' subdirectory"
assert (root_directory / 'invokeai.init').exists() or (root_directory / 'invokeai.yaml').exists(), f"{root_directory} does not contain an InvokeAI init file." assert (src_root / 'models' / 'hub').exists(), f"{src_root} does not contain a version 2.3 models directory"
assert (src_root / 'invokeai.init').exists() or (src_root / 'invokeai.yaml').exists(), f"{src_root} does not contain an InvokeAI init file."
dest_directory = args.dest_directory dest_root = args.dest_root
assert dest_directory.is_dir(), f"{dest_directory} is not a valid directory" assert dest_root.is_dir(), f"{dest_root} is not a valid directory"
config = InvokeAIAppConfig.get_config()
config.parse_args(['--root',str(dest_root)])
# TODO: revisit # TODO: revisit
# assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory" # assert (dest_root / 'models').is_dir(), f"{dest_root} does not contain a 'models' subdirectory"
# assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file." # assert (dest_root / 'invokeai.yaml').exists(), f"{dest_root} does not contain an InvokeAI init file."
do_migrate(root_directory,dest_directory) do_migrate(src_root,dest_root)
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -11,6 +11,7 @@ from typing import List, Dict, Callable, Union, Set
import requests import requests
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
from diffusers import logging as dlogging
from huggingface_hub import hf_hub_url, HfFolder, HfApi from huggingface_hub import hf_hub_url, HfFolder, HfApi
from omegaconf import OmegaConf from omegaconf import OmegaConf
from tqdm import tqdm from tqdm import tqdm
@ -153,6 +154,9 @@ class ModelInstall(object):
return defaults[0] return defaults[0]
def install(self, selections: InstallSelections): def install(self, selections: InstallSelections):
verbosity = dlogging.get_verbosity() # quench NSFW nags
dlogging.set_verbosity_error()
job = 1 job = 1
jobs = len(selections.remove_models) + len(selections.install_models) jobs = len(selections.remove_models) + len(selections.install_models)
@ -160,20 +164,28 @@ class ModelInstall(object):
for key in selections.remove_models: for key in selections.remove_models:
name,base,mtype = self.mgr.parse_key(key) name,base,mtype = self.mgr.parse_key(key)
logger.info(f'Deleting {mtype} model {name} [{job}/{jobs}]') logger.info(f'Deleting {mtype} model {name} [{job}/{jobs}]')
try:
self.mgr.del_model(name,base,mtype) self.mgr.del_model(name,base,mtype)
except FileNotFoundError as e:
logger.warning(e)
job += 1 job += 1
# add requested models # add requested models
for path in selections.install_models: for path in selections.install_models:
logger.info(f'Installing {path} [{job}/{jobs}]') logger.info(f'Installing {path} [{job}/{jobs}]')
try:
self.heuristic_import(path) self.heuristic_import(path)
except (ValueError, KeyError) as e:
logger.error(str(e))
job += 1 job += 1
dlogging.set_verbosity(verbosity)
self.mgr.commit() self.mgr.commit()
def heuristic_import(self, def heuristic_import(self,
model_path_id_or_url: Union[str,Path], model_path_id_or_url: Union[str,Path],
models_installed: Set[Path]=None)->Dict[str, AddModelResult]: models_installed: Set[Path]=None,
)->Dict[str, AddModelResult]:
''' '''
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL :param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
:param models_installed: Set of installed models, used for recursive invocation :param models_installed: Set of installed models, used for recursive invocation
@ -186,11 +198,9 @@ class ModelInstall(object):
# A little hack to allow nested routines to retrieve info on the requested ID # A little hack to allow nested routines to retrieve info on the requested ID
self.current_id = model_path_id_or_url self.current_id = model_path_id_or_url
path = Path(model_path_id_or_url) path = Path(model_path_id_or_url)
try:
# checkpoint file, or similar # checkpoint file, or similar
if path.is_file(): if path.is_file():
models_installed.update(self._install_path(path)) models_installed.update({str(path):self._install_path(path)})
# folders style or similar # folders style or similar
elif path.is_dir() and any([(path/x).exists() for x in \ elif path.is_dir() and any([(path/x).exists() for x in \
@ -205,43 +215,36 @@ class ModelInstall(object):
self.heuristic_import(child, models_installed=models_installed) self.heuristic_import(child, models_installed=models_installed)
# huggingface repo # huggingface repo
elif len(str(path).split('/')) == 2: elif len(str(model_path_id_or_url).split('/')) == 2:
models_installed.update(self._install_repo(str(path))) models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
# a URL # a URL
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")): elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
models_installed.update(self._install_url(model_path_id_or_url)) models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
else: else:
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping') raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
except ValueError as e:
logger.error(str(e))
return models_installed return models_installed
# install a model from a local path. The optional info parameter is there to prevent # install a model from a local path. The optional info parameter is there to prevent
# the model from being probed twice in the event that it has already been probed. # the model from being probed twice in the event that it has already been probed.
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]: def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult:
try:
model_result = None
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
if not info:
logger.warning(f'Unable to parse format of {path}')
return None
model_name = path.stem if path.is_file() else path.name model_name = path.stem if path.is_file() else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type): if self.mgr.model_exists(model_name, info.base_type, info.model_type):
raise ValueError(f'A model named "{model_name}" is already installed.') raise ValueError(f'A model named "{model_name}" is already installed.')
attributes = self._make_attributes(path,info) attributes = self._make_attributes(path,info)
model_result = self.mgr.add_model(model_name = model_name, return self.mgr.add_model(model_name = model_name,
base_model = info.base_type, base_model = info.base_type,
model_type = info.model_type, model_type = info.model_type,
model_attributes = attributes, model_attributes = attributes,
) )
except Exception as e:
logger.warning(f'{str(e)} Skipping registration.')
return {}
return {str(path): model_result}
def _install_url(self, url: str)->dict: def _install_url(self, url: str)->AddModelResult:
# copy to a staging area, probe, import and delete
with TemporaryDirectory(dir=self.config.models_path) as staging: with TemporaryDirectory(dir=self.config.models_path) as staging:
location = download_with_resume(url,Path(staging)) location = download_with_resume(url,Path(staging))
if not location: if not location:
@ -253,7 +256,7 @@ class ModelInstall(object):
# staged version will be garbage-collected at this time # staged version will be garbage-collected at this time
return self._install_path(Path(models_path), info) return self._install_path(Path(models_path), info)
def _install_repo(self, repo_id: str)->dict: def _install_repo(self, repo_id: str)->AddModelResult:
hinfo = HfApi().model_info(repo_id) hinfo = HfApi().model_info(repo_id)
# we try to figure out how to download this most economically # we try to figure out how to download this most economically
@ -279,16 +282,16 @@ class ModelInstall(object):
location = self._download_hf_model(repo_id, files, staging) location = self._download_hf_model(repo_id, files, staging)
break break
elif f'learned_embeds.{suffix}' in files: elif f'learned_embeds.{suffix}' in files:
location = self._download_hf_model(repo_id, ['learned_embeds.suffix'], staging) location = self._download_hf_model(repo_id, [f'learned_embeds.{suffix}'], staging)
break break
if not location: if not location:
logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.') logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.')
return return {}
info = ModelProbe().heuristic_probe(location, self.prediction_helper) info = ModelProbe().heuristic_probe(location, self.prediction_helper)
if not info: if not info:
logger.warning(f'Could not probe {location}. Skipping install.') logger.warning(f'Could not probe {location}. Skipping install.')
return return {}
dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location) dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location)
if dest.exists(): if dest.exists():
shutil.rmtree(dest) shutil.rmtree(dest)

View File

@ -1,7 +1,8 @@
""" """
Initialization file for invokeai.backend.model_management Initialization file for invokeai.backend.model_management
""" """
from .model_manager import ModelManager, ModelInfo, AddModelResult from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
from .model_cache import ModelCache from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
from .model_merge import ModelMerger, MergeInterpolationMethod

View File

@ -2,8 +2,8 @@ from __future__ import annotations
import copy import copy
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Dict, Tuple, Any, Union, List
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union, List
import torch import torch
from compel.embeddings_provider import BaseTextualInversionManager from compel.embeddings_provider import BaseTextualInversionManager
@ -615,6 +615,24 @@ class ModelPatcher:
text_encoder.resize_token_embeddings(init_tokens_count) text_encoder.resize_token_embeddings(init_tokens_count)
@classmethod
@contextmanager
def apply_clip_skip(
cls,
text_encoder: CLIPTextModel,
clip_skip: int,
):
skipped_layers = []
try:
for i in range(clip_skip):
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
yield
finally:
while len(skipped_layers) > 0:
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
class TextualInversionModel: class TextualInversionModel:
name: str name: str
embedding: torch.Tensor # [n, 768]|[n, 1280] embedding: torch.Tensor # [n, 768]|[n, 1280]

View File

@ -342,7 +342,9 @@ class ModelCache(object):
for model_key, cache_entry in self._cached_models.items(): for model_key, cache_entry in self._cached_models.items():
if not cache_entry.locked and cache_entry.loaded: if not cache_entry.locked and cache_entry.loaded:
self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}') self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}')
with VRAMUsage() as mem:
cache_entry.model.to(self.storage_device) cache_entry.model.to(self.storage_device)
self.logger.debug(f'GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB')
def _local_model_hash(self, model_path: Union[str, Path]) -> str: def _local_model_hash(self, model_path: Union[str, Path]) -> str:
sha = hashlib.sha256() sha = hashlib.sha256()

View File

@ -52,7 +52,7 @@ A typical example is:
sd1_5 = mgr.get_model('stable-diffusion-v1-5', sd1_5 = mgr.get_model('stable-diffusion-v1-5',
model_type=ModelType.Main, model_type=ModelType.Main,
base_model=BaseModelType.StableDiffusion1, base_model=BaseModelType.StableDiffusion1,
submodel_type=SubModelType.Unet) submodel_type=SubModelType.UNet)
with sd1_5 as unet: with sd1_5 as unet:
run_some_inference(unet) run_some_inference(unet)
@ -234,7 +234,7 @@ import textwrap
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
from shutil import rmtree from shutil import rmtree, move
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -279,7 +279,7 @@ class InvalidModelError(Exception):
pass pass
class AddModelResult(BaseModel): class AddModelResult(BaseModel):
name: str = Field(description="The name of the model after import") name: str = Field(description="The name of the model after installation")
model_type: ModelType = Field(description="The type of model") model_type: ModelType = Field(description="The type of model")
base_model: BaseModelType = Field(description="The base model") base_model: BaseModelType = Field(description="The base model")
config: ModelConfigBase = Field(description="The configuration of the model") config: ModelConfigBase = Field(description="The configuration of the model")
@ -311,7 +311,6 @@ class ModelManager(object):
and sequential_offload boolean. Note that the default device and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision. type and precision are set up for a CUDA system running at half precision.
""" """
self.config_path = None self.config_path = None
if isinstance(config, (str, Path)): if isinstance(config, (str, Path)):
self.config_path = Path(config) self.config_path = Path(config)
@ -491,17 +490,32 @@ class ModelManager(object):
""" """
return [(self.parse_key(x)) for x in self.models.keys()] return [(self.parse_key(x)) for x in self.models.keys()]
def list_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> dict:
"""
Returns a dict describing one installed model, using
the combined format of the list_models() method.
"""
models = self.list_models(base_model,model_type,model_name)
return models[0] if models else None
def list_models( def list_models(
self, self,
base_model: Optional[BaseModelType] = None, base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
model_name: Optional[str] = None,
) -> list[dict]: ) -> list[dict]:
""" """
Return a list of models. Return a list of models.
""" """
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
models = [] models = []
for model_key in sorted(self.models, key=str.casefold): for model_key in model_keys:
model_config = self.models[model_key] model_config = self.models[model_key]
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
@ -546,10 +560,7 @@ class ModelManager(object):
model_cfg = self.models.pop(model_key, None) model_cfg = self.models.pop(model_key, None)
if model_cfg is None: if model_cfg is None:
self.logger.error( raise KeyError(f"Unknown model {model_key}")
f"Unknown model {model_key}"
)
return
# note: it not garantie to release memory(model can has other references) # note: it not garantie to release memory(model can has other references)
cache_ids = self.cache_keys.pop(model_key, []) cache_ids = self.cache_keys.pop(model_key, [])
@ -615,6 +626,7 @@ class ModelManager(object):
self.cache.uncache_model(cache_id) self.cache.uncache_model(cache_id)
self.models[model_key] = model_config self.models[model_key] = model_config
self.commit()
return AddModelResult( return AddModelResult(
name = model_name, name = model_name,
model_type = model_type, model_type = model_type,
@ -622,6 +634,60 @@ class ModelManager(object):
config = model_config, config = model_config,
) )
def convert_model (
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
) -> AddModelResult:
'''
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
This will raise a ValueError unless the model is a checkpoint.
'''
info = self.model_info(model_name, base_model, model_type)
if info["model_format"] != "checkpoint":
raise ValueError(f"not a checkpoint format model: {model_name}")
# We are taking advantage of a side effect of get_model() that converts check points
# into cached diffusers directories stored at `location`. It doesn't matter
# what submodeltype we request here, so we get the smallest.
submodel = {"submodel_type": SubModelType.Tokenizer} if model_type==ModelType.Main else {}
model = self.get_model(model_name,
base_model,
model_type,
**submodel,
)
checkpoint_path = self.app_config.root_path / info["path"]
old_diffusers_path = self.app_config.models_path / model.location
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
if new_diffusers_path.exists():
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
try:
move(old_diffusers_path,new_diffusers_path)
info["model_format"] = "diffusers"
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path))
info.pop('config')
result = self.add_model(model_name, base_model, model_type,
model_attributes = info,
clobber=True)
except:
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
rmtree(new_diffusers_path)
raise
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path):
checkpoint_path.unlink()
return result
def search_models(self, search_folder): def search_models(self, search_folder):
self.logger.info(f"Finding Models In: {search_folder}") self.logger.info(f"Finding Models In: {search_folder}")
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt") models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
@ -703,6 +769,7 @@ class ModelManager(object):
model_class = MODEL_CLASSES[cur_base_model][cur_model_type] model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
if model_class.save_to_config: if model_class.save_to_config:
model_config.error = ModelError.NotFound model_config.error = ModelError.NotFound
self.models.pop(model_key, None)
else: else:
self.models.pop(model_key, None) self.models.pop(model_key, None)
else: else:
@ -821,6 +888,10 @@ class ModelManager(object):
The result is a set of successfully installed models. Each element The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model. that model.
May return the following exceptions:
- KeyError - one or more of the items to import is not a valid path, repo_id or URL
- ValueError - a corresponding model already exists
''' '''
# avoid circular import here # avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.backend.install.model_install_backend import ModelInstall
@ -830,11 +901,7 @@ class ModelManager(object):
prediction_type_helper = prediction_type_helper, prediction_type_helper = prediction_type_helper,
model_manager = self) model_manager = self)
for thing in items_to_import: for thing in items_to_import:
try:
installed = installer.heuristic_import(thing) installed = installer.heuristic_import(thing)
successfully_installed.update(installed) successfully_installed.update(installed)
except Exception as e:
self.logger.warning(f'{thing} could not be imported: {str(e)}')
self.commit() self.commit()
return successfully_installed return successfully_installed

View File

@ -0,0 +1,131 @@
"""
invokeai.backend.model_management.model_merge exports:
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import warnings
from enum import Enum
from pathlib import Path
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
from typing import List, Union
import invokeai.backend.util.logging as logger
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
class MergeInterpolationMethod(str, Enum):
WeightedSum = "weighted_sum"
Sigmoid = "sigmoid"
InvSigmoid = "inv_sigmoid"
AddDifference = "add_difference"
class ModelMerger(object):
def __init__(self, manager: ModelManager):
self.manager = manager
def merge_diffusion_models(
self,
model_paths: List[Path],
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
force: bool = False,
**kwargs,
) -> DiffusionPipeline:
"""
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
pipe = DiffusionPipeline.from_pretrained(
model_paths[0],
custom_pipeline="checkpoint_merger",
)
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_paths,
alpha=alpha,
interp=interp.value if interp else None, #diffusers API treats None as "weighted sum"
force=force,
**kwargs,
)
dlogging.set_verbosity(verbosity)
return merged_pipe
def merge_diffusion_models_and_save (
self,
model_names: List[str],
base_model: Union[BaseModelType,str],
merged_model_name: str,
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
force: bool = False,
**kwargs,
) -> AddModelResult:
"""
:param models: up to three models, designated by their InvokeAI models.yaml model name
:param base_model: base model (must be the same for all merged models!)
:param merged_model_name: name for new model
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_paths = list()
config = self.manager.app_config
base_model = BaseModelType(base_model)
vae = None
for mod in model_names:
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
assert info, f"model {mod}, base_model {base_model}, is unknown"
assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
assert len(model_names) <= 2 or \
interp==MergeInterpolationMethod.AddDifference, "When merging three models, only the 'add_difference' merge method is supported"
# pick up the first model's vae
if mod == model_names[0]:
vae = info.get("vae")
model_paths.extend([config.root_path / info["path"]])
merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp)
logger.debug(f'interp = {interp}, merge_method={merge_method}')
merged_pipe = self.merge_diffusion_models(
model_paths, alpha, merge_method, force, **kwargs
)
dump_path = config.models_path / base_model.value / ModelType.Main.value
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
attributes = dict(
path = str(dump_path),
description = f"Merge of models {', '.join(model_names)}",
model_format = "diffusers",
variant = ModelVariantType.Normal.value,
vae = vae,
)
return self.manager.add_model(merged_model_name,
base_model = base_model,
model_type = ModelType.Main,
model_attributes = attributes,
clobber = True
)

View File

@ -6,7 +6,7 @@ from dataclasses import dataclass
from diffusers import ModelMixin, ConfigMixin from diffusers import ModelMixin, ConfigMixin
from pathlib import Path from pathlib import Path
from typing import Callable, Literal, Union, Dict from typing import Callable, Literal, Union, Dict, Optional
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from .models import ( from .models import (
@ -64,8 +64,8 @@ class ModelProbe(object):
@classmethod @classmethod
def probe(cls, def probe(cls,
model_path: Path, model_path: Path,
model: Union[Dict, ModelMixin] = None, model: Optional[Union[Dict, ModelMixin]] = None,
prediction_type_helper: Callable[[Path],SchedulerPredictionType] = None)->ModelProbeInfo: prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]] = None)->ModelProbeInfo:
''' '''
Probe the model at model_path and return sufficient information about it Probe the model at model_path and return sufficient information about it
to place it somewhere in the models directory hierarchy. If the model is to place it somewhere in the models directory hierarchy. If the model is
@ -168,7 +168,7 @@ class ModelProbe(object):
return type return type
# give up # give up
raise ValueError("Unable to determine model type for {folder_path}") raise ValueError(f"Unable to determine model type for {folder_path}")
@classmethod @classmethod
def _scan_and_load_checkpoint(cls,model_path: Path)->dict: def _scan_and_load_checkpoint(cls,model_path: Path)->dict:

View File

@ -68,7 +68,11 @@ def get_model_config_enums():
enums = list() enums = list()
for model_config in MODEL_CONFIGS: for model_config in MODEL_CONFIGS:
if hasattr(inspect,'get_annotations'):
fields = inspect.get_annotations(model_config) fields = inspect.get_annotations(model_config)
else:
fields = model_config.__annotations__
try: try:
field = fields["model_format"] field = fields["model_format"]
except: except:

View File

@ -7,7 +7,7 @@ import secrets
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import BaseModel, Field from pydantic import Field
import einops import einops
import PIL.Image import PIL.Image
@ -17,12 +17,11 @@ import psutil
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline, StableDiffusionPipeline,
) )
from diffusers.pipelines.controlnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
@ -46,7 +45,7 @@ from .diffusion import (
InvokeAIDiffuserComponent, InvokeAIDiffuserComponent,
PostprocessingSettings, PostprocessingSettings,
) )
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup from .offloading import FullyLoadedModelGroup, ModelGroup
@dataclass @dataclass
class PipelineIntermediateState: class PipelineIntermediateState:
@ -105,7 +104,7 @@ class AddsMaskGuidance:
_debug: Optional[Callable] = None _debug: Optional[Callable] = None
def __call__( def __call__(
self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning
) -> BaseOutput: ) -> BaseOutput:
output_class = step_output.__class__ # We'll create a new one with masked data. output_class = step_output.__class__ # We'll create a new one with masked data.
@ -360,12 +359,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
and not config.disable_xformers and not config.disable_xformers
): ):
self.enable_xformers_memory_efficient_attention() self.enable_xformers_memory_efficient_attention()
else:
if torch.backends.mps.is_available():
# until pytorch #91617 is fixed, slicing is borked on MPS
# https://github.com/pytorch/pytorch/issues/91617
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
pass
else: else:
if self.device.type == "cpu" or self.device.type == "mps": if self.device.type == "cpu" or self.device.type == "mps":
mem_free = psutil.virtual_memory().free mem_free = psutil.virtual_memory().free
@ -390,6 +383,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
mem_free * 3.0 / 4.0 mem_free * 3.0 / 4.0
): # 3.3 / 4.0 is from old Invoke code ): # 3.3 / 4.0 is from old Invoke code
self.enable_attention_slicing(slice_size="max") self.enable_attention_slicing(slice_size="max")
elif torch.backends.mps.is_available():
# diffusers recommends always enabling for mps
self.enable_attention_slicing(slice_size="max")
else: else:
self.disable_attention_slicing() self.disable_attention_slicing()
@ -917,20 +913,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype): def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
init_image = init_image.to(device=device, dtype=dtype) init_image = init_image.to(device=device, dtype=dtype)
with torch.inference_mode(): with torch.inference_mode():
if device.type == "mps":
# workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
self.vae.to(CPU_DEVICE)
init_image = init_image.to(CPU_DEVICE)
else:
self._model_group.load(self.vae) self._model_group.load(self.vae)
init_latent_dist = self.vae.encode(init_image).latent_dist init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample().to( init_latents = init_latent_dist.sample().to(
dtype=dtype dtype=dtype
) # FIXME: uses torch.randn. make reproducible! ) # FIXME: uses torch.randn. make reproducible!
if device.type == "mps":
self.vae.to(device)
init_latents = init_latents.to(device)
init_latents = 0.18215 * init_latents init_latents = 0.18215 * init_latents
return init_latents return init_latents

View File

@ -248,9 +248,6 @@ class InvokeAIDiffuserComponent:
x_twice, sigma_twice, both_conditionings, **kwargs, x_twice, sigma_twice, both_conditionings, **kwargs,
) )
unconditioned_next_x, conditioned_next_x = both_results.chunk(2) unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
def _apply_standard_conditioning_sequentially( def _apply_standard_conditioning_sequentially(
@ -264,9 +261,6 @@ class InvokeAIDiffuserComponent:
# low-memory sequential path # low-memory sequential path
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs) unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs) conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
# TODO: looks unused # TODO: looks unused

View File

@ -4,7 +4,7 @@ import warnings
import weakref import weakref
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections.abc import MutableMapping from collections.abc import MutableMapping
from typing import Callable from typing import Callable, Union
import torch import torch
from accelerate.utils import send_to_device from accelerate.utils import send_to_device
@ -117,7 +117,7 @@ class LazilyLoadedModelGroup(ModelGroup):
""" """
_hooks: MutableMapping[torch.nn.Module, RemovableHandle] _hooks: MutableMapping[torch.nn.Module, RemovableHandle]
_current_model_ref: Callable[[], torch.nn.Module | _NoModel] _current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]]
def __init__(self, execution_device: torch.device): def __init__(self, execution_device: torch.device):
super().__init__(execution_device) super().__init__(execution_device)

View File

@ -4,6 +4,7 @@ from contextlib import nullcontext
import torch import torch
from torch import autocast from torch import autocast
from typing import Union
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
@ -28,6 +29,8 @@ def choose_precision(device: torch.device) -> str:
device_name = torch.cuda.get_device_name(device) device_name = torch.cuda.get_device_name(device)
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name): if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
return "float16" return "float16"
elif device.type == "mps":
return "float16"
return "float32" return "float32"
@ -49,7 +52,7 @@ def choose_autocast(precision):
return nullcontext return nullcontext
def normalize_device(device: str | torch.device) -> torch.device: def normalize_device(device: Union[str, torch.device]) -> torch.device:
"""Ensure device has a device index defined, if appropriate.""" """Ensure device has a device index defined, if appropriate."""
device = torch.device(device) device = torch.device(device)
if device.index is None: if device.index is None:

View File

@ -0,0 +1,63 @@
import torch
if torch.backends.mps.is_available():
torch.empty = torch.zeros
_torch_layer_norm = torch.nn.functional.layer_norm
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
if input.device.type == "mps" and input.dtype == torch.float16:
input = input.float()
if weight is not None:
weight = weight.float()
if bias is not None:
bias = bias.float()
return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half()
else:
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
torch.nn.functional.layer_norm = new_layer_norm
_torch_tensor_permute = torch.Tensor.permute
def new_torch_tensor_permute(input, *dims):
result = _torch_tensor_permute(input, *dims)
if input.device == "mps" and input.dtype == torch.float16:
result = result.contiguous()
return result
torch.Tensor.permute = new_torch_tensor_permute
_torch_lerp = torch.lerp
def new_torch_lerp(input, end, weight, *, out=None):
if input.device.type == "mps" and input.dtype == torch.float16:
input = input.float()
end = end.float()
if isinstance(weight, torch.Tensor):
weight = weight.float()
if out is not None:
out_fp32 = torch.zeros_like(out, dtype=torch.float32)
else:
out_fp32 = None
result = _torch_lerp(input, end, weight, out=out_fp32)
if out is not None:
out.copy_(out_fp32.half())
del out_fp32
return result.half()
else:
return _torch_lerp(input, end, weight, out=out)
torch.lerp = new_torch_lerp
_torch_interpolate = torch.nn.functional.interpolate
def new_torch_interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
if input.device.type == "mps" and input.dtype == torch.float16:
return _torch_interpolate(input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias).half()
else:
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
torch.nn.functional.interpolate = new_torch_interpolate

View File

@ -108,11 +108,11 @@ def main():
print(f':crossed_fingers: Upgrading to [yellow]{tag if tag else release}[/yellow]') print(f':crossed_fingers: Upgrading to [yellow]{tag if tag else release}[/yellow]')
if release: if release:
cmd = f"pip install 'invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip' --use-pep517 --upgrade" cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip" --use-pep517 --upgrade'
elif tag: elif tag:
cmd = f"pip install 'invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip' --use-pep517 --upgrade" cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip" --use-pep517 --upgrade'
else: else:
cmd = f"pip install 'invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip' --use-pep517 --upgrade" cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip" --use-pep517 --upgrade'
print('') print('')
print('') print('')
if os.system(cmd)==0: if os.system(cmd)==0:

View File

@ -382,10 +382,21 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
) )
return min(cols, len(self.installed_models)) return min(cols, len(self.installed_models))
def confirm_deletions(self, selections: InstallSelections)->bool:
remove_models = selections.remove_models
if len(remove_models) > 0:
mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models])
return npyscreen.notify_ok_cancel(f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}")
else:
return True
def on_execute(self): def on_execute(self):
self.monitor.entry_widget.buffer(['Processing...'],scroll_end=True)
self.marshall_arguments() self.marshall_arguments()
app = self.parentApp app = self.parentApp
if not self.confirm_deletions(app.install_selections):
return
self.monitor.entry_widget.buffer(['Processing...'],scroll_end=True)
self.ok_button.hidden = True self.ok_button.hidden = True
self.display() self.display()
@ -417,6 +428,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
def on_done(self): def on_done(self):
self.marshall_arguments() self.marshall_arguments()
if not self.confirm_deletions(self.parentApp.install_selections):
return
self.parentApp.setNextForm(None) self.parentApp.setNextForm(None)
self.parentApp.user_cancelled = False self.parentApp.user_cancelled = False
self.editing = False self.editing = False

View File

@ -18,7 +18,7 @@ from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
# minimum size for UIs # minimum size for UIs
MIN_COLS = 130 MIN_COLS = 130
MIN_LINES = 40 MIN_LINES = 45
# ------------------------------------- # -------------------------------------
def set_terminal_size(columns: int, lines: int, launch_command: str=None): def set_terminal_size(columns: int, lines: int, launch_command: str=None):

View File

@ -0,0 +1,19 @@
import os
import sys
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--web', action='store_true')
opts,_ = parser.parse_known_args()
if opts.web:
sys.argv.pop(sys.argv.index('--web'))
from invokeai.app.api_app import invoke_api
invoke_api()
else:
from invokeai.app.cli_app import invoke_cli
invoke_cli()
if __name__ == '__main__':
main()

View File

@ -1,4 +1,5 @@
""" """
Initialization file for invokeai.frontend.merge Initialization file for invokeai.frontend.merge
""" """
from .merge_diffusers import main as invokeai_merge_diffusers, merge_diffusion_models from .merge_diffusers import main as invokeai_merge_diffusers

View File

@ -6,9 +6,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
""" """
import argparse import argparse
import curses import curses
import os
import sys import sys
import warnings
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Union
@ -20,99 +18,15 @@ from npyscreen import widget
from omegaconf import OmegaConf from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from ...backend.model_management import ModelManager from invokeai.backend.model_management import (
from ...frontend.install.widgets import FloatTitleSlider ModelMerger, MergeInterpolationMethod,
ModelManager, ModelType, BaseModelType,
)
from invokeai.frontend.install.widgets import FloatTitleSlider, TextBox, SingleSelectColumns
DEST_MERGED_MODEL_DIR = "merged_models"
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
def merge_diffusion_models(
model_ids_or_paths: List[Union[str, Path]],
alpha: float = 0.5,
interp: str = None,
force: bool = False,
**kwargs,
) -> DiffusionPipeline:
"""
model_ids_or_paths - up to three models, designated by their local paths or HuggingFace repo_ids
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
pipe = DiffusionPipeline.from_pretrained(
model_ids_or_paths[0],
cache_dir=kwargs.get("cache_dir", config.cache_dir),
custom_pipeline="checkpoint_merger",
)
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_ids_or_paths,
alpha=alpha,
interp=interp,
force=force,
**kwargs,
)
dlogging.set_verbosity(verbosity)
return merged_pipe
def merge_diffusion_models_and_commit(
models: List["str"],
merged_model_name: str,
alpha: float = 0.5,
interp: str = None,
force: bool = False,
**kwargs,
):
"""
models - up to three models, designated by their InvokeAI models.yaml model name
merged_model_name = name for new model
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
config_file = config.model_conf_path
model_manager = ModelManager(OmegaConf.load(config_file))
for mod in models:
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
assert (
model_manager.model_info(mod).get("format", None) == "diffusers"
), f"** {mod} is not a diffusers model. It must be optimized before merging."
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
merged_pipe = merge_diffusion_models(
model_ids_or_paths, alpha, interp, force, **kwargs
)
dump_path = config.models_dir / DEST_MERGED_MODEL_DIR
os.makedirs(dump_path, exist_ok=True)
dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
import_args = dict(
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
)
if vae := model_manager.config[models[0]].get("vae", None):
logger.info(f"Using configured VAE assigned to {models[0]}")
import_args.update(vae=vae)
model_manager.import_diffuser_model(dump_path, **import_args)
model_manager.commit(config_file)
def _parse_args() -> Namespace: def _parse_args() -> Namespace:
parser = argparse.ArgumentParser(description="InvokeAI model merging") parser = argparse.ArgumentParser(description="InvokeAI model merging")
parser.add_argument( parser.add_argument(
@ -131,10 +45,17 @@ def _parse_args() -> Namespace:
) )
parser.add_argument( parser.add_argument(
"--models", "--models",
dest="model_names",
type=str, type=str,
nargs="+", nargs="+",
help="Two to three model names to be merged", help="Two to three model names to be merged",
) )
parser.add_argument(
"--base_model",
type=str,
choices=[x.value for x in BaseModelType],
help="The base model shared by the models to be merged",
)
parser.add_argument( parser.add_argument(
"--merged_model_name", "--merged_model_name",
"--destination", "--destination",
@ -192,6 +113,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
window_height, window_width = curses.initscr().getmaxyx() window_height, window_width = curses.initscr().getmaxyx()
self.model_names = self.get_model_names() self.model_names = self.get_model_names()
self.current_base = 0
max_width = max([len(x) for x in self.model_names]) max_width = max([len(x) for x in self.model_names])
max_width += 6 max_width += 6
horizontal_layout = max_width * 3 < window_width horizontal_layout = max_width * 3 < window_width
@ -208,12 +130,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.", value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
editable=False, editable=False,
) )
self.nextrely += 1
self.base_select = self.add_widget_intelligent(
SingleSelectColumns,
values=[
'Models Built on SD-1.x',
'Models Built on SD-2.x',
],
value=[self.current_base],
columns = 4,
max_height = 2,
relx=8,
scroll_exit = True,
)
self.base_select.on_changed = self._populate_models
self.add_widget_intelligent( self.add_widget_intelligent(
npyscreen.FixedText, npyscreen.FixedText,
value="MODEL 1", value="MODEL 1",
color="GOOD", color="GOOD",
editable=False, editable=False,
rely=4 if horizontal_layout else None, rely=6 if horizontal_layout else None,
) )
self.model1 = self.add_widget_intelligent( self.model1 = self.add_widget_intelligent(
npyscreen.SelectOne, npyscreen.SelectOne,
@ -222,7 +158,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_height=len(self.model_names), max_height=len(self.model_names),
max_width=max_width, max_width=max_width,
scroll_exit=True, scroll_exit=True,
rely=5, rely=7,
) )
self.add_widget_intelligent( self.add_widget_intelligent(
npyscreen.FixedText, npyscreen.FixedText,
@ -230,7 +166,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
color="GOOD", color="GOOD",
editable=False, editable=False,
relx=max_width + 3 if horizontal_layout else None, relx=max_width + 3 if horizontal_layout else None,
rely=4 if horizontal_layout else None, rely=6 if horizontal_layout else None,
) )
self.model2 = self.add_widget_intelligent( self.model2 = self.add_widget_intelligent(
npyscreen.SelectOne, npyscreen.SelectOne,
@ -240,7 +176,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_height=len(self.model_names), max_height=len(self.model_names),
max_width=max_width, max_width=max_width,
relx=max_width + 3 if horizontal_layout else None, relx=max_width + 3 if horizontal_layout else None,
rely=5 if horizontal_layout else None, rely=7 if horizontal_layout else None,
scroll_exit=True, scroll_exit=True,
) )
self.add_widget_intelligent( self.add_widget_intelligent(
@ -249,7 +185,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
color="GOOD", color="GOOD",
editable=False, editable=False,
relx=max_width * 2 + 3 if horizontal_layout else None, relx=max_width * 2 + 3 if horizontal_layout else None,
rely=4 if horizontal_layout else None, rely=6 if horizontal_layout else None,
) )
models_plus_none = self.model_names.copy() models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None") models_plus_none.insert(0, "None")
@ -262,24 +198,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_width=max_width, max_width=max_width,
scroll_exit=True, scroll_exit=True,
relx=max_width * 2 + 3 if horizontal_layout else None, relx=max_width * 2 + 3 if horizontal_layout else None,
rely=5 if horizontal_layout else None, rely=7 if horizontal_layout else None,
) )
for m in [self.model1, self.model2, self.model3]: for m in [self.model1, self.model2, self.model3]:
m.when_value_edited = self.models_changed m.when_value_edited = self.models_changed
self.merged_model_name = self.add_widget_intelligent( self.merged_model_name = self.add_widget_intelligent(
npyscreen.TitleText, TextBox,
name="Name for merged model:", name="Name for merged model:",
labelColor="CONTROL", labelColor="CONTROL",
max_height=3,
value="", value="",
scroll_exit=True, scroll_exit=True,
) )
self.force = self.add_widget_intelligent( self.force = self.add_widget_intelligent(
npyscreen.Checkbox, npyscreen.Checkbox,
name="Force merge of incompatible models", name="Force merge of models created by different diffusers library versions",
labelColor="CONTROL", labelColor="CONTROL",
value=False, value=True,
scroll_exit=True, scroll_exit=True,
) )
self.nextrely += 1
self.merge_method = self.add_widget_intelligent( self.merge_method = self.add_widget_intelligent(
npyscreen.TitleSelectOne, npyscreen.TitleSelectOne,
name="Merge Method:", name="Merge Method:",
@ -341,7 +279,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
interp = self.interpolations[self.merge_method.value[0]] interp = self.interpolations[self.merge_method.value[0]]
args = dict( args = dict(
models=models, model_names=models,
base_model=tuple(BaseModelType)[self.base_select.value[0]],
alpha=self.alpha.value, alpha=self.alpha.value,
interp=interp, interp=interp,
force=self.force.value, force=self.force.value,
@ -379,21 +318,30 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
else: else:
return True return True
def get_model_names(self) -> List[str]: def get_model_names(self, base_model: BaseModelType=None) -> List[str]:
model_names = [ model_names = [
name info["name"]
for name in self.model_manager.model_names() for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
if self.model_manager.model_info(name).get("format") == "diffusers" if info["model_format"] == "diffusers"
] ]
return sorted(model_names) return sorted(model_names)
def _populate_models(self,value=None):
base_model = tuple(BaseModelType)[value[0]]
self.model_names = self.get_model_names(base_model)
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
self.model1.values = self.model_names
self.model2.values = self.model_names
self.model3.values = models_plus_none
self.display()
class Mergeapp(npyscreen.NPSAppManaged): class Mergeapp(npyscreen.NPSAppManaged):
def __init__(self): def __init__(self, model_manager:ModelManager):
super().__init__() super().__init__()
conf = OmegaConf.load(config.model_conf_path) self.model_manager = model_manager
self.model_manager = ModelManager(
conf, "cpu", "float16"
) # precision doesn't really matter here
def onStart(self): def onStart(self):
npyscreen.setTheme(npyscreen.Themes.ElegantTheme) npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
@ -401,44 +349,41 @@ class Mergeapp(npyscreen.NPSAppManaged):
def run_gui(args: Namespace): def run_gui(args: Namespace):
mergeapp = Mergeapp() model_manager = ModelManager(config.model_conf_path)
mergeapp = Mergeapp(model_manager)
mergeapp.run() mergeapp.run()
args = mergeapp.merge_arguments args = mergeapp.merge_arguments
merge_diffusion_models_and_commit(**args) merger = ModelMerger(model_manager)
merger.merge_diffusion_models_and_save(**args)
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".') logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
def run_cli(args: Namespace): def run_cli(args: Namespace):
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1" assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
assert ( assert (
args.models and len(args.models) >= 1 and len(args.models) <= 3 args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage." ), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
if not args.merged_model_name: if not args.merged_model_name:
args.merged_model_name = "+".join(args.models) args.merged_model_name = "+".join(args.model_names)
logger.info( logger.info(
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"' f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
) )
model_manager = ModelManager(OmegaConf.load(config.model_conf_path)) model_manager = ModelManager(config.model_conf_path)
assert ( assert (
args.clobber or args.merged_model_name not in model_manager.model_names() not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merge_diffusion_models_and_commit(**vars(args)) merger = ModelMerger(model_manager)
merger.merge_diffusion_models_and_save(**vars(args))
logger.info(f'Models merged into new model: "{args.merged_model_name}".') logger.info(f'Models merged into new model: "{args.merged_model_name}".')
def main(): def main():
args = _parse_args() args = _parse_args()
config.root = args.root_dir config.parse_args(['--root',str(args.root_dir)])
cache_dir = config.cache_dir
os.environ[
"HF_HOME"
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
args.cache_dir = cache_dir
try: try:
if args.front_end: if args.front_end:

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -12,7 +12,7 @@
margin: 0; margin: 0;
} }
</style> </style>
<script type="module" crossorigin src="./assets/index-c0367e37.js"></script> <script type="module" crossorigin src="./assets/index-581af3d4.js"></script>
</head> </head>
<body dir="ltr"> <body dir="ltr">

View File

@ -52,6 +52,7 @@
"unifiedCanvas": "Unified Canvas", "unifiedCanvas": "Unified Canvas",
"linear": "Linear", "linear": "Linear",
"nodes": "Node Editor", "nodes": "Node Editor",
"batch": "Batch Manager",
"modelmanager": "Model Manager", "modelmanager": "Model Manager",
"postprocessing": "Post Processing", "postprocessing": "Post Processing",
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.", "nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",

View File

@ -0,0 +1,122 @@
{
"accessibility": {
"reset": "Resetoi",
"useThisParameter": "Käytä tätä parametria",
"modelSelect": "Mallin Valinta",
"exitViewer": "Poistu katselimesta",
"uploadImage": "Lataa kuva",
"copyMetadataJson": "Kopioi metadata JSON:iin",
"invokeProgressBar": "Invoken edistymispalkki",
"nextImage": "Seuraava kuva",
"previousImage": "Edellinen kuva",
"zoomIn": "Lähennä",
"flipHorizontally": "Käännä vaakasuoraan",
"zoomOut": "Loitonna",
"rotateCounterClockwise": "Kierrä vastapäivään",
"rotateClockwise": "Kierrä myötäpäivään",
"flipVertically": "Käännä pystysuoraan",
"showGallery": "Näytä galleria",
"modifyConfig": "Muokkaa konfiguraatiota",
"toggleAutoscroll": "Kytke automaattinen vieritys",
"toggleLogViewer": "Kytke lokin katselutila",
"showOptionsPanel": "Näytä asetukset"
},
"common": {
"postProcessDesc2": "Erillinen käyttöliittymä tullaan julkaisemaan helpottaaksemme työnkulkua jälkikäsittelyssä.",
"training": "Kouluta",
"statusLoadingModel": "Ladataan mallia",
"statusModelChanged": "Malli vaihdettu",
"statusConvertingModel": "Muunnetaan mallia",
"statusModelConverted": "Malli muunnettu",
"langFrench": "Ranska",
"langItalian": "Italia",
"languagePickerLabel": "Kielen valinta",
"hotkeysLabel": "Pikanäppäimet",
"reportBugLabel": "Raportoi Bugista",
"langPolish": "Puola",
"themeLabel": "Teema",
"langDutch": "Hollanti",
"settingsLabel": "Asetukset",
"githubLabel": "Github",
"darkTheme": "Tumma",
"lightTheme": "Vaalea",
"greenTheme": "Vihreä",
"langGerman": "Saksa",
"langPortuguese": "Portugali",
"discordLabel": "Discord",
"langEnglish": "Englanti",
"oceanTheme": "Meren sininen",
"langRussian": "Venäjä",
"langUkranian": "Ukraina",
"langSpanish": "Espanja",
"upload": "Lataa",
"statusMergedModels": "Mallit yhdistelty",
"img2img": "Kuva kuvaksi",
"nodes": "Solmut",
"nodesDesc": "Solmupohjainen järjestelmä kuvien generoimiseen on parhaillaan kehitteillä. Pysy kuulolla päivityksistä tähän uskomattomaan ominaisuuteen liittyen.",
"postProcessDesc1": "Invoke AI tarjoaa monenlaisia jälkikäsittelyominaisuukisa. Kuvan laadun skaalaus sekä kasvojen korjaus ovat jo saatavilla WebUI:ssä. Voit ottaa ne käyttöön lisäasetusten valikosta teksti kuvaksi sekä kuva kuvaksi -välilehdiltä. Voit myös suoraan prosessoida kuvia käyttämällä kuvan toimintapainikkeita nykyisen kuvan yläpuolella tai tarkastelussa.",
"postprocessing": "Jälkikäsitellään",
"postProcessing": "Jälkikäsitellään",
"cancel": "Peruuta",
"close": "Sulje",
"accept": "Hyväksy",
"statusConnected": "Yhdistetty",
"statusError": "Virhe",
"statusProcessingComplete": "Prosessointi valmis",
"load": "Lataa",
"back": "Takaisin",
"statusGeneratingTextToImage": "Generoidaan tekstiä kuvaksi",
"trainingDesc2": "InvokeAI tukee jo mukautettujen upotusten kouluttamista tekstin inversiolla käyttäen pääskriptiä.",
"statusDisconnected": "Yhteys katkaistu",
"statusPreparing": "Valmistellaan",
"statusIterationComplete": "Iteraatio valmis",
"statusMergingModels": "Yhdistellään malleja",
"statusProcessingCanceled": "Valmistelu peruutettu",
"statusSavingImage": "Tallennetaan kuvaa",
"statusGeneratingImageToImage": "Generoidaan kuvaa kuvaksi",
"statusRestoringFacesGFPGAN": "Korjataan kasvoja (GFPGAN)",
"statusRestoringFacesCodeFormer": "Korjataan kasvoja (CodeFormer)",
"statusGeneratingInpainting": "Generoidaan sisällemaalausta",
"statusGeneratingOutpainting": "Generoidaan ulosmaalausta",
"statusRestoringFaces": "Korjataan kasvoja",
"pinOptionsPanel": "Kiinnitä asetukset -paneeli",
"loadingInvokeAI": "Ladataan Invoke AI:ta",
"loading": "Ladataan",
"statusGenerating": "Generoidaan",
"txt2img": "Teksti kuvaksi",
"trainingDesc1": "Erillinen työnkulku omien upotusten ja tarkastuspisteiden kouluttamiseksi käyttäen tekstin inversiota ja dreamboothia selaimen käyttöliittymässä.",
"postProcessDesc3": "Invoke AI:n komentorivi tarjoaa paljon muita ominaisuuksia, kuten esimerkiksi Embiggenin.",
"unifiedCanvas": "Yhdistetty kanvas",
"statusGenerationComplete": "Generointi valmis"
},
"gallery": {
"uploads": "Lataukset",
"showUploads": "Näytä lataukset",
"galleryImageResetSize": "Resetoi koko",
"maintainAspectRatio": "Säilytä kuvasuhde",
"galleryImageSize": "Kuvan koko",
"pinGallery": "Kiinnitä galleria",
"showGenerations": "Näytä generaatiot",
"singleColumnLayout": "Yhden sarakkeen asettelu",
"generations": "Generoinnit",
"gallerySettings": "Gallerian asetukset",
"autoSwitchNewImages": "Vaihda uusiin kuviin automaattisesti",
"allImagesLoaded": "Kaikki kuvat ladattu",
"noImagesInGallery": "Ei kuvia galleriassa",
"loadMore": "Lataa lisää"
},
"hotkeys": {
"keyboardShortcuts": "näppäimistön pikavalinnat",
"appHotkeys": "Sovelluksen pikanäppäimet",
"generalHotkeys": "Yleiset pikanäppäimet",
"galleryHotkeys": "Gallerian pikanäppäimet",
"unifiedCanvasHotkeys": "Yhdistetyn kanvaan pikanäppäimet",
"cancel": {
"desc": "Peruuta kuvan luominen",
"title": "Peruuta"
},
"invoke": {
"desc": "Luo kuva"
}
}
}

View File

@ -0,0 +1 @@
{}

View File

@ -0,0 +1,254 @@
{
"accessibility": {
"copyMetadataJson": "Kopiera metadata JSON",
"zoomIn": "Zooma in",
"exitViewer": "Avslutningsvisare",
"modelSelect": "Välj modell",
"uploadImage": "Ladda upp bild",
"invokeProgressBar": "Invoke förloppsmätare",
"nextImage": "Nästa bild",
"toggleAutoscroll": "Växla automatisk rullning",
"flipHorizontally": "Vänd vågrätt",
"flipVertically": "Vänd lodrätt",
"zoomOut": "Zooma ut",
"toggleLogViewer": "Växla logvisare",
"reset": "Starta om",
"previousImage": "Föregående bild",
"useThisParameter": "Använd denna parametern",
"showGallery": "Visa galleri",
"rotateCounterClockwise": "Rotera moturs",
"rotateClockwise": "Rotera medurs",
"modifyConfig": "Ändra konfiguration",
"showOptionsPanel": "Visa inställningspanelen"
},
"common": {
"hotkeysLabel": "Snabbtangenter",
"reportBugLabel": "Rapportera bugg",
"githubLabel": "Github",
"discordLabel": "Discord",
"settingsLabel": "Inställningar",
"darkTheme": "Mörk",
"lightTheme": "Ljus",
"greenTheme": "Grön",
"oceanTheme": "Hav",
"langEnglish": "Engelska",
"langDutch": "Nederländska",
"langFrench": "Franska",
"langGerman": "Tyska",
"langItalian": "Italienska",
"langArabic": "العربية",
"langHebrew": "עברית",
"langPolish": "Polski",
"langPortuguese": "Português",
"langBrPortuguese": "Português do Brasil",
"langSimplifiedChinese": "简体中文",
"langJapanese": "日本語",
"langKorean": "한국어",
"langRussian": "Русский",
"unifiedCanvas": "Förenad kanvas",
"nodesDesc": "Ett nodbaserat system för bildgenerering är under utveckling. Håll utkik för uppdateringar om denna fantastiska funktion.",
"langUkranian": "Украї́нська",
"langSpanish": "Español",
"postProcessDesc2": "Ett dedikerat användargränssnitt kommer snart att släppas för att underlätta mer avancerade arbetsflöden av efterbehandling.",
"trainingDesc1": "Ett dedikerat arbetsflöde för träning av dina egna inbäddningar och kontrollpunkter genom Textual Inversion eller Dreambooth från webbgränssnittet.",
"trainingDesc2": "InvokeAI stöder redan träning av anpassade inbäddningar med hjälp av Textual Inversion genom huvudscriptet.",
"upload": "Ladda upp",
"close": "Stäng",
"cancel": "Avbryt",
"accept": "Acceptera",
"statusDisconnected": "Frånkopplad",
"statusGeneratingTextToImage": "Genererar text till bild",
"statusGeneratingImageToImage": "Genererar Bild till bild",
"statusGeneratingInpainting": "Genererar Måla i",
"statusGenerationComplete": "Generering klar",
"statusModelConverted": "Modell konverterad",
"statusMergingModels": "Sammanfogar modeller",
"pinOptionsPanel": "Nåla fast inställningspanelen",
"loading": "Laddar",
"loadingInvokeAI": "Laddar Invoke AI",
"statusRestoringFaces": "Återskapar ansikten",
"languagePickerLabel": "Språkväljare",
"themeLabel": "Tema",
"txt2img": "Text till bild",
"nodes": "Noder",
"img2img": "Bild till bild",
"postprocessing": "Efterbehandling",
"postProcessing": "Efterbehandling",
"load": "Ladda",
"training": "Träning",
"postProcessDesc1": "Invoke AI erbjuder ett brett utbud av efterbehandlingsfunktioner. Uppskalning och ansiktsåterställning finns redan tillgängligt i webbgränssnittet. Du kommer åt dem ifrån Avancerade inställningar-menyn under Bild till bild-fliken. Du kan också behandla bilder direkt genom att använda knappen bildåtgärder ovanför nuvarande bild eller i bildvisaren.",
"postProcessDesc3": "Invoke AI's kommandotolk erbjuder många olika funktioner, bland annat \"Förstora\".",
"statusGenerating": "Genererar",
"statusError": "Fel",
"back": "Bakåt",
"statusConnected": "Ansluten",
"statusPreparing": "Förbereder",
"statusProcessingCanceled": "Bearbetning avbruten",
"statusProcessingComplete": "Bearbetning färdig",
"statusGeneratingOutpainting": "Genererar Fyll ut",
"statusIterationComplete": "Itterering klar",
"statusSavingImage": "Sparar bild",
"statusRestoringFacesGFPGAN": "Återskapar ansikten (GFPGAN)",
"statusRestoringFacesCodeFormer": "Återskapar ansikten (CodeFormer)",
"statusUpscaling": "Skala upp",
"statusUpscalingESRGAN": "Uppskalning (ESRGAN)",
"statusModelChanged": "Modell ändrad",
"statusLoadingModel": "Laddar modell",
"statusConvertingModel": "Konverterar modell",
"statusMergedModels": "Modeller sammanfogade"
},
"gallery": {
"generations": "Generationer",
"showGenerations": "Visa generationer",
"uploads": "Uppladdningar",
"showUploads": "Visa uppladdningar",
"galleryImageSize": "Bildstorlek",
"allImagesLoaded": "Alla bilder laddade",
"loadMore": "Ladda mer",
"galleryImageResetSize": "Återställ storlek",
"gallerySettings": "Galleriinställningar",
"maintainAspectRatio": "Behåll bildförhållande",
"pinGallery": "Nåla fast galleri",
"noImagesInGallery": "Inga bilder i galleriet",
"autoSwitchNewImages": "Ändra automatiskt till nya bilder",
"singleColumnLayout": "Enkolumnslayout"
},
"hotkeys": {
"generalHotkeys": "Allmänna snabbtangenter",
"galleryHotkeys": "Gallerisnabbtangenter",
"unifiedCanvasHotkeys": "Snabbtangenter för sammanslagskanvas",
"invoke": {
"title": "Anropa",
"desc": "Genererar en bild"
},
"cancel": {
"title": "Avbryt",
"desc": "Avbryt bildgenerering"
},
"focusPrompt": {
"desc": "Fokusera området för promptinmatning",
"title": "Fokusprompt"
},
"pinOptions": {
"desc": "Nåla fast alternativpanelen",
"title": "Nåla fast alternativ"
},
"toggleOptions": {
"title": "Växla inställningar",
"desc": "Öppna och stäng alternativpanelen"
},
"toggleViewer": {
"title": "Växla visaren",
"desc": "Öppna och stäng bildvisaren"
},
"toggleGallery": {
"title": "Växla galleri",
"desc": "Öppna eller stäng galleribyrån"
},
"maximizeWorkSpace": {
"title": "Maximera arbetsyta",
"desc": "Stäng paneler och maximera arbetsyta"
},
"changeTabs": {
"title": "Växla flik",
"desc": "Byt till en annan arbetsyta"
},
"consoleToggle": {
"title": "Växla konsol",
"desc": "Öppna och stäng konsol"
},
"setSeed": {
"desc": "Använd seed för nuvarande bild",
"title": "välj seed"
},
"setParameters": {
"title": "Välj parametrar",
"desc": "Använd alla parametrar från nuvarande bild"
},
"setPrompt": {
"desc": "Använd prompt för nuvarande bild",
"title": "Välj prompt"
},
"restoreFaces": {
"title": "Återskapa ansikten",
"desc": "Återskapa nuvarande bild"
},
"upscale": {
"title": "Skala upp",
"desc": "Skala upp nuvarande bild"
},
"showInfo": {
"title": "Visa info",
"desc": "Visa metadata för nuvarande bild"
},
"sendToImageToImage": {
"title": "Skicka till Bild till bild",
"desc": "Skicka nuvarande bild till Bild till bild"
},
"deleteImage": {
"title": "Radera bild",
"desc": "Radera nuvarande bild"
},
"closePanels": {
"title": "Stäng paneler",
"desc": "Stäng öppna paneler"
},
"previousImage": {
"title": "Föregående bild",
"desc": "Visa föregående bild"
},
"nextImage": {
"title": "Nästa bild",
"desc": "Visa nästa bild"
},
"toggleGalleryPin": {
"title": "Växla gallerinål",
"desc": "Nålar fast eller nålar av galleriet i gränssnittet"
},
"increaseGalleryThumbSize": {
"title": "Förstora galleriets bildstorlek",
"desc": "Förstora miniatyrbildernas storlek"
},
"decreaseGalleryThumbSize": {
"title": "Minska gelleriets bildstorlek",
"desc": "Minska miniatyrbildernas storlek i galleriet"
},
"decreaseBrushSize": {
"desc": "Förminska storleken på kanvas- pensel eller suddgummi",
"title": "Minska penselstorlek"
},
"increaseBrushSize": {
"title": "Öka penselstorlek",
"desc": "Öka stoleken på kanvas- pensel eller suddgummi"
},
"increaseBrushOpacity": {
"title": "Öka penselns opacitet",
"desc": "Öka opaciteten för kanvaspensel"
},
"decreaseBrushOpacity": {
"desc": "Minska kanvaspenselns opacitet",
"title": "Minska penselns opacitet"
},
"moveTool": {
"title": "Flytta",
"desc": "Tillåt kanvasnavigation"
},
"fillBoundingBox": {
"title": "Fyll ram",
"desc": "Fyller ramen med pensels färg"
},
"keyboardShortcuts": "Snabbtangenter",
"appHotkeys": "Appsnabbtangenter",
"selectBrush": {
"desc": "Välj kanvaspensel",
"title": "Välj pensel"
},
"selectEraser": {
"desc": "Välj kanvassuddgummi",
"title": "Välj suddgummi"
},
"eraseBoundingBox": {
"title": "Ta bort ram"
}
}
}

View File

@ -0,0 +1,64 @@
{
"accessibility": {
"invokeProgressBar": "Invoke ilerleme durumu",
"nextImage": "Sonraki Resim",
"useThisParameter": "Kullanıcı parametreleri",
"copyMetadataJson": "Metadata verilerini kopyala (JSON)",
"exitViewer": "Görüntüleme Modundan Çık",
"zoomIn": "Yakınlaştır",
"zoomOut": "Uzaklaştır",
"rotateCounterClockwise": "Döndür (Saat yönünün tersine)",
"rotateClockwise": "Döndür (Saat yönünde)",
"flipHorizontally": "Yatay Çevir",
"flipVertically": "Dikey Çevir",
"modifyConfig": "Ayarları Değiştir",
"toggleAutoscroll": "Otomatik kaydırmayı aç/kapat",
"toggleLogViewer": "Günlük Görüntüleyici Aç/Kapa",
"showOptionsPanel": "Ayarlar Panelini Göster",
"modelSelect": "Model Seçin",
"reset": "Sıfırla",
"uploadImage": "Resim Yükle",
"previousImage": "Önceki Resim",
"menu": "Menü",
"showGallery": "Galeriyi Göster"
},
"common": {
"hotkeysLabel": "Kısayol Tuşları",
"themeLabel": "Tema",
"languagePickerLabel": "Dil Seçimi",
"reportBugLabel": "Hata Bildir",
"githubLabel": "Github",
"discordLabel": "Discord",
"settingsLabel": "Ayarlar",
"darkTheme": "Karanlık Tema",
"lightTheme": "Aydınlık Tema",
"greenTheme": "Yeşil Tema",
"oceanTheme": "Okyanus Tema",
"langArabic": "Arapça",
"langEnglish": "İngilizce",
"langDutch": "Hollandaca",
"langFrench": "Fransızca",
"langGerman": "Almanca",
"langItalian": "İtalyanca",
"langJapanese": "Japonca",
"langPolish": "Lehçe",
"langPortuguese": "Portekizce",
"langBrPortuguese": "Portekizcr (Brezilya)",
"langRussian": "Rusça",
"langSimplifiedChinese": "Çince (Basit)",
"langUkranian": "Ukraynaca",
"langSpanish": "İspanyolca",
"txt2img": "Metinden Resime",
"img2img": "Resimden Metine",
"linear": "Çizgisel",
"nodes": "Düğümler",
"postprocessing": "İşlem Sonrası",
"postProcessing": "İşlem Sonrası",
"postProcessDesc2": "Daha gelişmiş özellikler için ve iş akışını kolaylaştırmak için özel bir kullanıcı arayüzü çok yakında yayınlanacaktır.",
"postProcessDesc3": "Invoke AI komut satırı arayüzü, bir çok yeni özellik sunmaktadır.",
"langKorean": "Korece",
"unifiedCanvas": "Akıllı Tuval",
"nodesDesc": "Görüntülerin oluşturulmasında hazırladığımız yeni bir sistem geliştirme aşamasındadır. Bu harika özellikler ve çok daha fazlası için bizi takip etmeye devam edin.",
"postProcessDesc1": "Invoke AI son kullanıcıya yönelik bir çok özellik sunar. Görüntü kalitesi yükseltme, yüz restorasyonu WebUI üzerinden kullanılabilir. Metinden resime ve resimden metne araçlarına gelişmiş seçenekler menüsünden ulaşabilirsiniz. İsterseniz mevcut görüntü ekranının üzerindeki veya görüntüleyicideki görüntüyü doğrudan düzenleyebilirsiniz."
}
}

View File

@ -0,0 +1 @@
{}

View File

@ -23,7 +23,7 @@
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"", "dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"", "dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build", "build": "yarn run lint && vite build",
"typegen": "npx openapi-typescript http://localhost:9090/openapi.json --output src/services/api/schema.d.ts -t", "typegen": "npx ts-node scripts/typegen.ts",
"preview": "vite preview", "preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx", "lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .", "lint:eslint": "eslint --max-warnings=0 .",
@ -83,7 +83,7 @@
"konva": "^9.2.0", "konva": "^9.2.0",
"lodash-es": "^4.17.21", "lodash-es": "^4.17.21",
"nanostores": "^0.9.2", "nanostores": "^0.9.2",
"openapi-fetch": "0.4.0", "openapi-fetch": "^0.6.1",
"overlayscrollbars": "^2.2.0", "overlayscrollbars": "^2.2.0",
"overlayscrollbars-react": "^0.5.0", "overlayscrollbars-react": "^0.5.0",
"patch-package": "^7.0.0", "patch-package": "^7.0.0",

View File

@ -1,55 +0,0 @@
diff --git a/node_modules/openapi-fetch/dist/index.js b/node_modules/openapi-fetch/dist/index.js
index cd4528a..8976b51 100644
--- a/node_modules/openapi-fetch/dist/index.js
+++ b/node_modules/openapi-fetch/dist/index.js
@@ -1,5 +1,5 @@
// settings & const
-const DEFAULT_HEADERS = {
+const CONTENT_TYPE_APPLICATION_JSON = {
"Content-Type": "application/json",
};
const TRAILING_SLASH_RE = /\/*$/;
@@ -29,18 +29,29 @@ export function createFinalURL(url, options) {
}
return finalURL;
}
+function stringifyBody(body) {
+ if (body instanceof ArrayBuffer || body instanceof File || body instanceof DataView || body instanceof Blob || ArrayBuffer.isView(body) || body instanceof URLSearchParams || body instanceof FormData) {
+ return;
+ }
+
+ if (typeof body === "string") {
+ return body;
+ }
+
+ return JSON.stringify(body);
+ }
+
export default function createClient(clientOptions = {}) {
const { fetch = globalThis.fetch, ...options } = clientOptions;
- const defaultHeaders = new Headers({
- ...DEFAULT_HEADERS,
- ...(options.headers ?? {}),
- });
+ const defaultHeaders = new Headers(options.headers ?? {});
async function coreFetch(url, fetchOptions) {
const { headers, body: requestBody, params = {}, parseAs = "json", querySerializer = defaultSerializer, ...init } = fetchOptions || {};
// URL
const finalURL = createFinalURL(url, { baseUrl: options.baseUrl, params, querySerializer });
+ // Stringify body if needed
+ const stringifiedBody = stringifyBody(requestBody);
// headers
- const baseHeaders = new Headers(defaultHeaders); // clone defaults (dont overwrite!)
+ const baseHeaders = new Headers(stringifiedBody ? { ...CONTENT_TYPE_APPLICATION_JSON, ...defaultHeaders } : defaultHeaders); // clone defaults (dont overwrite!)
const headerOverrides = new Headers(headers);
for (const [k, v] of headerOverrides.entries()) {
if (v === undefined || v === null)
@@ -54,7 +65,7 @@ export default function createClient(clientOptions = {}) {
...options,
...init,
headers: baseHeaders,
- body: typeof requestBody === "string" ? requestBody : JSON.stringify(requestBody),
+ body: stringifiedBody ?? requestBody,
});
// handle empty content
// note: we return `{}` because we want user truthy checks for `.data` or `.error` to succeed

View File

@ -527,7 +527,8 @@
"showOptionsPanel": "Show Options Panel", "showOptionsPanel": "Show Options Panel",
"hidePreview": "Hide Preview", "hidePreview": "Hide Preview",
"showPreview": "Show Preview", "showPreview": "Show Preview",
"controlNetControlMode": "Control Mode" "controlNetControlMode": "Control Mode",
"clipSkip": "Clip Skip"
}, },
"settings": { "settings": {
"models": "Models", "models": "Models",
@ -551,7 +552,8 @@
"generation": "Generation", "generation": "Generation",
"ui": "User Interface", "ui": "User Interface",
"favoriteSchedulers": "Favorite Schedulers", "favoriteSchedulers": "Favorite Schedulers",
"favoriteSchedulersPlaceholder": "No schedulers favorited" "favoriteSchedulersPlaceholder": "No schedulers favorited",
"showAdvancedOptions": "Show Advanced Options"
}, },
"toast": { "toast": {
"serverError": "Server Error", "serverError": "Server Error",

View File

@ -0,0 +1,3 @@
{
"type": "module"
}

View File

@ -0,0 +1,23 @@
import fs from 'node:fs';
import openapiTS from 'openapi-typescript';
const OPENAPI_URL = 'http://localhost:9090/openapi.json';
const OUTPUT_FILE = 'src/services/api/schema.d.ts';
async function main() {
process.stdout.write(
`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...`
);
const types = await openapiTS(OPENAPI_URL, {
exportType: true,
transform: (schemaObject, metadata) => {
if ('format' in schemaObject && schemaObject.format === 'binary') {
return schemaObject.nullable ? 'Blob | null' : 'Blob';
}
},
});
fs.writeFileSync(OUTPUT_FILE, types);
process.stdout.write(` OK!\r\n`);
}
main();

View File

@ -1,5 +1,6 @@
import { Flex, Grid, Portal } from '@chakra-ui/react'; import { Flex, Grid, Portal } from '@chakra-ui/react';
import { useLogger } from 'app/logging/useLogger'; import { useLogger } from 'app/logging/useLogger';
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { PartialAppConfig } from 'app/types/invokeai'; import { PartialAppConfig } from 'app/types/invokeai';
import ImageUploader from 'common/components/ImageUploader'; import ImageUploader from 'common/components/ImageUploader';
@ -46,6 +47,10 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
dispatch(configChanged(config)); dispatch(configChanged(config));
}, [dispatch, config, log]); }, [dispatch, config, log]);
useEffect(() => {
dispatch(appStarted());
}, [dispatch]);
return ( return (
<> <>
<Grid w="100vw" h="100vh" position="relative" overflow="hidden"> <Grid w="100vw" h="100vh" position="relative" overflow="hidden">

View File

@ -55,6 +55,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
} }
if (props.dragData.payloadType === 'IMAGE_DTO') { if (props.dragData.payloadType === 'IMAGE_DTO') {
const { thumbnail_url, width, height } = props.dragData.payload.imageDTO;
return ( return (
<Box <Box
sx={{ sx={{
@ -72,7 +73,10 @@ const DragPreview = (props: OverlayDragImageProps) => {
sx={{ sx={{
...STYLES, ...STYLES,
}} }}
src={props.dragData.payload.imageDTO.thumbnail_url} objectFit="contain"
src={thumbnail_url}
width={width}
height={height}
/> />
</Box> </Box>
); );

View File

@ -9,4 +9,5 @@ export const actionsDenylist = [
'canvas/addPointToCurrentLine', 'canvas/addPointToCurrentLine',
'socket/socketGeneratorProgress', 'socket/socketGeneratorProgress',
'socket/appSocketGeneratorProgress', 'socket/appSocketGeneratorProgress',
'hotkeys/shiftKeyPressed',
]; ];

View File

@ -1,49 +1,67 @@
import type { TypedAddListener, TypedStartListening } from '@reduxjs/toolkit';
import { import {
createListenerMiddleware,
addListener,
ListenerEffect,
AnyAction, AnyAction,
ListenerEffect,
addListener,
createListenerMiddleware,
} from '@reduxjs/toolkit'; } from '@reduxjs/toolkit';
import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit';
import type { RootState, AppDispatch } from '../../store'; import type { AppDispatch, RootState } from '../../store';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected'; import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
import { addAppStartedListener } from './listeners/appStarted';
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesDeleted';
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
import { addCanvasMergedListener } from './listeners/canvasMerged';
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
import { import {
addImageUploadedFulfilledListener, addImageAddedToBoardFulfilledListener,
addImageUploadedRejectedListener, addImageAddedToBoardRejectedListener,
} from './listeners/imageUploaded'; } from './listeners/imageAddedToBoard';
import { import {
addImageDeletedFulfilledListener, addImageDeletedFulfilledListener,
addImageDeletedPendingListener, addImageDeletedPendingListener,
addImageDeletedRejectedListener, addImageDeletedRejectedListener,
addRequestedImageDeletionListener, addRequestedImageDeletionListener,
} from './listeners/imageDeleted'; } from './listeners/imageDeleted';
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addImageDroppedListener } from './listeners/imageDropped';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
import { addCanvasMergedListener } from './listeners/canvasMerged';
import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete';
import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke';
import { import {
addImageMetadataReceivedFulfilledListener, addImageMetadataReceivedFulfilledListener,
addImageMetadataReceivedRejectedListener, addImageMetadataReceivedRejectedListener,
} from './listeners/imageMetadataReceived'; } from './listeners/imageMetadataReceived';
import {
addImageRemovedFromBoardFulfilledListener,
addImageRemovedFromBoardRejectedListener,
} from './listeners/imageRemovedFromBoard';
import { addImageToDeleteSelectedListener } from './listeners/imageToDeleteSelected';
import {
addImageUpdatedFulfilledListener,
addImageUpdatedRejectedListener,
} from './listeners/imageUpdated';
import {
addImageUploadedFulfilledListener,
addImageUploadedRejectedListener,
} from './listeners/imageUploaded';
import { import {
addImageUrlsReceivedFulfilledListener, addImageUrlsReceivedFulfilledListener,
addImageUrlsReceivedRejectedListener, addImageUrlsReceivedRejectedListener,
} from './listeners/imageUrlsReceived'; } from './listeners/imageUrlsReceived';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
import { addModelSelectedListener } from './listeners/modelSelected';
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
import {
addReceivedPageOfImagesFulfilledListener,
addReceivedPageOfImagesRejectedListener,
} from './listeners/receivedPageOfImages';
import { addSelectionAddedToBatchListener } from './listeners/selectionAddedToBatch';
import {
addSessionCanceledFulfilledListener,
addSessionCanceledPendingListener,
addSessionCanceledRejectedListener,
} from './listeners/sessionCanceled';
import { import {
addSessionCreatedFulfilledListener, addSessionCreatedFulfilledListener,
addSessionCreatedPendingListener, addSessionCreatedPendingListener,
@ -54,38 +72,21 @@ import {
addSessionInvokedPendingListener, addSessionInvokedPendingListener,
addSessionInvokedRejectedListener, addSessionInvokedRejectedListener,
} from './listeners/sessionInvoked'; } from './listeners/sessionInvoked';
import { import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke';
addSessionCanceledFulfilledListener, import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
addSessionCanceledPendingListener, import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
addSessionCanceledRejectedListener, import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
} from './listeners/sessionCanceled'; import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete';
import { import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete';
addImageUpdatedFulfilledListener, import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
addImageUpdatedRejectedListener, import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
} from './listeners/imageUpdated'; import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
import { import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
addReceivedPageOfImagesFulfilledListener,
addReceivedPageOfImagesRejectedListener,
} from './listeners/receivedPageOfImages';
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved'; import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener'; import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import {
addImageAddedToBoardFulfilledListener,
addImageAddedToBoardRejectedListener,
} from './listeners/imageAddedToBoard';
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
import {
addImageRemovedFromBoardFulfilledListener,
addImageRemovedFromBoardRejectedListener,
} from './listeners/imageRemovedFromBoard';
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesDeleted';
import { addSelectionAddedToBatchListener } from './listeners/selectionAddedToBatch';
import { addImageDroppedListener } from './listeners/imageDropped';
import { addImageToDeleteSelectedListener } from './listeners/imageToDeleteSelected';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -195,9 +196,6 @@ addSessionCanceledRejectedListener();
addReceivedPageOfImagesFulfilledListener(); addReceivedPageOfImagesFulfilledListener();
addReceivedPageOfImagesRejectedListener(); addReceivedPageOfImagesRejectedListener();
// Gallery
addImageCategoriesChangedListener();
// ControlNet // ControlNet
addControlNetImageProcessedListener(); addControlNetImageProcessedListener();
addControlNetAutoProcessListener(); addControlNetAutoProcessListener();
@ -220,3 +218,9 @@ addSelectionAddedToBatchListener();
// DND // DND
addImageDroppedListener(); addImageDroppedListener();
// Models
addModelSelectedListener();
// app startup
addAppStartedListener();

View File

@ -0,0 +1,43 @@
import { createAction } from '@reduxjs/toolkit';
import {
INITIAL_IMAGE_LIMIT,
isLoadingChanged,
} from 'features/gallery/store/gallerySlice';
import { receivedPageOfImages } from 'services/api/thunks/image';
import { startAppListening } from '..';
export const appStarted = createAction('app/appStarted');
export const addAppStartedListener = () => {
startAppListening({
actionCreator: appStarted,
effect: async (
action,
{ getState, dispatch, unsubscribe, cancelActiveListeners }
) => {
cancelActiveListeners();
unsubscribe();
// fill up the gallery tab with images
await dispatch(
receivedPageOfImages({
categories: ['general'],
is_intermediate: false,
offset: 0,
limit: INITIAL_IMAGE_LIMIT,
})
);
// fill up the assets tab with images
await dispatch(
receivedPageOfImages({
categories: ['control', 'mask', 'user', 'other'],
is_intermediate: false,
offset: 0,
limit: INITIAL_IMAGE_LIMIT,
})
);
dispatch(isLoadingChanged(false));
},
});
};

View File

@ -1,29 +0,0 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { receivedPageOfImages } from 'services/api/thunks/image';
import {
imageCategoriesChanged,
selectFilteredImages,
} from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'gallery' });
export const addImageCategoriesChangedListener = () => {
startAppListening({
actionCreator: imageCategoriesChanged,
effect: (action, { getState, dispatch }) => {
const state = getState();
const filteredImagesCount = selectFilteredImages(state).length;
if (!filteredImagesCount) {
dispatch(
receivedPageOfImages({
categories: action.payload,
board_id: state.boards.selectedBoardId,
is_intermediate: false,
})
);
}
},
});
};

View File

@ -0,0 +1,42 @@
import { makeToast } from 'app/components/Toaster';
import { modelSelected } from 'features/parameters/store/actions';
import {
modelChanged,
vaeSelected,
} from 'features/parameters/store/generationSlice';
import { zMainModel } from 'features/parameters/store/parameterZodSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { startAppListening } from '..';
import { lorasCleared } from '../../../../../features/lora/store/loraSlice';
export const addModelSelectedListener = () => {
startAppListening({
actionCreator: modelSelected,
effect: (action, { getState, dispatch }) => {
const state = getState();
const [base_model, type, name] = action.payload.split('/');
if (state.generation.model?.base_model !== base_model) {
dispatch(
addToast(
makeToast({
title: 'Base model changed, clearing submodels',
status: 'warning',
})
)
);
dispatch(vaeSelected(null));
dispatch(lorasCleared());
// TODO: controlnet cleared
}
const newModel = zMainModel.parse({
id: action.payload,
base_model,
name,
});
dispatch(modelChanged(newModel));
},
});
};

View File

@ -93,7 +93,8 @@ export type AppFeature =
| 'discordLink' | 'discordLink'
| 'bugLink' | 'bugLink'
| 'localization' | 'localization'
| 'consoleLogging'; | 'consoleLogging'
| 'dynamicPrompting';
/** /**
* A disable-able Stable Diffusion feature * A disable-able Stable Diffusion feature
@ -104,7 +105,10 @@ export type SDFeature =
| 'variation' | 'variation'
| 'symmetry' | 'symmetry'
| 'seamless' | 'seamless'
| 'hires'; | 'hires'
| 'lora'
| 'embedding'
| 'vae';
/** /**
* Configuration options for the InvokeAI UI. * Configuration options for the InvokeAI UI.

View File

@ -5,8 +5,10 @@ import {
Input, Input,
InputProps, InputProps,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { stopPastePropagation } from 'common/util/stopPastePropagation'; import { stopPastePropagation } from 'common/util/stopPastePropagation';
import { ChangeEvent, memo } from 'react'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { ChangeEvent, KeyboardEvent, memo, useCallback } from 'react';
interface IAIInputProps extends InputProps { interface IAIInputProps extends InputProps {
label?: string; label?: string;
@ -25,6 +27,25 @@ const IAIInput = (props: IAIInputProps) => {
...rest ...rest
} = props; } = props;
const dispatch = useAppDispatch();
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (e.shiftKey) {
dispatch(shiftKeyPressed(true));
}
},
[dispatch]
);
const handleKeyUp = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (!e.shiftKey) {
dispatch(shiftKeyPressed(false));
}
},
[dispatch]
);
return ( return (
<FormControl <FormControl
isInvalid={isInvalid} isInvalid={isInvalid}
@ -32,7 +53,12 @@ const IAIInput = (props: IAIInputProps) => {
{...formControlProps} {...formControlProps}
> >
{label !== '' && <FormLabel>{label}</FormLabel>} {label !== '' && <FormLabel>{label}</FormLabel>}
<Input {...rest} onPaste={stopPastePropagation} /> <Input
{...rest}
onPaste={stopPastePropagation}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
/>
</FormControl> </FormControl>
); );
}; };

View File

@ -1,7 +1,9 @@
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react'; import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core'; import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { RefObject, memo } from 'react'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
type IAIMultiSelectProps = MultiSelectProps & { type IAIMultiSelectProps = MultiSelectProps & {
@ -11,6 +13,7 @@ type IAIMultiSelectProps = MultiSelectProps & {
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, inputRef, ...rest } = props; const { searchable = true, tooltip, inputRef, ...rest } = props;
const dispatch = useAppDispatch();
const { const {
base50, base50,
base100, base100,
@ -31,11 +34,32 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const [boxShadow] = useToken('shadows', ['dark-lg']); const [boxShadow] = useToken('shadows', ['dark-lg']);
const { colorMode } = useColorMode(); const { colorMode } = useColorMode();
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (e.shiftKey) {
dispatch(shiftKeyPressed(true));
}
},
[dispatch]
);
const handleKeyUp = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (!e.shiftKey) {
dispatch(shiftKeyPressed(false));
}
},
[dispatch]
);
return ( return (
<Tooltip label={tooltip} placement="top" hasArrow> <Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
<MultiSelect <MultiSelect
ref={inputRef} ref={inputRef}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
searchable={searchable} searchable={searchable}
maxDropdownHeight={300}
styles={() => ({ styles={() => ({
label: { label: {
color: mode(base700, base300)(colorMode), color: mode(base700, base300)(colorMode),
@ -66,6 +90,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
'&[data-disabled]': { '&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode), backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode), color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
}, },
}, },
value: { value: {
@ -108,6 +133,10 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
color: mode('white', base50)(colorMode), color: mode('white', base50)(colorMode),
}, },
}, },
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
}, },
rightSection: { rightSection: {
width: 24, width: 24,

View File

@ -1,7 +1,9 @@
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react'; import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core'; import { Select, SelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { memo } from 'react'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { KeyboardEvent, RefObject, memo, useCallback, useState } from 'react';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
export type IAISelectDataType = { export type IAISelectDataType = {
@ -12,10 +14,12 @@ export type IAISelectDataType = {
type IAISelectProps = SelectProps & { type IAISelectProps = SelectProps & {
tooltip?: string; tooltip?: string;
inputRef?: RefObject<HTMLInputElement>;
}; };
const IAIMantineSelect = (props: IAISelectProps) => { const IAIMantineSelect = (props: IAISelectProps) => {
const { searchable = true, tooltip, ...rest } = props; const { searchable = true, tooltip, inputRef, onChange, ...rest } = props;
const dispatch = useAppDispatch();
const { const {
base50, base50,
base100, base100,
@ -35,13 +39,54 @@ const IAIMantineSelect = (props: IAISelectProps) => {
} = useChakraThemeTokens(); } = useChakraThemeTokens();
const { colorMode } = useColorMode(); const { colorMode } = useColorMode();
const [searchValue, setSearchValue] = useState('');
// we want to capture shift keypressed even when an input is focused
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (e.shiftKey) {
dispatch(shiftKeyPressed(true));
}
},
[dispatch]
);
const handleKeyUp = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (!e.shiftKey) {
dispatch(shiftKeyPressed(false));
}
},
[dispatch]
);
// wrap onChange to clear search value on select
const handleChange = useCallback(
(v: string | null) => {
setSearchValue('');
if (!onChange) {
return;
}
onChange(v);
},
[onChange]
);
const [boxShadow] = useToken('shadows', ['dark-lg']); const [boxShadow] = useToken('shadows', ['dark-lg']);
return ( return (
<Tooltip label={tooltip} placement="top" hasArrow> <Tooltip label={tooltip} placement="top" hasArrow>
<Select <Select
ref={inputRef}
searchValue={searchValue}
onSearchChange={setSearchValue}
onChange={handleChange}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
searchable={searchable} searchable={searchable}
maxDropdownHeight={300}
styles={() => ({ styles={() => ({
label: { label: {
color: mode(base700, base300)(colorMode), color: mode(base700, base300)(colorMode),
@ -67,6 +112,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
'&[data-disabled]': { '&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode), backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode), color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
}, },
}, },
value: { value: {
@ -109,6 +155,10 @@ const IAIMantineSelect = (props: IAISelectProps) => {
color: mode('white', base50)(colorMode), color: mode('white', base50)(colorMode),
}, },
}, },
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
}, },
rightSection: { rightSection: {
width: 32, width: 32,

View File

@ -0,0 +1,31 @@
import { Box, Tooltip } from '@chakra-ui/react';
import { Text } from '@mantine/core';
import { forwardRef, memo } from 'react';
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
label: string;
description?: string;
tooltip?: string;
disabled?: boolean;
}
const IAIMantineSelectItemWithTooltip = forwardRef<HTMLDivElement, ItemProps>(
({ label, tooltip, description, disabled, ...others }: ItemProps, ref) => (
<Tooltip label={tooltip} placement="top" hasArrow>
<Box ref={ref} {...others}>
<Box>
<Text>{label}</Text>
{description && (
<Text size="xs" color="base.600">
{description}
</Text>
)}
</Box>
</Box>
</Tooltip>
)
);
IAIMantineSelectItemWithTooltip.displayName = 'IAIMantineSelectItemWithTooltip';
export default memo(IAIMantineSelectItemWithTooltip);

View File

@ -14,10 +14,19 @@ import {
Tooltip, Tooltip,
TooltipProps, TooltipProps,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { stopPastePropagation } from 'common/util/stopPastePropagation'; import { stopPastePropagation } from 'common/util/stopPastePropagation';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { clamp } from 'lodash-es'; import { clamp } from 'lodash-es';
import { FocusEvent, memo, useEffect, useState } from 'react'; import {
FocusEvent,
KeyboardEvent,
memo,
useCallback,
useEffect,
useState,
} from 'react';
const numberStringRegex = /^-?(0\.)?\.?$/; const numberStringRegex = /^-?(0\.)?\.?$/;
@ -60,6 +69,8 @@ const IAINumberInput = (props: Props) => {
...rest ...rest
} = props; } = props;
const dispatch = useAppDispatch();
/** /**
* Using a controlled input with a value that accepts decimals needs special * Using a controlled input with a value that accepts decimals needs special
* handling. If the user starts to type in "1.5", by the time they press the * handling. If the user starts to type in "1.5", by the time they press the
@ -109,6 +120,24 @@ const IAINumberInput = (props: Props) => {
onChange(clamped); onChange(clamped);
}; };
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (e.shiftKey) {
dispatch(shiftKeyPressed(true));
}
},
[dispatch]
);
const handleKeyUp = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (!e.shiftKey) {
dispatch(shiftKeyPressed(false));
}
},
[dispatch]
);
return ( return (
<Tooltip {...tooltipProps}> <Tooltip {...tooltipProps}>
<FormControl <FormControl
@ -128,7 +157,11 @@ const IAINumberInput = (props: Props) => {
{...rest} {...rest}
onPaste={stopPastePropagation} onPaste={stopPastePropagation}
> >
<NumberInputField {...numberInputFieldProps} /> <NumberInputField
{...numberInputFieldProps}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
/>
{showStepper && ( {showStepper && (
<NumberInputStepper> <NumberInputStepper>
<NumberIncrementStepper {...numberInputStepperProps} /> <NumberIncrementStepper {...numberInputStepperProps} />

View File

@ -26,9 +26,12 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { clamp } from 'lodash-es'; import { clamp } from 'lodash-es';
import { useTranslation } from 'react-i18next'; import { useAppDispatch } from 'app/store/storeHooks';
import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { import {
FocusEvent, FocusEvent,
KeyboardEvent,
memo, memo,
MouseEvent, MouseEvent,
useCallback, useCallback,
@ -36,9 +39,9 @@ import {
useMemo, useMemo,
useState, useState,
} from 'react'; } from 'react';
import { useTranslation } from 'react-i18next';
import { BiReset } from 'react-icons/bi'; import { BiReset } from 'react-icons/bi';
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton'; import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = { const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5, mt: 1.5,
@ -56,7 +59,6 @@ export type IAIFullSliderProps = {
withInput?: boolean; withInput?: boolean;
isInteger?: boolean; isInteger?: boolean;
inputWidth?: string | number; inputWidth?: string | number;
inputReadOnly?: boolean;
withReset?: boolean; withReset?: boolean;
handleReset?: () => void; handleReset?: () => void;
tooltipSuffix?: string; tooltipSuffix?: string;
@ -90,7 +92,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
withInput = false, withInput = false,
isInteger = false, isInteger = false,
inputWidth = 16, inputWidth = 16,
inputReadOnly = false,
withReset = false, withReset = false,
hideTooltip = false, hideTooltip = false,
isCompact = false, isCompact = false,
@ -109,7 +110,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
sliderIAIIconButtonProps, sliderIAIIconButtonProps,
...rest ...rest
} = props; } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const [localInputValue, setLocalInputValue] = useState< const [localInputValue, setLocalInputValue] = useState<
@ -152,6 +153,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
); );
const handleInputChange = useCallback((v: number | string) => { const handleInputChange = useCallback((v: number | string) => {
console.log('input');
setLocalInputValue(v); setLocalInputValue(v);
}, []); }, []);
@ -168,6 +170,24 @@ const IAISlider = (props: IAIFullSliderProps) => {
} }
}, []); }, []);
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (e.shiftKey) {
dispatch(shiftKeyPressed(true));
}
},
[dispatch]
);
const handleKeyUp = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (!e.shiftKey) {
dispatch(shiftKeyPressed(false));
}
},
[dispatch]
);
return ( return (
<FormControl <FormControl
onClick={forceInputBlur} onClick={forceInputBlur}
@ -311,7 +331,8 @@ const IAISlider = (props: IAIFullSliderProps) => {
{...sliderNumberInputProps} {...sliderNumberInputProps}
> >
<NumberInputField <NumberInputField
readOnly={inputReadOnly} onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
minWidth={inputWidth} minWidth={inputWidth}
{...sliderNumberInputFieldProps} {...sliderNumberInputFieldProps}
/> />

View File

@ -1,9 +1,38 @@
import { Textarea, TextareaProps, forwardRef } from '@chakra-ui/react'; import { Textarea, TextareaProps, forwardRef } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { stopPastePropagation } from 'common/util/stopPastePropagation'; import { stopPastePropagation } from 'common/util/stopPastePropagation';
import { memo } from 'react'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { KeyboardEvent, memo, useCallback } from 'react';
const IAITextarea = forwardRef((props: TextareaProps, ref) => { const IAITextarea = forwardRef((props: TextareaProps, ref) => {
return <Textarea ref={ref} onPaste={stopPastePropagation} {...props} />; const dispatch = useAppDispatch();
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.shiftKey) {
dispatch(shiftKeyPressed(true));
}
},
[dispatch]
);
const handleKeyUp = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => {
if (!e.shiftKey) {
dispatch(shiftKeyPressed(false));
}
},
[dispatch]
);
return (
<Textarea
ref={ref}
onPaste={stopPastePropagation}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
{...props}
/>
);
}); });
export default memo(IAITextarea); export default memo(IAITextarea);

View File

@ -235,7 +235,6 @@ const IAICanvasToolChooserOptions = () => {
withInput withInput
onChange={(newSize) => dispatch(setBrushSize(newSize))} onChange={(newSize) => dispatch(setBrushSize(newSize))}
sliderNumberInputProps={{ max: 500 }} sliderNumberInputProps={{ max: 500 }}
inputReadOnly={false}
/> />
</Flex> </Flex>
<IAIColorPicker <IAIColorPicker

Some files were not shown because too many files have changed in this diff Show More