mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
merge with main, resolve conflicts
This commit is contained in:
commit
f2b2ebfffa
6
.gitignore
vendored
6
.gitignore
vendored
@ -34,7 +34,7 @@ __pycache__/
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
# dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
@ -79,6 +79,7 @@ cov.xml
|
||||
.pytest.ini
|
||||
cover/
|
||||
junit/
|
||||
notes/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
@ -201,6 +202,9 @@ checkpoints
|
||||
# If it's a Mac
|
||||
.DS_Store
|
||||
|
||||
invokeai/frontend/yarn.lock
|
||||
invokeai/frontend/node_modules
|
||||
|
||||
# Let the frontend manage its own gitignore
|
||||
!invokeai/frontend/web/*
|
||||
|
||||
|
199
README.md
199
README.md
@ -36,9 +36,38 @@
|
||||
|
||||
</div>
|
||||
|
||||
InvokeAI is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. InvokeAI offers an industry leading Web Interface, interactive Command Line Interface, and also serves as the foundation for multiple commercial products.
|
||||
_**Note: This is an alpha release. Bugs are expected and not all
|
||||
features are fully implemented. Please use the GitHub [Issues
|
||||
pages](https://github.com/invoke-ai/InvokeAI/issues?q=is%3Aissue+is%3Aopen)
|
||||
to report unexpected problems. Also note that InvokeAI root directory
|
||||
which contains models, outputs and configuration files, has changed
|
||||
between the 2.x and 3.x release. If you wish to use your v2.3 root
|
||||
directory with v3.0, please follow the directions in [Migrating a 2.3
|
||||
root directory to 3.0](#migrating-to-3).**_
|
||||
|
||||
**Quick links**: [[How to Install](https://invoke-ai.github.io/InvokeAI/#installation)] [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>] [<a href="https://invoke-ai.github.io/InvokeAI/">Documentation and Tutorials</a>] [<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a href="https://github.com/invoke-ai/InvokeAI/issues">Bug Reports</a>] [<a href="https://github.com/invoke-ai/InvokeAI/discussions">Discussion, Ideas & Q&A</a>]
|
||||
InvokeAI is a leading creative engine built to empower professionals
|
||||
and enthusiasts alike. Generate and create stunning visual media using
|
||||
the latest AI-driven technologies. InvokeAI offers an industry leading
|
||||
Web Interface, interactive Command Line Interface, and also serves as
|
||||
the foundation for multiple commercial products.
|
||||
|
||||
**Quick links**: [[How to
|
||||
Install](https://invoke-ai.github.io/InvokeAI/#installation)] [<a
|
||||
href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>] [<a
|
||||
href="https://invoke-ai.github.io/InvokeAI/">Documentation and
|
||||
Tutorials</a>] [<a
|
||||
href="https://github.com/invoke-ai/InvokeAI/">Code and
|
||||
Downloads</a>] [<a
|
||||
href="https://github.com/invoke-ai/InvokeAI/issues">Bug Reports</a>]
|
||||
[<a
|
||||
href="https://github.com/invoke-ai/InvokeAI/discussions">Discussion,
|
||||
Ideas & Q&A</a>]
|
||||
|
||||
<div align="center">
|
||||
|
||||
![canvas preview](https://github.com/invoke-ai/InvokeAI/raw/main/docs/assets/canvas_preview.png)
|
||||
|
||||
</div>
|
||||
|
||||
## Table of Contents
|
||||
|
||||
@ -63,6 +92,9 @@ Table of Contents 📝
|
||||
For full installation and upgrade instructions, please see:
|
||||
[InvokeAI Installation Overview](https://invoke-ai.github.io/InvokeAI/installation/)
|
||||
|
||||
If upgrading from version 2.3, please read [Migrating a 2.3 root
|
||||
directory to 3.0](#migrating-to-3) first.
|
||||
|
||||
### Automatic Installer (suggested for 1st time users)
|
||||
|
||||
1. Go to the bottom of the [Latest Release Page](https://github.com/invoke-ai/InvokeAI/releases/latest)
|
||||
@ -100,7 +132,168 @@ and go to http://localhost:9090.
|
||||
|
||||
### Command-Line Installation (for developers and users familiar with Terminals)
|
||||
|
||||
Please see [InvokeAI Installation Overview](https://invoke-ai.github.io/InvokeAI/installation/) for more details on installing and managing your virtual environment manually.
|
||||
You must have Python 3.9 or 3.10 installed on your machine. Earlier or later versions are
|
||||
not supported.
|
||||
|
||||
1. Open a command-line window on your machine. The PowerShell is recommended for Windows.
|
||||
2. Create a directory to install InvokeAI into. You'll need at least 15 GB of free space:
|
||||
|
||||
```terminal
|
||||
mkdir invokeai
|
||||
````
|
||||
|
||||
3. Create a virtual environment named `.venv` inside this directory and activate it:
|
||||
|
||||
```terminal
|
||||
cd invokeai
|
||||
python -m venv .venv --prompt InvokeAI
|
||||
```
|
||||
|
||||
4. Activate the virtual environment (do it every time you run InvokeAI)
|
||||
|
||||
_For Linux/Mac users:_
|
||||
|
||||
```sh
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
_For Windows users:_
|
||||
|
||||
```ps
|
||||
.venv\Scripts\activate
|
||||
```
|
||||
|
||||
5. Install the InvokeAI module and its dependencies. Choose the command suited for your platform & GPU.
|
||||
|
||||
_For Windows/Linux with an NVIDIA GPU:_
|
||||
|
||||
```terminal
|
||||
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
```
|
||||
|
||||
_For Linux with an AMD GPU:_
|
||||
|
||||
```sh
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||
```
|
||||
|
||||
_For non-GPU systems:_
|
||||
```terminal
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
_For Macintoshes, either Intel or M1/M2:_
|
||||
|
||||
```sh
|
||||
pip install InvokeAI --use-pep517
|
||||
```
|
||||
|
||||
6. Configure InvokeAI and install a starting set of image generation models (you only need to do this once):
|
||||
|
||||
```terminal
|
||||
invokeai-configure
|
||||
```
|
||||
|
||||
7. Launch the web server (do it every time you run InvokeAI):
|
||||
|
||||
```terminal
|
||||
invokeai --web
|
||||
```
|
||||
|
||||
8. Point your browser to http://localhost:9090 to bring up the web interface.
|
||||
9. Type `banana sushi` in the box on the top left and click `Invoke`.
|
||||
|
||||
Be sure to activate the virtual environment each time before re-launching InvokeAI,
|
||||
using `source .venv/bin/activate` or `.venv\Scripts\activate`.
|
||||
|
||||
## Detailed Installation Instructions
|
||||
|
||||
This fork is supported across Linux, Windows and Macintosh. Linux
|
||||
users can use either an Nvidia-based card (with CUDA support) or an
|
||||
AMD card (using the ROCm driver). For full installation and upgrade
|
||||
instructions, please see:
|
||||
[InvokeAI Installation Overview](https://invoke-ai.github.io/InvokeAI/installation/INSTALL_SOURCE/)
|
||||
|
||||
<a name="migrating-to-3"></a>
|
||||
### Migrating a v2.3 InvokeAI root directory
|
||||
|
||||
The InvokeAI root directory is where the InvokeAI startup file,
|
||||
installed models, and generated images are stored. It is ordinarily
|
||||
named `invokeai` and located in your home directory. The contents and
|
||||
layout of this directory has changed between versions 2.3 and 3.0 and
|
||||
cannot be used directly.
|
||||
|
||||
We currently recommend that you use the installer to create a new root
|
||||
directory named differently from the 2.3 one, e.g. `invokeai-3` and
|
||||
then use a migration script to copy your 2.3 models into the new
|
||||
location. However, if you choose, you can upgrade this directory in
|
||||
place. This section gives both recipes.
|
||||
|
||||
#### Creating a new root directory and migrating old models
|
||||
|
||||
This is the safer recipe because it leaves your old root directory in
|
||||
place to fall back on.
|
||||
|
||||
1. Follow the instructions above to create and install InvokeAI in a
|
||||
directory that has a different name from the 2.3 invokeai directory.
|
||||
In this example, we will use "invokeai-3"
|
||||
|
||||
2. When you are prompted to select models to install, select a minimal
|
||||
set of models, such as stable-diffusion-v1.5 only.
|
||||
|
||||
3. After installation is complete launch `invokeai.sh` (Linux/Mac) or
|
||||
`invokeai.bat` and select option 8 "Open the developers console". This
|
||||
will take you to the command line.
|
||||
|
||||
4. Issue the command `invokeai-migrate3 --from /path/to/v2.3-root --to
|
||||
/path/to/invokeai-3-root`. Provide the correct `--from` and `--to`
|
||||
paths for your v2.3 and v3.0 root directories respectively.
|
||||
|
||||
This will copy and convert your old models from 2.3 format to 3.0
|
||||
format and create a new `models` directory in the 3.0 directory. The
|
||||
old models directory (which contains the models selected at install
|
||||
time) will be renamed `models.orig` and can be deleted once you have
|
||||
confirmed that the migration was successful.
|
||||
|
||||
#### Migrating in place
|
||||
|
||||
For the adventurous, you may do an in-place upgrade from 2.3 to 3.0
|
||||
without touching the command line. The recipe is as follows>
|
||||
|
||||
1. Launch the InvokeAI launcher script in your current v2.3 root directory.
|
||||
|
||||
2. Select option [9] "Update InvokeAI" to bring up the updater dialog.
|
||||
|
||||
3a. During the alpha release phase, select option [3] and manually
|
||||
enter the tag name `v3.0.0+a2`.
|
||||
|
||||
3b. Once 3.0 is released, select option [1] to upgrade to the latest release.
|
||||
|
||||
4. Once the upgrade is finished you will be returned to the launcher
|
||||
menu. Select option [7] "Re-run the configure script to fix a broken
|
||||
install or to complete a major upgrade".
|
||||
|
||||
This will run the configure script against the v2.3 directory and
|
||||
update it to the 3.0 format. The following files will be replaced:
|
||||
|
||||
- The invokeai.init file, replaced by invokeai.yaml
|
||||
- The models directory
|
||||
- The configs/models.yaml model index
|
||||
|
||||
The original versions of these files will be saved with the suffix
|
||||
".orig" appended to the end. Once you have confirmed that the upgrade
|
||||
worked, you can safely remove these files. Alternatively you can
|
||||
restore a working v2.3 directory by removing the new files and
|
||||
restoring the ".orig" files' original names.
|
||||
|
||||
#### Migration Caveats
|
||||
|
||||
The migration script will migrate your invokeai settings and models,
|
||||
including textual inversion models, LoRAs and merges that you may have
|
||||
installed previously. However it does **not** migrate the generated
|
||||
images stored in your 2.3-format outputs directory. The released
|
||||
version of 3.0 is expected to have an interface for importing an
|
||||
entire directory of image files as a batch.
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
|
@ -149,7 +149,7 @@ class Installer:
|
||||
|
||||
return venv_dir
|
||||
|
||||
def install(self, root: str = "~/invokeai", version: str = "latest", yes_to_all=False, find_links: Path = None) -> None:
|
||||
def install(self, root: str = "~/invokeai-3", version: str = "latest", yes_to_all=False, find_links: Path = None) -> None:
|
||||
"""
|
||||
Install the InvokeAI application into the given runtime path
|
||||
|
||||
|
@ -14,7 +14,7 @@ echo 3. Run textual inversion training
|
||||
echo 4. Merge models (diffusers type only)
|
||||
echo 5. Download and install models
|
||||
echo 6. Change InvokeAI startup options
|
||||
echo 7. Re-run the configure script to fix a broken install
|
||||
echo 7. Re-run the configure script to fix a broken install or to complete a major upgrade
|
||||
echo 8. Open the developer console
|
||||
echo 9. Update InvokeAI
|
||||
echo 10. Command-line help
|
||||
|
@ -81,7 +81,7 @@ do_choice() {
|
||||
;;
|
||||
7)
|
||||
clear
|
||||
printf "Re-run the configure script to fix a broken install\n"
|
||||
printf "Re-run the configure script to fix a broken install or to complete a major upgrade\n"
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
|
||||
;;
|
||||
8)
|
||||
@ -118,12 +118,12 @@ do_choice() {
|
||||
do_dialog() {
|
||||
options=(
|
||||
1 "Generate images with a browser-based interface"
|
||||
2 "Generate images using a command-line interface"
|
||||
2 "Explore InvokeAI nodes using a command-line interface"
|
||||
3 "Textual inversion training"
|
||||
4 "Merge models (diffusers type only)"
|
||||
5 "Download and install models"
|
||||
6 "Change InvokeAI startup options"
|
||||
7 "Re-run the configure script to fix a broken install"
|
||||
7 "Re-run the configure script to fix a broken install or to complete a major upgrade"
|
||||
8 "Open the developer console"
|
||||
9 "Update InvokeAI")
|
||||
|
||||
|
18
invokeai/app/api/routers/app_info.py
Normal file
18
invokeai/app/api/routers/app_info.py
Normal file
@ -0,0 +1,18 @@
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.version import __version__
|
||||
|
||||
app_router = APIRouter(prefix="/v1/app", tags=['app'])
|
||||
|
||||
|
||||
class AppVersion(BaseModel):
|
||||
"""App Version Response"""
|
||||
version: str
|
||||
|
||||
|
||||
@app_router.get('/version', operation_id="app_version",
|
||||
status_code=200,
|
||||
response_model=AppVersion)
|
||||
async def get_version() -> AppVersion:
|
||||
return AppVersion(version=__version__)
|
@ -1,75 +1,30 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
|
||||
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from fastapi import Query, Body
|
||||
from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from ..dependencies import ApiDependencies
|
||||
from typing import Literal, List, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management import AddModelResult
|
||||
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
|
||||
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
from invokeai.backend.model_management.models import (
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
class VaeRepo(BaseModel):
|
||||
repo_id: str = Field(description="The repo ID to use for this VAE")
|
||||
path: Optional[str] = Field(description="The path to the VAE")
|
||||
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
description: Optional[str] = Field(description="A description of the model")
|
||||
model_name: str = Field(description="The name of the model")
|
||||
model_type: str = Field(description="The type of the model")
|
||||
|
||||
class DiffusersModelInfo(ModelInfo):
|
||||
format: Literal['folder'] = 'folder'
|
||||
|
||||
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||
path: Optional[str] = Field(description="The path to the model")
|
||||
|
||||
class CkptModelInfo(ModelInfo):
|
||||
format: Literal['ckpt'] = 'ckpt'
|
||||
|
||||
config: str = Field(description="The path to the model config")
|
||||
weights: str = Field(description="The path to the model weights")
|
||||
vae: str = Field(description="The path to the model VAE")
|
||||
width: Optional[int] = Field(description="The width of the model")
|
||||
height: Optional[int] = Field(description="The height of the model")
|
||||
|
||||
class SafetensorsModelInfo(CkptModelInfo):
|
||||
format: Literal['safetensors'] = 'safetensors'
|
||||
|
||||
class CreateModelRequest(BaseModel):
|
||||
name: str = Field(description="The name of the model")
|
||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||
|
||||
class CreateModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ImportModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the imported model")
|
||||
# base_model: str = Field(description="The base model")
|
||||
# model_type: str = Field(description="The model type")
|
||||
info: AddModelResult = Field(description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ConversionRequest(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: CkptModelInfo = Field(description="The converted model info")
|
||||
save_location: str = Field(description="The path to save the converted model weights")
|
||||
|
||||
class ConvertedModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: list[MODEL_CONFIGS]
|
||||
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
@ -77,75 +32,103 @@ class ModelsList(BaseModel):
|
||||
responses={200: {"model": ModelsList }},
|
||||
)
|
||||
async def list_models(
|
||||
base_model: Optional[BaseModelType] = Query(
|
||||
default=None, description="Base model"
|
||||
),
|
||||
model_type: Optional[ModelType] = Query(
|
||||
default=None, description="The type of model to get"
|
||||
),
|
||||
base_model: Optional[BaseModelType] = Query(default=None, description="Base model"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
) -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
|
||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||
return models
|
||||
|
||||
@models_router.post(
|
||||
"/",
|
||||
@models_router.patch(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="update_model",
|
||||
responses={200: {"status": "success"}},
|
||||
responses={200: {"description" : "The model was updated successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
400: {"description" : "Bad request"}
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = UpdateModelResponse,
|
||||
)
|
||||
async def update_model(
|
||||
model_request: CreateModelRequest
|
||||
) -> CreateModelResponse:
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> UpdateModelResponse:
|
||||
""" Add Model """
|
||||
model_request_info = model_request.info
|
||||
info_dict = model_request_info.dict()
|
||||
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
model_name=model_request.name,
|
||||
model_attributes=info_dict,
|
||||
clobber=True,
|
||||
)
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_attributes=info.dict()
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
model_response = parse_obj_as(UpdateModelResponse, model_raw)
|
||||
except KeyError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return model_response
|
||||
|
||||
@models_router.post(
|
||||
"/import",
|
||||
"/",
|
||||
operation_id="import_model",
|
||||
responses= {
|
||||
201: {"description" : "The model imported successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
)
|
||||
async def import_model(
|
||||
name: str = Query(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using its local path, repo_id, or remote URL """
|
||||
items_to_import = {name}
|
||||
|
||||
items_to_import = {location}
|
||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import = items_to_import,
|
||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||
)
|
||||
if info := installed_models.get(name):
|
||||
logger.info(f'Successfully imported {name}, got {info}')
|
||||
return ImportModelResponse(
|
||||
name = name,
|
||||
info = info,
|
||||
status = "success",
|
||||
|
||||
try:
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import = items_to_import,
|
||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||
)
|
||||
else:
|
||||
logger.error(f'Model {name} not imported')
|
||||
raise HTTPException(status_code=404, detail=f'Model {name} not found')
|
||||
info = installed_models.get(location)
|
||||
|
||||
if not info:
|
||||
logger.error("Import failed")
|
||||
raise HTTPException(status_code=424)
|
||||
|
||||
logger.info(f'Successfully imported {location}, got {info}')
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.name,
|
||||
base_model=info.base_model,
|
||||
model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
|
||||
except KeyError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
|
||||
@models_router.delete(
|
||||
"/{model_name}",
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="del_model",
|
||||
responses={
|
||||
204: {
|
||||
@ -156,144 +139,95 @@ async def import_model(
|
||||
}
|
||||
},
|
||||
)
|
||||
async def delete_model(model_name: str) -> None:
|
||||
async def delete_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
) -> Response:
|
||||
"""Delete Model"""
|
||||
model_names = ApiDependencies.invoker.services.model_manager.model_names()
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_exists = model_name in model_names
|
||||
|
||||
# check if model exists
|
||||
logger.info(f"Checking for model {model_name}...")
|
||||
|
||||
if model_exists:
|
||||
logger.info(f"Deleting Model: {model_name}")
|
||||
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
|
||||
logger.info(f"Model Deleted: {model_name}")
|
||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||
|
||||
else:
|
||||
logger.error("Model not found")
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.del_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type
|
||||
)
|
||||
logger.info(f"Deleted model: {model_name}")
|
||||
return Response(status_code=204)
|
||||
except KeyError:
|
||||
logger.error(f"Model not found: {model_name}")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
|
||||
# @socketio.on("convertToDiffusers")
|
||||
# def convert_to_diffusers(model_to_convert: dict):
|
||||
# try:
|
||||
# if model_info := self.generate.model_manager.model_info(
|
||||
# model_name=model_to_convert["model_name"]
|
||||
# ):
|
||||
# if "weights" in model_info:
|
||||
# ckpt_path = Path(model_info["weights"])
|
||||
# original_config_file = Path(model_info["config"])
|
||||
# model_name = model_to_convert["model_name"]
|
||||
# model_description = model_info["description"]
|
||||
# else:
|
||||
# self.socketio.emit(
|
||||
# "error", {"message": "Model is not a valid checkpoint file"}
|
||||
# )
|
||||
# else:
|
||||
# self.socketio.emit(
|
||||
# "error", {"message": "Could not retrieve model info."}
|
||||
# )
|
||||
|
||||
# if not ckpt_path.is_absolute():
|
||||
# ckpt_path = Path(Globals.root, ckpt_path)
|
||||
|
||||
# if original_config_file and not original_config_file.is_absolute():
|
||||
# original_config_file = Path(Globals.root, original_config_file)
|
||||
|
||||
# diffusers_path = Path(
|
||||
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if model_to_convert["save_location"] == "root":
|
||||
# diffusers_path = Path(
|
||||
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if (
|
||||
# model_to_convert["save_location"] == "custom"
|
||||
# and model_to_convert["custom_location"] is not None
|
||||
# ):
|
||||
# diffusers_path = Path(
|
||||
# model_to_convert["custom_location"], f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if diffusers_path.exists():
|
||||
# shutil.rmtree(diffusers_path)
|
||||
|
||||
# self.generate.model_manager.convert_and_import(
|
||||
# ckpt_path,
|
||||
# diffusers_path,
|
||||
# model_name=model_name,
|
||||
# model_description=model_description,
|
||||
# vae=None,
|
||||
# original_config_file=original_config_file,
|
||||
# commit_to_conf=opt.conf,
|
||||
# )
|
||||
|
||||
# new_model_list = self.generate.model_manager.list_models()
|
||||
# socketio.emit(
|
||||
# "modelConverted",
|
||||
# {
|
||||
# "new_model_name": model_name,
|
||||
# "model_list": new_model_list,
|
||||
# "update": True,
|
||||
# },
|
||||
# )
|
||||
# print(f">> Model Converted: {model_name}")
|
||||
# except Exception as e:
|
||||
# self.handle_exceptions(e)
|
||||
|
||||
# @socketio.on("mergeDiffusersModels")
|
||||
# def merge_diffusers_models(model_merge_info: dict):
|
||||
# try:
|
||||
# models_to_merge = model_merge_info["models_to_merge"]
|
||||
# model_ids_or_paths = [
|
||||
# self.generate.model_manager.model_name_or_path(x)
|
||||
# for x in models_to_merge
|
||||
# ]
|
||||
# merged_pipe = merge_diffusion_models(
|
||||
# model_ids_or_paths,
|
||||
# model_merge_info["alpha"],
|
||||
# model_merge_info["interp"],
|
||||
# model_merge_info["force"],
|
||||
# )
|
||||
|
||||
# dump_path = global_models_dir() / "merged_models"
|
||||
# if model_merge_info["model_merge_save_path"] is not None:
|
||||
# dump_path = Path(model_merge_info["model_merge_save_path"])
|
||||
|
||||
# os.makedirs(dump_path, exist_ok=True)
|
||||
# dump_path = dump_path / model_merge_info["merged_model_name"]
|
||||
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
|
||||
# merged_model_config = dict(
|
||||
# model_name=model_merge_info["merged_model_name"],
|
||||
# description=f'Merge of models {", ".join(models_to_merge)}',
|
||||
# commit_to_conf=opt.conf,
|
||||
# )
|
||||
|
||||
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
||||
# "vae", None
|
||||
# ):
|
||||
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
|
||||
# merged_model_config.update(vae=vae)
|
||||
|
||||
# self.generate.model_manager.import_diffuser_model(
|
||||
# dump_path, **merged_model_config
|
||||
# )
|
||||
# new_model_list = self.generate.model_manager.list_models()
|
||||
|
||||
# socketio.emit(
|
||||
# "modelsMerged",
|
||||
# {
|
||||
# "merged_models": models_to_merge,
|
||||
# "merged_model_name": model_merge_info["merged_model_name"],
|
||||
# "model_list": new_model_list,
|
||||
# "update": True,
|
||||
# },
|
||||
# )
|
||||
# print(f">> Models Merged: {models_to_merge}")
|
||||
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||
# except Exception as e:
|
||||
@models_router.put(
|
||||
"/convert/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="convert_model",
|
||||
responses={
|
||||
200: { "description": "Model converted successfully" },
|
||||
400: {"description" : "Bad request" },
|
||||
404: { "description": "Model not found" },
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = ConvertModelResponse,
|
||||
)
|
||||
async def convert_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
) -> ConvertModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Converting model: {model_name}")
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
||||
@models_router.put(
|
||||
"/merge/{base_model}",
|
||||
operation_id="merge_models",
|
||||
responses={
|
||||
200: { "description": "Model converted successfully" },
|
||||
400: { "description": "Incompatible models" },
|
||||
404: { "description": "One or more models not found" },
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = MergeModelResponse,
|
||||
)
|
||||
async def merge_models(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
|
||||
) -> MergeModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {model_names}")
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
||||
base_model,
|
||||
merged_model_name or "+".join(model_names),
|
||||
alpha,
|
||||
interp,
|
||||
force)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
||||
base_model = base_model,
|
||||
model_type = ModelType.Main,
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
@ -24,10 +24,14 @@ logger = InvokeAILogger.getLogger(config=app_config)
|
||||
import invokeai.frontend.web as web_dir
|
||||
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import sessions, models, images, boards, board_images
|
||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
|
||||
import torch
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
|
||||
@ -82,6 +86,8 @@ app.include_router(boards.boards_router, prefix="/api")
|
||||
|
||||
app.include_router(board_images.board_images_router, prefix="/api")
|
||||
|
||||
app.include_router(app_info.app_router, prefix='/api')
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
def custom_openapi():
|
||||
|
@ -47,7 +47,7 @@ def add_parsers(
|
||||
commands: list[type],
|
||||
command_field: str = "type",
|
||||
exclude_fields: list[str] = ["id", "type"],
|
||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None],None] = None
|
||||
):
|
||||
"""Adds parsers for each command to the subparsers"""
|
||||
|
||||
@ -72,7 +72,7 @@ def add_parsers(
|
||||
def add_graph_parsers(
|
||||
subparsers,
|
||||
graphs: list[LibraryGraph],
|
||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
||||
):
|
||||
for graph in graphs:
|
||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||
|
@ -1,12 +1,11 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import sys
|
||||
import time
|
||||
from typing import Union, get_type_hints
|
||||
from typing import Union, get_type_hints, Optional
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic.fields import Field
|
||||
@ -53,6 +52,10 @@ from .services.processor import DefaultInvocationProcessor
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
||||
import torch
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes
|
||||
|
||||
|
||||
class CliCommand(BaseModel):
|
||||
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||
@ -348,7 +351,7 @@ def invoke_cli():
|
||||
|
||||
# Parse invocation
|
||||
command: CliCommand = None # type:ignore
|
||||
system_graph: LibraryGraph|None = None
|
||||
system_graph: Optional[LibraryGraph] = None
|
||||
if args['type'] in system_graph_names:
|
||||
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
|
||||
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
||||
|
@ -1,19 +1,16 @@
|
||||
from typing import Literal, Optional, Union, List
|
||||
from pydantic import BaseModel, Field
|
||||
import re
|
||||
from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import (Blend, Conjunction,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt, Fragment)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.model_management import ModelType
|
||||
from ...backend.model_management.models import ModelNotFoundException
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .model import ClipField
|
||||
@ -95,6 +92,7 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),\
|
||||
text_encoder_info as text_encoder:
|
||||
|
||||
compel = Compel(
|
||||
@ -134,6 +132,24 @@ class CompelInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
|
||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""Clip skip node output"""
|
||||
type: Literal["clip_skip_output"] = "clip_skip_output"
|
||||
clip: ClipField = Field(None, description="Clip with skipped layers")
|
||||
|
||||
class ClipSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
type: Literal["clip_skip"] = "clip_skip"
|
||||
|
||||
clip: ClipField = Field(None, description="Clip to use")
|
||||
skipped_layers: int = Field(0, description="Number of layers to skip in text_encoder")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||
self.clip.skipped_layers += self.skipped_layers
|
||||
return ClipSkipInvocationOutput(
|
||||
clip=self.clip,
|
||||
)
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction],
|
||||
|
@ -6,7 +6,7 @@ from builtins import float, bool
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import Literal, Optional, Union, List, Dict
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from ..models.image import ImageField, ImageCategory, ResourceOrigin
|
||||
@ -422,9 +422,9 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
|
||||
# Inputs
|
||||
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||
h: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: Union[int, None] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
h: Optional[int] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||
w: Optional[int] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: Optional[int] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
# fmt: on
|
||||
|
||||
def run_processor(self, image):
|
||||
|
@ -1,11 +1,10 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from functools import partial
|
||||
from typing import Literal, Optional, Union, get_args
|
||||
from typing import Literal, Optional, get_args
|
||||
|
||||
import torch
|
||||
from diffusers import ControlNetModel
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
|
||||
ResourceOrigin)
|
||||
@ -18,7 +17,6 @@ from ..util.step_callback import stable_diffusion_step_callback
|
||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||
from .image import ImageOutput
|
||||
|
||||
import re
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from .model import UNetField, VaeField
|
||||
@ -76,7 +74,7 @@ class InpaintInvocation(BaseInvocation):
|
||||
vae: VaeField = Field(default=None, description="Vae model")
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image")
|
||||
image: Optional[ImageField] = Field(description="The input image")
|
||||
strength: float = Field(
|
||||
default=0.75, gt=0, le=1, description="The strength of the original image"
|
||||
)
|
||||
@ -86,7 +84,7 @@ class InpaintInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
# Inputs
|
||||
mask: Union[ImageField, None] = Field(description="The mask")
|
||||
mask: Optional[ImageField] = Field(description="The mask")
|
||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||
seam_blur: int = Field(
|
||||
default=16, ge=0, description="The seam inpaint blur radius (px)"
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import io
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||
@ -67,7 +66,7 @@ class LoadImageInvocation(BaseInvocation):
|
||||
type: Literal["load_image"] = "load_image"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(
|
||||
image: Optional[ImageField] = Field(
|
||||
default=None, description="The image to load"
|
||||
)
|
||||
# fmt: on
|
||||
@ -87,7 +86,7 @@ class ShowImageInvocation(BaseInvocation):
|
||||
type: Literal["show_image"] = "show_image"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(
|
||||
image: Optional[ImageField] = Field(
|
||||
default=None, description="The image to show"
|
||||
)
|
||||
|
||||
@ -112,7 +111,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
type: Literal["img_crop"] = "img_crop"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to crop")
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to crop")
|
||||
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||
@ -150,8 +149,8 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
type: Literal["img_paste"] = "img_paste"
|
||||
|
||||
# Inputs
|
||||
base_image: Union[ImageField, None] = Field(default=None, description="The base image")
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to paste")
|
||||
base_image: Optional[ImageField] = Field(default=None, description="The base image")
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to paste")
|
||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||
@ -203,7 +202,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
type: Literal["tomask"] = "tomask"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to create the mask from")
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to create the mask from")
|
||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||
# fmt: on
|
||||
|
||||
@ -237,8 +236,8 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
type: Literal["img_mul"] = "img_mul"
|
||||
|
||||
# Inputs
|
||||
image1: Union[ImageField, None] = Field(default=None, description="The first image to multiply")
|
||||
image2: Union[ImageField, None] = Field(default=None, description="The second image to multiply")
|
||||
image1: Optional[ImageField] = Field(default=None, description="The first image to multiply")
|
||||
image2: Optional[ImageField] = Field(default=None, description="The second image to multiply")
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@ -273,7 +272,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
type: Literal["img_chan"] = "img_chan"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to get the channel from")
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to get the channel from")
|
||||
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
|
||||
# fmt: on
|
||||
|
||||
@ -308,7 +307,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
type: Literal["img_conv"] = "img_conv"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to convert")
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to convert")
|
||||
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
|
||||
# fmt: on
|
||||
|
||||
@ -340,7 +339,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
type: Literal["img_blur"] = "img_blur"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to blur")
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to blur")
|
||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
||||
# fmt: on
|
||||
@ -398,7 +397,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
type: Literal["img_resize"] = "img_resize"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to resize")
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to resize")
|
||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
@ -437,7 +436,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
type: Literal["img_scale"] = "img_scale"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to scale")
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
|
||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||
# fmt: on
|
||||
@ -477,7 +476,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
type: Literal["img_lerp"] = "img_lerp"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||
# fmt: on
|
||||
@ -513,7 +512,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
type: Literal["img_ilerp"] = "img_ilerp"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to lerp")
|
||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||
# fmt: on
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
from typing import Literal, Union, get_args
|
||||
from typing import Literal, Optional, get_args
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
@ -68,7 +68,7 @@ def get_tile_images(image: np.ndarray, width=8, height=8):
|
||||
|
||||
|
||||
def tile_fill_missing(
|
||||
im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
|
||||
im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
|
||||
) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
@ -125,7 +125,7 @@ class InfillColorInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
type: Literal["infill_rgba"] = "infill_rgba"
|
||||
image: Union[ImageField, None] = Field(
|
||||
image: Optional[ImageField] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
color: ColorField = Field(
|
||||
@ -162,7 +162,7 @@ class InfillTileInvocation(BaseInvocation):
|
||||
|
||||
type: Literal["infill_tile"] = "infill_tile"
|
||||
|
||||
image: Union[ImageField, None] = Field(
|
||||
image: Optional[ImageField] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
||||
@ -202,7 +202,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
|
||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||
|
||||
image: Union[ImageField, None] = Field(
|
||||
image: Optional[ImageField] = Field(
|
||||
default=None, description="The image to infill"
|
||||
)
|
||||
|
||||
|
@ -1,19 +1,17 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||
from diffusers import ControlNetModel
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
@ -23,7 +21,6 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
||||
PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .compel import ConditioningField
|
||||
@ -585,7 +582,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
type: Literal["i2l"] = "i2l"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The image to encode")
|
||||
image: Optional[ImageField] = Field(description="The image to encode")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
tiled: bool = Field(
|
||||
default=False,
|
||||
|
@ -30,6 +30,7 @@ class UNetField(BaseModel):
|
||||
class ClipField(BaseModel):
|
||||
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
|
||||
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
|
||||
|
||||
@ -154,6 +155,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
|
@ -32,7 +32,7 @@ def get_noise(
|
||||
perlin: float = 0.0,
|
||||
):
|
||||
"""Generate noise for a given image size."""
|
||||
noise_device_type = "cpu" if (use_cpu or device.type == "mps") else device.type
|
||||
noise_device_type = "cpu" if use_cpu else device.type
|
||||
|
||||
# limit noise to only the diffusion image channels, not the mask channels
|
||||
input_channels = min(latent_channels, 4)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Literal, Union
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@ -15,7 +15,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
type: Literal["restore_face"] = "restore_face"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image")
|
||||
image: Optional[ImageField] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
|
||||
# fmt: on
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Union
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@ -16,7 +16,7 @@ class UpscaleInvocation(BaseInvocation):
|
||||
type: Literal["upscale"] = "upscale"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
||||
image: Optional[ImageField] = Field(description="The input image", default=None)
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
||||
# fmt: on
|
||||
|
@ -1,8 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Union, cast
|
||||
from invokeai.app.services.board_record_storage import BoardRecord
|
||||
from typing import Optional, cast
|
||||
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.image_record import (
|
||||
@ -44,7 +43,7 @@ class BoardImageRecordStorageBase(ABC):
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
) -> Optional[str]:
|
||||
"""Gets an image's board id, if it has one."""
|
||||
pass
|
||||
|
||||
@ -215,7 +214,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
|
@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Optional
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.board_record_storage import (
|
||||
BoardRecord,
|
||||
@ -49,7 +49,7 @@ class BoardImagesServiceABC(ABC):
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
) -> Optional[str]:
|
||||
"""Gets an image's board id, if it has one."""
|
||||
pass
|
||||
|
||||
@ -126,13 +126,13 @@ class BoardImagesService(BoardImagesServiceABC):
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
) -> Optional[str]:
|
||||
board_id = self._services.board_image_records.get_board_for_image(image_name)
|
||||
return board_id
|
||||
|
||||
|
||||
def board_record_to_dto(
|
||||
board_record: BoardRecord, cover_image_name: str | None, image_count: int
|
||||
board_record: BoardRecord, cover_image_name: Optional[str], image_count: int
|
||||
) -> BoardDTO:
|
||||
"""Converts a board record to a board DTO."""
|
||||
return BoardDTO(
|
||||
|
@ -171,6 +171,7 @@ from pydantic import BaseSettings, Field, parse_obj_as
|
||||
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
|
||||
|
||||
INIT_FILE = Path('invokeai.yaml')
|
||||
MODEL_CORE = Path('models/core')
|
||||
DB_FILE = Path('invokeai.db')
|
||||
LEGACY_INIT_FILE = Path('invokeai.init')
|
||||
|
||||
@ -324,16 +325,11 @@ class InvokeAISettings(BaseSettings):
|
||||
help=field.field_info.description,
|
||||
)
|
||||
def _find_root()->Path:
|
||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
||||
elif (
|
||||
os.environ.get("VIRTUAL_ENV")
|
||||
and (Path(os.environ.get("VIRTUAL_ENV"), "..", INIT_FILE).exists()
|
||||
or
|
||||
Path(os.environ.get("VIRTUAL_ENV"), "..", LEGACY_INIT_FILE).exists()
|
||||
)
|
||||
):
|
||||
root = Path(os.environ.get("VIRTUAL_ENV"), "..").resolve()
|
||||
elif any([(venv.parent/x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE, MODEL_CORE]]):
|
||||
root = (venv.parent).resolve()
|
||||
else:
|
||||
root = Path("~/invokeai").expanduser().resolve()
|
||||
return root
|
||||
|
@ -1,10 +1,9 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
@ -28,7 +27,7 @@ class EventServiceBase:
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
progress_image: ProgressImage | None,
|
||||
progress_image: Optional[ProgressImage],
|
||||
step: int,
|
||||
total_steps: int,
|
||||
) -> None:
|
||||
|
@ -3,7 +3,6 @@
|
||||
import copy
|
||||
import itertools
|
||||
import uuid
|
||||
from types import NoneType
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
@ -26,6 +25,8 @@ from ..invocations.baseinvocation import (
|
||||
InvocationContext,
|
||||
)
|
||||
|
||||
# in 3.10 this would be "from types import NoneType"
|
||||
NoneType = type(None)
|
||||
|
||||
class EdgeConnection(BaseModel):
|
||||
node_id: str = Field(description="The id of the node for this edge connection")
|
||||
@ -60,8 +61,6 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
|
||||
node_input_field = node_inputs.get(field) or None
|
||||
return node_input_field
|
||||
|
||||
from typing import Optional, Union, List, get_args
|
||||
|
||||
def is_union_subtype(t1, t2):
|
||||
t1_args = get_args(t1)
|
||||
t2_args = get_args(t2)
|
||||
@ -846,7 +845,7 @@ class GraphExecutionState(BaseModel):
|
||||
]
|
||||
}
|
||||
|
||||
def next(self) -> BaseInvocation | None:
|
||||
def next(self) -> Optional[BaseInvocation]:
|
||||
"""Gets the next node ready to execute."""
|
||||
|
||||
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
|
||||
|
@ -2,13 +2,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
from PIL import Image, PngImagePlugin
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.models.image import ResourceOrigin
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
@ -80,7 +79,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
__cache: Dict[Path, PILImageType]
|
||||
__max_cache_size: int
|
||||
|
||||
def __init__(self, output_folder: str | Path):
|
||||
def __init__(self, output_folder: Union[str, Path]):
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = 10 # TODO: get this from config
|
||||
@ -164,7 +163,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
|
||||
return path
|
||||
|
||||
def validate_path(self, path: str | Path) -> bool:
|
||||
def validate_path(self, path: Union[str, Path]) -> bool:
|
||||
"""Validates the path given for an image or thumbnail."""
|
||||
path = path if isinstance(path, Path) else Path(path)
|
||||
return path.exists()
|
||||
@ -175,7 +174,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
for folder in folders:
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def __get_cache(self, image_name: Path) -> PILImageType | None:
|
||||
def __get_cache(self, image_name: Path) -> Optional[PILImageType]:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
def __set_cache(self, image_name: Path, image: PILImageType):
|
||||
|
@ -3,7 +3,6 @@ from datetime import datetime
|
||||
from typing import Generic, Optional, TypeVar, cast
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
@ -116,7 +115,7 @@ class ImageRecordStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None:
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||
"""Gets the most recent image for a board."""
|
||||
pass
|
||||
|
||||
@ -208,7 +207,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
)
|
||||
|
||||
def get(self, image_name: str) -> Union[ImageRecord, None]:
|
||||
def get(self, image_name: str) -> Optional[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
@ -220,7 +219,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordNotFoundException from e
|
||||
@ -475,7 +474,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
def get_most_recent_image_for_board(
|
||||
self, board_id: str
|
||||
) -> Union[ImageRecord, None]:
|
||||
) -> Optional[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@ -490,7 +489,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
|
@ -370,7 +370,7 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def _get_metadata(
|
||||
self, session_id: Optional[str] = None, node_id: Optional[str] = None
|
||||
) -> Union[ImageMetadata, None]:
|
||||
) -> Optional[ImageMetadata]:
|
||||
"""Get the metadata for a node."""
|
||||
metadata = None
|
||||
|
||||
|
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from queue import Queue
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from typing import Optional
|
||||
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
@ -22,7 +22,7 @@ class InvocationQueueABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, item: InvocationQueueItem | None) -> None:
|
||||
def put(self, item: Optional[InvocationQueueItem]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -57,7 +57,7 @@ class MemoryInvocationQueue(InvocationQueueABC):
|
||||
|
||||
return item
|
||||
|
||||
def put(self, item: InvocationQueueItem | None) -> None:
|
||||
def put(self, item: Optional[InvocationQueueItem]) -> None:
|
||||
self.__queue.put(item)
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
|
@ -1,14 +1,11 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC
|
||||
from threading import Event, Thread
|
||||
from typing import Optional
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .graph import Graph, GraphExecutionState
|
||||
from .invocation_queue import InvocationQueueABC, InvocationQueueItem
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invocation_services import InvocationServices
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
|
||||
class Invoker:
|
||||
"""The invoker, used to execute invocations"""
|
||||
@ -21,7 +18,7 @@ class Invoker:
|
||||
|
||||
def invoke(
|
||||
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
||||
) -> str | None:
|
||||
) -> Optional[str]:
|
||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||
|
||||
@ -45,7 +42,7 @@ class Invoker:
|
||||
|
||||
return invocation.id
|
||||
|
||||
def create_execution_state(self, graph: Graph | None = None) -> GraphExecutionState:
|
||||
def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState:
|
||||
"""Creates a new execution state for the given graph"""
|
||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||
self.services.graph_execution_manager.set(new_state)
|
||||
|
@ -3,7 +3,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict
|
||||
from typing import Dict, Union, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -55,7 +55,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
if name in self.__cache:
|
||||
del self.__cache[name]
|
||||
|
||||
def __get_cache(self, name: str) -> torch.Tensor|None:
|
||||
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
|
||||
return None if name not in self.__cache else self.__cache[name]
|
||||
|
||||
def __set_cache(self, name: str, data: torch.Tensor):
|
||||
@ -69,9 +69,9 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
class DiskLatentsStorage(LatentsStorageBase):
|
||||
"""Stores latents in a folder on disk without caching"""
|
||||
|
||||
__output_folder: str | Path
|
||||
__output_folder: Union[str, Path]
|
||||
|
||||
def __init__(self, output_folder: str | Path):
|
||||
def __init__(self, output_folder: Union[str, Path]):
|
||||
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -91,4 +91,4 @@ class DiskLatentsStorage(LatentsStorageBase):
|
||||
|
||||
def get_path(self, name: str) -> Path:
|
||||
return self.__output_folder / name
|
||||
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional
|
||||
import networkx as nx
|
||||
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
@ -34,7 +34,7 @@ class CoreMetadataService(MetadataServiceBase):
|
||||
|
||||
return metadata
|
||||
|
||||
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Union[str, None]:
|
||||
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Optional[str]:
|
||||
"""
|
||||
Finds the id of the nearest ancestor (of a valid type) of a given node.
|
||||
|
||||
@ -65,7 +65,7 @@ class CoreMetadataService(MetadataServiceBase):
|
||||
|
||||
def _get_additional_metadata(
|
||||
self, graph: Graph, node_id: str
|
||||
) -> Union[dict[str, Any], None]:
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
Returns additional metadata for a given node.
|
||||
|
||||
|
@ -2,22 +2,29 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
|
||||
from dataclasses import dataclass
|
||||
from pydantic import Field
|
||||
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
|
||||
from types import ModuleType
|
||||
|
||||
from invokeai.backend.model_management.model_manager import (
|
||||
from invokeai.backend.model_management import (
|
||||
ModelManager,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelInfo,
|
||||
AddModelResult,
|
||||
SchedulerPredictionType,
|
||||
ModelMerger,
|
||||
MergeInterpolationMethod,
|
||||
)
|
||||
|
||||
|
||||
import torch
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from .config import InvokeAIAppConfig
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from .config import InvokeAIAppConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||
@ -30,7 +37,7 @@ class ModelManagerServiceBase(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: types.ModuleType,
|
||||
logger: ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
@ -73,13 +80,7 @@ class ModelManagerServiceBase(ABC):
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
Uses the exact format as the omegaconf stanza.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -101,7 +102,20 @@ class ModelManagerServiceBase(ABC):
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(
|
||||
@ -111,7 +125,7 @@ class ModelManagerServiceBase(ABC):
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False
|
||||
) -> None:
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
@ -121,6 +135,24 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
KeyErrorException if the name does not already exist.
|
||||
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(
|
||||
self,
|
||||
@ -135,11 +167,32 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||
)->Dict[str, AddModelResult]:
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||
)->dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
@ -159,7 +212,27 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Path = None) -> None:
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
||||
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
@ -173,7 +246,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: types.ModuleType,
|
||||
logger: ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
@ -191,6 +264,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
logger.debug(f'config file={config_file}')
|
||||
|
||||
device = torch.device(choose_torch_device())
|
||||
logger.debug(f'GPU device = {device}')
|
||||
|
||||
precision = config.precision
|
||||
if precision == "auto":
|
||||
precision = choose_precision(device)
|
||||
@ -299,12 +374,19 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None
|
||||
) -> list[dict]:
|
||||
# ) -> dict:
|
||||
"""
|
||||
Return a list of models.
|
||||
"""
|
||||
return self.mgr.list_models(base_model, model_type)
|
||||
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
return self.mgr.list_model(model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type)
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -320,9 +402,28 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f'add/update model {model_name}')
|
||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||
|
||||
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
KeyError exception if the name does not already exist.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f'update model {model_name}')
|
||||
if not self.model_exists(model_name, base_model, model_type):
|
||||
raise KeyError(f"Unknown model {model_name}")
|
||||
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -334,8 +435,29 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
"""
|
||||
self.logger.debug(f'delete model {model_name}')
|
||||
self.mgr.del_model(model_name, base_model, model_type)
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
self.logger.debug(f'convert model {model_name}')
|
||||
return self.mgr.convert_model(model_name, base_model, model_type)
|
||||
|
||||
def commit(self, conf_file: Optional[Path]=None):
|
||||
"""
|
||||
@ -387,9 +509,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
return self.mgr.logger
|
||||
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||
)->Dict[str, AddModelResult]:
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||
)->dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
@ -406,4 +528,35 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
'''
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
||||
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
"""
|
||||
merger = ModelMerger(self.mgr)
|
||||
try:
|
||||
result = merger.merge_diffusion_models_and_save(
|
||||
model_names = model_names,
|
||||
base_model = base_model,
|
||||
merged_model_name = merged_model_name,
|
||||
alpha = alpha,
|
||||
interp = interp,
|
||||
force = force,
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise ValueError(e)
|
||||
return result
|
||||
|
@ -88,7 +88,7 @@ class ImageUrlsDTO(BaseModel):
|
||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
"""Deserialized image record, enriched for the frontend."""
|
||||
|
||||
board_id: Union[str, None] = Field(
|
||||
board_id: Optional[str] = Field(
|
||||
description="The id of the board the image belongs to, if one exists."
|
||||
)
|
||||
"""The id of the board the image belongs to, if one exists."""
|
||||
@ -96,7 +96,7 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
|
||||
|
||||
def image_record_to_dto(
|
||||
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None]
|
||||
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Optional[str]
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
|
@ -1,6 +1,6 @@
|
||||
import sqlite3
|
||||
from threading import Lock
|
||||
from typing import Generic, TypeVar, Union, get_args
|
||||
from typing import Generic, TypeVar, Optional, Union, get_args
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
|
||||
@ -63,7 +63,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._lock.release()
|
||||
self._on_changed(item)
|
||||
|
||||
def get(self, id: str) -> Union[T, None]:
|
||||
def get(self, id: str) -> Optional[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
|
@ -21,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DiffusionPipeline
|
||||
from tqdm import trange
|
||||
from typing import Callable, List, Iterator, Optional, Type
|
||||
from typing import Callable, List, Iterator, Optional, Type, Union
|
||||
from dataclasses import dataclass, field
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
@ -178,7 +178,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
# ------------------------------------
|
||||
class Img2Img(InvokeAIGenerator):
|
||||
def generate(self,
|
||||
init_image: Image.Image | torch.FloatTensor,
|
||||
init_image: Union[Image.Image, torch.FloatTensor],
|
||||
strength: float=0.75,
|
||||
**keyword_args
|
||||
)->Iterator[InvokeAIGeneratorOutput]:
|
||||
@ -195,7 +195,7 @@ class Img2Img(InvokeAIGenerator):
|
||||
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||
class Inpaint(Img2Img):
|
||||
def generate(self,
|
||||
mask_image: Image.Image | torch.FloatTensor,
|
||||
mask_image: Union[Image.Image, torch.FloatTensor],
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 96,
|
||||
seam_blur: int = 16,
|
||||
@ -570,28 +570,16 @@ class Generator:
|
||||
device = self.model.device
|
||||
# limit noise to only the diffusion image channels, not the mask channels
|
||||
input_channels = min(self.latent_channels, 4)
|
||||
if self.use_mps_noise or device.type == "mps":
|
||||
x = torch.randn(
|
||||
[
|
||||
1,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
],
|
||||
dtype=self.torch_dtype(),
|
||||
device="cpu",
|
||||
).to(device)
|
||||
else:
|
||||
x = torch.randn(
|
||||
[
|
||||
1,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
],
|
||||
dtype=self.torch_dtype(),
|
||||
device=device,
|
||||
)
|
||||
x = torch.randn(
|
||||
[
|
||||
1,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
],
|
||||
dtype=self.torch_dtype(),
|
||||
device=device,
|
||||
)
|
||||
if self.perlin > 0.0:
|
||||
perlin_noise = self.get_perlin_noise(
|
||||
width // self.downsampling_factor, height // self.downsampling_factor
|
||||
|
@ -88,10 +88,7 @@ class Img2Img(Generator):
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
device = like.device
|
||||
if device.type == "mps":
|
||||
x = torch.randn_like(like, device="cpu").to(device)
|
||||
else:
|
||||
x = torch.randn_like(like, device=device)
|
||||
x = torch.randn_like(like, device=device)
|
||||
if self.perlin > 0.0:
|
||||
shape = like.shape
|
||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
|
||||
|
@ -4,11 +4,10 @@ invokeai.backend.generator.inpaint descends from .generator
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
from typing import Tuple, Union, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
|
||||
@ -76,7 +75,7 @@ class Inpaint(Img2Img):
|
||||
return im_patched
|
||||
|
||||
def tile_fill_missing(
|
||||
self, im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
|
||||
self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
|
||||
) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
@ -203,8 +202,8 @@ class Inpaint(Img2Img):
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_image: Image.Image | torch.FloatTensor,
|
||||
mask_image: Image.Image | torch.FloatTensor,
|
||||
init_image: Union[Image.Image, torch.FloatTensor],
|
||||
mask_image: Union[Image.Image, torch.FloatTensor],
|
||||
strength: float,
|
||||
mask_blur_radius: int = 8,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
|
@ -45,6 +45,7 @@ from invokeai.app.services.config import (
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from invokeai.frontend.install.widgets import (
|
||||
SingleSelectColumns,
|
||||
CenteredButtonPress,
|
||||
IntTitleSlider,
|
||||
set_min_terminal_size,
|
||||
@ -76,7 +77,7 @@ Weights_dir = "ldm/stable-diffusion-v1/"
|
||||
Default_config_file = config.model_conf_path
|
||||
SD_Configs = config.legacy_conf_path
|
||||
|
||||
PRECISION_CHOICES = ['auto','float16','float32','autocast']
|
||||
PRECISION_CHOICES = ['auto','float16','float32']
|
||||
|
||||
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# This is the InvokeAI initialization file, which contains command-line default values.
|
||||
@ -359,9 +360,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
label = """If you have an account at HuggingFace you may optionally paste your access token here
|
||||
to allow InvokeAI to download restricted styles & subjects from the "Concept Library". See https://huggingface.co/settings/tokens.
|
||||
"""
|
||||
label = """HuggingFace access token (OPTIONAL) for automatic model downloads. See https://huggingface.co/settings/tokens."""
|
||||
for line in textwrap.wrap(label,width=window_width-6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
@ -423,6 +422,7 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib
|
||||
)
|
||||
self.precision = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
columns = 2,
|
||||
name="Precision",
|
||||
values=PRECISION_CHOICES,
|
||||
value=PRECISION_CHOICES.index(precision),
|
||||
|
@ -3,7 +3,6 @@ Migrate the models directory and models.yaml file from an existing
|
||||
InvokeAI 2.3 installation to 3.0.0.
|
||||
'''
|
||||
|
||||
import io
|
||||
import os
|
||||
import argparse
|
||||
import shutil
|
||||
@ -28,9 +27,10 @@ from transformers import (
|
||||
)
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import ModelManager
|
||||
from invokeai.backend.model_management.model_probe import (
|
||||
ModelProbe, ModelType, BaseModelType, SchedulerPredictionType, ModelProbeInfo
|
||||
ModelProbe, ModelType, BaseModelType, ModelProbeInfo
|
||||
)
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
@ -47,52 +47,27 @@ class ModelPaths:
|
||||
|
||||
class MigrateTo3(object):
|
||||
def __init__(self,
|
||||
root_directory: Path,
|
||||
dest_models: Path,
|
||||
yaml_file: io.TextIOBase,
|
||||
from_root: Path,
|
||||
to_models: Path,
|
||||
model_manager: ModelManager,
|
||||
src_paths: ModelPaths,
|
||||
):
|
||||
self.root_directory = root_directory
|
||||
self.dest_models = dest_models
|
||||
self.dest_yaml = yaml_file
|
||||
self.model_names = set()
|
||||
self.root_directory = from_root
|
||||
self.dest_models = to_models
|
||||
self.mgr = model_manager
|
||||
self.src_paths = src_paths
|
||||
|
||||
self._initialize_yaml()
|
||||
|
||||
def _initialize_yaml(self):
|
||||
self.dest_yaml.write(
|
||||
yaml.dump(
|
||||
{
|
||||
'__metadata__':
|
||||
@classmethod
|
||||
def initialize_yaml(cls, yaml_file: Path):
|
||||
with open(yaml_file, 'w') as file:
|
||||
file.write(
|
||||
yaml.dump(
|
||||
{
|
||||
'version':'3.0.0'}
|
||||
}
|
||||
'__metadata__': {'version':'3.0.0'}
|
||||
}
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def unique_name(self,name,info)->str:
|
||||
'''
|
||||
Create a unique name for a model for use within models.yaml.
|
||||
'''
|
||||
done = False
|
||||
|
||||
# some model names have slashes in them, which really screws things up
|
||||
name = name.replace('/','_')
|
||||
|
||||
key = ModelManager.create_key(name,info.base_type,info.model_type)
|
||||
unique_name = key
|
||||
counter = 1
|
||||
while not done:
|
||||
if unique_name in self.model_names:
|
||||
unique_name = f'{key}-{counter:0>2d}'
|
||||
counter += 1
|
||||
else:
|
||||
done = True
|
||||
self.model_names.add(unique_name)
|
||||
name,_,_ = ModelManager.parse_key(unique_name)
|
||||
return name
|
||||
|
||||
def create_directory_structure(self):
|
||||
'''
|
||||
Create the basic directory structure for the models folder.
|
||||
@ -140,23 +115,8 @@ class MigrateTo3(object):
|
||||
that looks like a model, and copy the model into the
|
||||
appropriate location within the destination models directory.
|
||||
'''
|
||||
directories_scanned = set()
|
||||
for root, dirs, files in os.walk(src_dir):
|
||||
for f in files:
|
||||
# hack - don't copy raw learned_embeds.bin, let them
|
||||
# be copied as part of a tree copy operation
|
||||
if f == 'learned_embeds.bin':
|
||||
continue
|
||||
try:
|
||||
model = Path(root,f)
|
||||
info = ModelProbe().heuristic_probe(model)
|
||||
if not info:
|
||||
continue
|
||||
dest = self._model_probe_to_path(info) / f
|
||||
self.copy_file(model, dest)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
for d in dirs:
|
||||
try:
|
||||
model = Path(root,d)
|
||||
@ -165,6 +125,29 @@ class MigrateTo3(object):
|
||||
continue
|
||||
dest = self._model_probe_to_path(info) / model.name
|
||||
self.copy_dir(model, dest)
|
||||
directories_scanned.add(model)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
for f in files:
|
||||
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
||||
# let them be copied as part of a tree copy operation
|
||||
try:
|
||||
if f in {'learned_embeds.bin','pytorch_lora_weights.bin'}:
|
||||
continue
|
||||
model = Path(root,f)
|
||||
if model.parent in directories_scanned:
|
||||
continue
|
||||
info = ModelProbe().heuristic_probe(model)
|
||||
if not info:
|
||||
continue
|
||||
dest = self._model_probe_to_path(info) / f
|
||||
self.copy_file(model, dest)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
@ -267,28 +250,6 @@ class MigrateTo3(object):
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
||||
def write_yaml(self, model_name: str, path:Path, info:ModelProbeInfo, **kwargs):
|
||||
'''
|
||||
Write a stanza for a moved model into the new models.yaml file.
|
||||
'''
|
||||
name = self.unique_name(model_name, info)
|
||||
stanza = {
|
||||
f'{info.base_type.value}/{info.model_type.value}/{name}': {
|
||||
'name': model_name,
|
||||
'path': str(path),
|
||||
'description': f'A {info.base_type.value} {info.model_type.value} model',
|
||||
'format': info.format,
|
||||
'image_size': info.image_size,
|
||||
'base': info.base_type.value,
|
||||
'variant': info.variant_type.value,
|
||||
'prediction_type': info.prediction_type.value,
|
||||
'upcast_attention': info.prediction_type == SchedulerPredictionType.VPrediction,
|
||||
**kwargs,
|
||||
}
|
||||
}
|
||||
self.dest_yaml.write(yaml.dump(stanza))
|
||||
self.dest_yaml.flush()
|
||||
|
||||
def _model_probe_to_path(self, info: ModelProbeInfo)->Path:
|
||||
return Path(self.dest_models, info.base_type.value, info.model_type.value)
|
||||
|
||||
@ -332,6 +293,7 @@ class MigrateTo3(object):
|
||||
elif repo_id := vae.get('repo_id'):
|
||||
if repo_id=='stabilityai/sd-vae-ft-mse': # this guy is already downloaded
|
||||
vae_path = 'models/core/convert/sd-vae-ft-mse'
|
||||
return vae_path
|
||||
else:
|
||||
vae_path = self._download_vae(repo_id, vae.get('subfolder'))
|
||||
|
||||
@ -344,7 +306,10 @@ class MigrateTo3(object):
|
||||
info = ModelProbe().heuristic_probe(vae_path)
|
||||
dest = self._model_probe_to_path(info) / vae_path.name
|
||||
if not dest.exists():
|
||||
self.copy_dir(vae_path,dest)
|
||||
if vae_path.is_dir():
|
||||
self.copy_dir(vae_path,dest)
|
||||
else:
|
||||
self.copy_file(vae_path,dest)
|
||||
vae_path = dest
|
||||
|
||||
if vae_path.is_relative_to(self.dest_models):
|
||||
@ -353,7 +318,7 @@ class MigrateTo3(object):
|
||||
else:
|
||||
return vae_path
|
||||
|
||||
def migrate_repo_id(self, repo_id: str, model_name :str=None, **extra_config):
|
||||
def migrate_repo_id(self, repo_id: str, model_name: str=None, **extra_config):
|
||||
'''
|
||||
Migrate a locally-cached diffusers pipeline identified with a repo_id
|
||||
'''
|
||||
@ -385,11 +350,15 @@ class MigrateTo3(object):
|
||||
if not info:
|
||||
return
|
||||
|
||||
dest = self._model_probe_to_path(info) / repo_name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.')
|
||||
return
|
||||
|
||||
dest = self._model_probe_to_path(info) / model_name
|
||||
self._save_pretrained(pipeline, dest)
|
||||
|
||||
rel_path = Path('models',dest.relative_to(dest_dir))
|
||||
self.write_yaml(model_name, path=rel_path, info=info, **extra_config)
|
||||
self._add_model(model_name, info, rel_path, **extra_config)
|
||||
|
||||
def migrate_path(self, location: Path, model_name: str=None, **extra_config):
|
||||
'''
|
||||
@ -399,20 +368,49 @@ class MigrateTo3(object):
|
||||
# handle relative paths
|
||||
dest_dir = self.dest_models
|
||||
location = self.root_directory / location
|
||||
model_name = model_name or location.stem
|
||||
|
||||
info = ModelProbe().heuristic_probe(location)
|
||||
if not info:
|
||||
return
|
||||
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.')
|
||||
return
|
||||
|
||||
# uh oh, weights is in the old models directory - move it into the new one
|
||||
if Path(location).is_relative_to(self.src_paths.models):
|
||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name)
|
||||
self.copy_dir(location,dest)
|
||||
if location.is_dir():
|
||||
self.copy_dir(location,dest)
|
||||
else:
|
||||
self.copy_file(location,dest)
|
||||
location = Path('models', info.base_type.value, info.model_type.value, location.name)
|
||||
model_name = model_name or location.stem
|
||||
model_name = self.unique_name(model_name, info)
|
||||
self.write_yaml(model_name, path=location, info=info, **extra_config)
|
||||
|
||||
self._add_model(model_name, info, location, **extra_config)
|
||||
|
||||
def _add_model(self,
|
||||
model_name: str,
|
||||
info: ModelProbeInfo,
|
||||
location: Path,
|
||||
**extra_config):
|
||||
if info.model_type != ModelType.Main:
|
||||
return
|
||||
|
||||
self.mgr.add_model(
|
||||
model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
clobber = True,
|
||||
model_attributes = {
|
||||
'path': str(location),
|
||||
'description': f'A {info.base_type.value} {info.model_type.value} model',
|
||||
'model_format': info.format,
|
||||
'variant': info.variant_type.value,
|
||||
**extra_config,
|
||||
}
|
||||
)
|
||||
|
||||
def migrate_defined_models(self):
|
||||
'''
|
||||
Migrate models defined in models.yaml
|
||||
@ -434,6 +432,9 @@ class MigrateTo3(object):
|
||||
|
||||
if config := stanza.get('config'):
|
||||
passthru_args['config'] = config
|
||||
|
||||
if description:= stanza.get('description'):
|
||||
passthru_args['description'] = description
|
||||
|
||||
if repo_id := stanza.get('repo_id'):
|
||||
logger.info(f'Migrating diffusers model {model_name}')
|
||||
@ -514,31 +515,50 @@ def get_legacy_embeddings(root: Path) -> ModelPaths:
|
||||
return _parse_legacy_yamlfile(root, path)
|
||||
|
||||
def do_migrate(src_directory: Path, dest_directory: Path):
|
||||
"""
|
||||
Migrate models from src to dest InvokeAI root directories
|
||||
"""
|
||||
config_file = dest_directory / 'configs' / 'models.yaml.3'
|
||||
dest_models = dest_directory / 'models.3'
|
||||
|
||||
dest_models = dest_directory / 'models-3.0'
|
||||
dest_yaml = dest_directory / 'configs/models.yaml-3.0'
|
||||
version_3 = (dest_directory / 'models' / 'core').exists()
|
||||
|
||||
# Here we create the destination models.yaml file.
|
||||
# If we are writing into a version 3 directory and the
|
||||
# file already exists, then we write into a copy of it to
|
||||
# avoid deleting its previous customizations. Otherwise we
|
||||
# create a new empty one.
|
||||
if version_3: # write into the dest directory
|
||||
try:
|
||||
shutil.copy(dest_directory / 'configs' / 'models.yaml', config_file)
|
||||
except:
|
||||
MigrateTo3.initialize_yaml(config_file)
|
||||
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
|
||||
(dest_directory / 'models').replace(dest_models)
|
||||
else:
|
||||
MigrateTo3.initialize_yaml(config_file)
|
||||
mgr = ModelManager(config_file)
|
||||
|
||||
paths = get_legacy_embeddings(src_directory)
|
||||
migrator = MigrateTo3(
|
||||
from_root = src_directory,
|
||||
to_models = dest_models,
|
||||
model_manager = mgr,
|
||||
src_paths = paths
|
||||
)
|
||||
migrator.migrate()
|
||||
print("Migration successful.")
|
||||
|
||||
with open(dest_yaml,'w') as yaml_file:
|
||||
migrator = MigrateTo3(src_directory,
|
||||
dest_models,
|
||||
yaml_file,
|
||||
src_paths = paths,
|
||||
)
|
||||
migrator.migrate()
|
||||
|
||||
shutil.rmtree(dest_directory / 'models.orig', ignore_errors=True)
|
||||
(dest_directory / 'models').replace(dest_directory / 'models.orig')
|
||||
dest_models.replace(dest_directory / 'models')
|
||||
|
||||
(dest_directory /'configs/models.yaml').replace(dest_directory / 'configs/models.yaml.orig')
|
||||
dest_yaml.replace(dest_directory / 'configs/models.yaml')
|
||||
print(f"""Migration successful.
|
||||
Original models directory moved to {dest_directory}/models.orig
|
||||
Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig
|
||||
""")
|
||||
|
||||
if not version_3:
|
||||
(dest_directory / 'models').replace(src_directory / 'models.orig')
|
||||
print(f'Original models directory moved to {dest_directory}/models.orig')
|
||||
|
||||
(dest_directory / 'configs' / 'models.yaml').replace(src_directory / 'configs' / 'models.yaml.orig')
|
||||
print(f'Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig')
|
||||
|
||||
config_file.replace(config_file.with_suffix(''))
|
||||
dest_models.replace(dest_models.with_suffix(''))
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(prog="invokeai-migrate3",
|
||||
description="""
|
||||
@ -550,36 +570,34 @@ It is safe to provide the same directory for both arguments, but it is better to
|
||||
script, which will perform a full upgrade in place."""
|
||||
)
|
||||
parser.add_argument('--from-directory',
|
||||
dest='root_directory',
|
||||
dest='src_root',
|
||||
type=Path,
|
||||
required=True,
|
||||
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")'
|
||||
)
|
||||
parser.add_argument('--to-directory',
|
||||
dest='dest_directory',
|
||||
dest='dest_root',
|
||||
type=Path,
|
||||
required=True,
|
||||
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")'
|
||||
)
|
||||
# TO DO: Implement full directory scanning
|
||||
# parser.add_argument('--all-models',
|
||||
# action="store_true",
|
||||
# help='Migrate all models found in `models` directory, not just those mentioned in models.yaml',
|
||||
# )
|
||||
args = parser.parse_args()
|
||||
root_directory = args.root_directory
|
||||
assert root_directory.is_dir(), f"{root_directory} is not a valid directory"
|
||||
assert (root_directory / 'models').is_dir(), f"{root_directory} does not contain a 'models' subdirectory"
|
||||
assert (root_directory / 'invokeai.init').exists() or (root_directory / 'invokeai.yaml').exists(), f"{root_directory} does not contain an InvokeAI init file."
|
||||
src_root = args.src_root
|
||||
assert src_root.is_dir(), f"{src_root} is not a valid directory"
|
||||
assert (src_root / 'models').is_dir(), f"{src_root} does not contain a 'models' subdirectory"
|
||||
assert (src_root / 'models' / 'hub').exists(), f"{src_root} does not contain a version 2.3 models directory"
|
||||
assert (src_root / 'invokeai.init').exists() or (src_root / 'invokeai.yaml').exists(), f"{src_root} does not contain an InvokeAI init file."
|
||||
|
||||
dest_directory = args.dest_directory
|
||||
assert dest_directory.is_dir(), f"{dest_directory} is not a valid directory"
|
||||
dest_root = args.dest_root
|
||||
assert dest_root.is_dir(), f"{dest_root} is not a valid directory"
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args(['--root',str(dest_root)])
|
||||
|
||||
# TODO: revisit
|
||||
# assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory"
|
||||
# assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file."
|
||||
# assert (dest_root / 'models').is_dir(), f"{dest_root} does not contain a 'models' subdirectory"
|
||||
# assert (dest_root / 'invokeai.yaml').exists(), f"{dest_root} does not contain an InvokeAI init file."
|
||||
|
||||
do_migrate(root_directory,dest_directory)
|
||||
do_migrate(src_root,dest_root)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -11,6 +11,7 @@ from typing import List, Dict, Callable, Union, Set
|
||||
|
||||
import requests
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
@ -153,6 +154,9 @@ class ModelInstall(object):
|
||||
return defaults[0]
|
||||
|
||||
def install(self, selections: InstallSelections):
|
||||
verbosity = dlogging.get_verbosity() # quench NSFW nags
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
job = 1
|
||||
jobs = len(selections.remove_models) + len(selections.install_models)
|
||||
|
||||
@ -160,20 +164,28 @@ class ModelInstall(object):
|
||||
for key in selections.remove_models:
|
||||
name,base,mtype = self.mgr.parse_key(key)
|
||||
logger.info(f'Deleting {mtype} model {name} [{job}/{jobs}]')
|
||||
self.mgr.del_model(name,base,mtype)
|
||||
try:
|
||||
self.mgr.del_model(name,base,mtype)
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(e)
|
||||
job += 1
|
||||
|
||||
# add requested models
|
||||
for path in selections.install_models:
|
||||
logger.info(f'Installing {path} [{job}/{jobs}]')
|
||||
self.heuristic_import(path)
|
||||
try:
|
||||
self.heuristic_import(path)
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.error(str(e))
|
||||
job += 1
|
||||
|
||||
|
||||
dlogging.set_verbosity(verbosity)
|
||||
self.mgr.commit()
|
||||
|
||||
def heuristic_import(self,
|
||||
model_path_id_or_url: Union[str,Path],
|
||||
models_installed: Set[Path]=None)->Dict[str, AddModelResult]:
|
||||
model_path_id_or_url: Union[str,Path],
|
||||
models_installed: Set[Path]=None,
|
||||
)->Dict[str, AddModelResult]:
|
||||
'''
|
||||
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
||||
:param models_installed: Set of installed models, used for recursive invocation
|
||||
@ -186,62 +198,53 @@ class ModelInstall(object):
|
||||
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||
self.current_id = model_path_id_or_url
|
||||
path = Path(model_path_id_or_url)
|
||||
# checkpoint file, or similar
|
||||
if path.is_file():
|
||||
models_installed.update({str(path):self._install_path(path)})
|
||||
|
||||
try:
|
||||
# checkpoint file, or similar
|
||||
if path.is_file():
|
||||
models_installed.update(self._install_path(path))
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any([(path/x).exists() for x in \
|
||||
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
||||
]
|
||||
):
|
||||
models_installed.update(self._install_path(path))
|
||||
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any([(path/x).exists() for x in \
|
||||
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
||||
]
|
||||
):
|
||||
models_installed.update(self._install_path(path))
|
||||
# recursive scan
|
||||
elif path.is_dir():
|
||||
for child in path.iterdir():
|
||||
self.heuristic_import(child, models_installed=models_installed)
|
||||
|
||||
# recursive scan
|
||||
elif path.is_dir():
|
||||
for child in path.iterdir():
|
||||
self.heuristic_import(child, models_installed=models_installed)
|
||||
# huggingface repo
|
||||
elif len(str(model_path_id_or_url).split('/')) == 2:
|
||||
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
||||
|
||||
# huggingface repo
|
||||
elif len(str(path).split('/')) == 2:
|
||||
models_installed.update(self._install_repo(str(path)))
|
||||
# a URL
|
||||
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
|
||||
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
||||
|
||||
# a URL
|
||||
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
|
||||
models_installed.update(self._install_url(model_path_id_or_url))
|
||||
|
||||
else:
|
||||
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
else:
|
||||
raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||
|
||||
return models_installed
|
||||
|
||||
# install a model from a local path. The optional info parameter is there to prevent
|
||||
# the model from being probed twice in the event that it has already been probed.
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]:
|
||||
try:
|
||||
model_result = None
|
||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||
model_name = path.stem if path.is_file() else path.name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||
attributes = self._make_attributes(path,info)
|
||||
model_result = self.mgr.add_model(model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
model_attributes = attributes,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f'{str(e)} Skipping registration.')
|
||||
return {}
|
||||
return {str(path): model_result}
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult:
|
||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||
if not info:
|
||||
logger.warning(f'Unable to parse format of {path}')
|
||||
return None
|
||||
model_name = path.stem if path.is_file() else path.name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||
attributes = self._make_attributes(path,info)
|
||||
return self.mgr.add_model(model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
model_type = info.model_type,
|
||||
model_attributes = attributes,
|
||||
)
|
||||
|
||||
def _install_url(self, url: str)->dict:
|
||||
# copy to a staging area, probe, import and delete
|
||||
def _install_url(self, url: str)->AddModelResult:
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
location = download_with_resume(url,Path(staging))
|
||||
if not location:
|
||||
@ -253,7 +256,7 @@ class ModelInstall(object):
|
||||
# staged version will be garbage-collected at this time
|
||||
return self._install_path(Path(models_path), info)
|
||||
|
||||
def _install_repo(self, repo_id: str)->dict:
|
||||
def _install_repo(self, repo_id: str)->AddModelResult:
|
||||
hinfo = HfApi().model_info(repo_id)
|
||||
|
||||
# we try to figure out how to download this most economically
|
||||
@ -279,16 +282,16 @@ class ModelInstall(object):
|
||||
location = self._download_hf_model(repo_id, files, staging)
|
||||
break
|
||||
elif f'learned_embeds.{suffix}' in files:
|
||||
location = self._download_hf_model(repo_id, ['learned_embeds.suffix'], staging)
|
||||
location = self._download_hf_model(repo_id, [f'learned_embeds.{suffix}'], staging)
|
||||
break
|
||||
if not location:
|
||||
logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.')
|
||||
return
|
||||
|
||||
return {}
|
||||
|
||||
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
||||
if not info:
|
||||
logger.warning(f'Could not probe {location}. Skipping install.')
|
||||
return
|
||||
return {}
|
||||
dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location)
|
||||
if dest.exists():
|
||||
shutil.rmtree(dest)
|
||||
|
@ -1,7 +1,8 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
||||
from .model_cache import ModelCache
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||
|
||||
|
@ -2,8 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Dict, Tuple, Any, Union, List
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple, Union, List
|
||||
|
||||
import torch
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
@ -615,6 +615,24 @@ class ModelPatcher:
|
||||
text_encoder.resize_token_embeddings(init_tokens_count)
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_clip_skip(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
clip_skip: int,
|
||||
):
|
||||
skipped_layers = []
|
||||
try:
|
||||
for i in range(clip_skip):
|
||||
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
while len(skipped_layers) > 0:
|
||||
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
|
||||
|
||||
class TextualInversionModel:
|
||||
name: str
|
||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||
|
@ -342,7 +342,9 @@ class ModelCache(object):
|
||||
for model_key, cache_entry in self._cached_models.items():
|
||||
if not cache_entry.locked and cache_entry.loaded:
|
||||
self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}')
|
||||
cache_entry.model.to(self.storage_device)
|
||||
with VRAMUsage() as mem:
|
||||
cache_entry.model.to(self.storage_device)
|
||||
self.logger.debug(f'GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB')
|
||||
|
||||
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
||||
sha = hashlib.sha256()
|
||||
|
@ -52,7 +52,7 @@ A typical example is:
|
||||
sd1_5 = mgr.get_model('stable-diffusion-v1-5',
|
||||
model_type=ModelType.Main,
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
submodel_type=SubModelType.Unet)
|
||||
submodel_type=SubModelType.UNet)
|
||||
with sd1_5 as unet:
|
||||
run_some_inference(unet)
|
||||
|
||||
@ -234,7 +234,7 @@ import textwrap
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
|
||||
from shutil import rmtree
|
||||
from shutil import rmtree, move
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
@ -279,7 +279,7 @@ class InvalidModelError(Exception):
|
||||
pass
|
||||
|
||||
class AddModelResult(BaseModel):
|
||||
name: str = Field(description="The name of the model after import")
|
||||
name: str = Field(description="The name of the model after installation")
|
||||
model_type: ModelType = Field(description="The type of model")
|
||||
base_model: BaseModelType = Field(description="The base model")
|
||||
config: ModelConfigBase = Field(description="The configuration of the model")
|
||||
@ -311,7 +311,6 @@ class ModelManager(object):
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
|
||||
self.config_path = None
|
||||
if isinstance(config, (str, Path)):
|
||||
self.config_path = Path(config)
|
||||
@ -491,17 +490,32 @@ class ModelManager(object):
|
||||
"""
|
||||
return [(self.parse_key(x)) for x in self.models.keys()]
|
||||
|
||||
def list_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> dict:
|
||||
"""
|
||||
Returns a dict describing one installed model, using
|
||||
the combined format of the list_models() method.
|
||||
"""
|
||||
models = self.list_models(base_model,model_type,model_name)
|
||||
return models[0] if models else None
|
||||
|
||||
def list_models(
|
||||
self,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
model_name: Optional[str] = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Return a list of models.
|
||||
"""
|
||||
|
||||
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
|
||||
models = []
|
||||
for model_key in sorted(self.models, key=str.casefold):
|
||||
for model_key in model_keys:
|
||||
model_config = self.models[model_key]
|
||||
|
||||
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||
@ -546,10 +560,7 @@ class ModelManager(object):
|
||||
model_cfg = self.models.pop(model_key, None)
|
||||
|
||||
if model_cfg is None:
|
||||
self.logger.error(
|
||||
f"Unknown model {model_key}"
|
||||
)
|
||||
return
|
||||
raise KeyError(f"Unknown model {model_key}")
|
||||
|
||||
# note: it not garantie to release memory(model can has other references)
|
||||
cache_ids = self.cache_keys.pop(model_key, [])
|
||||
@ -615,6 +626,7 @@ class ModelManager(object):
|
||||
self.cache.uncache_model(cache_id)
|
||||
|
||||
self.models[model_key] = model_config
|
||||
self.commit()
|
||||
return AddModelResult(
|
||||
name = model_name,
|
||||
model_type = model_type,
|
||||
@ -622,6 +634,60 @@ class ModelManager(object):
|
||||
config = model_config,
|
||||
)
|
||||
|
||||
def convert_model (
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
'''
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
|
||||
This will raise a ValueError unless the model is a checkpoint.
|
||||
'''
|
||||
info = self.model_info(model_name, base_model, model_type)
|
||||
if info["model_format"] != "checkpoint":
|
||||
raise ValueError(f"not a checkpoint format model: {model_name}")
|
||||
|
||||
# We are taking advantage of a side effect of get_model() that converts check points
|
||||
# into cached diffusers directories stored at `location`. It doesn't matter
|
||||
# what submodeltype we request here, so we get the smallest.
|
||||
submodel = {"submodel_type": SubModelType.Tokenizer} if model_type==ModelType.Main else {}
|
||||
model = self.get_model(model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
**submodel,
|
||||
)
|
||||
checkpoint_path = self.app_config.root_path / info["path"]
|
||||
old_diffusers_path = self.app_config.models_path / model.location
|
||||
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
|
||||
if new_diffusers_path.exists():
|
||||
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
||||
|
||||
try:
|
||||
move(old_diffusers_path,new_diffusers_path)
|
||||
info["model_format"] = "diffusers"
|
||||
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path))
|
||||
info.pop('config')
|
||||
|
||||
result = self.add_model(model_name, base_model, model_type,
|
||||
model_attributes = info,
|
||||
clobber=True)
|
||||
except:
|
||||
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
||||
rmtree(new_diffusers_path)
|
||||
raise
|
||||
|
||||
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path):
|
||||
checkpoint_path.unlink()
|
||||
|
||||
return result
|
||||
|
||||
def search_models(self, search_folder):
|
||||
self.logger.info(f"Finding Models In: {search_folder}")
|
||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||
@ -703,6 +769,7 @@ class ModelManager(object):
|
||||
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||
if model_class.save_to_config:
|
||||
model_config.error = ModelError.NotFound
|
||||
self.models.pop(model_key, None)
|
||||
else:
|
||||
self.models.pop(model_key, None)
|
||||
else:
|
||||
@ -821,6 +888,10 @@ class ModelManager(object):
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
|
||||
May return the following exceptions:
|
||||
- KeyError - one or more of the items to import is not a valid path, repo_id or URL
|
||||
- ValueError - a corresponding model already exists
|
||||
'''
|
||||
# avoid circular import here
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
@ -830,11 +901,7 @@ class ModelManager(object):
|
||||
prediction_type_helper = prediction_type_helper,
|
||||
model_manager = self)
|
||||
for thing in items_to_import:
|
||||
try:
|
||||
installed = installer.heuristic_import(thing)
|
||||
successfully_installed.update(installed)
|
||||
except Exception as e:
|
||||
self.logger.warning(f'{thing} could not be imported: {str(e)}')
|
||||
|
||||
installed = installer.heuristic_import(thing)
|
||||
successfully_installed.update(installed)
|
||||
self.commit()
|
||||
return successfully_installed
|
||||
|
131
invokeai/backend/model_management/model_merge.py
Normal file
131
invokeai/backend/model_management/model_merge.py
Normal file
@ -0,0 +1,131 @@
|
||||
"""
|
||||
invokeai.backend.model_management.model_merge exports:
|
||||
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
|
||||
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from typing import List, Union
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||
|
||||
class MergeInterpolationMethod(str, Enum):
|
||||
WeightedSum = "weighted_sum"
|
||||
Sigmoid = "sigmoid"
|
||||
InvSigmoid = "inv_sigmoid"
|
||||
AddDifference = "add_difference"
|
||||
|
||||
class ModelMerger(object):
|
||||
def __init__(self, manager: ModelManager):
|
||||
self.manager = manager
|
||||
|
||||
def merge_diffusion_models(
|
||||
self,
|
||||
model_paths: List[Path],
|
||||
alpha: float = 0.5,
|
||||
interp: MergeInterpolationMethod = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
|
||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_paths[0],
|
||||
custom_pipeline="checkpoint_merger",
|
||||
)
|
||||
merged_pipe = pipe.merge(
|
||||
pretrained_model_name_or_path_list=model_paths,
|
||||
alpha=alpha,
|
||||
interp=interp.value if interp else None, #diffusers API treats None as "weighted sum"
|
||||
force=force,
|
||||
**kwargs,
|
||||
)
|
||||
dlogging.set_verbosity(verbosity)
|
||||
return merged_pipe
|
||||
|
||||
|
||||
def merge_diffusion_models_and_save (
|
||||
self,
|
||||
model_names: List[str],
|
||||
base_model: Union[BaseModelType,str],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
interp: MergeInterpolationMethod = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
:param models: up to three models, designated by their InvokeAI models.yaml model name
|
||||
:param base_model: base model (must be the same for all merged models!)
|
||||
:param merged_model_name: name for new model
|
||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
model_paths = list()
|
||||
config = self.manager.app_config
|
||||
base_model = BaseModelType(base_model)
|
||||
vae = None
|
||||
|
||||
for mod in model_names:
|
||||
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
|
||||
assert info, f"model {mod}, base_model {base_model}, is unknown"
|
||||
assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
|
||||
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
|
||||
assert len(model_names) <= 2 or \
|
||||
interp==MergeInterpolationMethod.AddDifference, "When merging three models, only the 'add_difference' merge method is supported"
|
||||
# pick up the first model's vae
|
||||
if mod == model_names[0]:
|
||||
vae = info.get("vae")
|
||||
model_paths.extend([config.root_path / info["path"]])
|
||||
|
||||
merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp)
|
||||
logger.debug(f'interp = {interp}, merge_method={merge_method}')
|
||||
merged_pipe = self.merge_diffusion_models(
|
||||
model_paths, alpha, merge_method, force, **kwargs
|
||||
)
|
||||
dump_path = config.models_path / base_model.value / ModelType.Main.value
|
||||
dump_path.mkdir(parents=True, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
attributes = dict(
|
||||
path = str(dump_path),
|
||||
description = f"Merge of models {', '.join(model_names)}",
|
||||
model_format = "diffusers",
|
||||
variant = ModelVariantType.Normal.value,
|
||||
vae = vae,
|
||||
)
|
||||
return self.manager.add_model(merged_model_name,
|
||||
base_model = base_model,
|
||||
model_type = ModelType.Main,
|
||||
model_attributes = attributes,
|
||||
clobber = True
|
||||
)
|
@ -6,7 +6,7 @@ from dataclasses import dataclass
|
||||
|
||||
from diffusers import ModelMixin, ConfigMixin
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal, Union, Dict
|
||||
from typing import Callable, Literal, Union, Dict, Optional
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from .models import (
|
||||
@ -64,8 +64,8 @@ class ModelProbe(object):
|
||||
@classmethod
|
||||
def probe(cls,
|
||||
model_path: Path,
|
||||
model: Union[Dict, ModelMixin] = None,
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType] = None)->ModelProbeInfo:
|
||||
model: Optional[Union[Dict, ModelMixin]] = None,
|
||||
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]] = None)->ModelProbeInfo:
|
||||
'''
|
||||
Probe the model at model_path and return sufficient information about it
|
||||
to place it somewhere in the models directory hierarchy. If the model is
|
||||
@ -168,7 +168,7 @@ class ModelProbe(object):
|
||||
return type
|
||||
|
||||
# give up
|
||||
raise ValueError("Unable to determine model type for {folder_path}")
|
||||
raise ValueError(f"Unable to determine model type for {folder_path}")
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
||||
|
@ -68,7 +68,11 @@ def get_model_config_enums():
|
||||
enums = list()
|
||||
|
||||
for model_config in MODEL_CONFIGS:
|
||||
fields = inspect.get_annotations(model_config)
|
||||
|
||||
if hasattr(inspect,'get_annotations'):
|
||||
fields = inspect.get_annotations(model_config)
|
||||
else:
|
||||
fields = model_config.__annotations__
|
||||
try:
|
||||
field = fields["model_format"]
|
||||
except:
|
||||
|
@ -116,7 +116,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
version=BaseModelType.StableDiffusion1,
|
||||
model_config=config,
|
||||
output_path=output_path,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
@ -7,7 +7,7 @@ import secrets
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
import einops
|
||||
import PIL.Image
|
||||
@ -17,12 +17,11 @@ import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from diffusers.pipelines.controlnet import MultiControlNetModel
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
@ -46,7 +45,7 @@ from .diffusion import (
|
||||
InvokeAIDiffuserComponent,
|
||||
PostprocessingSettings,
|
||||
)
|
||||
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
||||
from .offloading import FullyLoadedModelGroup, ModelGroup
|
||||
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
@ -105,7 +104,7 @@ class AddsMaskGuidance:
|
||||
_debug: Optional[Callable] = None
|
||||
|
||||
def __call__(
|
||||
self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning
|
||||
self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning
|
||||
) -> BaseOutput:
|
||||
output_class = step_output.__class__ # We'll create a new one with masked data.
|
||||
|
||||
@ -361,37 +360,34 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
):
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
if torch.backends.mps.is_available():
|
||||
# until pytorch #91617 is fixed, slicing is borked on MPS
|
||||
# https://github.com/pytorch/pytorch/issues/91617
|
||||
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
|
||||
pass
|
||||
if self.device.type == "cpu" or self.device.type == "mps":
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
|
||||
else:
|
||||
if self.device.type == "cpu" or self.device.type == "mps":
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device))
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {self.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
||||
bytes_per_element_needed_for_baddbmm_duplication = (
|
||||
latents.element_size() + 4
|
||||
)
|
||||
max_size_required_for_baddbmm = (
|
||||
16
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* bytes_per_element_needed_for_baddbmm_duplication
|
||||
)
|
||||
if max_size_required_for_baddbmm > (
|
||||
mem_free * 3.0 / 4.0
|
||||
): # 3.3 / 4.0 is from old Invoke code
|
||||
self.enable_attention_slicing(slice_size="max")
|
||||
else:
|
||||
self.disable_attention_slicing()
|
||||
raise ValueError(f"unrecognized device {self.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
||||
bytes_per_element_needed_for_baddbmm_duplication = (
|
||||
latents.element_size() + 4
|
||||
)
|
||||
max_size_required_for_baddbmm = (
|
||||
16
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* bytes_per_element_needed_for_baddbmm_duplication
|
||||
)
|
||||
if max_size_required_for_baddbmm > (
|
||||
mem_free * 3.0 / 4.0
|
||||
): # 3.3 / 4.0 is from old Invoke code
|
||||
self.enable_attention_slicing(slice_size="max")
|
||||
elif torch.backends.mps.is_available():
|
||||
# diffusers recommends always enabling for mps
|
||||
self.enable_attention_slicing(slice_size="max")
|
||||
else:
|
||||
self.disable_attention_slicing()
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||
# overridden method; types match the superclass.
|
||||
@ -917,20 +913,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
|
||||
init_image = init_image.to(device=device, dtype=dtype)
|
||||
with torch.inference_mode():
|
||||
if device.type == "mps":
|
||||
# workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
|
||||
# TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
|
||||
self.vae.to(CPU_DEVICE)
|
||||
init_image = init_image.to(CPU_DEVICE)
|
||||
else:
|
||||
self._model_group.load(self.vae)
|
||||
self._model_group.load(self.vae)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample().to(
|
||||
dtype=dtype
|
||||
) # FIXME: uses torch.randn. make reproducible!
|
||||
if device.type == "mps":
|
||||
self.vae.to(device)
|
||||
init_latents = init_latents.to(device)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
return init_latents
|
||||
|
@ -248,9 +248,6 @@ class InvokeAIDiffuserComponent:
|
||||
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||
)
|
||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||
if conditioned_next_x.device.type == "mps":
|
||||
# prevent a result filled with zeros. seems to be a torch bug.
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def _apply_standard_conditioning_sequentially(
|
||||
@ -264,9 +261,6 @@ class InvokeAIDiffuserComponent:
|
||||
# low-memory sequential path
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
|
||||
if conditioned_next_x.device.type == "mps":
|
||||
# prevent a result filled with zeros. seems to be a torch bug.
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
# TODO: looks unused
|
||||
|
@ -4,7 +4,7 @@ import warnings
|
||||
import weakref
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Callable
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
from accelerate.utils import send_to_device
|
||||
@ -117,7 +117,7 @@ class LazilyLoadedModelGroup(ModelGroup):
|
||||
"""
|
||||
|
||||
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
|
||||
_current_model_ref: Callable[[], torch.nn.Module | _NoModel]
|
||||
_current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]]
|
||||
|
||||
def __init__(self, execution_device: torch.device):
|
||||
super().__init__(execution_device)
|
||||
|
@ -4,6 +4,7 @@ from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch import autocast
|
||||
from typing import Union
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
@ -28,6 +29,8 @@ def choose_precision(device: torch.device) -> str:
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
||||
return "float16"
|
||||
elif device.type == "mps":
|
||||
return "float16"
|
||||
return "float32"
|
||||
|
||||
|
||||
@ -49,7 +52,7 @@ def choose_autocast(precision):
|
||||
return nullcontext
|
||||
|
||||
|
||||
def normalize_device(device: str | torch.device) -> torch.device:
|
||||
def normalize_device(device: Union[str, torch.device]) -> torch.device:
|
||||
"""Ensure device has a device index defined, if appropriate."""
|
||||
device = torch.device(device)
|
||||
if device.index is None:
|
||||
|
63
invokeai/backend/util/mps_fixes.py
Normal file
63
invokeai/backend/util/mps_fixes.py
Normal file
@ -0,0 +1,63 @@
|
||||
import torch
|
||||
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
torch.empty = torch.zeros
|
||||
|
||||
|
||||
_torch_layer_norm = torch.nn.functional.layer_norm
|
||||
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
input = input.float()
|
||||
if weight is not None:
|
||||
weight = weight.float()
|
||||
if bias is not None:
|
||||
bias = bias.float()
|
||||
return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half()
|
||||
else:
|
||||
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
|
||||
|
||||
torch.nn.functional.layer_norm = new_layer_norm
|
||||
|
||||
|
||||
_torch_tensor_permute = torch.Tensor.permute
|
||||
def new_torch_tensor_permute(input, *dims):
|
||||
result = _torch_tensor_permute(input, *dims)
|
||||
if input.device == "mps" and input.dtype == torch.float16:
|
||||
result = result.contiguous()
|
||||
return result
|
||||
|
||||
torch.Tensor.permute = new_torch_tensor_permute
|
||||
|
||||
|
||||
_torch_lerp = torch.lerp
|
||||
def new_torch_lerp(input, end, weight, *, out=None):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
input = input.float()
|
||||
end = end.float()
|
||||
if isinstance(weight, torch.Tensor):
|
||||
weight = weight.float()
|
||||
if out is not None:
|
||||
out_fp32 = torch.zeros_like(out, dtype=torch.float32)
|
||||
else:
|
||||
out_fp32 = None
|
||||
result = _torch_lerp(input, end, weight, out=out_fp32)
|
||||
if out is not None:
|
||||
out.copy_(out_fp32.half())
|
||||
del out_fp32
|
||||
return result.half()
|
||||
|
||||
else:
|
||||
return _torch_lerp(input, end, weight, out=out)
|
||||
|
||||
torch.lerp = new_torch_lerp
|
||||
|
||||
|
||||
_torch_interpolate = torch.nn.functional.interpolate
|
||||
def new_torch_interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
|
||||
if input.device.type == "mps" and input.dtype == torch.float16:
|
||||
return _torch_interpolate(input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias).half()
|
||||
else:
|
||||
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
|
||||
|
||||
torch.nn.functional.interpolate = new_torch_interpolate
|
@ -108,11 +108,11 @@ def main():
|
||||
|
||||
print(f':crossed_fingers: Upgrading to [yellow]{tag if tag else release}[/yellow]')
|
||||
if release:
|
||||
cmd = f"pip install 'invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip' --use-pep517 --upgrade"
|
||||
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip" --use-pep517 --upgrade'
|
||||
elif tag:
|
||||
cmd = f"pip install 'invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip' --use-pep517 --upgrade"
|
||||
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip" --use-pep517 --upgrade'
|
||||
else:
|
||||
cmd = f"pip install 'invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip' --use-pep517 --upgrade"
|
||||
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip" --use-pep517 --upgrade'
|
||||
print('')
|
||||
print('')
|
||||
if os.system(cmd)==0:
|
||||
|
@ -382,10 +382,21 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
)
|
||||
return min(cols, len(self.installed_models))
|
||||
|
||||
def confirm_deletions(self, selections: InstallSelections)->bool:
|
||||
remove_models = selections.remove_models
|
||||
if len(remove_models) > 0:
|
||||
mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models])
|
||||
return npyscreen.notify_ok_cancel(f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}")
|
||||
else:
|
||||
return True
|
||||
|
||||
def on_execute(self):
|
||||
self.monitor.entry_widget.buffer(['Processing...'],scroll_end=True)
|
||||
self.marshall_arguments()
|
||||
app = self.parentApp
|
||||
if not self.confirm_deletions(app.install_selections):
|
||||
return
|
||||
|
||||
self.monitor.entry_widget.buffer(['Processing...'],scroll_end=True)
|
||||
self.ok_button.hidden = True
|
||||
self.display()
|
||||
|
||||
@ -417,6 +428,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
|
||||
def on_done(self):
|
||||
self.marshall_arguments()
|
||||
if not self.confirm_deletions(self.parentApp.install_selections):
|
||||
return
|
||||
self.parentApp.setNextForm(None)
|
||||
self.parentApp.user_cancelled = False
|
||||
self.editing = False
|
||||
|
@ -18,7 +18,7 @@ from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
|
||||
|
||||
# minimum size for UIs
|
||||
MIN_COLS = 130
|
||||
MIN_LINES = 40
|
||||
MIN_LINES = 45
|
||||
|
||||
# -------------------------------------
|
||||
def set_terminal_size(columns: int, lines: int, launch_command: str=None):
|
||||
|
19
invokeai/frontend/legacy_launch_invokeai.py
Normal file
19
invokeai/frontend/legacy_launch_invokeai.py
Normal file
@ -0,0 +1,19 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--web', action='store_true')
|
||||
opts,_ = parser.parse_known_args()
|
||||
|
||||
if opts.web:
|
||||
sys.argv.pop(sys.argv.index('--web'))
|
||||
from invokeai.app.api_app import invoke_api
|
||||
invoke_api()
|
||||
else:
|
||||
from invokeai.app.cli_app import invoke_cli
|
||||
invoke_cli()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,4 +1,5 @@
|
||||
"""
|
||||
Initialization file for invokeai.frontend.merge
|
||||
"""
|
||||
from .merge_diffusers import main as invokeai_merge_diffusers, merge_diffusion_models
|
||||
from .merge_diffusers import main as invokeai_merge_diffusers
|
||||
|
||||
|
@ -6,9 +6,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
import argparse
|
||||
import curses
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
@ -20,99 +18,15 @@ from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.services.config import InvokeAIAppConfig
|
||||
from ...backend.model_management import ModelManager
|
||||
from ...frontend.install.widgets import FloatTitleSlider
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import (
|
||||
ModelMerger, MergeInterpolationMethod,
|
||||
ModelManager, ModelType, BaseModelType,
|
||||
)
|
||||
from invokeai.frontend.install.widgets import FloatTitleSlider, TextBox, SingleSelectColumns
|
||||
|
||||
DEST_MERGED_MODEL_DIR = "merged_models"
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
def merge_diffusion_models(
|
||||
model_ids_or_paths: List[Union[str, Path]],
|
||||
alpha: float = 0.5,
|
||||
interp: str = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
model_ids_or_paths - up to three models, designated by their local paths or HuggingFace repo_ids
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_ids_or_paths[0],
|
||||
cache_dir=kwargs.get("cache_dir", config.cache_dir),
|
||||
custom_pipeline="checkpoint_merger",
|
||||
)
|
||||
merged_pipe = pipe.merge(
|
||||
pretrained_model_name_or_path_list=model_ids_or_paths,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
**kwargs,
|
||||
)
|
||||
dlogging.set_verbosity(verbosity)
|
||||
return merged_pipe
|
||||
|
||||
|
||||
def merge_diffusion_models_and_commit(
|
||||
models: List["str"],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
interp: str = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
models - up to three models, designated by their InvokeAI models.yaml model name
|
||||
merged_model_name = name for new model
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
config_file = config.model_conf_path
|
||||
model_manager = ModelManager(OmegaConf.load(config_file))
|
||||
for mod in models:
|
||||
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
|
||||
assert (
|
||||
model_manager.model_info(mod).get("format", None) == "diffusers"
|
||||
), f"** {mod} is not a diffusers model. It must be optimized before merging."
|
||||
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
|
||||
|
||||
merged_pipe = merge_diffusion_models(
|
||||
model_ids_or_paths, alpha, interp, force, **kwargs
|
||||
)
|
||||
dump_path = config.models_dir / DEST_MERGED_MODEL_DIR
|
||||
|
||||
os.makedirs(dump_path, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
import_args = dict(
|
||||
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
||||
)
|
||||
if vae := model_manager.config[models[0]].get("vae", None):
|
||||
logger.info(f"Using configured VAE assigned to {models[0]}")
|
||||
import_args.update(vae=vae)
|
||||
model_manager.import_diffuser_model(dump_path, **import_args)
|
||||
model_manager.commit(config_file)
|
||||
|
||||
|
||||
def _parse_args() -> Namespace:
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
||||
parser.add_argument(
|
||||
@ -131,10 +45,17 @@ def _parse_args() -> Namespace:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
dest="model_names",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Two to three model names to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_model",
|
||||
type=str,
|
||||
choices=[x.value for x in BaseModelType],
|
||||
help="The base model shared by the models to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merged_model_name",
|
||||
"--destination",
|
||||
@ -192,6 +113,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
window_height, window_width = curses.initscr().getmaxyx()
|
||||
|
||||
self.model_names = self.get_model_names()
|
||||
self.current_base = 0
|
||||
max_width = max([len(x) for x in self.model_names])
|
||||
max_width += 6
|
||||
horizontal_layout = max_width * 3 < window_width
|
||||
@ -208,12 +130,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
|
||||
editable=False,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.base_select = self.add_widget_intelligent(
|
||||
SingleSelectColumns,
|
||||
values=[
|
||||
'Models Built on SD-1.x',
|
||||
'Models Built on SD-2.x',
|
||||
],
|
||||
value=[self.current_base],
|
||||
columns = 4,
|
||||
max_height = 2,
|
||||
relx=8,
|
||||
scroll_exit = True,
|
||||
)
|
||||
self.base_select.on_changed = self._populate_models
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="MODEL 1",
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
rely=4 if horizontal_layout else None,
|
||||
rely=6 if horizontal_layout else None,
|
||||
)
|
||||
self.model1 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
@ -222,7 +158,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
max_height=len(self.model_names),
|
||||
max_width=max_width,
|
||||
scroll_exit=True,
|
||||
rely=5,
|
||||
rely=7,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
@ -230,7 +166,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
relx=max_width + 3 if horizontal_layout else None,
|
||||
rely=4 if horizontal_layout else None,
|
||||
rely=6 if horizontal_layout else None,
|
||||
)
|
||||
self.model2 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
@ -240,7 +176,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
max_height=len(self.model_names),
|
||||
max_width=max_width,
|
||||
relx=max_width + 3 if horizontal_layout else None,
|
||||
rely=5 if horizontal_layout else None,
|
||||
rely=7 if horizontal_layout else None,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
@ -249,7 +185,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||
rely=4 if horizontal_layout else None,
|
||||
rely=6 if horizontal_layout else None,
|
||||
)
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0, "None")
|
||||
@ -262,24 +198,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
max_width=max_width,
|
||||
scroll_exit=True,
|
||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||
rely=5 if horizontal_layout else None,
|
||||
rely=7 if horizontal_layout else None,
|
||||
)
|
||||
for m in [self.model1, self.model2, self.model3]:
|
||||
m.when_value_edited = self.models_changed
|
||||
self.merged_model_name = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
TextBox,
|
||||
name="Name for merged model:",
|
||||
labelColor="CONTROL",
|
||||
max_height=3,
|
||||
value="",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.force = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Force merge of incompatible models",
|
||||
name="Force merge of models created by different diffusers library versions",
|
||||
labelColor="CONTROL",
|
||||
value=False,
|
||||
value=True,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.merge_method = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Merge Method:",
|
||||
@ -341,7 +279,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
interp = self.interpolations[self.merge_method.value[0]]
|
||||
|
||||
args = dict(
|
||||
models=models,
|
||||
model_names=models,
|
||||
base_model=tuple(BaseModelType)[self.base_select.value[0]],
|
||||
alpha=self.alpha.value,
|
||||
interp=interp,
|
||||
force=self.force.value,
|
||||
@ -379,21 +318,30 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self) -> List[str]:
|
||||
def get_model_names(self, base_model: BaseModelType=None) -> List[str]:
|
||||
model_names = [
|
||||
name
|
||||
for name in self.model_manager.model_names()
|
||||
if self.model_manager.model_info(name).get("format") == "diffusers"
|
||||
info["name"]
|
||||
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
|
||||
if info["model_format"] == "diffusers"
|
||||
]
|
||||
return sorted(model_names)
|
||||
|
||||
def _populate_models(self,value=None):
|
||||
base_model = tuple(BaseModelType)[value[0]]
|
||||
self.model_names = self.get_model_names(base_model)
|
||||
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0, "None")
|
||||
self.model1.values = self.model_names
|
||||
self.model2.values = self.model_names
|
||||
self.model3.values = models_plus_none
|
||||
|
||||
self.display()
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
def __init__(self, model_manager:ModelManager):
|
||||
super().__init__()
|
||||
conf = OmegaConf.load(config.model_conf_path)
|
||||
self.model_manager = ModelManager(
|
||||
conf, "cpu", "float16"
|
||||
) # precision doesn't really matter here
|
||||
self.model_manager = model_manager
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
||||
@ -401,44 +349,41 @@ class Mergeapp(npyscreen.NPSAppManaged):
|
||||
|
||||
|
||||
def run_gui(args: Namespace):
|
||||
mergeapp = Mergeapp()
|
||||
model_manager = ModelManager(config.model_conf_path)
|
||||
mergeapp = Mergeapp(model_manager)
|
||||
mergeapp.run()
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
merge_diffusion_models_and_commit(**args)
|
||||
merger = ModelMerger(model_manager)
|
||||
merger.merge_diffusion_models_and_save(**args)
|
||||
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
||||
|
||||
|
||||
def run_cli(args: Namespace):
|
||||
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
||||
assert (
|
||||
args.models and len(args.models) >= 1 and len(args.models) <= 3
|
||||
args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
|
||||
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
|
||||
|
||||
if not args.merged_model_name:
|
||||
args.merged_model_name = "+".join(args.models)
|
||||
args.merged_model_name = "+".join(args.model_names)
|
||||
logger.info(
|
||||
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||
)
|
||||
|
||||
model_manager = ModelManager(OmegaConf.load(config.model_conf_path))
|
||||
assert (
|
||||
args.clobber or args.merged_model_name not in model_manager.model_names()
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
model_manager = ModelManager(config.model_conf_path)
|
||||
assert (
|
||||
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
merge_diffusion_models_and_commit(**vars(args))
|
||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||
merger = ModelMerger(model_manager)
|
||||
merger.merge_diffusion_models_and_save(**vars(args))
|
||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
config.root = args.root_dir
|
||||
|
||||
cache_dir = config.cache_dir
|
||||
os.environ[
|
||||
"HF_HOME"
|
||||
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
||||
args.cache_dir = cache_dir
|
||||
config.parse_args(['--root',str(args.root_dir)])
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
|
1
invokeai/frontend/web/dist/assets/App-4c33c38e.css
vendored
Normal file
1
invokeai/frontend/web/dist/assets/App-4c33c38e.css
vendored
Normal file
File diff suppressed because one or more lines are too long
199
invokeai/frontend/web/dist/assets/App-9a48e001.js
vendored
Normal file
199
invokeai/frontend/web/dist/assets/App-9a48e001.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
invokeai/frontend/web/dist/assets/MantineProvider-d7f01775.js
vendored
Normal file
1
invokeai/frontend/web/dist/assets/MantineProvider-d7f01775.js
vendored
Normal file
File diff suppressed because one or more lines are too long
302
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-419b361f.js
vendored
Normal file
302
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-419b361f.js
vendored
Normal file
File diff suppressed because one or more lines are too long
9
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-8d49f92d.css
vendored
Normal file
9
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-8d49f92d.css
vendored
Normal file
File diff suppressed because one or more lines are too long
125
invokeai/frontend/web/dist/assets/index-581af3d4.js
vendored
Normal file
125
invokeai/frontend/web/dist/assets/index-581af3d4.js
vendored
Normal file
File diff suppressed because one or more lines are too long
BIN
invokeai/frontend/web/dist/assets/inter-cyrillic-ext-wght-normal-848492d3.woff2
vendored
Normal file
BIN
invokeai/frontend/web/dist/assets/inter-cyrillic-ext-wght-normal-848492d3.woff2
vendored
Normal file
Binary file not shown.
BIN
invokeai/frontend/web/dist/assets/inter-cyrillic-wght-normal-262a1054.woff2
vendored
Normal file
BIN
invokeai/frontend/web/dist/assets/inter-cyrillic-wght-normal-262a1054.woff2
vendored
Normal file
Binary file not shown.
BIN
invokeai/frontend/web/dist/assets/inter-greek-ext-wght-normal-fe977ddb.woff2
vendored
Normal file
BIN
invokeai/frontend/web/dist/assets/inter-greek-ext-wght-normal-fe977ddb.woff2
vendored
Normal file
Binary file not shown.
BIN
invokeai/frontend/web/dist/assets/inter-greek-wght-normal-89b4a3fe.woff2
vendored
Normal file
BIN
invokeai/frontend/web/dist/assets/inter-greek-wght-normal-89b4a3fe.woff2
vendored
Normal file
Binary file not shown.
BIN
invokeai/frontend/web/dist/assets/inter-latin-ext-wght-normal-45606f83.woff2
vendored
Normal file
BIN
invokeai/frontend/web/dist/assets/inter-latin-ext-wght-normal-45606f83.woff2
vendored
Normal file
Binary file not shown.
BIN
invokeai/frontend/web/dist/assets/inter-latin-wght-normal-450f3ba4.woff2
vendored
Normal file
BIN
invokeai/frontend/web/dist/assets/inter-latin-wght-normal-450f3ba4.woff2
vendored
Normal file
Binary file not shown.
BIN
invokeai/frontend/web/dist/assets/inter-vietnamese-wght-normal-ac4e131c.woff2
vendored
Normal file
BIN
invokeai/frontend/web/dist/assets/inter-vietnamese-wght-normal-ac4e131c.woff2
vendored
Normal file
Binary file not shown.
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-c0367e37.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-581af3d4.js"></script>
|
||||
</head>
|
||||
|
||||
<body dir="ltr">
|
||||
|
1
invokeai/frontend/web/dist/locales/en.json
vendored
1
invokeai/frontend/web/dist/locales/en.json
vendored
@ -52,6 +52,7 @@
|
||||
"unifiedCanvas": "Unified Canvas",
|
||||
"linear": "Linear",
|
||||
"nodes": "Node Editor",
|
||||
"batch": "Batch Manager",
|
||||
"modelmanager": "Model Manager",
|
||||
"postprocessing": "Post Processing",
|
||||
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
||||
|
122
invokeai/frontend/web/dist/locales/fi.json
vendored
Normal file
122
invokeai/frontend/web/dist/locales/fi.json
vendored
Normal file
@ -0,0 +1,122 @@
|
||||
{
|
||||
"accessibility": {
|
||||
"reset": "Resetoi",
|
||||
"useThisParameter": "Käytä tätä parametria",
|
||||
"modelSelect": "Mallin Valinta",
|
||||
"exitViewer": "Poistu katselimesta",
|
||||
"uploadImage": "Lataa kuva",
|
||||
"copyMetadataJson": "Kopioi metadata JSON:iin",
|
||||
"invokeProgressBar": "Invoken edistymispalkki",
|
||||
"nextImage": "Seuraava kuva",
|
||||
"previousImage": "Edellinen kuva",
|
||||
"zoomIn": "Lähennä",
|
||||
"flipHorizontally": "Käännä vaakasuoraan",
|
||||
"zoomOut": "Loitonna",
|
||||
"rotateCounterClockwise": "Kierrä vastapäivään",
|
||||
"rotateClockwise": "Kierrä myötäpäivään",
|
||||
"flipVertically": "Käännä pystysuoraan",
|
||||
"showGallery": "Näytä galleria",
|
||||
"modifyConfig": "Muokkaa konfiguraatiota",
|
||||
"toggleAutoscroll": "Kytke automaattinen vieritys",
|
||||
"toggleLogViewer": "Kytke lokin katselutila",
|
||||
"showOptionsPanel": "Näytä asetukset"
|
||||
},
|
||||
"common": {
|
||||
"postProcessDesc2": "Erillinen käyttöliittymä tullaan julkaisemaan helpottaaksemme työnkulkua jälkikäsittelyssä.",
|
||||
"training": "Kouluta",
|
||||
"statusLoadingModel": "Ladataan mallia",
|
||||
"statusModelChanged": "Malli vaihdettu",
|
||||
"statusConvertingModel": "Muunnetaan mallia",
|
||||
"statusModelConverted": "Malli muunnettu",
|
||||
"langFrench": "Ranska",
|
||||
"langItalian": "Italia",
|
||||
"languagePickerLabel": "Kielen valinta",
|
||||
"hotkeysLabel": "Pikanäppäimet",
|
||||
"reportBugLabel": "Raportoi Bugista",
|
||||
"langPolish": "Puola",
|
||||
"themeLabel": "Teema",
|
||||
"langDutch": "Hollanti",
|
||||
"settingsLabel": "Asetukset",
|
||||
"githubLabel": "Github",
|
||||
"darkTheme": "Tumma",
|
||||
"lightTheme": "Vaalea",
|
||||
"greenTheme": "Vihreä",
|
||||
"langGerman": "Saksa",
|
||||
"langPortuguese": "Portugali",
|
||||
"discordLabel": "Discord",
|
||||
"langEnglish": "Englanti",
|
||||
"oceanTheme": "Meren sininen",
|
||||
"langRussian": "Venäjä",
|
||||
"langUkranian": "Ukraina",
|
||||
"langSpanish": "Espanja",
|
||||
"upload": "Lataa",
|
||||
"statusMergedModels": "Mallit yhdistelty",
|
||||
"img2img": "Kuva kuvaksi",
|
||||
"nodes": "Solmut",
|
||||
"nodesDesc": "Solmupohjainen järjestelmä kuvien generoimiseen on parhaillaan kehitteillä. Pysy kuulolla päivityksistä tähän uskomattomaan ominaisuuteen liittyen.",
|
||||
"postProcessDesc1": "Invoke AI tarjoaa monenlaisia jälkikäsittelyominaisuukisa. Kuvan laadun skaalaus sekä kasvojen korjaus ovat jo saatavilla WebUI:ssä. Voit ottaa ne käyttöön lisäasetusten valikosta teksti kuvaksi sekä kuva kuvaksi -välilehdiltä. Voit myös suoraan prosessoida kuvia käyttämällä kuvan toimintapainikkeita nykyisen kuvan yläpuolella tai tarkastelussa.",
|
||||
"postprocessing": "Jälkikäsitellään",
|
||||
"postProcessing": "Jälkikäsitellään",
|
||||
"cancel": "Peruuta",
|
||||
"close": "Sulje",
|
||||
"accept": "Hyväksy",
|
||||
"statusConnected": "Yhdistetty",
|
||||
"statusError": "Virhe",
|
||||
"statusProcessingComplete": "Prosessointi valmis",
|
||||
"load": "Lataa",
|
||||
"back": "Takaisin",
|
||||
"statusGeneratingTextToImage": "Generoidaan tekstiä kuvaksi",
|
||||
"trainingDesc2": "InvokeAI tukee jo mukautettujen upotusten kouluttamista tekstin inversiolla käyttäen pääskriptiä.",
|
||||
"statusDisconnected": "Yhteys katkaistu",
|
||||
"statusPreparing": "Valmistellaan",
|
||||
"statusIterationComplete": "Iteraatio valmis",
|
||||
"statusMergingModels": "Yhdistellään malleja",
|
||||
"statusProcessingCanceled": "Valmistelu peruutettu",
|
||||
"statusSavingImage": "Tallennetaan kuvaa",
|
||||
"statusGeneratingImageToImage": "Generoidaan kuvaa kuvaksi",
|
||||
"statusRestoringFacesGFPGAN": "Korjataan kasvoja (GFPGAN)",
|
||||
"statusRestoringFacesCodeFormer": "Korjataan kasvoja (CodeFormer)",
|
||||
"statusGeneratingInpainting": "Generoidaan sisällemaalausta",
|
||||
"statusGeneratingOutpainting": "Generoidaan ulosmaalausta",
|
||||
"statusRestoringFaces": "Korjataan kasvoja",
|
||||
"pinOptionsPanel": "Kiinnitä asetukset -paneeli",
|
||||
"loadingInvokeAI": "Ladataan Invoke AI:ta",
|
||||
"loading": "Ladataan",
|
||||
"statusGenerating": "Generoidaan",
|
||||
"txt2img": "Teksti kuvaksi",
|
||||
"trainingDesc1": "Erillinen työnkulku omien upotusten ja tarkastuspisteiden kouluttamiseksi käyttäen tekstin inversiota ja dreamboothia selaimen käyttöliittymässä.",
|
||||
"postProcessDesc3": "Invoke AI:n komentorivi tarjoaa paljon muita ominaisuuksia, kuten esimerkiksi Embiggenin.",
|
||||
"unifiedCanvas": "Yhdistetty kanvas",
|
||||
"statusGenerationComplete": "Generointi valmis"
|
||||
},
|
||||
"gallery": {
|
||||
"uploads": "Lataukset",
|
||||
"showUploads": "Näytä lataukset",
|
||||
"galleryImageResetSize": "Resetoi koko",
|
||||
"maintainAspectRatio": "Säilytä kuvasuhde",
|
||||
"galleryImageSize": "Kuvan koko",
|
||||
"pinGallery": "Kiinnitä galleria",
|
||||
"showGenerations": "Näytä generaatiot",
|
||||
"singleColumnLayout": "Yhden sarakkeen asettelu",
|
||||
"generations": "Generoinnit",
|
||||
"gallerySettings": "Gallerian asetukset",
|
||||
"autoSwitchNewImages": "Vaihda uusiin kuviin automaattisesti",
|
||||
"allImagesLoaded": "Kaikki kuvat ladattu",
|
||||
"noImagesInGallery": "Ei kuvia galleriassa",
|
||||
"loadMore": "Lataa lisää"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "näppäimistön pikavalinnat",
|
||||
"appHotkeys": "Sovelluksen pikanäppäimet",
|
||||
"generalHotkeys": "Yleiset pikanäppäimet",
|
||||
"galleryHotkeys": "Gallerian pikanäppäimet",
|
||||
"unifiedCanvasHotkeys": "Yhdistetyn kanvaan pikanäppäimet",
|
||||
"cancel": {
|
||||
"desc": "Peruuta kuvan luominen",
|
||||
"title": "Peruuta"
|
||||
},
|
||||
"invoke": {
|
||||
"desc": "Luo kuva"
|
||||
}
|
||||
}
|
||||
}
|
1
invokeai/frontend/web/dist/locales/mn.json
vendored
Normal file
1
invokeai/frontend/web/dist/locales/mn.json
vendored
Normal file
@ -0,0 +1 @@
|
||||
{}
|
254
invokeai/frontend/web/dist/locales/sv.json
vendored
Normal file
254
invokeai/frontend/web/dist/locales/sv.json
vendored
Normal file
@ -0,0 +1,254 @@
|
||||
{
|
||||
"accessibility": {
|
||||
"copyMetadataJson": "Kopiera metadata JSON",
|
||||
"zoomIn": "Zooma in",
|
||||
"exitViewer": "Avslutningsvisare",
|
||||
"modelSelect": "Välj modell",
|
||||
"uploadImage": "Ladda upp bild",
|
||||
"invokeProgressBar": "Invoke förloppsmätare",
|
||||
"nextImage": "Nästa bild",
|
||||
"toggleAutoscroll": "Växla automatisk rullning",
|
||||
"flipHorizontally": "Vänd vågrätt",
|
||||
"flipVertically": "Vänd lodrätt",
|
||||
"zoomOut": "Zooma ut",
|
||||
"toggleLogViewer": "Växla logvisare",
|
||||
"reset": "Starta om",
|
||||
"previousImage": "Föregående bild",
|
||||
"useThisParameter": "Använd denna parametern",
|
||||
"showGallery": "Visa galleri",
|
||||
"rotateCounterClockwise": "Rotera moturs",
|
||||
"rotateClockwise": "Rotera medurs",
|
||||
"modifyConfig": "Ändra konfiguration",
|
||||
"showOptionsPanel": "Visa inställningspanelen"
|
||||
},
|
||||
"common": {
|
||||
"hotkeysLabel": "Snabbtangenter",
|
||||
"reportBugLabel": "Rapportera bugg",
|
||||
"githubLabel": "Github",
|
||||
"discordLabel": "Discord",
|
||||
"settingsLabel": "Inställningar",
|
||||
"darkTheme": "Mörk",
|
||||
"lightTheme": "Ljus",
|
||||
"greenTheme": "Grön",
|
||||
"oceanTheme": "Hav",
|
||||
"langEnglish": "Engelska",
|
||||
"langDutch": "Nederländska",
|
||||
"langFrench": "Franska",
|
||||
"langGerman": "Tyska",
|
||||
"langItalian": "Italienska",
|
||||
"langArabic": "العربية",
|
||||
"langHebrew": "עברית",
|
||||
"langPolish": "Polski",
|
||||
"langPortuguese": "Português",
|
||||
"langBrPortuguese": "Português do Brasil",
|
||||
"langSimplifiedChinese": "简体中文",
|
||||
"langJapanese": "日本語",
|
||||
"langKorean": "한국어",
|
||||
"langRussian": "Русский",
|
||||
"unifiedCanvas": "Förenad kanvas",
|
||||
"nodesDesc": "Ett nodbaserat system för bildgenerering är under utveckling. Håll utkik för uppdateringar om denna fantastiska funktion.",
|
||||
"langUkranian": "Украї́нська",
|
||||
"langSpanish": "Español",
|
||||
"postProcessDesc2": "Ett dedikerat användargränssnitt kommer snart att släppas för att underlätta mer avancerade arbetsflöden av efterbehandling.",
|
||||
"trainingDesc1": "Ett dedikerat arbetsflöde för träning av dina egna inbäddningar och kontrollpunkter genom Textual Inversion eller Dreambooth från webbgränssnittet.",
|
||||
"trainingDesc2": "InvokeAI stöder redan träning av anpassade inbäddningar med hjälp av Textual Inversion genom huvudscriptet.",
|
||||
"upload": "Ladda upp",
|
||||
"close": "Stäng",
|
||||
"cancel": "Avbryt",
|
||||
"accept": "Acceptera",
|
||||
"statusDisconnected": "Frånkopplad",
|
||||
"statusGeneratingTextToImage": "Genererar text till bild",
|
||||
"statusGeneratingImageToImage": "Genererar Bild till bild",
|
||||
"statusGeneratingInpainting": "Genererar Måla i",
|
||||
"statusGenerationComplete": "Generering klar",
|
||||
"statusModelConverted": "Modell konverterad",
|
||||
"statusMergingModels": "Sammanfogar modeller",
|
||||
"pinOptionsPanel": "Nåla fast inställningspanelen",
|
||||
"loading": "Laddar",
|
||||
"loadingInvokeAI": "Laddar Invoke AI",
|
||||
"statusRestoringFaces": "Återskapar ansikten",
|
||||
"languagePickerLabel": "Språkväljare",
|
||||
"themeLabel": "Tema",
|
||||
"txt2img": "Text till bild",
|
||||
"nodes": "Noder",
|
||||
"img2img": "Bild till bild",
|
||||
"postprocessing": "Efterbehandling",
|
||||
"postProcessing": "Efterbehandling",
|
||||
"load": "Ladda",
|
||||
"training": "Träning",
|
||||
"postProcessDesc1": "Invoke AI erbjuder ett brett utbud av efterbehandlingsfunktioner. Uppskalning och ansiktsåterställning finns redan tillgängligt i webbgränssnittet. Du kommer åt dem ifrån Avancerade inställningar-menyn under Bild till bild-fliken. Du kan också behandla bilder direkt genom att använda knappen bildåtgärder ovanför nuvarande bild eller i bildvisaren.",
|
||||
"postProcessDesc3": "Invoke AI's kommandotolk erbjuder många olika funktioner, bland annat \"Förstora\".",
|
||||
"statusGenerating": "Genererar",
|
||||
"statusError": "Fel",
|
||||
"back": "Bakåt",
|
||||
"statusConnected": "Ansluten",
|
||||
"statusPreparing": "Förbereder",
|
||||
"statusProcessingCanceled": "Bearbetning avbruten",
|
||||
"statusProcessingComplete": "Bearbetning färdig",
|
||||
"statusGeneratingOutpainting": "Genererar Fyll ut",
|
||||
"statusIterationComplete": "Itterering klar",
|
||||
"statusSavingImage": "Sparar bild",
|
||||
"statusRestoringFacesGFPGAN": "Återskapar ansikten (GFPGAN)",
|
||||
"statusRestoringFacesCodeFormer": "Återskapar ansikten (CodeFormer)",
|
||||
"statusUpscaling": "Skala upp",
|
||||
"statusUpscalingESRGAN": "Uppskalning (ESRGAN)",
|
||||
"statusModelChanged": "Modell ändrad",
|
||||
"statusLoadingModel": "Laddar modell",
|
||||
"statusConvertingModel": "Konverterar modell",
|
||||
"statusMergedModels": "Modeller sammanfogade"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generationer",
|
||||
"showGenerations": "Visa generationer",
|
||||
"uploads": "Uppladdningar",
|
||||
"showUploads": "Visa uppladdningar",
|
||||
"galleryImageSize": "Bildstorlek",
|
||||
"allImagesLoaded": "Alla bilder laddade",
|
||||
"loadMore": "Ladda mer",
|
||||
"galleryImageResetSize": "Återställ storlek",
|
||||
"gallerySettings": "Galleriinställningar",
|
||||
"maintainAspectRatio": "Behåll bildförhållande",
|
||||
"pinGallery": "Nåla fast galleri",
|
||||
"noImagesInGallery": "Inga bilder i galleriet",
|
||||
"autoSwitchNewImages": "Ändra automatiskt till nya bilder",
|
||||
"singleColumnLayout": "Enkolumnslayout"
|
||||
},
|
||||
"hotkeys": {
|
||||
"generalHotkeys": "Allmänna snabbtangenter",
|
||||
"galleryHotkeys": "Gallerisnabbtangenter",
|
||||
"unifiedCanvasHotkeys": "Snabbtangenter för sammanslagskanvas",
|
||||
"invoke": {
|
||||
"title": "Anropa",
|
||||
"desc": "Genererar en bild"
|
||||
},
|
||||
"cancel": {
|
||||
"title": "Avbryt",
|
||||
"desc": "Avbryt bildgenerering"
|
||||
},
|
||||
"focusPrompt": {
|
||||
"desc": "Fokusera området för promptinmatning",
|
||||
"title": "Fokusprompt"
|
||||
},
|
||||
"pinOptions": {
|
||||
"desc": "Nåla fast alternativpanelen",
|
||||
"title": "Nåla fast alternativ"
|
||||
},
|
||||
"toggleOptions": {
|
||||
"title": "Växla inställningar",
|
||||
"desc": "Öppna och stäng alternativpanelen"
|
||||
},
|
||||
"toggleViewer": {
|
||||
"title": "Växla visaren",
|
||||
"desc": "Öppna och stäng bildvisaren"
|
||||
},
|
||||
"toggleGallery": {
|
||||
"title": "Växla galleri",
|
||||
"desc": "Öppna eller stäng galleribyrån"
|
||||
},
|
||||
"maximizeWorkSpace": {
|
||||
"title": "Maximera arbetsyta",
|
||||
"desc": "Stäng paneler och maximera arbetsyta"
|
||||
},
|
||||
"changeTabs": {
|
||||
"title": "Växla flik",
|
||||
"desc": "Byt till en annan arbetsyta"
|
||||
},
|
||||
"consoleToggle": {
|
||||
"title": "Växla konsol",
|
||||
"desc": "Öppna och stäng konsol"
|
||||
},
|
||||
"setSeed": {
|
||||
"desc": "Använd seed för nuvarande bild",
|
||||
"title": "välj seed"
|
||||
},
|
||||
"setParameters": {
|
||||
"title": "Välj parametrar",
|
||||
"desc": "Använd alla parametrar från nuvarande bild"
|
||||
},
|
||||
"setPrompt": {
|
||||
"desc": "Använd prompt för nuvarande bild",
|
||||
"title": "Välj prompt"
|
||||
},
|
||||
"restoreFaces": {
|
||||
"title": "Återskapa ansikten",
|
||||
"desc": "Återskapa nuvarande bild"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Skala upp",
|
||||
"desc": "Skala upp nuvarande bild"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Visa info",
|
||||
"desc": "Visa metadata för nuvarande bild"
|
||||
},
|
||||
"sendToImageToImage": {
|
||||
"title": "Skicka till Bild till bild",
|
||||
"desc": "Skicka nuvarande bild till Bild till bild"
|
||||
},
|
||||
"deleteImage": {
|
||||
"title": "Radera bild",
|
||||
"desc": "Radera nuvarande bild"
|
||||
},
|
||||
"closePanels": {
|
||||
"title": "Stäng paneler",
|
||||
"desc": "Stäng öppna paneler"
|
||||
},
|
||||
"previousImage": {
|
||||
"title": "Föregående bild",
|
||||
"desc": "Visa föregående bild"
|
||||
},
|
||||
"nextImage": {
|
||||
"title": "Nästa bild",
|
||||
"desc": "Visa nästa bild"
|
||||
},
|
||||
"toggleGalleryPin": {
|
||||
"title": "Växla gallerinål",
|
||||
"desc": "Nålar fast eller nålar av galleriet i gränssnittet"
|
||||
},
|
||||
"increaseGalleryThumbSize": {
|
||||
"title": "Förstora galleriets bildstorlek",
|
||||
"desc": "Förstora miniatyrbildernas storlek"
|
||||
},
|
||||
"decreaseGalleryThumbSize": {
|
||||
"title": "Minska gelleriets bildstorlek",
|
||||
"desc": "Minska miniatyrbildernas storlek i galleriet"
|
||||
},
|
||||
"decreaseBrushSize": {
|
||||
"desc": "Förminska storleken på kanvas- pensel eller suddgummi",
|
||||
"title": "Minska penselstorlek"
|
||||
},
|
||||
"increaseBrushSize": {
|
||||
"title": "Öka penselstorlek",
|
||||
"desc": "Öka stoleken på kanvas- pensel eller suddgummi"
|
||||
},
|
||||
"increaseBrushOpacity": {
|
||||
"title": "Öka penselns opacitet",
|
||||
"desc": "Öka opaciteten för kanvaspensel"
|
||||
},
|
||||
"decreaseBrushOpacity": {
|
||||
"desc": "Minska kanvaspenselns opacitet",
|
||||
"title": "Minska penselns opacitet"
|
||||
},
|
||||
"moveTool": {
|
||||
"title": "Flytta",
|
||||
"desc": "Tillåt kanvasnavigation"
|
||||
},
|
||||
"fillBoundingBox": {
|
||||
"title": "Fyll ram",
|
||||
"desc": "Fyller ramen med pensels färg"
|
||||
},
|
||||
"keyboardShortcuts": "Snabbtangenter",
|
||||
"appHotkeys": "Appsnabbtangenter",
|
||||
"selectBrush": {
|
||||
"desc": "Välj kanvaspensel",
|
||||
"title": "Välj pensel"
|
||||
},
|
||||
"selectEraser": {
|
||||
"desc": "Välj kanvassuddgummi",
|
||||
"title": "Välj suddgummi"
|
||||
},
|
||||
"eraseBoundingBox": {
|
||||
"title": "Ta bort ram"
|
||||
}
|
||||
}
|
||||
}
|
64
invokeai/frontend/web/dist/locales/tr.json
vendored
Normal file
64
invokeai/frontend/web/dist/locales/tr.json
vendored
Normal file
@ -0,0 +1,64 @@
|
||||
{
|
||||
"accessibility": {
|
||||
"invokeProgressBar": "Invoke ilerleme durumu",
|
||||
"nextImage": "Sonraki Resim",
|
||||
"useThisParameter": "Kullanıcı parametreleri",
|
||||
"copyMetadataJson": "Metadata verilerini kopyala (JSON)",
|
||||
"exitViewer": "Görüntüleme Modundan Çık",
|
||||
"zoomIn": "Yakınlaştır",
|
||||
"zoomOut": "Uzaklaştır",
|
||||
"rotateCounterClockwise": "Döndür (Saat yönünün tersine)",
|
||||
"rotateClockwise": "Döndür (Saat yönünde)",
|
||||
"flipHorizontally": "Yatay Çevir",
|
||||
"flipVertically": "Dikey Çevir",
|
||||
"modifyConfig": "Ayarları Değiştir",
|
||||
"toggleAutoscroll": "Otomatik kaydırmayı aç/kapat",
|
||||
"toggleLogViewer": "Günlük Görüntüleyici Aç/Kapa",
|
||||
"showOptionsPanel": "Ayarlar Panelini Göster",
|
||||
"modelSelect": "Model Seçin",
|
||||
"reset": "Sıfırla",
|
||||
"uploadImage": "Resim Yükle",
|
||||
"previousImage": "Önceki Resim",
|
||||
"menu": "Menü",
|
||||
"showGallery": "Galeriyi Göster"
|
||||
},
|
||||
"common": {
|
||||
"hotkeysLabel": "Kısayol Tuşları",
|
||||
"themeLabel": "Tema",
|
||||
"languagePickerLabel": "Dil Seçimi",
|
||||
"reportBugLabel": "Hata Bildir",
|
||||
"githubLabel": "Github",
|
||||
"discordLabel": "Discord",
|
||||
"settingsLabel": "Ayarlar",
|
||||
"darkTheme": "Karanlık Tema",
|
||||
"lightTheme": "Aydınlık Tema",
|
||||
"greenTheme": "Yeşil Tema",
|
||||
"oceanTheme": "Okyanus Tema",
|
||||
"langArabic": "Arapça",
|
||||
"langEnglish": "İngilizce",
|
||||
"langDutch": "Hollandaca",
|
||||
"langFrench": "Fransızca",
|
||||
"langGerman": "Almanca",
|
||||
"langItalian": "İtalyanca",
|
||||
"langJapanese": "Japonca",
|
||||
"langPolish": "Lehçe",
|
||||
"langPortuguese": "Portekizce",
|
||||
"langBrPortuguese": "Portekizcr (Brezilya)",
|
||||
"langRussian": "Rusça",
|
||||
"langSimplifiedChinese": "Çince (Basit)",
|
||||
"langUkranian": "Ukraynaca",
|
||||
"langSpanish": "İspanyolca",
|
||||
"txt2img": "Metinden Resime",
|
||||
"img2img": "Resimden Metine",
|
||||
"linear": "Çizgisel",
|
||||
"nodes": "Düğümler",
|
||||
"postprocessing": "İşlem Sonrası",
|
||||
"postProcessing": "İşlem Sonrası",
|
||||
"postProcessDesc2": "Daha gelişmiş özellikler için ve iş akışını kolaylaştırmak için özel bir kullanıcı arayüzü çok yakında yayınlanacaktır.",
|
||||
"postProcessDesc3": "Invoke AI komut satırı arayüzü, bir çok yeni özellik sunmaktadır.",
|
||||
"langKorean": "Korece",
|
||||
"unifiedCanvas": "Akıllı Tuval",
|
||||
"nodesDesc": "Görüntülerin oluşturulmasında hazırladığımız yeni bir sistem geliştirme aşamasındadır. Bu harika özellikler ve çok daha fazlası için bizi takip etmeye devam edin.",
|
||||
"postProcessDesc1": "Invoke AI son kullanıcıya yönelik bir çok özellik sunar. Görüntü kalitesi yükseltme, yüz restorasyonu WebUI üzerinden kullanılabilir. Metinden resime ve resimden metne araçlarına gelişmiş seçenekler menüsünden ulaşabilirsiniz. İsterseniz mevcut görüntü ekranının üzerindeki veya görüntüleyicideki görüntüyü doğrudan düzenleyebilirsiniz."
|
||||
}
|
||||
}
|
1
invokeai/frontend/web/dist/locales/vi.json
vendored
Normal file
1
invokeai/frontend/web/dist/locales/vi.json
vendored
Normal file
@ -0,0 +1 @@
|
||||
{}
|
@ -23,7 +23,7 @@
|
||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
|
||||
"build": "yarn run lint && vite build",
|
||||
"typegen": "npx openapi-typescript http://localhost:9090/openapi.json --output src/services/api/schema.d.ts -t",
|
||||
"typegen": "npx ts-node scripts/typegen.ts",
|
||||
"preview": "vite preview",
|
||||
"lint:madge": "madge --circular src/main.tsx",
|
||||
"lint:eslint": "eslint --max-warnings=0 .",
|
||||
@ -83,7 +83,7 @@
|
||||
"konva": "^9.2.0",
|
||||
"lodash-es": "^4.17.21",
|
||||
"nanostores": "^0.9.2",
|
||||
"openapi-fetch": "0.4.0",
|
||||
"openapi-fetch": "^0.6.1",
|
||||
"overlayscrollbars": "^2.2.0",
|
||||
"overlayscrollbars-react": "^0.5.0",
|
||||
"patch-package": "^7.0.0",
|
||||
|
@ -1,55 +0,0 @@
|
||||
diff --git a/node_modules/openapi-fetch/dist/index.js b/node_modules/openapi-fetch/dist/index.js
|
||||
index cd4528a..8976b51 100644
|
||||
--- a/node_modules/openapi-fetch/dist/index.js
|
||||
+++ b/node_modules/openapi-fetch/dist/index.js
|
||||
@@ -1,5 +1,5 @@
|
||||
// settings & const
|
||||
-const DEFAULT_HEADERS = {
|
||||
+const CONTENT_TYPE_APPLICATION_JSON = {
|
||||
"Content-Type": "application/json",
|
||||
};
|
||||
const TRAILING_SLASH_RE = /\/*$/;
|
||||
@@ -29,18 +29,29 @@ export function createFinalURL(url, options) {
|
||||
}
|
||||
return finalURL;
|
||||
}
|
||||
+function stringifyBody(body) {
|
||||
+ if (body instanceof ArrayBuffer || body instanceof File || body instanceof DataView || body instanceof Blob || ArrayBuffer.isView(body) || body instanceof URLSearchParams || body instanceof FormData) {
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
+ if (typeof body === "string") {
|
||||
+ return body;
|
||||
+ }
|
||||
+
|
||||
+ return JSON.stringify(body);
|
||||
+ }
|
||||
+
|
||||
export default function createClient(clientOptions = {}) {
|
||||
const { fetch = globalThis.fetch, ...options } = clientOptions;
|
||||
- const defaultHeaders = new Headers({
|
||||
- ...DEFAULT_HEADERS,
|
||||
- ...(options.headers ?? {}),
|
||||
- });
|
||||
+ const defaultHeaders = new Headers(options.headers ?? {});
|
||||
async function coreFetch(url, fetchOptions) {
|
||||
const { headers, body: requestBody, params = {}, parseAs = "json", querySerializer = defaultSerializer, ...init } = fetchOptions || {};
|
||||
// URL
|
||||
const finalURL = createFinalURL(url, { baseUrl: options.baseUrl, params, querySerializer });
|
||||
+ // Stringify body if needed
|
||||
+ const stringifiedBody = stringifyBody(requestBody);
|
||||
// headers
|
||||
- const baseHeaders = new Headers(defaultHeaders); // clone defaults (don’t overwrite!)
|
||||
+ const baseHeaders = new Headers(stringifiedBody ? { ...CONTENT_TYPE_APPLICATION_JSON, ...defaultHeaders } : defaultHeaders); // clone defaults (don’t overwrite!)
|
||||
const headerOverrides = new Headers(headers);
|
||||
for (const [k, v] of headerOverrides.entries()) {
|
||||
if (v === undefined || v === null)
|
||||
@@ -54,7 +65,7 @@ export default function createClient(clientOptions = {}) {
|
||||
...options,
|
||||
...init,
|
||||
headers: baseHeaders,
|
||||
- body: typeof requestBody === "string" ? requestBody : JSON.stringify(requestBody),
|
||||
+ body: stringifiedBody ?? requestBody,
|
||||
});
|
||||
// handle empty content
|
||||
// note: we return `{}` because we want user truthy checks for `.data` or `.error` to succeed
|
@ -527,7 +527,8 @@
|
||||
"showOptionsPanel": "Show Options Panel",
|
||||
"hidePreview": "Hide Preview",
|
||||
"showPreview": "Show Preview",
|
||||
"controlNetControlMode": "Control Mode"
|
||||
"controlNetControlMode": "Control Mode",
|
||||
"clipSkip": "Clip Skip"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Models",
|
||||
@ -551,7 +552,8 @@
|
||||
"generation": "Generation",
|
||||
"ui": "User Interface",
|
||||
"favoriteSchedulers": "Favorite Schedulers",
|
||||
"favoriteSchedulersPlaceholder": "No schedulers favorited"
|
||||
"favoriteSchedulersPlaceholder": "No schedulers favorited",
|
||||
"showAdvancedOptions": "Show Advanced Options"
|
||||
},
|
||||
"toast": {
|
||||
"serverError": "Server Error",
|
||||
|
3
invokeai/frontend/web/scripts/package.json
Normal file
3
invokeai/frontend/web/scripts/package.json
Normal file
@ -0,0 +1,3 @@
|
||||
{
|
||||
"type": "module"
|
||||
}
|
23
invokeai/frontend/web/scripts/typegen.ts
Normal file
23
invokeai/frontend/web/scripts/typegen.ts
Normal file
@ -0,0 +1,23 @@
|
||||
import fs from 'node:fs';
|
||||
import openapiTS from 'openapi-typescript';
|
||||
|
||||
const OPENAPI_URL = 'http://localhost:9090/openapi.json';
|
||||
const OUTPUT_FILE = 'src/services/api/schema.d.ts';
|
||||
|
||||
async function main() {
|
||||
process.stdout.write(
|
||||
`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...`
|
||||
);
|
||||
const types = await openapiTS(OPENAPI_URL, {
|
||||
exportType: true,
|
||||
transform: (schemaObject, metadata) => {
|
||||
if ('format' in schemaObject && schemaObject.format === 'binary') {
|
||||
return schemaObject.nullable ? 'Blob | null' : 'Blob';
|
||||
}
|
||||
},
|
||||
});
|
||||
fs.writeFileSync(OUTPUT_FILE, types);
|
||||
process.stdout.write(` OK!\r\n`);
|
||||
}
|
||||
|
||||
main();
|
@ -1,5 +1,6 @@
|
||||
import { Flex, Grid, Portal } from '@chakra-ui/react';
|
||||
import { useLogger } from 'app/logging/useLogger';
|
||||
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { PartialAppConfig } from 'app/types/invokeai';
|
||||
import ImageUploader from 'common/components/ImageUploader';
|
||||
@ -46,6 +47,10 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
||||
dispatch(configChanged(config));
|
||||
}, [dispatch, config, log]);
|
||||
|
||||
useEffect(() => {
|
||||
dispatch(appStarted());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Grid w="100vw" h="100vh" position="relative" overflow="hidden">
|
||||
|
@ -55,6 +55,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
|
||||
}
|
||||
|
||||
if (props.dragData.payloadType === 'IMAGE_DTO') {
|
||||
const { thumbnail_url, width, height } = props.dragData.payload.imageDTO;
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
@ -72,7 +73,10 @@ const DragPreview = (props: OverlayDragImageProps) => {
|
||||
sx={{
|
||||
...STYLES,
|
||||
}}
|
||||
src={props.dragData.payload.imageDTO.thumbnail_url}
|
||||
objectFit="contain"
|
||||
src={thumbnail_url}
|
||||
width={width}
|
||||
height={height}
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
|
@ -9,4 +9,5 @@ export const actionsDenylist = [
|
||||
'canvas/addPointToCurrentLine',
|
||||
'socket/socketGeneratorProgress',
|
||||
'socket/appSocketGeneratorProgress',
|
||||
'hotkeys/shiftKeyPressed',
|
||||
];
|
||||
|
@ -1,49 +1,67 @@
|
||||
import type { TypedAddListener, TypedStartListening } from '@reduxjs/toolkit';
|
||||
import {
|
||||
createListenerMiddleware,
|
||||
addListener,
|
||||
ListenerEffect,
|
||||
AnyAction,
|
||||
ListenerEffect,
|
||||
addListener,
|
||||
createListenerMiddleware,
|
||||
} from '@reduxjs/toolkit';
|
||||
import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit';
|
||||
|
||||
import type { RootState, AppDispatch } from '../../store';
|
||||
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
|
||||
import type { AppDispatch, RootState } from '../../store';
|
||||
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
|
||||
import { addAppStartedListener } from './listeners/appStarted';
|
||||
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
|
||||
import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesDeleted';
|
||||
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
|
||||
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
|
||||
import { addCanvasMergedListener } from './listeners/canvasMerged';
|
||||
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
|
||||
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
|
||||
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
|
||||
import {
|
||||
addImageUploadedFulfilledListener,
|
||||
addImageUploadedRejectedListener,
|
||||
} from './listeners/imageUploaded';
|
||||
addImageAddedToBoardFulfilledListener,
|
||||
addImageAddedToBoardRejectedListener,
|
||||
} from './listeners/imageAddedToBoard';
|
||||
import {
|
||||
addImageDeletedFulfilledListener,
|
||||
addImageDeletedPendingListener,
|
||||
addImageDeletedRejectedListener,
|
||||
addRequestedImageDeletionListener,
|
||||
} from './listeners/imageDeleted';
|
||||
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
|
||||
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
|
||||
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
|
||||
import { addCanvasMergedListener } from './listeners/canvasMerged';
|
||||
import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
|
||||
import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete';
|
||||
import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete';
|
||||
import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
|
||||
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
|
||||
import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
|
||||
import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
|
||||
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
|
||||
import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
|
||||
import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke';
|
||||
import { addImageDroppedListener } from './listeners/imageDropped';
|
||||
import {
|
||||
addImageMetadataReceivedFulfilledListener,
|
||||
addImageMetadataReceivedRejectedListener,
|
||||
} from './listeners/imageMetadataReceived';
|
||||
import {
|
||||
addImageRemovedFromBoardFulfilledListener,
|
||||
addImageRemovedFromBoardRejectedListener,
|
||||
} from './listeners/imageRemovedFromBoard';
|
||||
import { addImageToDeleteSelectedListener } from './listeners/imageToDeleteSelected';
|
||||
import {
|
||||
addImageUpdatedFulfilledListener,
|
||||
addImageUpdatedRejectedListener,
|
||||
} from './listeners/imageUpdated';
|
||||
import {
|
||||
addImageUploadedFulfilledListener,
|
||||
addImageUploadedRejectedListener,
|
||||
} from './listeners/imageUploaded';
|
||||
import {
|
||||
addImageUrlsReceivedFulfilledListener,
|
||||
addImageUrlsReceivedRejectedListener,
|
||||
} from './listeners/imageUrlsReceived';
|
||||
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
|
||||
import { addModelSelectedListener } from './listeners/modelSelected';
|
||||
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
|
||||
import {
|
||||
addReceivedPageOfImagesFulfilledListener,
|
||||
addReceivedPageOfImagesRejectedListener,
|
||||
} from './listeners/receivedPageOfImages';
|
||||
import { addSelectionAddedToBatchListener } from './listeners/selectionAddedToBatch';
|
||||
import {
|
||||
addSessionCanceledFulfilledListener,
|
||||
addSessionCanceledPendingListener,
|
||||
addSessionCanceledRejectedListener,
|
||||
} from './listeners/sessionCanceled';
|
||||
import {
|
||||
addSessionCreatedFulfilledListener,
|
||||
addSessionCreatedPendingListener,
|
||||
@ -54,38 +72,21 @@ import {
|
||||
addSessionInvokedPendingListener,
|
||||
addSessionInvokedRejectedListener,
|
||||
} from './listeners/sessionInvoked';
|
||||
import {
|
||||
addSessionCanceledFulfilledListener,
|
||||
addSessionCanceledPendingListener,
|
||||
addSessionCanceledRejectedListener,
|
||||
} from './listeners/sessionCanceled';
|
||||
import {
|
||||
addImageUpdatedFulfilledListener,
|
||||
addImageUpdatedRejectedListener,
|
||||
} from './listeners/imageUpdated';
|
||||
import {
|
||||
addReceivedPageOfImagesFulfilledListener,
|
||||
addReceivedPageOfImagesRejectedListener,
|
||||
} from './listeners/receivedPageOfImages';
|
||||
import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke';
|
||||
import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
|
||||
import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
|
||||
import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
|
||||
import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete';
|
||||
import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete';
|
||||
import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
|
||||
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
|
||||
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
|
||||
import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
|
||||
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
|
||||
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
|
||||
import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged';
|
||||
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
|
||||
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
|
||||
import {
|
||||
addImageAddedToBoardFulfilledListener,
|
||||
addImageAddedToBoardRejectedListener,
|
||||
} from './listeners/imageAddedToBoard';
|
||||
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
|
||||
import {
|
||||
addImageRemovedFromBoardFulfilledListener,
|
||||
addImageRemovedFromBoardRejectedListener,
|
||||
} from './listeners/imageRemovedFromBoard';
|
||||
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
|
||||
import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesDeleted';
|
||||
import { addSelectionAddedToBatchListener } from './listeners/selectionAddedToBatch';
|
||||
import { addImageDroppedListener } from './listeners/imageDropped';
|
||||
import { addImageToDeleteSelectedListener } from './listeners/imageToDeleteSelected';
|
||||
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -195,9 +196,6 @@ addSessionCanceledRejectedListener();
|
||||
addReceivedPageOfImagesFulfilledListener();
|
||||
addReceivedPageOfImagesRejectedListener();
|
||||
|
||||
// Gallery
|
||||
addImageCategoriesChangedListener();
|
||||
|
||||
// ControlNet
|
||||
addControlNetImageProcessedListener();
|
||||
addControlNetAutoProcessListener();
|
||||
@ -220,3 +218,9 @@ addSelectionAddedToBatchListener();
|
||||
|
||||
// DND
|
||||
addImageDroppedListener();
|
||||
|
||||
// Models
|
||||
addModelSelectedListener();
|
||||
|
||||
// app startup
|
||||
addAppStartedListener();
|
||||
|
@ -0,0 +1,43 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import {
|
||||
INITIAL_IMAGE_LIMIT,
|
||||
isLoadingChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { receivedPageOfImages } from 'services/api/thunks/image';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const appStarted = createAction('app/appStarted');
|
||||
|
||||
export const addAppStartedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: appStarted,
|
||||
effect: async (
|
||||
action,
|
||||
{ getState, dispatch, unsubscribe, cancelActiveListeners }
|
||||
) => {
|
||||
cancelActiveListeners();
|
||||
unsubscribe();
|
||||
// fill up the gallery tab with images
|
||||
await dispatch(
|
||||
receivedPageOfImages({
|
||||
categories: ['general'],
|
||||
is_intermediate: false,
|
||||
offset: 0,
|
||||
limit: INITIAL_IMAGE_LIMIT,
|
||||
})
|
||||
);
|
||||
|
||||
// fill up the assets tab with images
|
||||
await dispatch(
|
||||
receivedPageOfImages({
|
||||
categories: ['control', 'mask', 'user', 'other'],
|
||||
is_intermediate: false,
|
||||
offset: 0,
|
||||
limit: INITIAL_IMAGE_LIMIT,
|
||||
})
|
||||
);
|
||||
|
||||
dispatch(isLoadingChanged(false));
|
||||
},
|
||||
});
|
||||
};
|
@ -1,29 +0,0 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { receivedPageOfImages } from 'services/api/thunks/image';
|
||||
import {
|
||||
imageCategoriesChanged,
|
||||
selectFilteredImages,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'gallery' });
|
||||
|
||||
export const addImageCategoriesChangedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: imageCategoriesChanged,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const filteredImagesCount = selectFilteredImages(state).length;
|
||||
|
||||
if (!filteredImagesCount) {
|
||||
dispatch(
|
||||
receivedPageOfImages({
|
||||
categories: action.payload,
|
||||
board_id: state.boards.selectedBoardId,
|
||||
is_intermediate: false,
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,42 @@
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import {
|
||||
modelChanged,
|
||||
vaeSelected,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { zMainModel } from 'features/parameters/store/parameterZodSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { startAppListening } from '..';
|
||||
import { lorasCleared } from '../../../../../features/lora/store/loraSlice';
|
||||
|
||||
export const addModelSelectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: modelSelected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const [base_model, type, name] = action.payload.split('/');
|
||||
|
||||
if (state.generation.model?.base_model !== base_model) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Base model changed, clearing submodels',
|
||||
status: 'warning',
|
||||
})
|
||||
)
|
||||
);
|
||||
dispatch(vaeSelected(null));
|
||||
dispatch(lorasCleared());
|
||||
// TODO: controlnet cleared
|
||||
}
|
||||
|
||||
const newModel = zMainModel.parse({
|
||||
id: action.payload,
|
||||
base_model,
|
||||
name,
|
||||
});
|
||||
|
||||
dispatch(modelChanged(newModel));
|
||||
},
|
||||
});
|
||||
};
|
@ -93,7 +93,8 @@ export type AppFeature =
|
||||
| 'discordLink'
|
||||
| 'bugLink'
|
||||
| 'localization'
|
||||
| 'consoleLogging';
|
||||
| 'consoleLogging'
|
||||
| 'dynamicPrompting';
|
||||
|
||||
/**
|
||||
* A disable-able Stable Diffusion feature
|
||||
@ -104,7 +105,10 @@ export type SDFeature =
|
||||
| 'variation'
|
||||
| 'symmetry'
|
||||
| 'seamless'
|
||||
| 'hires';
|
||||
| 'hires'
|
||||
| 'lora'
|
||||
| 'embedding'
|
||||
| 'vae';
|
||||
|
||||
/**
|
||||
* Configuration options for the InvokeAI UI.
|
||||
|
@ -5,8 +5,10 @@ import {
|
||||
Input,
|
||||
InputProps,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { stopPastePropagation } from 'common/util/stopPastePropagation';
|
||||
import { ChangeEvent, memo } from 'react';
|
||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||
import { ChangeEvent, KeyboardEvent, memo, useCallback } from 'react';
|
||||
|
||||
interface IAIInputProps extends InputProps {
|
||||
label?: string;
|
||||
@ -25,6 +27,25 @@ const IAIInput = (props: IAIInputProps) => {
|
||||
...rest
|
||||
} = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.shiftKey) {
|
||||
dispatch(shiftKeyPressed(true));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleKeyUp = useCallback(
|
||||
(e: KeyboardEvent<HTMLInputElement>) => {
|
||||
if (!e.shiftKey) {
|
||||
dispatch(shiftKeyPressed(false));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl
|
||||
isInvalid={isInvalid}
|
||||
@ -32,7 +53,12 @@ const IAIInput = (props: IAIInputProps) => {
|
||||
{...formControlProps}
|
||||
>
|
||||
{label !== '' && <FormLabel>{label}</FormLabel>}
|
||||
<Input {...rest} onPaste={stopPastePropagation} />
|
||||
<Input
|
||||
{...rest}
|
||||
onPaste={stopPastePropagation}
|
||||
onKeyDown={handleKeyDown}
|
||||
onKeyUp={handleKeyUp}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
@ -1,7 +1,9 @@
|
||||
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
|
||||
import { MultiSelect, MultiSelectProps } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||
import { RefObject, memo } from 'react';
|
||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||
import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
|
||||
import { mode } from 'theme/util/mode';
|
||||
|
||||
type IAIMultiSelectProps = MultiSelectProps & {
|
||||
@ -11,6 +13,7 @@ type IAIMultiSelectProps = MultiSelectProps & {
|
||||
|
||||
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
const { searchable = true, tooltip, inputRef, ...rest } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const {
|
||||
base50,
|
||||
base100,
|
||||
@ -31,11 +34,32 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
const [boxShadow] = useToken('shadows', ['dark-lg']);
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.shiftKey) {
|
||||
dispatch(shiftKeyPressed(true));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleKeyUp = useCallback(
|
||||
(e: KeyboardEvent<HTMLInputElement>) => {
|
||||
if (!e.shiftKey) {
|
||||
dispatch(shiftKeyPressed(false));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
|
||||
<MultiSelect
|
||||
ref={inputRef}
|
||||
onKeyDown={handleKeyDown}
|
||||
onKeyUp={handleKeyUp}
|
||||
searchable={searchable}
|
||||
maxDropdownHeight={300}
|
||||
styles={() => ({
|
||||
label: {
|
||||
color: mode(base700, base300)(colorMode),
|
||||
@ -66,6 +90,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
'&[data-disabled]': {
|
||||
backgroundColor: mode(base300, base700)(colorMode),
|
||||
color: mode(base600, base400)(colorMode),
|
||||
cursor: 'not-allowed',
|
||||
},
|
||||
},
|
||||
value: {
|
||||
@ -108,6 +133,10 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
color: mode('white', base50)(colorMode),
|
||||
},
|
||||
},
|
||||
'&[data-disabled]': {
|
||||
color: mode(base500, base600)(colorMode),
|
||||
cursor: 'not-allowed',
|
||||
},
|
||||
},
|
||||
rightSection: {
|
||||
width: 24,
|
||||
|
@ -1,7 +1,9 @@
|
||||
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
|
||||
import { Select, SelectProps } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||
import { memo } from 'react';
|
||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||
import { KeyboardEvent, RefObject, memo, useCallback, useState } from 'react';
|
||||
import { mode } from 'theme/util/mode';
|
||||
|
||||
export type IAISelectDataType = {
|
||||
@ -12,10 +14,12 @@ export type IAISelectDataType = {
|
||||
|
||||
type IAISelectProps = SelectProps & {
|
||||
tooltip?: string;
|
||||
inputRef?: RefObject<HTMLInputElement>;
|
||||
};
|
||||
|
||||
const IAIMantineSelect = (props: IAISelectProps) => {
|
||||
const { searchable = true, tooltip, ...rest } = props;
|
||||
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const {
|
||||
base50,
|
||||
base100,
|
||||
@ -35,13 +39,54 @@ const IAIMantineSelect = (props: IAISelectProps) => {
|
||||
} = useChakraThemeTokens();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const [searchValue, setSearchValue] = useState('');
|
||||
|
||||
// we want to capture shift keypressed even when an input is focused
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.shiftKey) {
|
||||
dispatch(shiftKeyPressed(true));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleKeyUp = useCallback(
|
||||
(e: KeyboardEvent<HTMLInputElement>) => {
|
||||
if (!e.shiftKey) {
|
||||
dispatch(shiftKeyPressed(false));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
// wrap onChange to clear search value on select
|
||||
const handleChange = useCallback(
|
||||
(v: string | null) => {
|
||||
setSearchValue('');
|
||||
|
||||
if (!onChange) {
|
||||
return;
|
||||
}
|
||||
|
||||
onChange(v);
|
||||
},
|
||||
[onChange]
|
||||
);
|
||||
|
||||
const [boxShadow] = useToken('shadows', ['dark-lg']);
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<Select
|
||||
ref={inputRef}
|
||||
searchValue={searchValue}
|
||||
onSearchChange={setSearchValue}
|
||||
onChange={handleChange}
|
||||
onKeyDown={handleKeyDown}
|
||||
onKeyUp={handleKeyUp}
|
||||
searchable={searchable}
|
||||
maxDropdownHeight={300}
|
||||
styles={() => ({
|
||||
label: {
|
||||
color: mode(base700, base300)(colorMode),
|
||||
@ -67,6 +112,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
|
||||
'&[data-disabled]': {
|
||||
backgroundColor: mode(base300, base700)(colorMode),
|
||||
color: mode(base600, base400)(colorMode),
|
||||
cursor: 'not-allowed',
|
||||
},
|
||||
},
|
||||
value: {
|
||||
@ -109,6 +155,10 @@ const IAIMantineSelect = (props: IAISelectProps) => {
|
||||
color: mode('white', base50)(colorMode),
|
||||
},
|
||||
},
|
||||
'&[data-disabled]': {
|
||||
color: mode(base500, base600)(colorMode),
|
||||
cursor: 'not-allowed',
|
||||
},
|
||||
},
|
||||
rightSection: {
|
||||
width: 32,
|
||||
|
@ -0,0 +1,31 @@
|
||||
import { Box, Tooltip } from '@chakra-ui/react';
|
||||
import { Text } from '@mantine/core';
|
||||
import { forwardRef, memo } from 'react';
|
||||
|
||||
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||
label: string;
|
||||
description?: string;
|
||||
tooltip?: string;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
const IAIMantineSelectItemWithTooltip = forwardRef<HTMLDivElement, ItemProps>(
|
||||
({ label, tooltip, description, disabled, ...others }: ItemProps, ref) => (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<Box ref={ref} {...others}>
|
||||
<Box>
|
||||
<Text>{label}</Text>
|
||||
{description && (
|
||||
<Text size="xs" color="base.600">
|
||||
{description}
|
||||
</Text>
|
||||
)}
|
||||
</Box>
|
||||
</Box>
|
||||
</Tooltip>
|
||||
)
|
||||
);
|
||||
|
||||
IAIMantineSelectItemWithTooltip.displayName = 'IAIMantineSelectItemWithTooltip';
|
||||
|
||||
export default memo(IAIMantineSelectItemWithTooltip);
|
@ -14,10 +14,19 @@ import {
|
||||
Tooltip,
|
||||
TooltipProps,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { stopPastePropagation } from 'common/util/stopPastePropagation';
|
||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||
import { clamp } from 'lodash-es';
|
||||
|
||||
import { FocusEvent, memo, useEffect, useState } from 'react';
|
||||
import {
|
||||
FocusEvent,
|
||||
KeyboardEvent,
|
||||
memo,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useState,
|
||||
} from 'react';
|
||||
|
||||
const numberStringRegex = /^-?(0\.)?\.?$/;
|
||||
|
||||
@ -60,6 +69,8 @@ const IAINumberInput = (props: Props) => {
|
||||
...rest
|
||||
} = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
/**
|
||||
* Using a controlled input with a value that accepts decimals needs special
|
||||
* handling. If the user starts to type in "1.5", by the time they press the
|
||||
@ -109,6 +120,24 @@ const IAINumberInput = (props: Props) => {
|
||||
onChange(clamped);
|
||||
};
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.shiftKey) {
|
||||
dispatch(shiftKeyPressed(true));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleKeyUp = useCallback(
|
||||
(e: KeyboardEvent<HTMLInputElement>) => {
|
||||
if (!e.shiftKey) {
|
||||
dispatch(shiftKeyPressed(false));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Tooltip {...tooltipProps}>
|
||||
<FormControl
|
||||
@ -128,7 +157,11 @@ const IAINumberInput = (props: Props) => {
|
||||
{...rest}
|
||||
onPaste={stopPastePropagation}
|
||||
>
|
||||
<NumberInputField {...numberInputFieldProps} />
|
||||
<NumberInputField
|
||||
{...numberInputFieldProps}
|
||||
onKeyDown={handleKeyDown}
|
||||
onKeyUp={handleKeyUp}
|
||||
/>
|
||||
{showStepper && (
|
||||
<NumberInputStepper>
|
||||
<NumberIncrementStepper {...numberInputStepperProps} />
|
||||
|
@ -26,9 +26,12 @@ import {
|
||||
} from '@chakra-ui/react';
|
||||
import { clamp } from 'lodash-es';
|
||||
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||
import {
|
||||
FocusEvent,
|
||||
KeyboardEvent,
|
||||
memo,
|
||||
MouseEvent,
|
||||
useCallback,
|
||||
@ -36,9 +39,9 @@ import {
|
||||
useMemo,
|
||||
useState,
|
||||
} from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { BiReset } from 'react-icons/bi';
|
||||
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
|
||||
import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
|
||||
|
||||
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
|
||||
mt: 1.5,
|
||||
@ -56,7 +59,6 @@ export type IAIFullSliderProps = {
|
||||
withInput?: boolean;
|
||||
isInteger?: boolean;
|
||||
inputWidth?: string | number;
|
||||
inputReadOnly?: boolean;
|
||||
withReset?: boolean;
|
||||
handleReset?: () => void;
|
||||
tooltipSuffix?: string;
|
||||
@ -90,7 +92,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
withInput = false,
|
||||
isInteger = false,
|
||||
inputWidth = 16,
|
||||
inputReadOnly = false,
|
||||
withReset = false,
|
||||
hideTooltip = false,
|
||||
isCompact = false,
|
||||
@ -109,7 +110,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
sliderIAIIconButtonProps,
|
||||
...rest
|
||||
} = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [localInputValue, setLocalInputValue] = useState<
|
||||
@ -152,6 +153,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
);
|
||||
|
||||
const handleInputChange = useCallback((v: number | string) => {
|
||||
console.log('input');
|
||||
setLocalInputValue(v);
|
||||
}, []);
|
||||
|
||||
@ -168,6 +170,24 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
}
|
||||
}, []);
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.shiftKey) {
|
||||
dispatch(shiftKeyPressed(true));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleKeyUp = useCallback(
|
||||
(e: KeyboardEvent<HTMLInputElement>) => {
|
||||
if (!e.shiftKey) {
|
||||
dispatch(shiftKeyPressed(false));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl
|
||||
onClick={forceInputBlur}
|
||||
@ -311,7 +331,8 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
{...sliderNumberInputProps}
|
||||
>
|
||||
<NumberInputField
|
||||
readOnly={inputReadOnly}
|
||||
onKeyDown={handleKeyDown}
|
||||
onKeyUp={handleKeyUp}
|
||||
minWidth={inputWidth}
|
||||
{...sliderNumberInputFieldProps}
|
||||
/>
|
||||
|
@ -1,9 +1,38 @@
|
||||
import { Textarea, TextareaProps, forwardRef } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { stopPastePropagation } from 'common/util/stopPastePropagation';
|
||||
import { memo } from 'react';
|
||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||
import { KeyboardEvent, memo, useCallback } from 'react';
|
||||
|
||||
const IAITextarea = forwardRef((props: TextareaProps, ref) => {
|
||||
return <Textarea ref={ref} onPaste={stopPastePropagation} {...props} />;
|
||||
const dispatch = useAppDispatch();
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (e.shiftKey) {
|
||||
dispatch(shiftKeyPressed(true));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleKeyUp = useCallback(
|
||||
(e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (!e.shiftKey) {
|
||||
dispatch(shiftKeyPressed(false));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Textarea
|
||||
ref={ref}
|
||||
onPaste={stopPastePropagation}
|
||||
onKeyDown={handleKeyDown}
|
||||
onKeyUp={handleKeyUp}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
export default memo(IAITextarea);
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user